Skip to content

Commit c2ea771

Browse files
authored
Merge pull request LeelaChessZero#8 from LeelaChessZero/master
Merge Master
2 parents cd2d1b5 + 9783710 commit c2ea771

File tree

10 files changed

+55
-39
lines changed

10 files changed

+55
-39
lines changed

src/chess/board.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,7 @@ MoveList ChessBoard::GenerateLegalMoves() const {
969969
return result;
970970
}
971971

972-
void ChessBoard::SetFromFen(const std::string& fen, int* no_capture_ply,
972+
void ChessBoard::SetFromFen(const std::string& fen, int* rule50_ply,
973973
int* moves) {
974974
Clear();
975975
int row = 7;
@@ -980,10 +980,10 @@ void ChessBoard::SetFromFen(const std::string& fen, int* no_capture_ply,
980980
string who_to_move;
981981
string castlings;
982982
string en_passant;
983-
int no_capture_halfmoves;
983+
int rule50_halfmoves;
984984
int total_moves;
985985
fen_str >> board >> who_to_move >> castlings >> en_passant >>
986-
no_capture_halfmoves >> total_moves;
986+
rule50_halfmoves >> total_moves;
987987

988988
if (!fen_str) throw Exception("Bad fen string: " + fen);
989989

@@ -1096,7 +1096,7 @@ void ChessBoard::SetFromFen(const std::string& fen, int* no_capture_ply,
10961096
if (who_to_move == "b" || who_to_move == "B") {
10971097
Mirror();
10981098
}
1099-
if (no_capture_ply) *no_capture_ply = no_capture_halfmoves;
1099+
if (rule50_ply) *rule50_ply = rule50_halfmoves;
11001100
if (moves) *moves = total_moves;
11011101
}
11021102

src/chess/board.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ class ChessBoard {
6767
static const BitBoard kPawnMask;
6868

6969
// Sets position from FEN string.
70-
// If @no_capture_ply and @moves are not nullptr, they are filled with number
70+
// If @rule50_ply and @moves are not nullptr, they are filled with number
7171
// of moves without capture and number of full moves since the beginning of
7272
// the game.
73-
void SetFromFen(const std::string& fen, int* no_capture_ply = nullptr,
73+
void SetFromFen(const std::string& fen, int* rule50_ply = nullptr,
7474
int* moves = nullptr);
7575
// Nullifies the whole structure.
7676
void Clear();

src/chess/position.cc

+12-12
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,17 @@
3232
namespace lczero {
3333

3434
Position::Position(const Position& parent, Move m)
35-
: no_capture_ply_(parent.no_capture_ply_ + 1),
35+
: rule50_ply_(parent.rule50_ply_ + 1),
3636
ply_count_(parent.ply_count_ + 1) {
3737
them_board_ = parent.us_board_;
38-
const bool capture = them_board_.ApplyMove(m);
38+
const bool is_zeroing = them_board_.ApplyMove(m);
3939
us_board_ = them_board_;
4040
us_board_.Mirror();
41-
if (capture) no_capture_ply_ = 0;
41+
if (is_zeroing) rule50_ply_ = 0;
4242
}
4343

44-
Position::Position(const ChessBoard& board, int no_capture_ply, int game_ply)
45-
: no_capture_ply_(no_capture_ply), repetitions_(0), ply_count_(game_ply) {
44+
Position::Position(const ChessBoard& board, int rule50_ply, int game_ply)
45+
: rule50_ply_(rule50_ply), repetitions_(0), ply_count_(game_ply) {
4646
us_board_ = board;
4747
them_board_ = board;
4848
them_board_.Mirror();
@@ -67,17 +67,17 @@ GameResult PositionHistory::ComputeGameResult() const {
6767
}
6868

6969
if (!board.HasMatingMaterial()) return GameResult::DRAW;
70-
if (Last().GetNoCaptureNoPawnPly() >= 100) return GameResult::DRAW;
70+
if (Last().GetRule50Ply() >= 100) return GameResult::DRAW;
7171
if (Last().GetGamePly() >= 450) return GameResult::DRAW;
7272
if (Last().GetRepetitions() >= 2) return GameResult::DRAW;
7373

7474
return GameResult::UNDECIDED;
7575
}
7676

77-
void PositionHistory::Reset(const ChessBoard& board, int no_capture_ply,
77+
void PositionHistory::Reset(const ChessBoard& board, int rule50_ply,
7878
int game_ply) {
7979
positions_.clear();
80-
positions_.emplace_back(board, no_capture_ply, game_ply);
80+
positions_.emplace_back(board, rule50_ply, game_ply);
8181
}
8282

8383
void PositionHistory::Append(Move m) {
@@ -91,14 +91,14 @@ void PositionHistory::Append(Move m) {
9191
int PositionHistory::ComputeLastMoveRepetitions() const {
9292
const auto& last = positions_.back();
9393
// TODO(crem) implement hash/cache based solution.
94-
if (last.GetNoCaptureNoPawnPly() < 4) return 0;
94+
if (last.GetRule50Ply() < 4) return 0;
9595

9696
for (int idx = positions_.size() - 3; idx >= 0; idx -= 2) {
9797
const auto& pos = positions_[idx];
9898
if (pos.GetBoard() == last.GetBoard()) {
9999
return 1 + pos.GetRepetitions();
100100
}
101-
if (pos.GetNoCaptureNoPawnPly() < 2) return 0;
101+
if (pos.GetRule50Ply() < 2) return 0;
102102
}
103103
return 0;
104104
}
@@ -107,7 +107,7 @@ bool PositionHistory::DidRepeatSinceLastZeroingMove() const {
107107
for (auto iter = positions_.rbegin(), end = positions_.rend(); iter != end;
108108
++iter) {
109109
if (iter->GetRepetitions() > 0) return true;
110-
if (iter->GetNoCaptureNoPawnPly() == 0) return false;
110+
if (iter->GetRule50Ply() == 0) return false;
111111
}
112112
return false;
113113
}
@@ -119,7 +119,7 @@ uint64_t PositionHistory::HashLast(int positions) const {
119119
if (!positions--) break;
120120
hash = HashCat(hash, iter->Hash());
121121
}
122-
return HashCat(hash, Last().GetNoCaptureNoPawnPly());
122+
return HashCat(hash, Last().GetRule50Ply());
123123
}
124124

125125
} // namespace lczero

src/chess/position.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class Position {
3838
// From parent position and move.
3939
Position(const Position& parent, Move m);
4040
// From particular position.
41-
Position(const ChessBoard& board, int no_capture_ply, int game_ply);
41+
Position(const ChessBoard& board, int rule50_ply, int game_ply);
4242

4343
uint64_t Hash() const;
4444
bool IsBlackToMove() const { return us_board_.flipped(); }
@@ -54,7 +54,7 @@ class Position {
5454
void SetRepetitions(int repetitions) { repetitions_ = repetitions; }
5555

5656
// Number of ply with no captures and pawn moves.
57-
int GetNoCaptureNoPawnPly() const { return no_capture_ply_; }
57+
int GetRule50Ply() const { return rule50_ply_; }
5858

5959
// Gets board from the point of view of player to move.
6060
const ChessBoard& GetBoard() const { return us_board_; }
@@ -70,7 +70,7 @@ class Position {
7070
ChessBoard them_board_;
7171

7272
// How many half-moves without capture or pawn move was there.
73-
int no_capture_ply_ = 0;
73+
int rule50_ply_ = 0;
7474
// How many repetitions this position had before. For new positions it's 0.
7575
int repetitions_;
7676
// number of half-moves since beginning of the game.
@@ -102,7 +102,7 @@ class PositionHistory {
102102
int GetLength() const { return positions_.size(); }
103103

104104
// Resets the position to a given state.
105-
void Reset(const ChessBoard& board, int no_capture_ply, int game_ply);
105+
void Reset(const ChessBoard& board, int rule50_ply, int game_ply);
106106

107107
// Appends a position to history.
108108
void Append(Move m);

src/mcts/node.cc

+12-7
Original file line numberDiff line numberDiff line change
@@ -380,18 +380,21 @@ V5TrainingData Node::GetV5TrainingData(
380380
// Other params.
381381
if (input_format ==
382382
pblczero::NetworkFormat::INPUT_112_WITH_CANONICALIZATION) {
383-
result.side_to_move = position.GetBoard().en_passant().as_int() >> 56;
383+
result.side_to_move_or_enpassant =
384+
position.GetBoard().en_passant().as_int() >> 56;
384385
if ((transform & FlipTransform) != 0) {
385-
result.side_to_move = ReverseBitsInBytes(result.side_to_move);
386+
result.side_to_move_or_enpassant =
387+
ReverseBitsInBytes(result.side_to_move_or_enpassant);
386388
}
387389
// Send transform in deprecated move count so rescorer can reverse it to
388390
// calculate the actual move list from the input data.
389-
result.deprecated_move_count = transform;
391+
result.invariance_info =
392+
transform | (position.IsBlackToMove() ? (1u << 7) : 0u);
390393
} else {
391-
result.side_to_move = position.IsBlackToMove() ? 1 : 0;
392-
result.deprecated_move_count = 0;
394+
result.side_to_move_or_enpassant = position.IsBlackToMove() ? 1 : 0;
395+
result.invariance_info = 0;
393396
}
394-
result.rule50_count = position.GetNoCaptureNoPawnPly();
397+
result.rule50_count = position.GetRule50Ply();
395398

396399
// Game result.
397400
if (game_result == GameResult::WHITE_WON) {
@@ -468,7 +471,9 @@ bool NodeTree::ResetToPosition(const std::string& starting_fen,
468471
int no_capture_ply;
469472
int full_moves;
470473
starting_board.SetFromFen(starting_fen, &no_capture_ply, &full_moves);
471-
if (gamebegin_node_ && history_.Starting().GetBoard() != starting_board) {
474+
if (gamebegin_node_ &&
475+
(history_.Starting().GetBoard() != starting_board ||
476+
history_.Starting().GetNoCaptureNoPawnPly() != no_capture_ply)) {
472477
// Completely different position.
473478
DeallocateTree();
474479
}

src/mcts/search.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -1138,7 +1138,7 @@ void SearchWorker::ExtendNode(Node* node) {
11381138
return;
11391139
}
11401140

1141-
if (history_.Last().GetNoCaptureNoPawnPly() >= 100) {
1141+
if (history_.Last().GetRule50Ply() >= 100) {
11421142
node->MakeTerminal(GameResult::DRAW);
11431143
return;
11441144
}
@@ -1150,7 +1150,7 @@ void SearchWorker::ExtendNode(Node* node) {
11501150

11511151
// Neither by-position or by-rule termination, but maybe it's a TB position.
11521152
if (search_->syzygy_tb_ && board.castlings().no_legal_castle() &&
1153-
history_.Last().GetNoCaptureNoPawnPly() == 0 &&
1153+
history_.Last().GetRule50Ply() == 0 &&
11541154
(board.ours() | board.theirs()).count() <=
11551155
search_->syzygy_tb_->max_cardinality()) {
11561156
ProbeState state;

src/neural/encoder.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ InputPlanes EncodePositionForNN(
181181
} else {
182182
if (we_are_black) result[kAuxPlaneBase + 4].SetAll();
183183
}
184-
result[kAuxPlaneBase + 5].Fill(history.Last().GetNoCaptureNoPawnPly());
184+
result[kAuxPlaneBase + 5].Fill(history.Last().GetRule50Ply());
185185
// Plane kAuxPlaneBase + 6 used to be movecount plane, now it's all zeros.
186186
// Plane kAuxPlaneBase + 7 is all ones to help NN find board edges.
187187
result[kAuxPlaneBase + 7].SetAll();
@@ -245,7 +245,7 @@ InputPlanes EncodePositionForNN(
245245
if (history_idx > 0) flip = !flip;
246246
// If no capture no pawn is 0, the previous was start of game, capture or
247247
// pawn push, so no need to go back further if stopping early.
248-
if (stop_early && position.GetNoCaptureNoPawnPly() == 0) break;
248+
if (stop_early && position.GetRule50Ply() == 0) break;
249249
}
250250
if (transform != NoTransform) {
251251
// Transform all masks.

src/neural/writer.h

+9-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
*/
2727

2828
#include <zlib.h>
29+
2930
#include <fstream>
31+
3032
#include "utils/cppattributes.h"
3133

3234
#pragma once
@@ -44,9 +46,14 @@ struct V5TrainingData {
4446
uint8_t castling_us_oo;
4547
uint8_t castling_them_ooo;
4648
uint8_t castling_them_oo;
47-
uint8_t side_to_move;
49+
// For input type 3 contains enpassant column as a mask.
50+
uint8_t side_to_move_or_enpassant;
4851
uint8_t rule50_count;
49-
uint8_t deprecated_move_count; // left in to keep 8 int8 fields.
52+
// For input type 3 contains a bit field indicating the transform that was
53+
// used and the original side to move info.
54+
// Side to move is in the top bit, transform in the lower bits.
55+
// In versions prior to v5 this spot contained an unused move count field.
56+
uint8_t invariance_info;
5057
int8_t result;
5158
float root_q;
5259
float best_q;

src/selfplay/game.cc

+5-1
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,11 @@ void SelfPlayGame::WriteTrainingData(TrainingDataWriter* writer) const {
281281
// different approach.
282282
float m_estimate = training_data_.back().best_m + training_data_.size() - 1;
283283
for (auto chunk : training_data_) {
284-
const bool black_to_move = chunk.side_to_move;
284+
bool black_to_move = chunk.side_to_move_or_enpassant;
285+
if (chunk.input_format ==
286+
pblczero::NetworkFormat::INPUT_112_WITH_CANONICALIZATION) {
287+
black_to_move = (chunk.invariance_info & (1u << 7)) != 0;
288+
}
285289
if (game_result_ == GameResult::WHITE_WON) {
286290
chunk.result = black_to_move ? -1 : 1;
287291
} else if (game_result_ == GameResult::BLACK_WON) {

src/syzygy/syzygy.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -1598,7 +1598,7 @@ int SyzygyTablebase::probe_dtz(const Position& pos, ProbeState* result) {
15981598
int min_DTZ = 0xFFFF;
15991599
for (const Move& move : pos.GetBoard().GenerateLegalMoves()) {
16001600
Position next_pos = Position(pos, move);
1601-
const bool zeroing = next_pos.GetNoCaptureNoPawnPly() == 0;
1601+
const bool zeroing = next_pos.GetRule50Ply() == 0;
16021602
// For zeroing moves we want the dtz of the move _before_ doing it,
16031603
// otherwise we will get the dtz of the next move sequence. Search the
16041604
// position after the move to get the score sign (because even in a winning
@@ -1629,7 +1629,7 @@ bool SyzygyTablebase::root_probe(const Position& pos, bool has_repeated,
16291629
ProbeState result;
16301630
auto root_moves = pos.GetBoard().GenerateLegalMoves();
16311631
// Obtain 50-move counter for the root position
1632-
const int cnt50 = pos.GetNoCaptureNoPawnPly();
1632+
const int cnt50 = pos.GetRule50Ply();
16331633
// Check whether a position was repeated since the last zeroing move.
16341634
const bool rep = has_repeated;
16351635
int dtz;
@@ -1640,7 +1640,7 @@ bool SyzygyTablebase::root_probe(const Position& pos, bool has_repeated,
16401640
for (auto& m : root_moves) {
16411641
Position next_pos = Position(pos, m);
16421642
// Calculate dtz for the current move counting from the root position
1643-
if (next_pos.GetNoCaptureNoPawnPly() == 0) {
1643+
if (next_pos.GetRule50Ply() == 0) {
16441644
// In case of a zeroing move, dtz is one of -101/-1/0/1/101
16451645
const WDLScore wdl = static_cast<WDLScore>(-probe_wdl(next_pos, &result));
16461646
dtz = dtz_before_zeroing(wdl);

0 commit comments

Comments
 (0)