-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathteacher_evaluate.py
40 lines (38 loc) · 1.86 KB
/
teacher_evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from environment import make_envs
from common import get_config,space2shape
from rl_base.representation import Teacher_Encoder
from rl_base.policy import ActorCriticPolicy as Teacher_Policy
from rl_base.teacher_agent import Teacher_Agent
import argparse
import torch
import os
os.environ['DISPLAY'] = ":1"
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--env_id",type=str,default="CoG-Navigation")
parser.add_argument("--seed",type=int,default=7820)
parser.add_argument("--config_path",type=str,default="./config/evaluate_config.yaml")
args = parser.parse_known_args()[0]
return args
if __name__ == "__main__":
args = get_args()
config = get_config(args.config_path)
envs = make_envs(args.env_id,args.seed,config)
observation_space = envs.observation_space
action_space = envs.action_space
representation = Teacher_Encoder(space2shape(observation_space),
None,
torch.nn.init.orthogonal_,
torch.nn.LeakyReLU,
config.device)
policy = Teacher_Policy(action_space,
representation,
config.actor_hidden_size,
config.critic_hidden_size,
initialize=torch.nn.init.orthogonal_,
activation=torch.nn.Tanh,
device = config.device)
optimizer = torch.optim.Adam(policy.parameters(),config.learning_rate,eps=1e-5)
lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer,start_factor=1.0,end_factor=0.25,total_iters=int(config.training_steps*config.nepoch*config.nminibatch/config.nsteps))
agent = Teacher_Agent(config,envs,policy,optimizer,lr_scheduler,config.device)
agent.test(100,config.model_name)