forked from tensorfly-gpu/aichess
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcollect.py
149 lines (135 loc) · 5.83 KB
/
collect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""自我对弈收集数据"""
import random
from collections import deque
import copy
import os
import pickle
import time
from game import Board, Game, move_action2move_id, move_id2move_action, flip_map
from mcts import MCTSPlayer
from config import CONFIG
if CONFIG['use_redis']:
import my_redis, redis
import zip_array
if CONFIG['use_frame'] == 'paddle':
from paddle_net import PolicyValueNet
elif CONFIG['use_frame'] == 'pytorch':
from pytorch_net import PolicyValueNet
else:
print('暂不支持您选择的框架')
# 定义整个对弈收集数据流程
class CollectPipeline:
def __init__(self, init_model=None):
# 象棋逻辑和棋盘
self.board = Board()
self.game = Game(self.board)
# 对弈参数
self.temp = 1 # 温度
self.n_playout = CONFIG['play_out'] # 每次移动的模拟次数
self.c_puct = CONFIG['c_puct'] # u的权重
self.buffer_size = CONFIG['buffer_size'] # 经验池大小
self.data_buffer = deque(maxlen=self.buffer_size)
self.iters = 0
if CONFIG['use_redis']:
self.redis_cli = my_redis.get_redis_cli()
# 从主体加载模型
def load_model(self):
if CONFIG['use_frame'] == 'paddle':
model_path = CONFIG['paddle_model_path']
elif CONFIG['use_frame'] == 'pytorch':
model_path = CONFIG['pytorch_model_path']
else:
print('暂不支持所选框架')
try:
self.policy_value_net = PolicyValueNet(model_file=model_path)
print('已加载最新模型')
except:
self.policy_value_net = PolicyValueNet()
print('已加载初始模型')
self.mcts_player = MCTSPlayer(self.policy_value_net.policy_value_fn,
c_puct=self.c_puct,
n_playout=self.n_playout,
is_selfplay=1)
def get_equi_data(self, play_data):
"""左右对称变换,扩充数据集一倍,加速一倍训练速度"""
extend_data = []
# 棋盘状态shape is [9, 10, 9], 走子概率,赢家
for state, mcts_prob, winner in play_data:
# 原始数据
extend_data.append(zip_array.zip_state_mcts_prob((state, mcts_prob, winner)))
# 水平翻转后的数据
state_flip = state.transpose([1, 2, 0])
state = state.transpose([1, 2, 0])
for i in range(10):
for j in range(9):
state_flip[i][j] = state[i][8 - j]
state_flip = state_flip.transpose([2, 0, 1])
mcts_prob_flip = copy.deepcopy(mcts_prob)
for i in range(len(mcts_prob_flip)):
mcts_prob_flip[i] = mcts_prob[move_action2move_id[flip_map(move_id2move_action[i])]]
extend_data.append(zip_array.zip_state_mcts_prob((state_flip, mcts_prob_flip, winner)))
return extend_data
def collect_selfplay_data(self, n_games=1):
# 收集自我对弈的数据
for i in range(n_games):
self.load_model() # 从本体处加载最新模型
winner, play_data = self.game.start_self_play(self.mcts_player, temp=self.temp, is_shown=False)
play_data = list(play_data)[:]
self.episode_len = len(play_data)
# 增加数据
play_data = self.get_equi_data(play_data)
if CONFIG['use_redis']:
while True:
try:
for d in play_data:
self.redis_cli.rpush('train_data_buffer', pickle.dumps(d))
self.redis_cli.incr('iters')
self.iters = self.redis_cli.get('iters')
print("存储完成")
break
except:
print("存储失败")
time.sleep(1)
else:
if os.path.exists(CONFIG['train_data_buffer_path']):
while True:
try:
with open(CONFIG['train_data_buffer_path'], 'rb') as data_dict:
data_file = pickle.load(data_dict)
self.data_buffer = deque(maxlen=self.buffer_size)
self.data_buffer.extend(data_file['data_buffer'])
self.iters = data_file['iters']
del data_file
self.iters += 1
self.data_buffer.extend(play_data)
print('成功载入数据')
break
except:
time.sleep(30)
else:
self.data_buffer.extend(play_data)
self.iters += 1
data_dict = {'data_buffer': self.data_buffer, 'iters': self.iters}
with open(CONFIG['train_data_buffer_path'], 'wb') as data_file:
pickle.dump(data_dict, data_file)
return self.iters
def run(self):
"""开始收集数据"""
try:
while True:
iters = self.collect_selfplay_data()
print('batch i: {}, episode_len: {}'.format(
iters, self.episode_len))
except KeyboardInterrupt:
print('\n\rquit')
collecting_pipeline = CollectPipeline(init_model='current_policy.model')
collecting_pipeline.run()
if CONFIG['use_frame'] == 'paddle':
collecting_pipeline = CollectPipeline(init_model='current_policy.model')
collecting_pipeline.run()
elif CONFIG['use_frame'] == 'pytorch':
collecting_pipeline = CollectPipeline(init_model='current_policy.pkl')
collecting_pipeline.run()
else:
print('暂不支持您选择的框架')
print('训练结束')