Skip to content

Commit

Permalink
update ObservationArray so it doesn't need a concatenate wrapper func
Browse files Browse the repository at this point in the history
  • Loading branch information
yoachim committed Dec 19, 2024
1 parent c876469 commit a04584e
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 43 deletions.
12 changes: 6 additions & 6 deletions rubin_scheduler/scheduler/detailers/detailer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as np

import rubin_scheduler.scheduler.features as features
from rubin_scheduler.scheduler.utils import IntRounded, ObservationArray, obsarray_concat
from rubin_scheduler.scheduler.utils import IntRounded, ObservationArray
from rubin_scheduler.utils import (
DEFAULT_NSIDE,
_angular_separation,
Expand Down Expand Up @@ -373,7 +373,7 @@ def __init__(
# Make backwards compatible if someone sent in a list
if isinstance(sequence_obs, list):
warnings.warn("sequence_obs should be ObsArray, not list of ObsArray. Concatenating")
sequence_obs = obsarray_concat(sequence_obs)
sequence_obs = np.concatenate(sequence_obs)

self.sequence_obs = sequence_obs

Expand Down Expand Up @@ -401,7 +401,7 @@ def __init__(
def __call__(self, observation_array, conditions):
# Do we need to add the opening sequence?
if (conditions.mjd - self.survey_features["last_matching"].feature["mjd"]) >= self.time_match:
observation_array = obsarray_concat([self.sequence_obs, observation_array])
observation_array = np.concatenate([self.sequence_obs, observation_array])

return observation_array

Expand Down Expand Up @@ -461,7 +461,7 @@ def __call__(self, obs_array, conditions):
else:
good = np.min(np.where(ang_dist == ang_dist.min())[0])
indx = in_band[good]
result = obsarray_concat([obs_array[indx:], obs_array[:indx]])
result = np.concatenate([obs_array[indx:], obs_array[:indx]])
return result


Expand Down Expand Up @@ -529,7 +529,7 @@ def __call__(self, observation_array, conditions):
observation_array["scheduler_note"] = np.char.add(
observation_array["scheduler_note"], ", %s" % tags[1]
)
result = obsarray_concat([paired, observation_array])
result = np.concatenate([paired, observation_array])

return result

Expand Down Expand Up @@ -576,4 +576,4 @@ def __call__(self, obs_array, conditions):
if self.update_note:
sub_arr["scheduler_note"] = np.char.add(sub_arr["scheduler_note"], ", %i" % i)
out_obs.append(sub_arr)
return obsarray_concat(out_obs)
return np.concatenate(out_obs)
2 changes: 1 addition & 1 deletion rubin_scheduler/scheduler/schedulers/core_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def update_conditions(self, conditions_in):
for sur in sl:
scheduled = sur.get_scheduled_obs()
if scheduled is not None:
all_scheduled.append(scheduled)
all_scheduled.append(scheduled.view(np.ndarray))
if len(all_scheduled) == 0:
self.conditions.scheduled_observations = []
else:
Expand Down
4 changes: 2 additions & 2 deletions rubin_scheduler/scheduler/surveys/dd_surveys.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import rubin_scheduler.scheduler.basis_functions as basis_functions
from rubin_scheduler.scheduler import features
from rubin_scheduler.scheduler.surveys import BaseSurvey
from rubin_scheduler.scheduler.utils import ObservationArray, obsarray_concat
from rubin_scheduler.scheduler.utils import ObservationArray
from rubin_scheduler.utils import DEFAULT_NSIDE, ddf_locations, ra_dec2_hpid

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(
self.observations = sequence

# Let's just make this an array for ease of use
self.observations = obsarray_concat(self.observations)
self.observations = np.concatenate(self.observations)
order = np.argsort(self.observations["filter"])
self.observations = self.observations[order]

Expand Down
4 changes: 2 additions & 2 deletions rubin_scheduler/scheduler/surveys/ddf_presched.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from rubin_scheduler.data import get_data_dir
from rubin_scheduler.scheduler.utils import ScheduledObservationArray, obsarray_concat
from rubin_scheduler.scheduler.utils import ScheduledObservationArray
from rubin_scheduler.site_models import Almanac
from rubin_scheduler.utils import SURVEY_START_MJD, calc_season, ddf_locations

Expand Down Expand Up @@ -514,5 +514,5 @@ def generate_ddf_scheduled_obs(
obs["moon_min_distance"] = moon_min_distance
all_scheduled_obs.append(obs)

result = obsarray_concat(all_scheduled_obs)
result = np.concatenate(all_scheduled_obs)
return result
4 changes: 2 additions & 2 deletions rubin_scheduler/scheduler/surveys/field_survey.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ..detailers import AltAz2RaDecDetailer
from ..features import LastObservation, NObsCount
from ..utils import ObservationArray, obsarray_concat
from ..utils import ObservationArray
from . import BaseSurvey


Expand Down Expand Up @@ -167,7 +167,7 @@ def __init__(
self.observations.append(obs)

# Let's just make this an array for ease of use
self.observations = obsarray_concat(self.observations)
self.observations = np.concatenate(self.observations)
order = np.argsort(self.observations["filter"])
self.observations = self.observations[order]

Expand Down
4 changes: 2 additions & 2 deletions rubin_scheduler/scheduler/surveys/scripted_surveys.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from rubin_scheduler.scheduler.surveys import BaseSurvey
from rubin_scheduler.scheduler.utils import ScheduledObservationArray, obsarray_concat
from rubin_scheduler.scheduler.utils import ScheduledObservationArray
from rubin_scheduler.utils import DEFAULT_NSIDE, _angular_separation, _approx_ra_dec2_alt_az

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -318,7 +318,7 @@ def set_script(self, obs_wanted, append=True, add_index=True):
self.id_start = self.script_id_array.max() + 1

if append & (self.obs_wanted is not None):
self.obs_wanted = obsarray_concat([self.obs_wanted, obs_wanted])
self.obs_wanted = np.concatenate([self.obs_wanted, obs_wanted])
self.obs_wanted.sort(order=["mjd", "filter"])
else:
self.obs_wanted = obs_wanted
Expand Down
56 changes: 31 additions & 25 deletions rubin_scheduler/scheduler/utils/observation_array.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
__all__ = (
"obsarray_concat",
"ObservationArray",
"ScheduledObservationArray",
)

import numpy as np

HANDLED_FUNCTIONS = {}


class ObservationArray(np.ndarray):
"""Class to work as an array of observations
Expand Down Expand Up @@ -150,6 +151,35 @@ def tolist(self):

return obs_list

def __array_function__(self, func, types, args, kwargs):
# If we want "standard numpy behavior",
# convert any ObservationArray to ndarray views
if func not in HANDLED_FUNCTIONS:
new_args = []
for arg in args:
if issubclass(arg.__class__, ObservationArray):
new_args.append(arg.view(np.ndarray))
else:
new_args.append(arg)
return func(*new_args, **kwargs)
if not all(issubclass(t, ObservationArray) for t in types):
return NotImplemented
return HANDLED_FUNCTIONS[func](*args, **kwargs)


def implements(numpy_function):
def decorator(func):
HANDLED_FUNCTIONS[numpy_function] = func
return func

return decorator


@implements(np.concatenate)
def concatenate(arrays):
result = arrays[0].__class__(n=sum(len(a) for a in arrays))
return np.concatenate([np.asarray(a) for a in arrays], out=result)


class ScheduledObservationArray(ObservationArray):
"""Make an array to hold pre-scheduling observations
Expand Down Expand Up @@ -236,27 +266,3 @@ def to_observation_array(self):
for key in in_common:
result[key] = self[key]
return result


def obsarray_concat(in_arrays):
"""Concatenate ObservationArray objects.
Can't use np.concatenate because it will no longer
be an array subclass
Parameters
----------
in_arrays : `list` of `ObservationArray` or `ScheduledObservationArray`
"""
# Check if we have ScheduledObservationArray
array_class = ObservationArray
if "observed" in in_arrays[0].dtype.names:
array_class = ScheduledObservationArray

size = 0
for arr in in_arrays:
size += arr.size
# Init empty array of proper class
# to hold output.
out_arr = array_class(n=size)
return np.concatenate(in_arrays, out=out_arr)
5 changes: 2 additions & 3 deletions tests/scheduler/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
ScheduledObservationArray,
SchemaConverter,
make_rolling_footprints,
obsarray_concat,
restore_scheduler,
run_info_table,
season_calc,
Expand Down Expand Up @@ -309,7 +308,7 @@ def test_observation_array(self):
assert len(obs_list[0]) == 1
assert len(obs_list) == n

back = obsarray_concat(obs_list)
back = np.concatenate(obs_list)

assert np.array_equal(back, obs)

Expand All @@ -321,7 +320,7 @@ def test_observation_array(self):
assert len(obs_list[0]) == 1
assert len(obs_list) == n

back = obsarray_concat(obs_list)
back = np.concatenate(obs_list)

assert np.array_equal(back, obs)

Expand Down

0 comments on commit a04584e

Please sign in to comment.