This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new cb3bf4014a [Unity][Training] Enhance op gradient (#14932)
cb3bf4014a is described below
commit cb3bf4014adfbb5d719cbf095dc2f805837ff2a5
Author: Yixin Dong <[email protected]>
AuthorDate: Sat May 27 22:56:11 2023 +0800
[Unity][Training] Enhance op gradient (#14932)
---
python/tvm/relax/op/_op_gradient.py | 74 ++++++++++++++-------------
tests/python/relax/test_transform_gradient.py | 6 +--
2 files changed, 41 insertions(+), 39 deletions(-)
diff --git a/python/tvm/relax/op/_op_gradient.py
b/python/tvm/relax/op/_op_gradient.py
index 36aed832e7..d57bcc8621 100644
--- a/python/tvm/relax/op/_op_gradient.py
+++ b/python/tvm/relax/op/_op_gradient.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=unused-argument, redefined-builtin
+# pylint: disable=unused-argument, redefined-builtin, invalid-name
"""Gradient definitions for Relax operators."""
import functools
import operator
@@ -23,6 +23,7 @@ from typing import List
from tvm import relax
from tvm._ffi.base import TVMError
from tvm.arith import Analyzer
+from tvm.relax.struct_info import ShapeStructInfo
from ..block_builder import BlockBuilder
from ..expr import Call, Var, Expr, ShapeExpr
@@ -55,7 +56,7 @@ from .manipulate import (
from .nn import conv2d_transpose, conv2d
from .search import where
from .statistical import sum, cumsum
-from .unary import cos, exp, log, sin, sqrt, sigmoid
+from .unary import cos, exp, log, sin, sigmoid
# TODO(yixin, chaofan): handle symbolic shape for most of the gradients
@@ -86,31 +87,40 @@ def _get_dtype(expr: Expr) -> str:
return dtype
-def _fit_shape(expr: Expr, expr_shape: ShapeExpr, target: Expr) -> Expr:
+def _fit_shape(bb: BlockBuilder, expr: Expr, target: Expr) -> Expr:
+ """When expr and target has the same shape, return expr;
+ otherwise return `collapse_sum_to(expr, target.struct_info.shape)`.
+
+ Will use BlockBuilder to normalize expr first.
+ """
target_shape = _get_shape(target)
- expr_sinfo = expr_shape.struct_info
+ expr_sinfo = _get_shape(bb.normalize(expr)).struct_info
target_sinfo = target_shape.struct_info
- assert isinstance(expr_sinfo, relax.ShapeStructInfo)
- assert isinstance(target_sinfo, relax.ShapeStructInfo)
+ assert isinstance(expr_sinfo, ShapeStructInfo)
+ assert isinstance(target_sinfo, ShapeStructInfo)
- def _check_shape_equal():
- if len(expr_sinfo.values) != len(target_sinfo.values):
+ def _check_shape_equal(lhs: ShapeStructInfo, rhs: ShapeStructInfo):
+ if len(lhs.values) != len(rhs.values):
return False
analyzer = Analyzer()
- for i, field in enumerate(expr_sinfo.values):
- if not analyzer.can_prove_equal(field, target_sinfo.values[i]):
+ for i, field in enumerate(lhs.values):
+ if not analyzer.can_prove_equal(field, rhs.values[i]):
return False
return True
- return expr if _check_shape_equal() else collapse_sum_to(expr,
target_shape)
+ return (
+ expr
+ if _check_shape_equal(expr_sinfo, target_sinfo)
+ else collapse_sum_to(expr, target_shape)
+ )
def _get_shape_prod(expr, axis):
+ # Requires static shape
shape = _get_shape(expr)
if axis is None:
return functools.reduce(operator.mul, (int(i) for i in shape), 1)
- else:
- return functools.reduce(operator.mul, (int(shape[int(i)]) for i in
axis), 1)
+ return functools.reduce(operator.mul, (int(shape[int(i)]) for i in axis),
1)
##################### Binary #####################
@@ -131,10 +141,9 @@ def add_grad(
Backward:
Returns `[z_output_grad, z_grad]`.
"""
- output_grad_shape = _get_shape(output_grad)
return [
- _fit_shape(output_grad, output_grad_shape, orig_call.args[0]),
- _fit_shape(output_grad, output_grad_shape, orig_call.args[1]),
+ _fit_shape(ctx, output_grad, orig_call.args[0]),
+ _fit_shape(ctx, output_grad, orig_call.args[1]),
]
@@ -153,10 +162,9 @@ def subtract_grad(
Backward:
Returns `[z_output_grad, -z_grad]`.
"""
- output_grad_shape = _get_shape(output_grad)
return [
- _fit_shape(output_grad, output_grad_shape, orig_call.args[0]),
- _fit_shape(-output_grad, output_grad_shape, orig_call.args[1]),
+ _fit_shape(ctx, output_grad, orig_call.args[0]),
+ _fit_shape(ctx, -output_grad, orig_call.args[1]),
]
@@ -176,10 +184,9 @@ def multiply_grad(
Returns `[z_grad * y, z_grad * x]`.
"""
x, y = orig_call.args
- output_grad_shape = _get_shape(output_grad)
return [
- _fit_shape(output_grad * y, output_grad_shape, x),
- _fit_shape(output_grad * x, output_grad_shape, y),
+ _fit_shape(ctx, output_grad * y, x),
+ _fit_shape(ctx, output_grad * x, y),
]
@@ -199,10 +206,9 @@ def divide_grad(
Returns `[z_grad / y, -z_grad * z / y]`.
"""
x, y = orig_call.args
- output_grad_shape = _get_shape(output_grad)
return [
- _fit_shape(output_grad / y, output_grad_shape, x),
- _fit_shape(-output_grad * orig_var / y, output_grad_shape, y),
+ _fit_shape(ctx, output_grad / y, x),
+ _fit_shape(ctx, -output_grad * orig_var / y, y),
]
@@ -224,11 +230,10 @@ def power_grad(
The gradient w.r.t. the second parameter, y, makes sense only when x >
0.
"""
x, y = orig_call.args
- output_grad_shape = _get_shape(output_grad)
one = relax.const(1, _get_dtype(y))
return [
- _fit_shape(output_grad * y * (x ** (y - one)), output_grad_shape, x),
- _fit_shape(output_grad * orig_var * log(x), output_grad_shape, y),
+ _fit_shape(ctx, output_grad * y * (x ** (y - one)), x),
+ _fit_shape(ctx, output_grad * orig_var * log(x), y),
]
@@ -580,11 +585,11 @@ def sqrt_grad(
`y = relax.sqrt(x)`
Backward:
- Returns `[0.5 * y_grad / sqrt(x)]`.
+ Returns `[0.5 * y_grad / y]`.
"""
x = orig_call.args[0]
cst = relax.const(0.5, _get_dtype(x))
- return [cst * output_grad / sqrt(x)]
+ return [cst * output_grad / orig_var]
@register_gradient("relax.tanh")
@@ -714,8 +719,7 @@ def permute_dims_grad(
for i in range(dims):
new_axes[int(axes[i])] = i
return [permute_dims(output_grad, axes=new_axes)]
- else:
- return [permute_dims(output_grad)]
+ return [permute_dims(output_grad)]
@register_gradient("relax.concat")
@@ -983,11 +987,9 @@ def matmul_grad(
a_grad = output_grad * tensor_b
b_grad = output_grad * tensor_a
- output_grad_shape = _get_shape(output_grad)
-
return [
- _fit_shape(a_grad, output_grad_shape, tensor_a),
- _fit_shape(b_grad, output_grad_shape, tensor_b),
+ _fit_shape(ctx, a_grad, tensor_a),
+ _fit_shape(ctx, b_grad, tensor_b),
]
diff --git a/tests/python/relax/test_transform_gradient.py
b/tests/python/relax/test_transform_gradient.py
index 1b3d174c13..50063fe385 100644
--- a/tests/python/relax/test_transform_gradient.py
+++ b/tests/python/relax/test_transform_gradient.py
@@ -1138,9 +1138,9 @@ def test_mlp_script():
lv4: R.Tensor((3, 5), dtype="float32") = R.multiply(lv2, lv3)
out_adjoint: R.Tensor((3, 5), dtype="float32") =
R.subtract(logits_adjoint, lv4)
lv0_adjoint: R.Tensor((3, 5), dtype="float32") = out_adjoint
- lv5: R.Tensor((10, 3), dtype="float32") = R.permute_dims(x,
axes=[1, 0])
- lv6: R.Tensor((10, 5), dtype="float32") = R.matmul(lv5,
lv0_adjoint, out_dtype="void")
- w0_adjoint: R.Tensor((10, 5), dtype="float32") =
R.collapse_sum_to(lv6, R.shape([10, 5]))
+ lv5: R.Tensor((5, 10), dtype="float32") = R.permute_dims(w0,
axes=[1, 0])
+ lv6: R.Tensor((10, 3), dtype="float32") = R.permute_dims(x,
axes=[1, 0])
+ w0_adjoint: R.Tensor((10, 5), dtype="float32") = R.matmul(lv6,
lv0_adjoint, out_dtype="void")
b0_adjoint: R.Tensor((5,), dtype="float32") =
R.collapse_sum_to(out_adjoint, R.shape([5]))
R.output(loss, w0_adjoint, b0_adjoint)
return (loss, (w0_adjoint, b0_adjoint))