Skip to content

Commit

Permalink
Compute DeltaUQ predicates on device when possible
Browse files Browse the repository at this point in the history
  • Loading branch information
koparasy authored and ggeorgakoudis committed Apr 9, 2024
1 parent 41687c2 commit 8f45cf5
Show file tree
Hide file tree
Showing 11 changed files with 304 additions and 84 deletions.
129 changes: 117 additions & 12 deletions src/AMSlib/ml/surrogate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include <string>
#include <unordered_map>

#include "AMS.h"
#include "wf/cuda/utilities.cuh"

#ifdef __ENABLE_TORCH__
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
Expand Down Expand Up @@ -176,6 +179,85 @@ class SurrogateModel
_load_torch(model_path, torch::Device(device_name), torch::kFloat32);
}

// -------------------------------------------------------------------------
// compute delta uq predicates
// -------------------------------------------------------------------------
void computeDeltaUQPredicates(AMSUQPolicy uq_policy,
const TypeInValue* __restrict__ outputs_stdev,
bool* __restrict__ predicates,
const size_t nrows,
const size_t ncols,
const double threshold)
{
auto computeDeltaUQMeanPredicatesHost = [&]() {
for (size_t i = 0; i < nrows; ++i) {
double mean = 0.0;
for (size_t j = 0; j < ncols; ++j)
mean += outputs_stdev[j + i * ncols];
mean /= ncols;

predicates[i] = (mean < threshold);
}
};

auto computeDeltaUQMaxPredicatesHost = [&]() {
for (size_t i = 0; i < nrows; ++i) {
predicates[i] = true;
for (size_t j = 0; j < ncols; ++j)
if (outputs_stdev[j + i * ncols] >= threshold) {
predicates[i] = false;
break;
}
}
};

if (uq_policy == AMSUQPolicy::DeltaUQ_Mean) {
if (model_resource == AMSResourceType::DEVICE)
#ifdef __ENABLE_CUDA__
{
DBG(Surrogate, "Compute mean delta uq predicates on device\n");
constexpr int block_size = 256;
int grid_size = divup(nrows, block_size);
computeDeltaUQMeanPredicatesKernel<<<grid_size, block_size>>>(
outputs_stdev, predicates, nrows, ncols, threshold);
// TODO: use combined routine when it lands.
cudaDeviceSynchronize();
CUDACHECKERROR();
}
#else
THROW(std::runtime_error,
"Expected CUDA is enabled when model data are on DEVICE");
#endif
else {
DBG(Surrogate, "Compute mean delta uq predicates on host\n");
computeDeltaUQMeanPredicatesHost();
}
} else if (uq_policy == AMSUQPolicy::DeltaUQ_Max) {
if (model_resource == AMSResourceType::DEVICE)
#ifdef __ENABLE_CUDA__
{
DBG(Surrogate, "Compute max delta uq predicates on device\n");
constexpr int block_size = 256;
int grid_size = divup(nrows, block_size);
computeDeltaUQMaxPredicatesKernel<<<grid_size, block_size>>>(
outputs_stdev, predicates, nrows, ncols, threshold);
// TODO: use combined routine when it lands.
cudaDeviceSynchronize();
CUDACHECKERROR();
}
#else
THROW(std::runtime_error,
"Expected CUDA is enabled when model data are on DEVICE");
#endif
else {
DBG(Surrogate, "Compute max delta uq predicates on host\n");
computeDeltaUQMaxPredicatesHost();
}
} else
THROW(std::runtime_error,
"Invalid uq_policy to compute delta uq predicates");
}

// -------------------------------------------------------------------------
// evaluate a torch model
// -------------------------------------------------------------------------
Expand All @@ -185,7 +267,9 @@ class SurrogateModel
size_t num_out,
const TypeInValue** inputs,
TypeInValue** outputs,
TypeInValue** outputs_stdev)
AMSUQPolicy uq_policy,
bool* predicates,
double threshold)
{
//torch::NoGradGuard no_grad;
c10::InferenceMode guard(true);
Expand All @@ -195,7 +279,6 @@ class SurrogateModel

input.set_requires_grad(false);
if (_is_DeltaUQ) {
assert(outputs_stdev && "Expected non-null outputs_stdev");
// The deltauq surrogate returns a tuple of (outputs, outputs_stdev)
CALIPER(CALI_MARK_BEGIN("SURROGATE-EVAL");)
auto output_tuple = module.forward({input}).toTuple();
Expand All @@ -206,11 +289,14 @@ class SurrogateModel
at::Tensor output_stdev_tensor =
output_tuple->elements()[1].toTensor().detach();
CALIPER(CALI_MARK_BEGIN("TENSOR_TO_ARRAY");)

computeDeltaUQPredicates(uq_policy,
output_stdev_tensor.data_ptr<TypeInValue>(),
predicates,
num_elements,
num_out,
threshold);
tensorToArray(output_mean_tensor, num_elements, num_out, outputs);
tensorToHostArray(output_stdev_tensor,
num_elements,
num_out,
outputs_stdev);
CALIPER(CALI_MARK_END("TENSOR_TO_ARRAY");)
} else {
CALIPER(CALI_MARK_BEGIN("SURROGATE-EVAL");)
Expand Down Expand Up @@ -248,7 +334,9 @@ class SurrogateModel
size_t num_out,
const TypeInValue** inputs,
TypeInValue** outputs,
TypeInValue** outputs_stdev)
AMSUQPolicy uq_policy,
bool* predicates,
double threshold)
{
}

Expand Down Expand Up @@ -340,23 +428,36 @@ class SurrogateModel
size_t num_out,
const TypeInValue** inputs,
TypeInValue** outputs,
TypeInValue** outputs_stdev = nullptr)
AMSUQPolicy uq_policy = AMSUQPolicy::AMSUQPolicy_BEGIN,
bool* predicates = nullptr,
double threshold = 0.0)
{
_evaluate(num_elements, num_in, num_out, inputs, outputs, outputs_stdev);
_evaluate(num_elements,
num_in,
num_out,
inputs,
outputs,
uq_policy,
predicates,
threshold);
}

PERFFASPECT()
inline void evaluate(long num_elements,
std::vector<const TypeInValue*> inputs,
std::vector<TypeInValue*> outputs,
std::vector<TypeInValue*> outputs_stdev)
AMSUQPolicy uq_policy,
bool* predicates,
double threshold)
{
_evaluate(num_elements,
inputs.size(),
outputs.size(),
static_cast<const TypeInValue**>(inputs.data()),
static_cast<TypeInValue**>(outputs.data()),
static_cast<TypeInValue**>(outputs_stdev.data()));
uq_policy,
predicates,
threshold);
}

PERFFASPECT()
Expand All @@ -369,7 +470,9 @@ class SurrogateModel
outputs.size(),
static_cast<const TypeInValue**>(inputs.data()),
static_cast<TypeInValue**>(outputs.data()),
nullptr);
AMSUQPolicy::AMSUQPolicy_BEGIN,
nullptr,
0.0);
}

#ifdef __ENABLE_TORCH__
Expand Down Expand Up @@ -410,6 +513,8 @@ class SurrogateModel
else
_load<TypeInValue>(new_path, "cuda");
}

AMSResourceType getModelResource() const { return model_resource; }
};

template <typename T>
Expand Down
56 changes: 16 additions & 40 deletions src/AMSlib/ml/uq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class UQ

if (uqPolicy == AMSUQPolicy::RandomUQ)
randomUQ = std::make_unique<RandomUQ>(resourceLocation, threshold);

DBG(UQ, "UQ Model is of type %d", uqPolicy)
}

PERFFASPECT()
Expand All @@ -73,48 +75,22 @@ class UQ
{
if ((uqPolicy == AMSUQPolicy::DeltaUQ_Mean) ||
(uqPolicy == AMSUQPolicy::DeltaUQ_Max)) {
CALIPER(CALI_MARK_BEGIN("DELTAUQ");)
const size_t ndims = outputs.size();
std::vector<FPTypeValue *> outputs_stdev(ndims);
// TODO: Enable device-side allocation and predicate calculation.
auto &rm = ams::ResourceManager::getInstance();
for (int dim = 0; dim < ndims; ++dim)
outputs_stdev[dim] =
rm.allocate<FPTypeValue>(totalElements, AMSResourceType::HOST);

CALIPER(CALI_MARK_BEGIN("SURROGATE");)
DBG(Workflow,
"Model exists, I am calling DeltaUQ surrogate (for all data)");
surrogate->evaluate(totalElements, inputs, outputs, outputs_stdev);
CALIPER(CALI_MARK_END("SURROGATE");)
auto &rm = ams::ResourceManager::getInstance();

if (uqPolicy == AMSUQPolicy::DeltaUQ_Mean) {
for (size_t i = 0; i < totalElements; ++i) {
// Use double for increased precision, range in the calculation
double mean = 0.0;
for (size_t dim = 0; dim < ndims; ++dim)
mean += outputs_stdev[dim][i];
mean /= ndims;
p_ml_acceptable[i] = (mean < threshold);
}
} else if (uqPolicy == AMSUQPolicy::DeltaUQ_Max) {
for (size_t i = 0; i < totalElements; ++i) {
bool is_acceptable = true;
for (size_t dim = 0; dim < ndims; ++dim)
if (outputs_stdev[dim][i] >= threshold) {
is_acceptable = false;
break;
}

p_ml_acceptable[i] = is_acceptable;
}
} else {
THROW(std::runtime_error, "Invalid UQ policy");
}

for (int dim = 0; dim < ndims; ++dim)
rm.deallocate(outputs_stdev[dim], AMSResourceType::HOST);
CALIPER(CALI_MARK_END("DELTAUQ");)
CALIPER(CALI_MARK_BEGIN("DELTAUQ SURROGATE");)
DBG(UQ,
"Model exists, I am calling DeltaUQ surrogate [%ld %ld] -> (mu:[%ld "
"%ld], std:[%ld %ld])",
totalElements,
inputs.size(),
totalElements,
outputs.size(),
totalElements,
inputs.size());
surrogate->evaluate(
totalElements, inputs, outputs, uqPolicy, p_ml_acceptable, threshold);
CALIPER(CALI_MARK_END("DELTAUQ SURROGATE");)
} else if (uqPolicy == AMSUQPolicy::FAISS_Mean ||
uqPolicy == AMSUQPolicy::FAISS_Max) {
CALIPER(CALI_MARK_BEGIN("HDCACHE");)
Expand Down
34 changes: 14 additions & 20 deletions src/AMSlib/wf/cuda/utilities.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#ifndef __DEVICE_UTILITIES__
#define __DEVICE_UTILITIES__

#ifdef __ENABLE_CUDA__

#include <curand.h>
#include <curand_kernel.h>
#include <thrust/device_vector.h>
Expand Down Expand Up @@ -288,27 +290,17 @@ int compact(bool cond,
{
int numBlocks = divup(length, blockSize);
auto& rm = ams::ResourceManager::getInstance();
int* d_BlocksCount =
rm.allocate<int>(numBlocks, AMSResourceType::DEVICE);
int* d_BlocksOffset =
rm.allocate<int>(numBlocks, AMSResourceType::DEVICE);
int* d_BlocksCount = rm.allocate<int>(numBlocks, AMSResourceType::DEVICE);
int* d_BlocksOffset = rm.allocate<int>(numBlocks, AMSResourceType::DEVICE);
// determine number of elements in the compacted list
int* h_BlocksCount =
rm.allocate<int>(numBlocks, AMSResourceType::HOST);
int* h_BlocksOffset =
rm.allocate<int>(numBlocks, AMSResourceType::HOST);

T** d_dense =
rm.allocate<T*>(dims, AMSResourceType::DEVICE);
T** d_sparse =
rm.allocate<T*>(dims, AMSResourceType::DEVICE);

rm.registerExternal(dense,
sizeof(T*) * dims,
AMSResourceType::HOST);
rm.registerExternal(sparse,
sizeof(T*) * dims,
AMSResourceType::HOST);
int* h_BlocksCount = rm.allocate<int>(numBlocks, AMSResourceType::HOST);
int* h_BlocksOffset = rm.allocate<int>(numBlocks, AMSResourceType::HOST);

T** d_dense = rm.allocate<T*>(dims, AMSResourceType::DEVICE);
T** d_sparse = rm.allocate<T*>(dims, AMSResourceType::DEVICE);

rm.registerExternal(dense, sizeof(T*) * dims, AMSResourceType::HOST);
rm.registerExternal(sparse, sizeof(T*) * dims, AMSResourceType::HOST);
rm.copy(dense, d_dense);
rm.copy(const_cast<T**>(sparse), d_sparse);
thrust::device_ptr<int> thrustPrt_bCount(d_BlocksCount);
Expand Down Expand Up @@ -468,3 +460,5 @@ void device_compute_predicate(float* data,
}

#endif

#endif
45 changes: 45 additions & 0 deletions src/AMSlib/wf/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,51 @@ inline void DtoHMemcpy(void *dest, void *src, size_t nBytes)
{
cudaMemcpy(dest, src, nBytes, cudaMemcpyDeviceToHost);
}

template <typename scalar_t>
__global__ void computeDeltaUQMeanPredicatesKernel(
const scalar_t *__restrict__ outputs_stdev,
bool *__restrict__ predicates,
const size_t nrows,
const size_t ncols,
const double threshold)
{

size_t idx = blockDim.x * blockIdx.x + threadIdx.x;
size_t stride = blockDim.x * gridDim.x;
// Compute mean over columns, strided loop.
for (size_t i = idx; i < nrows; i += stride) {
double mean = 0.0;
for (size_t j = 0; j < ncols; ++j)
mean += outputs_stdev[j + i * ncols];
mean /= ncols;

predicates[i] = (mean < threshold);
}
}

template <typename scalar_t>
__global__ void computeDeltaUQMaxPredicatesKernel(
const scalar_t *__restrict__ outputs_stdev,
bool *__restrict__ predicates,
const size_t nrows,
const size_t ncols,
const double threshold)
{

size_t idx = blockDim.x * blockIdx.x + threadIdx.x;
size_t stride = blockDim.x * gridDim.x;
// Compute max delta uq over columns, strided loop.
for (size_t i = idx; i < nrows; i += stride) {
predicates[i] = true;
for (size_t j = 0; j < ncols; ++j)
if (outputs_stdev[j + i * ncols] >= threshold) {
predicates[i] = false;
break;
}
}
}

#else
PERFFASPECT()
inline void DtoDMemcpy(void *dest, void *src, size_t nBytes)
Expand Down
Loading

0 comments on commit 8f45cf5

Please sign in to comment.