Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle option masking correctly when computing TD targets. #51

Open
Ueva opened this issue May 7, 2024 · 0 comments
Open

Handle option masking correctly when computing TD targets. #51

Ueva opened this issue May 7, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@Ueva
Copy link
Owner

Ueva commented May 7, 2024

# Compute TD targets.
next_state_values = torch.zeros(self.batch_size, dtype=torch.float32)
with torch.no_grad():
next_state_values[non_terminal_mask] = self.target(non_terminal_next_states).max(1).values
targets = reward_batch + self.gamma * next_state_values

Currently, the value of masked options is not ignored when computing TD targets for Macro-DQN updates.

We need a function that produces an option mask for given states, which can be used here and elsewhere where option masking is needed. Ideally, this should be able to take batches of states and produce batches of option masks. This function will be called potentially many times per time step, so it should be performant.

@Ueva Ueva added the bug Something isn't working label May 7, 2024
@Ueva Ueva self-assigned this May 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
Development

No branches or pull requests

1 participant