Skip to content

Commit

Permalink
Fix example code to adhere to new API
Browse files Browse the repository at this point in the history
  • Loading branch information
koparasy committed Jun 4, 2024
1 parent 194a6b5 commit c8fb053
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 52 deletions.
41 changes: 14 additions & 27 deletions examples/app/eos_ams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,22 @@ AMSEOS<FPType>::AMSEOS(EOS<FPType> *model,
const int mpi_nproc,
const double threshold,
const char *surrogate_path,
const char *uq_path,
const char *db_path)
const char *uq_path)
: model_(model)
{
AMSConfig conf = {exec_policy,
dtype,
res_type,
callBack<FPType>,
(char *)surrogate_path,
(char *)uq_path,
(char *)db_path,
threshold,
uq_policy,
k_nearest,
mpi_task,
mpi_nproc};

wf_ = AMSCreateExecutor(conf);
AMSCAbstrModel model_descr = AMSRegisterAbstractModel("ideal_gas",
uq_policy,
threshold,
surrogate_path,
uq_path,
"ideal_gas",
k_nearest);
wf_ = AMSCreateExecutor(model_descr,
dtype,
res_type,
(AMSPhysicFn)callBack<FPType>,
mpi_task,
mpi_nproc);
}

template <typename FPType>
Expand All @@ -72,24 +70,13 @@ void AMSEOS<FPType>::Eval(const int length,
std::vector<const FPType *> inputs = {density, energy};
std::vector<FPType *> outputs = {pressure, soundspeed2, bulkmod, temperature};

#ifdef __ENABLE_MPI__
AMSDistributedExecute(wf_,
MPI_COMM_WORLD,
(void *)model_,
length,
reinterpret_cast<const void **>(inputs.data()),
reinterpret_cast<void **>(outputs.data()),
inputs.size(),
outputs.size());
#else
AMSExecute(wf_,
(void *)model_,
length,
reinterpret_cast<const void **>(inputs.data()),
reinterpret_cast<void **>(outputs.data()),
inputs.size(),
outputs.size());
#endif
}

template class AMSEOS<double>;
Expand Down
3 changes: 1 addition & 2 deletions examples/app/eos_ams.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ class AMSEOS : public EOS<FPType>
const int mpi_nproc,
const double threshold,
const char *surrogate_path,
const char *uq_path,
const char *db_path);
const char *uq_path);

virtual ~AMSEOS() { delete model_; }

Expand Down
50 changes: 27 additions & 23 deletions examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,27 +179,32 @@ int run(const char *device_name,
CALIPER(CALI_MARK_BEGIN("Setup");)

const bool use_device = std::strcmp(device_name, "cpu") != 0;
AMSDBType dbType = AMSDBType::None;
AMSDBType dbType = AMSDBType::AMS_NONE;
if (std::strcmp(db_type, "csv") == 0) {
dbType = AMSDBType::CSV;
dbType = AMSDBType::AMS_CSV;
} else if (std::strcmp(db_type, "hdf5") == 0) {
dbType = AMSDBType::HDF5;
dbType = AMSDBType::AMS_HDF5;
} else if (std::strcmp(db_type, "rmq") == 0) {
dbType = AMSDBType::RMQ;
dbType = AMSDBType::AMS_RMQ;
}


if (dbType != AMSDBType::AMS_RMQ) {
AMSConfigureFSDatabase(dbType, db_config);
}

AMSUQPolicy uq_policy;

if (strcmp(uq_policy_opt, "faiss-max") == 0)
uq_policy = AMSUQPolicy::FAISS_Max;
uq_policy = AMSUQPolicy::AMS_FAISS_MAX;
else if (strcmp(uq_policy_opt, "faiss-mean") == 0)
uq_policy = AMSUQPolicy::FAISS_Mean;
uq_policy = AMSUQPolicy::AMS_FAISS_MEAN;
else if (strcmp(uq_policy_opt, "deltauq-max") == 0)
uq_policy = AMSUQPolicy::DeltaUQ_Max;
uq_policy = AMSUQPolicy::AMS_DELTAUQ_MAX;
else if (strcmp(uq_policy_opt, "deltauq-mean") == 0)
uq_policy = AMSUQPolicy::DeltaUQ_Mean;
uq_policy = AMSUQPolicy::AMS_DELTAUQ_MEAN;
else if (strcmp(uq_policy_opt, "random") == 0)
uq_policy = AMSUQPolicy::Random;
uq_policy = AMSUQPolicy::AMS_RANDOM;
else
throw std::runtime_error("Invalid UQ policy");

Expand Down Expand Up @@ -262,11 +267,11 @@ int run(const char *device_name,
// When we are not allocating from parent/root umpire allocator
// we need to inform AMS about the pool allocators.
if (strcmp(pool, "default") != 0) {
AMSSetAllocator(AMSResourceType::HOST, ams_host_alloc.c_str());
AMSSetAllocator(AMSResourceType::AMS_HOST, ams_host_alloc.c_str());

if (use_device) {
AMSSetAllocator(AMSResourceType::DEVICE, ams_device_alloc.c_str());
AMSSetAllocator(AMSResourceType::PINNED, ams_pinned_alloc.c_str());
AMSSetAllocator(AMSResourceType::AMS_DEVICE, ams_device_alloc.c_str());
AMSSetAllocator(AMSResourceType::AMS_PINNED, ams_pinned_alloc.c_str());
}
}

Expand Down Expand Up @@ -335,10 +340,10 @@ int run(const char *device_name,

db_path = (strlen(db_config) > 0) ? db_config : nullptr;

AMSResourceType ams_device = AMSResourceType::HOST;
if (use_device) ams_device = AMSResourceType::DEVICE;
AMSExecPolicy ams_loadBalance = AMSExecPolicy::UBALANCED;
if (lbalance) ams_loadBalance = AMSExecPolicy::BALANCED;
AMSResourceType ams_device = AMSResourceType::AMS_HOST;
if (use_device) ams_device = AMSResourceType::AMS_DEVICE;
AMSExecPolicy ams_loadBalance = AMSExecPolicy::AMS_UBALANCED;
if (lbalance) ams_loadBalance = AMSExecPolicy::AMS_BALANCED;
#else
constexpr bool use_ams = false;
#endif
Expand Down Expand Up @@ -370,8 +375,7 @@ int run(const char *device_name,
wS,
threshold,
surrogate_path,
uq_path,
db_path);
uq_path);

} else
#endif
Expand Down Expand Up @@ -616,7 +620,7 @@ int main(int argc, char **argv)
const char *db_type = "";

const char *precision_opt = "double";
AMSDType precision = AMSDType::Double;
AMSDType precision = AMSDType::AMS_DOUBLE;

const char *uq_policy_opt = "";
int k_nearest = 5;
Expand Down Expand Up @@ -837,16 +841,16 @@ int main(int argc, char **argv)
<< "(Weak Scaling)\n";

if (strcmp(precision_opt, "single") == 0)
precision = AMSDType::Single;
precision = AMSDType::AMS_SINGLE;
else if (strcmp(precision_opt, "double") == 0)
precision = AMSDType::Double;
precision = AMSDType::AMS_DOUBLE;
else {
std::cerr << "Invalid precision " << precision_opt << "\n";
return -1;
}

int ret = 0;
if (precision == AMSDType::Single)
if (precision == AMSDType::AMS_SINGLE)
ret = run<float>(device_name,
db_type,
uq_policy_opt,
Expand All @@ -872,7 +876,7 @@ int main(int argc, char **argv)
db_config,
lbalance,
k_nearest);
else if (precision == AMSDType::Double)
else if (precision == AMSDType::AMS_DOUBLE)
ret = run<double>(device_name,
db_type,
uq_policy_opt,
Expand Down

0 comments on commit c8fb053

Please sign in to comment.