Skip to content

Commit

Permalink
Remove load balance as a policy option from AMSExecute.
Browse files Browse the repository at this point in the history
  • Loading branch information
koparasy committed May 20, 2024
1 parent 2de8a47 commit 81d4b5c
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 95 deletions.
142 changes: 76 additions & 66 deletions src/AMSlib/AMS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,7 @@ void _AMSExecute(AMSExecutor executor,
const void **input_data,
void **output_data,
int inputDim,
int outputDim,
MPI_Comm Comm = 0)
int outputDim)
{
long index = static_cast<long>(executor);
if (index >= _amsWrap.executors.size())
Expand All @@ -424,8 +423,7 @@ void _AMSExecute(AMSExecutor executor,
reinterpret_cast<const double **>(input_data),
reinterpret_cast<double **>(output_data),
inputDim,
outputDim,
Comm);
outputDim);
} else if (currExec.first == AMSDType::AMS_SINGLE) {
ams::AMSWorkflow<float> *sWF =
reinterpret_cast<ams::AMSWorkflow<float> *>(currExec.second);
Expand All @@ -434,73 +432,105 @@ void _AMSExecute(AMSExecutor executor,
reinterpret_cast<const float **>(input_data),
reinterpret_cast<float **>(output_data),
inputDim,
outputDim,
Comm);
outputDim);
} else {
throw std::invalid_argument("Data type is not supported by AMSLib!");
return;
}
}

template <typename FPTypeValue>
ams::AMSWorkflow<FPTypeValue> *_AMSCreateExecutor(AMSCAbstrModel model,
AMSDType data_type,
AMSResourceType resource_type,
AMSPhysicFn call_back,
int process_id,
int world_size)
{
static std::once_flag flag;
std::call_once(flag, [&]() {
auto &rm = ams::ResourceManager::getInstance();
rm.init();
});

AMSAbstractModel &model_descr = _amsWrap.get_model(model);

ams::AMSWorkflow<FPTypeValue> *WF =
new ams::AMSWorkflow<FPTypeValue>(call_back,
model_descr.UQPath,
model_descr.SPath,
model_descr.DBLabel,
resource_type,
model_descr.threshold,
model_descr.uqPolicy,
model_descr.nClusters,
process_id,
world_size);
return WF;
}

template <typename FPTypeValue>
AMSExecutor _AMSRegisterExecutor(AMSDType data_type,
ams::AMSWorkflow<FPTypeValue> *workflow)
{
_amsWrap.executors.push_back(
std::make_pair(data_type, static_cast<void *>(workflow)));
return static_cast<AMSExecutor>(_amsWrap.executors.size()) - 1L;
}


#ifdef __cplusplus
extern "C" {
#endif

AMSExecutor AMSCreateExecutor(AMSCAbstrModel model,
AMSExecPolicy exec_policy,
AMSDType data_type,
AMSResourceType resource_type,
AMSPhysicFn call_back,
int process_id,
int world_size)
{
static std::once_flag flag;
std::call_once(flag, [&]() {
auto &rm = ams::ResourceManager::getInstance();
rm.init();
});
if (data_type == AMSDType::AMS_DOUBLE) {
auto *dWF = _AMSCreateExecutor<double>(
model, data_type, resource_type, call_back, process_id, world_size);
return _AMSRegisterExecutor(data_type, dWF);

AMSAbstractModel &model_descr = _amsWrap.get_model(model);
std::cout << "Returing and creating executor from model\n";
model_descr.dump();
} else if (data_type == AMSDType::AMS_SINGLE) {
auto *sWF = _AMSCreateExecutor<float>(
model, data_type, resource_type, call_back, process_id, world_size);
return _AMSRegisterExecutor(data_type, sWF);
} else {
throw std::invalid_argument("Data type is not supported by AMSLib!");
return static_cast<AMSExecutor>(-1);
}
}

#ifdef __AMS_ENABLE_MPI__
AMSExecutor AMSCreateDistributedExecutor(AMSCAbstrModel model,
AMSDType data_type,
AMSResourceType resource_type,
AMSPhysicFn call_back,
MPI_Comm Comm,
int process_id,
int world_size)
{
if (data_type == AMSDType::AMS_DOUBLE) {
ams::AMSWorkflow<double> *dWF =
new ams::AMSWorkflow<double>(call_back,
model_descr.UQPath,
model_descr.SPath,
model_descr.DBLabel,
resource_type,
model_descr.threshold,
model_descr.uqPolicy,
model_descr.nClusters,
process_id,
world_size,
exec_policy);
_amsWrap.executors.push_back(
std::make_pair(data_type, static_cast<void *>(dWF)));
return static_cast<AMSExecutor>(_amsWrap.executors.size()) - 1L;
auto *dWF = _AMSCreateExecutor<double>(
model, data_type, resource_type, call_back, process_id, world_size);
dWF->set_communicator(Comm);
return _AMSRegisterExecutor(data_type, dWF);

} else if (data_type == AMSDType::AMS_SINGLE) {
ams::AMSWorkflow<float> *sWF =
new ams::AMSWorkflow<float>(call_back,
model_descr.UQPath,
model_descr.SPath,
model_descr.DBLabel,
resource_type,
model_descr.threshold,
model_descr.uqPolicy,
model_descr.nClusters,
process_id,
world_size,
exec_policy);
_amsWrap.executors.push_back(
std::make_pair(data_type, static_cast<void *>(sWF)));
return static_cast<AMSExecutor>(_amsWrap.executors.size()) - 1L;
auto *sWF = _AMSCreateExecutor<float>(
model, data_type, resource_type, call_back, process_id, world_size);
sWF->set_communicator(Comm);
return _AMSRegisterExecutor(data_type, sWF);
} else {
throw std::invalid_argument("Data type is not supported by AMSLib!");
return static_cast<AMSExecutor>(-1);
}
}
#endif

void AMSExecute(AMSExecutor executor,
void *probDescr,
Expand Down Expand Up @@ -536,26 +566,6 @@ void AMSDestroyExecutor(AMSExecutor executor)
}
}

#ifdef __ENABLE_MPI__
void AMSDistributedExecute(AMSExecutor executor,
MPI_Comm Comm,
void *probDescr,
const int numElements,
const void **input_data,
void **output_data,
int inputDim,
int outputDim)
{
_AMSExecute(executor,
probDescr,
numElements,
input_data,
output_data,
inputDim,
outputDim,
Comm);
}
#endif

const char *AMSGetAllocatorName(AMSResourceType device)
{
Expand Down Expand Up @@ -594,7 +604,7 @@ AMSCAbstrModel AMSQueryModel(const char *domain_model)
return _amsWrap.get_model_index(domain_model);
}

void configure_ams_fs_database(AMSDBType db_type, const char *db_path)
void AMSConfigureFSDatabase(AMSDBType db_type, const char *db_path)
{
auto &db_instance = ams::db::DBManager::getInstance();
db_instance.instantiate_fs_db(db_type, std::string(db_path));
Expand Down
29 changes: 12 additions & 17 deletions src/AMSlib/include/AMS.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,24 @@ typedef enum {
} AMSUQPolicy;


#warning "I need to remove the exec policy when creating the executor"
AMSExecutor AMSCreateExecutor(AMSCAbstrModel model,
AMSExecPolicy exec_policy,
AMSDType data_type,
AMSResourceType resource_type,
AMSPhysicFn call_back,
int process_id,
int world_size);

#ifdef __AMS_ENABLE_MPI__
AMSExecutor AMSCreateDistributedExecutor(AMSCAbstrModel model,
AMSDType data_type,
AMSResourceType resource_type,
AMSPhysicFn call_back,
MPI_Comm comm,
int process_id,
int world_size);
#endif


AMSCAbstrModel AMSRegisterAbstractModel(const char *domain_name,
AMSUQPolicy uq_policy,
double threshold,
Expand All @@ -90,16 +99,6 @@ AMSCAbstrModel AMSRegisterAbstractModel(const char *domain_name,

AMSCAbstrModel AMSQueryModel(const char *domain_model);

#ifdef __AMS_ENABLE_MPI__
void AMSDistributedExecute(AMSExecutor executor,
MPI_Comm Comm,
void *probDescr,
const int numElements,
const void **input_data,
void **output_data,
int inputDim,
int outputDim);
#endif
void AMSExecute(AMSExecutor executor,
void *probDescr,
const int numElements,
Expand All @@ -110,13 +109,9 @@ void AMSExecute(AMSExecutor executor,

void AMSDestroyExecutor(AMSExecutor executor);

#ifdef __AMS_ENABLE_MPI__
int AMSSetCommunicator(MPI_Comm Comm);
#endif

void AMSSetAllocator(AMSResourceType resource, const char *alloc_name);
const char *AMSGetAllocatorName(AMSResourceType device);
void configure_ams_fs_database(AMSDBType db_type, const char *db_path);
void AMSConfigureFSDatabase(AMSDBType db_type, const char *db_path);

#ifdef __cplusplus
}
Expand Down
51 changes: 42 additions & 9 deletions src/AMSlib/wf/workflow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#ifndef __AMS_WORKFLOW_HPP__
#define __AMS_WORKFLOW_HPP__

#include <mpi.h>

#include "debug.h"
#ifdef __AMS_ENABLE_CALIPER__
#include <caliper/cali_macros.h>
Expand Down Expand Up @@ -77,7 +79,16 @@ class AMSWorkflow
AMSResourceType appDataLoc;

/** @brief execution policy of the distributed system. Load balance or not. */
const AMSExecPolicy ePolicy;
AMSExecPolicy ePolicy;

#ifdef __AMS_ENABLE_MPI__
/** @brief MPI Communicator for all ranks that call collectively the evaluate function **/
MPI_Comm comm;
#endif


/** @brief Is the evaluate a distributed execution **/
bool isDistributed;

/** \brief Store the data in the database and copies
* data from the GPU to the CPU and then to the database.
Expand Down Expand Up @@ -151,6 +162,9 @@ class AMSWorkflow
: AppCall(nullptr),
DB(nullptr),
appDataLoc(AMSResourceType::AMS_HOST),
#ifdef __AMS_ENABLE_MPI__
comm(MPI_COMM_NULL),
#endif
ePolicy(AMSExecPolicy::AMS_UBALANCED)
{
}
Expand All @@ -164,15 +178,18 @@ class AMSWorkflow
const AMSUQPolicy uq_policy,
const int nClusters,
int _pId = 0,
int _wSize = 1,
AMSExecPolicy policy = AMSExecPolicy::AMS_UBALANCED)
int _wSize = 1)
: AppCall(_AppCall),
domainName(domain_name),
rId(_pId),
wSize(_wSize),
appDataLoc(app_data_loc),
uqPolicy(uq_policy),
ePolicy(policy)
#ifdef __AMS_ENABLE_MPI__
comm(MPI_COMM_NULL),
#endif
ePolicy(AMSExecPolicy::AMS_UBALANCED)

{
DB = nullptr;
auto &dbm = ams::db::DBManager::getInstance();
Expand All @@ -184,6 +201,19 @@ class AMSWorkflow

void set_physics(AMSPhysicFn _AppCall) { AppCall = _AppCall; }

void set_communicator(MPI_Comm communicator) { comm = communicator; }

void set_exec_policy(AMSExecPolicy policy) { ePolicy = policy; }

bool should_load_balance() const
{
#ifdef __AMS_ENABLE_MPI__
return (comm != MPI_COMM_NULL && ePolicy == AMSExecPolicy::AMS_BALANCED);
#else
return false;
#endif
}

~AMSWorkflow() { DBG(Workflow, "Destroying Workflow Handler"); }

/** @brief This is the main entry point of AMSLib and replaces the original
Expand Down Expand Up @@ -237,8 +267,7 @@ class AMSWorkflow
const FPTypeValue **inputs,
FPTypeValue **outputs,
int inputDim,
int outputDim,
MPI_Comm Comm = nullptr)
int outputDim)
{
CALIPER(CALI_MARK_BEGIN("AMSEvaluate");)

Expand Down Expand Up @@ -329,10 +358,14 @@ class AMSWorkflow
void **oPtr = reinterpret_cast<void **>(packedOutputs.data());
long lbElements = packedElements;

// FIXME: I don't like the way we separate code here.
// Simple modification can make it easier to read.
// if (should_load_balance) -> Code for load balancing
// else -> current code
#ifdef __ENABLE_MPI__
CALIPER(CALI_MARK_BEGIN("LOAD BALANCE MODULE");)
AMSLoadBalancer<FPTypeValue> lBalancer(rId, wSize, packedElements, Comm);
if (ePolicy == AMSExecPolicy::AMS_BALANCED && Comm) {
AMSLoadBalancer<FPTypeValue> lBalancer(rId, wSize, packedElements, comm);
if (should_load_balance()) {
lBalancer.init(inputDim, outputDim, appDataLoc);
lBalancer.scatterInputs(packedInputs, appDataLoc);
iPtr = reinterpret_cast<void **>(lBalancer.inputs());
Expand All @@ -351,7 +384,7 @@ class AMSWorkflow

#ifdef __ENABLE_MPI__
CALIPER(CALI_MARK_BEGIN("LOAD BALANCE MODULE");)
if (ePolicy == AMSExecPolicy::AMS_BALANCED && Comm) {
if (should_load_balance()) {
lBalancer.gatherOutputs(packedOutputs, appDataLoc);
}
CALIPER(CALI_MARK_END("LOAD BALANCE MODULE");)
Expand Down
Loading

0 comments on commit 81d4b5c

Please sign in to comment.