BullDemonKing commented on a change in pull request #13588: Accelerate DGL csr
neighbor sampling
URL: https://github.com/apache/incubator-mxnet/pull/13588#discussion_r242412722
##########
File path: src/operator/contrib/dgl_graph.cc
##########
@@ -586,123 +563,119 @@ static void SampleSubgraph(const NDArray &csr,
dgl_id_t* out_layer = sub_layer.data().dptr<dgl_id_t>();
// BFS traverse the graph and sample vertices
- dgl_id_t sub_vertices_count = 0;
// <vertex_id, layer_id>
- std::unordered_map<dgl_id_t, int> sub_ver_mp;
- std::queue<ver_node> node_queue;
+ std::unordered_set<dgl_id_t> sub_ver_mp;
+ std::vector<std::pair<dgl_id_t, dgl_id_t> > sub_vers;
+ sub_vers.reserve(num_seeds * 10);
// add seed vertices
for (size_t i = 0; i < num_seeds; ++i) {
- ver_node node;
- node.vertex_id = seed[i];
- node.level = 0;
- node_queue.push(node);
+ auto ret = sub_ver_mp.insert(seed[i]);
+ // If the vertex is inserted successfully.
+ if (ret.second) {
+ sub_vers.emplace_back(seed[i], 0);
+ }
}
- std::vector<dgl_id_t> tmp_src_list;
- std::vector<dgl_id_t> tmp_edge_list;
std::vector<dgl_id_t> tmp_sampled_src_list;
std::vector<dgl_id_t> tmp_sampled_edge_list;
- std::unordered_map<dgl_id_t, neigh_list> neigh_mp;
+ // ver_id, position
+ std::vector<std::pair<dgl_id_t, size_t> > neigh_pos;
+ neigh_pos.reserve(num_seeds);
+ std::vector<dgl_id_t> neighbor_list;
size_t num_edges = 0;
- while (!node_queue.empty() &&
- sub_vertices_count <= max_num_vertices ) {
- ver_node& cur_node = node_queue.front();
- dgl_id_t dst_id = cur_node.vertex_id;
- if (cur_node.level < num_hops) {
- auto ret = sub_ver_mp.find(dst_id);
- if (ret != sub_ver_mp.end()) {
- node_queue.pop();
- continue;
- }
- tmp_src_list.clear();
- tmp_edge_list.clear();
- tmp_sampled_src_list.clear();
- tmp_sampled_edge_list.clear();
- GetSrcList(val_list,
- col_list,
- indptr,
- dst_id,
- &tmp_src_list,
- &tmp_edge_list);
- if (probability == nullptr) { // uniform-sample
- GetUniformSample(tmp_src_list,
- tmp_edge_list,
+
+ // sub_vers is used both as a node collection and a queue.
+ // In the while loop, we iterate over sub_vers and new nodes are added to
the vector.
+ // A vertex in the vector only needs to be accessed once. If there is a
vertex behind idx
+ // isn't in the last level, we will sample its neighbors. If not, the while
loop terminates.
+ size_t idx = 0;
+ while (idx < sub_vers.size() &&
+ sub_ver_mp.size() < max_num_vertices) {
+ dgl_id_t dst_id = sub_vers[idx].first;
+ int cur_node_level = sub_vers[idx].second;
+ idx++;
+ // If the node is in the last level, we don't need to sample neighbors
+ // from this node.
+ if (cur_node_level >= num_hops)
+ continue;
+
+ tmp_sampled_src_list.clear();
+ tmp_sampled_edge_list.clear();
+ dgl_id_t ver_len = *(indptr+dst_id+1) - *(indptr+dst_id);
+ if (probability == nullptr) { // uniform-sample
+ GetUniformSample(val_list + *(indptr + dst_id),
+ col_list + *(indptr + dst_id),
+ ver_len,
num_neighbor,
&tmp_sampled_src_list,
&tmp_sampled_edge_list,
&time_seed);
- } else { // non-uniform-sample
- GetNonUniformSample(probability,
- tmp_src_list,
- tmp_edge_list,
+ } else { // non-uniform-sample
+ GetNonUniformSample(probability,
+ val_list + *(indptr + dst_id),
+ col_list + *(indptr + dst_id),
+ ver_len,
num_neighbor,
&tmp_sampled_src_list,
&tmp_sampled_edge_list,
&time_seed);
- }
- neigh_mp.insert(std::pair<dgl_id_t, neigh_list>(dst_id,
- neigh_list(tmp_sampled_src_list,
- tmp_sampled_edge_list)));
- num_edges += tmp_sampled_src_list.size();
- sub_ver_mp[cur_node.vertex_id] = cur_node.level;
- for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {
- auto ret = sub_ver_mp.find(tmp_sampled_src_list[i]);
- if (ret == sub_ver_mp.end()) {
- ver_node new_node;
- new_node.vertex_id = tmp_sampled_src_list[i];
- new_node.level = cur_node.level + 1;
- node_queue.push(new_node);
- }
- }
- } else { // vertex without any neighbor
- auto ret = sub_ver_mp.find(dst_id);
- if (ret != sub_ver_mp.end()) {
- node_queue.pop();
- continue;
- }
- tmp_sampled_src_list.clear();
- tmp_sampled_edge_list.clear();
- neigh_mp.insert(std::pair<dgl_id_t, neigh_list>(dst_id,
- neigh_list(tmp_sampled_src_list, // empty vector
- tmp_sampled_edge_list))); // empty vector
- sub_ver_mp[cur_node.vertex_id] = cur_node.level;
}
- sub_vertices_count++;
- node_queue.pop();
+ CHECK_EQ(tmp_sampled_src_list.size(), tmp_sampled_edge_list.size());
+ size_t pos = neighbor_list.size();
+ neigh_pos.emplace_back(dst_id, pos);
+ // First we push the size of neighbor vector
+ neighbor_list.push_back(tmp_sampled_edge_list.size());
+ // Then push the vertices
+ for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {
+ neighbor_list.push_back(tmp_sampled_src_list[i]);
+ }
+ // Finally we push the edge list
+ for (size_t i = 0; i < tmp_sampled_edge_list.size(); ++i) {
+ neighbor_list.push_back(tmp_sampled_edge_list[i]);
+ }
+ num_edges += tmp_sampled_src_list.size();
+ for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {
+ // If we have sampled the max number of vertices, we have to stop.
+ if (sub_ver_mp.size() >= max_num_vertices)
+ break;
+ // We need to add the neighbor in the hashtable here. This ensures that
+ // the vertex in the queue is unique. If we see a vertex before, we don't
+ // need to add it to the queue again.
+ auto ret = sub_ver_mp.insert(tmp_sampled_src_list[i]);
+ // If the sampled neighbor is inserted to the map successfully.
+ if (ret.second)
+ sub_vers.emplace_back(tmp_sampled_src_list[i], cur_node_level + 1);
+ }
+ }
+ // Let's check if there is a vertex that we haven't sampled its neighbors.
+ for (; idx < sub_vers.size(); idx++) {
+ if (sub_vers[idx].second < num_hops) {
+ LOG(WARNING)
Review comment:
it works. i can see warning messages after the number of sampled vertices
exceeds the maximal number.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services