44#ifndef NIMBLE_MPI_REDUCTION_H
45#define NIMBLE_MPI_REDUCTION_H
59#include <unordered_map>
60#include <unordered_set>
71 template <
class... Args>
72 using vector = std::vector<Args...>;
73 template <
class... Args>
74 using hashmap = std::unordered_map<Args...>;
75 std::vector<ReductionClique_t> cliques;
76 std::vector<int> unfinished;
79 const vector<int>& clique_colors,
80 const vector<int>& clique_ids,
81 const vector<int>& clique_assignment,
82 const vector<int>& global_ids,
83 const mpicontext& context)
85 hashmap<int, vector<int>> clique_index_assignment;
86 clique_index_assignment.reserve(clique_ids.size());
87 int num_ranks = context.get_size();
89 hashmap<int, int> counts;
90 counts.reserve(clique_ids.size());
92 for (
int clique_id : clique_assignment)
93 if (clique_id > num_ranks) counts[clique_id] += 1;
95 for (
int clique_id : clique_ids) clique_index_assignment[clique_id].reserve(counts[clique_id]);
97 for (
int index = 0, max = clique_assignment.size(); index < max; ++index) {
98 int clique_id = clique_assignment[index];
99 if (clique_id > num_ranks) { clique_index_assignment[clique_id].emplace_back(index); }
102 std::vector<MPI_Comm> comms{};
103 comms.reserve(clique_ids.size());
106 for (
int color : clique_colors) {
107 MPI_Comm comm = context.split_by_color(color);
108 if (comm != MPI_COMM_NULL) { comms.push_back(comm); }
112 if (comms.size() != clique_ids.size())
NIMBLE_ABORT(
"**** Error, comms.size() != clique_ids.size().");
114 for (
size_t i = 0; i < clique_ids.size(); ++i) {
115 int clique_id = clique_ids[i];
116 MPI_Comm comm = comms[i];
117 auto& index_list = clique_index_assignment[clique_id];
118 std::sort(index_list.begin(), index_list.end(), [&](
int a,
int b) { return global_ids[a] < global_ids[b]; });
119 this->cliques.emplace_back(std::move(index_list), index_list.size() * 3, comm);
124 ReductionInfo(ReductionInfo&& ri) =
default;
126 operator=(ReductionInfo&& ri) =
default;
127 template <
int field_size,
class Lookup>
129 Reduce(Lookup&& data)
131 for (
auto& clique : cliques) clique.asyncreduce_initialize<field_size>(data);
133 unfinished.resize(cliques.size());
134 std::iota(unfinished.begin(), unfinished.end(), 0);
136 while (!unfinished.empty()) {
138 for (
size_t i = 0; i < unfinished.size(); i += increment) {
139 bool reduceFinished = cliques[unfinished[i]].asyncreduce_finalize<field_size>(data);
141 increment = !reduceFinished;
143 if (reduceFinished) {
144 unfinished[i] = unfinished.back();
145 unfinished.pop_back();
151 GetAllIndices(std::vector<int>& indices, std::vector<int>& min_rank_containing_index)
153 for (
auto& clique : cliques) {
154 int min_mpi_comm_world_rank = clique.GetMPICommWorldRanks()[0];
155 int num_indices = clique.GetNumIndices();
156 int const* clique_indices = clique.GetIndices();
157 for (
int i = 0; i < num_indices; i++) {
158 indices.push_back(clique_indices[i]);
159 min_rank_containing_index.push_back(min_mpi_comm_world_rank);
164 PerformReduction(
double* data,
int field_size)
166 switch (field_size) {
167 case 1: Reduce<1>(data);
break;
168 case 2: Reduce<2>(data);
break;
169 case 3: Reduce<3>(data);
break;
171 std::string fs = std::to_string(field_size);
173 throw std::invalid_argument(
"Bad field size of " + fs);
176 template <
class Lookup>
178 PerformReduction(Lookup& lookup,
int field_size)
180 switch (field_size) {
181 case 1: Reduce<1>(lookup);
break;
182 case 2: Reduce<2>(lookup);
break;
183 case 3: Reduce<3>(lookup);
break;
185 std::string fs = std::to_string(field_size);
187 throw std::invalid_argument(
"Bad field size of " + fs);
192template <
class list_of_lists_t,
class F>
194fill_clique_lookup(list_of_lists_t& ids_by_rank, F&& clique_lookup) ->
195 typename std::decay<quanta::transformed_iterated_t<F, quanta::elem_t<list_of_lists_t>>>::type
197 typedef decltype(clique_lookup) lookup_t;
198 typedef quanta::elem_t<list_of_lists_t> inner_list_t;
199 typedef typename std::decay<quanta::transformed_iterated_t<lookup_t, inner_list_t>>::type clique_t;
201 std::unordered_map<clique_t, clique_t> remapped_cliques{};
203 const clique_t zero{};
204 clique_t rank_plus_one{};
206 auto generate_clique_id = quanta::make_counter<clique_t>(quanta::len(ids_by_rank) + 1);
208 for (
auto& id_list : ids_by_rank) {
209 remapped_cliques.clear();
212 for (
auto&
id : id_list) {
213 auto& current_clique = clique_lookup(
id);
215 if (current_clique == zero) {
216 current_clique = rank_plus_one;
218 auto remapped_clique_iter = remapped_cliques.find(current_clique);
219 if (remapped_clique_iter != remapped_cliques.end()) {
220 current_clique = remapped_clique_iter->second;
222 auto new_clique = generate_clique_id();
223 remapped_cliques[current_clique] = new_clique;
224 current_clique = new_clique;
229 return generate_clique_id.get_count();
233GenerateReductionInfo(
const std::vector<int>& raw_global_ids,
const mpicontext& context);
Definition kokkos_contact_manager.h:49
#define NIMBLE_ABORT(...)
Definition nimble_macros.h:87