hypercubestart commented on a change in pull request #6336:
URL: https://github.com/apache/incubator-tvm/pull/6336#discussion_r476912921



##########
File path: src/relay/transforms/gradient.cc
##########
@@ -85,7 +85,7 @@ Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) {
   if (mod.defined() && x) {
     BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x));
     if (auto* n = base_func.as<FunctionNode>()) {

Review comment:
       why do we need if else here? doesn't every GlobalVar map to a Function?

##########
File path: src/relay/transforms/gradient.cc
##########
@@ -559,6 +581,39 @@ struct ReverseAD : ExprMutator {
     return ad_vars->at(var_ref);
   }
 
+  Expr VisitExpr_(const GlobalVarNode* op) final {
+    // todo: concatenating string to add attribute seems like a brittle hack.
+    // maybe get module indexed by a rose tree of string?
+    CHECK(mod.defined());
+    auto orig_gv = GetRef<GlobalVar>(op);
+    if (ad_gvars->count(orig_gv) == 0) {
+      GlobalVar gv(op->name_hint + "_grad");
+      (*ad_gvars)[orig_gv] = gv;
+      Function orig_f = 
Downcast<Function>(mod.value()->Lookup(GetRef<GlobalVar>(op)));

Review comment:
       mod.value()->Lookup(orig_gv)

##########
File path: src/relay/transforms/gradient.cc
##########
@@ -559,6 +581,39 @@ struct ReverseAD : ExprMutator {
     return ad_vars->at(var_ref);
   }
 
+  Expr VisitExpr_(const GlobalVarNode* op) final {
+    // todo: concatenating string to add attribute seems like a brittle hack.
+    // maybe get module indexed by a rose tree of string?
+    CHECK(mod.defined());
+    auto orig_gv = GetRef<GlobalVar>(op);
+    if (ad_gvars->count(orig_gv) == 0) {
+      GlobalVar gv(op->name_hint + "_grad");
+      (*ad_gvars)[orig_gv] = gv;
+      Function orig_f = 
Downcast<Function>(mod.value()->Lookup(GetRef<GlobalVar>(op)));

Review comment:
       would it be a good idea to DeDup here?

##########
File path: tests/python/relay/test_pass_gradient.py
##########
@@ -341,5 +357,19 @@ def test_no_duplication():
     counts = count_ops(gr)
     assert counts['nn.dense'] == 3, "We expect 3 dense (1 forward, two 
backward)"
 
+
+def test_global_function():
+    m = tvm.IRModule()
+    t = relay.TensorType([])
+    x = relay.Var('x', t)
+    d = GlobalVar('double')
+    m[d] = relay.Function([x], x + x)
+    y = relay.Var('y', t)
+    q = GlobalVar('q')
+    m[q] = relay.Function([y], d(d(y)))
+    g = GlobalVar('grad')
+    m[g] = tvm.relay.transform.gradient(q, m)
+

Review comment:
       add type and value check




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to