Skip to content

Commit

Permalink
Implement Delta UQ
Browse files Browse the repository at this point in the history
  • Loading branch information
ggeorgakoudis committed Oct 31, 2023
1 parent bb4f2d2 commit 662626b
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 59 deletions.
47 changes: 34 additions & 13 deletions examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <mfem.hpp>
#include <random>
#include <stdexcept>
#include <string>
#include <unordered_set>
#include <vector>
Expand Down Expand Up @@ -80,7 +82,9 @@ int main(int argc, char **argv)
const char *db_config = "";
const char *db_type = "";

const char *uq_policy_opt = "mean";
const char *precision_opt = "double";
AMSDType precision = AMSDType::Double;
const char *uq_policy_opt = "faiss-mean";
int k_nearest = 5;

int seed = 0;
Expand Down Expand Up @@ -112,6 +116,9 @@ int main(int argc, char **argv)
mfem::OptionsParser args(argc, argv);
args.AddOption(&device_name, "-d", "--device", "Device config string");

// set precision
args.AddOption(&precision_opt, "-pr", "--precision", "Set precision (single or double)");

// surrogate model
args.AddOption(&model_path, "-S", "--surrogate", "Path to surrogate model");
args.AddOption(&hdcache_path, "-H", "--hdcache", "Path to hdcache index");
Expand Down Expand Up @@ -197,11 +204,14 @@ int main(int argc, char **argv)
"-uq",
"--uqtype",
"Types of UQ to select from: \n"
"\t 'mean' Uncertainty is computed in comparison against the "
"\t 'faiss-mean' Uncertainty is computed in comparison "
"against the "
"mean distance of k-nearest neighbors\n"
"\t 'max': Uncertainty is computed in comparison with the "
"\t 'faiss-max': Uncertainty is computed in comparison with "
"the "
"k'st cluster \n"
"\t 'deltauq': Uncertainty through DUQ (not supported)\n");
"\t 'deltauq-mean': Uncertainty through DUQ using mean\n"
"\t 'deltauq-max': Uncertainty through DUQ using max\n");

args.AddOption(
&verbose, "-v", "--verbose", "-qu", "--quiet", "Print extra stuff");
Expand Down Expand Up @@ -272,14 +282,25 @@ int main(int argc, char **argv)
dbType = AMSDBType::RMQ;
}

AMSUQPolicy uq_policy = (std::strcmp(uq_policy_opt, "max") == 0)
? AMSUQPolicy::FAISSMax
: AMSUQPolicy::FAISSMean;

if (uq_policy != AMSUQPolicy::FAISSMax)
uq_policy = ((std::strcmp(uq_policy_opt, "deltauq") == 0))
? AMSUQPolicy::DeltaUQ
: AMSUQPolicy::FAISSMean;
AMSUQPolicy uq_policy;

if (strcmp(uq_policy_opt, "faiss-max") == 0)
uq_policy = AMSUQPolicy::FAISS_Max;
else if (strcmp(uq_policy_opt, "faiss-mean") == 0)
uq_policy = AMSUQPolicy::FAISS_Mean;
else if (strcmp(uq_policy_opt, "deltauq-max") == 0)
uq_policy = AMSUQPolicy::DeltaUQ_Max;
else if (strcmp(uq_policy_opt, "deltauq-mean") == 0)
uq_policy = AMSUQPolicy::DeltaUQ_Mean;
else
throw std::runtime_error("Invalid UQ policy");

if (strcmp(precision_opt, "single") == 0)
precision = AMSDType::Single;
else if (strcmp(precision_opt, "double") == 0)
precision = AMSDType::Double;
else
throw std::runtime_error("Invalid precision");

// set up a randomization seed
srand(seed + rId);
Expand Down Expand Up @@ -434,7 +455,7 @@ int main(int argc, char **argv)
if (lbalance) ams_loadBalance = AMSExecPolicy::BALANCED;

AMSConfig amsConf = {ams_loadBalance,
AMSDType::Double,
precision,
ams_device,
dbType,
callBack,
Expand Down
9 changes: 6 additions & 3 deletions src/include/AMS.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,12 @@ typedef enum { UBALANCED = 0, BALANCED } AMSExecPolicy;
typedef enum { None = 0, CSV, REDIS, HDF5, RMQ } AMSDBType;

typedef enum {
FAISSMean =0,
FAISSMax,
DeltaUQ // Not supported
AMSUQPolicy_BEGIN = 0,
FAISS_Mean,
FAISS_Max,
DeltaUQ_Mean,
DeltaUQ_Max,
AMSUQPolicy_END
} AMSUQPolicy;

typedef struct ams_conf {
Expand Down
21 changes: 11 additions & 10 deletions src/ml/hdcache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class HDCache
const bool m_use_random;
const bool m_use_device;
const int m_knbrs = 0;
const AMSUQPolicy m_policy = AMSUQPolicy::FAISSMean;
const AMSUQPolicy m_policy = AMSUQPolicy::FAISS_Mean;

AMSResourceType defaultRes;

Expand Down Expand Up @@ -215,6 +215,11 @@ class HDCache
return cache;
}

if (uqPolicy != AMSUQPolicy::FAISS_Mean &&
uqPolicy != AMSUQPolicy::FAISS_Max)
THROW(std::invalid_argument,
"Invalid UQ policy for hdcache" + std::to_string(uqPolicy));

DBG(UQModule, "Generating new cache under (%s)", cache_path.c_str())
std::shared_ptr<HDCache<TypeInValue>> new_cache =
std::shared_ptr<HDCache<TypeInValue>>(new HDCache<TypeInValue>(
Expand All @@ -230,7 +235,7 @@ class HDCache
{
static std::string random_path("random");
std::shared_ptr<HDCache<TypeInValue>> cache = find_cache(
random_path, use_device, AMSUQPolicy::FAISSMean, -1, threshold);
random_path, use_device, AMSUQPolicy::FAISS_Mean, -1, threshold);
if (cache) {
DBG(UQModule, "Returning existing cache under (%s)", random_path.c_str())
return cache;
Expand Down Expand Up @@ -553,16 +558,13 @@ class HDCache
// compute means
if (defaultRes == AMSResourceType::HOST) {
for (size_t i = 0; i < ndata; ++i) {
CFATAL(UQModule,
m_policy == AMSUQPolicy::DeltaUQ,
"DeltaUQ is not supported yet");
if (m_policy == AMSUQPolicy::FAISSMean) {
if (m_policy == AMSUQPolicy::FAISS_Mean) {
TypeValue mean_dist = std::accumulate(kdists + i * knbrs,
kdists + (i + 1) * knbrs,
0.) *
ook;
is_acceptable[i] = mean_dist < acceptable_error;
} else if (m_policy == AMSUQPolicy::FAISSMax) {
} else if (m_policy == AMSUQPolicy::FAISS_Max) {
// Take the furtherst cluster as the distance metric
TypeValue max_dist =
*std::max_element(&kdists[i * knbrs],
Expand All @@ -572,9 +574,8 @@ class HDCache
}
} else {
CFATAL(UQModule,
(m_policy == AMSUQPolicy::DeltaUQ) ||
(m_policy == AMSUQPolicy::FAISSMax),
"DeltaUQ is not supported yet");
m_policy == AMSUQPolicy::FAISS_Max,
"FAISS Max on device is not supported yet");

ams::Device::computePredicate(
kdists, is_acceptable, ndata, knbrs, acceptable_error);
Expand Down
83 changes: 71 additions & 12 deletions src/ml/surrogate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#ifndef __AMS_SURROGATE_HPP__
#define __AMS_SURROGATE_HPP__

#include <ATen/core/ivalue.h>
#include <memory>
#include <stdexcept>
#include <string>
Expand Down Expand Up @@ -37,6 +38,7 @@ class SurrogateModel
private:
const std::string model_path;
const bool _is_cpu;
const bool _isDeltaUq;


#ifdef __ENABLE_TORCH__
Expand Down Expand Up @@ -104,6 +106,30 @@ class SurrogateModel
}
}

PERFFASPECT()
inline void tensorToHostArray(at::Tensor tensor,
long numRows,
long numCols,
TypeInValue** array)
{
// Transpose to get continuous memory and
// perform single memcpy.
tensor = tensor.transpose(1, 0);
if (_is_cpu) {
for (long j = 0; j < numCols; j++) {
auto tmp = tensor[j].contiguous();
TypeInValue* ptr = tmp.data_ptr<TypeInValue>();
HtoHMemcpy(array[j], ptr, sizeof(TypeInValue) * numRows);
}
} else {
for (long j = 0; j < numCols; j++) {
auto tmp = tensor[j].contiguous();
TypeInValue* ptr = tmp.data_ptr<TypeInValue>();
DtoHMemcpy(array[j], ptr, sizeof(TypeInValue) * numRows);
}
}
}

// -------------------------------------------------------------------------
// loading a surrogate model!
// -------------------------------------------------------------------------
Expand Down Expand Up @@ -150,21 +176,36 @@ class SurrogateModel
size_t num_in,
size_t num_out,
const TypeInValue** inputs,
TypeInValue** outputs)
TypeInValue** outputs,
TypeInValue** outputs_stdev)
{
//torch::NoGradGuard no_grad;
c10::InferenceMode guard(true);
auto input = arrayToTensor(num_elements, num_in, inputs);
input.set_requires_grad(false);
at::Tensor output = module.forward({input}).toTensor().detach();
if (_isDeltaUq) {
assert(outputs_stdev && "Expected non-null outputs_stdev");
// The deltauq surrogate returns a tuple of (outputs, outputs_stdev)
auto output_tuple = module.forward({input}).toTuple();
at::Tensor output_tensor = output_tuple->elements()[0].toTensor().detach();
at::Tensor output_stdev_tensor = output_tuple->elements()[1].toTensor().detach();
tensorToArray(output_tensor, num_elements, num_out, outputs);
tensorToHostArray(output_stdev_tensor,
num_elements,
num_out,
outputs_stdev);
}
else {
at::Tensor output = module.forward({input}).toTensor().detach();
tensorToArray(output, num_elements, num_out, outputs);
}

DBG(Surrogate,
"Evaluate surrogate model (%ld, %ld) -> (%ld, %ld)",
num_elements,
num_in,
num_elements,
num_out);
tensorToArray(output, num_elements, num_out, outputs);
}

#else
Expand All @@ -186,10 +227,11 @@ class SurrogateModel

#endif

SurrogateModel(const char* model_path, bool is_cpu = true)
: model_path(model_path), _is_cpu(is_cpu)
SurrogateModel(const char* model_path,
bool is_cpu = true,
bool is_DeltaUQ = false)
: model_path(model_path), _is_cpu(is_cpu), _isDeltaUq(is_DeltaUQ)
{

if (_is_cpu)
_load<TypeInValue>(model_path, "cpu");
else
Expand Down Expand Up @@ -222,7 +264,8 @@ class SurrogateModel

static std::shared_ptr<SurrogateModel<TypeInValue>> getInstance(
const char* model_path,
bool is_cpu = true)
bool is_cpu = true,
bool is_DeltaUQ = false)
{
auto model =
SurrogateModel<TypeInValue>::instances.find(std::string(model_path));
Expand All @@ -248,7 +291,7 @@ class SurrogateModel
DBG(Surrogate, "Generating new model under (%s)", model_path);
std::shared_ptr<SurrogateModel<TypeInValue>> torch_model =
std::shared_ptr<SurrogateModel<TypeInValue>>(
new SurrogateModel<TypeInValue>(model_path, is_cpu));
new SurrogateModel<TypeInValue>(model_path, is_cpu, is_DeltaUQ));
instances.insert(std::make_pair(std::string(model_path), torch_model));
return torch_model;
};
Expand All @@ -264,9 +307,24 @@ class SurrogateModel
size_t num_in,
size_t num_out,
const TypeInValue** inputs,
TypeInValue** outputs)
TypeInValue** outputs,
TypeInValue **outputs_stdev = nullptr)
{
_evaluate(num_elements, num_in, num_out, inputs, outputs);
_evaluate(num_elements, num_in, num_out, inputs, outputs, outputs_stdev);
}

PERFFASPECT()
inline void evaluate(long num_elements,
std::vector<const TypeInValue*> inputs,
std::vector<TypeInValue*> outputs,
std::vector<TypeInValue*> outputs_stdev)
{
_evaluate(num_elements,
inputs.size(),
outputs.size(),
static_cast<const TypeInValue**>(inputs.data()),
static_cast<TypeInValue**>(outputs.data()),
static_cast<TypeInValue**>(outputs_stdev.data()));
}

PERFFASPECT()
Expand All @@ -277,8 +335,9 @@ class SurrogateModel
_evaluate(num_elements,
inputs.size(),
outputs.size(),
static_cast< const TypeInValue**>(inputs.data()),
static_cast<TypeInValue**>(outputs.data()));
static_cast<const TypeInValue**>(inputs.data()),
static_cast<TypeInValue**>(outputs.data()),
nullptr);
}

bool is_double() { return (tensorOptions.dtype() == torch::kFloat64); }
Expand Down
Loading

0 comments on commit 662626b

Please sign in to comment.