This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 51d4b6bcde [Relax] Batch norm correctness on eval mode (#17752)
51d4b6bcde is described below
commit 51d4b6bcdeef9094e721e9154ae096370a05a465
Author: Hugo Latendresse <[email protected]>
AuthorDate: Tue Mar 25 20:57:11 2025 -0400
[Relax] Batch norm correctness on eval mode (#17752)
Batch_norm is a different operator in training and eval.
The previous interface defaulted to the training mode and required
changing an ingested pytorch program itself to use the eval mode.
This is sub-ideal, especially since torch.export explicitely communicates
whether batch_norm should be in training or eval in a given torch program.
This PR automates the selection of training/eval mode in the exported
program translator, and achieves correctness for eval mode.
Future TODO: there is something wrong with batch_norm on training
mode. It does not pass a correctness test when taken straight from
the main branch (there's an issue with tensor dimensions).
I added a note to address later as training mode is probably not high
priority.
---
include/tvm/relax/attrs/nn.h | 2 +
.../frontend/torch/base_fx_graph_translator.py | 2 +-
.../frontend/torch/exported_program_translator.py | 43 +-
python/tvm/relax/op/nn/nn.py | 8 +-
python/tvm/relax/transform/legalize_ops/nn.py | 4 +-
python/tvm/topi/nn/batch_norm.py | 26 +-
src/relax/op/nn/nn.cc | 8 +-
src/relax/op/nn/nn.h | 2 +-
tests/python/relax/test_from_exported_to_cuda.py | 59 +-
.../python/relax/test_transform_legalize_ops_nn.py | 980 ++++++++++++---------
10 files changed, 683 insertions(+), 451 deletions(-)
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 8329344174..8f63012e09 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -462,6 +462,7 @@ struct BatchNormAttrs : public
tvm::AttrsNode<BatchNormAttrs> {
bool center;
bool scale;
double momentum;
+ bool training;
TVM_DECLARE_ATTRS(BatchNormAttrs, "relax.attrs.BatchNormAttrs") {
TVM_ATTR_FIELD(axis).describe("The axis along which the normalization is
applied.");
@@ -470,6 +471,7 @@ struct BatchNormAttrs : public
tvm::AttrsNode<BatchNormAttrs> {
"Indicating if the beta offset will be added to the normalized
tensor.");
TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be
multiplied.");
TVM_ATTR_FIELD(momentum).describe("The value used for the moving_mean and
moving_var update.");
+ TVM_ATTR_FIELD(training).describe("Whether we are training (i.e., not in
eval mode).");
}
}; // struct BatchNormAttrs
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 4bfdb8c1bc..ecd8665b43 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1106,7 +1106,7 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
return self.env[node.args[0]]
def _copy_(self, node: fx.Node) -> relax.Var:
- # Copies the source tensor's to the destination tensor
+ # Copies the source tensor's into the destination tensor
# In TVM, that means simply returning the source tensor
return self.env[node.args[1]]
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index bf01bd6531..84821d27b5 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -45,7 +45,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
########## Neural Network ##########
- def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var:
+ def _batch_norm(self, node: fx.Node, training) -> relax.Var:
import numpy as np
x = self.env[node.args[0]]
@@ -55,22 +55,43 @@ class ExportedProgramImporter(BaseFXGraphImporter):
bias = self.env.get(node.args[2], relax.const(np.zeros(channel),
dtype=dtype))
running_mean = self.env.get(node.args[3],
relax.const(np.zeros(channel), dtype=dtype))
running_var = self.env.get(node.args[4], relax.const(np.ones(channel),
dtype=dtype))
- momentum = node.args[5] if len(node.args) > 5 else
node.kwargs.get("momentum", 0.1)
- eps = node.args[6] if len(node.args) > 6 else node.kwargs.get("eps",
1e-05)
+ ignore_running_stats = (
+ node.args[5] if len(node.args) > 5 else
node.kwargs.get("track_running_stats", True)
+ )
+ track_running_stats = not ignore_running_stats
+ momentum = node.args[6] if len(node.args) > 6 else
node.kwargs.get("momentum", 0.1)
+ eps = node.args[7] if len(node.args) > 7 else node.kwargs.get("eps",
1e-05)
+
+ if track_running_stats:
+ training = True
return self.block_builder.emit(
relax.op.nn.batch_norm(
- x,
- weight,
- bias,
- running_mean,
- running_var,
- axis=1,
+ data=x,
+ gamma=weight,
+ beta=bias,
+ moving_mean=running_mean,
+ moving_var=running_var,
+ axis=1, # Always over channel
epsilon=eps,
momentum=momentum,
- )
+ training=training,
+ )[0]
)
+ def _batch_norm_legit_functional(self, node: fx.Node) -> relax.Var:
+ # This method is called for batch_norm in training mode
+ # TODO does not have correctness!
+ # TODO we need to store the running mean and variance returned by the
+ # previous call to batch_norm and pass it again
+ training = True
+ return self._batch_norm(node, training)
+
+ def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var:
+ # This method is called for batch_norm in eval mode
+ training = False
+ return self._batch_norm(node, training)
+
def _group_norm(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
num_groups = node.args[1]
@@ -283,7 +304,9 @@ class ExportedProgramImporter(BaseFXGraphImporter):
# linear algebra
"linalg_vector_norm.default": self._linalg_vector_norm,
# neural network
+ "_native_batch_norm_legit_functional.default":
self._batch_norm_legit_functional,
"_native_batch_norm_legit_no_training.default":
self._batch_norm_legit_no_training,
+ "batch_norm.default": self._batch_norm_legit_no_training,
"adaptive_avg_pool2d.default": self._adaptive_avg_pool2d,
"addmm.default": self._addmm,
"avg_pool2d.default": self._avg_pool2d,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 5a1895cbc1..09a7df5149 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -1393,6 +1393,7 @@ def batch_norm(
center: bool = True,
scale: bool = True,
momentum: float = 0.1,
+ training: bool = True,
) -> Expr:
r"""
Batch normalization layer (Ioffe and Szegedy, 2014).
@@ -1481,13 +1482,18 @@ def batch_norm(
momentum : float
The value used for the moving_mean and moving_var update.
+ training : bool
+ A boolean value to indicate whether training or in eval mode. By
default.
+ relax batch_norm is training mode. To transform it to inference mode,
+ can use DecomposeOpsForInference.
+
Returns
-------
result : relax.Expr
The computed result.
"""
return _ffi_api.batch_norm( # type: ignore
- data, gamma, beta, moving_mean, moving_var, axis, epsilon, center,
scale, momentum
+ data, gamma, beta, moving_mean, moving_var, axis, epsilon, center,
scale, momentum, training
)
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index d9fb4701f7..4c8bdbc661 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -551,9 +551,7 @@ def _nn_batch_norm(bb: BlockBuilder, call: Call) -> Expr:
epsilon=call.attrs.epsilon,
center=call.attrs.center,
scale=call.attrs.scale,
- # By default relax batch_norm is training mode.
- # To transform it to inference mode, use DecomposeOpsForInference.
- training=True,
+ training=call.attrs.training,
momentum=call.attrs.momentum,
)
diff --git a/python/tvm/topi/nn/batch_norm.py b/python/tvm/topi/nn/batch_norm.py
index 3181efd7da..8308c93eae 100644
--- a/python/tvm/topi/nn/batch_norm.py
+++ b/python/tvm/topi/nn/batch_norm.py
@@ -111,22 +111,26 @@ def batch_norm(
shape = [1] * len(data.shape)
shape[axis] = data.shape[axis]
+ reduce_axes = list(range(len(data.shape)))
+ reduce_axes.remove(axis)
+ shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in
reduce_axes], 1)
+
+ data_mean = topi.sum(data, axis=reduce_axes) / shape_prod
+ data_mean_rs = topi.reshape(data_mean, shape)
+ data_var = (
+ topi.sum((data - data_mean_rs) * (data - data_mean_rs),
axis=reduce_axes) / shape_prod
+ )
+ data_var_rs = topi.reshape(data_var, shape)
+
if training:
- reduce_axes = list(range(len(data.shape)))
- reduce_axes.remove(axis)
- shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in
reduce_axes], 1)
- data_mean = topi.sum(data, axis=reduce_axes) / shape_prod
- data_mean_rs = topi.reshape(data_mean, shape)
- data_var = (
- topi.sum((data - data_mean_rs) * (data - data_mean_rs),
axis=reduce_axes) / shape_prod
- )
- data_var_rs = topi.reshape(data_var, shape)
- out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon)
- else:
moving_mean_rs = topi.reshape(moving_mean, shape)
moving_var_rs = topi.reshape(moving_var, shape)
+
out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon)
+ else:
+ out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon)
+
if scale:
out = out * topi.reshape(gamma, shape)
if center:
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index b4668d65d3..826711538c 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -252,13 +252,14 @@ bool NormCheckDtypeAndShape(const Call& call, const
BlockBuilder& ctx,
TVM_REGISTER_NODE_TYPE(BatchNormAttrs);
Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr
moving_var, //
- int axis, double epsilon, bool center, bool scale, double
momentum) {
+ int axis, double epsilon, bool center, bool scale, double
momentum, bool training) {
ObjectPtr<BatchNormAttrs> attrs = make_object<BatchNormAttrs>();
attrs->axis = axis;
attrs->epsilon = epsilon;
attrs->center = center;
attrs->scale = scale;
attrs->momentum = momentum;
+ attrs->training = training;
static const Op& op = Op::Get("relax.nn.batch_norm");
return Call(op,
@@ -266,7 +267,6 @@ Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr
moving_mean, Expr moving_
std::move(moving_var)},
Attrs{attrs}, {});
}
-
TVM_REGISTER_GLOBAL("relax.op.nn.batch_norm").set_body_typed(batch_norm);
StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx)
{
@@ -388,7 +388,7 @@ InferLayoutOutput InferLayoutLayerNorm(const Call& call,
TVM_REGISTER_OP("relax.nn.layer_norm")
.set_attrs_type<LayerNormAttrs>()
.set_num_inputs(3)
- .add_argument("data", "Tensor", "Input to which batch_norm will be
applied.")
+ .add_argument("data", "Tensor", "Input to which layer_norm will be
applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoLayerNorm)
@@ -500,7 +500,7 @@ InferLayoutOutput InferLayoutGroupNorm(const Call& call,
TVM_REGISTER_OP("relax.nn.group_norm")
.set_attrs_type<GroupNormAttrs>()
.set_num_inputs(3)
- .add_argument("data", "Tensor", "Input to which batch_norm will be
applied.")
+ .add_argument("data", "Tensor", "Input to which group_norm will be
applied.")
.add_argument("gamma", "Tensor", "The gamma scale factor.")
.add_argument("beta", "Tensor", "The beta offset factor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoGroupNorm)
diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h
index a3658fed54..28c14139b9 100644
--- a/src/relax/op/nn/nn.h
+++ b/src/relax/op/nn/nn.h
@@ -68,7 +68,7 @@ Expr log_softmax(Expr data, int axis);
/*! \brief Compute batch normalization. */
Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr
moving_var, //
- int axis, double epsilon, bool center, bool scale, double
momentum);
+ int axis, double epsilon, bool center, bool scale, double
momentum, bool training);
/*! \brief Compute layer normalization. */
Expr layer_norm(Expr data, Expr gamma, Expr beta, Array<Integer> axes, double
epsilon, bool center,
diff --git a/tests/python/relax/test_from_exported_to_cuda.py
b/tests/python/relax/test_from_exported_to_cuda.py
index c120eb8981..f7501dd3b5 100644
--- a/tests/python/relax/test_from_exported_to_cuda.py
+++ b/tests/python/relax/test_from_exported_to_cuda.py
@@ -289,6 +289,25 @@ def test_linalg_vector_norm(target, dev):
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3,
target, dev)
[email protected]_targets("cuda")
+def test_batch_norm_prog(target, dev):
+ # Default args, in a pytorch program (to ensure output is in proper type
and format)
+ raw_data = np.random.randn(2, 3, 2, 2).astype(np.float32)
+
+ class BatchNormWrapper(nn.Module):
+ def __init__(self):
+ super(BatchNormWrapper, self).__init__()
+ self.bn = nn.BatchNorm2d(3)
+
+ def forward(self, x):
+ x = self.bn(x)
+ x = x + 1
+ return x
+
+ torch_module = BatchNormWrapper().eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
@tvm.testing.parametrize_targets("cuda")
def test_split_size(target, dev):
# Test split using the split_size argument such that it is not a divisor
@@ -310,7 +329,46 @@ def test_split_size(target, dev):
return torch.split(x, split_size_or_sections=self.split_size,
dim=self.dim)
torch_module = SplitModelSplitSize(split_size=split_size, dim=dim).eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
[email protected]_targets("cuda")
+def test_batch_norm0(target, dev):
+ # Eval, no momentum, no affine, no running stats
+ raw_data = np.random.randn(8, 3, 4, 4).astype(np.float32)
+ torch_module = nn.BatchNorm2d(
+ 3, eps=1e-02, momentum=0.0, affine=False, track_running_stats=False,
device=None, dtype=None
+ ).eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_batch_norm1(target, dev):
+ # Eval, with momentum, no affine, with running stats
+ raw_data = np.random.randn(1, 4, 2, 2).astype(np.float32)
+ torch_module = nn.BatchNorm2d(
+ 4, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True,
device=None, dtype=None
+ ).eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_batch_norm2(target, dev):
+ # Eval, with momentum, affine, no running stats
+ raw_data = np.random.randn(3, 4, 2, 2).astype(np.float32)
+ torch_module = nn.BatchNorm2d(
+ 4, eps=1e-05, momentum=0.2, affine=True, track_running_stats=False
+ ).eval()
+ assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
+
+
[email protected]_targets("cuda")
+def test_batch_norm3(target, dev):
+ # Eval, no momentum, affine, with running stats
+ raw_data = np.random.randn(1, 2, 2, 2).astype(np.float32)
+ torch_module = nn.BatchNorm2d(
+ 2, eps=1e-05, momentum=0.0, affine=True, track_running_stats=True
+ ).eval()
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
@@ -335,7 +393,6 @@ def test_split_sections_list(target, dev):
return torch.split(x, split_size_or_sections=self.split_size,
dim=self.dim)
torch_module = SplitModelSectionsList(split_size=sections, dim=dim).eval()
-
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module,
target, dev)
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index d83d0567e4..4ac4b57b91 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -1955,212 +1955,289 @@ def test_batch_norm():
@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
- def batch_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3),),
"float32"), rxplaceholder_2: T.Buffer((T.int64(3),), "float32"),
rxplaceholder_3: T.Buffer((T.int64(3),), "float32"), rxplaceholder_4:
T.Buffer((T.int64(3),), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)), "float32"), T_add_1: T.Buffer((T.int64(3),),
"float32"), T_add_2: T.Buffer((T.int64(3),), "float32")):
- T.func_attr({"tir.noalias": True})
- # with T.block("root"):
- rxplaceholder_red = T.alloc_buffer((T.int64(3),))
- T_divide = T.alloc_buffer((T.int64(3),))
- T_reshape = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1),
T.int64(1)))
- T_subtract = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28),
T.int64(28)))
- T_subtract_1 = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
- T_subtract_2 = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
- T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28),
T.int64(28)))
- T_multiply_red = T.alloc_buffer((T.int64(3),))
- T_divide_1 = T.alloc_buffer((T.int64(3),))
- T_reshape_1 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1),
T.int64(1)))
- T_add_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1),
T.int64(1)))
- compute = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1),
T.int64(1)))
- T_divide_2 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(28),
T.int64(28)))
- T_reshape_2 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1),
T.int64(1)))
- T_multiply_1 = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
- T_reshape_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1),
T.int64(1)))
- T_multiply_2 = T.alloc_buffer((T.int64(3),))
- T_multiply_3 = T.alloc_buffer((T.int64(3),))
- T_multiply_4 = T.alloc_buffer((T.int64(3),))
- T_subtract_3 = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
- T_subtract_4 = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
- T_multiply_5 = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
- T_multiply_red_1 = T.alloc_buffer((T.int64(3),))
- T_divide_3 = T.alloc_buffer((T.int64(3),))
- T_multiply_6 = T.alloc_buffer((T.int64(3),))
- for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28),
T.int64(28)):
- with T.block("rxplaceholder_red"):
- v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0,
k2, k3])
- T.reads(rxplaceholder[v_k0, v_ax0, v_k2, v_k3])
- T.writes(rxplaceholder_red[v_ax0])
- with T.init():
- rxplaceholder_red[v_ax0] = T.float32(0)
- rxplaceholder_red[v_ax0] = rxplaceholder_red[v_ax0] +
rxplaceholder[v_k0, v_ax0, v_k2, v_k3]
- for ax0 in range(T.int64(3)):
- with T.block("T_divide"):
- v_ax0 = T.axis.spatial(T.int64(3), ax0)
- T.reads(rxplaceholder_red[v_ax0])
- T.writes(T_divide[v_ax0])
- T_divide[v_ax0] = rxplaceholder_red[v_ax0] *
T.float32(0.00063775510204081628)
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3),
T.int64(1), T.int64(1)):
- with T.block("T_reshape"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_divide[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)])
- T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
- T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[(v_ax1 +
v_ax2 + v_ax3) % T.int64(3)]
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(28), T.int64(28)):
- with T.block("T_subtract"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3])
- T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(28), T.int64(28)):
- with T.block("T_subtract_1"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3])
- T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(28), T.int64(28)):
- with T.block("T_subtract_2"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
- T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(28), T.int64(28)):
- with T.block("T_multiply"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3],
T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
- T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
- T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] =
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2,
v_ax3]
- for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28),
T.int64(28)):
- with T.block("T_multiply_red"):
- v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0,
k2, k3])
- T.reads(T_multiply[v_k0, v_ax0, v_k2, v_k3])
- T.writes(T_multiply_red[v_ax0])
- with T.init():
- T_multiply_red[v_ax0] = T.float32(0)
- T_multiply_red[v_ax0] = T_multiply_red[v_ax0] +
T_multiply[v_k0, v_ax0, v_k2, v_k3]
- for ax0 in range(T.int64(3)):
- with T.block("T_divide_1"):
- v_ax0 = T.axis.spatial(T.int64(3), ax0)
- T.reads(T_multiply_red[v_ax0])
- T.writes(T_divide_1[v_ax0])
- T_divide_1[v_ax0] = T_multiply_red[v_ax0] *
T.float32(0.00063775510204081628)
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3),
T.int64(1), T.int64(1)):
- with T.block("T_reshape_1"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)])
- T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3])
- T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3),
T.int64(1), T.int64(1)):
- with T.block("T_add"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3])
- T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3])
- T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0,
v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05)
- for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(3), T.int64(1),
T.int64(1)):
- with T.block("compute"):
- v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
- T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3])
- T.writes(compute[v_i0, v_i1, v_i2, v_i3])
- compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0,
v_i1, v_i2, v_i3])
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(28), T.int64(28)):
- with T.block("T_divide_2"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3],
compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3])
- T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0,
v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3),
T.int64(1), T.int64(1)):
- with T.block("T_reshape_2"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(rxplaceholder_1[(v_ax1 + v_ax2 + v_ax3) %
T.int64(3)])
- T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3])
- T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(28), T.int64(28)):
- with T.block("T_multiply_1"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3])
- T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3),
T.int64(1), T.int64(1)):
- with T.block("T_reshape_3"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(rxplaceholder_2[(v_ax1 + v_ax2 + v_ax3) %
T.int64(3)])
- T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3])
- T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder_2[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(28), T.int64(28)):
- with T.block("T_add_1"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
- T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply_1[v_ax0,
v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
- for ax0 in range(T.int64(3)):
- with T.block("T_multiply_2"):
- v_ax0 = T.axis.spatial(T.int64(3), ax0)
- T.reads(rxplaceholder_3[v_ax0])
- T.writes(T_multiply_2[v_ax0])
- T_multiply_2[v_ax0] = T.float32(0.90000000000000002) *
rxplaceholder_3[v_ax0]
- for ax0 in range(T.int64(3)):
- with T.block("T_multiply_3"):
- v_ax0 = T.axis.spatial(T.int64(3), ax0)
- T.reads(T_divide[v_ax0])
- T.writes(T_multiply_3[v_ax0])
- T_multiply_3[v_ax0] = T.float32(0.10000000000000001) *
T_divide[v_ax0]
- for ax0 in range(T.int64(3)):
- with T.block("T_add_2"):
- v_ax0 = T.axis.spatial(T.int64(3), ax0)
- T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0])
- T.writes(T_add_1[v_ax0])
- T_add_1[v_ax0] = T_multiply_2[v_ax0] + T_multiply_3[v_ax0]
- for ax0 in range(T.int64(3)):
- with T.block("T_multiply_4"):
- v_ax0 = T.axis.spatial(T.int64(3), ax0)
- T.reads(rxplaceholder_4[v_ax0])
- T.writes(T_multiply_4[v_ax0])
- T_multiply_4[v_ax0] = T.float32(0.90000000000000002) *
rxplaceholder_4[v_ax0]
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(28), T.int64(28)):
- with T.block("T_subtract_3"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3])
- T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(28), T.int64(28)):
- with T.block("T_subtract_4"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3])
- T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(28), T.int64(28)):
- with T.block("T_multiply_5"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3],
T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3])
- T.writes(T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3])
- T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3] =
T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_4[v_ax0, v_ax1, v_ax2,
v_ax3]
- for ax0, k0, k2, k3 in T.grid(T.int64(3), T.int64(2), T.int64(28),
T.int64(28)):
- with T.block("T_multiply_red_1"):
- v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0,
k2, k3])
- T.reads(T_multiply_5[v_k0, v_ax0, v_k2, v_k3])
- T.writes(T_multiply_red_1[v_ax0])
- with T.init():
- T_multiply_red_1[v_ax0] = T.float32(0)
- T_multiply_red_1[v_ax0] = T_multiply_red_1[v_ax0] +
T_multiply_5[v_k0, v_ax0, v_k2, v_k3]
- for ax0 in range(T.int64(3)):
- with T.block("T_divide_3"):
- v_ax0 = T.axis.spatial(T.int64(3), ax0)
- T.reads(T_multiply_red_1[v_ax0])
- T.writes(T_divide_3[v_ax0])
- T_divide_3[v_ax0] = T_multiply_red_1[v_ax0] *
T.float32(0.00063775510204081628)
- for ax0 in range(T.int64(3)):
- with T.block("T_multiply_6"):
- v_ax0 = T.axis.spatial(T.int64(3), ax0)
- T.reads(T_divide_3[v_ax0])
- T.writes(T_multiply_6[v_ax0])
- T_multiply_6[v_ax0] = T.float32(0.10000000000000001) *
T_divide_3[v_ax0]
- for ax0 in range(T.int64(3)):
- with T.block("T_add_3"):
- v_ax0 = T.axis.spatial(T.int64(3), ax0)
- T.reads(T_multiply_4[v_ax0], T_multiply_6[v_ax0])
- T.writes(T_add_2[v_ax0])
- T_add_2[v_ax0] = T_multiply_4[v_ax0] + T_multiply_6[v_ax0]
+ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta:
T.handle, var_moving_mean: T.handle, var_moving_var: T.handle, var_T_add:
T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ x = T.match_buffer(var_x, (T.int64(2), T.int64(3), T.int64(28),
T.int64(28)))
+ gamma = T.match_buffer(var_gamma, (T.int64(3),))
+ beta = T.match_buffer(var_beta, (T.int64(3),))
+ moving_mean = T.match_buffer(var_moving_mean, (T.int64(3),))
+ moving_var = T.match_buffer(var_moving_var, (T.int64(3),))
+ T_add = T.match_buffer(var_T_add, (T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
+ T_add_1 = T.match_buffer(var_T_add_1, (T.int64(3),))
+ T_add_2 = T.match_buffer(var_T_add_2, (T.int64(3),))
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ T_reshape = T.alloc_buffer((T.int64(1), T.int64(3),
T.int64(1), T.int64(1)))
+ T_subtract = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
+ T_reshape_1 = T.alloc_buffer((T.int64(1), T.int64(3),
T.int64(1), T.int64(1)))
+ T_add_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1),
T.int64(1)))
+ compute = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1),
T.int64(1)))
+ T_divide = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
+ T_reshape_2 = T.alloc_buffer((T.int64(1), T.int64(3),
T.int64(1), T.int64(1)))
+ T_multiply = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
+ T_reshape_3 = T.alloc_buffer((T.int64(1), T.int64(3),
T.int64(1), T.int64(1)))
+ T_multiply_1 = T.alloc_buffer((T.int64(3),))
+ x_red = T.alloc_buffer((T.int64(3),))
+ T_divide_1 = T.alloc_buffer((T.int64(3),))
+ T_multiply_2 = T.alloc_buffer((T.int64(3),))
+ T_multiply_3 = T.alloc_buffer((T.int64(3),))
+ T_reshape_4 = T.alloc_buffer((T.int64(1), T.int64(3),
T.int64(1), T.int64(1)))
+ T_subtract_1 = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
+ T_subtract_2 = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
+ T_multiply_4 = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
+ T_multiply_red = T.alloc_buffer((T.int64(3),))
+ T_divide_2 = T.alloc_buffer((T.int64(3),))
+ T_multiply_5 = T.alloc_buffer((T.int64(3),))
+ for ax0 in range(T.int64(1)):
+ for ax1 in range(T.int64(3)):
+ for ax2 in range(T.int64(1)):
+ for ax3 in range(T.int64(1)):
+ with T.block("T_reshape"):
+ v_ax0 = T.axis.spatial(T.int64(1), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_ax2 = T.axis.spatial(T.int64(1), ax2)
+ v_ax3 = T.axis.spatial(T.int64(1), ax3)
+ T.reads(moving_mean[(v_ax1 + v_ax2 +
v_ax3) % T.int64(3)])
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] =
moving_mean[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
+ for ax0 in range(T.int64(2)):
+ for ax1 in range(T.int64(3)):
+ for ax2 in range(T.int64(28)):
+ for ax3 in range(T.int64(28)):
+ with T.block("T_subtract"):
+ v_ax0 = T.axis.spatial(T.int64(2), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_ax2 = T.axis.spatial(T.int64(28), ax2)
+ v_ax3 = T.axis.spatial(T.int64(28), ax3)
+ T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_subtract[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] =
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
+ for ax0 in range(T.int64(1)):
+ for ax1 in range(T.int64(3)):
+ for ax2 in range(T.int64(1)):
+ for ax3 in range(T.int64(1)):
+ with T.block("T_reshape_1"):
+ v_ax0 = T.axis.spatial(T.int64(1), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_ax2 = T.axis.spatial(T.int64(1), ax2)
+ v_ax3 = T.axis.spatial(T.int64(1), ax3)
+ T.reads(moving_var[(v_ax1 + v_ax2 + v_ax3)
% T.int64(3)])
+ T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] =
moving_var[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
+ for ax0 in range(T.int64(1)):
+ for ax1 in range(T.int64(3)):
+ for ax2 in range(T.int64(1)):
+ for ax3 in range(T.int64(1)):
+ with T.block("T_add"):
+ v_ax0 = T.axis.spatial(T.int64(1), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_ax2 = T.axis.spatial(T.int64(1), ax2)
+ v_ax3 = T.axis.spatial(T.int64(1), ax3)
+ T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T.writes(T_add_3[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] =
T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05)
+ for i0 in range(T.int64(1)):
+ for i1 in range(T.int64(3)):
+ for i2 in range(T.int64(1)):
+ for i3 in range(T.int64(1)):
+ with T.block("compute"):
+ v_i0 = T.axis.spatial(T.int64(1), i0)
+ v_i1 = T.axis.spatial(T.int64(3), i1)
+ v_i2 = T.axis.spatial(T.int64(1), i2)
+ v_i3 = T.axis.spatial(T.int64(1), i3)
+ T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3])
+ T.writes(compute[v_i0, v_i1, v_i2, v_i3])
+ compute[v_i0, v_i1, v_i2, v_i3] =
T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3])
+ for ax0 in range(T.int64(2)):
+ for ax1 in range(T.int64(3)):
+ for ax2 in range(T.int64(28)):
+ for ax3 in range(T.int64(28)):
+ with T.block("T_divide"):
+ v_ax0 = T.axis.spatial(T.int64(2), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_ax2 = T.axis.spatial(T.int64(28), ax2)
+ v_ax3 = T.axis.spatial(T.int64(28), ax3)
+ T.reads(T_subtract[v_ax0, v_ax1, v_ax2,
v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_divide[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_divide[v_ax0, v_ax1, v_ax2, v_ax3] =
T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
+ for ax0 in range(T.int64(1)):
+ for ax1 in range(T.int64(3)):
+ for ax2 in range(T.int64(1)):
+ for ax3 in range(T.int64(1)):
+ with T.block("T_reshape_2"):
+ v_ax0 = T.axis.spatial(T.int64(1), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_ax2 = T.axis.spatial(T.int64(1), ax2)
+ v_ax3 = T.axis.spatial(T.int64(1), ax3)
+ T.reads(gamma[(v_ax1 + v_ax2 + v_ax3) %
T.int64(3)])
+ T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] =
gamma[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
+ for ax0 in range(T.int64(2)):
+ for ax1 in range(T.int64(3)):
+ for ax2 in range(T.int64(28)):
+ for ax3 in range(T.int64(28)):
+ with T.block("T_multiply"):
+ v_ax0 = T.axis.spatial(T.int64(2), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_ax2 = T.axis.spatial(T.int64(28), ax2)
+ v_ax3 = T.axis.spatial(T.int64(28), ax3)
+ T.reads(T_divide[v_ax0, v_ax1, v_ax2,
v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_multiply[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
+ for ax0 in range(T.int64(1)):
+ for ax1 in range(T.int64(3)):
+ for ax2 in range(T.int64(1)):
+ for ax3 in range(T.int64(1)):
+ with T.block("T_reshape_3"):
+ v_ax0 = T.axis.spatial(T.int64(1), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_ax2 = T.axis.spatial(T.int64(1), ax2)
+ v_ax3 = T.axis.spatial(T.int64(1), ax3)
+ T.reads(beta[(v_ax1 + v_ax2 + v_ax3) %
T.int64(3)])
+ T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] =
beta[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
+ for ax0 in range(T.int64(2)):
+ for ax1 in range(T.int64(3)):
+ for ax2 in range(T.int64(28)):
+ for ax3 in range(T.int64(28)):
+ with T.block("T_add_1"):
+ v_ax0 = T.axis.spatial(T.int64(2), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_ax2 = T.axis.spatial(T.int64(28), ax2)
+ v_ax3 = T.axis.spatial(T.int64(28), ax3)
+ T.reads(T_multiply[v_ax0, v_ax1, v_ax2,
v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_add[v_ax0, v_ax1, v_ax2, v_ax3] =
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
+ for ax0 in range(T.int64(3)):
+ with T.block("T_multiply_1"):
+ v_ax0 = T.axis.spatial(T.int64(3), ax0)
+ T.reads(moving_mean[v_ax0])
+ T.writes(T_multiply_1[v_ax0])
+ T_multiply_1[v_ax0] = T.float32(0.90000000000000002) *
moving_mean[v_ax0]
+ for ax0 in range(T.int64(3)):
+ for k0 in range(T.int64(2)):
+ for k2 in range(T.int64(28)):
+ for k3 in range(T.int64(28)):
+ with T.block("x_red"):
+ v_ax0 = T.axis.spatial(T.int64(3), ax0)
+ v_k0 = T.axis.reduce(T.int64(2), k0)
+ v_k2 = T.axis.reduce(T.int64(28), k2)
+ v_k3 = T.axis.reduce(T.int64(28), k3)
+ T.reads(x[v_k0, v_ax0, v_k2, v_k3])
+ T.writes(x_red[v_ax0])
+ with T.init():
+ x_red[v_ax0] = T.float32(0.0)
+ x_red[v_ax0] = x_red[v_ax0] + x[v_k0,
v_ax0, v_k2, v_k3]
+ for ax0 in range(T.int64(3)):
+ with T.block("T_divide_1"):
+ v_ax0 = T.axis.spatial(T.int64(3), ax0)
+ T.reads(x_red[v_ax0])
+ T.writes(T_divide_1[v_ax0])
+ T_divide_1[v_ax0] = x_red[v_ax0] *
T.float32(0.00063775510204081628)
+ for ax0 in range(T.int64(3)):
+ with T.block("T_multiply_2"):
+ v_ax0 = T.axis.spatial(T.int64(3), ax0)
+ T.reads(T_divide_1[v_ax0])
+ T.writes(T_multiply_2[v_ax0])
+ T_multiply_2[v_ax0] = T.float32(0.10000000000000001) *
T_divide_1[v_ax0]
+ for ax0 in range(T.int64(3)):
+ with T.block("T_add_2"):
+ v_ax0 = T.axis.spatial(T.int64(3), ax0)
+ T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0])
+ T.writes(T_add_1[v_ax0])
+ T_add_1[v_ax0] = T_multiply_1[v_ax0] +
T_multiply_2[v_ax0]
+ for ax0 in range(T.int64(3)):
+ with T.block("T_multiply_3"):
+ v_ax0 = T.axis.spatial(T.int64(3), ax0)
+ T.reads(moving_var[v_ax0])
+ T.writes(T_multiply_3[v_ax0])
+ T_multiply_3[v_ax0] = T.float32(0.90000000000000002) *
moving_var[v_ax0]
+ for ax0 in range(T.int64(1)):
+ for ax1 in range(T.int64(3)):
+ for ax2 in range(T.int64(1)):
+ for ax3 in range(T.int64(1)):
+ with T.block("T_reshape_4"):
+ v_ax0 = T.axis.spatial(T.int64(1), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_ax2 = T.axis.spatial(T.int64(1), ax2)
+ v_ax3 = T.axis.spatial(T.int64(1), ax3)
+ T.reads(T_divide_1[(v_ax1 + v_ax2 + v_ax3)
% T.int64(3)])
+ T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
+ for ax0 in range(T.int64(2)):
+ for ax1 in range(T.int64(3)):
+ for ax2 in range(T.int64(28)):
+ for ax3 in range(T.int64(28)):
+ with T.block("T_subtract_1"):
+ v_ax0 = T.axis.spatial(T.int64(2), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_ax2 = T.axis.spatial(T.int64(28), ax2)
+ v_ax3 = T.axis.spatial(T.int64(28), ax3)
+ T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] =
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
+ for ax0 in range(T.int64(2)):
+ for ax1 in range(T.int64(3)):
+ for ax2 in range(T.int64(28)):
+ for ax3 in range(T.int64(28)):
+ with T.block("T_subtract_2"):
+ v_ax0 = T.axis.spatial(T.int64(2), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_ax2 = T.axis.spatial(T.int64(28), ax2)
+ v_ax3 = T.axis.spatial(T.int64(28), ax3)
+ T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] =
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
+ for ax0 in range(T.int64(2)):
+ for ax1 in range(T.int64(3)):
+ for ax2 in range(T.int64(28)):
+ for ax3 in range(T.int64(28)):
+ with T.block("T_multiply_4"):
+ v_ax0 = T.axis.spatial(T.int64(2), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_ax2 = T.axis.spatial(T.int64(28), ax2)
+ v_ax3 = T.axis.spatial(T.int64(28), ax3)
+ T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2,
v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] =
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2,
v_ax3]
+ for ax0 in range(T.int64(3)):
+ for k0 in range(T.int64(2)):
+ for k2 in range(T.int64(28)):
+ for k3 in range(T.int64(28)):
+ with T.block("T_multiply_red"):
+ v_ax0 = T.axis.spatial(T.int64(3), ax0)
+ v_k0 = T.axis.reduce(T.int64(2), k0)
+ v_k2 = T.axis.reduce(T.int64(28), k2)
+ v_k3 = T.axis.reduce(T.int64(28), k3)
+ T.reads(T_multiply_4[v_k0, v_ax0, v_k2,
v_k3])
+ T.writes(T_multiply_red[v_ax0])
+ with T.init():
+ T_multiply_red[v_ax0] = T.float32(0.0)
+ T_multiply_red[v_ax0] =
T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3]
+ for ax0 in range(T.int64(3)):
+ with T.block("T_divide_2"):
+ v_ax0 = T.axis.spatial(T.int64(3), ax0)
+ T.reads(T_multiply_red[v_ax0])
+ T.writes(T_divide_2[v_ax0])
+ T_divide_2[v_ax0] = T_multiply_red[v_ax0] *
T.float32(0.00063775510204081628)
+ for ax0 in range(T.int64(3)):
+ with T.block("T_multiply_5"):
+ v_ax0 = T.axis.spatial(T.int64(3), ax0)
+ T.reads(T_divide_2[v_ax0])
+ T.writes(T_multiply_5[v_ax0])
+ T_multiply_5[v_ax0] = T.float32(0.10000000000000001) *
T_divide_2[v_ax0]
+ for ax0 in range(T.int64(3)):
+ with T.block("T_add_3"):
+ v_ax0 = T.axis.spatial(T.int64(3), ax0)
+ T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0])
+ T.writes(T_add_2[v_ax0])
+ T_add_2[v_ax0] = T_multiply_3[v_ax0] +
T_multiply_5[v_ax0]
@R.function
def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma:
R.Tensor((3,), dtype="float32"), beta: R.Tensor((3,), dtype="float32"),
moving_mean: R.Tensor((3,), dtype="float32"), moving_var: R.Tensor((3,),
dtype="float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), dtype="float32"),
R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")):
- gv = R.call_tir(Expected.batch_norm, (x, gamma, beta, moving_mean,
moving_var), out_sinfo=[R.Tensor((2, 3, 28, 28), dtype="float32"),
R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")])
+ cls = Expected
+ gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean,
moving_var), out_sinfo=[R.Tensor((2, 3, 28, 28), dtype="float32"),
R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")])
return gv
# fmt: on
@@ -2184,230 +2261,295 @@ def test_batch_norm_symbolic():
@tvm.script.ir_module
class Expected:
@T.prim_func(private=True)
- def batch_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1:
T.handle, var_rxplaceholder_2: T.handle, var_rxplaceholder_3: T.handle,
var_rxplaceholder_4: T.handle, var_T_add: T.handle, var_T_add_1: T.handle,
var_T_add_2: T.handle):
- T.func_attr({"tir.noalias": True})
- n = T.int64()
- h = T.int64()
- w = T.int64()
- c = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, (n, h, w, c))
- rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (c,))
- rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (c,))
- rxplaceholder_3 = T.match_buffer(var_rxplaceholder_3, (c,))
- rxplaceholder_4 = T.match_buffer(var_rxplaceholder_4, (c,))
+ def batch_norm(var_x: T.handle, var_gamma: T.handle, var_beta:
T.handle, var_moving_mean: T.handle, var_moving_var: T.handle, var_T_add:
T.handle, var_T_add_1: T.handle, var_T_add_2: T.handle):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ n, h, w, c = T.int64(), T.int64(), T.int64(), T.int64()
+ x = T.match_buffer(var_x, (n, h, w, c))
+ gamma = T.match_buffer(var_gamma, (c,))
+ beta = T.match_buffer(var_beta, (c,))
+ moving_mean = T.match_buffer(var_moving_mean, (c,))
+ moving_var = T.match_buffer(var_moving_var, (c,))
T_add = T.match_buffer(var_T_add, (n, h, w, c))
T_add_1 = T.match_buffer(var_T_add_1, (T.max(c, h),))
T_add_2 = T.match_buffer(var_T_add_2, (T.max(c, h),))
- # with T.block("root"):
- rxplaceholder_red = T.alloc_buffer((h,))
- T_divide = T.alloc_buffer((h,))
- T_reshape = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1)))
- T_subtract = T.alloc_buffer((n, h, w, c))
- T_subtract_1 = T.alloc_buffer((n, h, w, c))
- T_subtract_2 = T.alloc_buffer((n, h, w, c))
- T_multiply = T.alloc_buffer((n, h, w, c))
- T_multiply_red = T.alloc_buffer((h,))
- T_divide_1 = T.alloc_buffer((h,))
- T_reshape_1 = T.alloc_buffer((T.int64(1), h, T.int64(1),
T.int64(1)))
- T_add_3 = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1)))
- compute = T.alloc_buffer((T.int64(1), h, T.int64(1), T.int64(1)))
- T_divide_2 = T.alloc_buffer((n, h, w, c))
- T_reshape_2 = T.alloc_buffer((T.int64(1), h, T.int64(1),
T.int64(1)))
- T_multiply_1 = T.alloc_buffer((n, h, w, c))
- T_reshape_3 = T.alloc_buffer((T.int64(1), h, T.int64(1),
T.int64(1)))
- T_multiply_2 = T.alloc_buffer((c,))
- T_multiply_3 = T.alloc_buffer((h,))
- T_multiply_4 = T.alloc_buffer((c,))
- T_subtract_3 = T.alloc_buffer((n, h, w, c))
- T_subtract_4 = T.alloc_buffer((n, h, w, c))
- T_multiply_5 = T.alloc_buffer((n, h, w, c))
- T_multiply_red_1 = T.alloc_buffer((h,))
- T_divide_3 = T.alloc_buffer((h,))
- T_multiply_6 = T.alloc_buffer((h,))
- for ax0, k0, k2, k3 in T.grid(h, n, w, c):
- with T.block("rxplaceholder_red"):
- v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0,
k2, k3])
- T.reads(rxplaceholder[v_k0, v_ax0, v_k2, v_k3])
- T.writes(rxplaceholder_red[v_ax0])
- with T.init():
- rxplaceholder_red[v_ax0] = T.float32(0)
- rxplaceholder_red[v_ax0] = rxplaceholder_red[v_ax0] +
rxplaceholder[v_k0, v_ax0, v_k2, v_k3]
- for ax0 in range(h):
- with T.block("T_divide"):
- v_ax0 = T.axis.spatial(h, ax0)
- T.reads(rxplaceholder_red[v_ax0])
- T.writes(T_divide[v_ax0])
- T_divide[v_ax0] = rxplaceholder_red[v_ax0] /
T.Cast("float32", n * w * c)
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1),
T.int64(1)):
- with T.block("T_reshape"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_divide[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h])
- T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
- T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = T_divide[(v_ax0 *
h + v_ax1 + v_ax2 + v_ax3) % h]
- for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
- with T.block("T_subtract"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3])
- T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
- for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
- with T.block("T_subtract_1"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3])
- T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
- for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
- with T.block("T_subtract_2"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
- T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
- for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
- with T.block("T_multiply"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3],
T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
- T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
- T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] =
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2,
v_ax3]
- for ax0, k0, k2, k3 in T.grid(h, n, w, c):
- with T.block("T_multiply_red"):
- v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0,
k2, k3])
- T.reads(T_multiply[v_k0, v_ax0, v_k2, v_k3])
- T.writes(T_multiply_red[v_ax0])
- with T.init():
- T_multiply_red[v_ax0] = T.float32(0)
- T_multiply_red[v_ax0] = T_multiply_red[v_ax0] +
T_multiply[v_k0, v_ax0, v_k2, v_k3]
- for ax0 in range(h):
- with T.block("T_divide_1"):
- v_ax0 = T.axis.spatial(h, ax0)
- T.reads(T_multiply_red[v_ax0])
- T.writes(T_divide_1[v_ax0])
- T_divide_1[v_ax0] = T_multiply_red[v_ax0] /
T.Cast("float32", n * w * c)
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1),
T.int64(1)):
- with T.block("T_reshape_1"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) %
h])
- T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3])
- T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1),
T.int64(1)):
- with T.block("T_add"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3])
- T.writes(T_add_3[v_ax0, v_ax1, v_ax2, v_ax3])
- T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_1[v_ax0,
v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05)
- for i0, i1, i2, i3 in T.grid(T.int64(1), h, T.int64(1),
T.int64(1)):
- with T.block("compute"):
- v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
- T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3])
- T.writes(compute[v_i0, v_i1, v_i2, v_i3])
- compute[v_i0, v_i1, v_i2, v_i3] = T.sqrt(T_add_3[v_i0,
v_i1, v_i2, v_i3])
- for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
- with T.block("T_divide_2"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_subtract[v_ax0, v_ax1, v_ax2, v_ax3],
compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3])
- T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] = T_subtract[v_ax0,
v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1),
T.int64(1)):
- with T.block("T_reshape_2"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(rxplaceholder_1[(v_ax0 * h + v_ax1 + v_ax2 +
v_ax3) % c])
- T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3])
- T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
- for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
- with T.block("T_multiply_1"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3])
- T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
- for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), h, T.int64(1),
T.int64(1)):
- with T.block("T_reshape_3"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(rxplaceholder_2[(v_ax0 * h + v_ax1 + v_ax2 +
v_ax3) % c])
- T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3])
- T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder_2[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
- for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
- with T.block("T_add_1"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
- T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_multiply_1[v_ax0,
v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)]
- for ax0 in range(c):
- with T.block("T_multiply_2"):
- v_ax0 = T.axis.spatial(c, ax0)
- T.reads(rxplaceholder_3[v_ax0])
- T.writes(T_multiply_2[v_ax0])
- T_multiply_2[v_ax0] = T.float32(0.90000000000000002) *
rxplaceholder_3[v_ax0]
- for ax0 in range(h):
- with T.block("T_multiply_3"):
- v_ax0 = T.axis.spatial(h, ax0)
- T.reads(T_divide[v_ax0])
- T.writes(T_multiply_3[v_ax0])
- T_multiply_3[v_ax0] = T.float32(0.10000000000000001) *
T_divide[v_ax0]
- for ax0 in range(T.max(c, h)):
- with T.block("T_add_2"):
- v_ax0 = T.axis.spatial(T.max(c, h), ax0)
- T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0])
- T.writes(T_add_1[v_ax0])
- T_add_1[v_ax0] = T_multiply_2[v_ax0] + T_multiply_3[v_ax0]
- for ax0 in range(c):
- with T.block("T_multiply_4"):
- v_ax0 = T.axis.spatial(c, ax0)
- T.reads(rxplaceholder_4[v_ax0])
- T.writes(T_multiply_4[v_ax0])
- T_multiply_4[v_ax0] = T.float32(0.90000000000000002) *
rxplaceholder_4[v_ax0]
- for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
- with T.block("T_subtract_3"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3])
- T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
- for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
- with T.block("T_subtract_4"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3])
- T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
- for ax0, ax1, ax2, ax3 in T.grid(n, h, w, c):
- with T.block("T_multiply_5"):
- v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
- T.reads(T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3],
T_subtract_4[v_ax0, v_ax1, v_ax2, v_ax3])
- T.writes(T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3])
- T_multiply_5[v_ax0, v_ax1, v_ax2, v_ax3] =
T_subtract_3[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_4[v_ax0, v_ax1, v_ax2,
v_ax3]
- for ax0, k0, k2, k3 in T.grid(h, n, w, c):
- with T.block("T_multiply_red_1"):
- v_ax0, v_k0, v_k2, v_k3 = T.axis.remap("SRRR", [ax0, k0,
k2, k3])
- T.reads(T_multiply_5[v_k0, v_ax0, v_k2, v_k3])
- T.writes(T_multiply_red_1[v_ax0])
- with T.init():
- T_multiply_red_1[v_ax0] = T.float32(0)
- T_multiply_red_1[v_ax0] = T_multiply_red_1[v_ax0] +
T_multiply_5[v_k0, v_ax0, v_k2, v_k3]
- for ax0 in range(h):
- with T.block("T_divide_3"):
- v_ax0 = T.axis.spatial(h, ax0)
- T.reads(T_multiply_red_1[v_ax0])
- T.writes(T_divide_3[v_ax0])
- T_divide_3[v_ax0] = T_multiply_red_1[v_ax0] /
T.Cast("float32", n * w * c)
- for ax0 in range(h):
- with T.block("T_multiply_6"):
- v_ax0 = T.axis.spatial(h, ax0)
- T.reads(T_divide_3[v_ax0])
- T.writes(T_multiply_6[v_ax0])
- T_multiply_6[v_ax0] = T.float32(0.10000000000000001) *
T_divide_3[v_ax0]
- for ax0 in range(T.max(c, h)):
- with T.block("T_add_3"):
- v_ax0 = T.axis.spatial(T.max(c, h), ax0)
- T.reads(T_multiply_4[v_ax0], T_multiply_6[v_ax0])
- T.writes(T_add_2[v_ax0])
- T_add_2[v_ax0] = T_multiply_4[v_ax0] + T_multiply_6[v_ax0]
-
- @R.function
- def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma:
R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"),
moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",),
dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"),
R.Tensor(("T.max(c,h)",), dtype="float32"), R.Tensor(("T.max(c,h)",),
dtype="float32")):
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ T_reshape = T.alloc_buffer((T.int64(1), h, T.int64(1),
T.int64(1)))
+ T_subtract = T.alloc_buffer((n, h, w, c))
+ T_reshape_1 = T.alloc_buffer((T.int64(1), h, T.int64(1),
T.int64(1)))
+ T_add_3 = T.alloc_buffer((T.int64(1), h, T.int64(1),
T.int64(1)))
+ compute = T.alloc_buffer((T.int64(1), h, T.int64(1),
T.int64(1)))
+ T_divide = T.alloc_buffer((n, h, w, c))
+ T_reshape_2 = T.alloc_buffer((T.int64(1), h, T.int64(1),
T.int64(1)))
+ T_multiply = T.alloc_buffer((n, h, w, c))
+ T_reshape_3 = T.alloc_buffer((T.int64(1), h, T.int64(1),
T.int64(1)))
+ T_multiply_1 = T.alloc_buffer((c,))
+ x_red = T.alloc_buffer((h,))
+ T_divide_1 = T.alloc_buffer((h,))
+ T_multiply_2 = T.alloc_buffer((h,))
+ T_multiply_3 = T.alloc_buffer((c,))
+ T_reshape_4 = T.alloc_buffer((T.int64(1), h, T.int64(1),
T.int64(1)))
+ T_subtract_1 = T.alloc_buffer((n, h, w, c))
+ T_subtract_2 = T.alloc_buffer((n, h, w, c))
+ T_multiply_4 = T.alloc_buffer((n, h, w, c))
+ T_multiply_red = T.alloc_buffer((h,))
+ T_divide_2 = T.alloc_buffer((h,))
+ T_multiply_5 = T.alloc_buffer((h,))
+ for ax0 in range(T.int64(1)):
+ for ax1 in range(h):
+ for ax2 in range(T.int64(1)):
+ for ax3 in range(T.int64(1)):
+ with T.block("T_reshape"):
+ v_ax0 = T.axis.spatial(T.int64(1), ax0)
+ v_ax1 = T.axis.spatial(h, ax1)
+ v_ax2 = T.axis.spatial(T.int64(1), ax2)
+ v_ax3 = T.axis.spatial(T.int64(1), ax3)
+ T.reads(moving_mean[(v_ax0 * h + v_ax1 +
v_ax2 + v_ax3) % c])
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] =
moving_mean[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
+ for ax0 in range(n):
+ for ax1 in range(h):
+ for ax2 in range(w):
+ for ax3 in range(c):
+ with T.block("T_subtract"):
+ v_ax0 = T.axis.spatial(n, ax0)
+ v_ax1 = T.axis.spatial(h, ax1)
+ v_ax2 = T.axis.spatial(w, ax2)
+ v_ax3 = T.axis.spatial(c, ax3)
+ T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_subtract[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] =
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
+ for ax0 in range(T.int64(1)):
+ for ax1 in range(h):
+ for ax2 in range(T.int64(1)):
+ for ax3 in range(T.int64(1)):
+ with T.block("T_reshape_1"):
+ v_ax0 = T.axis.spatial(T.int64(1), ax0)
+ v_ax1 = T.axis.spatial(h, ax1)
+ v_ax2 = T.axis.spatial(T.int64(1), ax2)
+ v_ax3 = T.axis.spatial(T.int64(1), ax3)
+ T.reads(moving_var[(v_ax0 * h + v_ax1 +
v_ax2 + v_ax3) % c])
+ T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] =
moving_var[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
+ for ax0 in range(T.int64(1)):
+ for ax1 in range(h):
+ for ax2 in range(T.int64(1)):
+ for ax3 in range(T.int64(1)):
+ with T.block("T_add"):
+ v_ax0 = T.axis.spatial(T.int64(1), ax0)
+ v_ax1 = T.axis.spatial(h, ax1)
+ v_ax2 = T.axis.spatial(T.int64(1), ax2)
+ v_ax3 = T.axis.spatial(T.int64(1), ax3)
+ T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T.writes(T_add_3[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_add_3[v_ax0, v_ax1, v_ax2, v_ax3] =
T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] + T.float32(1.0000000000000001e-05)
+ for i0 in range(T.int64(1)):
+ for i1 in range(h):
+ for i2 in range(T.int64(1)):
+ for i3 in range(T.int64(1)):
+ with T.block("compute"):
+ v_i0 = T.axis.spatial(T.int64(1), i0)
+ v_i1 = T.axis.spatial(h, i1)
+ v_i2 = T.axis.spatial(T.int64(1), i2)
+ v_i3 = T.axis.spatial(T.int64(1), i3)
+ T.reads(T_add_3[v_i0, v_i1, v_i2, v_i3])
+ T.writes(compute[v_i0, v_i1, v_i2, v_i3])
+ compute[v_i0, v_i1, v_i2, v_i3] =
T.sqrt(T_add_3[v_i0, v_i1, v_i2, v_i3])
+ for ax0 in range(n):
+ for ax1 in range(h):
+ for ax2 in range(w):
+ for ax3 in range(c):
+ with T.block("T_divide"):
+ v_ax0 = T.axis.spatial(n, ax0)
+ v_ax1 = T.axis.spatial(h, ax1)
+ v_ax2 = T.axis.spatial(w, ax2)
+ v_ax3 = T.axis.spatial(c, ax3)
+ T.reads(T_subtract[v_ax0, v_ax1, v_ax2,
v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_divide[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_divide[v_ax0, v_ax1, v_ax2, v_ax3] =
T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
+ for ax0 in range(T.int64(1)):
+ for ax1 in range(h):
+ for ax2 in range(T.int64(1)):
+ for ax3 in range(T.int64(1)):
+ with T.block("T_reshape_2"):
+ v_ax0 = T.axis.spatial(T.int64(1), ax0)
+ v_ax1 = T.axis.spatial(h, ax1)
+ v_ax2 = T.axis.spatial(T.int64(1), ax2)
+ v_ax3 = T.axis.spatial(T.int64(1), ax3)
+ T.reads(gamma[(v_ax0 * h + v_ax1 + v_ax2 +
v_ax3) % c])
+ T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] =
gamma[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
+ for ax0 in range(n):
+ for ax1 in range(h):
+ for ax2 in range(w):
+ for ax3 in range(c):
+ with T.block("T_multiply"):
+ v_ax0 = T.axis.spatial(n, ax0)
+ v_ax1 = T.axis.spatial(h, ax1)
+ v_ax2 = T.axis.spatial(w, ax2)
+ v_ax3 = T.axis.spatial(c, ax3)
+ T.reads(T_divide[v_ax0, v_ax1, v_ax2,
v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_multiply[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
+ for ax0 in range(T.int64(1)):
+ for ax1 in range(h):
+ for ax2 in range(T.int64(1)):
+ for ax3 in range(T.int64(1)):
+ with T.block("T_reshape_3"):
+ v_ax0 = T.axis.spatial(T.int64(1), ax0)
+ v_ax1 = T.axis.spatial(h, ax1)
+ v_ax2 = T.axis.spatial(T.int64(1), ax2)
+ v_ax3 = T.axis.spatial(T.int64(1), ax3)
+ T.reads(beta[(v_ax0 * h + v_ax1 + v_ax2 +
v_ax3) % c])
+ T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_reshape_3[v_ax0, v_ax1, v_ax2, v_ax3] =
beta[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
+ for ax0 in range(n):
+ for ax1 in range(h):
+ for ax2 in range(w):
+ for ax3 in range(c):
+ with T.block("T_add_1"):
+ v_ax0 = T.axis.spatial(n, ax0)
+ v_ax1 = T.axis.spatial(h, ax1)
+ v_ax2 = T.axis.spatial(w, ax2)
+ v_ax3 = T.axis.spatial(c, ax3)
+ T.reads(T_multiply[v_ax0, v_ax1, v_ax2,
v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_add[v_ax0, v_ax1, v_ax2, v_ax3] =
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
+ for ax0 in range(c):
+ with T.block("T_multiply_1"):
+ v_ax0 = T.axis.spatial(c, ax0)
+ T.reads(moving_mean[v_ax0])
+ T.writes(T_multiply_1[v_ax0])
+ T_multiply_1[v_ax0] = T.float32(0.90000000000000002) *
moving_mean[v_ax0]
+ for ax0 in range(h):
+ for k0 in range(n):
+ for k2 in range(w):
+ for k3 in range(c):
+ with T.block("x_red"):
+ v_ax0 = T.axis.spatial(h, ax0)
+ v_k0 = T.axis.reduce(n, k0)
+ v_k2 = T.axis.reduce(w, k2)
+ v_k3 = T.axis.reduce(c, k3)
+ T.reads(x[v_k0, v_ax0, v_k2, v_k3])
+ T.writes(x_red[v_ax0])
+ with T.init():
+ x_red[v_ax0] = T.float32(0.0)
+ x_red[v_ax0] = x_red[v_ax0] + x[v_k0,
v_ax0, v_k2, v_k3]
+ for ax0 in range(h):
+ with T.block("T_divide_1"):
+ v_ax0 = T.axis.spatial(h, ax0)
+ T.reads(x_red[v_ax0])
+ T.writes(T_divide_1[v_ax0])
+ T_divide_1[v_ax0] = x_red[v_ax0] / T.Cast("float32", n
* w * c)
+ for ax0 in range(h):
+ with T.block("T_multiply_2"):
+ v_ax0 = T.axis.spatial(h, ax0)
+ T.reads(T_divide_1[v_ax0])
+ T.writes(T_multiply_2[v_ax0])
+ T_multiply_2[v_ax0] = T.float32(0.10000000000000001) *
T_divide_1[v_ax0]
+ for ax0 in range(T.max(c, h)):
+ with T.block("T_add_2"):
+ v_ax0 = T.axis.spatial(T.max(c, h), ax0)
+ T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0])
+ T.writes(T_add_1[v_ax0])
+ T_add_1[v_ax0] = T_multiply_1[v_ax0] +
T_multiply_2[v_ax0]
+ for ax0 in range(c):
+ with T.block("T_multiply_3"):
+ v_ax0 = T.axis.spatial(c, ax0)
+ T.reads(moving_var[v_ax0])
+ T.writes(T_multiply_3[v_ax0])
+ T_multiply_3[v_ax0] = T.float32(0.90000000000000002) *
moving_var[v_ax0]
+ for ax0 in range(T.int64(1)):
+ for ax1 in range(h):
+ for ax2 in range(T.int64(1)):
+ for ax3 in range(T.int64(1)):
+ with T.block("T_reshape_4"):
+ v_ax0 = T.axis.spatial(T.int64(1), ax0)
+ v_ax1 = T.axis.spatial(h, ax1)
+ v_ax2 = T.axis.spatial(T.int64(1), ax2)
+ v_ax3 = T.axis.spatial(T.int64(1), ax3)
+ T.reads(T_divide_1[(v_ax0 * h + v_ax1 +
v_ax2 + v_ax3) % h])
+ T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]
+ for ax0 in range(n):
+ for ax1 in range(h):
+ for ax2 in range(w):
+ for ax3 in range(c):
+ with T.block("T_subtract_1"):
+ v_ax0 = T.axis.spatial(n, ax0)
+ v_ax1 = T.axis.spatial(h, ax1)
+ v_ax2 = T.axis.spatial(w, ax2)
+ v_ax3 = T.axis.spatial(c, ax3)
+ T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] =
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
+ for ax0 in range(n):
+ for ax1 in range(h):
+ for ax2 in range(w):
+ for ax3 in range(c):
+ with T.block("T_subtract_2"):
+ v_ax0 = T.axis.spatial(n, ax0)
+ v_ax1 = T.axis.spatial(h, ax1)
+ v_ax2 = T.axis.spatial(w, ax2)
+ v_ax3 = T.axis.spatial(c, ax3)
+ T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] =
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
+ for ax0 in range(n):
+ for ax1 in range(h):
+ for ax2 in range(w):
+ for ax3 in range(c):
+ with T.block("T_multiply_4"):
+ v_ax0 = T.axis.spatial(n, ax0)
+ v_ax1 = T.axis.spatial(h, ax1)
+ v_ax2 = T.axis.spatial(w, ax2)
+ v_ax3 = T.axis.spatial(c, ax3)
+ T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2,
v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] =
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2,
v_ax3]
+ for ax0 in range(h):
+ for k0 in range(n):
+ for k2 in range(w):
+ for k3 in range(c):
+ with T.block("T_multiply_red"):
+ v_ax0 = T.axis.spatial(h, ax0)
+ v_k0 = T.axis.reduce(n, k0)
+ v_k2 = T.axis.reduce(w, k2)
+ v_k3 = T.axis.reduce(c, k3)
+ T.reads(T_multiply_4[v_k0, v_ax0, v_k2,
v_k3])
+ T.writes(T_multiply_red[v_ax0])
+ with T.init():
+ T_multiply_red[v_ax0] = T.float32(0.0)
+ T_multiply_red[v_ax0] =
T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3]
+ for ax0 in range(h):
+ with T.block("T_divide_2"):
+ v_ax0 = T.axis.spatial(h, ax0)
+ T.reads(T_multiply_red[v_ax0])
+ T.writes(T_divide_2[v_ax0])
+ T_divide_2[v_ax0] = T_multiply_red[v_ax0] /
T.Cast("float32", n * w * c)
+ for ax0 in range(h):
+ with T.block("T_multiply_5"):
+ v_ax0 = T.axis.spatial(h, ax0)
+ T.reads(T_divide_2[v_ax0])
+ T.writes(T_multiply_5[v_ax0])
+ T_multiply_5[v_ax0] = T.float32(0.10000000000000001) *
T_divide_2[v_ax0]
+ for ax0 in range(T.max(c, h)):
+ with T.block("T_add_3"):
+ v_ax0 = T.axis.spatial(T.max(c, h), ax0)
+ T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0])
+ T.writes(T_add_2[v_ax0])
+ T_add_2[v_ax0] = T_multiply_3[v_ax0] +
T_multiply_5[v_ax0]
+
+ @R.function
+ def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma:
R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"),
moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",),
dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"),
R.Tensor(("T.max(c, h)",), dtype="float32"), R.Tensor(("T.max(c, h)",),
dtype="float32")):
n = T.int64()
h = T.int64()
w = T.int64()
c = T.int64()
- gv = R.call_tir(Expected.batch_norm, (x, gamma, beta, moving_mean,
moving_var), out_sinfo=[R.Tensor((n, h, w, c), dtype="float32"),
R.Tensor((T.max(c, h),), dtype="float32"), R.Tensor((T.max(c, h),),
dtype="float32")])
+ cls = Expected
+ gv = R.call_tir(cls.batch_norm, (x, gamma, beta, moving_mean,
moving_var), out_sinfo=[R.Tensor((n, h, w, c), dtype="float32"),
R.Tensor((T.max(c, h),), dtype="float32"), R.Tensor((T.max(c, h),),
dtype="float32")])
return gv
- # fmt: on
mod = LegalizeOps()(BatchNorm)
tvm.ir.assert_structural_equal(mod, Expected)