Skip to content

Commit 110235a

Browse files
authored
Merge pull request LeelaChessZero#8 from Tilps/rescore_tb
Rescore tb
2 parents 8346850 + 2400643 commit 110235a

31 files changed

+1726
-255
lines changed

libs/lczero-common

meson.build

+1-1
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ if not get_option('pext')
439439
add_project_arguments('-DNO_PEXT', language : 'cpp')
440440
endif
441441

442-
executable('lc0', 'src/main.cc',
442+
executable('rescorer', 'src/main.cc',
443443
files, include_directories: includes, dependencies: deps, install: true)
444444

445445

src/main.cc

+6-1
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,13 @@ int main(int argc, const char** argv) {
4646
CommandLine::RegisterMode("uci", "(default) Act as UCI engine");
4747
CommandLine::RegisterMode("selfplay", "Play games with itself");
4848
CommandLine::RegisterMode("benchmark", "Quick benchmark");
49+
CommandLine::RegisterMode("rescore",
50+
"Update data scores with tablebase support");
4951

50-
if (CommandLine::ConsumeCommand("selfplay")) {
52+
if (CommandLine::ConsumeCommand("rescore")) {
53+
RescoreLoop loop;
54+
loop.RunLoop();
55+
} else if (CommandLine::ConsumeCommand("selfplay")) {
5156
// Selfplay mode.
5257
SelfPlayLoop loop;
5358
loop.RunLoop();

src/mcts/node.cc

+16-6
Original file line numberDiff line numberDiff line change
@@ -297,19 +297,21 @@ uint64_t ReverseBitsInBytes(uint64_t v) {
297297
}
298298
} // namespace
299299

300-
V3TrainingData Node::GetV3TrainingData(
301-
GameResult game_result, const PositionHistory& history,
302-
FillEmptyHistory fill_empty_history) const {
303-
V3TrainingData result;
300+
V4TrainingData Node::GetV4TrainingData(GameResult game_result,
301+
const PositionHistory& history,
302+
FillEmptyHistory fill_empty_history,
303+
float best_eval) const {
304+
V4TrainingData result;
304305

305306
// Set version.
306-
result.version = 3;
307+
result.version = 4;
307308

308309
// Populate probabilities.
309310
float total_n = static_cast<float>(GetChildrenVisits());
310311
// Prevent garbage/invalid training data from being uploaded to server.
311312
if (total_n <= 0.0f) throw Exception("Search generated invalid data!");
312-
std::memset(result.probabilities, 0, sizeof(result.probabilities));
313+
// Set illegal moves to have -1 probability.
314+
std::memset(result.probabilities, -1, sizeof(result.probabilities));
313315
for (const auto& child : Edges()) {
314316
result.probabilities[child.edge()->GetMove().as_nn_index()] =
315317
child.GetN() / total_n;
@@ -343,6 +345,14 @@ V3TrainingData Node::GetV3TrainingData(
343345
result.result = 0;
344346
}
345347

348+
// Aggregate evaluation Q.
349+
result.root_q = -GetQ();
350+
result.best_q = best_eval;
351+
352+
// Draw probability of WDL head.
353+
result.root_d = 0.0;
354+
result.best_d = 0.0;
355+
346356
return result;
347357
}
348358

src/mcts/node.h

+3-2
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,10 @@ class Node {
204204
// in depth parameter, and returns true if it was indeed updated.
205205
bool UpdateFullDepth(uint16_t* depth);
206206

207-
V3TrainingData GetV3TrainingData(GameResult result,
207+
V4TrainingData GetV4TrainingData(GameResult result,
208208
const PositionHistory& history,
209-
FillEmptyHistory fill_empty_history) const;
209+
FillEmptyHistory fill_empty_history,
210+
float best_eval) const;
210211

211212
// Returns range for iterating over edges.
212213
ConstIterator Edges() const;

src/neural/blas/network_blas.cc

+100-25
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "neural/network_legacy.h"
2727
#include "neural/shared/activation.h"
2828
#include "neural/shared/batchnorm.h"
29+
#include "neural/shared/policy_map.h"
2930
#include "neural/shared/winograd_filter.h"
3031

3132
#include <algorithm>
@@ -38,7 +39,8 @@ namespace {
3839

3940
class BlasComputation : public NetworkComputation {
4041
public:
41-
BlasComputation(const LegacyWeights& weights, const size_t max_batch_size);
42+
BlasComputation(const LegacyWeights& weights, const size_t max_batch_size,
43+
const bool conv_policy);
4244

4345
virtual ~BlasComputation() {}
4446

@@ -65,12 +67,17 @@ class BlasComputation : public NetworkComputation {
6567
static constexpr auto kWidth = 8;
6668
static constexpr auto kHeight = 8;
6769
static constexpr auto kSquares = kWidth * kHeight;
70+
static constexpr auto kPolicyOutputs = 1858;
71+
// Number of used planes with convolutional policy.
72+
// The real number of planes is higher because of padding.
73+
static constexpr auto kPolicyUsedPlanes = 73;
6874

6975
const LegacyWeights& weights_;
7076
size_t max_batch_size_;
7177
std::vector<InputPlanes> planes_;
7278
std::vector<std::vector<float>> policies_;
7379
std::vector<float> q_values_;
80+
bool conv_policy_;
7481
};
7582

7683
class BlasNetwork : public Network {
@@ -79,7 +86,8 @@ class BlasNetwork : public Network {
7986
virtual ~BlasNetwork(){};
8087

8188
std::unique_ptr<NetworkComputation> NewComputation() override {
82-
return std::make_unique<BlasComputation>(weights_, max_batch_size_);
89+
return std::make_unique<BlasComputation>(weights_, max_batch_size_,
90+
conv_policy_);
8391
}
8492

8593
private:
@@ -88,21 +96,24 @@ class BlasNetwork : public Network {
8896

8997
LegacyWeights weights_;
9098
size_t max_batch_size_;
99+
bool conv_policy_;
91100
};
92101

93102
BlasComputation::BlasComputation(const LegacyWeights& weights,
94-
const size_t max_batch_size)
103+
const size_t max_batch_size,
104+
const bool conv_policy)
95105
: weights_(weights),
96106
max_batch_size_(max_batch_size),
97107
policies_(0),
98-
q_values_(0) {}
108+
q_values_(0),
109+
conv_policy_(conv_policy) {}
99110

100111
void BlasComputation::ComputeBlocking() {
101112
// Retrieve network key dimensions from the weights structure.
102113
const auto num_value_channels = weights_.ip1_val_b.size();
103114
const auto num_value_input_planes = weights_.value.bn_means.size();
104115
const auto num_policy_input_planes = weights_.policy.bn_means.size();
105-
const auto num_output_policy = weights_.ip_pol_b.size();
116+
const auto num_output_policy = kPolicyOutputs;
106117
const auto output_channels = weights_.input.biases.size();
107118

108119
// max_channels is the maximum number of input channels of any
@@ -204,29 +215,60 @@ void BlasComputation::ComputeBlocking() {
204215
}
205216
}
206217

207-
Convolution1::Forward(batch_size, output_channels, num_policy_input_planes,
208-
conv_out, weights_.policy.weights.data(),
209-
policy_buffer.data());
218+
if (conv_policy_) {
219+
// Need to preserve conv_out which is used for value head
220+
convolve3.Forward(batch_size, output_channels, output_channels, conv_out,
221+
&weights_.policy1.weights[0], res);
222+
223+
ApplyBatchNormalization(batch_size, output_channels, &res[0],
224+
weights_.policy1.bn_means.data(),
225+
weights_.policy1.bn_stddivs.data());
226+
227+
convolve3.Forward(batch_size, output_channels, num_policy_input_planes,
228+
res, &weights_.policy.weights[0], policy_buffer.data());
229+
230+
ApplyBatchNormalization(
231+
batch_size, num_policy_input_planes, &policy_buffer.data()[0],
232+
weights_.policy.bn_means.data(), weights_.policy.bn_stddivs.data(),
233+
nullptr, false);
234+
235+
// Mapping from convolutional policy to lc0 policy
236+
for (auto batch = size_t{0}; batch < batch_size; batch++) {
237+
for (auto i = 0; i < kPolicyUsedPlanes * kSquares; i++) {
238+
auto j = kConvPolicyMap[i];
239+
if (j >= 0) {
240+
output_pol[batch * num_output_policy + j] =
241+
policy_buffer[batch * num_policy_input_planes * kSquares + i];
242+
}
243+
}
244+
}
245+
246+
} else {
247+
Convolution1::Forward(
248+
batch_size, output_channels, num_policy_input_planes, conv_out,
249+
weights_.policy.weights.data(), policy_buffer.data());
250+
251+
ApplyBatchNormalization(
252+
batch_size, num_policy_input_planes, &policy_buffer[0],
253+
weights_.policy.bn_means.data(), weights_.policy.bn_stddivs.data());
254+
255+
FullyConnectedLayer::Forward1D(
256+
batch_size, num_policy_input_planes * kSquares, num_output_policy,
257+
policy_buffer.data(), weights_.ip_pol_w.data(),
258+
weights_.ip_pol_b.data(),
259+
false, // Relu Off
260+
output_pol.data());
261+
}
210262

263+
// Value head
211264
Convolution1::Forward(batch_size, output_channels, num_value_input_planes,
212265
conv_out, weights_.value.weights.data(),
213266
value_buffer.data());
214267

215-
ApplyBatchNormalization(batch_size, num_policy_input_planes,
216-
&policy_buffer[0], weights_.policy.bn_means.data(),
217-
weights_.policy.bn_stddivs.data());
218-
219268
ApplyBatchNormalization(batch_size, num_value_input_planes,
220269
&value_buffer[0], weights_.value.bn_means.data(),
221270
weights_.value.bn_stddivs.data());
222271

223-
FullyConnectedLayer::Forward1D(
224-
batch_size, num_policy_input_planes * kSquares, num_output_policy,
225-
policy_buffer.data(), weights_.ip_pol_w.data(),
226-
weights_.ip_pol_b.data(),
227-
false, // Relu Off
228-
output_pol.data());
229-
230272
FullyConnectedLayer::Forward1D(
231273
batch_size, num_value_input_planes * kSquares, num_value_channels,
232274
value_buffer.data(), weights_.ip1_val_w.data(),
@@ -238,8 +280,8 @@ void BlasComputation::ComputeBlocking() {
238280
std::vector<float> policy(num_output_policy);
239281

240282
// Get the moves
241-
SoftmaxActivation(
242-
num_output_policy, &output_pol[j * num_output_policy], policy.data());
283+
SoftmaxActivation(num_output_policy, &output_pol[j * num_output_policy],
284+
policy.data());
243285

244286
policies_.emplace_back(std::move(policy));
245287

@@ -268,6 +310,9 @@ BlasNetwork::BlasNetwork(const WeightsFile& file, const OptionsDict& options)
268310
max_batch_size_ =
269311
static_cast<size_t>(options.GetOrDefault<int>("batch_size", 256));
270312

313+
conv_policy_ = file.format().network_format().policy() ==
314+
pblczero::NetworkFormat::POLICY_CONVOLUTION;
315+
271316
if (max_batch_size_ > kHardMaxBatchSize) {
272317
max_batch_size_ = kHardMaxBatchSize;
273318
}
@@ -298,8 +343,24 @@ BlasNetwork::BlasNetwork(const WeightsFile& file, const OptionsDict& options)
298343
conv2.InvertStddev();
299344
}
300345

301-
weights_.policy.OffsetMeans();
302-
weights_.policy.InvertStddev();
346+
if (conv_policy_) {
347+
weights_.policy1.OffsetMeans();
348+
weights_.policy1.InvertStddev();
349+
350+
weights_.policy1.weights =
351+
WinogradFilterTransformF(weights_.policy1.weights, channels, channels);
352+
auto pol_channels = weights_.policy.biases.size();
353+
weights_.policy.weights = WinogradFilterTransformF(weights_.policy.weights,
354+
pol_channels, channels);
355+
// Move bias to batchnorm
356+
for (auto i = size_t{0}; i < pol_channels; i++) {
357+
weights_.policy.bn_means.emplace_back(-weights_.policy.biases[i]);
358+
weights_.policy.bn_stddivs.emplace_back(1.0f);
359+
}
360+
} else {
361+
weights_.policy.OffsetMeans();
362+
weights_.policy.InvertStddev();
363+
}
303364
weights_.value.OffsetMeans();
304365
weights_.value.InvertStddev();
305366

@@ -346,14 +407,28 @@ BlasNetwork::BlasNetwork(const WeightsFile& file, const OptionsDict& options)
346407
std::unique_ptr<Network> MakeBlasNetwork(const WeightsFile& weights,
347408
const OptionsDict& options) {
348409
if (weights.format().network_format().network() !=
349-
pblczero::NetworkFormat::NETWORK_CLASSICAL &&
410+
pblczero::NetworkFormat::NETWORK_CLASSICAL_WITH_HEADFORMAT &&
350411
weights.format().network_format().network() !=
351-
pblczero::NetworkFormat::NETWORK_SE) {
412+
pblczero::NetworkFormat::NETWORK_SE_WITH_HEADFORMAT) {
352413
throw Exception(
353414
"Network format " +
354415
std::to_string(weights.format().network_format().network()) +
355416
" is not supported by BLAS backend.");
356417
}
418+
if (weights.format().network_format().policy() !=
419+
pblczero::NetworkFormat::POLICY_CLASSICAL &&
420+
weights.format().network_format().policy() !=
421+
pblczero::NetworkFormat::POLICY_CONVOLUTION) {
422+
throw Exception("Policy format " +
423+
std::to_string(weights.format().network_format().policy()) +
424+
" is not supported by BLAS backend.");
425+
}
426+
if (weights.format().network_format().value() !=
427+
pblczero::NetworkFormat::VALUE_CLASSICAL) {
428+
throw Exception("Value format " +
429+
std::to_string(weights.format().network_format().value()) +
430+
" is not supported by BLAS backend.");
431+
}
357432
return std::make_unique<BlasNetwork>(weights, options);
358433
}
359434

0 commit comments

Comments
 (0)