-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDDPG_test.py
65 lines (51 loc) · 1.73 KB
/
DDPG_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import math
import random
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
import matplotlib.pyplot as plt
import pprint
import highway_env
from DDPG_net import *
import json
from collections import defaultdict
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
env = gym.make("lvxinfei-v0")
env.reset()
ddpg = torch.load('./weights_test/ddpg_net0.pth')
max_steps = 1
rewards = []
batch_size = 32
speed = []
info_out = defaultdict(list)
with torch.no_grad():
for step in range(max_steps):
print("================第{}回合======================================".format(step+1))
state = env.reset()
state = torch.flatten(torch.tensor(state))
done = False
while not done:
action = ddpg.policy_net.get_action(state)
next_state, reward, done, info = env.step(action)
#对一些信息进行存储
info_out["speed"].append(info['speed'])
info_out["x"].append(info['x'])
info_out["y"].append(info['y'])
info_out["vx"].append(info['vx'])
info_out["vy"].append(info['vy'])
info_out["sin_h"].append(info['sin_h'])
info_out["cos_h"].append(info['cos_h'])
info_out["vehicle heading"].append(info['vehicle heading'])
info_out['road heading'].append(info['road heading'])
# print(info)
next_state = torch.flatten(torch.tensor(next_state))
state = next_state
env.render()
env.close()
# with open("./JSON/v1.json", 'w', encoding='UTF-8') as f:
# f.write(json.dumps(info_out))