|
6 | 6 |
|
7 | 7 | import astropy.units as u
|
8 | 8 | import numpy as np
|
9 |
| -from astropy import utils as astutil |
10 | 9 | from astropy.io import fits
|
11 | 10 | from astropy.modeling import fix_inputs, projections
|
12 | 11 | from astropy.modeling.bounding_box import ModelBoundingBox as Bbox
|
@@ -117,24 +116,18 @@ def __init__(
|
117 | 116 | self._pixel_shape = None
|
118 | 117 |
|
119 | 118 | def _add_units_input(
|
120 |
| - self, arrays: list[np.ndarray], frame: CoordinateFrame | None |
| 119 | + self, arrays: np.ndarray | float, frame: CoordinateFrame | None |
121 | 120 | ) -> tuple[u.Quantity, ...]:
|
122 | 121 | if frame is not None:
|
123 |
| - return tuple( |
124 |
| - u.Quantity(array, unit) |
125 |
| - for array, unit in zip(arrays, frame.unit, strict=False) |
126 |
| - ) |
| 122 | + return frame.add_units(arrays) |
127 | 123 |
|
128 | 124 | return arrays
|
129 | 125 |
|
130 | 126 | def _remove_units_input(
|
131 | 127 | self, arrays: list[u.Quantity], frame: CoordinateFrame | None
|
132 | 128 | ) -> tuple[np.ndarray, ...]:
|
133 | 129 | if frame is not None:
|
134 |
| - return tuple( |
135 |
| - array.to_value(unit) if isinstance(array, u.Quantity) else array |
136 |
| - for array, unit in zip(arrays, frame.unit, strict=False) |
137 |
| - ) |
| 130 | + return frame.remove_units(arrays) |
138 | 131 |
|
139 | 132 | return arrays
|
140 | 133 |
|
@@ -166,10 +159,7 @@ def __call__(
|
166 | 159 | results = self._call_forward(
|
167 | 160 | *args, with_bounding_box=with_bounding_box, fill_value=fill_value, **kwargs
|
168 | 161 | )
|
169 |
| - |
170 | 162 | if with_units:
|
171 |
| - if not astutil.isiterable(results): |
172 |
| - results = (results,) |
173 | 163 | # values are always expected to be arrays or scalars not quantities
|
174 | 164 | results = self._remove_units_input(results, self.output_frame)
|
175 | 165 | high_level = values_to_high_level_objects(*results, low_level_wcs=self)
|
@@ -403,7 +393,9 @@ def outside_footprint(self, world_arrays):
|
403 | 393 | max_ax = axis_range[~m].min()
|
404 | 394 | outside = (coord > min_ax) & (coord < max_ax)
|
405 | 395 | else:
|
406 |
| - coord_ = self._remove_units_input([coord], self.output_frame)[0] |
| 396 | + coord_ = self._remove_quantity_output( |
| 397 | + world_arrays, self.output_frame |
| 398 | + )[idim] |
407 | 399 | outside = (coord_ < min_ax) | (coord_ > max_ax)
|
408 | 400 | if np.any(outside):
|
409 | 401 | if np.isscalar(coord):
|
|
0 commit comments