Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute DeltaUQ on the device when possible #61

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 115 additions & 13 deletions src/AMSlib/ml/surrogate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include <string>
#include <unordered_map>

#include "AMS.h"
#include "wf/cuda/utilities.cuh"

#ifdef __ENABLE_TORCH__
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
Expand Down Expand Up @@ -176,6 +179,85 @@ class SurrogateModel
_load_torch(model_path, torch::Device(device_name), torch::kFloat32);
}

// -------------------------------------------------------------------------
// compute delta uq predicates
// -------------------------------------------------------------------------
void computeDeltaUQPredicates(AMSUQPolicy uq_policy,
const TypeInValue* __restrict__ outputs_stdev,
bool* __restrict__ predicates,
const size_t nrows,
const size_t ncols,
const double threshold)
{
auto computeDeltaUQMeanPredicatesHost = [&]() {
for (size_t i = 0; i < nrows; ++i) {
double mean = 0.0;
for (size_t j = 0; j < ncols; ++j)
mean += outputs_stdev[j + i * ncols];
mean /= ncols;

predicates[i] = (mean < threshold);
}
};

auto computeDeltaUQMaxPredicatesHost = [&]() {
for (size_t i = 0; i < nrows; ++i) {
predicates[i] = true;
for (size_t j = 0; j < ncols; ++j)
if (outputs_stdev[j + i * ncols] >= threshold) {
predicates[i] = false;
break;
}
}
};

if (uq_policy == AMSUQPolicy::DeltaUQ_Mean) {
if (model_resource == AMSResourceType::DEVICE)
#ifdef __ENABLE_CUDA__
{
DBG(Surrogate, "Compute mean delta uq predicates on device\n");
constexpr int block_size = 256;
int grid_size = divup(nrows, block_size);
computeDeltaUQMeanPredicatesKernel<<<grid_size, block_size>>>(
outputs_stdev, predicates, nrows, ncols, threshold);
// TODO: use combined routine when it lands.
cudaDeviceSynchronize();
CUDACHECKERROR();
}
#else
THROW(std::runtime_error,
"Expected CUDA is enabled when model data are on DEVICE");
#endif
else {
DBG(Surrogate, "Compute mean delta uq predicates on host\n");
computeDeltaUQMeanPredicatesHost();
}
} else if (uq_policy == AMSUQPolicy::DeltaUQ_Max) {
if (model_resource == AMSResourceType::DEVICE)
#ifdef __ENABLE_CUDA__
{
DBG(Surrogate, "Compute max delta uq predicates on device\n");
constexpr int block_size = 256;
int grid_size = divup(nrows, block_size);
computeDeltaUQMaxPredicatesKernel<<<grid_size, block_size>>>(
outputs_stdev, predicates, nrows, ncols, threshold);
// TODO: use combined routine when it lands.
cudaDeviceSynchronize();
CUDACHECKERROR();
}
#else
THROW(std::runtime_error,
"Expected CUDA is enabled when model data are on DEVICE");
#endif
else {
DBG(Surrogate, "Compute max delta uq predicates on host\n");
computeDeltaUQMaxPredicatesHost();
}
} else
THROW(std::runtime_error,
"Invalid uq_policy to compute delta uq predicates");
}

// -------------------------------------------------------------------------
// evaluate a torch model
// -------------------------------------------------------------------------
Expand All @@ -185,7 +267,9 @@ class SurrogateModel
size_t num_out,
const TypeInValue** inputs,
TypeInValue** outputs,
TypeInValue** outputs_stdev)
AMSUQPolicy uq_policy,
bool* predicates,
double threshold)
{
//torch::NoGradGuard no_grad;
c10::InferenceMode guard(true);
Expand All @@ -195,7 +279,6 @@ class SurrogateModel

input.set_requires_grad(false);
if (_is_DeltaUQ) {
assert(outputs_stdev && "Expected non-null outputs_stdev");
// The deltauq surrogate returns a tuple of (outputs, outputs_stdev)
CALIPER(CALI_MARK_BEGIN("SURROGATE-EVAL");)
auto output_tuple = module.forward({input}).toTuple();
Expand All @@ -206,11 +289,14 @@ class SurrogateModel
at::Tensor output_stdev_tensor =
output_tuple->elements()[1].toTensor().detach();
CALIPER(CALI_MARK_BEGIN("TENSOR_TO_ARRAY");)

computeDeltaUQPredicates(uq_policy,
output_stdev_tensor.data_ptr<TypeInValue>(),
predicates,
num_elements,
num_out,
threshold);
tensorToArray(output_mean_tensor, num_elements, num_out, outputs);
tensorToHostArray(output_stdev_tensor,
num_elements,
num_out,
outputs_stdev);
CALIPER(CALI_MARK_END("TENSOR_TO_ARRAY");)
} else {
CALIPER(CALI_MARK_BEGIN("SURROGATE-EVAL");)
Expand Down Expand Up @@ -247,8 +333,7 @@ class SurrogateModel
long num_in,
size_t num_out,
const TypeInValue** inputs,
TypeInValue** outputs,
TypeInValue** outputs_stdev)
TypeInValue** outputs)
{
}

Expand Down Expand Up @@ -340,23 +425,36 @@ class SurrogateModel
size_t num_out,
const TypeInValue** inputs,
TypeInValue** outputs,
TypeInValue** outputs_stdev = nullptr)
AMSUQPolicy uq_policy = AMSUQPolicy::AMSUQPolicy_BEGIN,
bool* predicates = nullptr,
double threshold = 0.0)
{
_evaluate(num_elements, num_in, num_out, inputs, outputs, outputs_stdev);
_evaluate(num_elements,
num_in,
num_out,
inputs,
outputs,
uq_policy,
predicates,
threshold);
}

PERFFASPECT()
inline void evaluate(long num_elements,
std::vector<const TypeInValue*> inputs,
std::vector<TypeInValue*> outputs,
std::vector<TypeInValue*> outputs_stdev)
AMSUQPolicy uq_policy,
bool* predicates,
double threshold)
{
_evaluate(num_elements,
inputs.size(),
outputs.size(),
static_cast<const TypeInValue**>(inputs.data()),
static_cast<TypeInValue**>(outputs.data()),
static_cast<TypeInValue**>(outputs_stdev.data()));
uq_policy,
predicates,
threshold);
}

PERFFASPECT()
Expand All @@ -369,7 +467,9 @@ class SurrogateModel
outputs.size(),
static_cast<const TypeInValue**>(inputs.data()),
static_cast<TypeInValue**>(outputs.data()),
nullptr);
AMSUQPolicy::AMSUQPolicy_BEGIN,
nullptr,
0.0);
}

#ifdef __ENABLE_TORCH__
Expand Down Expand Up @@ -410,6 +510,8 @@ class SurrogateModel
else
_load<TypeInValue>(new_path, "cuda");
}

AMSResourceType getModelResource() const { return model_resource; }
};

template <typename T>
Expand Down
56 changes: 16 additions & 40 deletions src/AMSlib/ml/uq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class UQ

if (uqPolicy == AMSUQPolicy::RandomUQ)
randomUQ = std::make_unique<RandomUQ>(resourceLocation, threshold);

DBG(UQ, "UQ Model is of type %d", uqPolicy)
}

PERFFASPECT()
Expand All @@ -73,48 +75,22 @@ class UQ
{
if ((uqPolicy == AMSUQPolicy::DeltaUQ_Mean) ||
(uqPolicy == AMSUQPolicy::DeltaUQ_Max)) {
CALIPER(CALI_MARK_BEGIN("DELTAUQ");)
const size_t ndims = outputs.size();
std::vector<FPTypeValue *> outputs_stdev(ndims);
// TODO: Enable device-side allocation and predicate calculation.
auto &rm = ams::ResourceManager::getInstance();
for (int dim = 0; dim < ndims; ++dim)
outputs_stdev[dim] =
rm.allocate<FPTypeValue>(totalElements, AMSResourceType::HOST);

CALIPER(CALI_MARK_BEGIN("SURROGATE");)
DBG(Workflow,
"Model exists, I am calling DeltaUQ surrogate (for all data)");
surrogate->evaluate(totalElements, inputs, outputs, outputs_stdev);
CALIPER(CALI_MARK_END("SURROGATE");)
auto &rm = ams::ResourceManager::getInstance();

if (uqPolicy == AMSUQPolicy::DeltaUQ_Mean) {
for (size_t i = 0; i < totalElements; ++i) {
// Use double for increased precision, range in the calculation
double mean = 0.0;
for (size_t dim = 0; dim < ndims; ++dim)
mean += outputs_stdev[dim][i];
mean /= ndims;
p_ml_acceptable[i] = (mean < threshold);
}
} else if (uqPolicy == AMSUQPolicy::DeltaUQ_Max) {
for (size_t i = 0; i < totalElements; ++i) {
bool is_acceptable = true;
for (size_t dim = 0; dim < ndims; ++dim)
if (outputs_stdev[dim][i] >= threshold) {
is_acceptable = false;
break;
}

p_ml_acceptable[i] = is_acceptable;
}
} else {
THROW(std::runtime_error, "Invalid UQ policy");
}

for (int dim = 0; dim < ndims; ++dim)
rm.deallocate(outputs_stdev[dim], AMSResourceType::HOST);
CALIPER(CALI_MARK_END("DELTAUQ");)
CALIPER(CALI_MARK_BEGIN("DELTAUQ SURROGATE");)
DBG(UQ,
"Model exists, I am calling DeltaUQ surrogate [%ld %ld] -> (mu:[%ld "
"%ld], std:[%ld %ld])",
totalElements,
inputs.size(),
totalElements,
outputs.size(),
totalElements,
inputs.size());
surrogate->evaluate(
totalElements, inputs, outputs, uqPolicy, p_ml_acceptable, threshold);
CALIPER(CALI_MARK_END("DELTAUQ SURROGATE");)
} else if (uqPolicy == AMSUQPolicy::FAISS_Mean ||
uqPolicy == AMSUQPolicy::FAISS_Max) {
CALIPER(CALI_MARK_BEGIN("HDCACHE");)
Expand Down
45 changes: 45 additions & 0 deletions src/AMSlib/wf/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,51 @@ inline void DtoHMemcpy(void *dest, void *src, size_t nBytes)
{
cudaMemcpy(dest, src, nBytes, cudaMemcpyDeviceToHost);
}

template <typename scalar_t>
__global__ void computeDeltaUQMeanPredicatesKernel(
const scalar_t *__restrict__ outputs_stdev,
bool *__restrict__ predicates,
const size_t nrows,
const size_t ncols,
const double threshold)
{

size_t idx = blockDim.x * blockIdx.x + threadIdx.x;
size_t stride = blockDim.x * gridDim.x;
// Compute mean over columns, strided loop.
for (size_t i = idx; i < nrows; i += stride) {
double mean = 0.0;
for (size_t j = 0; j < ncols; ++j)
mean += outputs_stdev[j + i * ncols];
mean /= ncols;

predicates[i] = (mean < threshold);
}
}

template <typename scalar_t>
__global__ void computeDeltaUQMaxPredicatesKernel(
const scalar_t *__restrict__ outputs_stdev,
bool *__restrict__ predicates,
const size_t nrows,
const size_t ncols,
const double threshold)
{

size_t idx = blockDim.x * blockIdx.x + threadIdx.x;
size_t stride = blockDim.x * gridDim.x;
// Compute max delta uq over columns, strided loop.
for (size_t i = idx; i < nrows; i += stride) {
predicates[i] = true;
for (size_t j = 0; j < ncols; ++j)
if (outputs_stdev[j + i * ncols] >= threshold) {
predicates[i] = false;
break;
}
}
}

#else
PERFFASPECT()
inline void DtoDMemcpy(void *dest, void *src, size_t nBytes)
Expand Down
Loading
Loading