Skip to content

Commit

Permalink
Merge pull request #10 from mrc-ide/feature/adaptivePT
Browse files Browse the repository at this point in the history
adaptive PTMCMC
  • Loading branch information
JasonAHendry authored Jan 22, 2025
2 parents 305684b + d0fd5d5 commit 2dc78e5
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 135 deletions.
22 changes: 15 additions & 7 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@ int main(int argc, char* argv[])
int n_pi_bins = 1000; // No. of WSAF bins in betabinomial lookup

// MCMC parameters
double w_proposal_sd = 0.1; // Titres sampled from ~N(0, w_propsal_sd)
int n_temps = 5; // Number of temperature levels for PT-MCMC
int n_burn_iters = 100; // Number of iterations in burn-in phase
int n_sample_iters = 900; // Number of iterations in sampling phase
int n_temps = 5; // Number of temperature levels for PT-MCMC
double target_acceptance = 0.44; // Target acceptance rate per MCMC rung
int swap_freq = 1; // Number of iterations between proposing Metropolis-coupling swaps

// OPTIONS
// Filter
Expand All @@ -73,6 +76,7 @@ int main(int argc, char* argv[])
->required();
cmd_infer->add_option("-o,--output_dir", output_dir, "Output directory.")
->group("Input and output");

cmd_infer->add_option("-K, --COI", K, "Complexity of infection.")
->group("Model Hyperparameters")
->check(CLI::Range(min_K, max_K));
Expand All @@ -94,12 +98,16 @@ int main(int argc, char* argv[])
cmd_infer->add_option("-b, --n_wsaf_bins", n_pi_bins, "Number of WSAF bins in Betabin lookup table.")
->group("Model Hyperparameters")
->check(CLI::Range(100, 10'000));
cmd_infer->add_option("-w, --w_proposal", w_proposal_sd, "Controls variance in proportion proposals.")
->group("MCMC Parameters")
->check(CLI::PositiveNumber);

cmd_infer->add_option("-t, --temps", n_temps, "Number of temperature levels in PT-MCMC.")
->group("MCMC Parameters")
->check(CLI::Range(5, 100));
cmd_infer->add_option("-B, --burnin", n_burn_iters, "Number of burn-in iterations.")
->group("MCMC Parameters")
->check(CLI::PositiveNumber);
cmd_infer->add_option("-S, --sampling", n_sample_iters, "Number of sampling iterations.")
->group("MCMC Parameters")
->check(CLI::PositiveNumber);

// Parse
CLI11_PARSE(app, argc, argv);
Expand Down Expand Up @@ -146,14 +154,14 @@ int main(int argc, char* argv[])
string K_output_dir = output_dir + "/K" + std::to_string(k);

// Create objects for this COI
Parameters params(k, e_0, e_1, v, rho, G, w_proposal_sd, n_pi_bins);
Parameters params(k, e_0, e_1, v, rho, G, n_pi_bins, target_acceptance, swap_freq);
ProposalEngine proposal_engine(params);
Model model(params, data); // TODO: Stop recreating BetabinArray
//model.print();

// Create MCMC on the heap
cout << " Runnning MCMC..." << endl;
mcmc_ptrs.emplace_back(std::make_unique<MCMC>(params, model, proposal_engine, n_temps));
mcmc_ptrs.emplace_back(std::make_unique<MCMC>(params, model, proposal_engine, n_burn_iters, n_sample_iters, n_temps));
mcmc_ptrs.back()->run();

// Write MCMC outputs
Expand Down
Loading

0 comments on commit 2dc78e5

Please sign in to comment.