ekalda commented on code in PR #11591:
URL: https://github.com/apache/tvm/pull/11591#discussion_r921124849
##########
tests/python/contrib/test_ethosu/test_copy_compute_reordering.py:
##########
@@ -468,5 +468,288 @@ def main() -> None:
tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
+def test_reordering_based_on_cycles():
+ # fmt: off
+ @tvm.script.ir_module
+ class ModuleBefore:
+ @T.prim_func
+ def main(placeholder: T.Buffer[(256,), "int8"], placeholder_encoded:
T.Buffer[(288,), "uint8"], placeholder_encoded_2: T.Buffer[(128,), "uint8"],
placeholder_encoded_4: T.Buffer[(288,), "uint8"], placeholder_encoded_6:
T.Buffer[(128,), "uint8"], placeholder_encoded_8: T.Buffer[(144,), "uint8"],
ethosu_write: T.Buffer[(572,), "int8"]) -> None:
+ # function attr dict
+ T.func_attr({"from_legacy_te_schedule": True, "global_symbol":
"main", "tir.noalias": True})
+ ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
+ ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
+ ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
+ ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.var("int32")
+ ax0_ax1_fused_ax2_fused_ax3_fused_4 = T.var("int32")
+ nn = T.var("int32")
+ nn_1 = T.var("int32")
+ nn_2 = T.var("int32")
+ nn_3 = T.var("int32")
+ nn_4 = T.var("int32")
+ nn_5 = T.var("int32")
+ nn_6 = T.var("int32")
+ nn_7 = T.var("int32")
+ nn_8 = T.var("int32")
+ nn_9 = T.var("int32")
+ T.preflattened_buffer(placeholder, [1, 8, 8, 4], dtype="int8",
data=placeholder.data)
+ T.preflattened_buffer(placeholder_encoded, [4, 3, 3, 4],
dtype="int8")
+ T.preflattened_buffer(placeholder_encoded_2, [4, 3, 3, 1],
dtype="int8")
+ T.preflattened_buffer(placeholder_encoded_4, [4, 3, 3, 4],
dtype="int8")
+ T.preflattened_buffer(placeholder_encoded_6, [4, 3, 3, 1],
dtype="int8")
+ T.preflattened_buffer(placeholder_encoded_8, [4, 1, 3, 4],
dtype="int8")
+ T.preflattened_buffer(ethosu_write, [1, 13, 11, 4], dtype="int8",
data=ethosu_write.data)
+ # body
+ placeholder_d_d_global = T.allocate([288], "uint8", "global")
+ ethosu_write_2 = T.allocate([256], "int8", "global")
+ placeholder_d_d_global_2 = T.allocate([128], "uint8", "global")
+ ethosu_write_3 = T.allocate([256], "int8", "global")
+ placeholder_d_d_global_4 = T.allocate([288], "uint8", "global")
+ ethosu_write_4 = T.allocate([256], "int8", "global")
+ ethosu_write_5 = T.allocate([256], "int8", "global")
+ ethosu_write_6 = T.allocate([324], "int8", "global")
+ placeholder_d_global = T.allocate([128], "uint8", "global")
+ ethosu_write_7 = T.allocate([324], "int8", "global")
+ ethosu_write_8 = T.allocate([484], "int8", "global")
+ ethosu_write_9 = T.allocate([484], "int8", "global")
+ ethosu_write_10 = T.allocate([484], "int8", "global")
+ placeholder_global = T.allocate([144], "uint8", "global")
+ with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused, None,
"DataPar", ""), "pragma_compute_cycles_hint", 2304):
+ T.evaluate(T.call_extern("ethosu_copy",
placeholder_encoded[0], 288, placeholder_d_d_global[0], dtype="handle"))
+ with T.attr(T.iter_var(nn, None, "DataPar", ""),
"pragma_compute_cycles_hint", 320):
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8,
0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 32, 4, 1, "int8", 8,
8, 4, 8, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 4,
1, 3, 3, 1, 1, 1, 1, placeholder_d_d_global[0], 240, T.int8(-1), T.int8(-1),
12, placeholder_d_d_global[240], 48, T.int8(-1), T.int8(-1), 1, 1, 1, 1,
"NONE", 0, 0, "TFL", "NONE", 8, 8, 8, dtype="handle"))
+ with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_1, None,
"DataPar", ""), "pragma_compute_cycles_hint", 576):
+ T.evaluate(T.call_extern("ethosu_copy",
placeholder_encoded_2[0], 128, placeholder_d_d_global_2[0], dtype="handle"))
+ with T.attr(T.iter_var(nn_1, None, "DataPar", ""),
"pragma_compute_cycles_hint", 320):
+ T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8,
8, 4, 8, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.59999999999999998), 11,
"NHWC", 32, 4, 1, "int8", 8, 8, 4, 8, 0, 8, ethosu_write_3[0], 0, 0, 0,
T.float32(0.26000000000000001), 15, "NHWC", 32, 4, 1, 3, 3, 1, 1, 1, 1,
placeholder_d_d_global_2[0], 80, 13, placeholder_d_d_global_2[80], 48, 1, 1, 1,
1, "NONE", 0, 0, "TFL", "NONE", 8, 8, 8, dtype="handle"))
+ with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_2, None,
"DataPar", ""), "pragma_compute_cycles_hint", 2304):
+ T.evaluate(T.call_extern("ethosu_copy",
placeholder_encoded_4[0], 288, placeholder_d_d_global_4[0], dtype="handle"))
+ with T.attr(T.iter_var(nn_2, None, "DataPar", ""),
"pragma_compute_cycles_hint", 320):
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8,
0, 8, ethosu_write_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 32, 4, 1, "int8",
8, 8, 4, 8, 0, 8, ethosu_write_4[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32,
4, 1, 3, 3, 1, 1, 1, 1, placeholder_d_d_global_4[0], 240, T.int8(-1),
T.int8(-1), 12, placeholder_d_d_global_4[240], 48, T.int8(-1), T.int8(-1), 1,
1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 8, 8, 8, dtype="handle"))
+ with T.attr(T.iter_var(nn_3, None, "DataPar", ""),
"pragma_compute_cycles_hint", 192):
+ T.evaluate(T.call_extern("ethosu_pooling", "int8", 8, 8, 4, 8,
0, 8, ethosu_write_4[0], 0, 0, 0, T.float32(1), 0, "NHWC", 32, 4, 1, "int8", 8,
8, 4, 8, 0, 8, ethosu_write_5[0], 0, 0, 0, T.float32(1), 0, "NHWC", 32, 4, 1,
"MAX", 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 8, 8, 8,
dtype="handle"))
+ with T.attr(T.iter_var(nn_4, None, "DataPar", ""),
"pragma_compute_cycles_hint", 300):
+ T.evaluate(T.call_extern("ethosu_pooling", "int8", 8, 8, 4, 8,
0, 8, ethosu_write_5[0], 0, 0, 0, T.float32(1), 0, "NHWC", 32, 4, 1, "int8", 9,
9, 4, 9, 0, 9, ethosu_write_6[0], 0, 0, 0, T.float32(1), 0, "NHWC", 36, 4, 1,
"AVG", 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 10, 10, 8,
dtype="handle"))
+ with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_3, None,
"DataPar", ""), "pragma_compute_cycles_hint", 576):
+ T.evaluate(T.call_extern("ethosu_copy",
placeholder_encoded_6[0], 128, placeholder_d_global[0], dtype="handle"))
+ with T.attr(T.iter_var(nn_5, None, "DataPar", ""),
"pragma_compute_cycles_hint", 500):
+ T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 9,
9, 4, 9, 0, 9, ethosu_write_6[0], 0, 0, 0, T.float32(0.59999999999999998), 11,
"NHWC", 36, 4, 1, "int8", 9, 9, 4, 9, 0, 9, ethosu_write_7[0], 0, 0, 0,
T.float32(0.26000000000000001), 15, "NHWC", 36, 4, 1, 3, 3, 1, 1, 1, 1,
placeholder_d_global[0], 80, 13, placeholder_d_global[80], 48, 1, 1, 1, 1,
"NONE", 0, 0, "TFL", "NONE", 10, 10, 8, dtype="handle"))
+ with T.attr(T.iter_var(nn_6, None, "DataPar", ""),
"pragma_compute_cycles_hint", 432):
+ T.evaluate(T.call_extern("ethosu_pooling", "int8", 9, 9, 4, 9,
0, 9, ethosu_write_7[0], 0, 0, 0, T.float32(1), 0, "NHWC", 36, 4, 1, "int8",
11, 11, 4, 11, 0, 11, ethosu_write_8[0], 0, 0, 0, T.float32(1), 0, "NHWC", 44,
4, 1, "MAX", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 12, 12,
8, dtype="handle"))
+ with T.attr(T.iter_var(nn_7, None, "DataPar", ""),
"pragma_compute_cycles_hint", 432):
+ T.evaluate(T.call_extern("ethosu_pooling", "int8", 11, 11, 4,
11, 0, 11, ethosu_write_8[0], 0, 0, 0, T.float32(1), 0, "NHWC", 44, 4, 1,
"int8", 11, 11, 4, 11, 0, 11, ethosu_write_9[0], 0, 0, 0, T.float32(1), 0,
"NHWC", 44, 4, 1, "AVG", 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, "NONE", 0, 0, "TFL",
"NONE", 12, 12, 8, dtype="handle"))
+ with T.attr(T.iter_var(nn_8, None, "DataPar", ""),
"pragma_compute_cycles_hint", 432):
+ T.evaluate(T.call_extern("ethosu_pooling", "int8", 11, 11, 4,
11, 0, 11, ethosu_write_9[0], 0, 0, 0, T.float32(1), 0, "NHWC", 44, 4, 1,
"int8", 11, 11, 4, 11, 0, 11, ethosu_write_10[0], 0, 0, 0, T.float32(1), 0,
"NHWC", 44, 4, 1, "AVG", 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, "NONE", 0, 0, "TFL",
"NONE", 12, 12, 8, dtype="handle"))
+ with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_4, None,
"DataPar", ""), "pragma_compute_cycles_hint", 768):
+ T.evaluate(T.call_extern("ethosu_copy",
placeholder_encoded_8[0], 144, placeholder_global[0], dtype="handle"))
+ T.attr(T.iter_var(nn_9, None, "DataPar", ""),
"pragma_compute_cycles_hint", 504)
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 11, 11, 4, 11,
0, 11, ethosu_write_10[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 44, 4, 1,
"int8", 13, 11, 4, 13, 0, 11, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14,
"NHWC", 44, 4, 1, 3, 1, 1, 1, 1, 1, placeholder_global[0], 96, T.int8(-1),
T.int8(-1), 12, placeholder_global[96], 48, T.int8(-1), T.int8(-1), 1, 1, 1, 1,
"NONE", 0, 0, "TFL", "NONE", 14, 12, 8, dtype="handle"))
+
+
+
+ @tvm.script.ir_module
+ class ModuleAfter:
+ @T.prim_func
+ def main(placeholder: T.Buffer[(256,), "int8"], placeholder_encoded:
T.Buffer[(288,), "uint8"], placeholder_encoded_2: T.Buffer[(128,), "uint8"],
placeholder_encoded_4: T.Buffer[(288,), "uint8"], placeholder_encoded_6:
T.Buffer[(128,), "uint8"], placeholder_encoded_8: T.Buffer[(144,), "uint8"],
ethosu_write: T.Buffer[(572,), "int8"]) -> None:
+ # function attr dict
+ T.func_attr({"from_legacy_te_schedule": True, "global_symbol":
"main", "tir.noalias": True})
+ ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
+ ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
+ ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
+ ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.var("int32")
+ ax0_ax1_fused_ax2_fused_ax3_fused_4 = T.var("int32")
+ nn = T.var("int32")
+ nn_1 = T.var("int32")
+ nn_2 = T.var("int32")
+ nn_3 = T.var("int32")
+ nn_4 = T.var("int32")
+ nn_5 = T.var("int32")
+ nn_6 = T.var("int32")
+ nn_7 = T.var("int32")
+ nn_8 = T.var("int32")
+ nn_9 = T.var("int32")
+ T.preflattened_buffer(placeholder, [1, 8, 8, 4], dtype="int8",
data=placeholder.data)
+ T.preflattened_buffer(placeholder_encoded, [4, 3, 3, 4],
dtype="int8", data=placeholder_encoded.data)
+ T.preflattened_buffer(placeholder_encoded_2, [4, 3, 3, 1],
dtype="int8", data=placeholder_encoded_2.data)
+ T.preflattened_buffer(placeholder_encoded_4, [4, 3, 3, 4],
dtype="int8", data=placeholder_encoded_4.data)
+ T.preflattened_buffer(placeholder_encoded_6, [4, 3, 3, 1],
dtype="int8", data=placeholder_encoded_6.data)
+ T.preflattened_buffer(placeholder_encoded_8, [4, 1, 3, 4],
dtype="int8", data=placeholder_encoded_8.data)
+ T.preflattened_buffer(ethosu_write, [1, 13, 11, 4], dtype="int8",
data=ethosu_write.data)
+ # body
+ placeholder_d_d_global = T.allocate([288], "uint8", "global")
+ ethosu_write_2 = T.allocate([256], "int8", "global")
+ placeholder_d_d_global_2 = T.allocate([128], "uint8", "global")
+ ethosu_write_3 = T.allocate([256], "int8", "global")
+ placeholder_d_d_global_4 = T.allocate([288], "uint8", "global")
+ ethosu_write_4 = T.allocate([256], "int8", "global")
+ ethosu_write_5 = T.allocate([256], "int8", "global")
+ ethosu_write_6 = T.allocate([324], "int8", "global")
+ placeholder_d_global = T.allocate([128], "uint8", "global")
+ ethosu_write_7 = T.allocate([324], "int8", "global")
+ ethosu_write_8 = T.allocate([484], "int8", "global")
+ ethosu_write_9 = T.allocate([484], "int8", "global")
+ ethosu_write_10 = T.allocate([484], "int8", "global")
+ placeholder_global = T.allocate([144], "uint8", "global")
+ with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused, None,
"DataPar", ""), "pragma_compute_cycles_hint", 2304):
+ T.evaluate(T.call_extern("ethosu_copy",
placeholder_encoded[0], 288, placeholder_d_d_global[0], dtype="handle"))
+ with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_1, None,
"DataPar", ""), "pragma_compute_cycles_hint", 576):
+ T.evaluate(T.call_extern("ethosu_copy",
placeholder_encoded_2[0], 128, placeholder_d_d_global_2[0], dtype="handle"))
+ with T.attr(T.iter_var(nn, None, "DataPar", ""),
"pragma_compute_cycles_hint", 320):
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8,
0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 32, 4, 1, "int8", 8,
8, 4, 8, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 4,
1, 3, 3, 1, 1, 1, 1, placeholder_d_d_global[0], 240, T.int8(-1), T.int8(-1),
12, placeholder_d_d_global[240], 48, T.int8(-1), T.int8(-1), 1, 1, 1, 1,
"NONE", 0, 0, "TFL", "NONE", 8, 8, 8, dtype="handle"))
+ with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_2, None,
"DataPar", ""), "pragma_compute_cycles_hint", 2304):
+ T.evaluate(T.call_extern("ethosu_copy",
placeholder_encoded_4[0], 288, placeholder_d_d_global_4[0], dtype="handle"))
+ with T.attr(T.iter_var(nn_1, None, "DataPar", ""),
"pragma_compute_cycles_hint", 320):
+ T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8,
8, 4, 8, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.59999999999999998), 11,
"NHWC", 32, 4, 1, "int8", 8, 8, 4, 8, 0, 8, ethosu_write_3[0], 0, 0, 0,
T.float32(0.26000000000000001), 15, "NHWC", 32, 4, 1, 3, 3, 1, 1, 1, 1,
placeholder_d_d_global_2[0], 80, 13, placeholder_d_d_global_2[80], 48, 1, 1, 1,
1, "NONE", 0, 0, "TFL", "NONE", 8, 8, 8, dtype="handle"))
+ with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_3, None,
"DataPar", ""), "pragma_compute_cycles_hint", 576):
+ T.evaluate(T.call_extern("ethosu_copy",
placeholder_encoded_6[0], 128, placeholder_d_global[0], dtype="handle"))
+ with T.attr(T.iter_var(nn_2, None, "DataPar", ""),
"pragma_compute_cycles_hint", 320):
+ T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8,
0, 8, ethosu_write_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 32, 4, 1, "int8",
8, 8, 4, 8, 0, 8, ethosu_write_4[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32,
4, 1, 3, 3, 1, 1, 1, 1, placeholder_d_d_global_4[0], 240, T.int8(-1),
T.int8(-1), 12, placeholder_d_d_global_4[240], 48, T.int8(-1), T.int8(-1), 1,
1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 8, 8, 8, dtype="handle"))
+ with T.attr(T.iter_var(nn_3, None, "DataPar", ""),
"pragma_compute_cycles_hint", 192):
+ T.evaluate(T.call_extern("ethosu_pooling", "int8", 8, 8, 4, 8,
0, 8, ethosu_write_4[0], 0, 0, 0, T.float32(1), 0, "NHWC", 32, 4, 1, "int8", 8,
8, 4, 8, 0, 8, ethosu_write_5[0], 0, 0, 0, T.float32(1), 0, "NHWC", 32, 4, 1,
"MAX", 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 8, 8, 8,
dtype="handle"))
+ with T.attr(T.iter_var(nn_4, None, "DataPar", ""),
"pragma_compute_cycles_hint", 300):
+ T.evaluate(T.call_extern("ethosu_pooling", "int8", 8, 8, 4, 8,
0, 8, ethosu_write_5[0], 0, 0, 0, T.float32(1), 0, "NHWC", 32, 4, 1, "int8", 9,
9, 4, 9, 0, 9, ethosu_write_6[0], 0, 0, 0, T.float32(1), 0, "NHWC", 36, 4, 1,
"AVG", 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 10, 10, 8,
dtype="handle"))
+ with T.attr(T.iter_var(nn_5, None, "DataPar", ""),
"pragma_compute_cycles_hint", 500):
+ T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 9,
9, 4, 9, 0, 9, ethosu_write_6[0], 0, 0, 0, T.float32(0.59999999999999998), 11,
"NHWC", 36, 4, 1, "int8", 9, 9, 4, 9, 0, 9, ethosu_write_7[0], 0, 0, 0,
T.float32(0.26000000000000001), 15, "NHWC", 36, 4, 1, 3, 3, 1, 1, 1, 1,
placeholder_d_global[0], 80, 13, placeholder_d_global[80], 48, 1, 1, 1, 1,
"NONE", 0, 0, "TFL", "NONE", 10, 10, 8, dtype="handle"))
+ with T.attr(T.iter_var(nn_6, None, "DataPar", ""),
"pragma_compute_cycles_hint", 432):
+ T.evaluate(T.call_extern("ethosu_pooling", "int8", 9, 9, 4, 9,
0, 9, ethosu_write_7[0], 0, 0, 0, T.float32(1), 0, "NHWC", 36, 4, 1, "int8",
11, 11, 4, 11, 0, 11, ethosu_write_8[0], 0, 0, 0, T.float32(1), 0, "NHWC", 44,
4, 1, "MAX", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 12, 12,
8, dtype="handle"))
+ with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_4, None,
"DataPar", ""), "pragma_compute_cycles_hint", 768):
+ T.evaluate(T.call_extern("ethosu_copy",
placeholder_encoded_8[0], 144, placeholder_global[0], dtype="handle"))
Review Comment:
Because it is already hidden by the combined cycles of the two following
pooling ops and we want to do the copy as late as possible to keep the memory
pressure to minimum
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]