eric-haibin-lin commented on a change in pull request #13588: Accelerate DGL
csr neighbor sampling
URL: https://github.com/apache/incubator-mxnet/pull/13588#discussion_r242288409
##########
File path: src/operator/contrib/dgl_graph.cc
##########
@@ -718,20 +691,37 @@ static void SampleSubgraph(const NDArray &csr,
dgl_id_t* indptr_out = sub_csr.aux_data(0).dptr<dgl_id_t>();
indptr_out[0] = 0;
size_t collected_nedges = 0;
+
+ // Both the out array and neigh_pos are sorted. By scanning the two arrays,
we can see
+ // which vertices have neighbors and which don't.
+ std::sort(neigh_pos.begin(), neigh_pos.end(),
+ [](const std::pair<dgl_id_t, size_t> &a1, const
std::pair<dgl_id_t, size_t> &a2) {
+ return a1.first < a2.first;
+ });
+ size_t idx_with_neigh = 0;
for (size_t i = 0; i < num_vertices; i++) {
dgl_id_t dst_id = *(out + i);
- auto it = neigh_mp.find(dst_id);
- const auto &edges = it->second.edges;
- const auto &neighs = it->second.neighs;
- CHECK_EQ(edges.size(), neighs.size());
- if (!edges.empty()) {
- std::copy(edges.begin(), edges.end(), val_list_out + collected_nedges);
- std::copy(neighs.begin(), neighs.end(), col_list_out + collected_nedges);
- collected_nedges += edges.size();
+ // If a vertex is in sub_ver_mp but not in neigh_pos, this vertex must not
+ // have edges.
+ size_t edge_size = 0;
+ if (idx_with_neigh < neigh_pos.size() && dst_id ==
neigh_pos[idx_with_neigh].first) {
+ size_t pos = neigh_pos[idx_with_neigh].second;
+ CHECK_LT(pos, neighbor_list.size());
+ edge_size = neighbor_list[pos];
+ CHECK_LE(pos + edge_size * 2 + 1, neighbor_list.size());
+
+ std::copy_n(neighbor_list.begin() + pos + 1,
+ edge_size,
+ col_list_out + collected_nedges);
+ std::copy_n(neighbor_list.begin() + pos + edge_size + 1,
+ edge_size,
+ val_list_out + collected_nedges);
+ collected_nedges += edge_size;
+ idx_with_neigh++;
}
- indptr_out[i+1] = indptr_out[i] + edges.size();
+ indptr_out[i+1] = indptr_out[i] + edge_size;
}
- for (dgl_id_t i = num_vertices+1; i <= max_num_vertices; ++i) {
+ for (size_t i = num_vertices+1; i <= max_num_vertices; ++i) {
indptr_out[i] = indptr_out[i-1];
}
}
Review comment:
Also, the current example seed doesn't result in `-1` in the output. Maybe
we want to pick some input which is more representative
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services