Skip to content

Commit

Permalink
Provides functionality to update a model at execution time (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
koparasy authored Feb 26, 2024
1 parent 8ddca6c commit 005b737
Show file tree
Hide file tree
Showing 10 changed files with 242 additions and 23 deletions.
17 changes: 17 additions & 0 deletions src/AMSlib/ml/surrogate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,23 @@ class SurrogateModel
}

bool is_DeltaUQ() { return _is_DeltaUQ; }

void update(std::string new_path)
{
/* This function updates the underlying torch model,
* with a new one pointed at location modelPath. The previous
* one is destructed automatically.
*
* TODO: I decided to not update the model path on the ``instances''
* map. As we currently expect this change will be agnostic to the application
* user. But, in any case we should keep track of which model has been used at which
* invocation. This is currently not done.
*/
if (model_resource != AMSResourceType::DEVICE)
_load<TypeInValue>(new_path, "cpu");
else
_load<TypeInValue>(new_path, "cuda");
}
};

template <typename T>
Expand Down
25 changes: 20 additions & 5 deletions src/AMSlib/ml/uq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,10 @@ class UQ
const size_t ndims = outputs.size();
std::vector<FPTypeValue *> outputs_stdev(ndims);
// TODO: Enable device-side allocation and predicate calculation.
auto& rm = ams::ResourceManager::getInstance();
auto &rm = ams::ResourceManager::getInstance();
for (int dim = 0; dim < ndims; ++dim)
outputs_stdev[dim] =
rm.allocate<FPTypeValue>(totalElements,
AMSResourceType::HOST);
rm.allocate<FPTypeValue>(totalElements, AMSResourceType::HOST);

CALIPER(CALI_MARK_BEGIN("SURROGATE");)
DBG(Workflow,
Expand Down Expand Up @@ -114,8 +113,7 @@ class UQ
}

for (int dim = 0; dim < ndims; ++dim)
rm.deallocate(outputs_stdev[dim],
AMSResourceType::HOST);
rm.deallocate(outputs_stdev[dim], AMSResourceType::HOST);
CALIPER(CALI_MARK_END("DELTAUQ");)
} else if (uqPolicy == AMSUQPolicy::FAISS_Mean ||
uqPolicy == AMSUQPolicy::FAISS_Max) {
Expand All @@ -142,6 +140,23 @@ class UQ
}
}

void updateModel(std::string model_path, std::string uq_path = "")
{
if (uqPolicy != AMSUQPolicy::RandomUQ &&
uqPolicy != AMSUQPolicy::DeltaUQ_Max &&
uqPolicy != AMSUQPolicy::DeltaUQ_Mean) {
THROW(std::runtime_error, "UQ model does not support update.");
}

if (uqPolicy == AMSUQPolicy::RandomUQ && uq_path != "") {
WARNING(Workflow,
"RandomUQ cannot update hdcache path, ignoring argument")
}

surrogate->update(model_path);
return;
}

bool hasSurrogate() { return (surrogate ? true : false); }

private:
Expand Down
34 changes: 19 additions & 15 deletions src/AMSlib/wf/basedb.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ class BaseDB
std::vector<TypeValue*>& outputs) = 0;

uint64_t getId() const { return id; }

virtual bool updateModel() { return false; }
};

/**
Expand Down Expand Up @@ -835,7 +837,8 @@ struct AMSMsgHeader {
uint8_t new_dtype = data_blob[current_offset];
current_offset += sizeof(uint8_t);
// MPI rank (should be 2 bytes)
uint16_t new_mpirank = (reinterpret_cast<uint16_t*>(data_blob + current_offset))[0];
uint16_t new_mpirank =
(reinterpret_cast<uint16_t*>(data_blob + current_offset))[0];
current_offset += sizeof(uint16_t);
// Num elem (should be 4 bytes)
uint32_t new_num_elem;
Expand Down Expand Up @@ -1844,18 +1847,19 @@ class RMQPublisherHandler : public AMQP::LibEventHandler
std::find_if(buf.begin(), buf.end(), [&msg_id](const AMSMessage& obj) {
return obj.id() == msg_id;
});
CFATAL(RMQPublisherHandler, it == buf.end(),
"Failed to deallocate msg #%d: not found",
msg_id)
CFATAL(RMQPublisherHandler,
it == buf.end(),
"Failed to deallocate msg #%d: not found",
msg_id)
auto& msg = *it;
auto& rm = ams::ResourceManager::getInstance();
try {
rm.deallocate(msg.data(), AMSResourceType::HOST);
} catch (const umpire::util::Exception& e) {
FATAL(RMQPublisherHandler,
"Failed to deallocate #%d (%p)",
msg.id(),
msg.data());
"Failed to deallocate #%d (%p)",
msg.id(),
msg.data());
}
DBG(RMQPublisherHandler, "Deallocated msg #%d (%p)", msg.id(), msg.data())
buf.erase(it);
Expand All @@ -1875,9 +1879,9 @@ class RMQPublisherHandler : public AMQP::LibEventHandler
rm.deallocate(dp.data(), AMSResourceType::HOST);
} catch (const umpire::util::Exception& e) {
FATAL(RMQPublisherHandler,
"Failed to deallocate msg #%d (%p)",
dp.id(),
dp.data());
"Failed to deallocate msg #%d (%p)",
dp.id(),
dp.data());
}
}
buffer.clear();
Expand Down Expand Up @@ -2308,11 +2312,11 @@ class RabbitMQDB final : public BaseDB<TypeValue>
}));

DBG(RMQPublisher,
"[rank=%d] we have %d buffered messages that will get re-send "
"(starting from msg #%d).",
_rank,
messages.size(),
msg_min.id())
"[rank=%d] we have %d buffered messages that will get re-send "
"(starting from msg #%d).",
_rank,
messages.size(),
msg_min.id())

// Stop the faulty publisher
_publisher->stop();
Expand Down
5 changes: 5 additions & 0 deletions src/AMSlib/wf/workflow.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,11 @@ class AMSWorkflow
CALIPER(CALI_MARK_END("AMSEvaluate");)
return;
}

if (DB && DB->updateModel()) {
UQModel->updateModel("");
}

// The predicate with which we will split the data on a later step
bool *p_ml_acceptable = rm.allocate<bool>(totalElements, appDataLoc);

Expand Down
3 changes: 3 additions & 0 deletions tests/AMSlib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ if (WITH_TORCH)
add_test(NAME AMSExampleSingleDeltaUQ::HOST COMMAND ams_example --precision single --uqtype deltauq-mean -db ./db -S ${CMAKE_CURRENT_SOURCE_DIR}/tuple-single.torchscript -e 100)
add_test(NAME AMSExampleSingleRandomUQ::HOST COMMAND ams_example --precision single --uqtype random -S ${CMAKE_CURRENT_SOURCE_DIR}/debug_model.pt -e 100)
add_test(NAME AMSExampleDoubleRandomUQ::HOST COMMAND ams_example --precision double --uqtype random -S ${CMAKE_CURRENT_SOURCE_DIR}/debug_model.pt -e 100)

BUILD_TEST(ams_update_model ams_update_model.cpp)
ADDTEST(ams_update_model AMSUpdateModelDouble "double" ${CMAKE_CURRENT_SOURCE_DIR}/ConstantZeroModel_cpu.pt ${CMAKE_CURRENT_SOURCE_DIR}/ConstantOneModel_cpu.pt)
endif()

if(WITH_FAISS)
Expand Down
Binary file added tests/AMSlib/ConstantOneModel_cpu.pt
Binary file not shown.
Binary file added tests/AMSlib/ConstantZeroModel_cpu.pt
Binary file not shown.
112 changes: 112 additions & 0 deletions tests/AMSlib/ams_update_model.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#include <AMS.h>
#include <ATen/core/interned_strings.h>
#include <c10/core/TensorOptions.h>
#include <torch/types.h>

#include <cstring>
#include <iostream>
#include <ml/surrogate.hpp>
#include <umpire/ResourceManager.hpp>
#include <umpire/Umpire.hpp>
#include <vector>
#include <wf/resource_manager.hpp>

#define SIZE (32L)

template <typename T>
bool inference(SurrogateModel<T> &model,
AMSResourceType resource,
std::string update_path)
{
using namespace ams;

std::vector<const T *> inputs;
std::vector<T *> outputs;
auto &ams_rm = ams::ResourceManager::getInstance();

for (int i = 0; i < 2; i++)
inputs.push_back(ams_rm.allocate<T>(SIZE, resource));

for (int i = 0; i < 4 * 2; i++)
outputs.push_back(ams_rm.allocate<T>(SIZE, resource));

for (int repeat = 0; repeat < 2; repeat++) {
model.evaluate(
SIZE, inputs.size(), 4, inputs.data(), &(outputs.data()[repeat * 4]));
if (repeat == 0) model.update(update_path);
}

// Verify
bool errors = false;
for (int i = 0; i < 4; i++) {
T *first_model_out = outputs[i];
T *second_model_out = outputs[i + 4];
if (resource == AMSResourceType::DEVICE) {
first_model_out = ams_rm.allocate<T>(SIZE, AMSResourceType::HOST);
second_model_out = ams_rm.allocate<T>(SIZE, AMSResourceType::HOST);
ams_rm.copy(outputs[i], first_model_out, SIZE * sizeof(T));
ams_rm.copy(outputs[i + 4], second_model_out, SIZE * sizeof(T));
}

for (int j = 0; j < SIZE; j++) {
if (first_model_out[j] != 1.0) {
errors = true;
std::cout << "One Model " << first_model_out << " " << j << " "
<< first_model_out[j] << "\n";
}
if (second_model_out[j] != 0.0) {
std::cout << "Zero Model " << second_model_out << " " << j << " "
<< second_model_out[j] << "\n";
errors = true;
}
}

if (resource == AMSResourceType::DEVICE) {
ams_rm.deallocate(first_model_out, resource);
ams_rm.deallocate(second_model_out, resource);
}
}

for (int i = 0; i < 2; i++)
ams_rm.deallocate(const_cast<T *>(inputs[i]), resource);

for (int i = 0; i < 4 * 2; i++)
ams_rm.deallocate(outputs[i], resource);

return errors;
}


int main(int argc, char *argv[])
{
using namespace ams;
auto &ams_rm = ams::ResourceManager::getInstance();
int use_device = std::atoi(argv[1]);
char *data_type = argv[2];
char *zero_model = argv[3];
char *one_model = argv[4];
char *swap;

AMSResourceType resource = AMSResourceType::HOST;
if (use_device == 1) {
resource = AMSResourceType::DEVICE;
}


ams_rm.init();
int ret = 0;
if (std::strcmp("double", data_type) == 0) {
std::shared_ptr<SurrogateModel<double>> model =
SurrogateModel<double>::getInstance(one_model, resource);
assert(model->is_double());
ret = inference<double>(*model, resource, zero_model);
} else if (std::strcmp("single", data_type) == 0) {
std::shared_ptr<SurrogateModel<float>> model =
SurrogateModel<float>::getInstance(one_model, resource);
assert(!model->is_double());
ret = inference<float>(*model, resource, zero_model);
}
std::cout << "Zero Model is " << zero_model << "\n";
std::cout << "One Model is " << one_model << "\n";
return ret;
}
64 changes: 64 additions & 0 deletions tests/AMSlib/generate_constant_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import os
import sys
import numpy as np
from torch.autograd import Variable
from torch import jit

class ConstantModel(torch.nn.Module):
def __init__(self, inputSize, outputSize, constant):
super(ConstantModel, self).__init__()
self.linear = torch.nn.Linear(inputSize, outputSize)
self.linear.weight.data.fill_(0.0)
self.linear.bias.data.fill_(constant)

def forward(self, x):
y = self.linear(x)
return y

def main(args):
inputDim = int(args[1])
outputDim = int(args[2])
device = args[3]
enable_cuda = True
if device == "cuda":
enable_cuda = True
suffix = '_gpu'
elif device == "cpu":
enable_cuda = False
suffix = '_cpu'

model = ConstantModel(inputDim, outputDim, 1.0).double()
if torch.cuda.is_available() and enable_cuda:
model = model.cuda()

model.eval()
with torch.jit.optimized_execution(True):
traced = torch.jit.trace(model, (torch.randn(inputDim, dtype=torch.double), ))
traced.save(f"ConstantOneModel_{suffix}.pt")

model = ConstantModel(inputDim, outputDim, 0.0).double()
if torch.cuda.is_available() and enable_cuda:
model = model.cuda()

model.eval()
with torch.jit.optimized_execution(True):
traced = torch.jit.trace(model, (torch.randn(inputDim, dtype=torch.double), ))
traced.save(f"ConstantZeroModel_{suffix}.pt")

inputs = Variable(torch.from_numpy(np.zeros((1, inputDim))))
zero_model = jit.load(f"ConstantZeroModel_{suffix}.pt")
print("ZeroModel", zero_model(inputs))

one_model = jit.load(f"ConstantOneModel_{suffix}.pt")
print("OneModel", one_model(inputs))




if __name__ == '__main__':
main(sys.argv)




5 changes: 2 additions & 3 deletions tests/AMSlib/torch_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ void inference(SurrogateModel<T> &model, AMSResourceType resource)

std::vector<const T *> inputs;
std::vector<T *> outputs;
auto& ams_rm = ams::ResourceManager::getInstance();
auto &ams_rm = ams::ResourceManager::getInstance();

for (int i = 0; i < 2; i++)
inputs.push_back(ams_rm.allocate<T>(SIZE, resource));
Expand All @@ -46,8 +46,7 @@ void inference(SurrogateModel<T> &model, AMSResourceType resource)
int main(int argc, char *argv[])
{
using namespace ams;
auto &rm = umpire::ResourceManager::getInstance();
auto& ams_rm = ams::ResourceManager::getInstance();
auto &ams_rm = ams::ResourceManager::getInstance();
int use_device = std::atoi(argv[1]);
char *model_path = argv[2];
char *data_type = argv[3];
Expand Down

0 comments on commit 005b737

Please sign in to comment.