26
26
#include " neural/network_legacy.h"
27
27
#include " neural/shared/activation.h"
28
28
#include " neural/shared/batchnorm.h"
29
+ #include " neural/shared/policy_map.h"
29
30
#include " neural/shared/winograd_filter.h"
30
31
31
32
#include < algorithm>
@@ -38,7 +39,8 @@ namespace {
38
39
39
40
class BlasComputation : public NetworkComputation {
40
41
public:
41
- BlasComputation (const LegacyWeights& weights, const size_t max_batch_size);
42
+ BlasComputation (const LegacyWeights& weights, const size_t max_batch_size,
43
+ const bool conv_policy);
42
44
43
45
virtual ~BlasComputation () {}
44
46
@@ -65,12 +67,17 @@ class BlasComputation : public NetworkComputation {
65
67
static constexpr auto kWidth = 8 ;
66
68
static constexpr auto kHeight = 8 ;
67
69
static constexpr auto kSquares = kWidth * kHeight ;
70
+ static constexpr auto kPolicyOutputs = 1858 ;
71
+ // Number of used planes with convolutional policy.
72
+ // The real number of planes is higher because of padding.
73
+ static constexpr auto kPolicyUsedPlanes = 73 ;
68
74
69
75
const LegacyWeights& weights_;
70
76
size_t max_batch_size_;
71
77
std::vector<InputPlanes> planes_;
72
78
std::vector<std::vector<float >> policies_;
73
79
std::vector<float > q_values_;
80
+ bool conv_policy_;
74
81
};
75
82
76
83
class BlasNetwork : public Network {
@@ -79,7 +86,8 @@ class BlasNetwork : public Network {
79
86
virtual ~BlasNetwork (){};
80
87
81
88
std::unique_ptr<NetworkComputation> NewComputation () override {
82
- return std::make_unique<BlasComputation>(weights_, max_batch_size_);
89
+ return std::make_unique<BlasComputation>(weights_, max_batch_size_,
90
+ conv_policy_);
83
91
}
84
92
85
93
private:
@@ -88,21 +96,24 @@ class BlasNetwork : public Network {
88
96
89
97
LegacyWeights weights_;
90
98
size_t max_batch_size_;
99
+ bool conv_policy_;
91
100
};
92
101
93
102
BlasComputation::BlasComputation (const LegacyWeights& weights,
94
- const size_t max_batch_size)
103
+ const size_t max_batch_size,
104
+ const bool conv_policy)
95
105
: weights_(weights),
96
106
max_batch_size_ (max_batch_size),
97
107
policies_(0 ),
98
- q_values_(0 ) {}
108
+ q_values_(0 ),
109
+ conv_policy_(conv_policy) {}
99
110
100
111
void BlasComputation::ComputeBlocking () {
101
112
// Retrieve network key dimensions from the weights structure.
102
113
const auto num_value_channels = weights_.ip1_val_b .size ();
103
114
const auto num_value_input_planes = weights_.value .bn_means .size ();
104
115
const auto num_policy_input_planes = weights_.policy .bn_means .size ();
105
- const auto num_output_policy = weights_. ip_pol_b . size () ;
116
+ const auto num_output_policy = kPolicyOutputs ;
106
117
const auto output_channels = weights_.input .biases .size ();
107
118
108
119
// max_channels is the maximum number of input channels of any
@@ -204,29 +215,60 @@ void BlasComputation::ComputeBlocking() {
204
215
}
205
216
}
206
217
207
- Convolution1::Forward (batch_size, output_channels, num_policy_input_planes,
208
- conv_out, weights_.policy .weights .data (),
209
- policy_buffer.data ());
218
+ if (conv_policy_) {
219
+ // Need to preserve conv_out which is used for value head
220
+ convolve3.Forward (batch_size, output_channels, output_channels, conv_out,
221
+ &weights_.policy1 .weights [0 ], res);
222
+
223
+ ApplyBatchNormalization (batch_size, output_channels, &res[0 ],
224
+ weights_.policy1 .bn_means .data (),
225
+ weights_.policy1 .bn_stddivs .data ());
226
+
227
+ convolve3.Forward (batch_size, output_channels, num_policy_input_planes,
228
+ res, &weights_.policy .weights [0 ], policy_buffer.data ());
229
+
230
+ ApplyBatchNormalization (
231
+ batch_size, num_policy_input_planes, &policy_buffer.data ()[0 ],
232
+ weights_.policy .bn_means .data (), weights_.policy .bn_stddivs .data (),
233
+ nullptr , false );
234
+
235
+ // Mapping from convolutional policy to lc0 policy
236
+ for (auto batch = size_t {0 }; batch < batch_size; batch++) {
237
+ for (auto i = 0 ; i < kPolicyUsedPlanes * kSquares ; i++) {
238
+ auto j = kConvPolicyMap [i];
239
+ if (j >= 0 ) {
240
+ output_pol[batch * num_output_policy + j] =
241
+ policy_buffer[batch * num_policy_input_planes * kSquares + i];
242
+ }
243
+ }
244
+ }
245
+
246
+ } else {
247
+ Convolution1::Forward (
248
+ batch_size, output_channels, num_policy_input_planes, conv_out,
249
+ weights_.policy .weights .data (), policy_buffer.data ());
250
+
251
+ ApplyBatchNormalization (
252
+ batch_size, num_policy_input_planes, &policy_buffer[0 ],
253
+ weights_.policy .bn_means .data (), weights_.policy .bn_stddivs .data ());
254
+
255
+ FullyConnectedLayer::Forward1D (
256
+ batch_size, num_policy_input_planes * kSquares , num_output_policy,
257
+ policy_buffer.data (), weights_.ip_pol_w .data (),
258
+ weights_.ip_pol_b .data (),
259
+ false , // Relu Off
260
+ output_pol.data ());
261
+ }
210
262
263
+ // Value head
211
264
Convolution1::Forward (batch_size, output_channels, num_value_input_planes,
212
265
conv_out, weights_.value .weights .data (),
213
266
value_buffer.data ());
214
267
215
- ApplyBatchNormalization (batch_size, num_policy_input_planes,
216
- &policy_buffer[0 ], weights_.policy .bn_means .data (),
217
- weights_.policy .bn_stddivs .data ());
218
-
219
268
ApplyBatchNormalization (batch_size, num_value_input_planes,
220
269
&value_buffer[0 ], weights_.value .bn_means .data (),
221
270
weights_.value .bn_stddivs .data ());
222
271
223
- FullyConnectedLayer::Forward1D (
224
- batch_size, num_policy_input_planes * kSquares , num_output_policy,
225
- policy_buffer.data (), weights_.ip_pol_w .data (),
226
- weights_.ip_pol_b .data (),
227
- false , // Relu Off
228
- output_pol.data ());
229
-
230
272
FullyConnectedLayer::Forward1D (
231
273
batch_size, num_value_input_planes * kSquares , num_value_channels,
232
274
value_buffer.data (), weights_.ip1_val_w .data (),
@@ -238,8 +280,8 @@ void BlasComputation::ComputeBlocking() {
238
280
std::vector<float > policy (num_output_policy);
239
281
240
282
// Get the moves
241
- SoftmaxActivation (
242
- num_output_policy, &output_pol[j * num_output_policy], policy.data ());
283
+ SoftmaxActivation (num_output_policy, &output_pol[j * num_output_policy],
284
+ policy.data ());
243
285
244
286
policies_.emplace_back (std::move (policy));
245
287
@@ -268,6 +310,9 @@ BlasNetwork::BlasNetwork(const WeightsFile& file, const OptionsDict& options)
268
310
max_batch_size_ =
269
311
static_cast <size_t >(options.GetOrDefault <int >(" batch_size" , 256 ));
270
312
313
+ conv_policy_ = file.format ().network_format ().policy () ==
314
+ pblczero::NetworkFormat::POLICY_CONVOLUTION;
315
+
271
316
if (max_batch_size_ > kHardMaxBatchSize ) {
272
317
max_batch_size_ = kHardMaxBatchSize ;
273
318
}
@@ -298,8 +343,24 @@ BlasNetwork::BlasNetwork(const WeightsFile& file, const OptionsDict& options)
298
343
conv2.InvertStddev ();
299
344
}
300
345
301
- weights_.policy .OffsetMeans ();
302
- weights_.policy .InvertStddev ();
346
+ if (conv_policy_) {
347
+ weights_.policy1 .OffsetMeans ();
348
+ weights_.policy1 .InvertStddev ();
349
+
350
+ weights_.policy1 .weights =
351
+ WinogradFilterTransformF (weights_.policy1 .weights , channels, channels);
352
+ auto pol_channels = weights_.policy .biases .size ();
353
+ weights_.policy .weights = WinogradFilterTransformF (weights_.policy .weights ,
354
+ pol_channels, channels);
355
+ // Move bias to batchnorm
356
+ for (auto i = size_t {0 }; i < pol_channels; i++) {
357
+ weights_.policy .bn_means .emplace_back (-weights_.policy .biases [i]);
358
+ weights_.policy .bn_stddivs .emplace_back (1 .0f );
359
+ }
360
+ } else {
361
+ weights_.policy .OffsetMeans ();
362
+ weights_.policy .InvertStddev ();
363
+ }
303
364
weights_.value .OffsetMeans ();
304
365
weights_.value .InvertStddev ();
305
366
@@ -346,14 +407,28 @@ BlasNetwork::BlasNetwork(const WeightsFile& file, const OptionsDict& options)
346
407
std::unique_ptr<Network> MakeBlasNetwork (const WeightsFile& weights,
347
408
const OptionsDict& options) {
348
409
if (weights.format ().network_format ().network () !=
349
- pblczero::NetworkFormat::NETWORK_CLASSICAL &&
410
+ pblczero::NetworkFormat::NETWORK_CLASSICAL_WITH_HEADFORMAT &&
350
411
weights.format ().network_format ().network () !=
351
- pblczero::NetworkFormat::NETWORK_SE ) {
412
+ pblczero::NetworkFormat::NETWORK_SE_WITH_HEADFORMAT ) {
352
413
throw Exception (
353
414
" Network format " +
354
415
std::to_string (weights.format ().network_format ().network ()) +
355
416
" is not supported by BLAS backend." );
356
417
}
418
+ if (weights.format ().network_format ().policy () !=
419
+ pblczero::NetworkFormat::POLICY_CLASSICAL &&
420
+ weights.format ().network_format ().policy () !=
421
+ pblczero::NetworkFormat::POLICY_CONVOLUTION) {
422
+ throw Exception (" Policy format " +
423
+ std::to_string (weights.format ().network_format ().policy ()) +
424
+ " is not supported by BLAS backend." );
425
+ }
426
+ if (weights.format ().network_format ().value () !=
427
+ pblczero::NetworkFormat::VALUE_CLASSICAL) {
428
+ throw Exception (" Value format " +
429
+ std::to_string (weights.format ().network_format ().value ()) +
430
+ " is not supported by BLAS backend." );
431
+ }
357
432
return std::make_unique<BlasNetwork>(weights, options);
358
433
}
359
434
0 commit comments