Skip to content

Commit

Permalink
Add tests for FAISS and release memory (#15)
Browse files Browse the repository at this point in the history
* Add faiss-index test and fix faiss-gpu syncrhonization issue.

* Delete HDCache and free example memory

* Delete index only if loaded
  • Loading branch information
koparasy authored Oct 28, 2023
1 parent 13e4e3a commit bb4f2d2
Show file tree
Hide file tree
Showing 8 changed files with 643 additions and 244 deletions.
9 changes: 9 additions & 0 deletions examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,15 @@ int main(int argc, char **argv)
CALIPER(CALI_MARK_END("Cycle");)
MPI_CALL(MPI_Barrier(MPI_COMM_WORLD));
}

delete[] workflow;

// TODO: Add smart-pointers
for (int mat_idx = 0; mat_idx < num_mats; ++mat_idx) {
delete eoses[mat_idx];
eoses[mat_idx] = nullptr;
}

CALIPER(CALI_MARK_END("TimeStepLoop"););
MPI_CALL(MPI_Finalize());
return 0;
Expand Down
33 changes: 25 additions & 8 deletions src/ml/hdcache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,15 @@ class HDCache
return new_cache;
}

~HDCache() { DBG(Surrogate, "Destroying UQ-cache") }
~HDCache()
{
DBG(UQModule, "Deleting UQ-Module");
if (m_index) {
DBG(UQModule, "Deleting HD-Cache");
m_index->reset();
delete m_index;
}
}

//! ------------------------------------------------------------------------
//! simple queries
Expand Down Expand Up @@ -400,6 +408,7 @@ class HDCache
} else {
_evaluate(ndata, data, is_acceptable);
}
DBG(UQModule, "Done with evalution of uq")
}

//! train on data that comes separate features (a vector of pointers)
Expand Down Expand Up @@ -431,6 +440,7 @@ class HDCache
_evaluate(ndata, lin_data, is_acceptable);
ams::ResourceManager::deallocate(lin_data, defaultRes);
}
DBG(UQModule, "Done with evalution of uq");
}

private:
Expand Down Expand Up @@ -529,28 +539,35 @@ class HDCache
for (int start = 0; start < ndata; start += MAGIC_NUMBER) {
unsigned int nElems =
((ndata - start) < MAGIC_NUMBER) ? ndata - start : MAGIC_NUMBER;
DBG(UQModule, "Running for %d elements %d %d", nElems, start, m_dim);
m_index->search(nElems,
&data[start],
&data[start * m_dim],
knbrs,
&kdists[start * knbrs],
&kidxs[start * knbrs]);
}
#ifdef __ENABLE_CUDA__
faiss::gpu::synchronizeAllDevices();
#endif

// compute means
if (defaultRes == AMSResourceType::HOST) {
TypeValue total_dist = 0;
for (size_t i = 0; i < ndata; ++i) {
CFATAL(UQModule,
m_policy == AMSUQPolicy::DeltaUQ,
"DeltaUQ is not supported yet");
if (m_policy == AMSUQPolicy::FAISSMean) {
total_dist =
std::accumulate(kdists + i * knbrs, kdists + (i + 1) * knbrs, 0.);
is_acceptable[i] = (ook * total_dist) < acceptable_error;
TypeValue mean_dist = std::accumulate(kdists + i * knbrs,
kdists + (i + 1) * knbrs,
0.) *
ook;
is_acceptable[i] = mean_dist < acceptable_error;
} else if (m_policy == AMSUQPolicy::FAISSMax) {
// Take the furtherst cluster as the distance metric
total_dist = kdists[i * knbrs + knbrs - 1];
is_acceptable[i] = (total_dist) < acceptable_error;
TypeValue max_dist =
*std::max_element(&kdists[i * knbrs],
&kdists[i * knbrs + knbrs - 1]);
is_acceptable[i] = (max_dist) < acceptable_error;
}
}
} else {
Expand Down
Loading

0 comments on commit bb4f2d2

Please sign in to comment.