ZiyueHuang commented on a change in pull request #8632: [WIP] a user friendly way to use g2c in module and an example of g2c URL: https://github.com/apache/incubator-mxnet/pull/8632#discussion_r151738034
########## File path: example/sparse/matrix_fact_parallel_model.py ########## @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx + +def matrix_fact_model_parallel_net(factor_size, num_hidden, max_user, max_item): + # set ctx_group attribute to 'dev1' for the symbols created in this scope, + # the symbols will be binded to the context that 'dev1' map to in group2ctxs + with mx.AttrScope(ctx_group='dev1'): + # input + user = mx.symbol.Variable('user') + item = mx.symbol.Variable('item') + # user feature lookup + user_weight = mx.symbol.Variable('user_weight', stype='row_sparse') + user = mx.symbol.contrib.SparseEmbedding(data=user, weight=user_weight, + input_dim=max_user, output_dim=factor_size) + # item feature lookup + item_weight = mx.symbol.Variable('item_weight', stype='row_sparse') + item = mx.symbol.contrib.SparseEmbedding(data=item, weight=item_weight, + input_dim=max_item, output_dim=factor_size) + # non-linear transformation of user features Review comment: @eric-haibin-lin I added some codes in `graph_executor.cc` for debug, `python matrix_factorization_model_parallel.py` ``` [00:53:08] src/executor/graph_executor.cc:365: args context [00:53:08] src/executor/graph_executor.cc:384: nid: 0 ctx.dev_id 0 [00:53:08] src/executor/graph_executor.cc:384: nid: 1 ctx.dev_id 0 [00:53:08] src/executor/graph_executor.cc:384: nid: 3 ctx.dev_id 1 [00:53:08] src/executor/graph_executor.cc:384: nid: 4 ctx.dev_id 1 [00:53:08] src/executor/graph_executor.cc:384: nid: 6 ctx.dev_id 0 [00:53:08] src/executor/graph_executor.cc:384: nid: 7 ctx.dev_id 0 [00:53:08] src/executor/graph_executor.cc:384: nid: 12 ctx.dev_id 1 [00:53:08] src/executor/graph_executor.cc:386: ===================== [00:53:08] src/executor/graph_executor.cc:387: 1 num_forward_outputs [00:53:08] src/executor/graph_executor.cc:388: 5 g.outputs.size() [00:53:08] src/executor/graph_executor.cc:389: 7 arg_grad_ctxes.size() [00:53:08] src/executor/graph_executor.cc:393: arg grads contexts [00:53:08] src/executor/graph_executor.cc:397: nid 19 ctx 0 [00:53:08] src/executor/graph_executor.cc:397: nid 18 ctx 0 [00:53:08] src/executor/graph_executor.cc:397: nid 18 ctx 1 [00:53:08] src/executor/graph_executor.cc:397: nid 20 ctx 1 [00:53:08] src/executor/graph_executor.cc:399: ===================== [00:53:08] src/executor/graph_executor.cc:409: fail nid 18 ctx 1 [00:53:08] src/executor/graph_executor.cc:423: node 0 var user [00:53:08] src/executor/graph_executor.cc:423: node 1 var user_weight [00:53:08] src/executor/graph_executor.cc:425: node 2 _contrib_SparseEmbedding [00:53:08] src/executor/graph_executor.cc:428: input 0 (entry id) [00:53:08] src/executor/graph_executor.cc:428: input 1 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 2 (entry id) [00:53:08] src/executor/graph_executor.cc:423: node 3 var ufcweight [00:53:08] src/executor/graph_executor.cc:423: node 4 var ufcbias [00:53:08] src/executor/graph_executor.cc:425: node 5 FullyConnected [00:53:08] src/executor/graph_executor.cc:428: input 2 (entry id) [00:53:08] src/executor/graph_executor.cc:428: input 3 (entry id) [00:53:08] src/executor/graph_executor.cc:428: input 4 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 5 (entry id) [00:53:08] src/executor/graph_executor.cc:423: node 6 var item [00:53:08] src/executor/graph_executor.cc:423: node 7 var item_weight [00:53:08] src/executor/graph_executor.cc:425: node 8 _contrib_SparseEmbedding [00:53:08] src/executor/graph_executor.cc:428: input 6 (entry id) [00:53:08] src/executor/graph_executor.cc:428: input 7 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 8 (entry id) [00:53:08] src/executor/graph_executor.cc:425: node 9 elemwise_mul [00:53:08] src/executor/graph_executor.cc:428: input 5 (entry id) [00:53:08] src/executor/graph_executor.cc:428: input 8 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 9 (entry id) [00:53:08] src/executor/graph_executor.cc:425: node 10 sum [00:53:08] src/executor/graph_executor.cc:428: input 9 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 10 (entry id) [00:53:08] src/executor/graph_executor.cc:425: node 11 Flatten [00:53:08] src/executor/graph_executor.cc:428: input 10 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 11 (entry id) [00:53:08] src/executor/graph_executor.cc:423: node 12 var score [00:53:08] src/executor/graph_executor.cc:425: node 13 LinearRegressionOutput [00:53:08] src/executor/graph_executor.cc:428: input 11 (entry id) [00:53:08] src/executor/graph_executor.cc:428: input 12 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 13 (entry id) [00:53:08] src/executor/graph_executor.cc:425: node 14 _backward_LinearRegressionOutput [00:53:08] src/executor/graph_executor.cc:428: input 12 (entry id) [00:53:08] src/executor/graph_executor.cc:428: input 13 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 14 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 15 (entry id) [00:53:08] src/executor/graph_executor.cc:425: node 15 _backward_copy [00:53:08] src/executor/graph_executor.cc:428: input 14 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 16 (entry id) [00:53:08] src/executor/graph_executor.cc:425: node 16 _backward_sum [00:53:08] src/executor/graph_executor.cc:428: input 16 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 17 (entry id) [00:53:08] src/executor/graph_executor.cc:425: node 17 _backward_mul [00:53:08] src/executor/graph_executor.cc:428: input 17 (entry id) [00:53:08] src/executor/graph_executor.cc:428: input 5 (entry id) [00:53:08] src/executor/graph_executor.cc:428: input 8 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 18 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 19 (entry id) [00:53:08] src/executor/graph_executor.cc:425: node 18 _backward_FullyConnected [00:53:08] src/executor/graph_executor.cc:428: input 18 (entry id) [00:53:08] src/executor/graph_executor.cc:428: input 2 (entry id) [00:53:08] src/executor/graph_executor.cc:428: input 3 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 20 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 21 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 22 (entry id) [00:53:08] src/executor/graph_executor.cc:425: node 19 _backward_SparseEmbedding [00:53:08] src/executor/graph_executor.cc:428: input 20 (entry id) [00:53:08] src/executor/graph_executor.cc:428: input 0 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 23 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 24 (entry id) [00:53:08] src/executor/graph_executor.cc:425: node 20 _backward_SparseEmbedding [00:53:08] src/executor/graph_executor.cc:428: input 19 (entry id) [00:53:08] src/executor/graph_executor.cc:428: input 6 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 25 (entry id) [00:53:08] src/executor/graph_executor.cc:432: output 26 (entry id) [00:53:08] /home/hanfeng/zyh/zyhmxnet/dmlc-core/include/dmlc/./logging.h:308: [00:53:08] src/executor/graph_executor.cc:436: Check failed: device[nid] == devid (0 vs. 1) fullyconnected0_backward device of same output not equal to each other ``` So as you can see, the contexts of `node 3 var ufcweight` and `node 4 var ufcbias` are at `dev1`, but the contexts of their grads are at `dev1` and `dev2` because the outputs below `arg grads contexts` ``` [00:53:08] src/executor/graph_executor.cc:393: arg grads contexts [00:53:08] src/executor/graph_executor.cc:397: nid 19 ctx 0 [00:53:08] src/executor/graph_executor.cc:397: nid 18 ctx 0 [00:53:08] src/executor/graph_executor.cc:397: nid 18 ctx 1 [00:53:08] src/executor/graph_executor.cc:397: nid 20 ctx 1 ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services