Skip to content

Commit

Permalink
Allow multiple instantions of workflow which use the same model/db. (#13
Browse files Browse the repository at this point in the history
)
  • Loading branch information
koparasy authored Oct 25, 2023
1 parent 3e25d60 commit 13e4e3a
Show file tree
Hide file tree
Showing 8 changed files with 515 additions and 200 deletions.
30 changes: 19 additions & 11 deletions examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ int main(int argc, char **argv)
args.AddOption(&db_config,
"-db",
"--dbconfig",
"Path to directory where applications will store their data (or Path to JSON configuration if RabbitMQ is chosen)",
"Path to directory where applications will store their data "
"(or Path to JSON configuration if RabbitMQ is chosen)",
reqDB);

args.AddOption(&db_type,
Expand All @@ -187,14 +188,19 @@ int main(int argc, char **argv)
"\t 'hdf5': use hdf5 as a back end\n"
"\t 'rmq': use RabbitMQ as a back end\n");

args.AddOption(&k_nearest, "-knn", "--k-nearest-neighbors", "Number of closest neightbors we should look at");
args.AddOption(&k_nearest,
"-knn",
"--k-nearest-neighbors",
"Number of closest neightbors we should look at");

args.AddOption(&uq_policy_opt,
"-uq",
"--uqtype",
"Types of UQ to select from: \n"
"\t 'mean' Uncertainty is computed in comparison against the mean distance of k-nearest neighbors\n"
"\t 'max': Uncertainty is computed in comparison with the k'st cluster \n"
"\t 'mean' Uncertainty is computed in comparison against the "
"mean distance of k-nearest neighbors\n"
"\t 'max': Uncertainty is computed in comparison with the "
"k'st cluster \n"
"\t 'deltauq': Uncertainty through DUQ (not supported)\n");

args.AddOption(
Expand Down Expand Up @@ -266,12 +272,14 @@ int main(int argc, char **argv)
dbType = AMSDBType::RMQ;
}

AMSUQPolicy uq_policy =
(std::strcmp(uq_policy_opt, "max") == 0) ? AMSUQPolicy::FAISSMax: AMSUQPolicy::FAISSMean;
AMSUQPolicy uq_policy = (std::strcmp(uq_policy_opt, "max") == 0)
? AMSUQPolicy::FAISSMax
: AMSUQPolicy::FAISSMean;

if ( uq_policy != AMSUQPolicy::FAISSMax )
if (uq_policy != AMSUQPolicy::FAISSMax)
uq_policy = ((std::strcmp(uq_policy_opt, "deltauq") == 0))
? AMSUQPolicy::DeltaUQ : AMSUQPolicy::FAISSMean;
? AMSUQPolicy::DeltaUQ
: AMSUQPolicy::FAISSMean;

// set up a randomization seed
srand(seed + rId);
Expand Down Expand Up @@ -423,7 +431,7 @@ int main(int argc, char **argv)
AMSResourceType ams_device = AMSResourceType::HOST;
if (use_device) ams_device = AMSResourceType::DEVICE;
AMSExecPolicy ams_loadBalance = AMSExecPolicy::UBALANCED;
if ( lbalance ) ams_loadBalance = AMSExecPolicy::BALANCED;
if (lbalance) ams_loadBalance = AMSExecPolicy::BALANCED;

AMSConfig amsConf = {ams_loadBalance,
AMSDType::Double,
Expand All @@ -437,10 +445,10 @@ int main(int argc, char **argv)
uq_policy,
k_nearest,
rId,
wS };
AMSExecutor wf = AMSCreateExecutor(amsConf);
wS};

for (int mat_idx = 0; mat_idx < num_mats; ++mat_idx) {
AMSExecutor wf = AMSCreateExecutor(amsConf);
workflow[mat_idx] = wf;
}
#endif
Expand Down
Loading

0 comments on commit 13e4e3a

Please sign in to comment.