diff --git a/src/operator/contrib/dgl_graph.cc
b/src/operator/contrib/dgl_graph.cc
index ed7caacfdba..ff866375056 100644
--- a/src/operator/contrib/dgl_graph.cc
+++ b/src/operator/contrib/dgl_graph.cc
@@ -413,21 +413,6 @@ static bool CSRNeighborNonUniformSampleType(const
nnvm::NodeAttrs& attrs,
return success;
}
-/*
- * Get src vertex and edge id for a destination vertex
- */
-static void GetSrcList(const dgl_id_t* val_list,
- const dgl_id_t* col_list,
- const dgl_id_t* indptr,
- const dgl_id_t dst_id,
- std::vector<dgl_id_t>* src_list,
- std::vector<dgl_id_t>* edge_list) {
- for (dgl_id_t i = *(indptr+dst_id); i < *(indptr+dst_id+1); ++i) {
- src_list->push_back(col_list[i]);
- edge_list->push_back(val_list[i]);
- }
-}
-
static void RandomSample(size_t set_size,
size_t num,
std::vector<size_t>* out,
@@ -464,34 +449,38 @@ static void NegateSet(const std::vector<size_t> &idxs,
/*
* Uniform sample
*/
-static void GetUniformSample(const std::vector<dgl_id_t>& ver_list,
- const std::vector<dgl_id_t>& edge_list,
+static void GetUniformSample(const dgl_id_t* val_list,
+ const dgl_id_t* col_list,
+ const dgl_id_t* indptr,
+ const dgl_id_t dst_id,
const size_t max_num_neighbor,
std::vector<dgl_id_t>* out_ver,
std::vector<dgl_id_t>* out_edge,
unsigned int* seed) {
- CHECK_EQ(ver_list.size(), edge_list.size());
+ size_t ver_len = *(indptr+dst_id+1) - *(indptr+dst_id);
// Copy ver_list to output
- if (ver_list.size() <= max_num_neighbor) {
- for (size_t i = 0; i < ver_list.size(); ++i) {
- out_ver->push_back(ver_list[i]);
- out_edge->push_back(edge_list[i]);
+ if (ver_len <= max_num_neighbor) {
+ for (dgl_id_t i = *(indptr+dst_id); i < *(indptr+dst_id+1); ++i) {
+ out_ver->push_back(col_list[i]);
+ out_edge->push_back(val_list[i]);
}
return;
}
// If we just sample a small number of elements from a large neighbor list.
+ const dgl_id_t* col_ptr = col_list + *(indptr + dst_id);
+ const dgl_id_t* val_ptr = val_list + *(indptr + dst_id);
std::vector<size_t> sorted_idxs;
- if (ver_list.size() > max_num_neighbor * 2) {
+ if (ver_len > max_num_neighbor * 2) {
sorted_idxs.reserve(max_num_neighbor);
- RandomSample(ver_list.size(), max_num_neighbor, &sorted_idxs, seed);
+ RandomSample(ver_len, max_num_neighbor, &sorted_idxs, seed);
std::sort(sorted_idxs.begin(), sorted_idxs.end());
} else {
std::vector<size_t> negate;
- negate.reserve(ver_list.size() - max_num_neighbor);
- RandomSample(ver_list.size(), ver_list.size() - max_num_neighbor,
+ negate.reserve(ver_len - max_num_neighbor);
+ RandomSample(ver_len, ver_len - max_num_neighbor,
&negate, seed);
std::sort(negate.begin(), negate.end());
- NegateSet(negate, ver_list.size(), &sorted_idxs);
+ NegateSet(negate, ver_len, &sorted_idxs);
}
// verify the result.
CHECK_EQ(sorted_idxs.size(), max_num_neighbor);
@@ -499,8 +488,8 @@ static void GetUniformSample(const std::vector<dgl_id_t>&
ver_list,
CHECK_GT(sorted_idxs[i], sorted_idxs[i - 1]);
}
for (auto idx : sorted_idxs) {
- out_ver->push_back(ver_list[idx]);
- out_edge->push_back(edge_list[idx]);
+ out_ver->push_back(col_ptr[idx]);
+ out_edge->push_back(val_ptr[idx]);
}
}
@@ -508,26 +497,30 @@ static void GetUniformSample(const std::vector<dgl_id_t>&
ver_list,
* Non-uniform sample via ArrayHeap
*/
static void GetNonUniformSample(const float* probability,
- const std::vector<dgl_id_t>& ver_list,
- const std::vector<dgl_id_t>& edge_list,
+ const dgl_id_t* val_list,
+ const dgl_id_t* col_list,
+ const dgl_id_t* indptr,
+ const dgl_id_t dst_id,
const size_t max_num_neighbor,
std::vector<dgl_id_t>* out_ver,
std::vector<dgl_id_t>* out_edge,
unsigned int* seed) {
- CHECK_EQ(ver_list.size(), edge_list.size());
+ size_t ver_len = *(indptr+dst_id+1) - *(indptr+dst_id);
// Copy ver_list to output
- if (ver_list.size() <= max_num_neighbor) {
- for (size_t i = 0; i < ver_list.size(); ++i) {
- out_ver->push_back(ver_list[i]);
- out_edge->push_back(edge_list[i]);
+ if (ver_len <= max_num_neighbor) {
+ for (dgl_id_t i = *(indptr+dst_id); i < *(indptr+dst_id+1); ++i) {
+ out_ver->push_back(col_list[i]);
+ out_edge->push_back(val_list[i]);
}
return;
}
// Make sample
+ const dgl_id_t* col_ptr = col_list + *(indptr + dst_id);
+ const dgl_id_t* val_ptr = val_list + *(indptr + dst_id);
std::vector<size_t> sp_index(max_num_neighbor);
- std::vector<float> sp_prob(ver_list.size());
- for (size_t i = 0; i < ver_list.size(); ++i) {
- sp_prob[i] = probability[ver_list[i]];
+ std::vector<float> sp_prob(ver_len);
+ for (size_t i = 0; i < ver_len; ++i) {
+ sp_prob[i] = probability[col_ptr[i]];
}
ArrayHeap arrayHeap(sp_prob);
arrayHeap.SampleWithoutReplacement(max_num_neighbor, &sp_index, seed);
@@ -535,8 +528,8 @@ static void GetNonUniformSample(const float* probability,
out_edge->resize(max_num_neighbor);
for (size_t i = 0; i < max_num_neighbor; ++i) {
size_t idx = sp_index[i];
- out_ver->at(i) = ver_list[idx];
- out_edge->at(i) = edge_list[idx];
+ out_ver->at(i) = col_ptr[idx];
+ out_edge->at(i) = val_ptr[idx];
}
sort(out_ver->begin(), out_ver->end());
sort(out_edge->begin(), out_edge->end());
@@ -597,14 +590,15 @@ static void SampleSubgraph(const NDArray &csr,
node.level = 0;
node_queue.push(node);
}
- std::vector<dgl_id_t> tmp_src_list;
- std::vector<dgl_id_t> tmp_edge_list;
std::vector<dgl_id_t> tmp_sampled_src_list;
std::vector<dgl_id_t> tmp_sampled_edge_list;
- std::unordered_map<dgl_id_t, neigh_list> neigh_mp;
+ // ver_id, position
+ std::unordered_map<dgl_id_t, size_t> neigh_pos;
+ std::vector<dgl_id_t> neighbor_list;
size_t num_edges = 0;
+
while (!node_queue.empty() &&
- sub_vertices_count <= max_num_vertices ) {
+ sub_vertices_count < max_num_vertices) {
ver_node& cur_node = node_queue.front();
dgl_id_t dst_id = cur_node.vertex_id;
if (cur_node.level < num_hops) {
@@ -613,35 +607,42 @@ static void SampleSubgraph(const NDArray &csr,
node_queue.pop();
continue;
}
- tmp_src_list.clear();
- tmp_edge_list.clear();
tmp_sampled_src_list.clear();
tmp_sampled_edge_list.clear();
- GetSrcList(val_list,
- col_list,
- indptr,
- dst_id,
- &tmp_src_list,
- &tmp_edge_list);
if (probability == nullptr) { // uniform-sample
- GetUniformSample(tmp_src_list,
- tmp_edge_list,
+ GetUniformSample(val_list,
+ col_list,
+ indptr,
+ dst_id,
num_neighbor,
&tmp_sampled_src_list,
&tmp_sampled_edge_list,
&time_seed);
} else { // non-uniform-sample
GetNonUniformSample(probability,
- tmp_src_list,
- tmp_edge_list,
+ val_list,
+ col_list,
+ indptr,
+ dst_id,
num_neighbor,
&tmp_sampled_src_list,
&tmp_sampled_edge_list,
&time_seed);
}
- neigh_mp.insert(std::pair<dgl_id_t, neigh_list>(dst_id,
- neigh_list(tmp_sampled_src_list,
- tmp_sampled_edge_list)));
+ CHECK_EQ(tmp_sampled_src_list.size(),
+ tmp_sampled_edge_list.size());
+ size_t pos = neighbor_list.size();
+ neigh_pos[dst_id] = pos;
+ // First we push the size of neighbor vector
+ neighbor_list.push_back(tmp_sampled_edge_list.size());
+ // Then push the vertices
+ for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {
+ neighbor_list.push_back(tmp_sampled_src_list[i]);
+ }
+ // Finally we push the edge list
+ for (size_t i = 0; i < tmp_sampled_edge_list.size(); ++i) {
+ neighbor_list.push_back(tmp_sampled_edge_list[i]);
+ }
num_edges += tmp_sampled_src_list.size();
sub_ver_mp[cur_node.vertex_id] = cur_node.level;
for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {
@@ -659,11 +660,9 @@ static void SampleSubgraph(const NDArray &csr,
node_queue.pop();
continue;
}
- tmp_sampled_src_list.clear();
- tmp_sampled_edge_list.clear();
- neigh_mp.insert(std::pair<dgl_id_t, neigh_list>(dst_id,
- neigh_list(tmp_sampled_src_list, // empty vector
- tmp_sampled_edge_list))); // empty vector
+ size_t pos = neighbor_list.size();
+ neigh_pos[dst_id] = pos;
+ neighbor_list.push_back(0);
sub_ver_mp[cur_node.vertex_id] = cur_node.level;
}
sub_vertices_count++;
@@ -720,16 +719,18 @@ static void SampleSubgraph(const NDArray &csr,
size_t collected_nedges = 0;
for (size_t i = 0; i < num_vertices; i++) {
dgl_id_t dst_id = *(out + i);
- auto it = neigh_mp.find(dst_id);
- const auto &edges = it->second.edges;
- const auto &neighs = it->second.neighs;
- CHECK_EQ(edges.size(), neighs.size());
- if (!edges.empty()) {
- std::copy(edges.begin(), edges.end(), val_list_out + collected_nedges);
- std::copy(neighs.begin(), neighs.end(), col_list_out + collected_nedges);
- collected_nedges += edges.size();
+ size_t pos = neigh_pos[dst_id];
+ size_t edge_size = neighbor_list[pos];
+ if (edge_size != 0) {
+ std::copy_n(neighbor_list.begin() + pos + 1,
+ edge_size,
+ col_list_out + collected_nedges);
+ std::copy_n(neighbor_list.begin() + pos + edge_size + 1,
+ edge_size,
+ val_list_out + collected_nedges);
+ collected_nedges += edge_size;
}
- indptr_out[i+1] = indptr_out[i] + edges.size();
+ indptr_out[i+1] = indptr_out[i] + edge_size;
}
for (dgl_id_t i = num_vertices+1; i <= max_num_vertices; ++i) {
indptr_out[i] = indptr_out[i-1];
With regards,
Apache Git Services