Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi chain for the rest of the hmc nuts samplers #3212

Merged
merged 10 commits into from
Jul 26, 2023
56 changes: 56 additions & 0 deletions src/stan/services/sample/fixed_param.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,62 @@ int fixed_param(Model& model, const stan::io::var_context& init,
return error_codes::OK;
}

template <typename Model, typename InitContextPtr, typename InitWriter,
typename SampleWriter, typename DiagnosticWriter>
int fixed_param(Model& model, const std::size_t num_chains,
const std::vector<InitContextPtr>& init,
unsigned int random_seed, unsigned int chain,
double init_radius, int num_samples, int num_thin, int refresh,
callbacks::interrupt& interrupt, callbacks::logger& logger,
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writers,
std::vector<DiagnosticWriter>& diagnostic_writers) {
std::vector<boost::ecuyer1988> rngs;
std::vector<Eigen::VectorXd> cont_vectors;
std::vector<util::mcmc_writer> writers;
std::vector<stan::mcmc::sample> samples;
std::vector<stan::mcmc::fixed_param_sampler> samplers(num_chains);
rngs.reserve(num_chains);
cont_vectors.reserve(num_chains);
writers.reserve(num_chains);
samples.reserve(num_chains);
for (int i = 0; i < num_chains; ++i) {
rngs.push_back(util::create_rng(random_seed, chain + i));
auto cont_vector = util::initialize(model, *init[i], rngs[i], init_radius,
false, logger, init_writer[i]);
cont_vectors.push_back(
Eigen::Map<Eigen::VectorXd>(cont_vector.data(), cont_vector.size()));
samples.emplace_back(cont_vectors[i], 0, 0);
writers.emplace_back(sample_writers[i], diagnostic_writers[i], logger);
// Headers
writers[i].write_sample_names(samples[i], samplers[i], model);
writers[i].write_diagnostic_names(samples[i], samplers[i], model);
}

tbb::parallel_for(
tbb::blocked_range<size_t>(0, num_chains, 1),
[&samplers, &writers, &samples, &model, &rngs, &interrupt, &logger,
num_samples, num_thin, refresh, chain,
num_chains](const tbb::blocked_range<size_t>& r) {
for (size_t i = r.begin(); i != r.end(); ++i) {
auto start = std::chrono::steady_clock::now();
util::generate_transitions(samplers[i], num_samples, 0, num_samples,
num_thin, refresh, true, false, writers[i],
samples[i], model, rngs[i], interrupt,
logger, chain + i, num_chains);
auto end = std::chrono::steady_clock::now();
double sample_delta_t
= std::chrono::duration_cast<std::chrono::milliseconds>(end
- start)
.count()
/ 1000.0;
writers[i].write_timing(0.0, sample_delta_t);
}
},
tbb::simple_partitioner());
return error_codes::OK;
}

} // namespace sample
} // namespace services
} // namespace stan
Expand Down
175 changes: 175 additions & 0 deletions src/stan/services/sample/hmc_nuts_dense_e.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,181 @@ int hmc_nuts_dense_e(Model& model, const stan::io::var_context& init,
sample_writer, diagnostic_writer);
}

/**
* Runs multiple chains of NUTS without adaptation using dense Euclidean metric
* with a pre-specified Euclidean metric.
*
* @tparam Model Model class
* @tparam InitContextPtr A pointer with underlying type derived from
`stan::io::var_context`
* @tparam InitInvContextPtr A pointer with underlying type derived from
`stan::io::var_context`
* @tparam SamplerWriter A type derived from `stan::callbacks::writer`
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @param[in] model Input model to test (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_inv_metric`, `init_writer`, `sample_writer`, and `diagnostic_writer`
must
* be the same length as this value.
* @param[in] init An std vector of init var contexts for initialization of each
* chain.
* @param[in] init_inv_metric An std vector of var contexts exposing an initial
* diagonal inverse Euclidean metric for each chain (must be positive definite)
* @param[in] random_seed random seed for the random number generator
* @param[in] init_chain_id first chain id. The pseudo random number generator
* will advance by for each chain by an integer sequence from `init_chain_id` to
* `init_chain_id+num_chains-1`
* @param[in] init_radius radius to initialize
* @param[in] num_warmup Number of warmup samples
* @param[in] num_samples Number of samples
* @param[in] num_thin Number to thin the samples
* @param[in] save_warmup Indicates whether to save the warmup iterations
* @param[in] refresh Controls the output
* @param[in] stepsize initial stepsize for discrete evolution
* @param[in] stepsize_jitter uniform random jitter of stepsize
* @param[in] max_depth Maximum tree depth
* @param[in,out] interrupt Callback for interrupts
* @param[in,out] logger Logger for messages
* @param[in,out] init_writer std vector of Writer callbacks for unconstrained
inits of each chain.
* @param[in,out] sample_writer std vector of Writers for draws of each chain.
* @param[in,out] diagnostic_writer std vector of Writers for diagnostic
* information of each chain.
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitInvContextPtr,
typename InitWriter, typename SampleWriter, typename DiagnosticWriter>
int hmc_nuts_dense_e(Model& model, size_t num_chains,
const std::vector<InitContextPtr>& init,
const std::vector<InitInvContextPtr>& init_inv_metric,
unsigned int random_seed, unsigned int init_chain_id,
double init_radius, int num_warmup, int num_samples,
int num_thin, bool save_warmup, int refresh,
double stepsize, double stepsize_jitter, int max_depth,
callbacks::interrupt& interrupt, callbacks::logger& logger,
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer) {
if (num_chains == 1) {
return hmc_nuts_dense_e(
model, *init[0], *init_inv_metric[0], random_seed, init_chain_id,
init_radius, num_warmup, num_samples, num_thin, save_warmup, refresh,
stepsize, stepsize_jitter, max_depth, interrupt, logger, init_writer[0],
sample_writer[0], diagnostic_writer[0]);
}
std::vector<boost::ecuyer1988> rngs;
rngs.reserve(num_chains);
std::vector<std::vector<double>> cont_vectors;
cont_vectors.reserve(num_chains);
using sample_t = stan::mcmc::dense_e_nuts<Model, boost::ecuyer1988>;
std::vector<sample_t> samplers;
samplers.reserve(num_chains);
try {
for (int i = 0; i < num_chains; ++i) {
rngs.emplace_back(util::create_rng(random_seed, init_chain_id + i));
cont_vectors.emplace_back(util::initialize(
model, *init[i], rngs[i], init_radius, true, logger, init_writer[i]));
Eigen::MatrixXd inv_metric = util::read_dense_inv_metric(
*init_inv_metric[i], model.num_params_r(), logger);
util::validate_dense_inv_metric(inv_metric, logger);

samplers.emplace_back(model, rngs[i]);
samplers[i].set_metric(inv_metric);
samplers[i].set_nominal_stepsize(stepsize);
samplers[i].set_stepsize_jitter(stepsize_jitter);
samplers[i].set_max_depth(max_depth);
}
} catch (const std::domain_error& e) {
return error_codes::CONFIG;
}
tbb::parallel_for(
tbb::blocked_range<size_t>(0, num_chains, 1),
[num_warmup, num_samples, num_thin, refresh, save_warmup, num_chains,
init_chain_id, &samplers, &model, &rngs, &interrupt, &logger,
&sample_writer, &cont_vectors,
&diagnostic_writer](const tbb::blocked_range<size_t>& r) {
for (size_t i = r.begin(); i != r.end(); ++i) {
util::run_sampler(samplers[i], model, cont_vectors[i], num_warmup,
num_samples, num_thin, refresh, save_warmup,
rngs[i], interrupt, logger, sample_writer[i],
diagnostic_writer[i], init_chain_id + i);
}
},
tbb::simple_partitioner());
return error_codes::OK;
}

/**
* Runs multiple chains of NUTS without adaptation using dense Euclidean metric,
* with identity matrix as initial inv_metric.
*
* @tparam Model Model class
* @tparam InitContextPtr A pointer with underlying type derived from
* `stan::io::var_context`
* @tparam InitWriter A type derived from `stan::callbacks::writer`
* @tparam SamplerWriter A type derived from `stan::callbacks::writer`
* @tparam DiagnosticWriter A type derived from `stan::callbacks::writer`
* @param[in] model Input model to test (with data already instantiated)
* @param[in] num_chains The number of chains to run in parallel. `init`,
* `init_writer`, `sample_writer`, and `diagnostic_writer` must be the same
* length as this value.
* @param[in] init An std vector of init var contexts for initialization of each
* chain.
* @param[in] random_seed random seed for the random number generator
* @param[in] init_chain_id first chain id. The pseudo random number generator
* will advance by for each chain by an integer sequence from `init_chain_id` to
* `init_chain_id+num_chains-1`
* @param[in] init_radius radius to initialize
* @param[in] num_warmup Number of warmup samples
* @param[in] num_samples Number of samples
* @param[in] num_thin Number to thin the samples
* @param[in] save_warmup Indicates whether to save the warmup iterations
* @param[in] refresh Controls the output
* @param[in] stepsize initial stepsize for discrete evolution
* @param[in] stepsize_jitter uniform random jitter of stepsize
* @param[in] max_depth Maximum tree depth
* @param[in,out] interrupt Callback for interrupts
* @param[in,out] logger Logger for messages
* @param[in,out] init_writer std vector of Writer callbacks for unconstrained
* inits of each chain.
* @param[in,out] sample_writer std vector of Writers for draws of each chain.
* @param[in,out] diagnostic_writer std vector of Writers for diagnostic
* information of each chain.
* @return error_codes::OK if successful
*/
template <class Model, typename InitContextPtr, typename InitWriter,
typename SampleWriter, typename DiagnosticWriter>
int hmc_nuts_dense_e(Model& model, size_t num_chains,
const std::vector<InitContextPtr>& init,
unsigned int random_seed, unsigned int init_chain_id,
double init_radius, int num_warmup, int num_samples,
int num_thin, bool save_warmup, int refresh,
double stepsize, double stepsize_jitter, int max_depth,
callbacks::interrupt& interrupt, callbacks::logger& logger,
std::vector<InitWriter>& init_writer,
std::vector<SampleWriter>& sample_writer,
std::vector<DiagnosticWriter>& diagnostic_writer) {
if (num_chains == 1) {
return hmc_nuts_dense_e(model, *init[0], random_seed, init_chain_id,
init_radius, num_warmup, num_samples, num_thin,
save_warmup, refresh, stepsize, stepsize_jitter,
max_depth, interrupt, logger, init_writer[0],
sample_writer[0], diagnostic_writer[0]);
}
std::vector<std::unique_ptr<stan::io::dump>> unit_e_metrics;
unit_e_metrics.reserve(num_chains);
for (size_t i = 0; i < num_chains; ++i) {
unit_e_metrics.emplace_back(std::make_unique<stan::io::dump>(
util::create_unit_e_dense_inv_metric(model.num_params_r())));
}
return hmc_nuts_dense_e(model, num_chains, init, unit_e_metrics, random_seed,
init_chain_id, init_radius, num_warmup, num_samples,
num_thin, save_warmup, refresh, stepsize,
stepsize_jitter, max_depth, interrupt, logger,
init_writer, sample_writer, diagnostic_writer);
}

} // namespace sample
} // namespace services
} // namespace stan
Expand Down
Loading