Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stochastic taxi (rainy+fickle) #1315

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
124 changes: 106 additions & 18 deletions gymnasium/envs/toy_text/taxi.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,6 @@ class TaxiEnv(Env):
- p - transition proability for the state.
- action_mask - if actions will cause a transition to a new state.

As taxi is not stochastic, the transition probability is always 1.0. Implementing
a transitional probability in line with the Dietterich paper ('The fickle taxi task')
is a TODO.

For some cases, taking an action will have no effect on the state of the episode.
In v0.25.0, ``info["action_mask"]`` contains a np.ndarray for each of the actions specifying
if the action will change the state.
Expand All @@ -143,12 +139,22 @@ class TaxiEnv(Env):
gym.make('Taxi-v3')
```

<a id="is_raining"></a>`is_raining=False`: If True the cab will move in intended direction with
probability of 80% else will move in either left or right of target direction with
equal probability of 10% in both directions.

<a id="fickle_passenger"></a>`fickle_passenger=False`: If true the passenger has a 30% chance of changing
destinations when the cab has moved one square away from the passenger's source location. Passenger fickleness
only happens on the first pickup and successful movement. If the passenger is dropped off at the source location
and picked up again, it is not triggered again.

## References
<a id="taxi_ref"></a>[1] T. G. Dietterich, “Hierarchical Reinforcement Learning with the MAXQ Value Function Decomposition,”
Journal of Artificial Intelligence Research, vol. 13, pp. 227–303, Nov. 2000, doi: 10.1613/jair.639.

## Version History
* v3: Map Correction + Cleaner Domain Description, v0.25.0 action masking added to the reset and step information
- In Gymnasium `1.1.0` the `is_rainy` and `fickle_passenger` arguments were added to align with Dietterich, 2000
* v2: Disallow Taxi start location = goal location, Update Taxi observations in the rollout, Update Taxi reward threshold.
* v1: Remove (3,2) from locs, add passidx<4 check
* v0: Initial version release
Expand All @@ -159,7 +165,12 @@ class TaxiEnv(Env):
"render_fps": 4,
}

def __init__(self, render_mode: Optional[str] = None):
def __init__(
self,
render_mode: Optional[str] = None,
is_rainy: bool = False,
fickle_passenger: bool = False,
):
self.desc = np.asarray(MAP, dtype="c")

self.locs = locs = [(0, 0), (0, 4), (4, 0), (4, 3)]
Expand All @@ -186,20 +197,58 @@ def __init__(self, render_mode: Optional[str] = None):
for action in range(num_actions):
# defaults
new_row, new_col, new_pass_idx = row, col, pass_idx
left_pos = (new_row, new_col)
right_pos = (new_row, new_col)
reward = (
-1
) # default reward when there is no pickup/dropoff
terminated = False
taxi_loc = (row, col)

if action == 0:
new_row = min(row + 1, max_row)
elif action == 1:
new_row = max(row - 1, 0)
if action == 2 and self.desc[1 + row, 2 * col + 2] == b":":
new_col = min(col + 1, max_col)
elif action == 3 and self.desc[1 + row, 2 * col] == b":":
new_col = max(col - 1, 0)
moves = {
0: ((1, 0), (0, -1), (0, 1)), # Down
1: ((-1, 0), (0, -1), (0, 1)), # Up
2: ((0, 1), (1, 0), (-1, 0)), # Right
3: ((0, -1), (1, 0), (-1, 0)), # Left
}

# Check if movement is allowed
if (
action in {0, 1}
or (
action == 2
and self.desc[1 + row, 2 * col + 2] == b":"
)
or (action == 3 and self.desc[1 + row, 2 * col] == b":")
):
dr, dc = moves[action][0]
new_row, new_col = max(0, min(row + dr, max_row)), max(
0, min(col + dc, max_col)
)

left_dr, left_dc = moves[action][1]
left_row, left_col = max(
0, min(row + left_dr, max_row)
), max(0, min(col + left_dc, max_col))
if self.desc[1 + left_row, 2 * left_col + 2] == b":":
left_pos = (left_row, left_col)
else:
left_pos = (
new_row,
new_col,
) # Default to current position if not traversable

right_dr, right_dc = moves[action][2]
right_row, right_col = max(
0, min(row + right_dr, max_row)
), max(0, min(col + right_dc, max_col))
if self.desc[1 + right_row, 2 * right_col] == b":":
right_pos = (right_row, right_col)
else:
right_pos = (
new_row,
new_col,
) # Default to current position if not traversable
elif action == 4: # pickup
if pass_idx < 4 and taxi_loc == locs[pass_idx]:
new_pass_idx = 4
Expand All @@ -214,17 +263,36 @@ def __init__(self, render_mode: Optional[str] = None):
new_pass_idx = locs.index(taxi_loc)
else: # dropoff at wrong location
reward = -10
new_state = self.encode(
intended_state = self.encode(
new_row, new_col, new_pass_idx, dest_idx
)
self.P[state][action].append(
(1.0, new_state, reward, terminated)
)
if action <= 3 and is_rainy:
left_state = self.encode(
left_pos[0], left_pos[1], new_pass_idx, dest_idx
)
right_state = self.encode(
right_pos[0], right_pos[1], new_pass_idx, dest_idx
)
self.P[state][action].append(
(0.8, intended_state, reward, terminated)
)
self.P[state][action].append(
(0.1, left_state, -1, terminated)
)
self.P[state][action].append(
(0.1, right_state, -1, terminated)
)
else:
self.P[state][action].append(
(1.0, intended_state, reward, terminated)
)
self.initial_state_distrib /= self.initial_state_distrib.sum()
self.action_space = spaces.Discrete(num_actions)
self.observation_space = spaces.Discrete(num_states)

self.render_mode = render_mode
self.fickle_passenger = fickle_passenger
self.fickle_step = self.fickle_passenger and self.np_random.random() < 0.3

# pygame utils
self.window = None
Expand Down Expand Up @@ -289,9 +357,28 @@ def step(self, a):
transitions = self.P[self.s][a]
i = categorical_sample([t[0] for t in transitions], self.np_random)
p, s, r, t = transitions[i]
self.s = s
self.lastaction = a

shadow_row, shadow_col, shadow_pass_loc, shadow_dest_idx = self.decode(self.s)
taxi_row, taxi_col, pass_loc, _ = self.decode(s)

# If we are in the fickle step, the passenger has been in the vehicle for at least a step and this step the
# position changed
if (
self.fickle_passenger
and self.fickle_step
and shadow_pass_loc == 4
and (taxi_row != shadow_row or taxi_col != shadow_col)
):
self.fickle_step = False
possible_destinations = [
i for i in range(len(self.locs)) if i != shadow_dest_idx
]
dest_idx = self.np_random.choice(possible_destinations)
s = self.encode(taxi_row, taxi_col, pass_loc, dest_idx)

self.s = s

if self.render_mode == "human":
self.render()
# truncation=False as the time limit is handled by the `TimeLimit` wrapper added during `make`
Expand All @@ -306,6 +393,7 @@ def reset(
super().reset(seed=seed)
self.s = categorical_sample(self.initial_state_distrib, self.np_random)
self.lastaction = None
self.fickle_step = self.fickle_passenger and self.np_random.random() < 0.3
self.taxi_orientation = 0

if self.render_mode == "human":
Expand Down
77 changes: 77 additions & 0 deletions tests/envs/test_env_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,83 @@ def test_taxi_encode_decode():
state, _, _, _, _ = env.step(env.action_space.sample())


def test_taxi_is_rainy():
env = TaxiEnv(is_rainy=True)
for state_dict in env.P.values():
for action, transitions in state_dict.items():
if action <= 3:
assert sum([t[0] for t in transitions]) == 1
assert {t[0] for t in transitions} == {0.8, 0.1}
else:
assert len(transitions) == 1
assert transitions[0][0] == 1.0

state, _ = env.reset()
_, _, _, _, info = env.step(0)
assert info["prob"] in {0.8, 0.1}

env = TaxiEnv(is_rainy=False)
for state_dict in env.P.values():
for action, transitions in state_dict.items():
assert len(transitions) == 1
assert transitions[0][0] == 1.0

state, _ = env.reset()
_, _, _, _, info = env.step(0)
assert info["prob"] == 1.0


def test_taxi_disallowed_transitions():
disallowed_transitions = [
((0, 1), (0, 3)),
((0, 3), (0, 1)),
((1, 0), (1, 2)),
((1, 2), (1, 0)),
((3, 1), (3, 3)),
((3, 3), (3, 1)),
((3, 3), (3, 5)),
((3, 5), (3, 3)),
((4, 1), (4, 3)),
((4, 3), (4, 1)),
((4, 3), (4, 5)),
((4, 5), (4, 3)),
]
for rain in {True, False}:
env = TaxiEnv(is_rainy=rain)
for state, state_dict in env.P.items():
start_row, start_col, _, _ = env.decode(state)
for action, transitions in state_dict.items():
for transition in transitions:
end_row, end_col, _, _ = env.decode(transition[1])
assert (
(start_row, start_col),
(end_row, end_col),
) not in disallowed_transitions


def test_taxi_fickle_passenger():
env = TaxiEnv(fickle_passenger=True)
# This is a fickle seed, if randomness or the draws from the PRNG were recently updated, find a new seed
env.reset(seed=43)
state, *_ = env.step(0)
taxi_row, taxi_col, pass_idx, orig_dest_idx = env.decode(state)
# force taxi to passenger location
env.s = env.encode(
env.locs[pass_idx][0], env.locs[pass_idx][1], pass_idx, orig_dest_idx
)
# pick up the passenger
env.step(4)
if env.locs[pass_idx][0] == 0:
# if we're on the top row, move down
state, *_ = env.step(0)
else:
# otherwise move up
state, *_ = env.step(1)
taxi_row, taxi_col, pass_idx, dest_idx = env.decode(state)
# check that passenger has changed their destination
assert orig_dest_idx != dest_idx


@pytest.mark.parametrize(
"env_name",
["Acrobot-v1", "CartPole-v1", "MountainCar-v0", "MountainCarContinuous-v0"],
Expand Down
Loading