yongwww commented on code in PR #14527:
URL: https://github.com/apache/tvm/pull/14527#discussion_r1160843113


##########
python/tvm/relax/op/_op_gradient.py:
##########
@@ -0,0 +1,1199 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument, redefined-builtin
+"""Gradient definitions for Relax operators."""
+import functools
+import operator
+from typing import List
+
+from tvm import relax
+from tvm._ffi.base import TVMError
+from tvm.arith import Analyzer
+
+from ..block_builder import BlockBuilder
+from ..expr import Call, Var, Expr, ShapeExpr
+from ...tir import PrimExpr
+
+from .base import register_gradient
+from .binary import less
+from .datatype import astype
+from .grad import (
+    no_grad,
+    nll_loss_backward,
+    max_pool2d_backward,
+    avg_pool2d_backward,
+    take_backward,
+)
+from .index import strided_slice
+from .linear_algebra import matmul
+from .manipulate import (
+    collapse_sum_to,
+    broadcast_to,
+    permute_dims,
+    expand_dims,
+    concat,
+    reshape,
+    split,
+    squeeze,
+    cumsum,
+    flatten,
+)
+from .nn import conv2d_transpose, conv2d
+from .search import where
+from .statistical import sum
+from .unary import cos, exp, log, sin, sqrt
+
+
+# TODO(yixin, chaofan): handle symbolic shape for most of the gradients
+
+
+##################### Utilities #####################
+
+
+def _get_shape(expr: Expr) -> ShapeExpr:
+    """Get the shape from a Tensor expr."""
+    try:
+        shape = expr.struct_info.shape
+    except Exception as error:
+        raise TVMError(
+            f"Get the shape of {expr} failed. Please normalize it first and 
ensure it is a Tensor."
+        ) from error
+    return shape
+
+
+def _get_dtype(expr: Expr) -> str:
+    """Get the dtype from a Tensor expr."""
+    try:
+        dtype = expr.struct_info.dtype
+    except Exception as error:
+        raise TVMError(
+            f"Get the dtype of {expr} failed. Please normalize it first and 
ensure it is a Tensor."
+        ) from error
+    return dtype
+
+
+def _fit_shape(expr: Expr, expr_shape: ShapeExpr, target: Expr) -> Expr:
+    target_shape = _get_shape(target)
+    expr_sinfo = expr_shape.struct_info
+    target_sinfo = target_shape.struct_info
+    assert isinstance(expr_sinfo, relax.ShapeStructInfo)
+    assert isinstance(target_sinfo, relax.ShapeStructInfo)
+
+    def _check_shape_equal():
+        if len(expr_sinfo.values) != len(target_sinfo.values):
+            return False
+        analyzer = Analyzer()
+        for i, field in enumerate(expr_sinfo.values):
+            if not analyzer.can_prove_equal(field, target_sinfo.values[i]):
+                return False
+        return True
+
+    return expr if _check_shape_equal() else collapse_sum_to(expr, 
target_shape)
+
+
+def _get_shape_prod(expr, axis):
+    shape = _get_shape(expr)
+    if axis is None:
+        return functools.reduce(operator.mul, (int(i) for i in shape), 1)
+    else:
+        return functools.reduce(operator.mul, (int(shape[int(i)]) for i in 
axis), 1)
+
+
+##################### Binary #####################
+
+
+@register_gradient("relax.add")
+def add_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of add.
+
+    Forward Form:
+        `z = relax.add(x, y)`
+
+    Backward:
+        Returns `[z_output_grad, z_grad]`.
+    """
+    output_grad_shape = _get_shape(output_grad)
+    return [
+        _fit_shape(output_grad, output_grad_shape, orig_call.args[0]),
+        _fit_shape(output_grad, output_grad_shape, orig_call.args[1]),
+    ]
+
+
+@register_gradient("relax.subtract")
+def subtract_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of subtract.
+
+    Forward Form:
+        `z = relax.subtract(x, y)`
+
+    Backward:
+        Returns `[z_output_grad, -z_grad]`.
+    """
+    output_grad_shape = _get_shape(output_grad)
+    return [
+        _fit_shape(output_grad, output_grad_shape, orig_call.args[0]),
+        _fit_shape(-output_grad, output_grad_shape, orig_call.args[1]),
+    ]
+
+
+@register_gradient("relax.multiply")
+def multiply_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of multiply.
+
+    Forward Form:
+        `z = relax.multiply(x, y)`
+
+    Backward:
+        Returns `[z_grad * y, z_grad * x]`.
+    """
+    x, y = orig_call.args
+    output_grad_shape = _get_shape(output_grad)
+    return [
+        _fit_shape(output_grad * y, output_grad_shape, x),
+        _fit_shape(output_grad * x, output_grad_shape, y),
+    ]
+
+
+@register_gradient("relax.divide")
+def divide_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of divide.
+
+    Forward Form:
+        `z = relax.divide(x, y)`
+
+    Backward:
+        Returns `[z_grad / y,  -z_grad * z / y]`.
+    """
+    x, y = orig_call.args
+    output_grad_shape = _get_shape(output_grad)
+    return [
+        _fit_shape(output_grad / y, output_grad_shape, x),
+        _fit_shape(-output_grad * orig_var / y, output_grad_shape, y),
+    ]
+
+
+@register_gradient("relax.power")
+def power_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of power.
+
+    Forward Form:
+        `z = relax.power(x, y)`
+
+    Backward:
+        Returns `[y * x ** (y-1) * z_grad, z * ln(x) * z_grad]`.
+
+        The gradient w.r.t. the second parameter, y, makes sense only when x > 
0.
+    """
+    x, y = orig_call.args
+    output_grad_shape = _get_shape(output_grad)
+    one = relax.const(1, _get_dtype(y))
+    return [
+        _fit_shape(output_grad * y * (x ** (y - one)), output_grad_shape, x),
+        _fit_shape(output_grad * orig_var * log(x), output_grad_shape, y),
+    ]
+
+
+##################### Binary Comparison #####################
+# For comparison operators, the gradients are no_grad
+
+
+@register_gradient("relax.equal")
+def equal_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
+
+
+@register_gradient("relax.greater")
+def greater_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
+
+
+@register_gradient("relax.greater_equal")
+def greater_equal_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
+
+
+@register_gradient("relax.less")
+def less_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
+
+
+@register_gradient("relax.less_equal")
+def less_equal_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
+
+
+@register_gradient("relax.not_equal")
+def not_equal_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
+
+
+##################### Create #####################
+# For create operators, the gradients are no_grad.
+
+
+@register_gradient("relax.zeros_like")
+def zeros_like_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    return [no_grad(orig_call.args[0])]
+
+
+@register_gradient("relax.ones_like")
+def ones_like_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    return [no_grad(orig_call.args[0])]
+
+
+@register_gradient("relax.full_like")
+def full_like_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
+
+
+@register_gradient("relax.zeros")
+def zeros_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    return [no_grad(orig_call.args[0])]
+
+
+@register_gradient("relax.ones")
+def ones_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    return [no_grad(orig_call.args[0])]
+
+
+@register_gradient("relax.full")
+def full_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
+
+
+##################### Unary #####################
+
+
+@register_gradient("relax.abs")
+def abs_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of abs.
+
+    Forward Form:
+        `y = relax.abs(x)`
+
+    Backward:
+        Returns `[y_grad * where(x < 0, -1, 1)]`.
+    """
+    x = orig_call.args[0]
+    zero = relax.const(0, _get_dtype(x))
+    one = relax.const(1, _get_dtype(x))
+    return [output_grad * where(less(x, zero), -one, one)]
+
+
+@register_gradient("relax.cos")
+def cos_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of cos.
+
+    Forward Form:
+        `y = relax.cos(x)`
+
+    Backward:
+        Returns `[-y_grad * sin(x)]`.
+    """
+    return [-output_grad * sin(orig_call.args[0])]
+
+
+@register_gradient("relax.exp")
+def exp_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of exp.
+
+    Forward Form:
+        `y = relax.exp(x)`
+
+    Backward:
+        Returns `[y_grad * y]`.
+    """
+    return [output_grad * orig_var]
+
+
+@register_gradient("relax.log")
+def log_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of log.
+
+    Forward Form:
+        `y = relax.log(x)`
+
+    Backward:
+        Returns `[y_grad / x]`.
+    """
+    return [output_grad / orig_call.args[0]]
+
+
+@register_gradient("relax.negative")
+def negative_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of negative.
+
+    Forward Form:
+        `y = relax.negative(x)`
+
+    Backward:
+        Returns `[-y_grad]`.
+    """
+    return [-output_grad]
+
+
+@register_gradient("relax.sigmoid")
+def sigmoid_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of sigmoid.
+
+    Forward Form:
+        `y = relax.sigmoid(x)`
+
+    Backward:
+        Returns `[y_grad * y * (1 - y)]`.
+    """
+    one = relax.const(1, _get_dtype(orig_call.args[0]))
+    return [output_grad * orig_var * (one - orig_var)]
+
+
+@register_gradient("relax.sin")
+def sin_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of sin.
+
+    Forward Form:
+        `y = relax.sin(x)`
+
+    Backward:
+        Returns `[y_grad * cos(x)]`.
+    """
+    return [output_grad * cos(orig_call.args[0])]
+
+
+@register_gradient("relax.sqrt")
+def sqrt_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of sqrt.
+
+    Forward Form:
+        `y = relax.sqrt(x)`
+
+    Backward:
+        Returns `[0.5 * y_grad / sqrt(x)]`.
+    """
+    x = orig_call.args[0]
+    cst = relax.const(0.5, _get_dtype(x))
+    return [cst * output_grad / sqrt(x)]
+
+
+@register_gradient("relax.tanh")
+def tanh_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of tanh.
+
+    Forward Form:
+        `y = relax.tanh(x)`
+
+    Backward:
+        Returns `[y_grad * (1 - y * y)]`.
+    """
+    one = relax.const(1, _get_dtype(orig_call.args[0]))
+    return [output_grad * (one - orig_var * orig_var)]
+
+
+##################### Statistical #####################
+
+
+@register_gradient("relax.sum")
+def sum_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of sum.
+
+    Forward Form:
+        `y = relax.sum(x, axis, keepdims)`
+
+    Backward:
+        Returns `[broadcast_to(y_output_grad, x.shape)]`.
+
+        If `keepdims=False`, the summed axis will be added back.
+    """
+    axis = orig_call.attrs.axis
+    keepdims = orig_call.attrs.keepdims
+    if not keepdims and axis:
+        output_grad = expand_dims(output_grad, axis)
+    return [broadcast_to(output_grad, _get_shape(orig_call.args[0]))]
+
+
+@register_gradient("relax.mean")
+def mean_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of sum.

Review Comment:
   -> mean



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to