This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new cb3bf4014a [Unity][Training] Enhance op gradient (#14932)
cb3bf4014a is described below

commit cb3bf4014adfbb5d719cbf095dc2f805837ff2a5
Author: Yixin Dong <[email protected]>
AuthorDate: Sat May 27 22:56:11 2023 +0800

    [Unity][Training] Enhance op gradient (#14932)
---
 python/tvm/relax/op/_op_gradient.py           | 74 ++++++++++++++-------------
 tests/python/relax/test_transform_gradient.py |  6 +--
 2 files changed, 41 insertions(+), 39 deletions(-)

diff --git a/python/tvm/relax/op/_op_gradient.py 
b/python/tvm/relax/op/_op_gradient.py
index 36aed832e7..d57bcc8621 100644
--- a/python/tvm/relax/op/_op_gradient.py
+++ b/python/tvm/relax/op/_op_gradient.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=unused-argument, redefined-builtin
+# pylint: disable=unused-argument, redefined-builtin, invalid-name
 """Gradient definitions for Relax operators."""
 import functools
 import operator
@@ -23,6 +23,7 @@ from typing import List
 from tvm import relax
 from tvm._ffi.base import TVMError
 from tvm.arith import Analyzer
+from tvm.relax.struct_info import ShapeStructInfo
 
 from ..block_builder import BlockBuilder
 from ..expr import Call, Var, Expr, ShapeExpr
@@ -55,7 +56,7 @@ from .manipulate import (
 from .nn import conv2d_transpose, conv2d
 from .search import where
 from .statistical import sum, cumsum
-from .unary import cos, exp, log, sin, sqrt, sigmoid
+from .unary import cos, exp, log, sin, sigmoid
 
 
 # TODO(yixin, chaofan): handle symbolic shape for most of the gradients
@@ -86,31 +87,40 @@ def _get_dtype(expr: Expr) -> str:
     return dtype
 
 
-def _fit_shape(expr: Expr, expr_shape: ShapeExpr, target: Expr) -> Expr:
+def _fit_shape(bb: BlockBuilder, expr: Expr, target: Expr) -> Expr:
+    """When expr and target has the same shape, return expr;
+    otherwise return `collapse_sum_to(expr, target.struct_info.shape)`.
+
+    Will use BlockBuilder to normalize expr first.
+    """
     target_shape = _get_shape(target)
-    expr_sinfo = expr_shape.struct_info
+    expr_sinfo = _get_shape(bb.normalize(expr)).struct_info
     target_sinfo = target_shape.struct_info
-    assert isinstance(expr_sinfo, relax.ShapeStructInfo)
-    assert isinstance(target_sinfo, relax.ShapeStructInfo)
+    assert isinstance(expr_sinfo, ShapeStructInfo)
+    assert isinstance(target_sinfo, ShapeStructInfo)
 
-    def _check_shape_equal():
-        if len(expr_sinfo.values) != len(target_sinfo.values):
+    def _check_shape_equal(lhs: ShapeStructInfo, rhs: ShapeStructInfo):
+        if len(lhs.values) != len(rhs.values):
             return False
         analyzer = Analyzer()
-        for i, field in enumerate(expr_sinfo.values):
-            if not analyzer.can_prove_equal(field, target_sinfo.values[i]):
+        for i, field in enumerate(lhs.values):
+            if not analyzer.can_prove_equal(field, rhs.values[i]):
                 return False
         return True
 
-    return expr if _check_shape_equal() else collapse_sum_to(expr, 
target_shape)
+    return (
+        expr
+        if _check_shape_equal(expr_sinfo, target_sinfo)
+        else collapse_sum_to(expr, target_shape)
+    )
 
 
 def _get_shape_prod(expr, axis):
+    # Requires static shape
     shape = _get_shape(expr)
     if axis is None:
         return functools.reduce(operator.mul, (int(i) for i in shape), 1)
-    else:
-        return functools.reduce(operator.mul, (int(shape[int(i)]) for i in 
axis), 1)
+    return functools.reduce(operator.mul, (int(shape[int(i)]) for i in axis), 
1)
 
 
 ##################### Binary #####################
@@ -131,10 +141,9 @@ def add_grad(
     Backward:
         Returns `[z_output_grad, z_grad]`.
     """
-    output_grad_shape = _get_shape(output_grad)
     return [
-        _fit_shape(output_grad, output_grad_shape, orig_call.args[0]),
-        _fit_shape(output_grad, output_grad_shape, orig_call.args[1]),
+        _fit_shape(ctx, output_grad, orig_call.args[0]),
+        _fit_shape(ctx, output_grad, orig_call.args[1]),
     ]
 
 
@@ -153,10 +162,9 @@ def subtract_grad(
     Backward:
         Returns `[z_output_grad, -z_grad]`.
     """
-    output_grad_shape = _get_shape(output_grad)
     return [
-        _fit_shape(output_grad, output_grad_shape, orig_call.args[0]),
-        _fit_shape(-output_grad, output_grad_shape, orig_call.args[1]),
+        _fit_shape(ctx, output_grad, orig_call.args[0]),
+        _fit_shape(ctx, -output_grad, orig_call.args[1]),
     ]
 
 
@@ -176,10 +184,9 @@ def multiply_grad(
         Returns `[z_grad * y, z_grad * x]`.
     """
     x, y = orig_call.args
-    output_grad_shape = _get_shape(output_grad)
     return [
-        _fit_shape(output_grad * y, output_grad_shape, x),
-        _fit_shape(output_grad * x, output_grad_shape, y),
+        _fit_shape(ctx, output_grad * y, x),
+        _fit_shape(ctx, output_grad * x, y),
     ]
 
 
@@ -199,10 +206,9 @@ def divide_grad(
         Returns `[z_grad / y,  -z_grad * z / y]`.
     """
     x, y = orig_call.args
-    output_grad_shape = _get_shape(output_grad)
     return [
-        _fit_shape(output_grad / y, output_grad_shape, x),
-        _fit_shape(-output_grad * orig_var / y, output_grad_shape, y),
+        _fit_shape(ctx, output_grad / y, x),
+        _fit_shape(ctx, -output_grad * orig_var / y, y),
     ]
 
 
@@ -224,11 +230,10 @@ def power_grad(
         The gradient w.r.t. the second parameter, y, makes sense only when x > 
0.
     """
     x, y = orig_call.args
-    output_grad_shape = _get_shape(output_grad)
     one = relax.const(1, _get_dtype(y))
     return [
-        _fit_shape(output_grad * y * (x ** (y - one)), output_grad_shape, x),
-        _fit_shape(output_grad * orig_var * log(x), output_grad_shape, y),
+        _fit_shape(ctx, output_grad * y * (x ** (y - one)), x),
+        _fit_shape(ctx, output_grad * orig_var * log(x), y),
     ]
 
 
@@ -580,11 +585,11 @@ def sqrt_grad(
         `y = relax.sqrt(x)`
 
     Backward:
-        Returns `[0.5 * y_grad / sqrt(x)]`.
+        Returns `[0.5 * y_grad / y]`.
     """
     x = orig_call.args[0]
     cst = relax.const(0.5, _get_dtype(x))
-    return [cst * output_grad / sqrt(x)]
+    return [cst * output_grad / orig_var]
 
 
 @register_gradient("relax.tanh")
@@ -714,8 +719,7 @@ def permute_dims_grad(
         for i in range(dims):
             new_axes[int(axes[i])] = i
         return [permute_dims(output_grad, axes=new_axes)]
-    else:
-        return [permute_dims(output_grad)]
+    return [permute_dims(output_grad)]
 
 
 @register_gradient("relax.concat")
@@ -983,11 +987,9 @@ def matmul_grad(
         a_grad = output_grad * tensor_b
         b_grad = output_grad * tensor_a
 
-    output_grad_shape = _get_shape(output_grad)
-
     return [
-        _fit_shape(a_grad, output_grad_shape, tensor_a),
-        _fit_shape(b_grad, output_grad_shape, tensor_b),
+        _fit_shape(ctx, a_grad, tensor_a),
+        _fit_shape(ctx, b_grad, tensor_b),
     ]
 
 
diff --git a/tests/python/relax/test_transform_gradient.py 
b/tests/python/relax/test_transform_gradient.py
index 1b3d174c13..50063fe385 100644
--- a/tests/python/relax/test_transform_gradient.py
+++ b/tests/python/relax/test_transform_gradient.py
@@ -1138,9 +1138,9 @@ def test_mlp_script():
                 lv4: R.Tensor((3, 5), dtype="float32") = R.multiply(lv2, lv3)
                 out_adjoint: R.Tensor((3, 5), dtype="float32") = 
R.subtract(logits_adjoint, lv4)
                 lv0_adjoint: R.Tensor((3, 5), dtype="float32") = out_adjoint
-                lv5: R.Tensor((10, 3), dtype="float32") = R.permute_dims(x, 
axes=[1, 0])
-                lv6: R.Tensor((10, 5), dtype="float32") = R.matmul(lv5, 
lv0_adjoint, out_dtype="void")
-                w0_adjoint: R.Tensor((10, 5), dtype="float32") = 
R.collapse_sum_to(lv6, R.shape([10, 5]))
+                lv5: R.Tensor((5, 10), dtype="float32") = R.permute_dims(w0, 
axes=[1, 0])
+                lv6: R.Tensor((10, 3), dtype="float32") = R.permute_dims(x, 
axes=[1, 0])
+                w0_adjoint: R.Tensor((10, 5), dtype="float32") = R.matmul(lv6, 
lv0_adjoint, out_dtype="void")
                 b0_adjoint: R.Tensor((5,), dtype="float32") = 
R.collapse_sum_to(out_adjoint, R.shape([5]))
                 R.output(loss, w0_adjoint, b0_adjoint)
             return (loss, (w0_adjoint, b0_adjoint))

Reply via email to