Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/json #67

Merged
merged 8 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ jobs:
- name: Run tests Torch=On FAISS=On HDF5=On AMS
run: |
cd build
make test
source /spack/share/spack/setup-env.sh
spack env activate -p /ams-spack-env
env CTEST_OUTPUT_ON_FAILURE=1 make test
- name: Build CALIPER=Off Torch=Off FAISS=On HDF5=On AMS
shell: bash -l {0}
run: |
Expand Down Expand Up @@ -131,7 +133,9 @@ jobs:
- name: Run tests Torch=Off FAISS=On HDF5=On AMS
run: |
cd build
make test
source /spack/share/spack/setup-env.sh
spack env activate -p /ams-spack-env
env CTEST_OUTPUT_ON_FAILURE=1 make test
- name: Build Torch=Off FAISS=Off HDF5=On AMS
shell: bash -l {0}
run: |
Expand Down Expand Up @@ -168,6 +172,8 @@ jobs:
- name: Run tests Torch=Off FAISS=Off HDF5=On AMS
run: |
cd build
source /spack/share/spack/setup-env.sh
spack env activate -p /ams-spack-env
make test
- name: Build Torch=Off FAISS=Off HDF5=Off AMS
shell: bash -l {0}
Expand Down Expand Up @@ -202,6 +208,8 @@ jobs:
- name: Run tests Torch=Off FAISS=Off HDF5=Off AMS
run: |
cd build
source /spack/share/spack/setup-env.sh
spack env activate -p /ams-spack-env
make test
build-cuda-tests:
Expand Down
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ find_package(UMPIRE REQUIRED
list(APPEND AMS_APP_LIBRARIES umpire)
list(APPEND AMS_APP_INCLUDES ${UMPIRE_INCLUDE_DIR})

find_package(nlohmann_json REQUIRED)
list(APPEND AMS_APP_LIBRARIES nlohmann_json::nlohmann_json)

# ------------------------------------------------------------------------------
find_package(Threads REQUIRED)

Expand Down
46 changes: 16 additions & 30 deletions examples/app/eos_ams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*/

#include "eos_ams.hpp"

#include <vector>

#include "eos_ams.hpp"

template <typename FPType>
void callBack(void *cls,
long elements,
Expand Down Expand Up @@ -37,25 +37,22 @@ AMSEOS<FPType>::AMSEOS(EOS<FPType> *model,
const int mpi_nproc,
const double threshold,
const char *surrogate_path,
const char *uq_path,
const char *db_path)
const char *uq_path)
: model_(model)
{
AMSConfig conf = {exec_policy,
dtype,
res_type,
db_type,
callBack<FPType>,
(char *)surrogate_path,
(char *)uq_path,
(char *)db_path,
threshold,
uq_policy,
k_nearest,
mpi_task,
mpi_nproc};

wf_ = AMSCreateExecutor(conf);
AMSCAbstrModel model_descr = AMSRegisterAbstractModel("ideal_gas",
uq_policy,
threshold,
surrogate_path,
uq_path,
"ideal_gas",
k_nearest);
wf_ = AMSCreateExecutor(model_descr,
dtype,
res_type,
(AMSPhysicFn)callBack<FPType>,
mpi_task,
mpi_nproc);
}

template <typename FPType>
Expand All @@ -73,24 +70,13 @@ void AMSEOS<FPType>::Eval(const int length,
std::vector<const FPType *> inputs = {density, energy};
std::vector<FPType *> outputs = {pressure, soundspeed2, bulkmod, temperature};

#ifdef __ENABLE_MPI__
AMSDistributedExecute(wf_,
MPI_COMM_WORLD,
(void *)model_,
length,
reinterpret_cast<const void **>(inputs.data()),
reinterpret_cast<void **>(outputs.data()),
inputs.size(),
outputs.size());
#else
AMSExecute(wf_,
(void *)model_,
length,
reinterpret_cast<const void **>(inputs.data()),
reinterpret_cast<void **>(outputs.data()),
inputs.size(),
outputs.size());
#endif
}

template class AMSEOS<double>;
Expand Down
3 changes: 1 addition & 2 deletions examples/app/eos_ams.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ class AMSEOS : public EOS<FPType>
const int mpi_nproc,
const double threshold,
const char *surrogate_path,
const char *uq_path,
const char *db_path);
const char *uq_path);

virtual ~AMSEOS() { delete model_; }

Expand Down
52 changes: 29 additions & 23 deletions examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,27 +179,34 @@ int run(const char *device_name,
CALIPER(CALI_MARK_BEGIN("Setup");)

const bool use_device = std::strcmp(device_name, "cpu") != 0;
AMSDBType dbType = AMSDBType::None;
AMSDBType dbType = AMSDBType::AMS_NONE;
if (std::strcmp(db_type, "csv") == 0) {
dbType = AMSDBType::CSV;
dbType = AMSDBType::AMS_CSV;
} else if (std::strcmp(db_type, "hdf5") == 0) {
dbType = AMSDBType::HDF5;
dbType = AMSDBType::AMS_HDF5;
} else if (std::strcmp(db_type, "rmq") == 0) {
dbType = AMSDBType::RMQ;
dbType = AMSDBType::AMS_RMQ;
}

if (db_config == nullptr) dbType = AMSDBType::AMS_NONE;


if (dbType != AMSDBType::AMS_RMQ) {
AMSConfigureFSDatabase(dbType, db_config);
}

AMSUQPolicy uq_policy;

if (strcmp(uq_policy_opt, "faiss-max") == 0)
uq_policy = AMSUQPolicy::FAISS_Max;
uq_policy = AMSUQPolicy::AMS_FAISS_MAX;
else if (strcmp(uq_policy_opt, "faiss-mean") == 0)
uq_policy = AMSUQPolicy::FAISS_Mean;
uq_policy = AMSUQPolicy::AMS_FAISS_MEAN;
else if (strcmp(uq_policy_opt, "deltauq-max") == 0)
uq_policy = AMSUQPolicy::DeltaUQ_Max;
uq_policy = AMSUQPolicy::AMS_DELTAUQ_MAX;
else if (strcmp(uq_policy_opt, "deltauq-mean") == 0)
uq_policy = AMSUQPolicy::DeltaUQ_Mean;
uq_policy = AMSUQPolicy::AMS_DELTAUQ_MEAN;
else if (strcmp(uq_policy_opt, "random") == 0)
uq_policy = AMSUQPolicy::RandomUQ;
uq_policy = AMSUQPolicy::AMS_RANDOM;
else
throw std::runtime_error("Invalid UQ policy");

Expand Down Expand Up @@ -262,11 +269,11 @@ int run(const char *device_name,
// When we are not allocating from parent/root umpire allocator
// we need to inform AMS about the pool allocators.
if (strcmp(pool, "default") != 0) {
AMSSetAllocator(AMSResourceType::HOST, ams_host_alloc.c_str());
AMSSetAllocator(AMSResourceType::AMS_HOST, ams_host_alloc.c_str());

if (use_device) {
AMSSetAllocator(AMSResourceType::DEVICE, ams_device_alloc.c_str());
AMSSetAllocator(AMSResourceType::PINNED, ams_pinned_alloc.c_str());
AMSSetAllocator(AMSResourceType::AMS_DEVICE, ams_device_alloc.c_str());
AMSSetAllocator(AMSResourceType::AMS_PINNED, ams_pinned_alloc.c_str());
}
}

Expand Down Expand Up @@ -335,10 +342,10 @@ int run(const char *device_name,

db_path = (strlen(db_config) > 0) ? db_config : nullptr;

AMSResourceType ams_device = AMSResourceType::HOST;
if (use_device) ams_device = AMSResourceType::DEVICE;
AMSExecPolicy ams_loadBalance = AMSExecPolicy::UBALANCED;
if (lbalance) ams_loadBalance = AMSExecPolicy::BALANCED;
AMSResourceType ams_device = AMSResourceType::AMS_HOST;
if (use_device) ams_device = AMSResourceType::AMS_DEVICE;
AMSExecPolicy ams_loadBalance = AMSExecPolicy::AMS_UBALANCED;
if (lbalance) ams_loadBalance = AMSExecPolicy::AMS_BALANCED;
#else
constexpr bool use_ams = false;
#endif
Expand Down Expand Up @@ -370,8 +377,7 @@ int run(const char *device_name,
wS,
threshold,
surrogate_path,
uq_path,
db_path);
uq_path);

} else
#endif
Expand Down Expand Up @@ -616,7 +622,7 @@ int main(int argc, char **argv)
const char *db_type = "";

const char *precision_opt = "double";
AMSDType precision = AMSDType::Double;
AMSDType precision = AMSDType::AMS_DOUBLE;

const char *uq_policy_opt = "";
int k_nearest = 5;
Expand Down Expand Up @@ -837,16 +843,16 @@ int main(int argc, char **argv)
<< "(Weak Scaling)\n";

if (strcmp(precision_opt, "single") == 0)
precision = AMSDType::Single;
precision = AMSDType::AMS_SINGLE;
else if (strcmp(precision_opt, "double") == 0)
precision = AMSDType::Double;
precision = AMSDType::AMS_DOUBLE;
else {
std::cerr << "Invalid precision " << precision_opt << "\n";
return -1;
}

int ret = 0;
if (precision == AMSDType::Single)
if (precision == AMSDType::AMS_SINGLE)
ret = run<float>(device_name,
db_type,
uq_policy_opt,
Expand All @@ -872,7 +878,7 @@ int main(int argc, char **argv)
db_config,
lbalance,
k_nearest);
else if (precision == AMSDType::Double)
else if (precision == AMSDType::AMS_DOUBLE)
ret = run<double>(device_name,
db_type,
uq_policy_opt,
Expand Down
Loading
Loading