This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new d5182388a5 [Unity][Op] Gradient functions for high-level Relax
operators (#14527)
d5182388a5 is described below
commit d5182388a5254287f7a82838880129f2ea4f4a21
Author: Chaofan Lin <[email protected]>
AuthorDate: Sat Apr 8 20:48:09 2023 +0800
[Unity][Op] Gradient functions for high-level Relax operators (#14527)
This PR registers gradient functions for many high-level Relax operators.
Similar with Relay, the gradient function is registered as an attribute
`FPrimalGradient` (OpAttr) of corresponding Relax operators. But the function
signature is different from Relay:
```
using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(
const Var& orig_var, const Call& orig_call, const Var& output_grad,
const BlockBuilder& ctx)>;
```
- `orig_call` is the orginal call expr which we want to differentiate.
- `output_grad` is the gradient of RHS.
- `orig_var` is `y`. It is passed to saving some calculations.
- `ctx` is the context which is not used right now. But we believe it is
useful when it comes to dynamic shape cases and when we need to emit some
bindings or do some normalizations.
Also this PR fixes two small problems about op:
- `CumsumAttrs` isn't declared in the Python side.
- A small problem in the implementation about legalizing op `variance`.
Co-authored-by: Yixin Dong <[email protected]>
---
include/tvm/relax/op_attr_types.h | 12 +
python/tvm/relax/op/__init__.py | 4 +
python/tvm/relax/op/_op_gradient.py | 1199 ++++++++++++++++++++
python/tvm/relax/op/base.py | 26 +-
.../legalize_ops => op/grad}/__init__.py | 16 +-
.../__init__.py => op/grad/_ffi_api.py} | 16 +-
python/tvm/relax/op/grad/grad.py | 144 +++
python/tvm/relax/op/op_attrs.py | 5 +
.../tvm/relax/transform/legalize_ops/__init__.py | 1 +
python/tvm/relax/transform/legalize_ops/grad.py | 218 ++++
.../relax/transform/legalize_ops/statistical.py | 2 +-
python/tvm/script/ir_builder/relax/ir.py | 2 +
src/relax/op/tensor/grad.cc | 167 +++
src/relax/op/tensor/grad.h | 66 ++
tests/python/relax/test_op_grad.py | 96 ++
tests/python/relax/test_op_gradient_numeric.py | 794 +++++++++++++
.../relax/test_transform_legalize_ops_grad.py | 337 ++++++
...st_transform_legalize_ops_search_statistical.py | 239 ++--
.../python/relax/test_tvmscript_parser_op_grad.py | 142 +++
19 files changed, 3379 insertions(+), 107 deletions(-)
diff --git a/include/tvm/relax/op_attr_types.h
b/include/tvm/relax/op_attr_types.h
index 413d3e0499..64e5bd89a5 100644
--- a/include/tvm/relax/op_attr_types.h
+++ b/include/tvm/relax/op_attr_types.h
@@ -58,6 +58,18 @@ using FCallPacked = String;
*/
using FLegalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, const
Call& call)>;
+/*!
+ * \brief Gradient for a specific op.
+ *
+ * \param orig_var the original var corresponding to orig_call.
+ * \param orig_call the original Call(op) expr.
+ * \param output_grad the gradient of the Expr.
+ * \param ctx the current block builder context.
+ * \return the gradient for each parameter.
+ */
+using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(
+ const Var& orig_var, const Call& orig_call, const Var& output_grad, const
BlockBuilder& ctx)>;
+
} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_OP_ATTR_TYPES_H_
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 39a645ffea..d9af245d79 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -32,10 +32,14 @@ from .set import *
from .ternary import *
from .unary import *
from . import builtin
+from . import grad
from . import image
from . import memory
from . import nn
+# Operator gradient functions
+from . import _op_gradient
+
def _register_op_make():
# pylint: disable=import-outside-toplevel
diff --git a/python/tvm/relax/op/_op_gradient.py
b/python/tvm/relax/op/_op_gradient.py
new file mode 100644
index 0000000000..93206ae1be
--- /dev/null
+++ b/python/tvm/relax/op/_op_gradient.py
@@ -0,0 +1,1199 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-argument, redefined-builtin
+"""Gradient definitions for Relax operators."""
+import functools
+import operator
+from typing import List
+
+from tvm import relax
+from tvm._ffi.base import TVMError
+from tvm.arith import Analyzer
+
+from ..block_builder import BlockBuilder
+from ..expr import Call, Var, Expr, ShapeExpr
+from ...tir import PrimExpr
+
+from .base import register_gradient
+from .binary import less
+from .datatype import astype
+from .grad import (
+ no_grad,
+ nll_loss_backward,
+ max_pool2d_backward,
+ avg_pool2d_backward,
+ take_backward,
+)
+from .index import strided_slice
+from .linear_algebra import matmul
+from .manipulate import (
+ collapse_sum_to,
+ broadcast_to,
+ permute_dims,
+ expand_dims,
+ concat,
+ reshape,
+ split,
+ squeeze,
+ cumsum,
+ flatten,
+)
+from .nn import conv2d_transpose, conv2d
+from .search import where
+from .statistical import sum
+from .unary import cos, exp, log, sin, sqrt
+
+
+# TODO(yixin, chaofan): handle symbolic shape for most of the gradients
+
+
+##################### Utilities #####################
+
+
+def _get_shape(expr: Expr) -> ShapeExpr:
+ """Get the shape from a Tensor expr."""
+ try:
+ shape = expr.struct_info.shape
+ except Exception as error:
+ raise TVMError(
+ f"Get the shape of {expr} failed. Please normalize it first and
ensure it is a Tensor."
+ ) from error
+ return shape
+
+
+def _get_dtype(expr: Expr) -> str:
+ """Get the dtype from a Tensor expr."""
+ try:
+ dtype = expr.struct_info.dtype
+ except Exception as error:
+ raise TVMError(
+ f"Get the dtype of {expr} failed. Please normalize it first and
ensure it is a Tensor."
+ ) from error
+ return dtype
+
+
+def _fit_shape(expr: Expr, expr_shape: ShapeExpr, target: Expr) -> Expr:
+ target_shape = _get_shape(target)
+ expr_sinfo = expr_shape.struct_info
+ target_sinfo = target_shape.struct_info
+ assert isinstance(expr_sinfo, relax.ShapeStructInfo)
+ assert isinstance(target_sinfo, relax.ShapeStructInfo)
+
+ def _check_shape_equal():
+ if len(expr_sinfo.values) != len(target_sinfo.values):
+ return False
+ analyzer = Analyzer()
+ for i, field in enumerate(expr_sinfo.values):
+ if not analyzer.can_prove_equal(field, target_sinfo.values[i]):
+ return False
+ return True
+
+ return expr if _check_shape_equal() else collapse_sum_to(expr,
target_shape)
+
+
+def _get_shape_prod(expr, axis):
+ shape = _get_shape(expr)
+ if axis is None:
+ return functools.reduce(operator.mul, (int(i) for i in shape), 1)
+ else:
+ return functools.reduce(operator.mul, (int(shape[int(i)]) for i in
axis), 1)
+
+
+##################### Binary #####################
+
+
+@register_gradient("relax.add")
+def add_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of add.
+
+ Forward Form:
+ `z = relax.add(x, y)`
+
+ Backward:
+ Returns `[z_output_grad, z_grad]`.
+ """
+ output_grad_shape = _get_shape(output_grad)
+ return [
+ _fit_shape(output_grad, output_grad_shape, orig_call.args[0]),
+ _fit_shape(output_grad, output_grad_shape, orig_call.args[1]),
+ ]
+
+
+@register_gradient("relax.subtract")
+def subtract_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of subtract.
+
+ Forward Form:
+ `z = relax.subtract(x, y)`
+
+ Backward:
+ Returns `[z_output_grad, -z_grad]`.
+ """
+ output_grad_shape = _get_shape(output_grad)
+ return [
+ _fit_shape(output_grad, output_grad_shape, orig_call.args[0]),
+ _fit_shape(-output_grad, output_grad_shape, orig_call.args[1]),
+ ]
+
+
+@register_gradient("relax.multiply")
+def multiply_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of multiply.
+
+ Forward Form:
+ `z = relax.multiply(x, y)`
+
+ Backward:
+ Returns `[z_grad * y, z_grad * x]`.
+ """
+ x, y = orig_call.args
+ output_grad_shape = _get_shape(output_grad)
+ return [
+ _fit_shape(output_grad * y, output_grad_shape, x),
+ _fit_shape(output_grad * x, output_grad_shape, y),
+ ]
+
+
+@register_gradient("relax.divide")
+def divide_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of divide.
+
+ Forward Form:
+ `z = relax.divide(x, y)`
+
+ Backward:
+ Returns `[z_grad / y, -z_grad * z / y]`.
+ """
+ x, y = orig_call.args
+ output_grad_shape = _get_shape(output_grad)
+ return [
+ _fit_shape(output_grad / y, output_grad_shape, x),
+ _fit_shape(-output_grad * orig_var / y, output_grad_shape, y),
+ ]
+
+
+@register_gradient("relax.power")
+def power_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of power.
+
+ Forward Form:
+ `z = relax.power(x, y)`
+
+ Backward:
+ Returns `[y * x ** (y-1) * z_grad, z * ln(x) * z_grad]`.
+
+ The gradient w.r.t. the second parameter, y, makes sense only when x >
0.
+ """
+ x, y = orig_call.args
+ output_grad_shape = _get_shape(output_grad)
+ one = relax.const(1, _get_dtype(y))
+ return [
+ _fit_shape(output_grad * y * (x ** (y - one)), output_grad_shape, x),
+ _fit_shape(output_grad * orig_var * log(x), output_grad_shape, y),
+ ]
+
+
+##################### Binary Comparison #####################
+# For comparison operators, the gradients are no_grad
+
+
+@register_gradient("relax.equal")
+def equal_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
+
+
+@register_gradient("relax.greater")
+def greater_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
+
+
+@register_gradient("relax.greater_equal")
+def greater_equal_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
+
+
+@register_gradient("relax.less")
+def less_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
+
+
+@register_gradient("relax.less_equal")
+def less_equal_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
+
+
+@register_gradient("relax.not_equal")
+def not_equal_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
+
+
+##################### Create #####################
+# For create operators, the gradients are no_grad.
+
+
+@register_gradient("relax.zeros_like")
+def zeros_like_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ return [no_grad(orig_call.args[0])]
+
+
+@register_gradient("relax.ones_like")
+def ones_like_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ return [no_grad(orig_call.args[0])]
+
+
+@register_gradient("relax.full_like")
+def full_like_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
+
+
+@register_gradient("relax.zeros")
+def zeros_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ return [no_grad(orig_call.args[0])]
+
+
+@register_gradient("relax.ones")
+def ones_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ return [no_grad(orig_call.args[0])]
+
+
+@register_gradient("relax.full")
+def full_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
+
+
+##################### Unary #####################
+
+
+@register_gradient("relax.abs")
+def abs_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of abs.
+
+ Forward Form:
+ `y = relax.abs(x)`
+
+ Backward:
+ Returns `[y_grad * where(x < 0, -1, 1)]`.
+ """
+ x = orig_call.args[0]
+ zero = relax.const(0, _get_dtype(x))
+ one = relax.const(1, _get_dtype(x))
+ return [output_grad * where(less(x, zero), -one, one)]
+
+
+@register_gradient("relax.cos")
+def cos_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of cos.
+
+ Forward Form:
+ `y = relax.cos(x)`
+
+ Backward:
+ Returns `[-y_grad * sin(x)]`.
+ """
+ return [-output_grad * sin(orig_call.args[0])]
+
+
+@register_gradient("relax.exp")
+def exp_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of exp.
+
+ Forward Form:
+ `y = relax.exp(x)`
+
+ Backward:
+ Returns `[y_grad * y]`.
+ """
+ return [output_grad * orig_var]
+
+
+@register_gradient("relax.log")
+def log_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of log.
+
+ Forward Form:
+ `y = relax.log(x)`
+
+ Backward:
+ Returns `[y_grad / x]`.
+ """
+ return [output_grad / orig_call.args[0]]
+
+
+@register_gradient("relax.negative")
+def negative_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of negative.
+
+ Forward Form:
+ `y = relax.negative(x)`
+
+ Backward:
+ Returns `[-y_grad]`.
+ """
+ return [-output_grad]
+
+
+@register_gradient("relax.sigmoid")
+def sigmoid_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of sigmoid.
+
+ Forward Form:
+ `y = relax.sigmoid(x)`
+
+ Backward:
+ Returns `[y_grad * y * (1 - y)]`.
+ """
+ one = relax.const(1, _get_dtype(orig_call.args[0]))
+ return [output_grad * orig_var * (one - orig_var)]
+
+
+@register_gradient("relax.sin")
+def sin_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of sin.
+
+ Forward Form:
+ `y = relax.sin(x)`
+
+ Backward:
+ Returns `[y_grad * cos(x)]`.
+ """
+ return [output_grad * cos(orig_call.args[0])]
+
+
+@register_gradient("relax.sqrt")
+def sqrt_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of sqrt.
+
+ Forward Form:
+ `y = relax.sqrt(x)`
+
+ Backward:
+ Returns `[0.5 * y_grad / sqrt(x)]`.
+ """
+ x = orig_call.args[0]
+ cst = relax.const(0.5, _get_dtype(x))
+ return [cst * output_grad / sqrt(x)]
+
+
+@register_gradient("relax.tanh")
+def tanh_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of tanh.
+
+ Forward Form:
+ `y = relax.tanh(x)`
+
+ Backward:
+ Returns `[y_grad * (1 - y * y)]`.
+ """
+ one = relax.const(1, _get_dtype(orig_call.args[0]))
+ return [output_grad * (one - orig_var * orig_var)]
+
+
+##################### Statistical #####################
+
+
+@register_gradient("relax.sum")
+def sum_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of sum.
+
+ Forward Form:
+ `y = relax.sum(x, axis, keepdims)`
+
+ Backward:
+ Returns `[broadcast_to(y_output_grad, x.shape)]`.
+
+ If `keepdims=False`, the summed axis will be added back.
+ """
+ axis = orig_call.attrs.axis
+ keepdims = orig_call.attrs.keepdims
+ if not keepdims and axis:
+ output_grad = expand_dims(output_grad, axis)
+ return [broadcast_to(output_grad, _get_shape(orig_call.args[0]))]
+
+
+@register_gradient("relax.mean")
+def mean_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of mean.
+
+ Forward Form:
+ `y = relax.mean(x, axis, keepdims)`
+
+ Backward:
+ Returns `[broadcast_to(y_output_grad, x.shape) / prod(x.shape[i] for i
in axis)]`.
+
+ If `keepdims=False`, the meaned axis will be added back.
+ """
+ axis = orig_call.attrs.axis
+ keepdims = orig_call.attrs.keepdims
+ output_grad = output_grad / relax.const(
+ _get_shape_prod(orig_call.args[0], axis), _get_dtype(output_grad)
+ )
+ if not keepdims and axis:
+ output_grad = expand_dims(output_grad, axis)
+ return [broadcast_to(output_grad, _get_shape(orig_call.args[0]))]
+
+
+@register_gradient("relax.variance")
+def variance_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of variance.
+
+ Forward Form:
+ `y = relax.variance(x, axis, keepdims)`
+
+ Backward:
+ Returns `[broadcast_to(y_output_grad, x.shape)]`.
+
+ If `keepdims=False`, the summed axis will be added back.
+ """
+ x = orig_call.args[0]
+ axis = orig_call.attrs.axis
+ keepdims = orig_call.attrs.keepdims
+ shape_prod = _get_shape_prod(x, axis)
+ dtype = _get_dtype(x)
+ grad1 = relax.const(2.0 / shape_prod, dtype) * x
+ grad2 = relax.const(2.0 / shape_prod / shape_prod, dtype) * sum(x, axis,
keepdims=True)
+ if not keepdims and axis:
+ output_grad = expand_dims(output_grad, axis)
+ return [output_grad * (grad1 - grad2)]
+
+
+##################### Manipulate #####################
+
+
+@register_gradient("relax.permute_dims")
+def permute_dims_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of permute_dims.
+
+ Forward Form:
+ `y = relax.permute_dims(x, axes)`
+
+ Backward:
+ Returns grad transposed over the **inverse permutation** of the
original permute_dims axes.
+ """
+ axes = orig_call.attrs.axes
+ if axes:
+ dims = len(axes)
+ new_axes = [0] * dims
+ for i in range(dims):
+ new_axes[int(axes[i])] = i
+ return [permute_dims(output_grad, axes=new_axes)]
+ else:
+ return [permute_dims(output_grad)]
+
+
+@register_gradient("relax.concat")
+def concat_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of concat.
+
+ Forward Form:
+ `y = relax.concat((x1, x2, x3), axis)`
+
+ Backward:
+ Returns `[split(y_output_grad, [x1.shape[axis], x1.shape[axis] +
x2.shape[axis]], axis)]`.
+ """
+ axis = orig_call.attrs.axis
+ assert axis is not None
+ axis = int(axis)
+ split_indices: List[PrimExpr] = []
+ sinfo = orig_call.args[0].struct_info
+ assert isinstance(sinfo, relax.TupleStructInfo)
+ for i in range(len(sinfo.fields) - 1):
+ tensor_sinfo = sinfo.fields[i]
+ assert isinstance(tensor_sinfo, relax.TensorStructInfo)
+ assert tensor_sinfo.shape is not None
+ index = tensor_sinfo.shape[axis]
+ if i > 0:
+ index += split_indices[i - 1]
+ split_indices.append(index)
+ return [split(output_grad, split_indices, axis)]
+
+
+@register_gradient("relax.split")
+def split_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of split.
+
+ Forward Form:
+ `y = relax.split(x, indices, axis)`
+
+ Backward:
+ Returns `[concat(y_output_grad, axis)]`.
+ """
+ axis = orig_call.attrs.axis
+ axis = int(axis)
+ return [concat(output_grad, axis)]
+
+
+@register_gradient("relax.expand_dims")
+def expand_dims_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of expand_dims.
+
+ Forward Form:
+ `y = relax.expand_dims(x, axis)`
+
+ Backward:
+ Returns `[squeeze_dims(y_grad, axis)]`.
+ """
+ return [squeeze(output_grad, orig_call.attrs.axis)]
+
+
+@register_gradient("relax.reshape")
+def reshape_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of reshape.
+
+ Forward Form:
+ `y = relax.reshape(x, new_shape)`
+
+ Backward:
+ Returns `[reshape(y_grad, x.shape), no_grad]`.
+
+ The second parameter, the target ShapeExpr, is not differentiable.
+ """
+ return [
+ reshape(output_grad, _get_shape(orig_call.args[0])),
+ no_grad(orig_call.args[1]),
+ ]
+
+
+@register_gradient("relax.cumsum")
+def cumsum_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of cumsum.
+
+ Forward Form:
+ `y = relax.cumsum(x, axis)`
+
+ Backward:
+ The "reversed" cumsum along the same axis. Implement by some tricks
now.
+ """
+
+ axis = orig_call.attrs["axis"]
+ dtype = orig_call.attrs["dtype"]
+ x_shape = _get_shape(orig_call.args[0])
+
+ if axis is not None:
+ axis = int(axis)
+ grad = sum(output_grad, axis, keepdims=True) - cumsum(output_grad,
axis) + output_grad
+ else:
+ grad = reshape(
+ sum(output_grad, keepdims=True) - cumsum(output_grad) +
flatten(output_grad), x_shape
+ )
+
+ if dtype is not None:
+ grad = astype(grad, _get_dtype(orig_call.args[0]))
+
+ return [grad]
+
+
+##################### Index #####################
+
+
+@register_gradient("relax.take")
+def take_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of take.
+
+ Forward Form:
+ `y = relax.take(x, indices, axis)`
+
+ Backward:
+ Returns .
+
+ The second parameter, the indices, is not differentiable.
+ """
+
+ axis = orig_call.attrs["axis"]
+
+ return [
+ take_backward(output_grad, orig_call.args[0], orig_call.args[1], axis),
+ no_grad(orig_call.args[1]),
+ ]
+
+
+##################### Search #####################
+
+
+@register_gradient("relax.where")
+def where_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of where.
+
+ Forward Form:
+ `y = relax.where(cond, x1, x2)`
+
+ Backward:
+ Returns `[where(cond, y_grad, 0), where(cond, 0, y_grad)]`.
+
+ The first parameter, the condition, is not differentiable.
+ """
+
+ cond = orig_call.args[0]
+ x1_zero = relax.const(0, _get_dtype(orig_call.args[1]))
+ x2_zero = relax.const(0, _get_dtype(orig_call.args[2]))
+
+ return [
+ no_grad(orig_call.args[0]),
+ where(cond, output_grad, x1_zero),
+ where(cond, x2_zero, output_grad),
+ ]
+
+
+##################### Linear Algebra #####################
+
+
+@register_gradient("relax.matmul")
+def matmul_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of matmul.
+
+ Forward Form:
+ `c = relax.matmul(a, b)`
+
+ Backward:
+ Generally, returns `[c_grad @ b^T, a^T @ c_grad]`.
+
+ Here we only transpose the last two dimensions because of the
definition
+ of batch matmul. Note that ndim=1 should be treaded specially.
+ """
+
+ tensor_a, tensor_b = orig_call.args
+
+ a_dim = len(_get_shape(tensor_a))
+ b_dim = len(_get_shape(tensor_b))
+
+ def _transpose_last_two_dim(tensor, ndim):
+ """Helper function for reversing the last two dimensions."""
+ assert ndim > 1
+ return permute_dims(
+ tensor, axes=[i if i < ndim - 2 else 2 * ndim - 3 - i for i in
range(ndim)]
+ )
+
+ if a_dim > 1 and b_dim > 1:
+ a_grad = matmul(output_grad, _transpose_last_two_dim(tensor_b, b_dim))
+ b_grad = matmul(_transpose_last_two_dim(tensor_a, a_dim), output_grad)
+ elif a_dim == 1 and b_dim > 1:
+ a_expand = expand_dims(tensor_a, 1)
+ grad_expand = expand_dims(output_grad, -2)
+ a_grad = matmul(grad_expand, _transpose_last_two_dim(tensor_b, b_dim))
+ b_grad = matmul(a_expand, grad_expand)
+ elif b_dim == 1 and a_dim > 1:
+ b_expand = expand_dims(tensor_b, 0)
+ grad_expand = expand_dims(output_grad, -1)
+ a_grad = matmul(grad_expand, b_expand)
+ b_grad = squeeze(
+ matmul(_transpose_last_two_dim(tensor_a, a_dim), grad_expand),
axis=-1
+ ) # squeeze last dim
+ else:
+ assert a_dim == 1 and b_dim == 1
+ a_grad = output_grad * tensor_b
+ b_grad = output_grad * tensor_a
+
+ output_grad_shape = _get_shape(output_grad)
+
+ return [
+ _fit_shape(a_grad, output_grad_shape, tensor_a),
+ _fit_shape(b_grad, output_grad_shape, tensor_b),
+ ]
+
+
+##################### Datatype #####################
+
+
+@register_gradient("relax.astype")
+def astype_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of astype.
+
+ Forward Form:
+ `y = relax.astype(x, dtype_of_y)`
+
+ Backward:
+ Returns `[astype(y_grad, dtype_of_x)]`.
+ """
+ return [astype(output_grad, _get_dtype(orig_call.args[0]))]
+
+
+##################### Neural network #####################
+
+
+@register_gradient("relax.nn.relu")
+def relu_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of relu.
+
+ Forward Form:
+ `y = relax.relu(x)`
+
+ Backward:
+ Returns `[y_grad * (where(x < 0, 0, 1))]`.
+ """
+ x = orig_call.args[0]
+ one = relax.const(1, _get_dtype(x))
+ zero = relax.const(0, _get_dtype(x))
+ return [where(less(x, zero), zero, one) * output_grad]
+
+
+@register_gradient("relax.nn.softmax")
+def softmax_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of softmax.
+
+ Forward Form:
+ `y = relax.softmax(x, axis)`
+
+ Backward:
+ Returns `[(y_grad - sum(y_grad * y, axis, keepdims=True)) * y]`
+ """
+ return [(output_grad - sum(output_grad * orig_var, orig_call.attrs.axis,
True)) * orig_var]
+
+
+@register_gradient("relax.nn.log_softmax")
+def log_softmax_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of log_softmax.
+
+ Forward Form:
+ `y = relax.log_softmax(x, axis)`
+
+ Backward:
+ Returns `[y_grad - sum(y_output_grad, axis, keepdims=True) *
softmax(x)]`
+ """
+ x_softmax = exp(orig_var)
+ return [(output_grad - sum(output_grad, orig_call.attrs.axis, True) *
x_softmax)]
+
+
+@register_gradient("relax.nn.cross_entropy_with_logits")
+def cross_entropy_with_logits_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of cross_entropy_with_logits.
+
+ Forward Form:
+ `z = relax.nn.cross_entropy_with_logits(x, y)`
+
+ Backward:
+ Returns `[-z_grad * y, -z_grad * x]`.
+ If it has batch_size N, the results should divide by N.
+ """
+ x, y = orig_call.args
+
+ if x.struct_info.ndim > 1:
+ batch_size = int(_get_shape(x)[0])
+ output_grad = output_grad / relax.const(batch_size,
_get_dtype(output_grad))
+
+ return [-output_grad * y, -output_grad * x]
+
+
+# TODO(chaofan, yixin): remove nll_loss_backward and register the gradient
using existing operators
+# This may require one_hot, strided_set, etc.
+@register_gradient("relax.nn.nll_loss")
+def nll_loss_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+):
+ """Gradient of nll_loss.
+
+ Forward Form:
+ `z = relax.nn.nll_loss(predictions, targets, weights, reduction,
ignore_index)`
+
+ Suppose that `out = nll_loss(predictions, targets, weights, "none",
ignore_index)`, and
+ `z = reduction(out)` where reduction is in `["none", "mean", "sum"]`.
+
+ Backward:
+ First find the gradient w.r.t. `out`. Assume it is `out_grad`.
+
+ Gererally, the gradient w.r.t. predictions is
+
+ `predictions_grad[n, c, i_1, ..., i_k] = -o * w if c == t else 0`,
where
+ - `o = out_grad[n, i_1, ..., i_k]`,
+ - `w = weights[n, i_1, ..., i_k]`,
+ - `t = targets[n, i_1, ..., i_k]`.
+
+ Additional checks are added if `ignore_index >= 0`, `weights=None`, or
the predictions
+ provided do not have batch.
+
+ The gradient w.r.t. targets and weights are not available.
+ """
+ pred_grad = nll_loss_backward(
+ output_grad,
+ orig_call.args[0],
+ orig_call.args[1],
+ weights=orig_call.args[2] if len(orig_call.args) == 3 else None,
+ reduction=orig_call.attrs.reduction,
+ ignore_index=orig_call.attrs.ignore_index,
+ )
+ if len(orig_call.args) == 2:
+ return [pred_grad, no_grad(orig_call.args[1])]
+
+ return [pred_grad, no_grad(orig_call.args[1]), no_grad(orig_call.args[2])]
+
+
+@register_gradient("relax.nn.conv2d")
+def conv2d_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of conv2d. Now only supports `NCHW` data layout and `OIHW`
kernel layout.
+
+ Forward Form:
+ `y = relax.nn.conv2d(x, weight, strides, padding, dilation, groups,
data_layout, \
+kernel_layout, out_layout, out_dtype)`
+
+ Backward:
+ Returns `[x_grad, weight_grad]`
+ """
+ attrs = orig_call.attrs
+ assert attrs.data_layout == "NCHW", "only support NCHW data layout"
+ assert attrs.kernel_layout == "OIHW", "only support OIHW kernel layout"
+ assert attrs.out_layout == "NCHW", "only support NCHW output layout"
+
+ assert len(attrs.padding) == 4
+ assert len(attrs.strides) == 2
+ assert len(attrs.dilation) == 2
+
+ # calculate output_padding
+ data, weight = orig_call.args
+ batch, out_channel, grad_h, grad_w = _get_shape(orig_var)
+ _, in_channel, in_h, in_w = _get_shape(data)
+ _, _, filter_h, filter_w = _get_shape(weight)
+
+ pad_top, pad_left, pad_bottom, pad_right = attrs.padding
+ stride_h, stride_w = attrs.strides
+ dilation_h, dilation_w = attrs.dilation
+
+ out_h = (grad_h - 1) * stride_h - pad_top - pad_bottom + filter_h
+ out_w = (grad_w - 1) * stride_w - pad_left - pad_right + filter_w
+
+ output_padding = (in_h - out_h, in_w - out_w)
+
+ data_grad = conv2d_transpose( # type: ignore
+ output_grad,
+ orig_call.args[1],
+ attrs.strides,
+ attrs.padding,
+ output_padding,
+ attrs.dilation,
+ attrs.groups,
+ attrs.out_layout,
+ attrs.kernel_layout[1] + attrs.kernel_layout[0] +
attrs.kernel_layout[2:],
+ attrs.data_layout,
+ attrs.out_dtype,
+ )
+
+ if attrs.groups != 1:
+ data = reshape(data, (batch, attrs.groups, in_channel // attrs.groups,
in_h, in_w))
+ data = permute_dims(data, [1, 0, 2, 3, 4])
+ data = reshape(data, (batch * attrs.groups, in_channel //
attrs.groups, in_h, in_w))
+
+ weight_grad = conv2d(
+ data,
+ output_grad,
+ strides=attrs.dilation,
+ padding=attrs.padding,
+ dilation=attrs.strides,
+ groups=attrs.groups,
+ out_dtype=attrs.out_dtype,
+ data_layout="CNHW",
+ kernel_layout="IOHW",
+ out_layout="CNHW",
+ )
+
+ # infer shape of weight_grad
+ weight_grad_h = (in_h - (grad_h - 1) * stride_h - 1 + pad_top +
pad_bottom) // dilation_h + 1
+ weight_grad_w = (in_w - (grad_w - 1) * stride_w - 1 + pad_left +
pad_right) // dilation_w + 1
+
+ assert weight_grad_h >= filter_h
+ assert weight_grad_w >= filter_w
+
+ if weight_grad_h > filter_h or weight_grad_w > filter_w:
+ weight_grad = strided_slice(
+ weight_grad,
+ axes=[0, 1, 2, 3],
+ begin=[0, 0, 0, 0],
+ end=[out_channel, in_channel // attrs.groups, filter_h, filter_w],
+ )
+
+ return [data_grad, weight_grad]
+
+
+@register_gradient("relax.nn.max_pool2d")
+def max_pool2d_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+):
+ """Gradient of max_pool2d.
+
+ Forward Form:
+ `y = relax.nn.max_pool2d(x, pool_size, strides, padding, dilation,
ceil_mode, layout, \
+out_layout)`
+
+ Backward:
+ Returns `[x_grad]`
+ """
+ return [
+ max_pool2d_backward( # type: ignore
+ output_grad,
+ orig_call.args[0],
+ orig_call.attrs.pool_size,
+ orig_call.attrs.strides,
+ orig_call.attrs.padding,
+ orig_call.attrs.dilation,
+ orig_call.attrs.ceil_mode,
+ orig_call.attrs.layout,
+ orig_call.attrs.out_layout,
+ )
+ ]
+
+
+@register_gradient("relax.nn.avg_pool2d")
+def avg_pool2d_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+):
+ """Gradient of avg_pool2d.
+
+ Forward Form:
+ `y = relax.nn.avg_pool2d(x, pool_size, strides, padding, dilation,
ceil_mode, layout, \
+out_layout)`
+
+ Backward:
+ Returns `[x_grad]`
+ """
+ return [
+ avg_pool2d_backward( # type: ignore
+ output_grad,
+ orig_call.args[0],
+ orig_call.attrs.pool_size,
+ orig_call.attrs.strides,
+ orig_call.attrs.padding,
+ orig_call.attrs.dilation,
+ orig_call.attrs.ceil_mode,
+ orig_call.attrs.layout,
+ orig_call.attrs.out_layout,
+ )
+ ]
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index d6e8b29b6d..67f5f57070 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# pylint: disable=redefined-builtin
"""The base Relax operators."""
-from typing import Union, List, Tuple, Optional
+from typing import Union, List, Tuple, Optional, Callable
import tvm
@@ -23,7 +23,7 @@ import tvm.runtime
from tvm.runtime.object import Object
from . import _ffi_api
-from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar
+from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar, Var
from ..expr import Tuple as RxTuple
from ..struct_info import StructInfo, TensorStructInfo
from ...ir import PrimExpr
@@ -33,6 +33,28 @@ from ..utils import args_converter
py_print = print # pylint: disable=invalid-name
+def register_gradient(
+ op_name: str,
+ fgradient: Callable[[Var, Call, Var, "BlockBuilder"], List[Expr]] = None,
+ level: int = 10,
+):
+ """Register operator gradient function for a relax operator.
+
+ Parameters
+ ----------
+ op_name: str
+ The name of the op.
+
+ fgradient: function (orig_var: Var, orig_call: Call, output_grad: Var,
ctx: BlockBuilder)
+ -> partials: List[Expr]
+ The gradient function being used.
+
+ level: int
+ The priority level
+ """
+ return tvm.ir.register_op_attr(op_name, "FPrimalGradient", fgradient,
level)
+
+
def null_value() -> Call:
"""Create a call node that represents a null value object.
diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py
b/python/tvm/relax/op/grad/__init__.py
similarity index 71%
copy from python/tvm/relax/transform/legalize_ops/__init__.py
copy to python/tvm/relax/op/grad/__init__.py
index 3e57b815db..844b8ac381 100644
--- a/python/tvm/relax/transform/legalize_ops/__init__.py
+++ b/python/tvm/relax/op/grad/__init__.py
@@ -14,15 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Legalize high-level operator calls in Relax functions to call_tir."""
-from . import binary
-from . import creation
-from . import datatype
-from . import image
-from . import index
-from . import linear_algebra
-from . import manipulate
-from . import nn
-from . import search
-from . import statistical
-from . import unary
+# pylint: disable=wildcard-import, redefined-builtin
+"""Operators serving for finding gradient of relax operators."""
+
+from .grad import *
diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py
b/python/tvm/relax/op/grad/_ffi_api.py
similarity index 71%
copy from python/tvm/relax/transform/legalize_ops/__init__.py
copy to python/tvm/relax/op/grad/_ffi_api.py
index 3e57b815db..9b819dd4df 100644
--- a/python/tvm/relax/transform/legalize_ops/__init__.py
+++ b/python/tvm/relax/op/grad/_ffi_api.py
@@ -14,15 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Legalize high-level operator calls in Relax functions to call_tir."""
-from . import binary
-from . import creation
-from . import datatype
-from . import image
-from . import index
-from . import linear_algebra
-from . import manipulate
-from . import nn
-from . import search
-from . import statistical
-from . import unary
+"""FFI APIs for tvm.relax.op.grad"""
+import tvm._ffi
+
+tvm._ffi._init_api("relax.op.grad", __name__)
diff --git a/python/tvm/relax/op/grad/grad.py b/python/tvm/relax/op/grad/grad.py
new file mode 100644
index 0000000000..b433dc9c60
--- /dev/null
+++ b/python/tvm/relax/op/grad/grad.py
@@ -0,0 +1,144 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin
+"""Operators to implement operaor gradients. Used in `_op_gradient.py`.
+
+We are trying to keep grad operators as simple as possible, and hope they are
only used for finding
+gradients for forward operators. The struct_info inference for grad operators
just returns the
+struct_info of the input.
+"""
+from typing import Optional, Tuple
+
+from . import _ffi_api
+from ...expr import Expr
+
+
+def no_grad(input: Expr) -> Expr:
+ """No gradient dummy operator w.r.t. the input.
+
+ Parameters
+ ----------
+ input : relax.Expr
+ The corresponding input tensor.
+
+ Returns
+ -------
+ result : relax.Expr
+ The no-gradient representation w.r.t. input.
+ """
+ return _ffi_api.no_grad(input) # type: ignore
+
+
+def nll_loss_backward(
+ output_grad: Expr,
+ predictions: Expr,
+ targets: Expr,
+ weights: Optional[Expr] = None,
+ reduction: str = "mean",
+ ignore_index: int = -100,
+) -> Expr:
+ """Backward operator of relax.nll_loss. All parameters except output_grad
is the same as
+ relax.nll_loss. Returns the gradient w.r.t. predictions.
+
+ Parameters
+ ----------
+ output_grad : relax.Expr
+ The gradient w.r.t. the result of nll_loss.
+
+ Returns
+ -------
+ result : relax.Expr
+ The gradient w.r.t. predictions.
+ """
+ return _ffi_api.nll_loss_backward( # type: ignore
+ output_grad, predictions, targets, weights, reduction, ignore_index
+ )
+
+
+def max_pool2d_backward(
+ output_grad: Expr,
+ data: Expr,
+ pool_size: Tuple[int, int] = (1, 1),
+ strides: Tuple[int, int] = (1, 1),
+ padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
+ dilation: Tuple[int, int] = (1, 1),
+ ceil_mode: bool = False,
+ layout: str = "NCHW",
+ out_layout: Optional[str] = None,
+) -> Expr:
+ """Backward operator of relax.max_pool2d. All parameters except
output_grad is the same as
+ relax.max_pool2d. Returns the gradient w.r.t. data.
+
+ Parameters
+ ----------
+ output_grad : relax.Expr
+ The gradient w.r.t. the result of max_pool2d.
+
+ Returns
+ -------
+ result : relax.Expr
+ The gradient w.r.t. data.
+ """
+ return _ffi_api.max_pool2d_backward( # type: ignore
+ output_grad, data, pool_size, strides, padding, dilation, ceil_mode,
layout, out_layout
+ )
+
+
+def avg_pool2d_backward(
+ output_grad: Expr,
+ data: Expr,
+ pool_size: Tuple[int, int] = (1, 1),
+ strides: Tuple[int, int] = (1, 1),
+ padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
+ dilation: Tuple[int, int] = (1, 1),
+ ceil_mode: bool = False,
+ layout: str = "NCHW",
+ out_layout: Optional[str] = None,
+) -> Expr:
+ """Backward operator of relax.avg_pool2d. All parameters except
output_grad is the same as
+ relax.avg_pool2d. Returns the gradient w.r.t. data.
+
+ Parameters
+ ----------
+ output_grad : relax.Expr
+ The gradient w.r.t. the result of avg_pool2d.
+
+ Returns
+ -------
+ result : relax.Expr
+ The gradient w.r.t. data.
+ """
+ return _ffi_api.avg_pool2d_backward( # type: ignore
+ output_grad, data, pool_size, strides, padding, dilation, ceil_mode,
layout, out_layout
+ )
+
+
+def take_backward(output_grad: Expr, x: Expr, indices: Expr, axis:
Optional[int] = None) -> Expr:
+ """Backward operator of relax.take. All parameters except output_grad is
the same as
+ relax.take. Returns the gradient w.r.t. x.
+
+ Parameters
+ ----------
+ output_grad : relax.Expr
+ The gradient w.r.t. the result of take.
+
+ Returns
+ -------
+ result : relax.Expr
+ The gradient w.r.t. x.
+ """
+ return _ffi_api.take_backward(output_grad, x, indices, axis) # type:
ignore
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index 2d0fdd14b3..7b002aa9f3 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -142,3 +142,8 @@ class RepeatAttrs(Attrs):
@tvm._ffi.register_object("relax.attrs.TileAttrs")
class TileAttrs(Attrs):
"""Attributes for tile operator"""
+
+
+@tvm._ffi.register_object("relax.attrs.CumsumAttrs")
+class CumsumAttrs(Attrs):
+ """Attributes for cumsum operator"""
diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py
b/python/tvm/relax/transform/legalize_ops/__init__.py
index 3e57b815db..8b668e5040 100644
--- a/python/tvm/relax/transform/legalize_ops/__init__.py
+++ b/python/tvm/relax/transform/legalize_ops/__init__.py
@@ -18,6 +18,7 @@
from . import binary
from . import creation
from . import datatype
+from . import grad
from . import image
from . import index
from . import linear_algebra
diff --git a/python/tvm/relax/transform/legalize_ops/grad.py
b/python/tvm/relax/transform/legalize_ops/grad.py
new file mode 100644
index 0000000000..7fb9b0864d
--- /dev/null
+++ b/python/tvm/relax/transform/legalize_ops/grad.py
@@ -0,0 +1,218 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-argument
+"""Default legalization function for perators to implement operaor
gradients."""
+import logging
+
+from tvm import te, tir, topi
+from ...block_builder import BlockBuilder
+from ...expr import Call, Expr
+from .common import register_legalize
+
+
+@register_legalize("relax.grad.no_grad")
+def _no_grad(bb: BlockBuilder, call: Call) -> Expr:
+ return call.args[0]
+
+
+@register_legalize("relax.grad.nll_loss_backward")
+def _grad_nll_loss_backward(bb: BlockBuilder, call: Call) -> Expr:
+ # topi.sum don't support zero-dim x
+ # we add support for that
+ def topi_sum_extend(x):
+ return x if x.ndim == 0 else topi.sum(x)
+
+ def te_nll_loss_backward(output_grad, predictions, targets, weights,
reduction, ignore_index):
+ # handle ignore_index
+ if ignore_index >= 0:
+ weights = te.compute(
+ weights.shape,
+ lambda i: tir.Select(i == ignore_index, tir.const(0,
weights.dtype), weights(i)),
+ "weights_new",
+ )
+
+ all_weights = te.compute(targets.shape, lambda *i:
weights(targets(*i)), "all_weights")
+
+ # handle reduction
+ if reduction == "sum":
+ output_grad = topi.broadcast_to(output_grad, targets.shape)
+ elif reduction == "mean":
+ output_grad = topi.divide(
+ topi.broadcast_to(output_grad, targets.shape),
topi_sum_extend(all_weights)
+ )
+
+ # handle no batch
+ if predictions.ndim == 1:
+ return te.compute(
+ predictions.shape,
+ lambda i: tir.Select(
+ i == targets(), -all_weights() * output_grad(),
tir.const(0, predictions.dtype)
+ ),
+ "pred_grad",
+ )
+
+ return te.compute(
+ predictions.shape,
+ lambda *i: tir.Select(
+ i[1] == targets(*i[:1], *i[2:]),
+ -all_weights(*i[:1], *i[2:]) * output_grad(*i[:1], *i[2:]),
+ tir.const(0, predictions.dtype),
+ ),
+ "pred_grad",
+ )
+
+ def te_nll_loss_backward_no_weight(output_grad, predictions, targets,
reduction, ignore_index):
+ weight = topi.full(
+ (predictions.shape[1] if len(predictions.shape) > 1 else
predictions.shape[0],),
+ predictions.dtype,
+ 1.0,
+ )
+ return te_nll_loss_backward(
+ output_grad, predictions, targets, weight, reduction, ignore_index
+ )
+
+ if len(call.args) == 3:
+ return bb.call_te(
+ te_nll_loss_backward_no_weight,
+ *call.args,
+ reduction=call.attrs.reduction,
+ ignore_index=call.attrs.ignore_index,
+ )
+
+ return bb.call_te(
+ te_nll_loss_backward,
+ *call.args,
+ reduction=call.attrs.reduction,
+ ignore_index=call.attrs.ignore_index,
+ primfunc_name_hint="nll_loss_backward",
+ )
+
+
+@register_legalize("relax.grad.max_pool2d_backward")
+def _grad_max_pool2d_backward(bb: BlockBuilder, call: Call) -> Expr:
+ if not (len(call.attrs.dilation) == 2 and all(i == 1 for i in
call.attrs.dilation)):
+ logging.info("Dilation is not supported in TOPI pool_grad and is not
legalized.")
+ return call
+ return bb.call_te(
+ topi.nn.pool_grad,
+ call.args[0],
+ call.args[1],
+ kernel=call.attrs.pool_size,
+ stride=call.attrs.strides,
+ padding=call.attrs.padding,
+ pool_type="max",
+ ceil_mode=call.attrs.ceil_mode,
+ layout=call.attrs.layout,
+ primfunc_name_hint="max_pool2d_backward",
+ )
+
+
+@register_legalize("relax.grad.avg_pool2d_backward")
+def _grad_avg_pool2d_backward(bb: BlockBuilder, call: Call) -> Expr:
+ if not (len(call.attrs.dilation) == 2 and all(i == 1 for i in
call.attrs.dilation)):
+ logging.info("Dilation is not supported in TOPI pool_grad and is not
legalized.")
+ return call
+ return bb.call_te(
+ topi.nn.pool_grad,
+ call.args[0],
+ call.args[1],
+ kernel=call.attrs.pool_size,
+ stride=call.attrs.strides,
+ padding=call.attrs.padding,
+ pool_type="avg",
+ ceil_mode=call.attrs.ceil_mode,
+ layout=call.attrs.layout,
+ primfunc_name_hint="avg_pool2d_backward",
+ )
+
+
+@register_legalize("relax.grad.take_backward")
+def _grad_take_backward(bb: BlockBuilder, call: Call) -> Expr:
+ axis = call.attrs.axis
+ if axis is not None:
+ axis = int(axis)
+
+ def te_take_backward(output_grad, x, indices):
+ def gen_ir(output_grad_ptr, x_ptr, indices_ptr, out_ptr):
+ # pylint: disable=invalid-name
+ ib = tir.ir_builder.create()
+
+ output_grad = ib.buffer_ptr(output_grad_ptr)
+ indices = ib.buffer_ptr(indices_ptr)
+ out = ib.buffer_ptr(out_ptr)
+
+ fused_shape = 1
+ for i in x_ptr.shape:
+ fused_shape *= i
+
+ with ib.for_range(0, fused_shape) as i:
+ out[i] = tir.const(0, dtype=x_ptr.dtype)
+
+ indices_len = indices_ptr.shape[0].value # must be 1-dim
+
+ if axis is not None:
+ fused_output_grad_shape_pre = 1
+ fused_output_grad_shape_nxt = 1
+ for i in range(len(output_grad_ptr.shape)):
+ if i < axis:
+ fused_output_grad_shape_pre *= output_grad_ptr.shape[i]
+ elif i > axis:
+ fused_output_grad_shape_nxt *= output_grad_ptr.shape[i]
+
+ x_axis_len = x_ptr.shape[axis].value
+
+ with ib.for_range(
+ 0, fused_output_grad_shape_pre *
fused_output_grad_shape_nxt, "parallel"
+ ) as fused:
+ i = fused // fused_output_grad_shape_nxt
+ j = fused % fused_output_grad_shape_nxt
+ for l in reversed(range(indices_len)):
+ out[
+ i * fused_output_grad_shape_nxt * x_axis_len
+ + indices[l] * fused_output_grad_shape_nxt
+ + j
+ ] += output_grad[
+ i * fused_output_grad_shape_nxt * indices_len
+ + l * fused_output_grad_shape_nxt
+ + j
+ ]
+ else:
+ for l in reversed(range(indices_len)):
+ out[indices[l]] += output_grad[l]
+
+ return ib.get()
+
+ shape = x.shape
+ out_buf = tir.decl_buffer(shape, x.dtype, "out_buf")
+
+ return te.extern(
+ [shape],
+ [output_grad, x, indices],
+ lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
+ dtype=x.dtype,
+ out_buffers=[out_buf],
+ name="take_backward",
+ tag="take_backward",
+ )
+
+ return bb.call_te(
+ te_take_backward,
+ call.args[0],
+ call.args[1],
+ call.args[2],
+ primfunc_name_hint="take_backward",
+ )
diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py
b/python/tvm/relax/transform/legalize_ops/statistical.py
index 3307d49f21..d9753d78d0 100644
--- a/python/tvm/relax/transform/legalize_ops/statistical.py
+++ b/python/tvm/relax/transform/legalize_ops/statistical.py
@@ -45,7 +45,7 @@ def _te_mean(x: te.Tensor, axis: List[tir.IntImm], keepdims:
bool) -> te.Tensor:
def _te_variance(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) ->
te.Tensor:
- dev = x - _te_mean(x, axis, keepdims)
+ dev = x - _te_mean(x, axis, True)
return _te_mean(dev * dev, axis, keepdims)
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 4630a850bf..242a9ff464 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -69,6 +69,7 @@ from tvm.relax.op import (
floor_divide,
full,
full_like,
+ grad,
greater,
greater_equal,
image,
@@ -590,6 +591,7 @@ __all__ = [
"func_ret_struct_info",
"func_ret_value",
"function",
+ "grad",
"greater",
"greater_equal",
"image",
diff --git a/src/relax/op/tensor/grad.cc b/src/relax/op/tensor/grad.cc
new file mode 100644
index 0000000000..a3bddd951b
--- /dev/null
+++ b/src/relax/op/tensor/grad.cc
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file grad.cc
+ * \brief Operators to implement operaor gradients.
+ */
+
+#include "grad.h"
+
+#include <utility>
+
+namespace tvm {
+namespace relax {
+
+/* relax.grad.no_grad */
+Expr no_grad(Expr input) {
+ static const Op& op = Op::Get("relax.grad.no_grad");
+ return Call(op, {std::move(input)}, {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.grad.no_grad").set_body_typed(no_grad);
+
+StructInfo InferStructInfoNoGrad(const Call& call, const BlockBuilder& ctx) {
+ return GetStructInfo(call->args[0]);
+}
+
+TVM_REGISTER_OP("relax.grad.no_grad")
+ .set_num_inputs(0)
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoNoGrad);
+
+/* relax.grad.nll_loss_backward */
+Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets,
Optional<Expr> weights,
+ String reduction, int ignore_index) {
+ ObjectPtr<NLLLossAttrs> attrs = make_object<NLLLossAttrs>();
+
+ attrs->reduction = reduction;
+ attrs->ignore_index = ignore_index;
+
+ static const Op& op = Op::Get("relax.grad.nll_loss_backward");
+ if (weights.defined()) {
+ return Call(op,
+ {std::move(output_grad), std::move(predictions),
std::move(targets),
+ std::move(weights.value())},
+ Attrs{attrs}, {});
+ } else {
+ return Call(op, {std::move(output_grad), std::move(predictions),
std::move(targets)},
+ Attrs{attrs}, {});
+ }
+}
+
+TVM_REGISTER_GLOBAL("relax.op.grad.nll_loss_backward").set_body_typed(nll_loss_backward);
+
+StructInfo InferStructInfoNLLLossBackward(const Call& call, const
BlockBuilder& ctx) {
+ return GetStructInfo(call->args[1]);
+}
+
+TVM_REGISTER_OP("relax.grad.nll_loss_backward")
+ .set_attrs_type<NLLLossAttrs>()
+ .set_num_inputs(4)
+ .add_argument("output_grad", "Tensor", "The output gradient.")
+ .add_argument("predictions", "Tensor", "The prediction tensor.")
+ .add_argument("targets", "Tensor", "The target tensor.")
+ .add_argument("weights", "Optional<Tensor>", "The weight of each target
values.")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoNLLLossBackward);
+
+/* relax.grad.max_pool2d_backward */
+Expr max_pool2d_backward(Expr output_grad, Expr data, Array<IntImm> pool_size,
+ Array<IntImm> strides, Array<IntImm> padding,
Array<IntImm> dilation,
+ bool ceil_mode, String layout, Optional<String>
out_layout) {
+ auto attrs = make_object<Pool2DAttrs>();
+ attrs->pool_size = std::move(pool_size);
+ attrs->strides = ConvertIntImmToInt64(strides);
+ attrs->padding = ConvertIntImmToInt64(padding);
+ attrs->dilation = ConvertIntImmToInt64(dilation);
+ attrs->ceil_mode = ceil_mode;
+ attrs->layout = layout;
+ attrs->out_layout = out_layout.value_or(layout);
+ static const Op& op = Op::Get("relax.grad.max_pool2d_backward");
+ return Call(op, {std::move(output_grad), std::move(data)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.grad.max_pool2d_backward").set_body_typed(max_pool2d_backward);
+
+StructInfo InferStructInfoMaxPool2DBackward(const Call& call, const
BlockBuilder& ctx) {
+ return GetStructInfo(call->args[1]);
+}
+
+TVM_REGISTER_OP("relax.grad.max_pool2d_backward")
+ .set_num_inputs(2)
+ .add_argument("output_grad", "Tensor", "The output gradient.")
+ .add_argument("data", "Tensor", "The input tensor")
+ .set_attrs_type<Pool2DAttrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoMaxPool2DBackward);
+
+/* relax.grad.avg_pool2d_backward */
+Expr avg_pool2d_backward(Expr output_grad, Expr data, Array<IntImm> pool_size,
+ Array<IntImm> strides, Array<IntImm> padding,
Array<IntImm> dilation,
+ bool ceil_mode, String layout, Optional<String>
out_layout) {
+ auto attrs = make_object<Pool2DAttrs>();
+ attrs->pool_size = std::move(pool_size);
+ attrs->strides = ConvertIntImmToInt64(strides);
+ attrs->padding = ConvertIntImmToInt64(padding);
+ attrs->dilation = ConvertIntImmToInt64(dilation);
+ attrs->ceil_mode = ceil_mode;
+ attrs->layout = layout;
+ attrs->out_layout = out_layout.value_or(layout);
+ static const Op& op = Op::Get("relax.grad.avg_pool2d_backward");
+ return Call(op, {std::move(output_grad), std::move(data)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.grad.avg_pool2d_backward").set_body_typed(avg_pool2d_backward);
+
+StructInfo InferStructInfoAvgPool2DBackward(const Call& call, const
BlockBuilder& ctx) {
+ return GetStructInfo(call->args[1]);
+}
+
+TVM_REGISTER_OP("relax.grad.avg_pool2d_backward")
+ .set_num_inputs(2)
+ .add_argument("output_grad", "Tensor", "The output gradient.")
+ .add_argument("data", "Tensor", "The input tensor")
+ .set_attrs_type<Pool2DAttrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoAvgPool2DBackward);
+
+/* relax.grad.take_backward */
+TVM_REGISTER_NODE_TYPE(TakeAttrs);
+
+Expr take_backward(Expr output_grad, Expr x, Expr indices, Optional<Integer>
axis) {
+ ObjectPtr<TakeAttrs> attrs = make_object<TakeAttrs>();
+ attrs->axis = std::move(axis);
+
+ static const Op& op = Op::Get("relax.grad.take_backward");
+ return Call(op, {std::move(output_grad), std::move(x), std::move(indices)},
Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.grad.take_backward").set_body_typed(take_backward);
+
+StructInfo InferStructInfoTakeBackward(const Call& call, const BlockBuilder&
ctx) {
+ return GetStructInfo(call->args[1]);
+}
+
+TVM_REGISTER_OP("relax.grad.take_backward")
+ .set_attrs_type<TakeAttrs>()
+ .set_num_inputs(3)
+ .add_argument("output_grad", "Tensor", "The output gradient.")
+ .add_argument("x", "Tensor", "The source tensor.")
+ .add_argument("indices", "Tensor", "The indices of the values to extract.")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoTakeBackward);
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/op/tensor/grad.h b/src/relax/op/tensor/grad.h
new file mode 100644
index 0000000000..886516020d
--- /dev/null
+++ b/src/relax/op/tensor/grad.h
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file grad.h
+ * \brief The functions to make Relax gradient operators.
+ */
+#ifndef TVM_RELAX_OP_TENSOR_GRAD_H_
+#define TVM_RELAX_OP_TENSOR_GRAD_H_
+
+#include <tvm/relax/attrs/index.h>
+#include <tvm/relax/attrs/nn.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief No gradient dummy operator.
+ * \param input The corresponding input tensor.
+ * \return The no-gradient representation w.r.t. input.
+ */
+Expr no_grad(Expr input);
+
+/*! \brief Backward operator of relax.nll_loss. All parameters except
output_grad is the same as
+ * relax.nll_loss. Returns the gradient w.r.t. predictions. */
+Expr nll_loss_backward(Expr output_grad, Expr predictions, Expr targets,
Optional<Expr> weights,
+ String reduction, int ignore_index);
+
+/*! \brief Backward operator of relax.max_pool2d. All parameters except
output_grad is the same as
+ * relax.max_pool2d. Returns the gradient w.r.t. data. */
+Expr max_pool2d_backward(Expr output_grad, Expr data, Array<IntImm> pool_size,
+ Array<IntImm> strides, Array<IntImm> padding,
Array<IntImm> dilation,
+ bool ceil_mode, String layout, Optional<String>
out_layout);
+
+/*! \brief Backward operator of relax.avg_pool2d. All parameters except
output_grad is the same as
+ * relax.avg_pool2d. Returns the gradient w.r.t. data. */
+Expr avg_pool2d_backward(Expr output_grad, Expr data, Array<IntImm> pool_size,
+ Array<IntImm> strides, Array<IntImm> padding,
Array<IntImm> dilation,
+ bool ceil_mode, String layout, Optional<String>
out_layout);
+
+/*! \brief Backward operator of relax.take. All parameters except output_grad
is the same as
+ * relax.take. Returns the gradient w.r.t. data. */
+Expr take_backward(Expr output_grad, Expr x, Expr indices, Optional<Integer>
axis);
+
+} // namespace relax
+} // namespace tvm
+
+#endif // TVM_RELAX_OP_TENSOR_GRAD_H_
diff --git a/tests/python/relax/test_op_grad.py
b/tests/python/relax/test_op_grad.py
new file mode 100644
index 0000000000..01c4226d96
--- /dev/null
+++ b/tests/python/relax/test_op_grad.py
@@ -0,0 +1,96 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.ir import Op
+from tvm.script import relax as R
+
+
+def test_op_correctness():
+ g = relax.Var("g", R.Tensor((3, 10, 10), "float32"))
+ x = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32"))
+ y = relax.Var("y", R.Tensor((3, 10, 10), "int64"))
+ w = relax.Var("w", R.Tensor((5,), "float32"))
+ assert relax.op.grad.nll_loss_backward(g, x, y, w).op ==
Op.get("relax.grad.nll_loss_backward")
+
+ g = relax.Var("g", R.Tensor((3, 3, 8, 8), "float32"))
+ x = relax.Var("x", R.Tensor((3, 2, 10, 10), "float32"))
+ assert relax.op.grad.max_pool2d_backward(g, x, (3, 3)).op == Op.get(
+ "relax.grad.max_pool2d_backward"
+ )
+ assert relax.op.grad.avg_pool2d_backward(g, x, (3, 3)).op == Op.get(
+ "relax.grad.avg_pool2d_backward"
+ )
+ g = relax.Var("g", R.Tensor((3, 2, 5), "float32"))
+ x = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
+ indices = relax.Var("indices", R.Tensor((2,), "float32"))
+ assert relax.op.grad.take_backward(g, x, indices, axis=1).op == Op.get(
+ "relax.grad.take_backward"
+ )
+ assert relax.op.grad.no_grad(x).op == Op.get("relax.grad.no_grad")
+
+
+def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo:
relax.StructInfo):
+ ret = bb.normalize(call)
+ tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+
+
+def test_nll_loss_backward_infer_struct_info():
+ bb = relax.BlockBuilder()
+
+ g = relax.Var("g", R.Tensor((3, 10, 10)))
+ x = relax.Var("x", R.Tensor((3, 5, 10, 10), "float32"))
+ y = relax.Var("y", R.Tensor((3, 10, 10), "int64"))
+ w = relax.Var("w", R.Tensor((5,), "float32"))
+
+ _check_inference(bb, relax.op.grad.nll_loss_backward(g, x, y),
x.struct_info)
+ _check_inference(bb, relax.op.grad.nll_loss_backward(g, x, y, w),
x.struct_info)
+
+
+def test_max_pool2d_backward_infer_struct_info():
+ bb = relax.BlockBuilder()
+
+ g = relax.Var("g", R.Tensor((3, 3, 8, 8), "float32"))
+ x = relax.Var("x", R.Tensor((3, 2, 10, 10), "float32"))
+
+ _check_inference(bb, relax.op.grad.max_pool2d_backward(g, x, (2, 2)),
x.struct_info)
+ _check_inference(bb, relax.op.grad.max_pool2d_backward(g, x, (3, 3)),
x.struct_info)
+
+
+def test_avg_pool2d_backward_infer_struct_info():
+ bb = relax.BlockBuilder()
+
+ g = relax.Var("g", R.Tensor((3, 3, 8, 8), "float32"))
+ x = relax.Var("x", R.Tensor((3, 2, 10, 10), "float32"))
+
+ _check_inference(bb, relax.op.grad.avg_pool2d_backward(g, x, (2, 2)),
x.struct_info)
+ _check_inference(bb, relax.op.grad.avg_pool2d_backward(g, x, (3, 3)),
x.struct_info)
+
+
+def test_take_backward_infer_struct_info():
+ bb = relax.BlockBuilder()
+
+ g = relax.Var("g", R.Tensor((3, 2, 5), "float32"))
+ x = relax.Var("x", R.Tensor((3, 4, 5), "float32"))
+ indices = relax.Var("indices", R.Tensor((2,), "float32"))
+
+ _check_inference(bb, relax.op.grad.take_backward(g, x, indices, axis=1),
x.struct_info)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_op_gradient_numeric.py
b/tests/python/relax/test_op_gradient_numeric.py
new file mode 100644
index 0000000000..cf2ff777d2
--- /dev/null
+++ b/tests/python/relax/test_op_gradient_numeric.py
@@ -0,0 +1,794 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from typing import Callable, Union, Tuple, List
+
+import numpy as np
+import tvm
+from tvm.relax.expr import Call
+from tvm.relax.struct_info import TensorStructInfo, TupleStructInfo
+import tvm.testing
+from tvm import relax
+from tvm.relax.transform import LegalizeOps
+from tvm.testing.utils import check_numerical_grads
+from tvm.ir.op import Op
+
+
+def relax_check_gradients(
+ op_func: Callable,
+ inputs_numpy: List[np.array],
+ target: Union[str, tvm.target.Target],
+ dev: tvm._ffi.runtime_ctypes.Device,
+ tuple_input: bool = False,
+ ignore_grads: List[int] = [],
+ **kwargs, # attr for operators
+):
+ """Generate the forward and the gradient module. Then run them and check
numeric gradients.
+
+ Parameters
+ ----------
+ op_func : Callable
+ The forward operator function. Should be a function in package
relax.op.
+
+ inputs_numpy : List[np.array]
+ The np array inputs for op_func. inputs_numpy will be transformed into
TVM NDArray inside
+ this function.
+
+ If op_func takes a tuple of tensors as input, you can set tuple_input
as True, and pass the
+ tuple input (or list) as inputs_numpy. See test_concat().
+
+ target : Union[str, tvm.target.Target]
+ The building target.
+
+ dev : tvm._ffi.runtime_ctypes.Device
+ The device to deploy the module.
+
+ tuple_input : bool
+ Whether the operator accepts a tuple as input. If true, operator will
accept exactly one
+ tuple of tensors as input; otherwise, operator accept one or more
tensors as input. See
+ test_concat(). Default: False.
+
+ ignore_grads: List[int]
+ Specifies which input we do not need to find gradient.
+
+ Sometimes the input is not differentiable, such as shape, boolean
values, positions, etc.
+ We can specify the index of these inputs to check the gradient of them
is no_grad, and
+ prevent computing numeric gradient.
+
+ kwargs : Any
+ The keyword arguments for the op_func. Will be passed to op_func
directly.
+ """
+
+ func_name = "main"
+
+ # Helper functions
+ def _numpy_to_sinfo(data):
+ if isinstance(data, list):
+ return relax.TupleStructInfo([_numpy_to_sinfo(d) for d in data])
+ return relax.TensorStructInfo(data.shape, str(data.dtype))
+
+ def _numpy_to_tvm(data):
+ if isinstance(data, list):
+ return [_numpy_to_tvm(d) for d in data]
+ return tvm.nd.array(data)
+
+ def _tvm_to_numpy(data, ignore_idx=[]):
+ if isinstance(data, tvm.ir.Array):
+ return [_tvm_to_numpy(d) for i, d in enumerate(data) if i not in
ignore_idx]
+ if isinstance(data, tvm.runtime.ndarray.NDArray):
+ return data.numpy()
+ return data
+
+ def _gen_weights(out_sinfo):
+ if isinstance(out_sinfo, TupleStructInfo):
+ return [_gen_weights(sinfo) for sinfo in out_sinfo.fields]
+ else:
+ assert isinstance(out_sinfo, TensorStructInfo)
+ return np.random.uniform(size=[int(i) for i in
out_sinfo.shape]).astype(out_sinfo.dtype)
+
+ def _is_call_no_grad(expr):
+ return isinstance(expr, Call) and expr.op ==
Op.get("relax.grad.no_grad")
+
+ # Generate parameter relax Vars
+ param_vars = [
+ relax.Var("x_" + str(i), _numpy_to_sinfo(data)) for i, data in
enumerate(inputs_numpy)
+ ]
+
+ # Generate the forward call
+ if tuple_input:
+ t = relax.Tuple(param_vars)
+ call = op_func(t, **kwargs)
+ else:
+ call = op_func(*param_vars, **kwargs)
+
+ # Forward mod
+ forward_bb = relax.BlockBuilder()
+ with forward_bb.function(func_name, param_vars):
+ with forward_bb.dataflow():
+ out = forward_bb.emit_output(call)
+ forward_bb.emit_func_output(out)
+ forward_mod = forward_bb.get()
+ forward_lower_mod = LegalizeOps()(forward_mod)
+ forward_ex = relax.build(forward_lower_mod, target)
+ forward_vm = relax.VirtualMachine(forward_ex, dev)
+
+ # Generate weights
+ # In forward process, weights represent the weight of every element of the
result of the
+ # forward call. The weighted result will be sum(weight * result).
+ # If the result is a tuple, weights will be a list, and the weighted
result will be
+ # sum(i * j for i, j in zip(weights, result))
+ # In the gradient process, weights is the output gradient, i.e. the
gradient w.r.t. the result.
+ out_sinfo = forward_mod[func_name].body.body.struct_info
+ weights = _gen_weights(out_sinfo)
+
+ # The inputs of the forward function are inputs_filtered below.
+ def forward(*inputs):
+ inputs_iter = iter(inputs)
+ inputs_tvm = [
+ _numpy_to_tvm(next(inputs_iter))
+ if i not in ignore_grads
+ else _numpy_to_tvm(inputs_numpy[i])
+ for i in range(len(inputs_numpy))
+ ]
+ result = forward_vm[func_name](*inputs_tvm)
+ result_numpy = _tvm_to_numpy(result)
+ if isinstance(result_numpy, list):
+ assert isinstance(weights, list)
+ assert len(weights) == len(result_numpy)
+ ret = 0
+ for i, weight in enumerate(weights):
+ ret += np.sum(weight * result_numpy[i])
+ return ret
+ return np.sum(weights * result_numpy)
+
+ # The gradient function
+ assert isinstance(call.op, Op)
+ op_grad_func = call.op.get_attr("FPrimalGradient")
+
+ # The parameter Var for gradient
+ grad_var = relax.Var("grad", _numpy_to_sinfo(weights))
+
+ # Gradient mod
+ grad_bb = relax.BlockBuilder()
+ with grad_bb.function(func_name, param_vars + [grad_var]):
+ with grad_bb.dataflow():
+ orig = grad_bb.emit(call)
+ # op_grad_func returns a list of Exprs representing the gradients
+ grad_call = op_grad_func(orig, call, grad_var, grad_bb)
+
+ # Check ignore_grads
+ for i, grad in enumerate(grad_call):
+ if i in ignore_grads:
+ assert _is_call_no_grad(grad), f"The {i}-th gradient
should be no_grad"
+ else:
+ assert not _is_call_no_grad(grad), f"The {i}-th gradient
should not be no_grad"
+
+ if tuple_input:
+ # If the input is a tuple, the gradient is also a tuple.
+ # The gradient tuple is the first (the only) element of
grad_call.
+ out = grad_bb.emit_output(grad_call[0])
+ else:
+ # We need to wrap the list into a relax.Tuple so as to emit it
+ out = grad_bb.emit_output(relax.Tuple(grad_call))
+ grad_bb.emit_func_output(out)
+
+ grad_mod = grad_bb.get()
+ grad_lower_mod = LegalizeOps()(grad_mod)
+ grad_ex = relax.build(grad_lower_mod, target)
+ grad_vm = relax.VirtualMachine(grad_ex, dev)
+
+ # tvm.runtime.NDArray inputs
+ inputs_tvm = [_numpy_to_tvm(i) for i in inputs_numpy]
+ weights_tvm = _numpy_to_tvm(weights)
+ result_filtered = _tvm_to_numpy(grad_vm[func_name](*inputs_tvm,
weights_tvm), ignore_grads)
+
+ # Inputs contained in ignore_grads are removed
+ inputs_filtered = [inputs_numpy[i] for i in range(len(inputs_numpy)) if i
not in ignore_grads]
+
+ check_numerical_grads(forward, inputs_filtered, result_filtered)
+
+
+##################### Unary #####################
+
+
+unary_op_func, can_be_neg = tvm.testing.parameters(
+ (relax.op.abs, True),
+ (relax.op.cos, True),
+ (relax.op.exp, True),
+ (relax.op.log, False),
+ (relax.op.negative, True),
+ (relax.op.sigmoid, True),
+ (relax.op.sin, True),
+ (relax.op.sqrt, False),
+ (relax.op.tanh, True),
+)
+
+
[email protected]_targets("llvm")
+def test_unary(target, dev, unary_op_func, can_be_neg):
+ (low, high) = (-1, 1) if can_be_neg else (0.1, 1)
+ data_numpy = np.random.uniform(low, high, (3, 3)).astype(np.float32)
+ relax_check_gradients(unary_op_func, [data_numpy], target, dev)
+
+
+##################### Binary #####################
+
+
+(binary_arith_op_func,) = tvm.testing.parameters(
+ (relax.op.add,),
+ (relax.op.subtract,),
+ (relax.op.multiply,),
+ (relax.op.divide,),
+ (relax.op.power,),
+)
+
+
[email protected]_targets("llvm")
+def test_binary_arith(target, dev, binary_arith_op_func):
+ data1_numpy = np.random.uniform(1, 2, (3, 3)).astype(np.float32)
+ data2_numpy = np.random.uniform(1, 2, (3, 3)).astype(np.float32)
+ relax_check_gradients(binary_arith_op_func, [data1_numpy, data2_numpy],
target, dev)
+
+
+(binary_cmp_op_func,) = tvm.testing.parameters(
+ (relax.op.equal,),
+ (relax.op.greater,),
+ (relax.op.greater_equal,),
+ (relax.op.less,),
+ (relax.op.less_equal,),
+ (relax.op.not_equal,),
+)
+
+
[email protected]_targets("llvm")
+def test_binary_cmp(target, dev, binary_cmp_op_func):
+ data1_numpy = np.random.uniform(1, 2, (3, 3)).astype(np.float32)
+ data2_numpy = np.random.uniform(1, 2, (3, 3)).astype(np.float32)
+ relax_check_gradients(
+ binary_cmp_op_func, [data1_numpy, data2_numpy], target, dev,
ignore_grads=[0, 1]
+ )
+
+
+##################### Create #####################
+
+
+(like_op_func,) = tvm.testing.parameters(
+ (relax.op.zeros_like,),
+ (relax.op.ones_like,),
+)
+
+
[email protected]_targets("llvm")
+def test_ones_zeros_like(target, dev, like_op_func):
+ data_numpy = np.random.uniform(-1, 1, (3, 3)).astype(np.float32)
+ relax_check_gradients(like_op_func, [data_numpy], target, dev,
ignore_grads=[0])
+
+
[email protected]_targets("llvm")
+def test_full_like(target, dev):
+ data_numpy = np.random.uniform(-1, 1, (3, 3)).astype(np.float32)
+ fill_value = np.random.uniform(-1, 1, ()).astype(np.float32)
+ relax_check_gradients(
+ relax.op.full_like, [data_numpy, fill_value], target, dev,
ignore_grads=[0, 1]
+ )
+
+
+(create_op_func,) = tvm.testing.parameters(
+ (relax.op.zeros,),
+ (relax.op.ones,),
+)
+
+
[email protected]_targets("llvm")
+def test_ones_zeros(target, dev, create_op_func):
+ relax_check_gradients(
+ create_op_func, [], target, dev, ignore_grads=[0], shape=(3, 3),
dtype="float32"
+ )
+
+
+##################### Statistical #####################
+
+
[email protected]_targets("llvm")
+def test_sum(target, dev):
+ data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32)
+ relax_check_gradients(relax.op.sum, [data1_numpy], target, dev)
+
+
[email protected]_targets("llvm")
+def test_sum_with_axis(target, dev):
+ data1_numpy = np.random.randint(0, 16, (2, 3, 4, 5)).astype(np.float32)
+ relax_check_gradients(relax.op.sum, [data1_numpy], target, dev, axis=[1,
3])
+
+
[email protected]_targets("llvm")
+def test_sum_keepdims(target, dev):
+ data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32)
+ relax_check_gradients(relax.op.sum, [data1_numpy], target, dev,
keepdims=True, axis=1)
+
+
[email protected]_targets("llvm")
+def test_mean(target, dev):
+ data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32)
+ relax_check_gradients(relax.op.mean, [data1_numpy], target, dev)
+
+
[email protected]_targets("llvm")
+def test_mean_with_axis(target, dev):
+ data1_numpy = np.random.randint(0, 16, (2, 3, 4, 5)).astype(np.float32)
+ relax_check_gradients(relax.op.mean, [data1_numpy], target, dev, axis=[1,
3])
+
+
[email protected]_targets("llvm")
+def test_mean_keepdims(target, dev):
+ data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32)
+ relax_check_gradients(relax.op.mean, [data1_numpy], target, dev,
keepdims=True, axis=1)
+
+
[email protected]_targets("llvm")
+def test_variance(target, dev):
+ data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32)
+ relax_check_gradients(relax.op.variance, [data1_numpy], target, dev)
+
+
[email protected]_targets("llvm")
+def test_variance_with_axis(target, dev):
+ data1_numpy = np.random.randint(0, 16, (2, 3, 4, 5)).astype(np.float32)
+ relax_check_gradients(relax.op.variance, [data1_numpy], target, dev,
axis=[1, 3])
+
+
[email protected]_targets("llvm")
+def test_variance_keepdims(target, dev):
+ data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32)
+ relax_check_gradients(relax.op.variance, [data1_numpy], target, dev,
keepdims=True, axis=1)
+
+
+##################### Manipulate #####################
+
+
[email protected]_targets("llvm")
+def test_reshape(target, dev):
+ data_numpy = np.random.randint(0, 16, (2, 3, 5)).astype(np.float32)
+ relax_check_gradients(
+ relax.op.reshape, [data_numpy], target, dev, ignore_grads=[1],
shape=(5, 6)
+ )
+
+
[email protected]_targets("llvm")
+def test_reshape_infer_dim(target, dev):
+ data_numpy = np.random.randint(0, 16, (2, 3, 5)).astype(np.float32)
+ relax_check_gradients(
+ relax.op.reshape, [data_numpy], target, dev, ignore_grads=[1],
shape=(5, 2, 1, -1)
+ )
+
+
[email protected]_targets("llvm")
+def test_permute_dims(target, dev):
+ data_numpy = np.random.randint(0, 16, (2, 3, 4, 5)).astype(np.float32)
+ relax_check_gradients(relax.op.permute_dims, [data_numpy], target, dev)
+
+
[email protected]_targets("llvm")
+def test_permute_dims_with_axes(target, dev):
+ data_numpy = np.random.randint(0, 16, (2, 3, 4, 5)).astype(np.float32)
+ relax_check_gradients(
+ relax.op.permute_dims,
+ [data_numpy],
+ target,
+ dev,
+ axes=(0, 3, 1, 2),
+ )
+
+
[email protected]_targets("llvm")
+def test_concat(target, dev):
+ data_numpy1 = np.random.randint(1, 16, (3, 3)).astype(np.float32)
+ data_numpy2 = np.random.randint(1, 16, (3, 4)).astype(np.float32)
+ data_numpy3 = np.random.randint(1, 16, (3, 5)).astype(np.float32)
+ relax_check_gradients(
+ relax.op.concat,
+ [data_numpy1, data_numpy2, data_numpy3],
+ target,
+ dev,
+ tuple_input=True,
+ axis=1,
+ )
+
+
[email protected]_targets("llvm")
+def test_split_indices(target, dev):
+ data_numpy = np.random.randint(1, 16, (3, 12)).astype(np.float32)
+ relax_check_gradients(
+ relax.op.split,
+ [data_numpy],
+ target,
+ dev,
+ indices_or_sections=[3, 7],
+ axis=1,
+ )
+
+
[email protected]_targets("llvm")
+def test_split_section(target, dev):
+ data_numpy = np.random.randint(1, 16, (3, 12)).astype(np.float32)
+ relax_check_gradients(
+ relax.op.split,
+ [data_numpy],
+ target,
+ dev,
+ indices_or_sections=3,
+ axis=1,
+ )
+
+
[email protected]_targets("llvm")
+def test_reshape(target, dev):
+ data_numpy = np.random.randint(1, 16, (3, 4)).astype(np.float32)
+
+ relax_check_gradients(
+ relax.op.reshape,
+ [data_numpy],
+ target,
+ dev,
+ shape=(3, 2, 2),
+ ignore_grads=[1],
+ )
+
+
[email protected]_targets("llvm")
+def test_cumsum(target, dev):
+ data_numpy1 = np.random.randint(1, 16, (3, 3)).astype(np.float32)
+ relax_check_gradients(
+ relax.op.cumsum,
+ [data_numpy1],
+ target,
+ dev,
+ axis=1,
+ )
+
+
[email protected]_targets("llvm")
+def test_cumsum_no_axis(target, dev):
+ data_numpy1 = np.random.randint(1, 16, (3, 3)).astype(np.float32)
+ relax_check_gradients(
+ relax.op.cumsum,
+ [data_numpy1],
+ target,
+ dev,
+ )
+
+
[email protected]_targets("llvm")
+def test_expand_dims(target, dev):
+ data_numpy = np.random.randint(1, 16, (3, 12)).astype(np.float32)
+ relax_check_gradients(relax.op.expand_dims, [data_numpy], target, dev,
axis=1)
+
+
[email protected]_targets("llvm")
+def test_expand_dims_list(target, dev):
+ data_numpy = np.random.randint(1, 16, (3, 12)).astype(np.float32)
+ relax_check_gradients(relax.op.expand_dims, [data_numpy], target, dev,
axis=(0, 2, 3))
+
+
+##################### Index #####################
+
+
[email protected]_targets("llvm")
+def test_take(target, dev):
+ data_numpy = np.random.uniform(0, 16, size=(2, 3, 4)).astype(np.float32)
+ indices = np.array([0, 1])
+ relax_check_gradients(
+ relax.op.take,
+ [data_numpy, indices],
+ target,
+ dev,
+ axis=1,
+ ignore_grads=[1],
+ )
+
+
[email protected]_targets("llvm")
+def test_take_no_axis(target, dev):
+ data_numpy = np.random.uniform(0, 16, size=(5,)).astype(np.float32)
+ indices = np.array([1, 3])
+ relax_check_gradients(
+ relax.op.take,
+ [data_numpy, indices],
+ target,
+ dev,
+ ignore_grads=[1],
+ )
+
+
+##################### Search #####################
+
+
[email protected]_targets("llvm")
+def test_where(target, dev):
+ data1_numpy = np.random.uniform(0, 1, size=(3, 3)) > 0.5
+ data2_numpy = np.random.uniform(0, 16, size=(3, 3)).astype(np.float32)
+ data3_numpy = np.random.uniform(0, 16, size=(3, 3)).astype(np.float32)
+
+ relax_check_gradients(
+ relax.op.where,
+ [data1_numpy, data2_numpy, data3_numpy],
+ target,
+ dev,
+ ignore_grads=[0],
+ )
+
+
+##################### Linear Algebra #####################
+
+
[email protected]_targets("llvm")
+def test_matmul_2_2(target, dev):
+ data1_numpy = np.random.randint(0, 16, (2, 3)).astype(np.float32)
+ data2_numpy = np.random.randint(0, 16, (3, 4)).astype(np.float32)
+ relax_check_gradients(relax.op.matmul, [data1_numpy, data2_numpy], target,
dev)
+
+
[email protected]_targets("llvm")
+def test_matmul_1_1(target, dev):
+ data1_numpy = np.random.randint(0, 16, (4,)).astype(np.float32)
+ data2_numpy = np.random.randint(0, 16, (4,)).astype(np.float32)
+ relax_check_gradients(relax.op.matmul, [data1_numpy, data2_numpy], target,
dev)
+
+
[email protected]_targets("llvm")
+def test_matmul_1_4(target, dev):
+ data1_numpy = np.random.randint(0, 16, (4,)).astype(np.float32)
+ data2_numpy = np.random.randint(0, 16, (2, 3, 4, 5)).astype(np.float32)
+ relax_check_gradients(relax.op.matmul, [data1_numpy, data2_numpy], target,
dev)
+
+
[email protected]_targets("llvm")
+def test_matmul_4_1(target, dev):
+ data1_numpy = np.random.randint(0, 16, (2, 3, 4, 5)).astype(np.float32)
+ data2_numpy = np.random.randint(0, 16, (5,)).astype(np.float32)
+ relax_check_gradients(relax.op.matmul, [data1_numpy, data2_numpy], target,
dev)
+
+
[email protected]_targets("llvm")
+def test_matmul_5_4(target, dev):
+ data1_numpy = np.random.randint(0, 16, (2, 3, 1, 4, 5)).astype(np.float32)
+ data2_numpy = np.random.randint(0, 16, (3, 2, 5, 4)).astype(np.float32)
+ relax_check_gradients(
+ relax.op.matmul,
+ [data1_numpy, data2_numpy],
+ target,
+ dev,
+ )
+
+
+##################### Datatype #####################
+
+
[email protected]_targets("llvm")
+def test_astype(target, dev):
+ data_numpy = np.random.uniform(0, 16, size=(3, 3)).astype(np.float64)
+ relax_check_gradients(relax.op.astype, [data_numpy], target, dev,
dtype="float32")
+
+
+##################### Neural network #####################
+
+
[email protected]_targets("llvm")
+def test_relu(target, dev):
+ data1_numpy = np.random.uniform(0.2, 1, (3, 3)).astype(np.float32)
+ sign = np.random.randint(0, 2, (3, 3)).astype(np.float32) * 2 - 1
+ data1_numpy *= sign
+ relax_check_gradients(relax.op.nn.relu, [data1_numpy], target, dev)
+
+
[email protected]_targets("llvm")
+def test_softmax(target, dev):
+ data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32)
+ relax_check_gradients(relax.op.nn.softmax, [data1_numpy], target, dev)
+
+
[email protected]_targets("llvm")
+def test_softmax_with_axis(target, dev):
+ data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32)
+ relax_check_gradients(relax.op.nn.softmax, [data1_numpy], target, dev,
axis=1)
+
+
[email protected]_targets("llvm")
+def test_log_softmax(target, dev):
+ data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32)
+ relax_check_gradients(relax.op.nn.log_softmax, [data1_numpy], target, dev)
+
+
[email protected]_targets("llvm")
+def test_log_softmax_with_axis(target, dev):
+ data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32)
+ relax_check_gradients(relax.op.nn.log_softmax, [data1_numpy], target, dev,
axis=1)
+
+
[email protected]_targets("llvm")
+def test_cross_entropy_with_logits(target, dev):
+ data_numpy1 = np.random.randint(1, 16, (3,)).astype(np.float32)
+ data_numpy2 = np.random.randint(1, 16, (3,)).astype(np.float32)
+ relax_check_gradients(
+ relax.op.nn.cross_entropy_with_logits,
+ [data_numpy1, data_numpy2],
+ target,
+ dev,
+ )
+
+
[email protected]_targets("llvm")
+def test_cross_entropy_with_logits_batch(target, dev):
+ data_numpy1 = np.random.randint(1, 16, (2, 3)).astype(np.float32)
+ data_numpy2 = np.random.randint(1, 16, (2, 3)).astype(np.float32)
+ relax_check_gradients(
+ relax.op.nn.cross_entropy_with_logits,
+ [data_numpy1, data_numpy2],
+ target,
+ dev,
+ )
+
+
+(nll_reduction, nll_weighted, nll_ignore_index) = tvm.testing.parameters(
+ ("mean", True, -1),
+ ("sum", True, -1),
+ ("none", True, -1),
+ ("mean", True, 1),
+ ("mean", True, 1),
+ ("mean", False, 1),
+)
+
+
[email protected]_targets("llvm")
+def test_nll_loss(target, dev, nll_reduction, nll_weighted, nll_ignore_index):
+ data1_numpy = np.random.randint(0, 16, (2, 3, 4)).astype(np.float32)
+ data2_numpy = np.random.randint(0, 3, (2, 4)).astype(np.int64)
+ data3_numpy = np.random.randint(0, 16, (3,)).astype(np.float32)
+
+ input = [data1_numpy, data2_numpy] + ([data3_numpy] if nll_weighted else
[])
+ ignore_grads = [1] + ([2] if nll_weighted else [])
+
+ relax_check_gradients(
+ relax.op.nn.nll_loss,
+ input,
+ target,
+ dev,
+ ignore_grads=ignore_grads,
+ reduction=nll_reduction,
+ ignore_index=nll_ignore_index,
+ )
+
+
+(nll_reduction1, nll_weighted1, nll_ignore_index1) = tvm.testing.parameters(
+ ("mean", True, -1),
+ ("sum", True, -1),
+ ("none", True, -1),
+)
+
+
[email protected]_targets("llvm")
+def test_nll_loss_no_batch(target, dev, nll_reduction1, nll_weighted1,
nll_ignore_index1):
+ data1_numpy = np.random.randint(0, 16, (3,)).astype(np.float32)
+ data2_numpy = np.random.randint(0, 3, ()).astype(np.int64)
+ data3_numpy = np.random.randint(1, 16, (3,)).astype(np.float32)
+
+ input = [data1_numpy, data2_numpy] + ([data3_numpy] if nll_weighted1 else
[])
+ ignore_grads = [1] + ([2] if nll_weighted1 else [])
+
+ relax_check_gradients(
+ relax.op.nn.nll_loss,
+ input,
+ target,
+ dev,
+ ignore_grads=ignore_grads,
+ reduction=nll_reduction1,
+ ignore_index=nll_ignore_index1,
+ )
+
+
+(c2d_shape1, c2d_shape2, c2d_kwargs,) = tvm.testing.parameters(
+ (
+ (3, 2, 10, 10),
+ (3, 2, 3, 3),
+ {},
+ ),
+ (
+ (3, 2, 10, 10),
+ (3, 2, 1, 2),
+ {},
+ ),
+ (
+ (3, 2, 10, 10),
+ (3, 2, 3, 3),
+ {"strides": (2, 2), "padding": (3, 2), "dilation": (1, 1)},
+ ),
+ (
+ (3, 2, 10, 10),
+ (3, 2, 3, 3),
+ {"strides": (2, 1), "padding": (2, 2), "dilation": (1, 1)},
+ ),
+ (
+ (3, 6, 10, 10),
+ (4, 3, 3, 3),
+ {"groups": 2},
+ ),
+ (
+ (3, 2, 10, 10),
+ (4, 1, 3, 3),
+ {"groups": 2, "strides": (2, 2), "padding": (2, 2), "dilation": (1,
1)},
+ ),
+)
+
+
[email protected]_targets("llvm")
+def test_conv2d(target, dev, c2d_shape1, c2d_shape2, c2d_kwargs):
+ # We should use float32 to check the correctness of conv2d
+ # to avoid possible precision problems
+ data1_numpy = np.random.randint(0, 16, c2d_shape1).astype(np.float64)
+ data2_numpy = np.random.randint(0, 3, c2d_shape2).astype(np.float64)
+ relax_check_gradients(
+ relax.op.nn.conv2d,
+ [data1_numpy, data2_numpy],
+ target,
+ dev,
+ **c2d_kwargs,
+ )
+
+
+(pool_size, pool_kwargs,) = tvm.testing.parameters(
+ (
+ (3, 3),
+ {},
+ ),
+ (
+ (3, 3),
+ {"strides": (2, 2), "padding": (1, 2), "dilation": (1, 1)},
+ ),
+ (
+ (5, 5),
+ {"strides": (2, 2), "padding": (2, 1), "dilation": (1, 1),
"ceil_mode": True},
+ ),
+)
+
+
[email protected]_targets("llvm")
+def test_max_pool2d(target, dev, pool_size, pool_kwargs):
+ data_numpy = np.random.uniform(0, 16, size=(3, 2, 10,
10)).astype(np.float64)
+ relax_check_gradients(
+ relax.op.nn.max_pool2d,
+ [data_numpy],
+ target,
+ dev,
+ pool_size=pool_size,
+ **pool_kwargs,
+ )
+
+
[email protected]_targets("llvm")
+def test_avg_pool2d(target, dev, pool_size, pool_kwargs):
+ data_numpy = np.random.uniform(0, 16, size=(3, 2, 10,
10)).astype(np.float64)
+ relax_check_gradients(
+ relax.op.nn.avg_pool2d,
+ [data_numpy],
+ target,
+ dev,
+ pool_size=pool_size,
+ **pool_kwargs,
+ )
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_grad.py
b/tests/python/relax/test_transform_legalize_ops_grad.py
new file mode 100644
index 0000000000..a92537f0d1
--- /dev/null
+++ b/tests/python/relax/test_transform_legalize_ops_grad.py
@@ -0,0 +1,337 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm.relax.transform import LegalizeOps
+from tvm.script import relax as R, tir as T, ir as I
+import tvm.testing
+from tvm.tir.op import div
+
+
+def test_nll_loss_backward():
+ # fmt: off
+ @tvm.script.ir_module
+ class NLLLossBackward:
+ @R.function
+ def main(output_grad: R.Tensor((), "float32"), predictions:
R.Tensor((2, 3, 4, 5), "float32"), targets: R.Tensor((2, 4, 5), "int64"),
weights: R.Tensor((4,), "float32")) -> R.Tensor((2, 3, 4, 5), "float32"):
+ gv: R.Tensor((2, 3, 4, 5), "float32") =
R.grad.nll_loss_backward(output_grad, predictions, targets, weights,
reduction="mean", ignore_index=-1)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def nll_loss_backward(rxplaceholder: T.Buffer((), "float32"),
rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)),
"float32"), rxplaceholder_2: T.Buffer((T.int64(2), T.int64(4), T.int64(5)),
"int64"), rxplaceholder_3: T.Buffer((T.int64(4),), "float32"), pred_grad:
T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ all_weights = T.alloc_buffer((T.int64(2), T.int64(4), T.int64(5)))
+ T_broadcast_to = T.alloc_buffer((T.int64(2), T.int64(4),
T.int64(5)))
+ all_weights_red = T.alloc_buffer(())
+ T_divide = T.alloc_buffer((T.int64(2), T.int64(4), T.int64(5)))
+ for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
+ with T.block("all_weights"):
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(rxplaceholder_3[rxplaceholder_2[v_i0, v_i1,
v_i2]], rxplaceholder_2[v_i0, v_i1, v_i2])
+ T.writes(all_weights[v_i0, v_i1, v_i2])
+ all_weights[v_i0, v_i1, v_i2] =
rxplaceholder_3[rxplaceholder_2[v_i0, v_i1, v_i2]]
+ for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
+ with T.block("T_broadcast_to"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(rxplaceholder[()])
+ T.writes(T_broadcast_to[v_ax0, v_ax1, v_ax2])
+ T_broadcast_to[v_ax0, v_ax1, v_ax2] = rxplaceholder[()]
+ for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
+ with T.block("all_weights_red"):
+ v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2])
+ T.reads(all_weights[v_k0, v_k1, v_k2])
+ T.writes(all_weights_red[()])
+ with T.init():
+ all_weights_red[()] = T.float32(0)
+ all_weights_red[()] = all_weights_red[()] +
all_weights[v_k0, v_k1, v_k2]
+ for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
+ with T.block("T_divide"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(T_broadcast_to[v_ax0, v_ax1, v_ax2],
all_weights_red[()])
+ T.writes(T_divide[v_ax0, v_ax1, v_ax2])
+ T_divide[v_ax0, v_ax1, v_ax2] = T_broadcast_to[v_ax0,
v_ax1, v_ax2] / all_weights_red[()]
+ for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
+ with T.block("pred_grad"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
+ T.reads(rxplaceholder_2[v_i0, v_i2, v_i3],
all_weights[v_i0, v_i2, v_i3], T_divide[v_i0, v_i2, v_i3])
+ T.writes(pred_grad[v_i0, v_i1, v_i2, v_i3])
+ pred_grad[v_i0, v_i1, v_i2, v_i3] = T.Select(v_i1 ==
rxplaceholder_2[v_i0, v_i2, v_i3], all_weights[v_i0, v_i2, v_i3] *
T.float32(-1) * T_divide[v_i0, v_i2, v_i3], T.float32(0))
+
+ @R.function
+ def main(output_grad: R.Tensor((), dtype="float32"), predictions:
R.Tensor((2, 3, 4, 5), dtype="float32"), targets: R.Tensor((2, 4, 5),
dtype="int64"), weights: R.Tensor((4,), dtype="float32")) -> R.Tensor((2, 3, 4,
5), dtype="float32"):
+ cls = Expected
+ gv = R.call_tir(cls.nll_loss_backward, (output_grad, predictions,
targets, weights), out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(NLLLossBackward)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_nll_loss_backward_no_weight():
+ # fmt: off
+ @I.ir_module
+ class NLLLossBackward:
+ @R.function
+ def main(output_grad: R.Tensor((), "float32"), predictions:
R.Tensor((2, 3, 4, 5), "float32"), targets: R.Tensor((2, 4, 5), "int64")) ->
R.Tensor((2, 3, 4, 5), "float32"):
+ gv: R.Tensor((2, 3, 4, 5), "float32") =
R.grad.nll_loss_backward(output_grad, predictions, targets, reduction="mean",
ignore_index=-1)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def te_nll_loss_backward_no_weight(rxplaceholder: T.Buffer((),
"float32"), rxplaceholder_1: T.Buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)), "float32"), rxplaceholder_2: T.Buffer((T.int64(2), T.int64(4),
T.int64(5)), "int64"), pred_grad: T.Buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ T_full = T.alloc_buffer((T.int64(3),))
+ all_weights = T.alloc_buffer((T.int64(2), T.int64(4), T.int64(5)))
+ T_broadcast_to = T.alloc_buffer((T.int64(2), T.int64(4),
T.int64(5)))
+ all_weights_red = T.alloc_buffer(())
+ T_divide = T.alloc_buffer((T.int64(2), T.int64(4), T.int64(5)))
+ for ax0 in range(T.int64(3)):
+ with T.block("T_full"):
+ v_ax0 = T.axis.spatial(T.int64(3), ax0)
+ T.reads()
+ T.writes(T_full[v_ax0])
+ T_full[v_ax0] = T.float32(1)
+ for i0, i1, i2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
+ with T.block("all_weights"):
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(T_full[rxplaceholder_2[v_i0, v_i1, v_i2]],
rxplaceholder_2[v_i0, v_i1, v_i2])
+ T.writes(all_weights[v_i0, v_i1, v_i2])
+ all_weights[v_i0, v_i1, v_i2] =
T_full[rxplaceholder_2[v_i0, v_i1, v_i2]]
+ for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
+ with T.block("T_broadcast_to"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(rxplaceholder[()])
+ T.writes(T_broadcast_to[v_ax0, v_ax1, v_ax2])
+ T_broadcast_to[v_ax0, v_ax1, v_ax2] = rxplaceholder[()]
+ for k0, k1, k2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
+ with T.block("all_weights_red"):
+ v_k0, v_k1, v_k2 = T.axis.remap("RRR", [k0, k1, k2])
+ T.reads(all_weights[v_k0, v_k1, v_k2])
+ T.writes(all_weights_red[()])
+ with T.init():
+ all_weights_red[()] = T.float32(0)
+ all_weights_red[()] = all_weights_red[()] +
all_weights[v_k0, v_k1, v_k2]
+ for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(5)):
+ with T.block("T_divide"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(T_broadcast_to[v_ax0, v_ax1, v_ax2],
all_weights_red[()])
+ T.writes(T_divide[v_ax0, v_ax1, v_ax2])
+ T_divide[v_ax0, v_ax1, v_ax2] = T_broadcast_to[v_ax0,
v_ax1, v_ax2] / all_weights_red[()]
+ for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
+ with T.block("pred_grad"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
+ T.reads(rxplaceholder_2[v_i0, v_i2, v_i3],
all_weights[v_i0, v_i2, v_i3], T_divide[v_i0, v_i2, v_i3])
+ T.writes(pred_grad[v_i0, v_i1, v_i2, v_i3])
+ pred_grad[v_i0, v_i1, v_i2, v_i3] = T.Select(v_i1 ==
rxplaceholder_2[v_i0, v_i2, v_i3], all_weights[v_i0, v_i2, v_i3] *
T.float32(-1) * T_divide[v_i0, v_i2, v_i3], T.float32(0))
+
+ @R.function
+ def main(output_grad: R.Tensor((), dtype="float32"), predictions:
R.Tensor((2, 3, 4, 5), dtype="float32"), targets: R.Tensor((2, 4, 5),
dtype="int64")) -> R.Tensor((2, 3, 4, 5), dtype="float32"):
+ cls = Expected
+ gv = R.call_tir(cls.te_nll_loss_backward_no_weight, (output_grad,
predictions, targets), out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(NLLLossBackward)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_nll_loss_backward_no_batch():
+ # fmt: off
+ @tvm.script.ir_module
+ class NLLLossBackward:
+ @R.function
+ def main(output_grad: R.Tensor((), "float32"), predictions:
R.Tensor((4,), "float32"), targets: R.Tensor((), "int64"), weights:
R.Tensor((4,), "float32")) -> R.Tensor((4,), "float32"):
+ gv: R.Tensor((4,), "float32") =
R.grad.nll_loss_backward(output_grad, predictions, targets, weights,
reduction="mean", ignore_index=-1)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(output_grad: R.Tensor((), dtype="float32"), predictions:
R.Tensor((4,), dtype="float32"), targets: R.Tensor((), dtype="int64"), weights:
R.Tensor((4,), dtype="float32")) -> R.Tensor((4,), dtype="float32"):
+ cls = Expected
+ gv = R.call_tir(cls.nll_loss_backward, (output_grad, predictions,
targets, weights), out_sinfo=R.Tensor((4,), dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def nll_loss_backward(rxplaceholder: T.Buffer((), "float32"),
rxplaceholder_1: T.Buffer((T.int64(4),), "float32"), rxplaceholder_2:
T.Buffer((), "int64"), rxplaceholder_3: T.Buffer((T.int64(4),), "float32"),
pred_grad: T.Buffer((T.int64(4),), "float32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ all_weights = T.alloc_buffer(())
+ T_broadcast_to = T.alloc_buffer(())
+ T_divide = T.alloc_buffer(())
+ with T.block("all_weights"):
+ vi = T.axis.spatial(T.int64(1), T.int64(0))
+ T.reads(rxplaceholder_3[rxplaceholder_2[()]],
rxplaceholder_2[()])
+ T.writes(all_weights[()])
+ all_weights[()] = rxplaceholder_3[rxplaceholder_2[()]]
+ with T.block("T_broadcast_to"):
+ vi = T.axis.spatial(1, T.int64(0))
+ T.reads(rxplaceholder[()])
+ T.writes(T_broadcast_to[()])
+ T_broadcast_to[()] = rxplaceholder[()]
+ with T.block("T_divide"):
+ vi = T.axis.spatial(1, T.int64(0))
+ T.reads(T_broadcast_to[()], all_weights[()])
+ T.writes(T_divide[()])
+ T_divide[()] = T_broadcast_to[()] / all_weights[()]
+ for i in range(T.int64(4)):
+ with T.block("pred_grad"):
+ v_i = T.axis.spatial(T.int64(4), i)
+ T.reads(rxplaceholder_2[()], all_weights[()], T_divide[()])
+ T.writes(pred_grad[v_i])
+ pred_grad[v_i] = T.Select(v_i == rxplaceholder_2[()],
all_weights[()] * T.float32(-1) * T_divide[()], T.float32(0))
+ # fmt: on
+
+ mod = LegalizeOps()(NLLLossBackward)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_max_pool2d_backward():
+ # fmt: off
+ @tvm.script.ir_module
+ class MaxPool2DBackward:
+ @R.function
+ def main(output_grad: R.Tensor((3, 2, 6, 5), "float32"), data:
R.Tensor((3, 2, 10, 10), "float32")):
+ gv = R.grad.max_pool2d_backward(output_grad, data, (5, 5), (2, 2),
(2, 1, 2, 1), (1, 1), True)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(output_grad: R.Tensor((3, 2, 6, 5), dtype="float32"), data:
R.Tensor((3, 2, 10, 10), dtype="float32")) -> R.Tensor((3, 2, 10, 10),
dtype="float32"):
+ cls = Expected
+ gv = R.call_tir(cls.max_pool2d_backward, (output_grad, data),
out_sinfo=R.Tensor((3, 2, 10, 10), dtype="float32"))
+ return gv
+
+ @T.prim_func
+ def max_pool2d_backward(rxplaceholder: T.Buffer((T.int64(3),
T.int64(2), T.int64(6), T.int64(5)), "float32"), rxplaceholder_1:
T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32"),
T_pool_grad: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)),
"float32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ pad_temp = T.alloc_buffer((T.int64(3), T.int64(2), T.int64(15),
T.int64(13)))
+ maxpool_grad_argmax_v0 = T.alloc_buffer((T.int64(3), T.int64(2),
T.int64(6), T.int64(5)), "int32")
+ maxpool_grad_argmax_v1 = T.alloc_buffer((T.int64(3), T.int64(2),
T.int64(6), T.int64(5)))
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(3), T.int64(2),
T.int64(15), T.int64(13)):
+ with T.block("pad_temp"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2 - T.int64(2),
v_ax3 - T.int64(1)])
+ T.writes(pad_temp[v_ax0, v_ax1, v_ax2, v_ax3])
+ pad_temp[v_ax0, v_ax1, v_ax2, v_ax3] =
T.if_then_else(T.int64(2) <= v_ax2 and v_ax2 < T.int64(12) and T.int64(1) <=
v_ax3 and v_ax3 < T.int64(11), rxplaceholder_1[v_ax0, v_ax1, v_ax2 -
T.int64(2), v_ax3 - T.int64(1)], T.float32(-3.4028234663852886e+38))
+ for ax0, ax1, ax2, ax3, dh, dw in T.grid(T.int64(3), T.int64(2),
T.int64(6), T.int64(5), T.int64(5), T.int64(5)):
+ with T.block("maxpool_grad_argmax"):
+ v_ax0, v_ax1, v_ax2, v_ax3, v_dh, v_dw =
T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, dh, dw])
+ T.reads(pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_dh,
v_ax3 * T.int64(2) + v_dw])
+ T.writes(maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2,
v_ax3], maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3])
+ with T.init():
+ maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3] = -1
+ maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] =
T.float32(-3.4028234663852886e+38)
+ v_maxpool_grad_argmax_v0: T.int64 =
T.Select(maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] > pad_temp[v_ax0,
v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw] or
maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] == pad_temp[v_ax0, v_ax1,
v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw] and T.Cast("int64",
maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3]) < v_ax0 * T.int64(390) +
v_ax1 * T.int64(195) + v_ax2 * T.int64(26) + v_dh * T.int64(13) + [...]
+ v_maxpool_grad_argmax_v1: T.float32 =
T.Select(maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] > pad_temp[v_ax0,
v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw],
maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3], pad_temp[v_ax0, v_ax1,
v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw])
+ maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3] =
T.Cast("int32", v_maxpool_grad_argmax_v0)
+ maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] =
v_maxpool_grad_argmax_v1
+ for ax0, ax1, ax2, ax3, wh, ww in T.grid(T.int64(3), T.int64(2),
T.int64(10), T.int64(10), T.int64(3), T.int64(3)):
+ with T.block("T_pool_grad"):
+ v_ax0, v_ax1, v_ax2, v_ax3, v_wh, v_ww =
T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, wh, ww])
+ T.reads(maxpool_grad_argmax_v0[v_ax0, v_ax1, div((v_ax2 +
T.int64(2)), T.int64(2)) - v_wh, div((v_ax3 + T.int64(1)), T.int64(2)) - v_ww],
rxplaceholder[v_ax0, v_ax1, div((v_ax2 + T.int64(2)), T.int64(2)) - v_wh,
div((v_ax3 + T.int64(1)), T.int64(2)) - v_ww])
+ T.writes(T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3])
+ with T.init():
+ T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0)
+ T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] =
T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 <
T.int64(3), T.int64(0), div((v_ax2 - T.int64(3)), T.int64(2)) + T.int64(1)) <=
div((v_ax2 + T.int64(2)), T.int64(2)) - v_wh and T.Select(v_ax3 < T.int64(4),
T.int64(0), div((v_ax3 - T.int64(4)), T.int64(2)) + T.int64(1)) <= div((v_ax3 +
T.int64(1)), T.int64(2)) - v_ww and T.Cast("int64",
maxpool_grad_argmax_v0[v_ax0, v_ax1, div((v_ax2 + T.int64(2)), T.in [...]
+ # fmt: on
+
+ mod = LegalizeOps()(MaxPool2DBackward)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_avg_pool2d_backward():
+ # fmt: off
+ @tvm.script.ir_module
+ class AvgPool2DBackward:
+ @R.function
+ def main(output_grad: R.Tensor((3, 2, 6, 5), "float32"), data:
R.Tensor((3, 2, 10, 10), "float32")):
+ gv = R.grad.avg_pool2d_backward(output_grad, data, (5, 5), (2, 2),
(2, 1, 2, 1), (1, 1), True)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def avg_pool2d_backward(rxplaceholder: T.Buffer((T.int64(3),
T.int64(2), T.int64(6), T.int64(5)), "float32"), rxplaceholder_1:
T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32"),
T_pool_grad: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)),
"float32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ for ax0, ax1, ax2, ax3, wh, ww in T.grid(T.int64(3), T.int64(2),
T.int64(10), T.int64(10), T.int64(3), T.int64(3)):
+ with T.block("T_pool_grad"):
+ v_ax0, v_ax1, v_ax2, v_ax3, v_wh, v_ww =
T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, wh, ww])
+ T.reads(rxplaceholder[v_ax0, v_ax1, T.Div((v_ax2 +
T.int64(2)), T.int64(2)) - v_wh, T.Div((v_ax3 + T.int64(1)), T.int64(2)) -
v_ww])
+ T.writes(T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3])
+ with T.init():
+ T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0)
+ T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] =
T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 <
T.int64(3), T.int64(0), T.Div((v_ax2 - T.int64(3)), T.int64(2)) + T.int64(1))
<= T.Div((v_ax2 + T.int64(2)), T.int64(2)) - v_wh and T.Div((v_ax2 +
T.int64(2)), T.int64(2)) - v_wh < T.int64(6) and T.Select(v_ax3 < T.int64(4),
T.int64(0), T.Div((v_ax3 - T.int64(4)), T.int64(2)) + T.int64(1)) <=
T.Div((v_ax3 + T.int64(1)), T.int64(2)) - v_ww and T.Div((v_ax [...]
+
+ @R.function
+ def main(output_grad: R.Tensor((3, 2, 6, 5), dtype="float32"), data:
R.Tensor((3, 2, 10, 10), dtype="float32")) -> R.Tensor((3, 2, 10, 10),
dtype="float32"):
+ cls = Expected
+ gv = R.call_tir(cls.avg_pool2d_backward, (output_grad, data),
out_sinfo=R.Tensor((3, 2, 10, 10), dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(AvgPool2DBackward)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_take_backward():
+ # fmt: off
+ @tvm.script.ir_module
+ class TakeBackward:
+ @R.function
+ def main(output_grad: R.Tensor((3, 2, 4), "float32"), x: R.Tensor((3,
4, 5), "float32"), indices: R.Tensor((2,), "int32")):
+ gv = R.grad.take_backward(output_grad, x, indices)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def take_backward(var_rxplaceholder: T.handle, var_rxplaceholder_1:
T.handle, var_rxplaceholder_2: T.handle, out_buf: T.Buffer((T.int64(3),
T.int64(4), T.int64(5)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(3),
T.int64(2), T.int64(4)), offset_factor=1)
+ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(3),
T.int64(4), T.int64(5)), offset_factor=1)
+ rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2,
(T.int64(2),), "int32", offset_factor=1)
+ with T.block("take_backward"):
+ T.reads(rxplaceholder[T.int64(0):T.int64(3),
T.int64(0):T.int64(2), T.int64(0):T.int64(4)],
rxplaceholder_1[T.int64(0):T.int64(3), T.int64(0):T.int64(4),
T.int64(0):T.int64(5)], rxplaceholder_2[T.int64(0):T.int64(2)])
+ T.writes(out_buf[T.int64(0):T.int64(3), T.int64(0):T.int64(4),
T.int64(0):T.int64(5)])
+ for i in range(T.int64(60)):
+ out_buf[i // T.int64(5) // T.int64(4), i // T.int64(5) %
T.int64(4), i % T.int64(5)] = T.float32(0)
+ out_buf[T.Cast("int64", rxplaceholder_2[T.int64(1)]) //
T.int64(5) // T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(1)]) //
T.int64(5) % T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(1)]) %
T.int64(5)] = out_buf[T.Cast("int64", rxplaceholder_2[T.int64(1)]) //
T.int64(5) // T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(1)]) //
T.int64(5) % T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(1)]) %
T.int64(5)] + rxplaceholder[T.int64(0), T.int64(0), T.int64(1)]
+ out_buf[T.Cast("int64", rxplaceholder_2[T.int64(0)]) //
T.int64(5) // T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(0)]) //
T.int64(5) % T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(0)]) %
T.int64(5)] = out_buf[T.Cast("int64", rxplaceholder_2[T.int64(0)]) //
T.int64(5) // T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(0)]) //
T.int64(5) % T.int64(4), T.Cast("int64", rxplaceholder_2[T.int64(0)]) %
T.int64(5)] + rxplaceholder[T.int64(0), T.int64(0), T.int64(0)]
+
+ @R.function
+ def main(output_grad: R.Tensor((3, 2, 4), dtype="float32"), x:
R.Tensor((3, 4, 5), dtype="float32"), indices: R.Tensor((2,), dtype="int32"))
-> R.Tensor((3, 4, 5), dtype="float32"):
+ cls = Expected
+ gv = R.call_tir(cls.take_backward, (output_grad, x, indices),
out_sinfo=R.Tensor((3, 4, 5), dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(TakeBackward)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git
a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
index 682abf2d57..42aa89f8cc 100644
--- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
+++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
@@ -17,7 +17,7 @@
import tvm
from tvm.relax.transform import LegalizeOps
-from tvm.script import relax as R, tir as T
+from tvm.script import relax as R, tir as T, ir as I
import tvm.testing
@@ -692,55 +692,52 @@ def test_std():
gv: R.Tensor((), "float32") = R.std(x)
return gv
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
- @R.function
- def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((),
"float32"):
- gv = R.call_tir(Expected.std, (x,), R.Tensor((), dtype="float32"))
- return gv
-
@T.prim_func
def std(rxplaceholder: T.Buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)), "float32"), compute: T.Buffer((), "float32")):
T.func_attr({"tir.noalias": True})
- rxplaceholder_red = T.alloc_buffer([], dtype="float32")
- T_divide = T.alloc_buffer([], dtype="float32")
- T_subtract = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(4),
T.int64(5)], dtype="float32")
- T_multiply = T.alloc_buffer([T.int64(2), T.int64(3), T.int64(4),
T.int64(5)], dtype="float32")
- T_multiply_red = T.alloc_buffer([], dtype="float32")
- T_divide_1 = T.alloc_buffer([], dtype="float32")
- for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
+ # with T.block("root"):
+ rxplaceholder_red = T.alloc_buffer((T.int64(1), T.int64(1),
T.int64(1), T.int64(1)))
+ T_divide = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1),
T.int64(1)))
+ T_subtract = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
+ T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
+ T_multiply_red = T.alloc_buffer(())
+ T_divide_1 = T.alloc_buffer(())
+ for ax0, ax1, ax2, ax3, k0, k1, k2, k3 in T.grid(T.int64(1),
T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
with T.block("rxplaceholder_red"):
- k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3])
- T.reads(rxplaceholder[k0, k1, k2, k3])
- T.writes(rxplaceholder_red[()])
+ v_ax0, v_ax1, v_ax2, v_ax3, v_k0, v_k1, v_k2, v_k3 =
T.axis.remap("SSSSRRRR", [ax0, ax1, ax2, ax3, k0, k1, k2, k3])
+ T.reads(rxplaceholder[v_k0, v_k1, v_k2, v_k3])
+ T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3])
with T.init():
- rxplaceholder_red[()] = T.float32(0)
- rxplaceholder_red[()] = rxplaceholder_red[()] +
rxplaceholder[k0, k1, k2, k3]
- with T.block("T_divide"):
- vi = T.axis.spatial(1, T.int64(0))
- T.reads(rxplaceholder_red[()])
- T.writes(T_divide[()])
- T_divide[()] = rxplaceholder_red[()] *
T.float32(0.0083333333333333332)
- for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
+ rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] =
T.float32(0)
+ rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] + rxplaceholder[v_k0, v_k1, v_k2,
v_k3]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1),
T.int64(1), T.int64(1)):
+ with T.block("T_divide"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_divide[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.0083333333333333332)
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.block("T_subtract"):
- ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
- T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_divide[()])
- T.writes(T_subtract[ax0, ax1, ax2, ax3])
- T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1,
ax2, ax3] - T_divide[()]
- for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
T_divide[T.int64(0), T.int64(0), T.int64(0), T.int64(0)])
+ T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_divide[T.int64(0), T.int64(0),
T.int64(0), T.int64(0)]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.block("T_multiply"):
- ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
- T.reads(T_subtract[ax0, ax1, ax2, ax3])
- T.writes(T_multiply[ax0, ax1, ax2, ax3])
- T_multiply[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2,
ax3] * T_subtract[ax0, ax1, ax2, ax3]
- for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0,
v_ax1, v_ax2, v_ax3] * T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]
+ for k0, k1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
with T.block("T_multiply_red"):
- k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3])
- T.reads(T_multiply[k0, k1, k2, k3])
+ v_k0, v_k1, v_k2, v_k3 = T.axis.remap("RRRR", [k0, k1, k2,
k3])
+ T.reads(T_multiply[v_k0, v_k1, v_k2, v_k3])
T.writes(T_multiply_red[()])
with T.init():
T_multiply_red[()] = T.float32(0)
- T_multiply_red[()] = T_multiply_red[()] + T_multiply[k0,
k1, k2, k3]
+ T_multiply_red[()] = T_multiply_red[()] + T_multiply[v_k0,
v_k1, v_k2, v_k3]
with T.block("T_divide_1"):
vi = T.axis.spatial(1, T.int64(0))
T.reads(T_multiply_red[()])
@@ -751,6 +748,12 @@ def test_std():
T.reads(T_divide_1[()])
T.writes(compute[()])
compute[()] = T.sqrt(T_divide_1[()])
+
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((),
dtype="float32"):
+ cls = Expected
+ gv = R.call_tir(cls.std, (x,), out_sinfo=R.Tensor((),
dtype="float32"))
+ return gv
# fmt: on
mod = LegalizeOps()(Std)
@@ -766,60 +769,54 @@ def test_std_symbolic():
gv: R.Tensor((), "float32") = R.std(x)
return gv
- @tvm.script.ir_module
+ @I.ir_module
class Expected:
- @R.function
- def main(x: R.Tensor(("a", "b", "c", "d"), "float32")) -> R.Tensor((),
"float32"):
- gv = R.call_tir(Expected.std, (x,), R.Tensor((), dtype="float32"))
- return gv
-
@T.prim_func
def std(var_rxplaceholder: T.handle, compute: T.Buffer((), "float32")):
T.func_attr({"tir.noalias": True})
- a = T.int64()
- b = T.int64()
- c = T.int64()
- d = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b, c, d],
dtype="float32")
- rxplaceholder_red = T.alloc_buffer([], dtype="float32")
- T_divide = T.alloc_buffer([], dtype="float32")
- T_subtract = T.alloc_buffer([a, b, c, d], dtype="float32")
- T_multiply = T.alloc_buffer([a, b, c, d], dtype="float32")
- T_multiply_red = T.alloc_buffer([], dtype="float32")
- T_divide_1 = T.alloc_buffer([], dtype="float32")
- for i0, i1, i2, i3 in T.grid(a, b, c, d):
+ a, b, c, d = T.int64(), T.int64(), T.int64(), T.int64()
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (a, b, c, d))
+ # with T.block("root"):
+ rxplaceholder_red = T.alloc_buffer((T.int64(1), T.int64(1),
T.int64(1), T.int64(1)))
+ T_divide = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(1),
T.int64(1)))
+ T_subtract = T.alloc_buffer((a, b, c, d))
+ T_multiply = T.alloc_buffer((a, b, c, d))
+ T_multiply_red = T.alloc_buffer(())
+ T_divide_1 = T.alloc_buffer(())
+ for ax0, ax1, ax2, ax3, k0, k1, k2, k3 in T.grid(T.int64(1),
T.int64(1), T.int64(1), T.int64(1), a, b, c, d):
with T.block("rxplaceholder_red"):
- k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3])
- T.reads(rxplaceholder[k0, k1, k2, k3])
- T.writes(rxplaceholder_red[()])
+ v_ax0, v_ax1, v_ax2, v_ax3, v_k0, v_k1, v_k2, v_k3 =
T.axis.remap("SSSSRRRR", [ax0, ax1, ax2, ax3, k0, k1, k2, k3])
+ T.reads(rxplaceholder[v_k0, v_k1, v_k2, v_k3])
+ T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3])
with T.init():
- rxplaceholder_red[()] = T.float32(0)
- rxplaceholder_red[()] = rxplaceholder_red[()] +
rxplaceholder[k0, k1, k2, k3]
- with T.block("T_divide"):
- vi = T.axis.spatial(1, T.int64(0))
- T.reads(rxplaceholder_red[()])
- T.writes(T_divide[()])
- T_divide[()] = rxplaceholder_red[()] / T.Cast("float32", a * b
* c * d)
- for i0, i1, i2, i3 in T.grid(a, b, c, d):
+ rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] =
T.float32(0)
+ rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] + rxplaceholder[v_k0, v_k1, v_k2,
v_k3]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1),
T.int64(1), T.int64(1)):
+ with T.block("T_divide"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_divide[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_divide[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] / T.Cast("float32", a * b * c * d)
+ for ax0, ax1, ax2, ax3 in T.grid(a, b, c, d):
with T.block("T_subtract"):
- ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
- T.reads(rxplaceholder[ax0, ax1, ax2, ax3], T_divide[()])
- T.writes(T_subtract[ax0, ax1, ax2, ax3])
- T_subtract[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1,
ax2, ax3] - T_divide[()]
- for i0, i1, i2, i3 in T.grid(a, b, c, d):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
T_divide[T.int64(0), T.int64(0), T.int64(0), T.int64(0)])
+ T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_divide[T.int64(0), T.int64(0),
T.int64(0), T.int64(0)]
+ for ax0, ax1, ax2, ax3 in T.grid(a, b, c, d):
with T.block("T_multiply"):
- ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
- T.reads(T_subtract[ax0, ax1, ax2, ax3])
- T.writes(T_multiply[ax0, ax1, ax2, ax3])
- T_multiply[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2,
ax3] * T_subtract[ax0, ax1, ax2, ax3]
- for i0, i1, i2, i3 in T.grid(a, b, c, d):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0,
v_ax1, v_ax2, v_ax3] * T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]
+ for k0, k1, k2, k3 in T.grid(a, b, c, d):
with T.block("T_multiply_red"):
- k0, k1, k2, k3 = T.axis.remap("RRRR", [i0, i1, i2, i3])
- T.reads(T_multiply[k0, k1, k2, k3])
+ v_k0, v_k1, v_k2, v_k3 = T.axis.remap("RRRR", [k0, k1, k2,
k3])
+ T.reads(T_multiply[v_k0, v_k1, v_k2, v_k3])
T.writes(T_multiply_red[()])
with T.init():
T_multiply_red[()] = T.float32(0)
- T_multiply_red[()] = T_multiply_red[()] + T_multiply[k0,
k1, k2, k3]
+ T_multiply_red[()] = T_multiply_red[()] + T_multiply[v_k0,
v_k1, v_k2, v_k3]
with T.block("T_divide_1"):
vi = T.axis.spatial(1, T.int64(0))
T.reads(T_multiply_red[()])
@@ -830,6 +827,16 @@ def test_std_symbolic():
T.reads(T_divide_1[()])
T.writes(compute[()])
compute[()] = T.sqrt(T_divide_1[()])
+
+ @R.function
+ def main(x: R.Tensor(("a", "b", "c", "d"), dtype="float32")) ->
R.Tensor((), dtype="float32"):
+ a = T.int64()
+ b = T.int64()
+ c = T.int64()
+ d = T.int64()
+ cls = Expected
+ gv = R.call_tir(cls.std, (x,), out_sinfo=R.Tensor((),
dtype="float32"))
+ return gv
# fmt: on
mod = LegalizeOps()(Std)
@@ -986,5 +993,77 @@ def test_variance_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_variance_no_keepdims():
+ # fmt: off
+ @tvm.script.ir_module
+ class Variance:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((1, 3, 4,
1), "float32"):
+ gv: R.Tensor((1, 3, 4, 1), "float32") = R.variance(x, [0, 3],
keepdims=False)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def variance(rxplaceholder: T.Buffer((T.int64(2), T.int64(3),
T.int64(4), T.int64(5)), "float32"), T_divide: T.Buffer((T.int64(3),
T.int64(4)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ rxplaceholder_red = T.alloc_buffer((T.int64(1), T.int64(3),
T.int64(4), T.int64(1)))
+ T_divide_1 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(4),
T.int64(1)))
+ T_subtract = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
+ T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
+ T_multiply_red = T.alloc_buffer((T.int64(3), T.int64(4)))
+ for ax0, ax1, ax2, ax3, k0, k3 in T.grid(T.int64(1), T.int64(3),
T.int64(4), T.int64(1), T.int64(2), T.int64(5)):
+ with T.block("rxplaceholder_red"):
+ v_ax0, v_ax1, v_ax2, v_ax3, v_k0, v_k3 =
T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, k0, k3])
+ T.reads(rxplaceholder[v_k0, v_ax1, v_ax2, v_k3])
+ T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3])
+ with T.init():
+ rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] =
T.float32(0)
+ rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] + rxplaceholder[v_k0, v_ax1,
v_ax2, v_k3]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3),
T.int64(4), T.int64(1)):
+ with T.block("T_divide"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_divide_1[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_divide_1[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder_red[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.10000000000000001)
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
+ with T.block("T_subtract"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
T_divide_1[T.int64(0), v_ax1, v_ax2, T.int64(0)])
+ T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_divide_1[T.int64(0), v_ax1,
v_ax2, T.int64(0)]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0,
v_ax1, v_ax2, v_ax3] * T_subtract[v_ax0, v_ax1, v_ax2, v_ax3]
+ for ax0, ax1, k0, k3 in T.grid(T.int64(3), T.int64(4), T.int64(2),
T.int64(5)):
+ with T.block("T_multiply_red"):
+ v_ax0, v_ax1, v_k0, v_k3 = T.axis.remap("SSRR", [ax0, ax1,
k0, k3])
+ T.reads(T_multiply[v_k0, v_ax0, v_ax1, v_k3])
+ T.writes(T_multiply_red[v_ax0, v_ax1])
+ with T.init():
+ T_multiply_red[v_ax0, v_ax1] = T.float32(0)
+ T_multiply_red[v_ax0, v_ax1] = T_multiply_red[v_ax0,
v_ax1] + T_multiply[v_k0, v_ax0, v_ax1, v_k3]
+ for ax0, ax1 in T.grid(T.int64(3), T.int64(4)):
+ with T.block("T_divide_1"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(T_multiply_red[v_ax0, v_ax1])
+ T.writes(T_divide[v_ax0, v_ax1])
+ T_divide[v_ax0, v_ax1] = T_multiply_red[v_ax0, v_ax1] *
T.float32(0.10000000000000001)
+
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((1,
3, 4, 1), dtype="float32"):
+ cls = Expected
+ gv = R.call_tir(cls.variance, (x,), out_sinfo=R.Tensor((3, 4),
dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Variance)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_grad.py
b/tests/python/relax/test_tvmscript_parser_op_grad.py
new file mode 100644
index 0000000000..9dd1f01fc7
--- /dev/null
+++ b/tests/python/relax/test_tvmscript_parser_op_grad.py
@@ -0,0 +1,142 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WA`RRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional, Union
+
+import tvm
+import tvm.testing
+from tvm import IRModule, relax
+from tvm.script.parser import relax as R
+
+
+def _check(
+ parsed: Union[relax.Function, IRModule],
+ expect: Optional[Union[relax.Function, IRModule]],
+):
+ test = parsed.script(show_meta=True)
+ roundtrip_mod = tvm.script.from_source(test)
+ tvm.ir.assert_structural_equal(parsed, roundtrip_mod)
+ if expect:
+ tvm.ir.assert_structural_equal(parsed, expect)
+
+
+def test_nll_loss_backward():
+ @R.function
+ def foo(
+ output_grad: R.Tensor((3, 10, 10), dtype="float32"),
+ predictions: R.Tensor((3, 5, 10, 10), dtype="float32"),
+ targets: R.Tensor((3, 10, 10), dtype="int64"),
+ weights: R.Tensor((5,), dtype="float32"),
+ ) -> R.Tensor((3, 5, 10, 10), dtype="float32"):
+ gv: R.Tensor((3, 5, 10, 10), dtype="float32") =
R.grad.nll_loss_backward(
+ output_grad, predictions, targets, weights, "mean", -1
+ )
+ return gv
+
+ output_grad = relax.Var("output_grad", R.Tensor((3, 10, 10), "float32"))
+ predictions = relax.Var("predictions", R.Tensor((3, 5, 10, 10), "float32"))
+ targets = relax.Var("targets", R.Tensor((3, 10, 10), "int64"))
+ weights = relax.Var("weights", R.Tensor((5,), "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [output_grad, predictions, targets, weights]):
+ gv = bb.emit(
+ relax.op.grad.nll_loss_backward(
+ output_grad, predictions, targets, weights, reduction="mean",
ignore_index=-1
+ )
+ )
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
+def test_nll_loss_backward_no_weights():
+ @R.function
+ def foo(
+ output_grad: R.Tensor((3, 10, 10), dtype="float32"),
+ predictions: R.Tensor((3, 5, 10, 10), dtype="float32"),
+ targets: R.Tensor((3, 10, 10), dtype="int64"),
+ ) -> R.Tensor((3, 5, 10, 10), dtype="float32"):
+ gv: R.Tensor((3, 5, 10, 10), dtype="float32") =
R.grad.nll_loss_backward(
+ output_grad, predictions, targets, reduction="mean",
ignore_index=-1
+ )
+ return gv
+
+ output_grad = relax.Var("output_grad", R.Tensor((3, 10, 10), "float32"))
+ predictions = relax.Var("predictions", R.Tensor((3, 5, 10, 10), "float32"))
+ targets = relax.Var("targets", R.Tensor((3, 10, 10), "int64"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [output_grad, predictions, targets]):
+ gv = bb.emit(
+ relax.op.grad.nll_loss_backward(
+ output_grad, predictions, targets, reduction="mean",
ignore_index=-1
+ )
+ )
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
+def test_max_pool2d_backward():
+ @R.function
+ def foo(
+ output_grad: R.Tensor((3, 2, 6, 5), "float32"), data: R.Tensor((3, 2,
10, 10), "float32")
+ ):
+ gv = R.grad.max_pool2d_backward(
+ output_grad, data, (5, 5), (2, 2), (2, 1, 2, 1), (1, 1), True
+ )
+ return gv
+
+ output_grad = relax.Var("output_grad", R.Tensor((3, 2, 6, 5), "float32"))
+ data = relax.Var("data", R.Tensor((3, 2, 10, 10), "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [output_grad, data]):
+ gv = bb.emit(
+ relax.op.grad.max_pool2d_backward(
+ output_grad, data, (5, 5), (2, 2), (2, 1, 2, 1), (1, 1), True
+ )
+ )
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
+def test_avg_pool2d_backward():
+ @R.function
+ def foo(
+ output_grad: R.Tensor((3, 2, 6, 5), "float32"), data: R.Tensor((3, 2,
10, 10), "float32")
+ ):
+ gv = R.grad.avg_pool2d_backward(
+ output_grad, data, (5, 5), (2, 2), (2, 1, 2, 1), (1, 1), True
+ )
+ return gv
+
+ output_grad = relax.Var("output_grad", R.Tensor((3, 2, 6, 5), "float32"))
+ data = relax.Var("data", R.Tensor((3, 2, 10, 10), "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [output_grad, data]):
+ gv = bb.emit(
+ relax.op.grad.avg_pool2d_backward(
+ output_grad, data, (5, 5), (2, 2), (2, 1, 2, 1), (1, 1), True
+ )
+ )
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
+if __name__ == "__main__":
+ tvm.testing.main()