This is an automated email from the ASF dual-hosted git repository.

marisa pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 2e04393  Amendments for gradients (#5941)
2e04393 is described below

commit 2e043937831dbb07e152f4702457ed05ff3cd31e
Author: Thomas Viehmann <tv.c...@beamnet.de>
AuthorDate: Tue Jun 30 05:35:36 2020 +0200

    Amendments for gradients (#5941)
    
    * Amendments for gradients
    
    - We fix the dtype handling of consts in generated gradients.
    - We add a collapse_sum_to instruction mirroring the collapse_sum_like.
      While for general definitions (potentially dynamic shapes),
      collapse_sum_like is the first choice, when moving to static,
      using collapse_sum_to will greatly simplify the graph.
      (This simplification is not part of the PR.)
    
    * Fix Broadcast rel description in comment
    
    Thank you, @MarisaKirisame
---
 python/tvm/relay/op/_tensor_grad.py        | 24 ++++++++++-----
 python/tvm/relay/op/_transform.py          |  1 +
 python/tvm/relay/op/transform.py           | 21 +++++++++++++
 src/relay/op/tensor/transform.cc           | 48 ++++++++++++++++++++++++++++++
 tests/python/relay/test_op_grad_level1.py  | 13 ++++----
 tests/python/relay/test_op_grad_level10.py | 14 +++++----
 tests/python/relay/test_op_grad_level3.py  | 31 +++++++++----------
 tests/python/relay/test_op_level10.py      | 20 +++++++++++++
 8 files changed, 137 insertions(+), 35 deletions(-)

diff --git a/python/tvm/relay/op/_tensor_grad.py 
b/python/tvm/relay/op/_tensor_grad.py
index 00ea097..2907d72 100644
--- a/python/tvm/relay/op/_tensor_grad.py
+++ b/python/tvm/relay/op/_tensor_grad.py
@@ -69,7 +69,7 @@ def log2_grad(orig, grad):
     """Returns [grad * 1 / (log(2) * x)]"""
     x = orig.args[0]
     ones = ones_like(x)
-    two = const(2.0)
+    two = const(2.0, dtype=x.checked_type.dtype)
     return [grad * ones / (log(two) * x)]
 
 
@@ -78,7 +78,7 @@ def log10_grad(orig, grad):
     """Returns [grad * 1 / (log(10) * x)]"""
     x = orig.args[0]
     ones = ones_like(x)
-    ten = const(10.0)
+    ten = const(10.0, dtype=x.checked_type.dtype)
     return [grad * ones / (log(ten) * x)]
 
 
@@ -175,8 +175,9 @@ def exp_grad(orig, grad):
 @register_gradient("sqrt")
 def sqrt_grad(orig, grad):
     """Returns [grad * 0.5 * (x ^ -0.5)]"""
-    a = const(0.5)  # (TODO) type?
-    return [grad * a * power(orig.args[0], negative(a))]
+    x = orig.args[0]
+    a = const(0.5, dtype=x.checked_type.dtype)
+    return [grad * a * power(x, negative(a))]
 
 
 @register_gradient("sigmoid")
@@ -261,6 +262,13 @@ def collapse_sum_like_grad(orig, grad):
     return [broadcast_to_like(grad, x), zeros_like(y)]
 
 
+@register_gradient("collapse_sum_to")
+def collapse_sum_to_grad(orig, grad):
+    """Returns [broadcast_to_like(grad, x), 0]"""
+    x, y = orig.args
+    return [broadcast_to_like(grad, x), zeros_like(y)]
+
+
 @register_gradient("abs")
 def abs_grad(orig, grad):
     """Returns grad * (select(x < 0, -1, 1))."""
@@ -284,8 +292,8 @@ def clip_grad(orig, grad):
     x = orig.args[0]
     a_min = orig.attrs.get_int("a_min")
     a_max = orig.attrs.get_int("a_max")
-    a_mins = broadcast_to_like(const(a_min), x)
-    a_maxs = broadcast_to_like(const(a_max), x)
+    a_mins = broadcast_to_like(const(a_min, dtype=x.checked_type.dtype), x)
+    a_maxs = broadcast_to_like(const(a_max, dtype=x.checked_type.dtype), x)
     zeros = zeros_like(x)
     ones = ones_like(x)
     return [where(less(x, a_mins), zeros, where(less(a_maxs, x), zeros, ones * 
grad))]
@@ -591,7 +599,7 @@ def cross_entropy_grad(orig, grad):
     x, y = orig.args
     shape = shape_of(x)
     batch_size = take(shape, const(0, dtype='int32'), axis=0)
-    grad = grad / batch_size.astype('float32')
+    grad = grad / batch_size.astype(x.checked_type.dtype)
     return [-grad * y / x, -grad * log(x)]
 
 
@@ -600,5 +608,5 @@ def cross_entropy_with_logits_grad(orig, grad):
     x, y = orig.args
     shape = shape_of(x)
     batch_size = take(shape, const(0, dtype='int32'), axis=0)
-    grad = grad / batch_size.astype('float32')
+    grad = grad / batch_size.astype(x.checked_type.dtype)
     return [-grad * y, -grad * x]
diff --git a/python/tvm/relay/op/_transform.py 
b/python/tvm/relay/op/_transform.py
index d104c1b..10238d1 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -57,6 +57,7 @@ _reg.register_injective_schedule("gather_nd")
 _reg.register_injective_schedule("sequence_mask")
 _reg.register_injective_schedule("one_hot")
 _reg.register_reduce_schedule("collapse_sum_like")
+_reg.register_reduce_schedule("collapse_sum_to")
 _reg.register_injective_schedule("unravel_index")
 _reg.register_injective_schedule("sparse_to_dense")
 
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index a37226e..cc9f730 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -660,6 +660,27 @@ def collapse_sum_like(data, collapse_type):
     return _make.collapse_sum_like(data, collapse_type)
 
 
+def collapse_sum_to(data, shape):
+    """Return a summation of data to the specified shape.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input tensor.
+
+    shape : relay.Expr
+        Shape to collapse to.
+
+    Returns
+    -------
+    result : relay.Expr
+        The resulting tensor.
+    """
+    if isinstance(shape, (list, tuple)):
+        shape = const(list(shape), "int32")
+    return _make.collapse_sum_to(data, shape)
+
+
 def split(data, indices_or_sections, axis=0):
     """Split input tensor along axis by sections or indices.
 
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index ee5e291..a07fa9a 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -1713,6 +1713,54 @@ RELAY_REGISTER_OP("collapse_sum_like")
     .set_attr<FTVMCompute>("FTVMCompute", CollapseSumLikeCompute)
     .set_attr<TOpPattern>("TOpPattern", kCommReduce);
 
+// CollapseSumTo: <A, B> -> B where Broadcast(A, B) = A
+bool CollapseSumToRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                      const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const InitOpAttrs* param = attrs.as<InitOpAttrs>();
+  const auto* target_shape = types[1].as<TensorTypeNode>();
+  DataType out_dtype = types[0].as<TensorTypeNode>()->dtype;
+
+  const IntImmNode* shape_shape = target_shape->shape[0].as<IntImmNode>();
+  CHECK(shape_shape) << "Parameter shape must have static shape";
+
+  std::vector<IndexExpr> oshape;
+  if (param->shape) {
+    const Array<Integer>& cshape_array = param->shape.value();
+    for (size_t i = 0; i < cshape_array.size(); ++i) {
+      oshape.push_back(cshape_array[i]);
+    }
+  } else {
+    for (int i = 0; i < shape_shape->value; ++i) {
+      oshape.push_back(Any());
+    }
+  }
+  reporter->Assign(types[2], TensorType(oshape, out_dtype));
+  return BroadcastRel({types[0], types[2], types[0]}, 2, Attrs(), reporter);
+}
+
+Expr MakeCollapseSumTo(Expr data, Expr shape) {
+  static const Op& op = Op::Get("collapse_sum_to");
+  auto attrs = make_object<InitOpAttrs>();
+  if (const auto* cshape = shape.as<ConstantNode>()) {
+    attrs->shape = ToVector(cshape->data);
+  }
+  return Call(op, {data, shape}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op._make.collapse_sum_to").set_body_typed(MakeCollapseSumTo);
+
+RELAY_REGISTER_OP("collapse_sum_to")
+    .describe(R"code(Broadcast the first input to match the shape argument.
+)code" TVM_ADD_FILELINE)
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("shape", "Tensor", "Target shape.")
+    .set_support_level(4)
+    .add_type_rel("CollapseSumTo", CollapseSumToRel)
+    .set_attr<FTVMCompute>("FTVMCompute", CollapseSumLikeCompute)
+    .set_attr<TOpPattern>("TOpPattern", kCommReduce);
+
 // BroadCastTo: <A, B> -> B where BroadCast(A, B) = B
 bool BroadCastToRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
                     const TypeReporter& reporter) {
diff --git a/tests/python/relay/test_op_grad_level1.py 
b/tests/python/relay/test_op_grad_level1.py
index 85506e0..437901e 100644
--- a/tests/python/relay/test_op_grad_level1.py
+++ b/tests/python/relay/test_op_grad_level1.py
@@ -36,9 +36,8 @@ def relu(x):
 
 
 def test_unary_op():
-    def check_single_op(opfunc, ref):
+    def check_single_op(opfunc, ref, dtype):
         shape = (10, 4)
-        dtype = 'float32'
         tp = relay.TensorType(shape, dtype)
         x = relay.var("x", tp)
         y = opfunc(x)
@@ -76,16 +75,17 @@ def test_unary_op():
                         (tvm.relay.acosh, lambda x: 1./ (x**2 - 1.)**(1./2.)),
                         (tvm.relay.asinh, lambda x: 1./ (x**2 + 1.)**(1./2.)),
                         (tvm.relay.atanh, lambda x: -1./ (x**2 - 1.))]:
-        check_single_op(opfunc, ref)
+        for dtype in ('float32', 'float64'):
+            check_single_op(opfunc, ref, dtype)
 
 
 def test_binary_op():
     def inst(vars, sh):
         return [vars.get(s, s) for s in sh]
 
-    def check_binary_op(opfunc, ref):
+    def check_binary_op(opfunc, ref, dtype):
         s = (5, 10, 5)
-        t = relay.TensorType((5, 10, 5))
+        t = relay.TensorType((5, 10, 5), dtype=dtype)
         x = relay.var("x", t)
         y = relay.var("y", t)
         z = opfunc(x, y)
@@ -107,7 +107,8 @@ def test_binary_op():
                         (relay.subtract, lambda x, y: [np.ones_like(x), 
-np.ones_like(y)]),
                         (relay.multiply, lambda x, y: [y, x]),
                         (relay.divide, lambda x, y: [1 / y, - x / (y**2)])]:
-        check_binary_op(opfunc, ref)
+        for dtype in ('float32', 'float64'):
+            check_binary_op(opfunc, ref, dtype)
 
 
 def test_softmax_grad():
diff --git a/tests/python/relay/test_op_grad_level10.py 
b/tests/python/relay/test_op_grad_level10.py
index 6e64999..2c749c9 100644
--- a/tests/python/relay/test_op_grad_level10.py
+++ b/tests/python/relay/test_op_grad_level10.py
@@ -21,15 +21,17 @@ from tvm.relay.testing import check_grad
 
 
 def test_cross_entropy_grad():
-    x = relay.var("x", shape=(2, 5))
-    y = relay.var("y", shape=(2, 5))
-    check_grad(relay.Function([x, y], relay.op.nn.cross_entropy(x, y)), 
eps=0.01, scale=0.1, mean=1)
+    for dtype in ('float32', 'float64'):
+        x = relay.var("x", shape=(2, 5), dtype=dtype)
+        y = relay.var("y", shape=(2, 5), dtype=dtype)
+        check_grad(relay.Function([x, y], relay.op.nn.cross_entropy(x, y)), 
eps=0.01, scale=0.1, mean=1)
 
 
 def test_cross_entropy_with_logits_grad():
-    x = relay.var("x", shape=(2, 5))
-    y = relay.var("y", shape=(2, 5))
-    check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, 
y)), eps=0.01, scale=0.1, mean=1)
+    for dtype in ('float32', 'float64'):
+        x = relay.var("x", shape=(2, 5), dtype=dtype)
+        y = relay.var("y", shape=(2, 5), dtype=dtype)
+        check_grad(relay.Function([x, y], 
relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1)
     
 def test_checkpoint():
     inputs = [relay.var("x{}".format(i), shape=(1,)) for i in range(4)]
diff --git a/tests/python/relay/test_op_grad_level3.py 
b/tests/python/relay/test_op_grad_level3.py
index b1d0e25..8ca1eae 100644
--- a/tests/python/relay/test_op_grad_level3.py
+++ b/tests/python/relay/test_op_grad_level3.py
@@ -25,21 +25,22 @@ from tvm.relay.transform import gradient
 
 
 def test_clip():
-    ref = (lambda x: np.where(x > 10.0, np.zeros_like(x),
-                     np.where(x < 1.0, np.zeros_like(x), np.ones_like(x))))
-    x = relay.var("x", relay.TensorType((10, 4), "float32"))
-    y = tvm.relay.clip(x, 1.0, 10.0)
-
-    data = np.random.rand(10, 4).astype("float32") * 11.0
-    ref_grad = ref(data)
-    fwd_func = relay.Function([x], y)
-    fwd_func = run_infer_type(fwd_func)
-    bwd_func = run_infer_type(gradient(fwd_func))
-
-    for target, ctx in ctx_list():
-        intrp = relay.create_executor(ctx=ctx, target=target)
-        op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
-        np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
+    for dtype in ('float32', 'float64'):
+        ref = (lambda x: np.where(x > 10.0, np.zeros_like(x),
+                         np.where(x < 1.0, np.zeros_like(x), np.ones_like(x))))
+        x = relay.var("x", relay.TensorType((10, 4), dtype))
+        y = tvm.relay.clip(x, 1.0, 10.0)
+
+        data = np.random.rand(10, 4).astype(dtype) * 11.0
+        ref_grad = ref(data)
+        fwd_func = relay.Function([x], y)
+        fwd_func = run_infer_type(fwd_func)
+        bwd_func = run_infer_type(gradient(fwd_func))
+
+        for target, ctx in ctx_list():
+            intrp = relay.create_executor(ctx=ctx, target=target)
+            op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
+            np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)
 
 
 def verify_transpose_grad(d_shape, axes=None):
diff --git a/tests/python/relay/test_op_level10.py 
b/tests/python/relay/test_op_level10.py
index 5e57c80..7528267 100644
--- a/tests/python/relay/test_op_level10.py
+++ b/tests/python/relay/test_op_level10.py
@@ -168,6 +168,26 @@ def test_collapse_sum_like():
             op_res = intrp.evaluate(func)(x, y)
             tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
 
+
+def test_collapse_sum_to():
+    shape = (3, 4, 5, 6)
+    shape_to = (4, 5, 6)
+    dtype = "float32"
+    x = relay.Var("x", relay.ty.TensorType(shape , dtype))
+    z = relay.collapse_sum_to(x, shape_to)
+    zz = run_infer_type(z)
+    assert zz.checked_type == relay.ty.TensorType(shape_to, dtype)
+
+    func = relay.Function([x], z)
+    x = np.random.uniform(size=shape).astype(dtype)
+    ref_res = np.sum(x, 0)
+    for target, ctx in ctx_list():
+        for kind in ["graph", "debug"]:
+            intrp = relay.create_executor(kind, ctx=ctx, target=target)
+            op_res = intrp.evaluate(func)(x)
+            tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
+
+
 def test_broadcast_to():
     shape = (4, 1, 6)
     shape_like = (3, 4, 5, 6)

Reply via email to