Skip to content

Commit

Permalink
Getting state type hint. Bug hunting on mypy ignore in ResourceState …
Browse files Browse the repository at this point in the history
…led to MockEnvironment change.
  • Loading branch information
JamesArruda committed Dec 18, 2024
1 parent a1fbb35 commit 6883c02
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 17 deletions.
21 changes: 16 additions & 5 deletions src/upstage_des/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from warnings import warn

from simpy import Environment as SimpyEnv
from simpy import Event as SimEvent
from simpy.core import SimTime

from upstage_des.geography import INTERSECTION_LOCATION_CALLABLE, EarthProtocol
from upstage_des.units.convert import STANDARD_TIMES, TIME_ALTERNATES, unit_convert
Expand Down Expand Up @@ -147,7 +149,7 @@ class RulesError(UpstageError):
"""Raised by the user when a simulation rule is violated."""


class MockEnvironment:
class MockEnvironment(SimpyEnv):
"""A fake environment that holds the ``now`` property and all-caps attributes."""

def __init__(self, now: float):
Expand All @@ -156,7 +158,16 @@ def __init__(self, now: float):
Args:
now (float): The time the environment is at.
"""
self.now = now
super().__init__(initial_time=now)

@property
def now(self) -> SimTime:
"""The current simulation time."""
return self._now

@now.setter
def now(self, value: SimTime) -> None:
self._now = value

@classmethod
def mock(cls, env: Union[SimpyEnv, "MockEnvironment"]) -> "MockEnvironment":
Expand All @@ -176,11 +187,11 @@ def mock(cls, env: Union[SimpyEnv, "MockEnvironment"]) -> "MockEnvironment":
return mock_env

@classmethod
def run(cls, until: float | int) -> None:
def run(cls, until: SimTime | SimEvent | None = None) -> Any | None:
"""Method stub for playing nice with rehearsal.
Args:
until (float | int): Placeholder
until (SimTime | SimEvent | None): Placeholder
"""
raise UpstageError("You tried to use `run` on a mock environment")

Expand Down Expand Up @@ -395,7 +406,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._new_env: MockEnvironment | None = None
super().__init__(*args, **kwargs)

@property # type: ignore [override]
@property
def env(self) -> SimpyEnv | MockEnvironment:
"""Get the relevant environment.
Expand Down
10 changes: 5 additions & 5 deletions src/upstage_des/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from simpy.resources.resource import Release, Request
from simpy.resources.store import StoreGet, StorePut

from .base import SimulationError, UpstageBase, UpstageError
from .base import MockEnvironment, SimulationError, UpstageBase, UpstageError
from .constants import PLANNING_FACTOR_OBJECT
from .units import unit_convert

Expand Down Expand Up @@ -241,7 +241,7 @@ def as_event(self) -> SIM.Timeout:
Returns:
SIM.Timeout
"""
assert isinstance(self.env, SIM.Environment)
assert not isinstance(self.env, MockEnvironment)
self._simpy_event = self.env.timeout(self._time_to_complete)
return self._simpy_event

Expand Down Expand Up @@ -413,7 +413,7 @@ def as_event(self) -> SIM.Event:
SIM.Event: typically an Any or All
"""
sub_events = [self._make_event(event) for event in self.events]
assert isinstance(self.env, SIM.Environment)
assert not isinstance(self.env, MockEnvironment)
self._simpy_event = self.simpy_equivalent(self.env, sub_events)
return self._simpy_event

Expand Down Expand Up @@ -782,7 +782,7 @@ def __init__(
# yielded on
self._payload: dict[str, Any] = {}
self._auto_reset = auto_reset
assert isinstance(self.env, SIM.Environment)
assert not isinstance(self.env, MockEnvironment)
self._event = SIM.Event(self.env)

def calculate_time_to_complete(self) -> float:
Expand Down Expand Up @@ -837,7 +837,7 @@ def get_payload(self) -> dict[str, tyAny]:

def reset(self) -> None:
"""Reset the event to allow it to be held again."""
assert isinstance(self.env, SIM.Environment)
assert not isinstance(self.env, MockEnvironment)
self._event = SIM.Event(self.env)

def cancel(self) -> None:
Expand Down
17 changes: 15 additions & 2 deletions src/upstage_des/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from collections.abc import Callable
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, get_args

from simpy import Container, Store

Expand Down Expand Up @@ -186,6 +186,19 @@ def __get__(self, instance: "Actor", objtype: type | None = None) -> ST:
def __set_name__(self, owner: "Actor", name: str) -> None:
self.name = name

def _infer_state(self, instance: "Actor") -> tuple[Any, ...]:
"""Infer types for the state.
Args:
instance (Actor): The actor the state is attached to.
Returns:
le[Any,...]: The state type
"""
state_class = instance._state_defs[self.name]
args = get_args(state_class.__orig_class__)
return args

def has_default(self) -> bool:
"""Check if a default exists.
Expand Down Expand Up @@ -943,7 +956,7 @@ def _make_clone(self, instance: "Actor", copy: T) -> T:
"""
base_class = type(copy)
memory: dict[str, Any] = instance.__dict__[f"_memory_for_{self.name}"]
new = base_class(instance.env, **memory) # type: ignore [arg-type]
new = base_class(instance.env, **memory)
if isinstance(copy, Store) and isinstance(new, Store):
new.items = list(copy.items)
if isinstance(copy, Container) and isinstance(new, Container):
Expand Down
3 changes: 1 addition & 2 deletions src/upstage_des/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import TYPE_CHECKING, Any, TypeVar
from warnings import warn

from simpy import Environment as SimpyEnv
from simpy import Event as SimpyEvent
from simpy import Interrupt, Process

Expand Down Expand Up @@ -530,7 +529,7 @@ def run(self, *, actor: "Actor") -> Generator[SimpyEvent, None, None]:
Generator[SimpyEvent, None, None]: Generator for SimPy event queue.
"""
self.make_decision(actor=actor)
assert isinstance(self.env, SimpyEnv)
assert not isinstance(self.env, MockEnvironment)
yield self.env.timeout(0.0)


Expand Down
2 changes: 2 additions & 0 deletions src/upstage_des/test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from upstage_des.base import (
STAGE_CONTEXT_VAR,
EnvironmentContext,
MockEnvironment,
NamedUpstageEntity,
UpstageBase,
UpstageError,
Expand All @@ -24,6 +25,7 @@
def test_context() -> None:
with EnvironmentContext() as env:
assert isinstance(env, SIM.Environment)
assert not isinstance(env, MockEnvironment)
env.run(until=3)
assert env.now == 3

Expand Down
6 changes: 3 additions & 3 deletions src/upstage_des/test/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_failures_for_tasks_with_simpy_events() -> None:

class BrokenTask(Task):
def task(self, *, actor: ActorForTest) -> TASK_GEN:
yield self.env.timeout(1.0) # type: ignore [misc, union-attr]
yield self.env.timeout(1.0) # type: ignore [misc]

# msg = "*Task is yielding objects without `as_event`*"
with pytest.raises(SimulationError): # , match=msg):
Expand All @@ -137,8 +137,8 @@ def task(self, *, actor: ActorForTest) -> TASK_GEN:
)
env.run()

# msg = "*'MockEnvironment' object has no attribute 'timeout'*"
with pytest.raises(AttributeError): # , match=msg):
msg = "must be a subclass of BaseEvent"
with pytest.raises(SimulationError, match=msg):
the_task = BrokenTask()
the_task.rehearse(
actor=actor,
Expand Down

0 comments on commit 6883c02

Please sign in to comment.