junrushao1994 commented on a change in pull request #9871:
URL: https://github.com/apache/tvm/pull/9871#discussion_r780590829
##########
File path: python/tvm/tir/schedule/schedule.py
##########
@@ -1733,6 +1733,258 @@ def after_set_scope(
########## Schedule: Blockize & Tensorize ##########
+ @type_checked
+ def blockize(self, loop: LoopRV) -> BlockRV:
+ """Convert the subtree rooted at a specific loop into a block.
+
+ Parameters
+ ----------
+ loop : LoopRV
+ The root of the subtree.
+
+ Returns
+ -------
+ result : BlockRV
+ The new block.
+
+ Examples
+ --------
+
+ Before blockize, in TensorIR, the IR is:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def before_blockize(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128, 128), "float32"]
+ ) -> None:
+ for i_0, j_0, i_1, j_1 in T.grid(8, 8, 16, 16):
+ with T.block("B"):
+ vi = T.axis.spatial(128, i_0 * 16 + i_1)
+ vj = T.axis.spatial(128, j_0 * 16 + j_1)
+ T.reads(A[vi, vj])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A[vi, vj] * T.float32(2)
+
+ Create the schedule and do set_scope:
+
+ .. code-block:: python
+
+ sch = tir.Schedule(before_blockize)
+ B = sch.get_block("B")
+ _, _, i1, _ = sch.get_loops(B)
+ sch.blockize(i1)
+ print(sch.mod["main"].script())
+
+ After applying blockize, the IR becomes:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def after_blockize(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128, 128), "float32"]
+ )-> None:
+ for i_0, j_0 in T.grid(8, 8):
+ with T.block("blockized_B"):
+ vio, vjo = T.axis.remap("SS", [i_0, j_0])
+ T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo *
16 + 16])
+ T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo *
16 + 16])
+ for i_1, j_1 in T.grid(16, 16):
+ with T.block("B"):
+ vi = T.axis.spatial(128, vio * 16 + i_1)
+ vj = T.axis.spatial(128, vjo * 16 + j_1)
+ T.reads(A[vi, vj])
+ T.writes(B[vi, vj])
+ B[vi, vj] = A[vi, vj] * T.float32(2)
+
+ Note
+ ----
+ blockize requires there is exactly one block under the given loop and
the bindings of the
+ block are divisible by the subspace represented by the loops starting
at the given loop.
+
+ """
+
+ return _ffi_api.ScheduleBlockize(self, loop) # type: ignore # pylint:
disable=no-member
+
+ @type_checked
+ def tensorize(self, loop: LoopRV, tensor_intrin: Union[str, TensorIntrin])
-> None:
+ """Tensorize the computation enclosed by loop with the tensor
intrinsic.
+
+ Parameters
+ ----------
+ loop : LoopRV
+ The loop to be tensorized.
+ tensor_intrin : Union[str, TensorIntrin]
+ The tensor intrin or the name of the tensor intrin.
+
+ Examples
+ --------
+
+ Before tensorize, in TensorIR, the IR is:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def before_tensorize(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128, 128), "float32"],
+ C: T.Buffer[(128, 128), "float32"],
+ ) -> None:
+ # body
+ # with T.block("root")
+ for i_0, j_0 in T.grid(8, 8):
+ for i_1_init, j_1_init in T.grid(16, 16):
+ with T.block("init"):
+ vi = T.axis.spatial(128, i_0 * 16 + i_1_init)
+ vj = T.axis.spatial(128, j_0 * 16 + j_1_init)
+ T.reads()
+ T.writes(C[vi, vj])
+ C[vi, vj] = T.float32(0)
+ for k_0, i_1, j_1, k_1 in T.grid(8, 16, 16, 16):
+ with T.block("update"):
+ vi = T.axis.spatial(128, i_0 * 16 + i_1)
+ vj = T.axis.spatial(128, j_0 * 16 + j_1)
+ vk = T.axis.reduce(128, k_0 * 16 + k_1)
+ T.reads(C[vi, vj], A[vi, vk], B[vj, vk])
+ T.writes(C[vi, vj])
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]
+
+ Declare and register the tensor intrinsic:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
+ B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
+ C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
+
+ with T.block("root"):
+ vi = T.axis.S(16, 0)
+ vj = T.axis.S(16, 0)
+ vk = T.axis.R(16, 0)
+ for i, j, k in T.grid(16, 16, 16):
+ with T.block("update"):
+ vii = T.axis.S(16, vi + i)
+ vjj = T.axis.S(16, vj + j)
+ vkk = T.axis.R(16, vk + k)
+ C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj,
vkk]
+
+
+ @T.prim_func
+ def mma_intrin(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (16, 16), align=128, offset_factor=1)
+ B = T.match_buffer(b, (16, 16), align=128, offset_factor=1)
+ C = T.match_buffer(c, (16, 16), align=128, offset_factor=1)
+
+ with T.block("root"):
+ vi = T.axis.S(16, 0)
+ vj = T.axis.S(16, 0)
+ vk = T.axis.R(16, 0)
+ T.reads(
+ [
+ C[vi : vi + 16, vj : vj + 16],
+ A[vi : vi + 16, vk : vk + 16],
+ B[vj : vj + 16, vk : vk + 16],
+ ]
+ )
+ T.writes(C[vi : vi + 16, vj : vj + 16])
+ T.evaluate(
+ T.tvm_mma_sync(
+ C.data,
+ C.elem_offset // 256,
+ A.data,
+ A.elem_offset // 256,
+ B.data,
+ B.elem_offset // 256,
+ C.data,
+ C.elem_offset // 256,
+ dtype="handle",
+ )
+ )
+
+ tir.TensorIntrin.register("test_mma_intrin", mma_desc, mma_intrin)
+
+ Create the schedule and do tensorize:
+
+ .. code-block:: python
+
+ sch = tir.Schedule(before_tensorize)
+ update = s.get_block("update")
+ _, _, _, i1, _, _ = s.get_loops(update)
+ s.tensorize(ii, "test_mma_intrin")
+ print(sch.mod["main"].script())
+
+ After applying tensorize, the IR becomes:
+
+ .. code-block:: python
+
+ @T.prim_func
+ def after_tensoirze(
+ A: T.Buffer[(128, 128), "float32"],
+ B: T.Buffer[(128, 128), "float32"],
+ C: T.Buffer[(128, 128), "float32"],
+ ) -> None:
+ # body
+ # with T.block("root")
+ for i_0, j_0 in T.grid(8, 8):
+ for i_1_init, j_1_init in T.grid(16, 16):
+ with T.block("init"):
+ vi = T.axis.spatial(128, i_0 * 16 + i_1_init)
+ vj = T.axis.spatial(128, j_0 * 16 + j_1_init)
+ T.reads()
+ T.writes(C[vi, vj])
+ C[vi, vj] = T.float32(0)
+ for k_0 in T.serial(8):
+ with T.block("blockized_update"):
+ vio, vjo, vko = T.axis.remap("SSR", [i_0, j_0,
k_0])
+ T.reads(
+ C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo *
16 + 16],
+ A[vio * 16 : vio * 16 + 16, vko * 16 : vko *
16 + 16],
+ B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko *
16 + 16],
+ )
+ T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 :
vjo * 16 + 16])
+ A_1 = T.match_buffer(
+ A[vio * 16 : vio * 16 + 16, vko * 16 : vko *
16 + 16],
+ [16, 16],
+ dtype="float32",
+ offset_factor=1,
+ )
+ B_1 = T.match_buffer(
+ B[vjo * 16 : vjo * 16 + 16, vko * 16 : vko *
16 + 16],
+ [16, 16],
+ dtype="float32",
+ offset_factor=1,
+ )
+ C_1 = T.match_buffer(
+ C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo *
16 + 16],
+ [16, 16],
+ dtype="float32",
+ offset_factor=1,
+ )
+ T.evaluate(
+ T.tvm_mma_sync(
+ C_1.data,
+ C_1.elem_offset // 256,
+ A_1.data,
+ A_1.elem_offset // 256,
+ B_1.data,
+ B_1.elem_offset // 256,
+ C_1.data,
+ C_1.elem_offset // 256,
+ dtype="handle",
+ )
+ )
+
+ """
+ if isinstance(tensor_intrin, str):
+ tensor_intrin = String(tensor_intrin)
Review comment:
oh i see. It's because our C++ registry are trying to decide between
String and TensorIntrin - then it makes sense to me
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]