Skip to content

Commit 5111048

Browse files
committed
Add enpassant rules to getCapturesInOrder
1 parent a8b2346 commit 5111048

File tree

6 files changed

+69
-7
lines changed

6 files changed

+69
-7
lines changed

game/game.go

+10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package game
33
import (
44
"errors"
55
"math/bits"
6+
"reflect"
67

78
"github.com/dylhunn/dragontoothmg"
89
)
@@ -226,3 +227,12 @@ func (g *Game) PopMove() error {
226227

227228
return nil
228229
}
230+
231+
func (g *Game) GetEnPassentSquare() uint8 {
232+
value := reflect.ValueOf(g.Position).Elem().FieldByName("enpassant")
233+
if value.Kind() != reflect.Uint8 {
234+
return 0
235+
}
236+
237+
return uint8(value.Uint())
238+
}

game/game_test.go

+20
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,23 @@ func TestPopNoMoves(t *testing.T) {
105105
t.Errorf("PopNoMoves = %v, is not an error", err)
106106
}
107107
}
108+
109+
func TestGame_GetEnPassentSquare(t *testing.T) {
110+
tests := []struct {
111+
name string
112+
game *Game
113+
want uint8
114+
}{
115+
{"Startposition", New(), 0},
116+
{"On c6", NewFromFen("rnbqkbnr/pp3ppp/3p4/2pPp3/4P3/8/PPP2PPP/RNBQKBNR w KQkq c6 0 41"), 42},
117+
{"On b3", NewFromFen("rnbqkbnr/pp3ppp/3p4/3Pp3/1Pp1P3/5P2/P1P3PP/RNBQKBNR b KQkq b3 0 5"), 17},
118+
}
119+
for _, tt := range tests {
120+
t.Run(tt.name, func(t *testing.T) {
121+
122+
if got := tt.game.GetEnPassentSquare(); got != tt.want {
123+
t.Errorf("Game.GetEnPassentSquare() = %v, want %v", got, tt.want)
124+
}
125+
})
126+
}
127+
}

search/search.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ func (s *Search) getCapturesInOrder() []dragontoothmg.Move {
339339
}
340340

341341
for _, move := range s.Game.Position.GenerateLegalMoves() {
342-
if bitboardsOpponent.All&(1<<move.To()) > 0 {
342+
if isCaptureOrPromotionMove(s.Game, move) {
343343
captures = append(captures, move)
344344
}
345345
}

search/search_test.go

-4
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,6 @@ func BenchmarkSearchFullEvaluation9(b *testing.B) { benchmarkSearchFullEvaluati
9090
func BenchmarkSearchFullEvaluation10(b *testing.B) { benchmarkSearchFullEvaluation(10, b) }
9191
func BenchmarkSearchFullEvaluation11(b *testing.B) { benchmarkSearchFullEvaluation(11, b) }
9292

93-
func getMove(moveStr string) dragontoothmg.Move {
94-
move, _ := dragontoothmg.ParseMove(moveStr)
95-
return move
96-
}
9793
func TestSearch_SearchBestMove(t *testing.T) {
9894
type fields struct {
9995
Game *game.Game

search/utils.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,27 @@ import (
55
"go.janniklasrichter.de/axwchessbot/game"
66
)
77

8+
func getMove(moveStr string) dragontoothmg.Move {
9+
move, _ := dragontoothmg.ParseMove(moveStr)
10+
return move
11+
}
12+
813
func (s *Search) getCaptureMVVLVA(move dragontoothmg.Move, bitboardsOwn dragontoothmg.Bitboards, bitboardsOpponent dragontoothmg.Bitboards) (score int) {
914
pieceTypeFrom, _ := getPieceTypeAtPosition(move.From(), bitboardsOwn)
1015
pieceTypeTo, _ := getPieceTypeAtPosition(move.To(), bitboardsOpponent)
1116

1217
return (1200 - s.evaluationProvider.GetPieceTypeValue(pieceTypeTo)) + int(pieceTypeFrom)
1318
}
1419

15-
// TODO: No en_passent_rules checked yet!
1620
func isCaptureOrPromotionMove(game *game.Game, move dragontoothmg.Move) bool {
21+
bitboardsOwn := game.Position.White
1722
bitboardsOpponent := game.Position.Black
1823
if !game.Position.Wtomove {
24+
bitboardsOwn = game.Position.Black
1925
bitboardsOpponent = game.Position.White
2026
}
2127

22-
return bitboardsOpponent.All&(1<<move.To()) > 0 || move.Promote() > 0
28+
return bitboardsOpponent.All&(1<<move.To()) > 0 || move.Promote() > 0 || (bitboardsOwn.Pawns&(1<<move.From()) > 0 && game.GetEnPassentSquare() == move.To())
2329
}
2430

2531
func getPieceTypeAtPosition(position uint8, bitboards dragontoothmg.Bitboards) (pieceType dragontoothmg.Piece, occupied bool) {

search/utils_test.go

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package search
2+
3+
import (
4+
"testing"
5+
6+
"github.com/dylhunn/dragontoothmg"
7+
"go.janniklasrichter.de/axwchessbot/game"
8+
)
9+
10+
func Test_isCaptureOrPromotionMove(t *testing.T) {
11+
type args struct {
12+
game *game.Game
13+
move dragontoothmg.Move
14+
}
15+
tests := []struct {
16+
name string
17+
args args
18+
want bool
19+
}{
20+
{"En Passant", args{game.NewFromFen("rnbqkbnr/pp3ppp/3p4/3Pp3/1Pp1P3/5P2/P1P3PP/RNBQKBNR b KQkq b3 0 5"), getMove("c4b3")}, true},
21+
{"Missed En Passant", args{game.NewFromFen("rnbqkbnr/pp4pp/3p1p2/3Pp3/1Pp1P3/5P1N/P1P3PP/RNBQKB1R b KQkq - 1 6"), getMove("c4b3")}, false},
22+
}
23+
for _, tt := range tests {
24+
t.Run(tt.name, func(t *testing.T) {
25+
if got := isCaptureOrPromotionMove(tt.args.game, tt.args.move); got != tt.want {
26+
t.Errorf("isCaptureOrPromotionMove() = %v, want %v", got, tt.want)
27+
}
28+
})
29+
}
30+
}

0 commit comments

Comments
 (0)