Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
foreverska committed Feb 22, 2025
1 parent 14a0560 commit 68ecf6b
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 5 deletions.
2 changes: 1 addition & 1 deletion gymnasium/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@
)

register(
id="Taxi-v4",
id="Taxi-v3",
entry_point="gymnasium.envs.toy_text.taxi:TaxiEnv",
reward_threshold=8, # optimum = 8.46
max_episode_steps=200,
Expand Down
4 changes: 2 additions & 2 deletions gymnasium/envs/toy_text/taxi.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,11 +352,11 @@ def step(self, a):
and (taxi_row != shadow_row or taxi_col != shadow_col)
):
self.fickle_step = False
if self.fickle_passenger and np.random.rand() < 0.3:
if self.fickle_passenger and self.np_random.random() < 0.3:
possible_destinations = [
i for i in range(len(self.locs)) if i != shadow_dest_idx
]
dest_idx = np.random.choice(possible_destinations)
dest_idx = self.np_random.choice(possible_destinations)
s = self.encode(taxi_row, taxi_col, pass_loc, dest_idx)

self.s = s
Expand Down
4 changes: 2 additions & 2 deletions tests/envs/registration/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_register_error(env_id):
("blackjack-v1", "Blackjack"),
("Blackjock-v1", "Blackjack"),
("mountaincarcontinuous-v0", "MountainCarContinuous"),
("taxi-v4", "Taxi"),
("taxi-v3", "Taxi"),
("taxi-v30", "Taxi"),
("MyAwesomeNamspce/MyAwesomeVersionedEnv-v1", "MyAwesomeNamespace"),
("MyAwesomeNamspce/MyAwesomeUnversionedEnv", "MyAwesomeNamespace"),
Expand All @@ -121,7 +121,7 @@ def test_env_suggestions(
("CartPole-v12", "`v0`, `v1`", False),
("Blackjack-v10", "`v1`", False),
("MountainCarContinuous-v100", "`v0`", False),
("Taxi-v30", "`v4`", False),
("Taxi-v30", "`v3`", False),
("MyAwesomeNamespace/MyAwesomeVersionedEnv-v6", "`v1`, `v3`, `v5`", False),
("MyAwesomeNamespace/MyAwesomeUnversionedEnv-v6", "", True),
],
Expand Down
8 changes: 8 additions & 0 deletions tests/envs/test_env_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,20 @@ def test_taxi_is_rainy():
assert len(transitions) == 1
assert transitions[0][0] == 1.0

state, _ = env.reset()
_, _, _, _, info = env.step(0)
assert info["prob"] in {0.8, 0.1}

env = TaxiEnv(is_rainy=False)
for state_dict in env.P.values():
for action, transitions in state_dict.items():
assert len(transitions) == 1
assert transitions[0][0] == 1.0

state, _ = env.reset()
_, _, _, _, info = env.step(0)
assert info["prob"] == 1.0


@pytest.mark.parametrize(
"env_name",
Expand Down

0 comments on commit 68ecf6b

Please sign in to comment.