Skip to content

Commit

Permalink
Optimize CFR (~20% increase in # of iterations) (#47)
Browse files Browse the repository at this point in the history
* Optimize CFR (~20% increase in # of iterations)

* Fix

* Fix
  • Loading branch information
b-inary authored Jan 31, 2024
1 parent 8b266ad commit b676689
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 85 deletions.
116 changes: 69 additions & 47 deletions csrc/src/cfr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@

namespace pokerbot {

HandActionsValues::HandActionsValues(const unsigned num_hands, const unsigned num_actions,
float value)
: data(num_hands * num_actions, value), num_hands_(num_hands), num_actions_(num_actions) {}

CFR::CFR(const Game& game) : game_(game), board_data_cache_(game) {
for (auto& child_values : children_cfvs_) {
child_values.resize(NUM_HANDS_POSTFLOP_3CARDS);
child_values.resize(ceil_to_multiple(NUM_HANDS_POSTFLOP_3CARDS));
}
}

Expand Down Expand Up @@ -80,7 +76,7 @@ void CFR::update_opponent_cfvs_vs_bet() {
const auto& raise_action = actions_.back();

// Opponent's perspective -> Fold means he lose
const auto half_pot = root_->pot_start_round() / 2;
const float half_pot = root_->pot_start_round() * 0.5f;
const float fold_payoff = -(half_pot + root_->bets[1 - player_id_]);
compute_fold_cfvs(opponent_range_raise_fold_, hero_range_raise_, fold_payoff, raise_fold_cfvs_);

Expand All @@ -97,7 +93,7 @@ void CFR::update_hero_cfvs_bet_node() {
const auto& raise_action = actions_.back();

// Hero's perspective -> Opponent folds mean we wins
const auto half_pot = root_->pot_start_round() / 2;
const float half_pot = root_->pot_start_round() * 0.5f;
const float fold_payoff = half_pot + root_->bets[1 - player_id_];
compute_fold_cfvs(hero_range_raise_, opponent_range_raise_fold_, fold_payoff, raise_fold_cfvs_);

Expand All @@ -124,7 +120,7 @@ void CFR::precompute_cfvs_fixed_nodes(const std::array<Range, 2>& ranges) {
continue;
}

const auto half_pot = root_->pot_start_round() / 2;
const float half_pot = root_->pot_start_round() * 0.5f;
if (actions_[a].type == Action::Type::FOLD) {
const float payoff = -(half_pot + root_->bets[player_id_]);
compute_fold_cfvs(ranges[player_id_], ranges[1 - player_id_], payoff, children_cfvs_[a]);
Expand All @@ -142,26 +138,33 @@ void CFR::update_opponent_regrets() {
const float root_value = opponent_fold_strategy_vs_bet_[hand] * raise_fold_cfvs_[hand] +
(1.0f - opponent_fold_strategy_vs_bet_[hand]) * raise_call_cfvs_[hand];

opponent_regrets_vs_bet_(hand, 0) += raise_fold_cfvs_[hand] - root_value;
opponent_regrets_vs_bet_(hand, 1) += raise_call_cfvs_[hand] - root_value;
opponent_regrets_vs_bet_(0, hand) += raise_fold_cfvs_[hand] - root_value;
opponent_regrets_vs_bet_(1, hand) += raise_call_cfvs_[hand] - root_value;

opponent_regrets_vs_bet_(hand, 0) *= regret_discount;
opponent_regrets_vs_bet_(hand, 1) *= regret_discount;
opponent_regrets_vs_bet_(0, hand) *= regret_discount;
opponent_regrets_vs_bet_(1, hand) *= regret_discount;
}
}

void CFR::update_hero_regrets() {
const auto regret_discount = get_linear_cfr_discount_factor();
for (hand_t hand = 0; hand < num_hands_[player_id_]; ++hand) {
float root_value = 0;

static_assert(RANGE_SIZE_MULTIPLE % 16 == 0);
for (hand_t hand = 0; hand < num_hands_[player_id_]; hand += 16) {
std::array<float, 16> root_values = {};

for (unsigned action = 0; action < num_actions(); ++action) {
root_value += strategy_(hand, action) * children_cfvs_[action][hand];
for (unsigned i = 0; i < 16; ++i) {
root_values[i] += strategy_(action, hand + i) * children_cfvs_[action][hand + i];
}
}

for (unsigned action = 0; action < num_actions(); ++action) {
const auto immediate_regret = children_cfvs_[action][hand] - root_value;
regrets_(hand, action) += immediate_regret;
regrets_(hand, action) *= regret_discount;
for (unsigned i = 0; i < 16; ++i) {
const auto immediate_regret = children_cfvs_[action][hand + i] - root_values[i];
regrets_(action, hand + i) += immediate_regret;
regrets_(action, hand + i) *= regret_discount;
}
}
}
}
Expand All @@ -172,8 +175,9 @@ void CFR::update_hero_reaches(const Range& hero_range) {
return;
}

const auto raise_action_index = num_actions() - 1;
for (hand_t hand = 0; hand < num_hands_[player_id_]; ++hand) {
const auto raise_prob = strategy_(hand, num_actions() - 1);
const auto raise_prob = strategy_(raise_action_index, hand);
hero_range_raise_.range[hand] = hero_range.range[hand] * raise_prob;
}
}
Expand All @@ -187,45 +191,63 @@ void CFR::update_opponent_reaches(const Range& opponent_range) {
}

void CFR::update_hero_strategy() {
for (hand_t hand = 0; hand < num_hands_[player_id_]; ++hand) {
float sum_positive_regrets = 0;
static_assert(RANGE_SIZE_MULTIPLE % 16 == 0);
for (hand_t hand = 0; hand < num_hands_[player_id_]; hand += 16) {
std::array<float, 16> sum_positive_regrets = {};

for (unsigned action = 0; action < num_actions(); ++action) {
if (regrets_(hand, action) > 0) {
sum_positive_regrets += regrets_(hand, action);
strategy_(hand, action) = regrets_(hand, action);
} else {
strategy_(hand, action) = 0;
for (unsigned i = 0; i < 16; ++i) {
const float positive_regret = std::max(regrets_(action, hand + i), 0.0f);
sum_positive_regrets[i] += positive_regret;
strategy_(action, hand + i) = positive_regret;
}
}

for (unsigned action = 0; action < num_actions(); ++action) {
strategy_(hand, action) = sum_positive_regrets > 0
? strategy_(hand, action) / sum_positive_regrets
: 1.0f / static_cast<float>(num_actions());
for (unsigned i = 0; i < 16; ++i) {
if (sum_positive_regrets[i] > 0) {
strategy_(action, hand + i) /= sum_positive_regrets[i];
} else {
strategy_(action, hand + i) = 1.0f / static_cast<float>(num_actions());
}
}
}
}

// zero out strategy for hands outside of range just in case
for (hand_t hand = num_hands_[player_id_]; hand < ceil_to_multiple(num_hands_[player_id_], 16);
++hand) {
for (unsigned action = 0; action < num_actions(); ++action) {
strategy_(action, hand) = 0;
}
}
}

void CFR::update_opponent_strategy() {
for (hand_t hand = 0; hand < num_hands_[1 - player_id_]; ++hand) {
float sum_positive_regrets = 0;
static_assert(RANGE_SIZE_MULTIPLE % 16 == 0);
for (hand_t hand = 0; hand < num_hands_[1 - player_id_]; hand += 16) {
std::array<float, 16> sum_positive_regrets = {};

// Fold
if (opponent_regrets_vs_bet_(hand, 0) > 0) {
sum_positive_regrets += opponent_regrets_vs_bet_(hand, 0);
opponent_fold_strategy_vs_bet_[hand] = opponent_regrets_vs_bet_(hand, 0);
} else {
opponent_fold_strategy_vs_bet_[hand] = 0;
for (unsigned i = 0; i < 16; ++i) {
const float positive_regret = std::max(opponent_regrets_vs_bet_(0, hand + i), 0.0f);
sum_positive_regrets[i] += positive_regret;
opponent_fold_strategy_vs_bet_[hand + i] = positive_regret;
}

// Call
if (opponent_regrets_vs_bet_(hand, 1) > 0) {
sum_positive_regrets += opponent_regrets_vs_bet_(hand, 1);
for (unsigned i = 0; i < 16; ++i) {
const float positive_regret = std::max(opponent_regrets_vs_bet_(1, hand + i), 0.0f);
sum_positive_regrets[i] += positive_regret;
}

opponent_fold_strategy_vs_bet_[hand] =
sum_positive_regrets > 0 ? opponent_fold_strategy_vs_bet_[hand] / sum_positive_regrets
: 0.5f;
for (unsigned i = 0; i < 16; ++i) {
if (sum_positive_regrets[i] > 0) {
opponent_fold_strategy_vs_bet_[hand + i] /= sum_positive_regrets[i];
} else {
opponent_fold_strategy_vs_bet_[hand + i] = 0.5f;
}
}
}
}

Expand Down Expand Up @@ -268,12 +290,12 @@ void CFR::solve(const std::array<Range, 2>& ranges, const RoundStatePtr& state,

build_tree();

regrets_ = HandActionsValues(num_hands_[player_id_], num_actions(), 0);
opponent_regrets_vs_bet_ = HandActionsValues(num_hands_[1 - player_id_], 2, 0);
strategy_ = HandActionsValues(num_hands_[player_id_], num_actions(), 1.0f / num_actions());
opponent_fold_strategy_vs_bet_.assign(num_hands_[1 - player_id_], 0.5f);
raise_fold_cfvs_.assign(max_num_hands, 0);
raise_call_cfvs_.assign(max_num_hands, 0);
regrets_ = HandActionsValues(num_actions(), num_hands_[player_id_], 0);
opponent_regrets_vs_bet_ = HandActionsValues(2, num_hands_[1 - player_id_], 0);
strategy_ = HandActionsValues(num_actions(), num_hands_[player_id_], 1.0f / num_actions());
opponent_fold_strategy_vs_bet_.assign(ceil_to_multiple(num_hands_[1 - player_id_]), 0.5f);
raise_fold_cfvs_.assign(ceil_to_multiple(max_num_hands), 0);
raise_call_cfvs_.assign(ceil_to_multiple(max_num_hands), 0);

precompute_cfvs_fixed_nodes(ranges);
update_hero_reaches(ranges[player_id_]);
Expand Down
25 changes: 12 additions & 13 deletions csrc/src/cfr.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,24 @@

namespace pokerbot {

class HandActionsValues {
public:
explicit HandActionsValues(unsigned num_hands, unsigned num_actions, float value);
struct HandActionsValues {
explicit HandActionsValues(unsigned num_actions, unsigned num_hands, float value)
: data(num_actions, std::vector<float>(ceil_to_multiple(num_hands), value)),
num_actions(num_actions),
num_hands(num_hands) {}

HandActionsValues() = default;

float& operator()(const hand_t hand, const unsigned action) {
return data[hand * num_actions_ + action];
}
float& operator()(const unsigned action, const hand_t hand) { return data[action][hand]; }

float const& operator()(const hand_t hand, const unsigned action) const {
return data[hand * num_actions_ + action];
float const& operator()(const unsigned action, const hand_t hand) const {
return data[action][hand];
}

// Store a 2D data of [action, hand] in 1D vector of size num_hands * num_actions
std::vector<float> data;

unsigned num_hands_ = 0;
unsigned num_actions_ = 0;
// Store a 2D data of [action, hand]
std::vector<std::vector<float>> data;
unsigned num_actions = 0;
unsigned num_hands = 0;
};

// NB: Should create this object once and reuse it each time you need to solve a subgame
Expand Down
14 changes: 7 additions & 7 deletions csrc/src/cfr_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ TEST_F(CFRTest, TestRiverNutAirToyGame) {
fmt::print("Board is {} \n", Card::to_string(board_cards));
double bluff_frequency = 0;
for (const auto& bluff : bluffs) {
fmt::print("{} Check = {} / Raise = {} \n", bluff, strategy(Hand(bluff).index(), 0),
strategy(Hand(bluff).index(), 1));
bluff_frequency += strategy(Hand(bluff).index(), 1);
fmt::print("{} Check = {} / Raise = {} \n", bluff, strategy(0, Hand(bluff).index()),
strategy(1, Hand(bluff).index()));
bluff_frequency += strategy(1, Hand(bluff).index());
}
double valuebet_frequency = 0;
for (const auto& value_bet : value_bets) {
fmt::print("{} Check = {} / Raise = {} \n", value_bet, strategy(Hand(value_bet).index(), 0),
strategy(Hand(value_bet).index(), 1));
valuebet_frequency += strategy(Hand(value_bet).index(), 1);
fmt::print("{} Check = {} / Raise = {} \n", value_bet, strategy(0, Hand(value_bet).index()),
strategy(1, Hand(value_bet).index()));
valuebet_frequency += strategy(1, Hand(value_bet).index());
}
double pot_odds = (double)round_state->effective_stack() /
(round_state->effective_stack() + round_state->pot());
Expand Down Expand Up @@ -111,7 +111,7 @@ TEST_F(CFRTest, TestPreflopOpenRaiseStrategy) {
"Tc9h", "Tc8d", "9s8s", "9h8s", "6c5c", "3c2c", "3c2d"}) {
Hand hand(hand_str);
fmt::print("{} Fold = {} / Call = {} / Raise = {} \n", hand.to_string(),
strategy(hand.index(), 0), strategy(hand.index(), 1), strategy(hand.index(), 2));
strategy(0, hand.index()), strategy(1, hand.index()), strategy(2, hand.index()));
}

// TODO Add tests/asserts here
Expand Down
10 changes: 5 additions & 5 deletions csrc/src/main_bot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ Action MainBot::sample_action_and_update_range(const GameInfo& game_info, const
const HandActionsValues& strategy,
const std::vector<Action>& legal_actions,
const float min_prob_sampling) {
if (legal_actions.size() != strategy.num_actions_) {
if (legal_actions.size() != strategy.num_actions) {
throw std::runtime_error("Actions mistmatch");
}

// Get strategy for hand
std::vector<float> probs;
probs.reserve(strategy.num_actions_);
for (unsigned a = 0; a < strategy.num_actions_; ++a) {
probs.push_back(strategy(hand.index(), a));
probs.reserve(strategy.num_actions);
for (unsigned a = 0; a < strategy.num_actions; ++a) {
probs.push_back(strategy(a, hand.index()));
}
// Don't sample if prob is too low due to non-convergence
for (auto& val : probs) {
Expand All @@ -43,7 +43,7 @@ Action MainBot::sample_action_and_update_range(const GameInfo& game_info, const
strategy_str);

for (hand_t i = 0; i < ranges_[hero_id].num_hands(); ++i) {
const auto action_prob = strategy(i, sampled_idx);
const auto action_prob = strategy(sampled_idx, i);
ranges_[hero_id].range[i] = ranges_[hero_id].range[i] * action_prob;
}

Expand Down
10 changes: 1 addition & 9 deletions csrc/src/range.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,10 @@

namespace pokerbot {

namespace {

auto init_random_range() {
std::array<float, NUM_HANDS_POSTFLOP_3CARDS> range{};
Range::Range() : num_cards(NumCards::Two), range(SIZE) {
std::fill_n(range.begin(), NUM_HANDS_POSTFLOP_2CARDS, 1.0);
return range;
}

} // namespace

Range::Range() : num_cards(NumCards::Two), range(init_random_range()) {}

void Range::update_on_board_cards(const Game& game, const std::vector<card_t>& board_cards) {
if (board_cards.empty()) {
return;
Expand Down
21 changes: 19 additions & 2 deletions csrc/src/range.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
#pragma once
#include <array>
#include "definitions.h"
#include "game.h"

#include <stdexcept>
#include <vector>

namespace pokerbot {

inline constexpr unsigned RANGE_SIZE_MULTIPLE = 16;

template <typename T>
inline constexpr T ceil_to_multiple(T n, unsigned multiple = RANGE_SIZE_MULTIPLE) {
if ((multiple & (multiple - 1)) != 0) {
throw std::invalid_argument("multiple must be a power of 2");
}
return (n + multiple - 1) & ~(multiple - 1);
}

struct Range {
// smallest multiple of `RANGE_SIZE_MULTIPLE` that is >= `NUM_HANDS_POSTFLOP_3CARDS`
static constexpr hand_t SIZE = ceil_to_multiple(NUM_HANDS_POSTFLOP_3CARDS);

NumCards num_cards;
std::array<float, NUM_HANDS_POSTFLOP_3CARDS> range;

// `range.size()` may not be equal to `NUM_HANDS_POSTFLOP_3CARDS`!
std::vector<float> range;

/// Initialize uniform random range
Range();
Expand Down
2 changes: 1 addition & 1 deletion csrc/src/runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ inline std::array<std::string, 2> parseArgs(int argc, char* argv[]) {
bool host_flag = false;
for (int i = 1; i < argc; i++) {
std::string arg(argv[i]);
if ((arg == "-h") | (arg == "--host")) {
if (arg == "-h" || arg == "--host") {
host_flag = true;
} else if (arg == "--port") {
// nothing to do
Expand Down
2 changes: 1 addition & 1 deletion scripts/clang_format_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ set -e

ROOT="$(dirname "$0")/../"

clang-format-14 -style=file -i $(find "$ROOT/csrc/src" "$ROOT/csrc/scripts" -name "*.cc" -o -name '*.cpp' -o -name '*.h')
clang-format-14 -style=file -i $(find "$ROOT/csrc/src" "$ROOT/csrc/scripts" "$ROOT/simple_bots" -name "*.cc" -o -name '*.cpp' -o -name '*.h')
1 change: 1 addition & 0 deletions simple_bots/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
cmake_minimum_required(VERSION 3.21)

project(simple-bot)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
Expand Down

0 comments on commit b676689

Please sign in to comment.