Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
philqc committed Feb 1, 2024
1 parent d3ddfcd commit 2142e71
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 64 deletions.
4 changes: 2 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# PARAMETERS TO CONTROL THE BEHAVIOR OF THE GAME ENGINE
# DO NOT REMOVE OR RENAME THIS FILE
PLAYER_1_NAME = 'GTOWizard AI'
PLAYER_1_NAME = 'GTOWizard AI A'
PLAYER_1_PATH = './csrc/main_bot'
# NO TRAILING SLASHES ARE ALLOWED IN PATHS
PLAYER_2_NAME = 'GTOWizard AI'
PLAYER_2_NAME = 'GTOWizard AI B'
PLAYER_2_PATH = './csrc/main_bot'
# GAME PROGRESS IS RECORDED HERE
GAME_LOG_FILENAME = 'gamelog'
Expand Down
28 changes: 16 additions & 12 deletions csrc/src/cfr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ CFR::CFR(const Game& game) : game_(game), board_data_cache_(game) {
}
}

void CFR::build_tree() {
void CFR::build_tree(const std::optional<Action> force_raise_root) {
actions_.clear();
actions_.reserve(legal_actions().size() + 1);

Expand All @@ -33,16 +33,20 @@ void CFR::build_tree() {
if (raise_is_legal) {
// Prefer to go all-in when it's less than X * pot-sized-bet
// Otherwise always choose a pot-sized-bet
const auto raise_bounds = root_->raise_bounds();
auto amount_call = std::max(root_->bets[0], root_->bets[1]);
auto pot_plus_call = root_->pot() + std::abs(root_->bets[1] - root_->bets[0]);
auto pot_sized_bet = amount_call + pot_plus_call;

if (raise_bounds[1] < 3 * pot_sized_bet) {
// All-in
actions_.emplace_back(Action::Type::RAISE, raise_bounds[1]);
if (force_raise_root.has_value() && force_raise_root->type == Action::Type::RAISE) {
actions_.emplace_back(*force_raise_root);
} else {
actions_.emplace_back(Action::Type::RAISE, pot_sized_bet);
const auto raise_bounds = root_->raise_bounds();
auto amount_call = std::max(root_->bets[0], root_->bets[1]);
auto pot_plus_call = root_->pot() + std::abs(root_->bets[1] - root_->bets[0]);
auto pot_sized_bet = amount_call + pot_plus_call;

if (raise_bounds[1] < 3 * pot_sized_bet) {
// All-in
actions_.emplace_back(Action::Type::RAISE, raise_bounds[1]);
} else {
actions_.emplace_back(Action::Type::RAISE, pot_sized_bet);
}
}
}

Expand Down Expand Up @@ -278,7 +282,7 @@ void CFR::step(const std::array<Range, 2>& ranges) {

void CFR::solve(const std::array<Range, 2>& ranges, const RoundStatePtr& state,
const unsigned player_id, const float time_budget_ms,
const unsigned max_num_iters) {
const std::optional<Action> force_raise_root, const unsigned max_num_iters) {
const auto start_time = std::chrono::high_resolution_clock::now();
root_ = state;
num_hands_ = {
Expand All @@ -293,7 +297,7 @@ void CFR::solve(const std::array<Range, 2>& ranges, const RoundStatePtr& state,
opponent_range_raise_call_ = ranges[1 - player_id_];
opponent_range_raise_fold_ = ranges[1 - player_id_];

build_tree();
build_tree(force_raise_root);

regrets_ = HandActionsValues(num_actions(), num_hands_[player_id_], 0);
opponent_regrets_vs_bet_ = HandActionsValues(2, num_hands_[1 - player_id_], 0);
Expand Down
5 changes: 3 additions & 2 deletions csrc/src/cfr.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class CFR {
explicit CFR(const Game& game);

void solve(const std::array<Range, 2>& ranges, const RoundStatePtr& state, unsigned player_id,
float time_budget_ms, unsigned max_num_iters = 1000);
float time_budget_ms, std::optional<Action> force_raise_root = std::nullopt,
unsigned max_num_iters = 1000);

// Actions considered at the state
const auto& legal_actions() const { return actions_; }
Expand All @@ -46,7 +47,7 @@ class CFR {
private:
void step(const std::array<Range, 2>& ranges);

void build_tree();
void build_tree(std::optional<Action> force_raise_root);

[[nodiscard]] float get_linear_cfr_discount_factor() const;

Expand Down
3 changes: 2 additions & 1 deletion csrc/src/cfr_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ TEST_F(CFRTest, TestRiverNutAirToyGame) {
auto round_state =
std::make_shared<RoundState>(BB_POS, false, bids, bets, stacks, hands, board_cards, nullptr);

cfr.solve(ranges, std::static_pointer_cast<const RoundState>(round_state), hero_id, 4000, 500);
cfr.solve(ranges, std::static_pointer_cast<const RoundState>(round_state), hero_id, 4000,
std::nullopt, 500);
const auto& strategy = cfr.strategy();

fmt::print("Board is {} \n", Card::to_string(board_cards));
Expand Down
145 changes: 106 additions & 39 deletions csrc/src/main_bot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ namespace pokerbot {
MainBot::MainBot()
: game_(), auctioneer_(), cfr_(game_), time_manager_(), gen_(std::random_device{}()) {}

Action MainBot::sample_action_and_update_range(const GameInfo& game_info, const RoundState& state,
const Hand& hand, const int hero_id,
const HandActionsValues& strategy,
const std::vector<Action>& legal_actions,
const float min_prob_sampling) {
Action MainBot::sample_action(const GameInfo& game_info, const RoundState& state, const Hand& hand,
const int hero_id, const HandActionsValues& strategy,
const std::vector<Action>& legal_actions,
const float min_prob_sampling) {
if (legal_actions.size() != strategy.num_actions) {
throw std::runtime_error("Actions mistmatch");
}
Expand Down Expand Up @@ -42,12 +41,35 @@ Action MainBot::sample_action_and_update_range(const GameInfo& game_info, const
state.bets[0], state.bets[1], sampled_action.to_string(), Card::to_string(state.board_cards),
strategy_str);

for (hand_t i = 0; i < ranges_[hero_id].num_hands(); ++i) {
const auto action_prob = strategy(sampled_idx, i);
ranges_[hero_id].range[i] = ranges_[hero_id].range[i] * action_prob;
return sampled_action;
}

void MainBot::update_range(const int player, const HandActionsValues& strategy,
const std::vector<Action>& legal_actions, const Action& action) {
if (legal_actions.size() != strategy.num_actions) {
throw std::runtime_error("Actions mistmatch");
}
unsigned action_idx;
for (action_idx = 0; action_idx < legal_actions.size(); ++action_idx) {
if (legal_actions[action_idx].amount == action.amount &&
legal_actions[action_idx].type == action.type) {
break;
}
}
if (action_idx == legal_actions.size()) {
std::string actions;
for (const auto& action : legal_actions) {
actions += action.to_string() + ",";
}
throw std::runtime_error(
fmt::format("Actions mistmatch Two: {} ({})", action.to_string(), actions));
}

return sampled_action;
// FIXME ADD UNIFORM RANDOM?
for (hand_t i = 0; i < ranges_[player].num_hands(); ++i) {
const auto action_prob = strategy(action_idx, i);
ranges_[player].range[i] *= action_prob;
}
}

void MainBot::handle_new_hand(const GameInfo& game_info, const RoundStatePtr& /*state*/,
Expand All @@ -66,26 +88,61 @@ void MainBot::handle_hand_over(const GameInfo& /*game_info*/,

Action MainBot::get_action(const GameInfo& game_info, const RoundStatePtr& state,
const int active) {
const auto legal_actions = state->legal_actions();
if (state->previous_state != nullptr) {
const auto last_decision_node =
std::dynamic_pointer_cast<const RoundState>(state->previous_state);
if (last_decision_node != nullptr) {
const auto opp_index = get_active(last_decision_node->button);
const auto is_opponent_node = opp_index != active;
fmt::print("Last decision node: Opponent = {} - {} \n", is_opponent_node,
last_decision_node->to_string());
// I don't know a better way to get this..
const std::optional<Action> last_action = [&]() -> std::optional<Action> {
if (state->bets[opp_index] > last_decision_node->bets[opp_index] &&
state->bets[opp_index] > state->bets[1 - opp_index]) {
return Action{Action::Type::RAISE, state->bets[opp_index]};
}
if (state->bets[opp_index] > last_decision_node->bets[opp_index] &&
state->bets[opp_index] == state->bets[1 - opp_index]) {
return Action{Action::Type::CALL};
}
if (state->bets[opp_index] == last_decision_node->bets[opp_index]) {
return Action{Action::Type::CHECK};
}
return std::nullopt;
;
}();
if (is_opponent_node && last_action.has_value()) {
get_action_any_player(game_info, last_decision_node, opp_index, last_action);
}
}
}
return get_action_any_player(game_info, state, active, std::nullopt);
}

const Hand hero_hand(state->hands[active]);
Action MainBot::get_action_any_player(const GameInfo& game_info, const RoundStatePtr& state,
const int player, std::optional<Action> sampled_action) {
const auto legal_actions = state->legal_actions();

// TODO - Need to update ranges on new board cards
std::optional<Hand> player_hand = std::nullopt;
const bool is_hero_node = !sampled_action.has_value();
if (is_hero_node) {
player_hand = Hand(state->hands[player]);
}

if (ranges::contains(legal_actions, Action::Type::BID)) {
if (is_hero_node && ranges::contains(legal_actions, Action::Type::BID)) {
/// Auction
const auto bid = auctioneer_.get_bid(ranges_[active], ranges_[1 - active], game_,
state->board_cards, hero_hand, state->pot());
const auto bid = auctioneer_.get_bid(ranges_[player], ranges_[1 - player], game_,
state->board_cards, *player_hand, state->pot());
fmt::print("Bidding {} in {}\n", bid, state->pot());
return {Action::Type::BID, bid};
}

if (state->bids[active].has_value() && state->bids[1 - active].has_value()) {
if (is_hero_node && state->bids[player].has_value() && state->bids[1 - player].has_value()) {
/// Right after auction
/// NB - Careful this is called on *every action* after auction happened

auctioneer_.receive_bid(ranges_[active], ranges_[1 - active], *state->bids[active],
*state->bids[1 - active], game_, state->board_cards, state->pot());
auctioneer_.receive_bid(ranges_[player], ranges_[1 - player], *state->bids[player],
*state->bids[1 - player], game_, state->board_cards, state->pot());
}

// NB: Put this constraint after checking if only legal action is bid!
Expand All @@ -95,46 +152,56 @@ Action MainBot::get_action(const GameInfo& game_info, const RoundStatePtr& state
return Action{state->legal_actions().front()};
}

if (state->round() == round::FLOP && state->bets[1 - active] == 0 && active == BB_POS) {
if (state->round() == round::FLOP && state->bets[1 - player] == 0 && player == BB_POS) {
fmt::print("OOP Flop - Checking 100% of the time \n");
return Action{Action::Type::CHECK};
}
if (state->round() == round::TURN && state->bets[1 - active] == 0 && active == BB_POS &&
hero_hand.num_cards() == 2) {
if (is_hero_node && state->round() == round::TURN && state->bets[1 - player] == 0 &&
player == BB_POS && player_hand->num_cards() == 2) {
fmt::print("OOP Turn with 2 cards - Checking 100% of the time \n");
return Action{Action::Type::CHECK};
}

Action sampled_action;
if (state->round() == round::PREFLOP && active == SB_POS &&
state->bets[1 - active] == BIG_BLIND) {
if (!preflop_sb_cached_strategy_.has_value() || !preflop_sb_cached_legal_actions_.has_value()) {
if (is_hero_node && state->round() == round::PREFLOP && player == SB_POS &&
state->bets[1 - player] == BIG_BLIND) {
if (!preflop_sb_cached_strategy_[player].has_value() ||
!preflop_sb_cached_legal_actions_[player].has_value()) {
// Solve with a larger time limit (computed once)
fmt::print("Solving root preflop node for 100ms \n");
cfr_.solve(ranges_, state, active, 100);
preflop_sb_cached_strategy_ = cfr_.strategy();
preflop_sb_cached_legal_actions_ = cfr_.legal_actions();
cfr_.solve(ranges_, state, player, 100, sampled_action);
preflop_sb_cached_strategy_[player] = cfr_.strategy();
preflop_sb_cached_legal_actions_[player] = cfr_.legal_actions();
}
sampled_action = sample_action_and_update_range(game_info, *state, hero_hand, active,
*preflop_sb_cached_strategy_,
*preflop_sb_cached_legal_actions_);
if (is_hero_node) {
sampled_action = sample_action(game_info, *state, *player_hand, player,
*preflop_sb_cached_strategy_[player],
*preflop_sb_cached_legal_actions_[player]);
}
update_range(player, *preflop_sb_cached_strategy_[player],
*preflop_sb_cached_legal_actions_[player], *sampled_action);

} else {
// set time budget
time_manager_.update_action(game_info, state);
const auto time_budget_ms = time_manager_.get_time_budget_ms(game_info, state);
// FIXME.. Could be avoided
ranges_[player].update_on_board_cards(game_, state->board_cards);
ranges_[1 - player].update_on_board_cards(game_, state->board_cards);

// Solve..
// FIXME - On flop/turn, CFVs are calculated as if showdown will be held with the current board
// (i.e., no more chance cards are dealt)
fmt::print("{:.2f} ms allocated for solving with CFR \n", time_budget_ms);
cfr_.solve(ranges_, state, active, time_budget_ms);
sampled_action = sample_action_and_update_range(game_info, *state, hero_hand, active,
cfr_.strategy(), cfr_.legal_actions());
fmt::print("is_hero={} - {:.2f} ms allocated for solving with CFR \n", is_hero_node,
time_budget_ms);
cfr_.solve(ranges_, state, player, time_budget_ms, sampled_action);
if (is_hero_node) {
sampled_action = sample_action(game_info, *state, *player_hand, player, cfr_.strategy(),
cfr_.legal_actions());
}
update_range(player, cfr_.strategy(), cfr_.legal_actions(), *sampled_action);
}

// TODO - Update Villain range somehow..

return sampled_action;
return *sampled_action;
}

} // namespace pokerbot
23 changes: 15 additions & 8 deletions csrc/src/main_bot.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,24 @@ class MainBot {
CFR cfr_;
TimeManager time_manager_;
mutable std::mt19937 gen_;
std::optional<HandActionsValues> preflop_sb_cached_strategy_;
std::optional<std::vector<Action>> preflop_sb_cached_legal_actions_;
std::array<std::optional<HandActionsValues>, 2> preflop_sb_cached_strategy_;
std::array<std::optional<std::vector<Action>>, 2> preflop_sb_cached_legal_actions_;
// FIXME -> ADD THIS?
// We update our opponent's range with a small weight of uniform random range
// As a regularization since our estimated ranges are not that great
// static constexpr float WEIGHT_UNIFORM_RANDOM_RANGE_OPPONENT = 0.05;

// Sample action based on strategy in `cfr_`
// Won't sample any action with prob < `min_prob_sampling`
// Then update hero's range based on sampled action
Action sample_action_and_update_range(const GameInfo& game_info, const RoundState& state,
const Hand& hand, int hero_id,
const HandActionsValues& strategy,
const std::vector<Action>& legal_actions,
float min_prob_sampling = 0.01);
Action sample_action(const GameInfo& game_info, const RoundState& state, const Hand& hand,
int hero_id, const HandActionsValues& strategy,
const std::vector<Action>& legal_actions, float min_prob_sampling = 0.05);

void update_range(int player, const HandActionsValues& strategy,
const std::vector<Action>& legal_actions, const Action& action);

Action get_action_any_player(const GameInfo& game_info, const RoundStatePtr& state, int player,
std::optional<Action> sampled_action);
};

} // namespace pokerbot

1 comment on commit 2142e71

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'Arena Benchmark'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 2142e71 Previous: ac01674 Ratio
Results vs. Uniform Random Bot 2.48 bb/hand (1.721) 13.8995 bb/hand (3.5797) 5.60
Results vs. Preflop All-in Bot 0.131 bb/hand (0.8965) 4.235 bb/hand (4.3176) 32.33

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.