eric-haibin-lin commented on a change in pull request #13290: A few operators on graphs stored as CSR URL: https://github.com/apache/incubator-mxnet/pull/13290#discussion_r235162543
########## File path: src/operator/contrib/dgl_graph.cc ########## @@ -0,0 +1,466 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <mxnet/io.h> +#include <mxnet/base.h> +#include <mxnet/ndarray.h> +#include <mxnet/operator.h> +#include <mxnet/operator_util.h> +#include <dmlc/logging.h> +#include <dmlc/optional.h> +#include "../operator_common.h" +#include "../elemwise_op_common.h" +#include "../../imperative/imperative_utils.h" +#include "../subgraph_op_common.h" +#include "../mshadow_op.h" +#include "../mxnet_op.h" +#include "../tensor/init_op.h" + +namespace mxnet { +namespace op { + + +///////////////////////// Create induced subgraph /////////////////////////// + +struct DGLSubgraphParam : public dmlc::Parameter<DGLSubgraphParam> { + int num_args; + bool return_mapping; + DMLC_DECLARE_PARAMETER(DGLSubgraphParam) { + DMLC_DECLARE_FIELD(num_args).set_lower_bound(2) + .describe("Number of input arguments, including all symbol inputs."); + DMLC_DECLARE_FIELD(return_mapping) + .describe("Return mapping of vid and eid between the subgraph and the parent graph."); + } +}; // struct DGLSubgraphParam + +DMLC_REGISTER_PARAMETER(DGLSubgraphParam); + +static bool DGLSubgraphStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector<int> *in_attrs, + std::vector<int> *out_attrs) { + CHECK_EQ(in_attrs->at(0), kCSRStorage); + for (size_t i = 1; i < in_attrs->size(); i++) + CHECK_EQ(in_attrs->at(i), kDefaultStorage); + + bool success = true; + *dispatch_mode = DispatchMode::kFComputeEx; + for (size_t i = 0; i < out_attrs->size(); i++) { + if (!type_assign(&(*out_attrs)[i], mxnet::kCSRStorage)) + success = false; + } + return success; +} + +static bool DGLSubgraphShape(const nnvm::NodeAttrs& attrs, + std::vector<TShape> *in_attrs, + std::vector<TShape> *out_attrs) { + const DGLSubgraphParam& params = nnvm::get<DGLSubgraphParam>(attrs.parsed); + CHECK_EQ(in_attrs->at(0).ndim(), 2U); + for (size_t i = 1; i < in_attrs->size(); i++) + CHECK_EQ(in_attrs->at(i).ndim(), 1U); + + size_t num_g = params.num_args - 1; + for (size_t i = 0; i < num_g; i++) { + TShape gshape(2); + gshape[0] = in_attrs->at(i + 1)[0]; + gshape[1] = in_attrs->at(i + 1)[0]; + out_attrs->at(i) = gshape; + } + for (size_t i = num_g; i < out_attrs->size(); i++) { + TShape gshape(2); + gshape[0] = in_attrs->at(i - num_g + 1)[0]; + gshape[1] = in_attrs->at(i - num_g + 1)[0]; + out_attrs->at(i) = gshape; + } + return true; +} + +static bool DGLSubgraphType(const nnvm::NodeAttrs& attrs, + std::vector<int> *in_attrs, + std::vector<int> *out_attrs) { + const DGLSubgraphParam& params = nnvm::get<DGLSubgraphParam>(attrs.parsed); + size_t num_g = params.num_args - 1; + for (size_t i = 0; i < num_g; i++) { + CHECK_EQ(in_attrs->at(i + 1), mshadow::kInt64); + } + for (size_t i = 0; i < out_attrs->size(); i++) { + out_attrs->at(i) = in_attrs->at(0); + } + return true; +} + +typedef int64_t dgl_id_t; + +class Bitmap { + const size_t size = 1024 * 1024 * 4; + const size_t mask = size - 1; + std::vector<bool> map; + + size_t hash(dgl_id_t id) const { + return id & mask; + } + public: + Bitmap(const dgl_id_t *vid_data, int64_t len): map(size) { + for (int64_t i = 0; i < len; ++i) { + map[hash(vid_data[i])] = 1; + } + } + + bool test(dgl_id_t id) const { + return map[hash(id)]; + } +}; + +/* + * This uses a hashtable to check if a node is in the given node list. + */ +class HashTableChecker { + std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv; + Bitmap map; + + public: + HashTableChecker(const dgl_id_t *vid_data, int64_t len): map(vid_data, len) { + oldv2newv.reserve(len); + for (int64_t i = 0; i < len; ++i) { + oldv2newv[vid_data[i]] = i; + } + } + + void CollectOnRow(const dgl_id_t col_idx[], const dgl_id_t eids[], size_t row_len, + std::vector<dgl_id_t> *new_col_idx, + std::vector<dgl_id_t> *orig_eids) { + // TODO(zhengda) I need to make sure the column index in each row is sorted. + for (size_t j = 0; j < row_len; ++j) { + const dgl_id_t oldsucc = col_idx[j]; + const dgl_id_t eid = eids[j]; + Collect(oldsucc, eid, new_col_idx, orig_eids); + } + } + + void Collect(const dgl_id_t old_id, const dgl_id_t old_eid, + std::vector<dgl_id_t> *col_idx, + std::vector<dgl_id_t> *orig_eids) { + if (!map.test(old_id)) + return; + + auto it = oldv2newv.find(old_id); + if (it != oldv2newv.end()) { + const dgl_id_t new_id = it->second; + col_idx->push_back(new_id); + if (orig_eids) + orig_eids->push_back(old_eid); + } + } +}; + +static void GetSubgraph(const NDArray &csr_arr, const NDArray &varr, + const NDArray &sub_csr, const NDArray *old_eids) { + const TBlob &data = varr.data(); + int64_t num_vertices = csr_arr.shape()[0]; + const size_t len = varr.shape()[0]; + const dgl_id_t *vid_data = data.dptr<dgl_id_t>(); + HashTableChecker def_check(vid_data, len); + // check if varr is sorted. + std::is_sorted(vid_data, vid_data + len); Review comment: It returns True if it is actually sorted. But the return value is not used anywhere. Is this intended? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
