Skip to content

Commit

Permalink
Unify UQ interfaces under the uq module
Browse files Browse the repository at this point in the history
- Separate random from hdcache
  • Loading branch information
ggeorgakoudis committed Dec 7, 2023
1 parent bb5bac3 commit 1d8deb9
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 154 deletions.
2 changes: 2 additions & 0 deletions examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ int run(const char *device_name,
uq_policy = AMSUQPolicy::DeltaUQ_Max;
else if (strcmp(uq_policy_opt, "deltauq-mean") == 0)
uq_policy = AMSUQPolicy::DeltaUQ_Mean;
else if (strcmp(uq_policy_opt, "random") == 0)
uq_policy = AMSUQPolicy::RandomUQ;
else
throw std::runtime_error("Invalid UQ policy");

Expand Down
5 changes: 3 additions & 2 deletions src/AMSlib/include/AMS.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,15 @@ typedef enum { UBALANCED = 0, BALANCED } AMSExecPolicy;
typedef enum { None = 0, CSV, REDIS, HDF5, RMQ } AMSDBType;

// TODO: create a cleaner interface that separates UQ type (FAISS, DeltaUQ) with policy (max, mean).
typedef enum {
enum struct AMSUQPolicy {
AMSUQPolicy_BEGIN = 0,
FAISS_Mean,
FAISS_Max,
DeltaUQ_Mean,
DeltaUQ_Max,
RandomUQ,
AMSUQPolicy_END
} AMSUQPolicy;
};

typedef struct ams_conf {
const AMSExecPolicy ePolicy;
Expand Down
95 changes: 12 additions & 83 deletions src/AMSlib/ml/hdcache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class HDCache
Index *m_index = nullptr;
const uint8_t m_dim;

const bool m_use_random;
const int m_knbrs = 0;
const AMSUQPolicy m_policy = AMSUQPolicy::FAISS_Mean;

Expand Down Expand Up @@ -96,17 +95,6 @@ class HDCache
//! ------------------------------------------------------------------------
//! constructors
//! ------------------------------------------------------------------------
HDCache(AMSResourceType resource, TypeInValue threshold = 0.5)
: m_index(nullptr),
m_dim(0),
m_use_random(true),
m_knbrs(-1),
cache_location(resource),
acceptable_error(threshold)
{
print();
}

#ifdef __ENABLE_FAISS__
HDCache(const std::string &cache_path,
AMSResourceType resource,
Expand All @@ -115,7 +103,6 @@ class HDCache
TypeInValue threshold = 0.5)
: m_index(load_cache(cache_path)),
m_dim(m_index->d),
m_use_random(false),
m_knbrs(knbrs),
m_policy(uqPolicy),
cache_location(resource),
Expand All @@ -139,7 +126,6 @@ class HDCache
TypeInValue threshold = 0.5)
: m_index(load_cache(cache_path)),
m_dim(0),
m_use_random(false),
m_knbrs(knbrs),
m_policy(uqPolicy),
cache_location(resource),
Expand Down Expand Up @@ -212,7 +198,8 @@ class HDCache
if (uqPolicy != AMSUQPolicy::FAISS_Mean &&
uqPolicy != AMSUQPolicy::FAISS_Max)
THROW(std::invalid_argument,
"Invalid UQ policy for hdcache" + std::to_string(uqPolicy));
"Invalid UQ policy for hdcache" +
std::to_string(static_cast<unsigned int>(uqPolicy)));

DBG(UQModule, "Generating new cache under (%s)", cache_path.c_str())
std::shared_ptr<HDCache<TypeInValue>> new_cache =
Expand All @@ -223,30 +210,6 @@ class HDCache
return new_cache;
}

static std::shared_ptr<HDCache<TypeInValue>> getInstance(
AMSResourceType resource,
float threshold = 0.5)
{
static std::string random_path("random");
std::shared_ptr<HDCache<TypeInValue>> cache = find_cache(
random_path, resource, AMSUQPolicy::FAISS_Mean, -1, threshold);
if (cache) {
DBG(UQModule, "Returning existing cache under (%s)", random_path.c_str())
return cache;
}

DBG(UQModule,
"Generating new cache under (%s, threshold:%f)",
random_path.c_str(),
threshold)
std::shared_ptr<HDCache<TypeInValue>> new_cache =
std::shared_ptr<HDCache<TypeInValue>>(
new HDCache<TypeInValue>(resource, threshold));

instances.insert(std::make_pair(random_path, new_cache));
return new_cache;
}

~HDCache()
{
DBG(UQModule, "Deleting UQ-Module");
Expand All @@ -272,25 +235,21 @@ class HDCache
if (has_index()) {
info = "npoints = " + std::to_string(count());
}
DBG(UQModule,
"HDCache (on_device = %d random = %d %s)",
cache_location,
m_use_random,
info.c_str());
DBG(UQModule, "HDCache (on_device = %d %s)", cache_location, info.c_str());
}

inline bool has_index() const
{
#ifdef __ENABLE_FAISS__
if (!m_use_random) return m_index != nullptr && m_index->is_trained;
return m_index != nullptr && m_index->is_trained;
#endif
return true;
}

inline size_t count() const
{
#ifdef __ENABLE_FAISS__
if (!m_use_random) return m_index->ntotal;
return m_index->ntotal;
#endif
return 0;
}
Expand Down Expand Up @@ -326,8 +285,6 @@ class HDCache
PERFFASPECT()
void add(const size_t ndata, const size_t d, TypeInValue *data)
{
if (m_use_random) return;

DBG(UQModule, "Add %ld %ld points to HDCache", ndata, d);
CFATAL(UQModule, d != m_dim, "Mismatch in data dimensionality!")
CFATAL(UQModule,
Expand All @@ -341,8 +298,6 @@ class HDCache
PERFFASPECT()
void add(const size_t ndata, const std::vector<TypeInValue *> &inputs)
{
if (m_use_random) return;

if (inputs.size() != m_dim)
CFATAL(UQModule,
inputs.size() != m_dim,
Expand All @@ -364,7 +319,6 @@ class HDCache
PERFFASPECT()
void train(const size_t ndata, const size_t d, TypeInValue *data)
{
if (m_use_random) return;
DBG(UQModule, "Add %ld %ld points to HDCache", ndata, d);
CFATAL(UQModule, d != m_dim, "Mismatch in data dimensionality!")
CFATAL(UQModule,
Expand All @@ -379,7 +333,6 @@ class HDCache
PERFFASPECT()
void train(const size_t ndata, const std::vector<TypeInValue *> &inputs)
{
if (m_use_random) return;
TypeValue *lin_data =
data_handler::linearize_features(cache_location, ndata, inputs);
_train(ndata, lin_data);
Expand All @@ -406,15 +359,9 @@ class HDCache
"HDCache does not have a valid and trained index!")
DBG(UQModule, "Evaluating %ld %ld points using HDCache", ndata, d);

CFATAL(UQModule,
(!m_use_random) && (d != m_dim),
"Mismatch in data dimensionality!")
CFATAL(UQModule, (d != m_dim), "Mismatch in data dimensionality!")

if (m_use_random) {
_evaluate(ndata, is_acceptable);
} else {
_evaluate(ndata, data, is_acceptable);
}
_evaluate(ndata, data, is_acceptable);

if (cache_location == AMSResourceType::DEVICE) {
deviceCheckErrors(__FILE__, __LINE__);
Expand Down Expand Up @@ -442,17 +389,13 @@ class HDCache
acceptable_error,
m_policy);
CFATAL(UQModule,
((!m_use_random) && inputs.size() != m_dim),
(inputs.size() != m_dim),
"Mismatch in data dimensionality!")

if (m_use_random) {
_evaluate(ndata, is_acceptable);
} else {
TypeValue *lin_data =
data_handler::linearize_features(cache_location, ndata, inputs);
_evaluate(ndata, lin_data, is_acceptable);
ams::ResourceManager::deallocate(lin_data, cache_location);
}
TypeValue *lin_data =
data_handler::linearize_features(cache_location, ndata, inputs);
_evaluate(ndata, lin_data, is_acceptable);
ams::ResourceManager::deallocate(lin_data, cache_location);
DBG(UQModule, "Done with evalution of uq");
}

Expand Down Expand Up @@ -619,20 +562,6 @@ class HDCache
{
}
#endif
PERFFASPECT()
inline void _evaluate(const size_t ndata, bool *is_acceptable) const
{
if (cache_location == AMSResourceType::DEVICE) {
#ifdef __ENABLE_CUDA__
random_uq_device<<<1, 1>>>(is_acceptable, ndata, acceptable_error);
#else
THROW(std::runtime_error,
"Random-uq is not configured to use device allocations");
#endif
} else {
random_uq_host(is_acceptable, ndata, acceptable_error);
}
}
// -------------------------------------------------------------------------
};

Expand Down
42 changes: 42 additions & 0 deletions src/AMSlib/ml/random_uq.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
* AMSLib Project Developers
*
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*/

#ifndef __AMS_RANDOM_UQ_HPP__
#define __AMS_RANDOM_UQ_HPP__

#include "AMS.h"
#include "wf/debug.h"
#include "wf/utils.hpp"

class RandomUQ
{
public:
PERFFASPECT()
inline void evaluate(const size_t ndata, bool *is_acceptable) const
{
if (resourceLocation == AMSResourceType::DEVICE) {
#ifdef __ENABLE_CUDA__
random_uq_device<<<1, 1>>>(is_acceptable, ndata, threshold);
#else
THROW(std::runtime_error,
"Random-uq is not configured to use device allocations");
#endif
} else {
random_uq_host(is_acceptable, ndata, threshold);
}
}
RandomUQ(AMSResourceType resourceLocation, float threshold)
: resourceLocation(resourceLocation), threshold(threshold)
{
}

private:
AMSResourceType resourceLocation;
float threshold;
};

#endif
70 changes: 51 additions & 19 deletions src/AMSlib/ml/uq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,46 @@

#include "AMS.h"
#include "ml/hdcache.hpp"
#include "ml/random_uq.hpp"
#include "ml/surrogate.hpp"
#include "wf/resource_manager.hpp"

template <typename FPTypeValue>
class UQ
{
public:
UQ(AMSResourceType resourceLocation,
const AMSUQPolicy uqPolicy,
const char *uqPath,
const int nClusters,
const char *surrogatePath,
FPTypeValue threshold)
: uqPolicy(uqPolicy), threshold(threshold)
{

if (surrogatePath) {
bool is_DeltaUQ = ((uqPolicy == AMSUQPolicy::DeltaUQ_Max ||
uqPolicy == AMSUQPolicy::DeltaUQ_Mean)
? true
: false);
surrogate = SurrogateModel<FPTypeValue>::getInstance(surrogatePath,
resourceLocation,
is_DeltaUQ);
}

if (uqPath)
hdcache = HDCache<FPTypeValue>::getInstance(
uqPath, resourceLocation, uqPolicy, nClusters, threshold);

if (uqPolicy == AMSUQPolicy::RandomUQ)
randomUQ = std::make_unique<RandomUQ>(resourceLocation, threshold);
}

PERFFASPECT()
static void evaluate(
AMSUQPolicy uqPolicy,
const int totalElements,
std::vector<const FPTypeValue *> &inputs,
std::vector<FPTypeValue *> &outputs,
const std::shared_ptr<HDCache<FPTypeValue>> &hdcache,
const std::shared_ptr<SurrogateModel<FPTypeValue>> &surrogate,
bool *p_ml_acceptable)
void evaluate(const int totalElements,
std::vector<const FPTypeValue *> &inputs,
std::vector<FPTypeValue *> &outputs,
bool *p_ml_acceptable)
{
if ((uqPolicy == AMSUQPolicy::DeltaUQ_Mean) ||
(uqPolicy == AMSUQPolicy::DeltaUQ_Max)) {
Expand All @@ -47,20 +71,20 @@ class UQ
surrogate->evaluate(totalElements, inputs, outputs, outputs_stdev);
CALIPER(CALI_MARK_END("SURROGATE");)

if (uqPolicy == DeltaUQ_Mean) {
if (uqPolicy == AMSUQPolicy::DeltaUQ_Mean) {
for (size_t i = 0; i < totalElements; ++i) {
// Use double for increased precision, range in the calculation
double mean = 0.0;
for (size_t dim = 0; dim < ndims; ++dim)
mean += outputs_stdev[dim][i];
mean /= ndims;
p_ml_acceptable[i] = (mean < _threshold);
p_ml_acceptable[i] = (mean < threshold);
}
} else if (uqPolicy == DeltaUQ_Max) {
} else if (uqPolicy == AMSUQPolicy::DeltaUQ_Max) {
for (size_t i = 0; i < totalElements; ++i) {
bool is_acceptable = true;
for (size_t dim = 0; dim < ndims; ++dim)
if (outputs_stdev[dim][i] >= _threshold) {
if (outputs_stdev[dim][i] >= threshold) {
is_acceptable = false;
break;
}
Expand All @@ -75,7 +99,8 @@ class UQ
ams::ResourceManager::deallocate(outputs_stdev[dim],
AMSResourceType::HOST);
CALIPER(CALI_MARK_END("DELTAUQ");)
} else {
} else if (uqPolicy == AMSUQPolicy::FAISS_Mean ||
uqPolicy == AMSUQPolicy::FAISS_Max) {
CALIPER(CALI_MARK_BEGIN("HDCACHE");)
if (hdcache) hdcache->evaluate(totalElements, inputs, p_ml_acceptable);
CALIPER(CALI_MARK_END("HDCACHE");)
Expand All @@ -84,17 +109,24 @@ class UQ
DBG(Workflow, "Model exists, I am calling surrogate (for all data)");
surrogate->evaluate(totalElements, inputs, outputs);
CALIPER(CALI_MARK_END("SURROGATE");)
} else if (uqPolicy == AMSUQPolicy::RandomUQ) {
CALIPER(CALI_MARK_BEGIN("RANDOM_UQ");)
DBG(Workflow, "Evaluating Random UQ");
randomUQ->evaluate(totalElements, p_ml_acceptable);
CALIPER(CALI_MARK_END("RANDOM_UQ");)
} else {
THROW(std::runtime_error, "Invalid UQ policy");
}
}

PERFFASPECT()
static void setThreshold(FPTypeValue threshold) { _threshold = threshold; }
bool hasSurrogate() { return (surrogate ? true : false); }

private:
static FPTypeValue _threshold;
AMSUQPolicy uqPolicy;
FPTypeValue threshold;
std::unique_ptr<RandomUQ> randomUQ;
std::shared_ptr<HDCache<FPTypeValue>> hdcache;
std::shared_ptr<SurrogateModel<FPTypeValue>> surrogate;
};

template <typename FPTypeValue>
FPTypeValue UQ<FPTypeValue>::_threshold = 0.5;

#endif
Loading

0 comments on commit 1d8deb9

Please sign in to comment.