This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ec89242fbd [Unity] Update docs for operators (#14659)
ec89242fbd is described below
commit ec89242fbd074acea7f77925eec49cdc21bbcb79
Author: Yixin Dong <[email protected]>
AuthorDate: Thu Apr 20 05:17:24 2023 +0800
[Unity] Update docs for operators (#14659)
This PR:
updates docs for several operators,
add an 1 default value for the axis parameter of batch_norm,
rename module tvm.relax.transform.legalize_ops.creation to
tvm.relax.transform.legalize_ops.create, which is align with
tvm.relax.op.create, and fixes a previous upstream error in
tests/python/relax/test_transform_legalize_ops_grad.py
---
python/tvm/relax/op/_op_gradient.py | 6 ++--
python/tvm/relax/op/grad/grad.py | 12 ++++----
python/tvm/relax/op/nn/nn.py | 19 +++++++++++--
.../tvm/relax/transform/legalize_ops/__init__.py | 2 +-
.../legalize_ops/{creation.py => create.py} | 0
.../relax/transform/legalize_ops/statistical.py | 4 +++
.../relax/test_transform_legalize_ops_grad.py | 32 ++++++++++------------
7 files changed, 44 insertions(+), 31 deletions(-)
diff --git a/python/tvm/relax/op/_op_gradient.py
b/python/tvm/relax/op/_op_gradient.py
index 9de7370545..b0e37a9418 100644
--- a/python/tvm/relax/op/_op_gradient.py
+++ b/python/tvm/relax/op/_op_gradient.py
@@ -573,7 +573,7 @@ def mean_grad(
Backward:
Returns `[broadcast_to(y_output_grad, x.shape) / prod(x.shape[i] for i
in axis)]`.
- If `keepdims=False`, the meaned axis will be added back.
+ If `keepdims=False`, the mean axis will be added back.
"""
axis = orig_call.attrs.axis
keepdims = orig_call.attrs.keepdims
@@ -749,7 +749,7 @@ def cumsum_grad(
`y = relax.cumsum(x, axis)`
Backward:
- The "reversed" cumsum along the same axis. Implement by some tricks
now.
+ The "reversed" cumsum along the same axis. Implemented by some tricks
now.
"""
axis = orig_call.attrs["axis"]
@@ -786,7 +786,7 @@ def take_grad(
`y = relax.take(x, indices, axis)`
Backward:
- Returns .
+ Returns [x_grad, no_grad].
The second parameter, the indices, is not differentiable.
"""
diff --git a/python/tvm/relax/op/grad/grad.py b/python/tvm/relax/op/grad/grad.py
index b433dc9c60..e1f1591876 100644
--- a/python/tvm/relax/op/grad/grad.py
+++ b/python/tvm/relax/op/grad/grad.py
@@ -51,8 +51,8 @@ def nll_loss_backward(
reduction: str = "mean",
ignore_index: int = -100,
) -> Expr:
- """Backward operator of relax.nll_loss. All parameters except output_grad
is the same as
- relax.nll_loss. Returns the gradient w.r.t. predictions.
+ """Backward operator of relax.nn.nll_loss. All parameters except
output_grad is the same as
+ relax.nn.nll_loss. Returns the gradient w.r.t. predictions.
Parameters
----------
@@ -80,8 +80,8 @@ def max_pool2d_backward(
layout: str = "NCHW",
out_layout: Optional[str] = None,
) -> Expr:
- """Backward operator of relax.max_pool2d. All parameters except
output_grad is the same as
- relax.max_pool2d. Returns the gradient w.r.t. data.
+ """Backward operator of relax.nn.max_pool2d. All parameters except
output_grad is the same as
+ relax.nn.max_pool2d. Returns the gradient w.r.t. data.
Parameters
----------
@@ -109,8 +109,8 @@ def avg_pool2d_backward(
layout: str = "NCHW",
out_layout: Optional[str] = None,
) -> Expr:
- """Backward operator of relax.avg_pool2d. All parameters except
output_grad is the same as
- relax.avg_pool2d. Returns the gradient w.r.t. data.
+ """Backward operator of relax.nn.avg_pool2d. All parameters except
output_grad is the same as
+ relax.nn.avg_pool2d. Returns the gradient w.r.t. data.
Parameters
----------
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 083bca653a..5483e7c5ee 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -667,6 +667,7 @@ def batch_norm(
) -> Expr:
r"""
Batch normalization layer (Ioffe and Szegedy, 2014).
+
Normalizes the input at each batch, i.e. applies a transformation
that maintains the mean activation close to 0 and the activation
standard deviation close to 1.
@@ -676,6 +677,8 @@ def batch_norm(
data\_mean[i] = mean(data[:,i,:,...]) \\
data\_var[i] = var(data[:,i,:,...])
+ Both *mean* and *var* returns a scalar by treating the input as a vector.
+
Then compute the normalized output, which has the same shape as input, as
following:
.. math::
@@ -683,8 +686,6 @@ def batch_norm(
out[:,i,:,...] = \frac{data[:,i,:,...] -
data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}}
* gamma[i] + beta[i]
- Both *mean* and *var* returns a scalar by treating the input as a vector.
-
Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
have shape *(k,)*.
@@ -703,7 +704,19 @@ def batch_norm(
.. note::
- This operator can be optimized away for inference.
+ This operator has two modes:
+ - Training mode.
+ - Use the mean and var computed from THIS batch to normalize.
+ - Update and then return the running mean and running var.
+ - Inference mode.
+ - Use the running_mean and running_var parameters to normalize.
+ - Do not update the running mean and running var. Just return the
original value.
+
+ In the legalization stage, this operator will be legalized to the
training mode by default.
+
+ You can use tvm.relax.transform.DecomposeOpsForInference to decompose
the operator, so it
+ executes the inference mode computation. Similarly, use
+ tvm.relax.transform.DecomposeOpsForTraining to execute the training
mode computation.
Parameters
----------
diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py
b/python/tvm/relax/transform/legalize_ops/__init__.py
index 8b668e5040..613bd8970f 100644
--- a/python/tvm/relax/transform/legalize_ops/__init__.py
+++ b/python/tvm/relax/transform/legalize_ops/__init__.py
@@ -16,7 +16,7 @@
# under the License.
"""Legalize high-level operator calls in Relax functions to call_tir."""
from . import binary
-from . import creation
+from . import create
from . import datatype
from . import grad
from . import image
diff --git a/python/tvm/relax/transform/legalize_ops/creation.py
b/python/tvm/relax/transform/legalize_ops/create.py
similarity index 100%
rename from python/tvm/relax/transform/legalize_ops/creation.py
rename to python/tvm/relax/transform/legalize_ops/create.py
diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py
b/python/tvm/relax/transform/legalize_ops/statistical.py
index 71cf1ef808..e1f273bda0 100644
--- a/python/tvm/relax/transform/legalize_ops/statistical.py
+++ b/python/tvm/relax/transform/legalize_ops/statistical.py
@@ -47,6 +47,10 @@ def _te_mean(x: te.Tensor, axis: List[tir.IntImm], keepdims:
bool) -> te.Tensor:
def _te_variance(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) ->
te.Tensor:
dev = x - _te_mean(x, axis, True)
return _te_mean(dev * dev, axis, keepdims)
+ # This version has better memory locality and performance
+ # But may trigger some precision problems, so we will use the previous
version now
+ # mean = _te_mean(x, axis, keepdims)
+ # return _te_mean(x * x, axis, keepdims) - mean * mean
@register_legalize("relax.mean")
diff --git a/tests/python/relax/test_transform_legalize_ops_grad.py
b/tests/python/relax/test_transform_legalize_ops_grad.py
index e8f75d83a9..67d0b9194b 100644
--- a/tests/python/relax/test_transform_legalize_ops_grad.py
+++ b/tests/python/relax/test_transform_legalize_ops_grad.py
@@ -14,13 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import pytest
-
import tvm
from tvm.relax.transform import LegalizeOps
from tvm.script import relax as R, tir as T, ir as I
import tvm.testing
-from tvm.tir.op import div
def test_nll_loss_backward():
@@ -207,7 +204,6 @@ def test_nll_loss_backward_no_batch():
tvm.ir.assert_structural_equal(mod, Expected)
[email protected]("Regression to be fixed in the generated after merge.")
def test_max_pool2d_backward():
# fmt: off
@tvm.script.ir_module
@@ -219,15 +215,9 @@ def test_max_pool2d_backward():
@I.ir_module
class Expected:
- @R.function
- def main(output_grad: R.Tensor((3, 2, 6, 5), dtype="float32"), data:
R.Tensor((3, 2, 10, 10), dtype="float32")) -> R.Tensor((3, 2, 10, 10),
dtype="float32"):
- cls = Expected
- gv = R.call_tir(cls.max_pool2d_backward, (output_grad, data),
out_sinfo=R.Tensor((3, 2, 10, 10), dtype="float32"))
- return gv
-
@T.prim_func
- def max_pool2d_backward(rxplaceholder: T.Buffer((T.int64(3),
T.int64(2), T.int64(6), T.int64(5)), "float32"), rxplaceholder_1:
T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32"),
T_pool_grad: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)),
"float32")):
- T.func_attr({"tir.noalias": True})
+ def max_pool2d_backward(A: T.Buffer((T.int64(3), T.int64(2),
T.int64(6), T.int64(5)), "float32"), B: T.Buffer((T.int64(3), T.int64(2),
T.int64(10), T.int64(10)), "float32"), T_pool_grad: T.Buffer((T.int64(3),
T.int64(2), T.int64(10), T.int64(10)), "float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
pad_temp = T.alloc_buffer((T.int64(3), T.int64(2), T.int64(15),
T.int64(13)))
maxpool_grad_argmax_v0 = T.alloc_buffer((T.int64(3), T.int64(2),
T.int64(6), T.int64(5)), "int64")
@@ -235,29 +225,35 @@ def test_max_pool2d_backward():
for ax0, ax1, ax2, ax3 in T.grid(T.int64(3), T.int64(2),
T.int64(15), T.int64(13)):
with T.block("pad_temp"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2 - T.int64(2),
v_ax3 - T.int64(1)])
+ T.reads(B[v_ax0, v_ax1, v_ax2 - T.int64(2), v_ax3 -
T.int64(1)])
T.writes(pad_temp[v_ax0, v_ax1, v_ax2, v_ax3])
- pad_temp[v_ax0, v_ax1, v_ax2, v_ax3] =
T.if_then_else(T.int64(2) <= v_ax2 and v_ax2 < T.int64(12) and T.int64(1) <=
v_ax3 and v_ax3 < T.int64(11), rxplaceholder_1[v_ax0, v_ax1, v_ax2 -
T.int64(2), v_ax3 - T.int64(1)], T.float32(-3.4028234663852886e+38))
+ pad_temp[v_ax0, v_ax1, v_ax2, v_ax3] =
T.if_then_else(T.int64(2) <= v_ax2 and v_ax2 < T.int64(12) and T.int64(1) <=
v_ax3 and v_ax3 < T.int64(11), B[v_ax0, v_ax1, v_ax2 - T.int64(2), v_ax3 -
T.int64(1)], T.float32(-3.4028234663852886e+38))
for ax0, ax1, ax2, ax3, dh, dw in T.grid(T.int64(3), T.int64(2),
T.int64(6), T.int64(5), T.int64(5), T.int64(5)):
with T.block("maxpool_grad_argmax"):
v_ax0, v_ax1, v_ax2, v_ax3, v_dh, v_dw =
T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, dh, dw])
T.reads(pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_dh,
v_ax3 * T.int64(2) + v_dw])
T.writes(maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2,
v_ax3], maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3])
with T.init():
- maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3] = -1
+ maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3] =
T.int64(-1)
maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] =
T.float32(-3.4028234663852886e+38)
- v_maxpool_grad_argmax_v0: T.int64 =
T.Select(maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] > pad_temp[v_ax0,
v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw] or
maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] == pad_temp[v_ax0, v_ax1,
v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw] and
maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3] < v_ax0 * T.int64(390) +
v_ax1 * T.int64(195) + v_ax2 * T.int64(26) + v_dh * T.int64(13) + v_ax3 *
T.int64( [...]
+ v_maxpool_grad_argmax_v0: T.int64 =
T.Select(maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] > pad_temp[v_ax0,
v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw] or
maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] == pad_temp[v_ax0, v_ax1,
v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw] and
maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3] < v_ax0 * T.int64(390) +
v_ax1 * T.int64(195) + v_ax2 * T.int64(26) + v_dh * T.int64(13) + v_ax3 *
T.int64(2 [...]
v_maxpool_grad_argmax_v1: T.float32 =
T.Select(maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] > pad_temp[v_ax0,
v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw],
maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3], pad_temp[v_ax0, v_ax1,
v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw])
maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3] =
v_maxpool_grad_argmax_v0
maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] =
v_maxpool_grad_argmax_v1
for ax0, ax1, ax2, ax3, wh, ww in T.grid(T.int64(3), T.int64(2),
T.int64(10), T.int64(10), T.int64(3), T.int64(3)):
with T.block("T_pool_grad"):
v_ax0, v_ax1, v_ax2, v_ax3, v_wh, v_ww =
T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, wh, ww])
- T.reads(maxpool_grad_argmax_v0[v_ax0, v_ax1, div((v_ax2 +
T.int64(2)), T.int64(2)) - v_wh, div((v_ax3 + T.int64(1)), T.int64(2)) - v_ww],
rxplaceholder[v_ax0, v_ax1, div((v_ax2 + T.int64(2)), T.int64(2)) - v_wh,
div((v_ax3 + T.int64(1)), T.int64(2)) - v_ww])
+ T.reads(maxpool_grad_argmax_v0[v_ax0, v_ax1, T.Div(v_ax2 +
T.int64(2), T.int64(2)) - v_wh, T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww],
A[v_ax0, v_ax1, T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh, T.Div(v_ax3 +
T.int64(1), T.int64(2)) - v_ww])
T.writes(T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3])
with T.init():
T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0)
- T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] =
T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 <
T.int64(3), T.int64(0), div((v_ax2 - T.int64(3)), T.int64(2)) + T.int64(1)) <=
div((v_ax2 + T.int64(2)), T.int64(2)) - v_wh and T.Select(v_ax3 < T.int64(4),
T.int64(0), div((v_ax3 - T.int64(4)), T.int64(2)) + T.int64(1)) <= div((v_ax3 +
T.int64(1)), T.int64(2)) - v_ww and T.Cast("int64",
maxpool_grad_argmax_v0[v_ax0, v_ax1, div((v_ax2 + T.int64(2)), T.in [...]
+ T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] =
T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 <
T.int64(3), T.int64(0), T.Div(v_ax2 - T.int64(3), T.int64(2)) + T.int64(1)) <=
T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh and T.Select(v_ax3 < T.int64(4),
T.int64(0), T.Div(v_ax3 - T.int64(4), T.int64(2)) + T.int64(1)) <= T.Div(v_ax3
+ T.int64(1), T.int64(2)) - v_ww and maxpool_grad_argmax_v0[v_ax0, v_ax1,
T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh, [...]
+
+ @R.function
+ def main(output_grad: R.Tensor((3, 2, 6, 5), dtype="float32"), data:
R.Tensor((3, 2, 10, 10), dtype="float32")) -> R.Tensor((3, 2, 10, 10),
dtype="float32"):
+ cls = Expected
+ gv = R.call_tir(cls.max_pool2d_backward, (output_grad, data),
out_sinfo=R.Tensor((3, 2, 10, 10), dtype="float32"))
+ return gv
# fmt: on
mod = LegalizeOps()(MaxPool2DBackward)