Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

assert total_steps >= num_timesteps #52

Open
kassasin opened this issue Feb 11, 2025 · 5 comments
Open

assert total_steps >= num_timesteps #52

kassasin opened this issue Feb 11, 2025 · 5 comments

Comments

@kassasin
Copy link

kassasin commented Feb 11, 2025

Hello, I recently meet a problem as follows:

Cell In[10], line 1
----> 1 make_inference_fn, params, _ = train_fn(
      2     environment=env,
      3     eval_env=registry.load(env_name, config=env_cfg),
      4     wrap_env_fn=wrapper.wrap_for_brax_training,
      5 )
      6 if len(times) > 1: 
      7   print(f"time to jit: {times[1] - times[0]}")

File ~/miniconda3/envs/mujoco_playground/lib/python3.10/site-packages/brax/training/agents/ppo/train.py:634, in train(environment, num_timesteps, episode_length, wrap_env, wrap_env_fn, action_repeat, num_envs, max_devices_per_host, num_eval_envs, learning_rate, entropy_cost, discounting, seed, unroll_length, batch_size, num_minibatches, num_updates_per_batch, num_evals, num_resets_per_eval, normalize_observations, reward_scaling, clipping_epsilon, gae_lambda, deterministic_eval, network_factory, progress_fn, normalize_advantage, eval_env, policy_params_fn, randomization_fn, restore_checkpoint_path, max_grad_norm, madrona_backend, augment_pixels)
    631     policy_params_fn(current_step, make_policy, params)
    633 total_steps = current_step
--> 634 assert total_steps >= num_timesteps
    636 # If there was no mistakes the training_state should still be identical on all
    637 # devices.
    638 pmap.assert_is_replicated(training_state)

AssertionError:  

I think it maybe relate with ppo_params, this is my ppo params:

 action_repeat: 1
batch_size: 512
clipping_epsilon: 0.2
discounting: 0.97
entropy_cost: 0.005
episode_length: 5000
learning_rate: 0.0003
max_grad_norm: 1.0
network_factory:
  policy_hidden_layer_sizes: &id001 !!python/tuple
  - 512
  - 256
  - 128
  policy_obs_key: state
  value_hidden_layer_sizes: *id001
  value_obs_key: privileged_state
normalize_observations: true
num_envs: 8192
num_evals: 120
num_minibatches: 32
num_resets_per_eval: 4
num_timesteps: 3000000000
num_updates_per_batch: 5
reward_scaling: 1.0
unroll_length: 100
@vincentzhang
Copy link
Contributor

Could you print the value for total_steps?

@kassasin
Copy link
Author

kassasin commented Feb 13, 2025

Could you print the value for total_steps?

Sorry I have already restart.I can take a test and report details.

@vincentzhang
Copy link
Contributor

Yeah thx, it will be helpful if you could reproduce it.

@kassasin
Copy link
Author

Yeah thx, it will be helpful if you could reproduce it.

Sorry for this late reply. I finally reproduce this problem.
This is my ppo params

action_repeat: 1
batch_size: 512
clipping_epsilon: 0.2
discounting: 0.98
entropy_cost: 0.005
episode_length: 4000
learning_rate: 0.0003
max_grad_norm: 1.0
network_factory:
  policy_hidden_layer_sizes: &id001 !!python/tuple
  - 512
  - 256
  - 128
  policy_obs_key: state
  value_hidden_layer_sizes: *id001
  value_obs_key: privileged_state
normalize_observations: true
num_envs: 16384
num_evals: 100
num_minibatches: 32
num_resets_per_eval: 4
num_timesteps: 2200000000
num_updates_per_batch: 5
reward_scaling: 1.0
unroll_length: 80

this is all time steps:

[0,
 26214400,
 52428800,
 78643200,
 104857600,
 131072000,
 157286400,
 183500800,
 209715200,
 235929600,
 262144000,
 288358400,
 314572800,
 340787200,
 367001600,
 393216000,
 419430400,
 445644800,
 471859200,
 498073600,
 524288000,
 550502400,
 576716800,
 602931200,
 629145600,
 655360000,
 681574400,
 707788800,
 734003200,
 760217600,
 786432000,
 812646400,
 838860800,
 865075200,
 891289600,
 917504000,
 943718400,
 969932800,
 996147200,
 1022361600,
 1048576000,
 1074790400,
 1101004800,
 1127219200,
 1153433600,
 1179648000,
 1205862400,
 1232076800,
 1258291200,
 1284505600,
 1310720000,
 1336934400,
 1363148800,
 1389363200,
 1415577600,
 1441792000,
 1468006400,
 1494220800,
 1520435200,
 1546649600,
 1572864000,
 1599078400,
 1625292800,
 1651507200,
 1677721600,
 1703936000,
 1730150400,
 1756364800,
 1782579200,
 1808793600,
 1835008000,
 1861222400,
 1887436800,
 1913651200,
 1939865600,
 1966080000,
 1992294400,
 2018508800,
 2044723200,
 2070937600,
 2097152000,
 2123366400,
 -2145386496,
 -2119172096,
 -2092957696,
 -2066743296,
 -2040528896,
 -2014314496,
 -1988100096,
 -1961885696,
 -1935671296,
 -1909456896,
 -1883242496,
 -1857028096,
 -1830813696,
 -1804599296,
 -1778384896,
 -1752170496,
 -1725956096,
 -1699741696]

I think you are right .

@vincentzhang
Copy link
Contributor

Thanks for the additional information. Indeed it seems to be an overflow problem. There's a suggestion of fix in google/brax#578. I'll give it a try and update it here if any.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants