Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] MueLu: Refactor CoalesceDropFactory_kokkos #12861

Draft
wants to merge 8 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions packages/muelu/src/Graph/Containers/MueLu_LWGraphBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ class LWGraphBase {
return graph;
}

const std::string& getObjectLabel() const {
return objectLabel_;
}

private:
//! Underlying graph (with label)
mutable local_graph_type graph_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "MueLu_LWGraphBase.hpp"
#include "MueLu_LWGraph_kokkos_fwd.hpp"
#include "MueLu_LWGraph_fwd.hpp"

namespace MueLu {

Expand All @@ -28,6 +29,9 @@ namespace MueLu {
template <class LocalOrdinal, class GlobalOrdinal, class Node>
class LWGraph_kokkos : public MueLu::LWGraphBase<LocalOrdinal, GlobalOrdinal, Node, false> {
using LWGraphBase<LocalOrdinal, GlobalOrdinal, Node, false>::LWGraphBase;

public:
RCP<MueLu::LWGraph<LocalOrdinal, GlobalOrdinal, Node> > copyToHost();
};

} // namespace MueLu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,33 @@
#ifndef MUELU_LWGRAPH_KOKKOS_DEF_HPP
#define MUELU_LWGRAPH_KOKKOS_DEF_HPP

#include "MueLu_LWGraph.hpp"
#include "MueLu_LWGraph_kokkos_decl.hpp"

namespace MueLu {

template <class LocalOrdinal, class GlobalOrdinal, class Node>
RCP<MueLu::LWGraph<LocalOrdinal, GlobalOrdinal, Node> > MueLu::LWGraph_kokkos<LocalOrdinal, GlobalOrdinal, Node>::copyToHost() {
auto graph = this->getGraph();

auto row_map_h = Kokkos::create_mirror_view(graph.row_map);
auto entries_h = Kokkos::create_mirror_view(graph.entries);
Kokkos::deep_copy(row_map_h, graph.row_map);
Kokkos::deep_copy(entries_h, graph.entries);

using local_graph_type_host = typename MueLu::LWGraphBase<LocalOrdinal, GlobalOrdinal, Node, true>::local_graph_type;
auto graph_h = local_graph_type_host(entries_h, row_map_h);

auto lw_h = rcp(new MueLu::LWGraph<LocalOrdinal, GlobalOrdinal, Node>(graph_h, this->GetDomainMap(), this->GetImportMap(), this->getObjectLabel()));

auto bndry = this->GetBoundaryNodeMap();
auto bndry_h = Kokkos::create_mirror_view(bndry);
Kokkos::deep_copy(bndry_h, bndry);
lw_h->SetBoundaryNodeMap(bndry_h);

return lw_h;
}

} // namespace MueLu

#endif // MUELU_LWGRAPH_KOKKOS_DEF_HPP
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
#ifndef MUELU_BOUNDARYDETECTION_HPP
#define MUELU_BOUNDARYDETECTION_HPP

#include <cstddef>
#include <tuple>
#include "Kokkos_Core.hpp"
#include "Kokkos_ArithTraits.hpp"
#include "MueLu_LWGraph_kokkos.hpp"
#include "MueLu_Utilities.hpp"
#include "Teuchos_RCP.hpp"
#include "Xpetra_ConfigDefs.hpp"
#include "Xpetra_CrsGraph.hpp"
#include "Xpetra_MultiVector.hpp"

namespace MueLu::BoundaryDetection {

// These functors all assume that the boundaryNodes view has been initialized to false.

// Marks rows as Dirichlet based on value threshold and number of off-diagonal entries.
template <class local_matrix_type>
class PointDirichletFunctor {
private:
using scalar_type = typename local_matrix_type::value_type;
using local_ordinal_type = typename local_matrix_type::ordinal_type;
using memory_space = typename local_matrix_type::memory_space;

using ATS = Kokkos::ArithTraits<scalar_type>;
using magnitudeType = typename ATS::magnitudeType;
using boundary_nodes_view = Kokkos::View<bool*, memory_space>;

local_matrix_type A;
boundary_nodes_view boundaryNodes;
magnitudeType dirichletThreshold;
local_ordinal_type dirichletNonzeroThreshold;

public:
PointDirichletFunctor(local_matrix_type& A_, boundary_nodes_view boundaryNodes_, magnitudeType dirichletThreshold_, local_ordinal_type dirichletNonzeroThreshold_)
: A(A_)
, boundaryNodes(boundaryNodes_)
, dirichletThreshold(dirichletThreshold_)
, dirichletNonzeroThreshold(dirichletNonzeroThreshold_) {}

KOKKOS_FORCEINLINE_FUNCTION
void operator()(const local_ordinal_type rlid) const {
auto row = A.row(rlid);
local_ordinal_type nnz = 0;
for (local_ordinal_type k = 0; k < row.length; ++k) {
auto clid = row.colidx(k);
auto val = row.value(k);
if ((rlid != static_cast<local_ordinal_type>(clid)) && (ATS::magnitude(val) > dirichletThreshold)) {
++nnz;
if (nnz == dirichletNonzeroThreshold) {
return;
}
}
}
boundaryNodes(rlid) = true;
}
};

// Marks rows as Dirichlet based on abs(rowsum) and abs(diag).
template <class local_matrix_type>
class RowSumFunctor {
private:
using scalar_type = typename local_matrix_type::value_type;
using local_ordinal_type = typename local_matrix_type::ordinal_type;
using memory_space = typename local_matrix_type::memory_space;

using ATS = Kokkos::ArithTraits<scalar_type>;
using magnitudeType = typename ATS::magnitudeType;
using magATS = Kokkos::ArithTraits<magnitudeType>;
using boundary_nodes_view = Kokkos::View<bool*, memory_space>;

local_matrix_type A;
boundary_nodes_view boundaryNodes;
magnitudeType rowSumTol;

public:
RowSumFunctor(local_matrix_type& A_, boundary_nodes_view boundaryNodes_, magnitudeType rowSumTol_)
: A(A_)
, boundaryNodes(boundaryNodes_)
, rowSumTol(rowSumTol_) {}

KOKKOS_FORCEINLINE_FUNCTION
void operator()(const local_ordinal_type rlid) const {
scalar_type rowsum = ATS::zero();
scalar_type diagval = ATS::zero();
auto row = A.row(rlid);
for (local_ordinal_type k = 0; k < row.length; ++k) {
auto clid = row.colidx(k);
auto val = row.value(k);
if (rlid == static_cast<local_ordinal_type>(clid))
diagval = val;
rowsum += val;
}
if (ATS::magnitude(rowsum) > ATS::magnitude(diagval) * rowSumTol) {
boundaryNodes(rlid) = true;
}
}
};

template <class local_matrix_type, class FunctorsType>
class BoundaryFunctor {
private:
using scalar_type = typename local_matrix_type::value_type;
using local_ordinal_type = typename local_matrix_type::ordinal_type;
using memory_space = typename local_matrix_type::memory_space;

local_matrix_type A;
FunctorsType functors;

public:
BoundaryFunctor(local_matrix_type& A_, FunctorsType& functors_)
: A(A_)
, functors(functors_) {}

KOKKOS_INLINE_FUNCTION
void operator()(const local_ordinal_type rlid) const {
if constexpr (std::tuple_size<FunctorsType>::value >= 1) {
auto functor = std::get<0>(functors);
functor(rlid);
}
if constexpr (std::tuple_size<FunctorsType>::value >= 2) {
auto functor = std::get<1>(functors);
functor(rlid);
}
if constexpr (std::tuple_size<FunctorsType>::value >= 3) {
auto functor = std::get<2>(functors);
functor(rlid);
}
if constexpr (std::tuple_size<FunctorsType>::value >= 4) {
auto functor = std::get<3>(functors);
functor(rlid);
}
}
};

// Marks rows as Dirichlet based on value threshold and number of off-diagonal entries.
// Marks blocks as Dirichlet when one row is Dirichlet (useGreedyDirichlet==true) or when all rows are Dirichlet (useGreedyDirichlet==false).
template <class local_matrix_type, bool useGreedyDirichlet>
class VectorDirichletFunctor {
private:
using scalar_type = typename local_matrix_type::value_type;
using local_ordinal_type = typename local_matrix_type::ordinal_type;
using memory_space = typename local_matrix_type::memory_space;

using ATS = Kokkos::ArithTraits<scalar_type>;
using magnitudeType = typename ATS::magnitudeType;
using boundary_nodes_view = Kokkos::View<bool*, memory_space>;

local_matrix_type A;
local_ordinal_type blockSize;
boundary_nodes_view boundaryNodes;
magnitudeType dirichletThreshold;
local_ordinal_type dirichletNonzeroThreshold;

public:
VectorDirichletFunctor(local_matrix_type& A_, local_ordinal_type blockSize_, boundary_nodes_view boundaryNodes_, magnitudeType dirichletThreshold_, local_ordinal_type dirichletNonzeroThreshold_)
: A(A_)
, blockSize(blockSize_)
, boundaryNodes(boundaryNodes_)
, dirichletThreshold(dirichletThreshold_)
, dirichletNonzeroThreshold(dirichletNonzeroThreshold_) {}

KOKKOS_FORCEINLINE_FUNCTION
void operator()(const local_ordinal_type rblid) const {
for (local_ordinal_type rlid = rblid * blockSize; rlid < (rblid + 1) * blockSize; ++rlid) {
auto row = A.row(rlid);
local_ordinal_type nnz = 0;
bool rowIsDirichlet = true;
for (local_ordinal_type k = 0; k < row.length; ++k) {
auto clid = row.colidx(k);
auto val = row.value(k);
if ((rlid != static_cast<local_ordinal_type>(clid)) && (ATS::magnitude(val) > dirichletThreshold)) {
++nnz;
if (nnz == dirichletNonzeroThreshold) {
rowIsDirichlet = false;
break;
}
}
}
if constexpr (useGreedyDirichlet) {
if (rowIsDirichlet) {
boundaryNodes(rblid) = true;
return;
}
} else {
if (!rowIsDirichlet) {
boundaryNodes(rblid) = false;
return;
}
}
}
if constexpr (useGreedyDirichlet)
boundaryNodes(rblid) = false;
else
boundaryNodes(rblid) = true;
}
};

} // namespace MueLu::BoundaryDetection

#endif
Loading
Loading