diff --git a/src/operator/contrib/dgl_graph.cc 
b/src/operator/contrib/dgl_graph.cc
index ed7caacfdba..ff866375056 100644
--- a/src/operator/contrib/dgl_graph.cc
+++ b/src/operator/contrib/dgl_graph.cc
@@ -413,21 +413,6 @@ static bool CSRNeighborNonUniformSampleType(const 
nnvm::NodeAttrs& attrs,
   return success;
 }
 
-/*
- * Get src vertex and edge id for a destination vertex
- */
-static void GetSrcList(const dgl_id_t* val_list,
-                       const dgl_id_t* col_list,
-                       const dgl_id_t* indptr,
-                       const dgl_id_t dst_id,
-                       std::vector<dgl_id_t>* src_list,
-                       std::vector<dgl_id_t>* edge_list) {
-  for (dgl_id_t i = *(indptr+dst_id); i < *(indptr+dst_id+1); ++i) {
-    src_list->push_back(col_list[i]);
-    edge_list->push_back(val_list[i]);
-  }
-}
-
 static void RandomSample(size_t set_size,
                          size_t num,
                          std::vector<size_t>* out,
@@ -464,34 +449,38 @@ static void NegateSet(const std::vector<size_t> &idxs,
 /*
  * Uniform sample
  */
-static void GetUniformSample(const std::vector<dgl_id_t>& ver_list,
-                             const std::vector<dgl_id_t>& edge_list,
+static void GetUniformSample(const dgl_id_t* val_list,
+                             const dgl_id_t* col_list,
+                             const dgl_id_t* indptr,
+                             const dgl_id_t dst_id,
                              const size_t max_num_neighbor,
                              std::vector<dgl_id_t>* out_ver,
                              std::vector<dgl_id_t>* out_edge,
                              unsigned int* seed) {
-  CHECK_EQ(ver_list.size(), edge_list.size());
+  size_t ver_len = *(indptr+dst_id+1) - *(indptr+dst_id);
   // Copy ver_list to output
-  if (ver_list.size() <= max_num_neighbor) {
-    for (size_t i = 0; i < ver_list.size(); ++i) {
-      out_ver->push_back(ver_list[i]);
-      out_edge->push_back(edge_list[i]);
+  if (ver_len <= max_num_neighbor) {
+    for (dgl_id_t i = *(indptr+dst_id); i < *(indptr+dst_id+1); ++i) {
+      out_ver->push_back(col_list[i]);
+      out_edge->push_back(val_list[i]);
     }
     return;
   }
   // If we just sample a small number of elements from a large neighbor list.
+  const dgl_id_t* col_ptr = col_list + *(indptr + dst_id);
+  const dgl_id_t* val_ptr = val_list + *(indptr + dst_id);
   std::vector<size_t> sorted_idxs;
-  if (ver_list.size() > max_num_neighbor * 2) {
+  if (ver_len > max_num_neighbor * 2) {
     sorted_idxs.reserve(max_num_neighbor);
-    RandomSample(ver_list.size(), max_num_neighbor, &sorted_idxs, seed);
+    RandomSample(ver_len, max_num_neighbor, &sorted_idxs, seed);
     std::sort(sorted_idxs.begin(), sorted_idxs.end());
   } else {
     std::vector<size_t> negate;
-    negate.reserve(ver_list.size() - max_num_neighbor);
-    RandomSample(ver_list.size(), ver_list.size() - max_num_neighbor,
+    negate.reserve(ver_len - max_num_neighbor);
+    RandomSample(ver_len, ver_len - max_num_neighbor,
                  &negate, seed);
     std::sort(negate.begin(), negate.end());
-    NegateSet(negate, ver_list.size(), &sorted_idxs);
+    NegateSet(negate, ver_len, &sorted_idxs);
   }
   // verify the result.
   CHECK_EQ(sorted_idxs.size(), max_num_neighbor);
@@ -499,8 +488,8 @@ static void GetUniformSample(const std::vector<dgl_id_t>& 
ver_list,
     CHECK_GT(sorted_idxs[i], sorted_idxs[i - 1]);
   }
   for (auto idx : sorted_idxs) {
-    out_ver->push_back(ver_list[idx]);
-    out_edge->push_back(edge_list[idx]);
+    out_ver->push_back(col_ptr[idx]);
+    out_edge->push_back(val_ptr[idx]);
   }
 }
 
@@ -508,26 +497,30 @@ static void GetUniformSample(const std::vector<dgl_id_t>& 
ver_list,
  * Non-uniform sample via ArrayHeap
  */
 static void GetNonUniformSample(const float* probability,
-                                const std::vector<dgl_id_t>& ver_list,
-                                const std::vector<dgl_id_t>& edge_list,
+                                const dgl_id_t* val_list,
+                                const dgl_id_t* col_list,
+                                const dgl_id_t* indptr,
+                                const dgl_id_t dst_id,
                                 const size_t max_num_neighbor,
                                 std::vector<dgl_id_t>* out_ver,
                                 std::vector<dgl_id_t>* out_edge,
                                 unsigned int* seed) {
-  CHECK_EQ(ver_list.size(), edge_list.size());
+  size_t ver_len = *(indptr+dst_id+1) - *(indptr+dst_id);
   // Copy ver_list to output
-  if (ver_list.size() <= max_num_neighbor) {
-    for (size_t i = 0; i < ver_list.size(); ++i) {
-      out_ver->push_back(ver_list[i]);
-      out_edge->push_back(edge_list[i]);
+  if (ver_len <= max_num_neighbor) {
+    for (dgl_id_t i = *(indptr+dst_id); i < *(indptr+dst_id+1); ++i) {
+      out_ver->push_back(col_list[i]);
+      out_edge->push_back(val_list[i]);
     }
     return;
   }
   // Make sample
+  const dgl_id_t* col_ptr = col_list + *(indptr + dst_id);
+  const dgl_id_t* val_ptr = val_list + *(indptr + dst_id);
   std::vector<size_t> sp_index(max_num_neighbor);
-  std::vector<float> sp_prob(ver_list.size());
-  for (size_t i = 0; i < ver_list.size(); ++i) {
-    sp_prob[i] = probability[ver_list[i]];
+  std::vector<float> sp_prob(ver_len);
+  for (size_t i = 0; i < ver_len; ++i) {
+    sp_prob[i] = probability[col_ptr[i]];
   }
   ArrayHeap arrayHeap(sp_prob);
   arrayHeap.SampleWithoutReplacement(max_num_neighbor, &sp_index, seed);
@@ -535,8 +528,8 @@ static void GetNonUniformSample(const float* probability,
   out_edge->resize(max_num_neighbor);
   for (size_t i = 0; i < max_num_neighbor; ++i) {
     size_t idx = sp_index[i];
-    out_ver->at(i) = ver_list[idx];
-    out_edge->at(i) = edge_list[idx];
+    out_ver->at(i) = col_ptr[idx];
+    out_edge->at(i) = val_ptr[idx];
   }
   sort(out_ver->begin(), out_ver->end());
   sort(out_edge->begin(), out_edge->end());
@@ -597,14 +590,15 @@ static void SampleSubgraph(const NDArray &csr,
     node.level = 0;
     node_queue.push(node);
   }
-  std::vector<dgl_id_t> tmp_src_list;
-  std::vector<dgl_id_t> tmp_edge_list;
   std::vector<dgl_id_t> tmp_sampled_src_list;
   std::vector<dgl_id_t> tmp_sampled_edge_list;
-  std::unordered_map<dgl_id_t, neigh_list> neigh_mp;
+  // ver_id, position
+  std::unordered_map<dgl_id_t, size_t> neigh_pos;
+  std::vector<dgl_id_t> neighbor_list;
   size_t num_edges = 0;
+
   while (!node_queue.empty() &&
-    sub_vertices_count <= max_num_vertices ) {
+    sub_vertices_count < max_num_vertices) {
     ver_node& cur_node = node_queue.front();
     dgl_id_t dst_id = cur_node.vertex_id;
     if (cur_node.level < num_hops) {
@@ -613,35 +607,42 @@ static void SampleSubgraph(const NDArray &csr,
         node_queue.pop();
         continue;
       }
-      tmp_src_list.clear();
-      tmp_edge_list.clear();
       tmp_sampled_src_list.clear();
       tmp_sampled_edge_list.clear();
-      GetSrcList(val_list,
-                 col_list,
-                 indptr,
-                 dst_id,
-                 &tmp_src_list,
-                 &tmp_edge_list);
       if (probability == nullptr) {  // uniform-sample
-        GetUniformSample(tmp_src_list,
-                       tmp_edge_list,
+        GetUniformSample(val_list,
+                       col_list,
+                       indptr,
+                       dst_id,
                        num_neighbor,
                        &tmp_sampled_src_list,
                        &tmp_sampled_edge_list,
                        &time_seed);
       } else {  // non-uniform-sample
         GetNonUniformSample(probability,
-                       tmp_src_list,
-                       tmp_edge_list,
+                       val_list,
+                       col_list,
+                       indptr,
+                       dst_id,
                        num_neighbor,
                        &tmp_sampled_src_list,
                        &tmp_sampled_edge_list,
                        &time_seed);
       }
-      neigh_mp.insert(std::pair<dgl_id_t, neigh_list>(dst_id,
-        neigh_list(tmp_sampled_src_list,
-                   tmp_sampled_edge_list)));
+      CHECK_EQ(tmp_sampled_src_list.size(),
+               tmp_sampled_edge_list.size());
+      size_t pos = neighbor_list.size();
+      neigh_pos[dst_id] = pos;
+      // First we push the size of neighbor vector
+      neighbor_list.push_back(tmp_sampled_edge_list.size());
+      // Then push the vertices
+      for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {
+        neighbor_list.push_back(tmp_sampled_src_list[i]);
+      }
+      // Finally we push the edge list
+      for (size_t i = 0; i < tmp_sampled_edge_list.size(); ++i) {
+        neighbor_list.push_back(tmp_sampled_edge_list[i]);
+      }
       num_edges += tmp_sampled_src_list.size();
       sub_ver_mp[cur_node.vertex_id] = cur_node.level;
       for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {
@@ -659,11 +660,9 @@ static void SampleSubgraph(const NDArray &csr,
         node_queue.pop();
         continue;
       }
-      tmp_sampled_src_list.clear();
-      tmp_sampled_edge_list.clear();
-      neigh_mp.insert(std::pair<dgl_id_t, neigh_list>(dst_id,
-        neigh_list(tmp_sampled_src_list,      // empty vector
-                   tmp_sampled_edge_list)));  // empty vector
+      size_t pos = neighbor_list.size();
+      neigh_pos[dst_id] = pos;
+      neighbor_list.push_back(0);
       sub_ver_mp[cur_node.vertex_id] = cur_node.level;
     }
     sub_vertices_count++;
@@ -720,16 +719,18 @@ static void SampleSubgraph(const NDArray &csr,
   size_t collected_nedges = 0;
   for (size_t i = 0; i < num_vertices; i++) {
     dgl_id_t dst_id = *(out + i);
-    auto it = neigh_mp.find(dst_id);
-    const auto &edges = it->second.edges;
-    const auto &neighs = it->second.neighs;
-    CHECK_EQ(edges.size(), neighs.size());
-    if (!edges.empty()) {
-      std::copy(edges.begin(), edges.end(), val_list_out + collected_nedges);
-      std::copy(neighs.begin(), neighs.end(), col_list_out + collected_nedges);
-      collected_nedges += edges.size();
+    size_t pos = neigh_pos[dst_id];
+    size_t edge_size = neighbor_list[pos];
+    if (edge_size != 0) {
+      std::copy_n(neighbor_list.begin() + pos + 1,
+                  edge_size,
+                  col_list_out + collected_nedges);
+      std::copy_n(neighbor_list.begin() + pos + edge_size + 1,
+                  edge_size,
+                  val_list_out + collected_nedges);
+      collected_nedges += edge_size;
     }
-    indptr_out[i+1] = indptr_out[i] + edges.size();
+    indptr_out[i+1] = indptr_out[i] + edge_size;
   }
   for (dgl_id_t i = num_vertices+1; i <= max_num_vertices; ++i) {
     indptr_out[i] = indptr_out[i-1];


With regards,
Apache Git Services

Reply via email to