ekalda commented on code in PR #11591:
URL: https://github.com/apache/tvm/pull/11591#discussion_r921124849


##########
tests/python/contrib/test_ethosu/test_copy_compute_reordering.py:
##########
@@ -468,5 +468,288 @@ def main() -> None:
     tvm.ir.assert_structural_equal(test_mod, reference_mod, True)
 
 
+def test_reordering_based_on_cycles():
+    # fmt: off
+    @tvm.script.ir_module
+    class ModuleBefore:
+        @T.prim_func
+        def main(placeholder: T.Buffer[(256,), "int8"], placeholder_encoded: 
T.Buffer[(288,), "uint8"], placeholder_encoded_2: T.Buffer[(128,), "uint8"], 
placeholder_encoded_4: T.Buffer[(288,), "uint8"], placeholder_encoded_6: 
T.Buffer[(128,), "uint8"], placeholder_encoded_8: T.Buffer[(144,), "uint8"], 
ethosu_write: T.Buffer[(572,), "int8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": 
"main", "tir.noalias": True})
+            ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
+            ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
+            ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
+            ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.var("int32")
+            ax0_ax1_fused_ax2_fused_ax3_fused_4 = T.var("int32")
+            nn = T.var("int32")
+            nn_1 = T.var("int32")
+            nn_2 = T.var("int32")
+            nn_3 = T.var("int32")
+            nn_4 = T.var("int32")
+            nn_5 = T.var("int32")
+            nn_6 = T.var("int32")
+            nn_7 = T.var("int32")
+            nn_8 = T.var("int32")
+            nn_9 = T.var("int32")
+            T.preflattened_buffer(placeholder, [1, 8, 8, 4], dtype="int8", 
data=placeholder.data)
+            T.preflattened_buffer(placeholder_encoded, [4, 3, 3, 4], 
dtype="int8")
+            T.preflattened_buffer(placeholder_encoded_2, [4, 3, 3, 1], 
dtype="int8")
+            T.preflattened_buffer(placeholder_encoded_4, [4, 3, 3, 4], 
dtype="int8")
+            T.preflattened_buffer(placeholder_encoded_6, [4, 3, 3, 1], 
dtype="int8")
+            T.preflattened_buffer(placeholder_encoded_8, [4, 1, 3, 4], 
dtype="int8")
+            T.preflattened_buffer(ethosu_write, [1, 13, 11, 4], dtype="int8", 
data=ethosu_write.data)
+            # body
+            placeholder_d_d_global = T.allocate([288], "uint8", "global")
+            ethosu_write_2 = T.allocate([256], "int8", "global")
+            placeholder_d_d_global_2 = T.allocate([128], "uint8", "global")
+            ethosu_write_3 = T.allocate([256], "int8", "global")
+            placeholder_d_d_global_4 = T.allocate([288], "uint8", "global")
+            ethosu_write_4 = T.allocate([256], "int8", "global")
+            ethosu_write_5 = T.allocate([256], "int8", "global")
+            ethosu_write_6 = T.allocate([324], "int8", "global")
+            placeholder_d_global = T.allocate([128], "uint8", "global")
+            ethosu_write_7 = T.allocate([324], "int8", "global")
+            ethosu_write_8 = T.allocate([484], "int8", "global")
+            ethosu_write_9 = T.allocate([484], "int8", "global")
+            ethosu_write_10 = T.allocate([484], "int8", "global")
+            placeholder_global = T.allocate([144], "uint8", "global")
+            with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused, None, 
"DataPar", ""), "pragma_compute_cycles_hint", 2304):
+                T.evaluate(T.call_extern("ethosu_copy", 
placeholder_encoded[0], 288, placeholder_d_d_global[0], dtype="handle"))
+            with T.attr(T.iter_var(nn, None, "DataPar", ""), 
"pragma_compute_cycles_hint", 320):
+                T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 
0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 32, 4, 1, "int8", 8, 
8, 4, 8, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 4, 
1, 3, 3, 1, 1, 1, 1, placeholder_d_d_global[0], 240, T.int8(-1), T.int8(-1), 
12, placeholder_d_d_global[240], 48, T.int8(-1), T.int8(-1), 1, 1, 1, 1, 
"NONE", 0, 0, "TFL", "NONE", 8, 8, 8, dtype="handle"))
+            with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_1, None, 
"DataPar", ""), "pragma_compute_cycles_hint", 576):
+                T.evaluate(T.call_extern("ethosu_copy", 
placeholder_encoded_2[0], 128, placeholder_d_d_global_2[0], dtype="handle"))
+            with T.attr(T.iter_var(nn_1, None, "DataPar", ""), 
"pragma_compute_cycles_hint", 320):
+                T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 
8, 4, 8, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.59999999999999998), 11, 
"NHWC", 32, 4, 1, "int8", 8, 8, 4, 8, 0, 8, ethosu_write_3[0], 0, 0, 0, 
T.float32(0.26000000000000001), 15, "NHWC", 32, 4, 1, 3, 3, 1, 1, 1, 1, 
placeholder_d_d_global_2[0], 80, 13, placeholder_d_d_global_2[80], 48, 1, 1, 1, 
1, "NONE", 0, 0, "TFL", "NONE", 8, 8, 8, dtype="handle"))
+            with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_2, None, 
"DataPar", ""), "pragma_compute_cycles_hint", 2304):
+                T.evaluate(T.call_extern("ethosu_copy", 
placeholder_encoded_4[0], 288, placeholder_d_d_global_4[0], dtype="handle"))
+            with T.attr(T.iter_var(nn_2, None, "DataPar", ""), 
"pragma_compute_cycles_hint", 320):
+                T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 
0, 8, ethosu_write_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 32, 4, 1, "int8", 
8, 8, 4, 8, 0, 8, ethosu_write_4[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 
4, 1, 3, 3, 1, 1, 1, 1, placeholder_d_d_global_4[0], 240, T.int8(-1), 
T.int8(-1), 12, placeholder_d_d_global_4[240], 48, T.int8(-1), T.int8(-1), 1, 
1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 8, 8, 8, dtype="handle"))
+            with T.attr(T.iter_var(nn_3, None, "DataPar", ""), 
"pragma_compute_cycles_hint", 192):
+                T.evaluate(T.call_extern("ethosu_pooling", "int8", 8, 8, 4, 8, 
0, 8, ethosu_write_4[0], 0, 0, 0, T.float32(1), 0, "NHWC", 32, 4, 1, "int8", 8, 
8, 4, 8, 0, 8, ethosu_write_5[0], 0, 0, 0, T.float32(1), 0, "NHWC", 32, 4, 1, 
"MAX", 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 8, 8, 8, 
dtype="handle"))
+            with T.attr(T.iter_var(nn_4, None, "DataPar", ""), 
"pragma_compute_cycles_hint", 300):
+                T.evaluate(T.call_extern("ethosu_pooling", "int8", 8, 8, 4, 8, 
0, 8, ethosu_write_5[0], 0, 0, 0, T.float32(1), 0, "NHWC", 32, 4, 1, "int8", 9, 
9, 4, 9, 0, 9, ethosu_write_6[0], 0, 0, 0, T.float32(1), 0, "NHWC", 36, 4, 1, 
"AVG", 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 10, 10, 8, 
dtype="handle"))
+            with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_3, None, 
"DataPar", ""), "pragma_compute_cycles_hint", 576):
+                T.evaluate(T.call_extern("ethosu_copy", 
placeholder_encoded_6[0], 128, placeholder_d_global[0], dtype="handle"))
+            with T.attr(T.iter_var(nn_5, None, "DataPar", ""), 
"pragma_compute_cycles_hint", 500):
+                T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 9, 
9, 4, 9, 0, 9, ethosu_write_6[0], 0, 0, 0, T.float32(0.59999999999999998), 11, 
"NHWC", 36, 4, 1, "int8", 9, 9, 4, 9, 0, 9, ethosu_write_7[0], 0, 0, 0, 
T.float32(0.26000000000000001), 15, "NHWC", 36, 4, 1, 3, 3, 1, 1, 1, 1, 
placeholder_d_global[0], 80, 13, placeholder_d_global[80], 48, 1, 1, 1, 1, 
"NONE", 0, 0, "TFL", "NONE", 10, 10, 8, dtype="handle"))
+            with T.attr(T.iter_var(nn_6, None, "DataPar", ""), 
"pragma_compute_cycles_hint", 432):
+                T.evaluate(T.call_extern("ethosu_pooling", "int8", 9, 9, 4, 9, 
0, 9, ethosu_write_7[0], 0, 0, 0, T.float32(1), 0, "NHWC", 36, 4, 1, "int8", 
11, 11, 4, 11, 0, 11, ethosu_write_8[0], 0, 0, 0, T.float32(1), 0, "NHWC", 44, 
4, 1, "MAX", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 12, 12, 
8, dtype="handle"))
+            with T.attr(T.iter_var(nn_7, None, "DataPar", ""), 
"pragma_compute_cycles_hint", 432):
+                T.evaluate(T.call_extern("ethosu_pooling", "int8", 11, 11, 4, 
11, 0, 11, ethosu_write_8[0], 0, 0, 0, T.float32(1), 0, "NHWC", 44, 4, 1, 
"int8", 11, 11, 4, 11, 0, 11, ethosu_write_9[0], 0, 0, 0, T.float32(1), 0, 
"NHWC", 44, 4, 1, "AVG", 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, "NONE", 0, 0, "TFL", 
"NONE", 12, 12, 8, dtype="handle"))
+            with T.attr(T.iter_var(nn_8, None, "DataPar", ""), 
"pragma_compute_cycles_hint", 432):
+                T.evaluate(T.call_extern("ethosu_pooling", "int8", 11, 11, 4, 
11, 0, 11, ethosu_write_9[0], 0, 0, 0, T.float32(1), 0, "NHWC", 44, 4, 1, 
"int8", 11, 11, 4, 11, 0, 11, ethosu_write_10[0], 0, 0, 0, T.float32(1), 0, 
"NHWC", 44, 4, 1, "AVG", 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, "NONE", 0, 0, "TFL", 
"NONE", 12, 12, 8, dtype="handle"))
+            with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_4, None, 
"DataPar", ""), "pragma_compute_cycles_hint", 768):
+                T.evaluate(T.call_extern("ethosu_copy", 
placeholder_encoded_8[0], 144, placeholder_global[0], dtype="handle"))
+            T.attr(T.iter_var(nn_9, None, "DataPar", ""), 
"pragma_compute_cycles_hint", 504)
+            T.evaluate(T.call_extern("ethosu_conv2d", "int8", 11, 11, 4, 11, 
0, 11, ethosu_write_10[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 44, 4, 1, 
"int8", 13, 11, 4, 13, 0, 11, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, 
"NHWC", 44, 4, 1, 3, 1, 1, 1, 1, 1, placeholder_global[0], 96, T.int8(-1), 
T.int8(-1), 12, placeholder_global[96], 48, T.int8(-1), T.int8(-1), 1, 1, 1, 1, 
"NONE", 0, 0, "TFL", "NONE", 14, 12, 8, dtype="handle"))
+
+
+
+    @tvm.script.ir_module
+    class ModuleAfter:
+        @T.prim_func
+        def main(placeholder: T.Buffer[(256,), "int8"], placeholder_encoded: 
T.Buffer[(288,), "uint8"], placeholder_encoded_2: T.Buffer[(128,), "uint8"], 
placeholder_encoded_4: T.Buffer[(288,), "uint8"], placeholder_encoded_6: 
T.Buffer[(128,), "uint8"], placeholder_encoded_8: T.Buffer[(144,), "uint8"], 
ethosu_write: T.Buffer[(572,), "int8"]) -> None:
+            # function attr dict
+            T.func_attr({"from_legacy_te_schedule": True, "global_symbol": 
"main", "tir.noalias": True})
+            ax0_ax1_fused_ax2_fused_ax3_fused = T.var("int32")
+            ax0_ax1_fused_ax2_fused_ax3_fused_1 = T.var("int32")
+            ax0_ax1_fused_ax2_fused_ax3_fused_2 = T.var("int32")
+            ax0_ax1_fused_ax2_fused_ax3_fused_3 = T.var("int32")
+            ax0_ax1_fused_ax2_fused_ax3_fused_4 = T.var("int32")
+            nn = T.var("int32")
+            nn_1 = T.var("int32")
+            nn_2 = T.var("int32")
+            nn_3 = T.var("int32")
+            nn_4 = T.var("int32")
+            nn_5 = T.var("int32")
+            nn_6 = T.var("int32")
+            nn_7 = T.var("int32")
+            nn_8 = T.var("int32")
+            nn_9 = T.var("int32")
+            T.preflattened_buffer(placeholder, [1, 8, 8, 4], dtype="int8", 
data=placeholder.data)
+            T.preflattened_buffer(placeholder_encoded, [4, 3, 3, 4], 
dtype="int8", data=placeholder_encoded.data)
+            T.preflattened_buffer(placeholder_encoded_2, [4, 3, 3, 1], 
dtype="int8", data=placeholder_encoded_2.data)
+            T.preflattened_buffer(placeholder_encoded_4, [4, 3, 3, 4], 
dtype="int8", data=placeholder_encoded_4.data)
+            T.preflattened_buffer(placeholder_encoded_6, [4, 3, 3, 1], 
dtype="int8", data=placeholder_encoded_6.data)
+            T.preflattened_buffer(placeholder_encoded_8, [4, 1, 3, 4], 
dtype="int8", data=placeholder_encoded_8.data)
+            T.preflattened_buffer(ethosu_write, [1, 13, 11, 4], dtype="int8", 
data=ethosu_write.data)
+            # body
+            placeholder_d_d_global = T.allocate([288], "uint8", "global")
+            ethosu_write_2 = T.allocate([256], "int8", "global")
+            placeholder_d_d_global_2 = T.allocate([128], "uint8", "global")
+            ethosu_write_3 = T.allocate([256], "int8", "global")
+            placeholder_d_d_global_4 = T.allocate([288], "uint8", "global")
+            ethosu_write_4 = T.allocate([256], "int8", "global")
+            ethosu_write_5 = T.allocate([256], "int8", "global")
+            ethosu_write_6 = T.allocate([324], "int8", "global")
+            placeholder_d_global = T.allocate([128], "uint8", "global")
+            ethosu_write_7 = T.allocate([324], "int8", "global")
+            ethosu_write_8 = T.allocate([484], "int8", "global")
+            ethosu_write_9 = T.allocate([484], "int8", "global")
+            ethosu_write_10 = T.allocate([484], "int8", "global")
+            placeholder_global = T.allocate([144], "uint8", "global")
+            with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused, None, 
"DataPar", ""), "pragma_compute_cycles_hint", 2304):
+                T.evaluate(T.call_extern("ethosu_copy", 
placeholder_encoded[0], 288, placeholder_d_d_global[0], dtype="handle"))
+            with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_1, None, 
"DataPar", ""), "pragma_compute_cycles_hint", 576):
+                T.evaluate(T.call_extern("ethosu_copy", 
placeholder_encoded_2[0], 128, placeholder_d_d_global_2[0], dtype="handle"))
+            with T.attr(T.iter_var(nn, None, "DataPar", ""), 
"pragma_compute_cycles_hint", 320):
+                T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 
0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 32, 4, 1, "int8", 8, 
8, 4, 8, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 4, 
1, 3, 3, 1, 1, 1, 1, placeholder_d_d_global[0], 240, T.int8(-1), T.int8(-1), 
12, placeholder_d_d_global[240], 48, T.int8(-1), T.int8(-1), 1, 1, 1, 1, 
"NONE", 0, 0, "TFL", "NONE", 8, 8, 8, dtype="handle"))
+            with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_2, None, 
"DataPar", ""), "pragma_compute_cycles_hint", 2304):
+                T.evaluate(T.call_extern("ethosu_copy", 
placeholder_encoded_4[0], 288, placeholder_d_d_global_4[0], dtype="handle"))
+            with T.attr(T.iter_var(nn_1, None, "DataPar", ""), 
"pragma_compute_cycles_hint", 320):
+                T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 8, 
8, 4, 8, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.59999999999999998), 11, 
"NHWC", 32, 4, 1, "int8", 8, 8, 4, 8, 0, 8, ethosu_write_3[0], 0, 0, 0, 
T.float32(0.26000000000000001), 15, "NHWC", 32, 4, 1, 3, 3, 1, 1, 1, 1, 
placeholder_d_d_global_2[0], 80, 13, placeholder_d_d_global_2[80], 48, 1, 1, 1, 
1, "NONE", 0, 0, "TFL", "NONE", 8, 8, 8, dtype="handle"))
+            with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_3, None, 
"DataPar", ""), "pragma_compute_cycles_hint", 576):
+                T.evaluate(T.call_extern("ethosu_copy", 
placeholder_encoded_6[0], 128, placeholder_d_global[0], dtype="handle"))
+            with T.attr(T.iter_var(nn_2, None, "DataPar", ""), 
"pragma_compute_cycles_hint", 320):
+                T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 
0, 8, ethosu_write_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 32, 4, 1, "int8", 
8, 8, 4, 8, 0, 8, ethosu_write_4[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 
4, 1, 3, 3, 1, 1, 1, 1, placeholder_d_d_global_4[0], 240, T.int8(-1), 
T.int8(-1), 12, placeholder_d_d_global_4[240], 48, T.int8(-1), T.int8(-1), 1, 
1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 8, 8, 8, dtype="handle"))
+            with T.attr(T.iter_var(nn_3, None, "DataPar", ""), 
"pragma_compute_cycles_hint", 192):
+                T.evaluate(T.call_extern("ethosu_pooling", "int8", 8, 8, 4, 8, 
0, 8, ethosu_write_4[0], 0, 0, 0, T.float32(1), 0, "NHWC", 32, 4, 1, "int8", 8, 
8, 4, 8, 0, 8, ethosu_write_5[0], 0, 0, 0, T.float32(1), 0, "NHWC", 32, 4, 1, 
"MAX", 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 8, 8, 8, 
dtype="handle"))
+            with T.attr(T.iter_var(nn_4, None, "DataPar", ""), 
"pragma_compute_cycles_hint", 300):
+                T.evaluate(T.call_extern("ethosu_pooling", "int8", 8, 8, 4, 8, 
0, 8, ethosu_write_5[0], 0, 0, 0, T.float32(1), 0, "NHWC", 32, 4, 1, "int8", 9, 
9, 4, 9, 0, 9, ethosu_write_6[0], 0, 0, 0, T.float32(1), 0, "NHWC", 36, 4, 1, 
"AVG", 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 10, 10, 8, 
dtype="handle"))
+            with T.attr(T.iter_var(nn_5, None, "DataPar", ""), 
"pragma_compute_cycles_hint", 500):
+                T.evaluate(T.call_extern("ethosu_depthwise_conv2d", "int8", 9, 
9, 4, 9, 0, 9, ethosu_write_6[0], 0, 0, 0, T.float32(0.59999999999999998), 11, 
"NHWC", 36, 4, 1, "int8", 9, 9, 4, 9, 0, 9, ethosu_write_7[0], 0, 0, 0, 
T.float32(0.26000000000000001), 15, "NHWC", 36, 4, 1, 3, 3, 1, 1, 1, 1, 
placeholder_d_global[0], 80, 13, placeholder_d_global[80], 48, 1, 1, 1, 1, 
"NONE", 0, 0, "TFL", "NONE", 10, 10, 8, dtype="handle"))
+            with T.attr(T.iter_var(nn_6, None, "DataPar", ""), 
"pragma_compute_cycles_hint", 432):
+                T.evaluate(T.call_extern("ethosu_pooling", "int8", 9, 9, 4, 9, 
0, 9, ethosu_write_7[0], 0, 0, 0, T.float32(1), 0, "NHWC", 36, 4, 1, "int8", 
11, 11, 4, 11, 0, 11, ethosu_write_8[0], 0, 0, 0, T.float32(1), 0, "NHWC", 44, 
4, 1, "MAX", 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 12, 12, 
8, dtype="handle"))
+            with T.attr(T.iter_var(ax0_ax1_fused_ax2_fused_ax3_fused_4, None, 
"DataPar", ""), "pragma_compute_cycles_hint", 768):
+                T.evaluate(T.call_extern("ethosu_copy", 
placeholder_encoded_8[0], 144, placeholder_global[0], dtype="handle"))

Review Comment:
   Because it is already hidden by the combined cycles of the two following 
pooling ops and we want to do the copy as late as possible to keep the memory 
pressure to minimum 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to