This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new e35b7fc [Relay][Training] Make AutoDiff thread through global
function. (#6336)
e35b7fc is described below
commit e35b7fc4bcdcfe008c5dfea60c2297b93dbff99e
Author: 雾雨魔理沙 <[email protected]>
AuthorDate: Thu Aug 27 11:32:40 2020 -0700
[Relay][Training] Make AutoDiff thread through global function. (#6336)
* save
* lint
* lint
* fix warning
* fix test
* save
---
src/printer/doc.cc | 2 +-
src/relay/transforms/gradient.cc | 106 ++++++++++++++++++++++++-------
tests/python/relay/test_pass_gradient.py | 41 +++++++++++-
3 files changed, 124 insertions(+), 25 deletions(-)
diff --git a/src/printer/doc.cc b/src/printer/doc.cc
index d487e3e..ab1eddb 100644
--- a/src/printer/doc.cc
+++ b/src/printer/doc.cc
@@ -129,7 +129,7 @@ Doc Doc::Indent(int indent, Doc doc) {
}
Doc Doc::StrLiteral(const std::string& value, std::string quote) {
- // TODO(M.K.): add escape.
+ // TODO(@M.K.): add escape.
Doc doc;
return doc << quote << value << quote;
}
diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc
index 7894c34..9c47254 100644
--- a/src/relay/transforms/gradient.cc
+++ b/src/relay/transforms/gradient.cc
@@ -72,7 +72,7 @@ Type WithGradientType(const Type&);
Expr FirstOrderGradient(const Expr& e, const Optional<IRModule>& mod);
Type WithGradientType(const Type& t) {
- // TODO(M.K.): stricter checking
+ // TODO(@M.K.): stricter checking
auto ty = t.as<FuncTypeNode>();
CHECK(ty) << "input should be a function";
return FuncType(ty->arg_types, TupleType({ty->ret_type,
TupleType(ty->arg_types)}), {}, {});
@@ -85,7 +85,7 @@ Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) {
if (mod.defined() && x) {
BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x));
if (auto* n = base_func.as<FunctionNode>()) {
- return n->body;
+ return GetRef<Function>(n);
} else {
return e;
}
@@ -338,11 +338,22 @@ Expr FirstOrderGradient(const Expr& re, const
Optional<IRModule>& mod) {
TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient").set_body_typed(FirstOrderGradient);
+Type bpt = RelayRefType(FuncType({}, TupleType(Array<Type>()), {}, {}));
+
struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final {
Type t = GetRef<Type>(ttn);
return TupleType({t, RelayRefType(t)});
}
+
+ Type VisitType_(const FuncTypeNode* ftn) final {
+ std::vector<Type> arg_types;
+ for (const auto& t : ftn->arg_types) {
+ arg_types.push_back(VisitType(t));
+ }
+ arg_types.push_back(bpt);
+ return FuncType(arg_types, ftn->ret_type, ftn->type_params,
ftn->type_constraints);
+ }
};
Type ReverseType(const Type& t) { return ReverseADType()(t); }
@@ -438,12 +449,18 @@ Expr BPEmpty() {
struct ReverseAD : ExprMutator {
using ADVarMap = std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>;
-
+ using ADGlobalVarMap = std::unordered_map<GlobalVar, GlobalVar,
ObjectPtrHash, ObjectPtrEqual>;
+ Optional<IRModule> mod;
+ // TODO(@M.K.) refactor AD to always use mod.
Var bp;
std::shared_ptr<ADVarMap> ad_vars;
+ std::shared_ptr<ADGlobalVarMap> ad_gvars;
const OpAttrMap<FPrimalGradient> rev_map =
Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
- explicit ReverseAD(const Var& bp, std::shared_ptr<ADVarMap> ad_vars) :
bp(bp), ad_vars(ad_vars) {}
+ explicit ReverseAD(const Optional<IRModule>& mod, const Var& bp,
+ const std::shared_ptr<ADVarMap>& ad_vars,
+ const std::shared_ptr<ADGlobalVarMap>& ad_gvars)
+ : mod(mod), bp(bp), ad_vars(ad_vars), ad_gvars(ad_gvars) {}
Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call";
@@ -481,9 +498,8 @@ struct ReverseAD : ExprMutator {
Expr nbp = Function({}, LetList::With([&](LetList* ll) {
// we need a new ReverseAD visitor to avoid
clobbering the bp local var
auto dup_bp = ll->Push(BPEmpty());
- ReverseAD dup_diff(dup_bp, ad_vars);
- auto dup_ad =
ll->Push(dup_diff.VisitExpr(DeDup(x)));
-
+ auto dup_ad =
+ ll->Push(ReverseAD(mod, dup_bp, ad_vars,
ad_gvars)(DeDup(x)));
TransferGrads(call->checked_type(), ret, dup_ad,
ll);
ll->Push(Call(RefRead(dup_bp), {}));
return Call(bpv, {});
@@ -518,22 +534,29 @@ struct ReverseAD : ExprMutator {
orig_var->checked_type_ = call->checked_type();
auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
auto bpv = ll->Push(RefRead(bp));
- Expr nbp = Function({}, LetList::With([&](LetList* ll) {
- tvm::Array<Expr> rev =
- rev_map[op_ref](orig,
GetGrad(call->checked_type(), ret, ll));
- CHECK(args.size() == rev.size());
- for (size_t i = 0; i < args.size(); ++i) {
- UpdateGrad(call->args[i]->checked_type(),
args[i], rev[i], ll);
- }
- return Call(bpv, {});
- }),
- TupleType::Empty(), {});
+ Expr nbp_body = LetList::With([&](LetList* ll) {
+ tvm::Array<Expr> rev = rev_map[op_ref](orig,
GetGrad(call->checked_type(), ret, ll));
+ CHECK(args.size() == rev.size());
+ for (size_t i = 0; i < args.size(); ++i) {
+ UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
+ }
+ return Call(bpv, {});
+ });
+ Expr nbp = Function({}, nbp_body, TupleType::Empty(), {});
ll->Push(RefWrite(bp, transform::ToANormalForm(nbp)));
// TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that.
return ret;
});
+ } else if (call->op.as<ConstructorNode>()) {
+ return ExprMutator::VisitExpr_(call);
+ } else {
+ std::vector<Expr> args;
+ for (const auto& arg : call->args) {
+ args.push_back(VisitExpr(arg));
+ }
+ args.push_back(bp);
+ return Call(VisitExpr(call->op), args);
}
- return ExprMutator::VisitExpr_(call);
}
Expr VisitExpr_(const ConstantNode* op) final {
@@ -559,6 +582,39 @@ struct ReverseAD : ExprMutator {
return ad_vars->at(var_ref);
}
+ Expr VisitExpr_(const GlobalVarNode* op) final {
+ // todo: concatenating string to add attribute seems like a brittle hack.
+ // maybe get module indexed by a rose tree of string?
+ CHECK(mod.defined());
+ auto orig_gv = GetRef<GlobalVar>(op);
+ if (ad_gvars->count(orig_gv) == 0) {
+ GlobalVar gv(op->name_hint + "_grad");
+ (*ad_gvars)[orig_gv] = gv;
+ Function orig_f =
Downcast<Function>(DeDup(mod.value()->Lookup(orig_gv)));
+ std::vector<Var> params;
+ for (const auto& p : orig_f->params) {
+ params.push_back(Downcast<Var>(VisitExpr(p)));
+ }
+ params.push_back(bp);
+ Expr body = VisitExpr(orig_f->body);
+ Function f(params, body, VisitType(orig_f->ret_type),
orig_f->type_params, orig_f->attrs);
+ std::cout << "gv " << op->name_hint << ": " << AsText(f, false) <<
std::endl;
+ mod.value()->Add(gv, f);
+ }
+ return ad_gvars->at(orig_gv);
+ }
+
+ Expr VisitExpr_(const FunctionNode* op) final {
+ std::vector<Var> params;
+ for (const auto& var : op->params) {
+ params.push_back(Downcast<Var>(VisitExpr(var)));
+ }
+ auto new_bp = Var("bp", bpt);
+ params.push_back(new_bp);
+ return Function(params, ReverseAD(mod, new_bp, ad_vars,
ad_gvars)(op->body),
+ VisitType(op->ret_type), op->type_params, op->attrs);
+ }
+
Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) :
t; }
};
@@ -604,12 +660,16 @@ Expr Gradient(const Expr& re, const Optional<IRModule>&
mod) {
}
CHECK(!MissingGrad(e)) << "input has operators with missing gradients";
Expr body = LetList::With([&](LetList* ll) {
- Var bp = ll->Push(BPEmpty());
- Expr rev = ReverseAD(bp, std::make_shared<ReverseAD::ADVarMap>())(e);
- std::vector<Expr> args;
+ Var bp = ll->Push(BPEmpty(), bpt);
+ Expr rev = ReverseAD(mod, bp, std::make_shared<ReverseAD::ADVarMap>(),
+ std::make_shared<ReverseAD::ADGlobalVarMap>())(e);
+ std::vector<Expr> normal_args, args;
for (const auto& p : f->params) {
- args.push_back(ll->Push(Pair(p, RefCreate(ZerosLike(p)))));
+ auto x = ll->Push(Pair(p, RefCreate(ZerosLike(p))));
+ normal_args.push_back(x);
+ args.push_back(x);
}
+ args.push_back(bp);
auto c = ll->Push(Call(rev, args));
std::function<void(const Expr&, const Type&)> init_grad;
init_grad = [&](const Expr& e, const Type& t) {
@@ -626,7 +686,7 @@ Expr Gradient(const Expr& re, const Optional<IRModule>&
mod) {
init_grad(c, f->body->checked_type());
ll->Push(Call(RefRead(bp), {}));
std::vector<Expr> ret;
- for (const auto& a : args) {
+ for (const auto& a : normal_args) {
ret.push_back(RefRead(GetField(a, 1)));
}
std::function<Expr(const Expr&, const Type&)> get_final_result;
diff --git a/tests/python/relay/test_pass_gradient.py
b/tests/python/relay/test_pass_gradient.py
index 296d3e5..b239ef4 100644
--- a/tests/python/relay/test_pass_gradient.py
+++ b/tests/python/relay/test_pass_gradient.py
@@ -21,6 +21,7 @@ import pytest
import tvm
from tvm import te
from tvm import relay
+from tvm.relay import GlobalVar
from tvm.relay.analysis import free_vars, free_type_vars
from tvm.relay import create_executor, transform
from tvm.relay.transform import gradient
@@ -29,7 +30,7 @@ from tvm.relay.testing import add_nat_definitions,
make_nat_expr, run_infer_type
import tvm.relay.op as op
-def test_id():
+def test_fo_id():
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
@@ -44,6 +45,21 @@ def test_id():
tvm.testing.assert_allclose(forward.asnumpy(), x.asnumpy())
tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy()))
+def test_id():
+ shape = (10, 10)
+ dtype = 'float32'
+ t = relay.TensorType(shape, dtype)
+ x = relay.var("x", t)
+ func = relay.Function([x], x)
+ func = run_infer_type(func)
+ back_func = run_infer_type(gradient(func))
+ assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t,
relay.TupleType([t])]))
+ ex = create_executor()
+ x = rand(dtype, *shape)
+ forward, (grad,) = ex.evaluate(back_func)(x)
+ tvm.testing.assert_allclose(forward.asnumpy(), x.asnumpy())
+ tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy()))
+
def test_relu():
shape = (10, 10)
@@ -341,5 +357,28 @@ def test_no_duplication():
counts = count_ops(gr)
assert counts['nn.dense'] == 3, "We expect 3 dense (1 forward, two
backward)"
+
+def test_global_function():
+ m = tvm.IRModule()
+ shape = (10, 10)
+ dtype = 'float32'
+ t = relay.TensorType(shape, dtype)
+ x = relay.Var('x', t)
+ d = GlobalVar('double')
+ m[d] = relay.Function([x], x + x)
+ y = relay.Var('y', t)
+ q = GlobalVar('q')
+ m[q] = relay.Function([y], d(d(y)))
+ g = GlobalVar('grad')
+ m[g] = tvm.relay.transform.gradient(q, m)
+ back_func = m[g]
+ assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t,
relay.TupleType([t])]))
+ ex = create_executor(mod=m)
+ x = rand(dtype, *shape)
+ forward, (grad,) = ex.evaluate(back_func)(x)
+ tvm.testing.assert_allclose(forward.asnumpy(), 4 * x.asnumpy())
+ tvm.testing.assert_allclose(grad.asnumpy(), 4 * np.ones_like(x.asnumpy()))
+
+
if __name__ == "__main__":
pytest.main([__file__])