Skip to content

Commit 7a794c1

Browse files
Bugfix for the loss of array values
This was largely an issue of the zip(strict=False) causing missed entries. I made this strict=True
1 parent 6c70743 commit 7a794c1

File tree

2 files changed

+29
-14
lines changed

2 files changed

+29
-14
lines changed

gwcs/coordinate_frames.py

+23
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,29 @@ def _native_world_axis_object_components(self):
388388
to be able to get the components in their native order.
389389
"""
390390

391+
def add_units(self, arrays: u.Quantity | np.ndarray | float) -> tuple[u.Quantity]:
392+
"""
393+
Add units to the arrays
394+
"""
395+
return tuple(
396+
u.Quantity(array, unit=unit)
397+
for array, unit in zip(arrays, self.unit, strict=True)
398+
)
399+
400+
def remove_units(
401+
self, arrays: u.Quantity | np.ndarray | float
402+
) -> tuple[np.ndarray]:
403+
"""
404+
Remove units from the input arrays
405+
"""
406+
if self.naxes == 1:
407+
arrays = (arrays,)
408+
409+
return tuple(
410+
array.to_value(unit) if isinstance(array, u.Quantity) else array
411+
for array, unit in zip(arrays, self.unit, strict=True)
412+
)
413+
391414

392415
class CoordinateFrame(BaseCoordinateFrame):
393416
"""

gwcs/wcs/_wcs.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import astropy.units as u
88
import numpy as np
9-
from astropy import utils as astutil
109
from astropy.io import fits
1110
from astropy.modeling import fix_inputs, projections
1211
from astropy.modeling.bounding_box import ModelBoundingBox as Bbox
@@ -117,24 +116,18 @@ def __init__(
117116
self._pixel_shape = None
118117

119118
def _add_units_input(
120-
self, arrays: list[np.ndarray], frame: CoordinateFrame | None
119+
self, arrays: np.ndarray | float, frame: CoordinateFrame | None
121120
) -> tuple[u.Quantity, ...]:
122121
if frame is not None:
123-
return tuple(
124-
u.Quantity(array, unit)
125-
for array, unit in zip(arrays, frame.unit, strict=False)
126-
)
122+
return frame.add_units(arrays)
127123

128124
return arrays
129125

130126
def _remove_units_input(
131127
self, arrays: list[u.Quantity], frame: CoordinateFrame | None
132128
) -> tuple[np.ndarray, ...]:
133129
if frame is not None:
134-
return tuple(
135-
array.to_value(unit) if isinstance(array, u.Quantity) else array
136-
for array, unit in zip(arrays, frame.unit, strict=False)
137-
)
130+
return frame.remove_units(arrays)
138131

139132
return arrays
140133

@@ -166,10 +159,7 @@ def __call__(
166159
results = self._call_forward(
167160
*args, with_bounding_box=with_bounding_box, fill_value=fill_value, **kwargs
168161
)
169-
170162
if with_units:
171-
if not astutil.isiterable(results):
172-
results = (results,)
173163
# values are always expected to be arrays or scalars not quantities
174164
results = self._remove_units_input(results, self.output_frame)
175165
high_level = values_to_high_level_objects(*results, low_level_wcs=self)
@@ -403,7 +393,9 @@ def outside_footprint(self, world_arrays):
403393
max_ax = axis_range[~m].min()
404394
outside = (coord > min_ax) & (coord < max_ax)
405395
else:
406-
coord_ = self._remove_units_input([coord], self.output_frame)[0]
396+
coord_ = self._remove_quantity_output(
397+
world_arrays, self.output_frame
398+
)[idim]
407399
outside = (coord_ < min_ax) | (coord_ > max_ax)
408400
if np.any(outside):
409401
if np.isscalar(coord):

0 commit comments

Comments
 (0)