This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5dc4e785d9 [Relax] Fix batch normalization computation logic (#18609)
5dc4e785d9 is described below
commit 5dc4e785d98892d4061b5c95d741f44b81261bf4
Author: Asuka <[email protected]>
AuthorDate: Fri Dec 26 20:57:48 2025 +0800
[Relax] Fix batch normalization computation logic (#18609)
Dear reviewers,
**Why**
The previous implementation of batch_norm incorrectly conflated the
computation of mean and variance between training and evaluation modes.
Additionally, for '_native_batch_norm_legit.no_statsā, using
instance_norm to handle normalization ignored the batch dimension,
leading to incorrect behavior.
**How**
This PR includes the following fixes:
1. Corrects the computation logic to properly distinguish between
training and evaluation modes.
2. Ensures the batch dimension is properly accounted for in
`_batch_norm_legit_no_stats`.
**Environment**
GPU: NVIDIA A100-SXM4-80GB
---
.../frontend/torch/exported_program_translator.py | 18 +-
python/tvm/topi/nn/batch_norm.py | 31 +-
.../relax/test_frontend_from_exported_program.py | 6 +-
.../python/relax/test_transform_legalize_ops_nn.py | 490 ++++++++++-----------
4 files changed, 260 insertions(+), 285 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 94df0282c8..b6b9723c13 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -121,6 +121,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
x = self.env[node.args[0]]
channel = int(self.shape_of(x)[1])
dtype = x.struct_info.dtype
+ scale = node.args[1] is not None
+ center = node.args[2] is not None
weight = self.env.get(node.args[1], relax.const(np.ones(channel),
dtype=dtype))
bias = self.env.get(node.args[2], relax.const(np.zeros(channel),
dtype=dtype))
running_mean = self.env.get(node.args[3],
relax.const(np.zeros(channel), dtype=dtype))
@@ -134,10 +136,6 @@ class ExportedProgramImporter(BaseFXGraphImporter):
eps = node.args[6] if len(node.args) > 6 else
node.kwargs.get("eps", 1e-05)
training = False
elif target_name.startswith("_native_batch_norm_legit_functional"):
- momentum = node.args[5] if len(node.args) > 5 else
node.kwargs.get("momentum", 0.1)
- eps = node.args[6] if len(node.args) > 6 else
node.kwargs.get("eps", 1e-05)
- training = True
- else:
ignore_running_stats = (
node.args[5] if len(node.args) > 5 else
node.kwargs.get("track_running_stats", True)
)
@@ -147,6 +145,10 @@ class ExportedProgramImporter(BaseFXGraphImporter):
if track_running_stats:
training = True
+ else:
+ momentum = node.args[5] if len(node.args) > 5 else
node.kwargs.get("momentum", 0.1)
+ eps = node.args[6] if len(node.args) > 6 else
node.kwargs.get("eps", 1e-05)
+ training = True
bn_result = self.block_builder.emit(
relax.op.nn.batch_norm(
@@ -157,6 +159,8 @@ class ExportedProgramImporter(BaseFXGraphImporter):
moving_var=running_var,
axis=1, # Always over channel
epsilon=eps,
+ scale=scale,
+ center=center,
momentum=momentum,
training=training,
)
@@ -197,9 +201,11 @@ class ExportedProgramImporter(BaseFXGraphImporter):
bias = self.env.get(node.args[2], relax.const(np.zeros(channel),
dtype=dtype))
eps = node.args[5] if len(node.args) > 5 else node.kwargs.get("eps",
1e-05)
- # Determine axes for instance norm (all spatial dimensions after
channel)
+ # Shared by InstanceNorm (view as [1, N*C, H, W])
+ # and eval-mode BatchNorm without track_running_stats
+ # Determine axes for instance norm (all spatial dimensions after
channel and batch dim)
dim = len(self.shape_of(x))
- axes = list(range(2, dim))
+ axes = [0] + list(range(2, dim))
return self.block_builder.emit(
relax.op.nn.instance_norm(
diff --git a/python/tvm/topi/nn/batch_norm.py b/python/tvm/topi/nn/batch_norm.py
index 8308c93eae..6bd43d29cb 100644
--- a/python/tvm/topi/nn/batch_norm.py
+++ b/python/tvm/topi/nn/batch_norm.py
@@ -110,27 +110,25 @@ def batch_norm(
shape = [1] * len(data.shape)
shape[axis] = data.shape[axis]
-
- reduce_axes = list(range(len(data.shape)))
- reduce_axes.remove(axis)
- shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in
reduce_axes], 1)
-
- data_mean = topi.sum(data, axis=reduce_axes) / shape_prod
- data_mean_rs = topi.reshape(data_mean, shape)
- data_var = (
- topi.sum((data - data_mean_rs) * (data - data_mean_rs),
axis=reduce_axes) / shape_prod
- )
- data_var_rs = topi.reshape(data_var, shape)
+ data_mean = None
+ data_var = None
if training:
+ reduce_axes = list(range(len(data.shape)))
+ reduce_axes.remove(axis)
+ shape_prod = reduce(lambda x, y: x * y, [data.shape[ax] for ax in
reduce_axes], 1)
+ data_mean = topi.sum(data, axis=reduce_axes) / shape_prod
+ data_mean_rs = topi.reshape(data_mean, shape)
+ data_var = (
+ topi.sum((data - data_mean_rs) * (data - data_mean_rs),
axis=reduce_axes) / shape_prod
+ )
+ data_var_rs = topi.reshape(data_var, shape)
+ out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon)
+ else:
moving_mean_rs = topi.reshape(moving_mean, shape)
moving_var_rs = topi.reshape(moving_var, shape)
-
out = (data - moving_mean_rs) / topi.math.sqrt(moving_var_rs + epsilon)
- else:
- out = (data - data_mean_rs) / topi.math.sqrt(data_var_rs + epsilon)
-
if scale:
out = out * topi.reshape(gamma, shape)
if center:
@@ -138,9 +136,6 @@ def batch_norm(
if training:
assert 0 <= momentum <= 1, "the valid momentum range is [0, 1]."
- data_var = (
- topi.sum((data - data_mean_rs) * (data - data_mean_rs),
axis=reduce_axes) / shape_prod
- )
return [
out,
(1 - momentum) * moving_mean + momentum * data_mean,
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index 7894a9fb6d..9f8842ddcb 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -1913,10 +1913,10 @@ def test_batchnorm2d():
w3,
w4,
axis=1,
- epsilon=0.1,
+ epsilon=1e-5,
center=True,
scale=True,
- momentum=1.0,
+ momentum=0.1,
training=True,
)
lv2: R.Tensor((2, 3, 4, 4), dtype="float32") = lv1[0]
@@ -3607,7 +3607,7 @@ def test_instancenorm2d():
w1,
w2,
channel_axis=1,
- axes=[2, 3],
+ axes=[0, 2, 3],
epsilon=1e-05,
center=True,
scale=True,
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index de2f183a10..e81e1bab2a 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -2046,27 +2046,46 @@ def test_batch_norm():
with T.block("root"):
T.reads()
T.writes()
+ x_red = T.alloc_buffer((T.int64(3),))
+ T_divide = T.alloc_buffer((T.int64(3),))
T_reshape = T.alloc_buffer((T.int64(1), T.int64(3),
T.int64(1), T.int64(1)))
T_subtract = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
+ T_subtract_1 = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
+ T_subtract_2 = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
+ T_multiply = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
+ T_multiply_red = T.alloc_buffer((T.int64(3),))
+ T_divide_1 = T.alloc_buffer((T.int64(3),))
T_reshape_1 = T.alloc_buffer((T.int64(1), T.int64(3),
T.int64(1), T.int64(1)))
T_add_3 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1),
T.int64(1)))
compute = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(1),
T.int64(1)))
- T_divide = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
+ T_divide_2 = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
T_reshape_2 = T.alloc_buffer((T.int64(1), T.int64(3),
T.int64(1), T.int64(1)))
- T_multiply = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
+ T_multiply_1 = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
T_reshape_3 = T.alloc_buffer((T.int64(1), T.int64(3),
T.int64(1), T.int64(1)))
- T_multiply_1 = T.alloc_buffer((T.int64(3),))
- x_red = T.alloc_buffer((T.int64(3),))
- T_divide_1 = T.alloc_buffer((T.int64(3),))
T_multiply_2 = T.alloc_buffer((T.int64(3),))
T_multiply_3 = T.alloc_buffer((T.int64(3),))
- T_reshape_4 = T.alloc_buffer((T.int64(1), T.int64(3),
T.int64(1), T.int64(1)))
- T_subtract_1 = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
- T_subtract_2 = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
- T_multiply_4 = T.alloc_buffer((T.int64(2), T.int64(3),
T.int64(28), T.int64(28)))
- T_multiply_red = T.alloc_buffer((T.int64(3),))
- T_divide_2 = T.alloc_buffer((T.int64(3),))
+ T_multiply_4 = T.alloc_buffer((T.int64(3),))
T_multiply_5 = T.alloc_buffer((T.int64(3),))
+ for ax0 in range(T.int64(3)):
+ for k0 in range(T.int64(2)):
+ for k2 in range(T.int64(28)):
+ for k3 in range(T.int64(28)):
+ with T.block("x_red"):
+ v_ax0 = T.axis.spatial(T.int64(3), ax0)
+ v_k0 = T.axis.reduce(T.int64(2), k0)
+ v_k2 = T.axis.reduce(T.int64(28), k2)
+ v_k3 = T.axis.reduce(T.int64(28), k3)
+ T.reads(x[v_k0, v_ax0, v_k2, v_k3])
+ T.writes(x_red[v_ax0])
+ with T.init():
+ x_red[v_ax0] = T.float32(0.0)
+ x_red[v_ax0] = x_red[v_ax0] + x[v_k0,
v_ax0, v_k2, v_k3]
+ for ax0 in range(T.int64(3)):
+ with T.block("T_divide"):
+ v_ax0 = T.axis.spatial(T.int64(3), ax0)
+ T.reads(x_red[v_ax0])
+ T.writes(T_divide[v_ax0])
+ T_divide[v_ax0] = x_red[v_ax0] / T.float32(1568.0)
for ax0 in range(T.int64(1)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(1)):
@@ -2076,9 +2095,9 @@ def test_batch_norm():
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(1), ax2)
v_ax3 = T.axis.spatial(T.int64(1), ax3)
- T.reads(moving_mean[(v_ax1 + v_ax2 +
v_ax3) % T.int64(3)])
+ T.reads(T_divide[(v_ax1 + v_ax2 + v_ax3) %
T.int64(3)])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2,
v_ax3])
- T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] =
moving_mean[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
+ T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
for ax0 in range(T.int64(2)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(28)):
@@ -2091,6 +2110,62 @@ def test_batch_norm():
T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_subtract[v_ax0, v_ax1, v_ax2,
v_ax3])
T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] =
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
+ for ax0 in range(T.int64(2)):
+ for ax1 in range(T.int64(3)):
+ for ax2 in range(T.int64(28)):
+ for ax3 in range(T.int64(28)):
+ with T.block("T_subtract_1"):
+ v_ax0 = T.axis.spatial(T.int64(2), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_ax2 = T.axis.spatial(T.int64(28), ax2)
+ v_ax3 = T.axis.spatial(T.int64(28), ax3)
+ T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] =
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
+ for ax0 in range(T.int64(2)):
+ for ax1 in range(T.int64(3)):
+ for ax2 in range(T.int64(28)):
+ for ax3 in range(T.int64(28)):
+ with T.block("T_subtract_2"):
+ v_ax0 = T.axis.spatial(T.int64(2), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_ax2 = T.axis.spatial(T.int64(28), ax2)
+ v_ax3 = T.axis.spatial(T.int64(28), ax3)
+ T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] =
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
+ for ax0 in range(T.int64(2)):
+ for ax1 in range(T.int64(3)):
+ for ax2 in range(T.int64(28)):
+ for ax3 in range(T.int64(28)):
+ with T.block("T_multiply"):
+ v_ax0 = T.axis.spatial(T.int64(2), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_ax2 = T.axis.spatial(T.int64(28), ax2)
+ v_ax3 = T.axis.spatial(T.int64(28), ax3)
+ T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2,
v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_multiply[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] =
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2,
v_ax3]
+ for ax0 in range(T.int64(3)):
+ for k0 in range(T.int64(2)):
+ for k2 in range(T.int64(28)):
+ for k3 in range(T.int64(28)):
+ with T.block("T_multiply_red"):
+ v_ax0 = T.axis.spatial(T.int64(3), ax0)
+ v_k0 = T.axis.reduce(T.int64(2), k0)
+ v_k2 = T.axis.reduce(T.int64(28), k2)
+ v_k3 = T.axis.reduce(T.int64(28), k3)
+ T.reads(T_multiply[v_k0, v_ax0, v_k2,
v_k3])
+ T.writes(T_multiply_red[v_ax0])
+ with T.init():
+ T_multiply_red[v_ax0] = T.float32(0.0)
+ T_multiply_red[v_ax0] =
T_multiply_red[v_ax0] + T_multiply[v_k0, v_ax0, v_k2, v_k3]
+ for ax0 in range(T.int64(3)):
+ with T.block("T_divide_1"):
+ v_ax0 = T.axis.spatial(T.int64(3), ax0)
+ T.reads(T_multiply_red[v_ax0])
+ T.writes(T_divide_1[v_ax0])
+ T_divide_1[v_ax0] = T_multiply_red[v_ax0] /
T.float32(1568.0)
for ax0 in range(T.int64(1)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(1)):
@@ -2100,9 +2175,9 @@ def test_batch_norm():
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(1), ax2)
v_ax3 = T.axis.spatial(T.int64(1), ax3)
- T.reads(moving_var[(v_ax1 + v_ax2 + v_ax3)
% T.int64(3)])
+ T.reads(T_divide_1[(v_ax1 + v_ax2 + v_ax3)
% T.int64(3)])
T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2,
v_ax3])
- T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] =
moving_var[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
+ T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
for ax0 in range(T.int64(1)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(1)):
@@ -2131,14 +2206,14 @@ def test_batch_norm():
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(28)):
for ax3 in range(T.int64(28)):
- with T.block("T_divide"):
+ with T.block("T_divide_2"):
v_ax0 = T.axis.spatial(T.int64(2), ax0)
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(28), ax2)
v_ax3 = T.axis.spatial(T.int64(28), ax3)
T.reads(T_subtract[v_ax0, v_ax1, v_ax2,
v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_divide[v_ax0, v_ax1, v_ax2,
v_ax3])
- T_divide[v_ax0, v_ax1, v_ax2, v_ax3] =
T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
+ T.writes(T_divide_2[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] =
T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
for ax0 in range(T.int64(1)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(1)):
@@ -2155,14 +2230,14 @@ def test_batch_norm():
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(28)):
for ax3 in range(T.int64(28)):
- with T.block("T_multiply"):
+ with T.block("T_multiply_1"):
v_ax0 = T.axis.spatial(T.int64(2), ax0)
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(28), ax2)
v_ax3 = T.axis.spatial(T.int64(28), ax3)
- T.reads(T_divide[v_ax0, v_ax1, v_ax2,
v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_multiply[v_ax0, v_ax1, v_ax2,
v_ax3])
- T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
+ T.reads(T_divide_2[v_ax0, v_ax1, v_ax2,
v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
for ax0 in range(T.int64(1)):
for ax1 in range(T.int64(3)):
for ax2 in range(T.int64(1)):
@@ -2184,133 +2259,45 @@ def test_batch_norm():
v_ax1 = T.axis.spatial(T.int64(3), ax1)
v_ax2 = T.axis.spatial(T.int64(28), ax2)
v_ax3 = T.axis.spatial(T.int64(28), ax3)
- T.reads(T_multiply[v_ax0, v_ax1, v_ax2,
v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2,
v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
- T_add[v_ax0, v_ax1, v_ax2, v_ax3] =
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
+ T_add[v_ax0, v_ax1, v_ax2, v_ax3] =
T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
for ax0 in range(T.int64(3)):
- with T.block("T_multiply_1"):
+ with T.block("T_multiply_2"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads(moving_mean[v_ax0])
- T.writes(T_multiply_1[v_ax0])
- T_multiply_1[v_ax0] = T.float32(0.90000000000000002) *
moving_mean[v_ax0]
- for ax0 in range(T.int64(3)):
- for k0 in range(T.int64(2)):
- for k2 in range(T.int64(28)):
- for k3 in range(T.int64(28)):
- with T.block("x_red"):
- v_ax0 = T.axis.spatial(T.int64(3), ax0)
- v_k0 = T.axis.reduce(T.int64(2), k0)
- v_k2 = T.axis.reduce(T.int64(28), k2)
- v_k3 = T.axis.reduce(T.int64(28), k3)
- T.reads(x[v_k0, v_ax0, v_k2, v_k3])
- T.writes(x_red[v_ax0])
- with T.init():
- x_red[v_ax0] = T.float32(0.0)
- x_red[v_ax0] = x_red[v_ax0] + x[v_k0,
v_ax0, v_k2, v_k3]
- for ax0 in range(T.int64(3)):
- with T.block("T_divide_1"):
- v_ax0 = T.axis.spatial(T.int64(3), ax0)
- T.reads(x_red[v_ax0])
- T.writes(T_divide_1[v_ax0])
- T_divide_1[v_ax0] = x_red[v_ax0] / T.float32(1568)
+ T.writes(T_multiply_2[v_ax0])
+ T_multiply_2[v_ax0] = T.float32(0.90000000000000002) *
moving_mean[v_ax0]
for ax0 in range(T.int64(3)):
- with T.block("T_multiply_2"):
+ with T.block("T_multiply_3"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
- T.reads(T_divide_1[v_ax0])
- T.writes(T_multiply_2[v_ax0])
- T_multiply_2[v_ax0] = T.float32(0.10000000000000001) *
T_divide_1[v_ax0]
+ T.reads(T_divide[v_ax0])
+ T.writes(T_multiply_3[v_ax0])
+ T_multiply_3[v_ax0] = T.float32(0.10000000000000001) *
T_divide[v_ax0]
for ax0 in range(T.int64(3)):
with T.block("T_add_2"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
- T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0])
+ T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0])
T.writes(T_add_1[v_ax0])
- T_add_1[v_ax0] = T_multiply_1[v_ax0] +
T_multiply_2[v_ax0]
+ T_add_1[v_ax0] = T_multiply_2[v_ax0] +
T_multiply_3[v_ax0]
for ax0 in range(T.int64(3)):
- with T.block("T_multiply_3"):
+ with T.block("T_multiply_4"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
T.reads(moving_var[v_ax0])
- T.writes(T_multiply_3[v_ax0])
- T_multiply_3[v_ax0] = T.float32(0.90000000000000002) *
moving_var[v_ax0]
- for ax0 in range(T.int64(1)):
- for ax1 in range(T.int64(3)):
- for ax2 in range(T.int64(1)):
- for ax3 in range(T.int64(1)):
- with T.block("T_reshape_4"):
- v_ax0 = T.axis.spatial(T.int64(1), ax0)
- v_ax1 = T.axis.spatial(T.int64(3), ax1)
- v_ax2 = T.axis.spatial(T.int64(1), ax2)
- v_ax3 = T.axis.spatial(T.int64(1), ax3)
- T.reads(T_divide_1[(v_ax1 + v_ax2 + v_ax3)
% T.int64(3)])
- T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2,
v_ax3])
- T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide_1[(v_ax1 + v_ax2 + v_ax3) % T.int64(3)]
- for ax0 in range(T.int64(2)):
- for ax1 in range(T.int64(3)):
- for ax2 in range(T.int64(28)):
- for ax3 in range(T.int64(28)):
- with T.block("T_subtract_1"):
- v_ax0 = T.axis.spatial(T.int64(2), ax0)
- v_ax1 = T.axis.spatial(T.int64(3), ax1)
- v_ax2 = T.axis.spatial(T.int64(28), ax2)
- v_ax3 = T.axis.spatial(T.int64(28), ax3)
- T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2,
v_ax3])
- T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] =
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
- for ax0 in range(T.int64(2)):
- for ax1 in range(T.int64(3)):
- for ax2 in range(T.int64(28)):
- for ax3 in range(T.int64(28)):
- with T.block("T_subtract_2"):
- v_ax0 = T.axis.spatial(T.int64(2), ax0)
- v_ax1 = T.axis.spatial(T.int64(3), ax1)
- v_ax2 = T.axis.spatial(T.int64(28), ax2)
- v_ax3 = T.axis.spatial(T.int64(28), ax3)
- T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2,
v_ax3])
- T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] =
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
- for ax0 in range(T.int64(2)):
- for ax1 in range(T.int64(3)):
- for ax2 in range(T.int64(28)):
- for ax3 in range(T.int64(28)):
- with T.block("T_multiply_4"):
- v_ax0 = T.axis.spatial(T.int64(2), ax0)
- v_ax1 = T.axis.spatial(T.int64(3), ax1)
- v_ax2 = T.axis.spatial(T.int64(28), ax2)
- v_ax3 = T.axis.spatial(T.int64(28), ax3)
- T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2,
v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
- T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2,
v_ax3])
- T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] =
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2,
v_ax3]
- for ax0 in range(T.int64(3)):
- for k0 in range(T.int64(2)):
- for k2 in range(T.int64(28)):
- for k3 in range(T.int64(28)):
- with T.block("T_multiply_red"):
- v_ax0 = T.axis.spatial(T.int64(3), ax0)
- v_k0 = T.axis.reduce(T.int64(2), k0)
- v_k2 = T.axis.reduce(T.int64(28), k2)
- v_k3 = T.axis.reduce(T.int64(28), k3)
- T.reads(T_multiply_4[v_k0, v_ax0, v_k2,
v_k3])
- T.writes(T_multiply_red[v_ax0])
- with T.init():
- T_multiply_red[v_ax0] = T.float32(0.0)
- T_multiply_red[v_ax0] =
T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3]
- for ax0 in range(T.int64(3)):
- with T.block("T_divide_2"):
- v_ax0 = T.axis.spatial(T.int64(3), ax0)
- T.reads(T_multiply_red[v_ax0])
- T.writes(T_divide_2[v_ax0])
- T_divide_2[v_ax0] = T_multiply_red[v_ax0] /
T.float32(1568)
+ T.writes(T_multiply_4[v_ax0])
+ T_multiply_4[v_ax0] = T.float32(0.90000000000000002) *
moving_var[v_ax0]
for ax0 in range(T.int64(3)):
with T.block("T_multiply_5"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
- T.reads(T_divide_2[v_ax0])
+ T.reads(T_divide_1[v_ax0])
T.writes(T_multiply_5[v_ax0])
- T_multiply_5[v_ax0] = T.float32(0.10000000000000001) *
T_divide_2[v_ax0]
+ T_multiply_5[v_ax0] = T.float32(0.10000000000000001) *
T_divide_1[v_ax0]
for ax0 in range(T.int64(3)):
with T.block("T_add_3"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
- T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0])
+ T.reads(T_multiply_4[v_ax0], T_multiply_5[v_ax0])
T.writes(T_add_2[v_ax0])
- T_add_2[v_ax0] = T_multiply_3[v_ax0] +
T_multiply_5[v_ax0]
+ T_add_2[v_ax0] = T_multiply_4[v_ax0] +
T_multiply_5[v_ax0]
@R.function
def main(x: R.Tensor((2, 3, 28, 28), dtype="float32"), gamma:
R.Tensor((3,), dtype="float32"), beta: R.Tensor((3,), dtype="float32"),
moving_mean: R.Tensor((3,), dtype="float32"), moving_var: R.Tensor((3,),
dtype="float32")) -> R.Tuple(R.Tensor((2, 3, 28, 28), dtype="float32"),
R.Tensor((3,), dtype="float32"), R.Tensor((3,), dtype="float32")):
@@ -2353,27 +2340,46 @@ def test_batch_norm_symbolic():
with T.block("root"):
T.reads()
T.writes()
+ x_red = T.alloc_buffer((h,))
+ T_divide = T.alloc_buffer((h,))
T_reshape = T.alloc_buffer((T.int64(1), h, T.int64(1),
T.int64(1)))
T_subtract = T.alloc_buffer((n, h, w, c))
+ T_subtract_1 = T.alloc_buffer((n, h, w, c))
+ T_subtract_2 = T.alloc_buffer((n, h, w, c))
+ T_multiply = T.alloc_buffer((n, h, w, c))
+ T_multiply_red = T.alloc_buffer((h,))
+ T_divide_1 = T.alloc_buffer((h,))
T_reshape_1 = T.alloc_buffer((T.int64(1), h, T.int64(1),
T.int64(1)))
T_add_3 = T.alloc_buffer((T.int64(1), h, T.int64(1),
T.int64(1)))
compute = T.alloc_buffer((T.int64(1), h, T.int64(1),
T.int64(1)))
- T_divide = T.alloc_buffer((n, h, w, c))
+ T_divide_2 = T.alloc_buffer((n, h, w, c))
T_reshape_2 = T.alloc_buffer((T.int64(1), h, T.int64(1),
T.int64(1)))
- T_multiply = T.alloc_buffer((n, h, w, c))
+ T_multiply_1 = T.alloc_buffer((n, h, w, c))
T_reshape_3 = T.alloc_buffer((T.int64(1), h, T.int64(1),
T.int64(1)))
- T_multiply_1 = T.alloc_buffer((c,))
- x_red = T.alloc_buffer((h,))
- T_divide_1 = T.alloc_buffer((h,))
- T_multiply_2 = T.alloc_buffer((h,))
- T_multiply_3 = T.alloc_buffer((c,))
- T_reshape_4 = T.alloc_buffer((T.int64(1), h, T.int64(1),
T.int64(1)))
- T_subtract_1 = T.alloc_buffer((n, h, w, c))
- T_subtract_2 = T.alloc_buffer((n, h, w, c))
- T_multiply_4 = T.alloc_buffer((n, h, w, c))
- T_multiply_red = T.alloc_buffer((h,))
- T_divide_2 = T.alloc_buffer((h,))
+ T_multiply_2 = T.alloc_buffer((c,))
+ T_multiply_3 = T.alloc_buffer((h,))
+ T_multiply_4 = T.alloc_buffer((c,))
T_multiply_5 = T.alloc_buffer((h,))
+ for ax0 in range(h):
+ for k0 in range(n):
+ for k2 in range(w):
+ for k3 in range(c):
+ with T.block("x_red"):
+ v_ax0 = T.axis.spatial(h, ax0)
+ v_k0 = T.axis.reduce(n, k0)
+ v_k2 = T.axis.reduce(w, k2)
+ v_k3 = T.axis.reduce(c, k3)
+ T.reads(x[v_k0, v_ax0, v_k2, v_k3])
+ T.writes(x_red[v_ax0])
+ with T.init():
+ x_red[v_ax0] = T.float32(0.0)
+ x_red[v_ax0] = x_red[v_ax0] + x[v_k0,
v_ax0, v_k2, v_k3]
+ for ax0 in range(h):
+ with T.block("T_divide"):
+ v_ax0 = T.axis.spatial(h, ax0)
+ T.reads(x_red[v_ax0])
+ T.writes(T_divide[v_ax0])
+ T_divide[v_ax0] = x_red[v_ax0] / T.Cast("float32", n *
w * c)
for ax0 in range(T.int64(1)):
for ax1 in range(h):
for ax2 in range(T.int64(1)):
@@ -2383,9 +2389,9 @@ def test_batch_norm_symbolic():
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(T.int64(1), ax2)
v_ax3 = T.axis.spatial(T.int64(1), ax3)
- T.reads(moving_mean[(v_ax0 * h + v_ax1 +
v_ax2 + v_ax3) % c])
+ T.reads(T_divide[(v_ax0 * h + v_ax1 +
v_ax2 + v_ax3) % h])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2,
v_ax3])
- T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] =
moving_mean[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
+ T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]
for ax0 in range(n):
for ax1 in range(h):
for ax2 in range(w):
@@ -2398,6 +2404,62 @@ def test_batch_norm_symbolic():
T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_subtract[v_ax0, v_ax1, v_ax2,
v_ax3])
T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] =
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
+ for ax0 in range(n):
+ for ax1 in range(h):
+ for ax2 in range(w):
+ for ax3 in range(c):
+ with T.block("T_subtract_1"):
+ v_ax0 = T.axis.spatial(n, ax0)
+ v_ax1 = T.axis.spatial(h, ax1)
+ v_ax2 = T.axis.spatial(w, ax2)
+ v_ax3 = T.axis.spatial(c, ax3)
+ T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] =
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
+ for ax0 in range(n):
+ for ax1 in range(h):
+ for ax2 in range(w):
+ for ax3 in range(c):
+ with T.block("T_subtract_2"):
+ v_ax0 = T.axis.spatial(n, ax0)
+ v_ax1 = T.axis.spatial(h, ax1)
+ v_ax2 = T.axis.spatial(w, ax2)
+ v_ax3 = T.axis.spatial(c, ax3)
+ T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] =
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
+ for ax0 in range(n):
+ for ax1 in range(h):
+ for ax2 in range(w):
+ for ax3 in range(c):
+ with T.block("T_multiply"):
+ v_ax0 = T.axis.spatial(n, ax0)
+ v_ax1 = T.axis.spatial(h, ax1)
+ v_ax2 = T.axis.spatial(w, ax2)
+ v_ax3 = T.axis.spatial(c, ax3)
+ T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2,
v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_multiply[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] =
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2,
v_ax3]
+ for ax0 in range(h):
+ for k0 in range(n):
+ for k2 in range(w):
+ for k3 in range(c):
+ with T.block("T_multiply_red"):
+ v_ax0 = T.axis.spatial(h, ax0)
+ v_k0 = T.axis.reduce(n, k0)
+ v_k2 = T.axis.reduce(w, k2)
+ v_k3 = T.axis.reduce(c, k3)
+ T.reads(T_multiply[v_k0, v_ax0, v_k2,
v_k3])
+ T.writes(T_multiply_red[v_ax0])
+ with T.init():
+ T_multiply_red[v_ax0] = T.float32(0.0)
+ T_multiply_red[v_ax0] =
T_multiply_red[v_ax0] + T_multiply[v_k0, v_ax0, v_k2, v_k3]
+ for ax0 in range(h):
+ with T.block("T_divide_1"):
+ v_ax0 = T.axis.spatial(h, ax0)
+ T.reads(T_multiply_red[v_ax0])
+ T.writes(T_divide_1[v_ax0])
+ T_divide_1[v_ax0] = T_multiply_red[v_ax0] /
T.Cast("float32", n * w * c)
for ax0 in range(T.int64(1)):
for ax1 in range(h):
for ax2 in range(T.int64(1)):
@@ -2407,9 +2469,9 @@ def test_batch_norm_symbolic():
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(T.int64(1), ax2)
v_ax3 = T.axis.spatial(T.int64(1), ax3)
- T.reads(moving_var[(v_ax0 * h + v_ax1 +
v_ax2 + v_ax3) % c])
+ T.reads(T_divide_1[(v_ax0 * h + v_ax1 +
v_ax2 + v_ax3) % h])
T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2,
v_ax3])
- T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] =
moving_var[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % c]
+ T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]
for ax0 in range(T.int64(1)):
for ax1 in range(h):
for ax2 in range(T.int64(1)):
@@ -2438,14 +2500,14 @@ def test_batch_norm_symbolic():
for ax1 in range(h):
for ax2 in range(w):
for ax3 in range(c):
- with T.block("T_divide"):
+ with T.block("T_divide_2"):
v_ax0 = T.axis.spatial(n, ax0)
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(w, ax2)
v_ax3 = T.axis.spatial(c, ax3)
T.reads(T_subtract[v_ax0, v_ax1, v_ax2,
v_ax3], compute[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_divide[v_ax0, v_ax1, v_ax2,
v_ax3])
- T_divide[v_ax0, v_ax1, v_ax2, v_ax3] =
T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
+ T.writes(T_divide_2[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] =
T_subtract[v_ax0, v_ax1, v_ax2, v_ax3] / compute[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
for ax0 in range(T.int64(1)):
for ax1 in range(h):
for ax2 in range(T.int64(1)):
@@ -2462,14 +2524,14 @@ def test_batch_norm_symbolic():
for ax1 in range(h):
for ax2 in range(w):
for ax3 in range(c):
- with T.block("T_multiply"):
+ with T.block("T_multiply_1"):
v_ax0 = T.axis.spatial(n, ax0)
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(w, ax2)
v_ax3 = T.axis.spatial(c, ax3)
- T.reads(T_divide[v_ax0, v_ax1, v_ax2,
v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_multiply[v_ax0, v_ax1, v_ax2,
v_ax3])
- T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
+ T.reads(T_divide_2[v_ax0, v_ax1, v_ax2,
v_ax3], T_reshape_2[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide_2[v_ax0, v_ax1, v_ax2, v_ax3] * T_reshape_2[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
for ax0 in range(T.int64(1)):
for ax1 in range(h):
for ax2 in range(T.int64(1)):
@@ -2491,133 +2553,45 @@ def test_batch_norm_symbolic():
v_ax1 = T.axis.spatial(h, ax1)
v_ax2 = T.axis.spatial(w, ax2)
v_ax3 = T.axis.spatial(c, ax3)
- T.reads(T_multiply[v_ax0, v_ax1, v_ax2,
v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
+ T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2,
v_ax3], T_reshape_3[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
- T_add[v_ax0, v_ax1, v_ax2, v_ax3] =
T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
+ T_add[v_ax0, v_ax1, v_ax2, v_ax3] =
T_multiply_1[v_ax0, v_ax1, v_ax2, v_ax3] + T_reshape_3[T.int64(0), v_ax1,
T.int64(0), T.int64(0)]
for ax0 in range(c):
- with T.block("T_multiply_1"):
+ with T.block("T_multiply_2"):
v_ax0 = T.axis.spatial(c, ax0)
T.reads(moving_mean[v_ax0])
- T.writes(T_multiply_1[v_ax0])
- T_multiply_1[v_ax0] = T.float32(0.90000000000000002) *
moving_mean[v_ax0]
- for ax0 in range(h):
- for k0 in range(n):
- for k2 in range(w):
- for k3 in range(c):
- with T.block("x_red"):
- v_ax0 = T.axis.spatial(h, ax0)
- v_k0 = T.axis.reduce(n, k0)
- v_k2 = T.axis.reduce(w, k2)
- v_k3 = T.axis.reduce(c, k3)
- T.reads(x[v_k0, v_ax0, v_k2, v_k3])
- T.writes(x_red[v_ax0])
- with T.init():
- x_red[v_ax0] = T.float32(0.0)
- x_red[v_ax0] = x_red[v_ax0] + x[v_k0,
v_ax0, v_k2, v_k3]
- for ax0 in range(h):
- with T.block("T_divide_1"):
- v_ax0 = T.axis.spatial(h, ax0)
- T.reads(x_red[v_ax0])
- T.writes(T_divide_1[v_ax0])
- T_divide_1[v_ax0] = x_red[v_ax0] / T.Cast("float32", n
* w * c)
+ T.writes(T_multiply_2[v_ax0])
+ T_multiply_2[v_ax0] = T.float32(0.90000000000000002) *
moving_mean[v_ax0]
for ax0 in range(h):
- with T.block("T_multiply_2"):
+ with T.block("T_multiply_3"):
v_ax0 = T.axis.spatial(h, ax0)
- T.reads(T_divide_1[v_ax0])
- T.writes(T_multiply_2[v_ax0])
- T_multiply_2[v_ax0] = T.float32(0.10000000000000001) *
T_divide_1[v_ax0]
+ T.reads(T_divide[v_ax0])
+ T.writes(T_multiply_3[v_ax0])
+ T_multiply_3[v_ax0] = T.float32(0.10000000000000001) *
T_divide[v_ax0]
for ax0 in range(T.max(c, h)):
with T.block("T_add_2"):
v_ax0 = T.axis.spatial(T.max(c, h), ax0)
- T.reads(T_multiply_1[v_ax0], T_multiply_2[v_ax0])
+ T.reads(T_multiply_2[v_ax0], T_multiply_3[v_ax0])
T.writes(T_add_1[v_ax0])
- T_add_1[v_ax0] = T_multiply_1[v_ax0] +
T_multiply_2[v_ax0]
+ T_add_1[v_ax0] = T_multiply_2[v_ax0] +
T_multiply_3[v_ax0]
for ax0 in range(c):
- with T.block("T_multiply_3"):
+ with T.block("T_multiply_4"):
v_ax0 = T.axis.spatial(c, ax0)
T.reads(moving_var[v_ax0])
- T.writes(T_multiply_3[v_ax0])
- T_multiply_3[v_ax0] = T.float32(0.90000000000000002) *
moving_var[v_ax0]
- for ax0 in range(T.int64(1)):
- for ax1 in range(h):
- for ax2 in range(T.int64(1)):
- for ax3 in range(T.int64(1)):
- with T.block("T_reshape_4"):
- v_ax0 = T.axis.spatial(T.int64(1), ax0)
- v_ax1 = T.axis.spatial(h, ax1)
- v_ax2 = T.axis.spatial(T.int64(1), ax2)
- v_ax3 = T.axis.spatial(T.int64(1), ax3)
- T.reads(T_divide_1[(v_ax0 * h + v_ax1 +
v_ax2 + v_ax3) % h])
- T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2,
v_ax3])
- T_reshape_4[v_ax0, v_ax1, v_ax2, v_ax3] =
T_divide_1[(v_ax0 * h + v_ax1 + v_ax2 + v_ax3) % h]
- for ax0 in range(n):
- for ax1 in range(h):
- for ax2 in range(w):
- for ax3 in range(c):
- with T.block("T_subtract_1"):
- v_ax0 = T.axis.spatial(n, ax0)
- v_ax1 = T.axis.spatial(h, ax1)
- v_ax2 = T.axis.spatial(w, ax2)
- v_ax3 = T.axis.spatial(c, ax3)
- T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_subtract_1[v_ax0, v_ax1, v_ax2,
v_ax3])
- T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] =
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
- for ax0 in range(n):
- for ax1 in range(h):
- for ax2 in range(w):
- for ax3 in range(c):
- with T.block("T_subtract_2"):
- v_ax0 = T.axis.spatial(n, ax0)
- v_ax1 = T.axis.spatial(h, ax1)
- v_ax2 = T.axis.spatial(w, ax2)
- v_ax3 = T.axis.spatial(c, ax3)
- T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3],
T_reshape_4[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
- T.writes(T_subtract_2[v_ax0, v_ax1, v_ax2,
v_ax3])
- T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3] =
x[v_ax0, v_ax1, v_ax2, v_ax3] - T_reshape_4[T.int64(0), v_ax1, T.int64(0),
T.int64(0)]
- for ax0 in range(n):
- for ax1 in range(h):
- for ax2 in range(w):
- for ax3 in range(c):
- with T.block("T_multiply_4"):
- v_ax0 = T.axis.spatial(n, ax0)
- v_ax1 = T.axis.spatial(h, ax1)
- v_ax2 = T.axis.spatial(w, ax2)
- v_ax3 = T.axis.spatial(c, ax3)
- T.reads(T_subtract_1[v_ax0, v_ax1, v_ax2,
v_ax3], T_subtract_2[v_ax0, v_ax1, v_ax2, v_ax3])
- T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2,
v_ax3])
- T_multiply_4[v_ax0, v_ax1, v_ax2, v_ax3] =
T_subtract_1[v_ax0, v_ax1, v_ax2, v_ax3] * T_subtract_2[v_ax0, v_ax1, v_ax2,
v_ax3]
- for ax0 in range(h):
- for k0 in range(n):
- for k2 in range(w):
- for k3 in range(c):
- with T.block("T_multiply_red"):
- v_ax0 = T.axis.spatial(h, ax0)
- v_k0 = T.axis.reduce(n, k0)
- v_k2 = T.axis.reduce(w, k2)
- v_k3 = T.axis.reduce(c, k3)
- T.reads(T_multiply_4[v_k0, v_ax0, v_k2,
v_k3])
- T.writes(T_multiply_red[v_ax0])
- with T.init():
- T_multiply_red[v_ax0] = T.float32(0.0)
- T_multiply_red[v_ax0] =
T_multiply_red[v_ax0] + T_multiply_4[v_k0, v_ax0, v_k2, v_k3]
- for ax0 in range(h):
- with T.block("T_divide_2"):
- v_ax0 = T.axis.spatial(h, ax0)
- T.reads(T_multiply_red[v_ax0])
- T.writes(T_divide_2[v_ax0])
- T_divide_2[v_ax0] = T_multiply_red[v_ax0] /
T.Cast("float32", n * w * c)
+ T.writes(T_multiply_4[v_ax0])
+ T_multiply_4[v_ax0] = T.float32(0.90000000000000002) *
moving_var[v_ax0]
for ax0 in range(h):
with T.block("T_multiply_5"):
v_ax0 = T.axis.spatial(h, ax0)
- T.reads(T_divide_2[v_ax0])
+ T.reads(T_divide_1[v_ax0])
T.writes(T_multiply_5[v_ax0])
- T_multiply_5[v_ax0] = T.float32(0.10000000000000001) *
T_divide_2[v_ax0]
+ T_multiply_5[v_ax0] = T.float32(0.10000000000000001) *
T_divide_1[v_ax0]
for ax0 in range(T.max(c, h)):
with T.block("T_add_3"):
v_ax0 = T.axis.spatial(T.max(c, h), ax0)
- T.reads(T_multiply_3[v_ax0], T_multiply_5[v_ax0])
+ T.reads(T_multiply_4[v_ax0], T_multiply_5[v_ax0])
T.writes(T_add_2[v_ax0])
- T_add_2[v_ax0] = T_multiply_3[v_ax0] +
T_multiply_5[v_ax0]
+ T_add_2[v_ax0] = T_multiply_4[v_ax0] +
T_multiply_5[v_ax0]
@R.function
def main(x: R.Tensor(("n", "h", "w", "c"), dtype="float32"), gamma:
R.Tensor(("c",), dtype="float32"), beta: R.Tensor(("c",), dtype="float32"),
moving_mean: R.Tensor(("c",), dtype="float32"), moving_var: R.Tensor(("c",),
dtype="float32")) -> R.Tuple(R.Tensor(("n", "h", "w", "c"), dtype="float32"),
R.Tensor(("T.max(c, h)",), dtype="float32"), R.Tensor(("T.max(c, h)",),
dtype="float32")):