This is an automated email from the ASF dual-hosted git repository.
hongyij pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ff0b99c5ce [Dlight] Scheduling Low batch GEMM using GEMV-like rule
(#16579)
ff0b99c5ce is described below
commit ff0b99c5ce4371ec966cd4fa07ae36351faf2a5e
Author: Hongyi Jin <[email protected]>
AuthorDate: Wed Feb 21 11:31:54 2024 -0500
[Dlight] Scheduling Low batch GEMM using GEMV-like rule (#16579)
* low batch
* fix
* fix lint
* do dequantize only once
* change default
* add test
* fix lint
* fix lint
---
python/tvm/dlight/gpu/__init__.py | 1 +
python/tvm/dlight/gpu/low_batch_gemv.py | 605 +++++++++++++++++++++++++
src/driver/driver_api.cc | 9 +-
src/tir/transforms/hoist_expression.cc | 9 +-
tests/python/dlight/test_gpu_low_batch_gemv.py | 255 +++++++++++
5 files changed, 876 insertions(+), 3 deletions(-)
diff --git a/python/tvm/dlight/gpu/__init__.py
b/python/tvm/dlight/gpu/__init__.py
index 7db383a161..077fdcaeb0 100644
--- a/python/tvm/dlight/gpu/__init__.py
+++ b/python/tvm/dlight/gpu/__init__.py
@@ -19,6 +19,7 @@ GPU-generic schedule rules.
For CUDA/ROCm/Vulkan/Metal-specific rules, use
`tvm.dlight.cuda/rocm/vulkan/metal` instead
"""
from .gemv import GEMV
+from .low_batch_gemv import LowBatchGEMV
from .fallback import Fallback
from .matmul import Matmul
from .reduction import Reduction
diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py
b/python/tvm/dlight/gpu/low_batch_gemv.py
new file mode 100644
index 0000000000..dfed020853
--- /dev/null
+++ b/python/tvm/dlight/gpu/low_batch_gemv.py
@@ -0,0 +1,605 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""A rule for low-batch GEMM / decode-GEMM using GEMV schedule."""
+import re
+from functools import reduce
+from typing import List, Optional, Union, Set
+
+from tvm import DataType, arith, ir, tir
+from tvm.target import Target
+
+from ..base import (
+ BlockInfo,
+ collect_block_iter_vars_used_in_access_region,
+ collect_vars_used_in_prim_expr,
+ is_broadcast_epilogue,
+ normalize_prim_func,
+ try_inline_contiguous_spatial,
+)
+from .base import GPUScheduleRule
+
+
+def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]:
+ # Detect and return `Y` in `X[...] = X[...] + Y`
+ buffer_store = block.body
+ if not isinstance(buffer_store, tir.BufferStore):
+ return None
+ if not isinstance(buffer_store.value, tir.Add):
+ return None
+ if not ir.structural_equal(
+ buffer_store.value.a,
+ tir.BufferLoad(buffer_store.buffer, block.body.indices),
+ map_free_vars=True,
+ ):
+ return None
+ return buffer_store.value.b
+
+
+def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV):
+ loop: tir.For = sch.get(loop_rv)
+ return loop.extent.value if isinstance(loop.extent, tir.IntImm) else
loop.extent
+
+
+def get_bytes(dtype: Union[DataType, str]) -> int:
+ num = re.findall(r"\d+", dtype)
+ if len(num) != 1:
+ raise ValueError(f"Cannot get bytes from {dtype}")
+ return int(num[0]) // 8
+
+
+def is_gemv(sch: tir.Schedule, block_info: BlockInfo) ->
Optional[List[tir.Buffer]]:
+ """Check if the block is a low batch GEMM.
+
+ Parameters
+ ----------
+
+ sch : tir.Schedule
+ The schedule
+
+ block_info : BlockInfo
+ The block info to be checked
+
+
+ Returns
+ -------
+ ret : Optional[List[tir.Buffer]]
+ The vector-like buffers used in the low batch GEMM if it is a low
batch GEMM,
+ otherwise None.
+ """
+ block = block_info.block_rv
+ block_stmt = sch.get(block)
+ conditions = []
+ conditions.append(block_info.is_reduction())
+ conditions.append(len(block_stmt.reads) >= 2)
+ conditions.append(len(block_stmt.writes) == 1)
+ conditions.append(_get_reduction_expr(block_stmt) is not None)
+ conditions.append(
+ len(collect_block_iter_vars_used_in_access_region(block_stmt,
block_stmt.writes[0].region))
+ > 0
+ )
+ if not all(conditions):
+ return None
+ const_iter_vars = set(
+ iter_var.var
+ for iter_var in block_stmt.iter_vars
+ if isinstance(iter_var.dom.extent, tir.IntImm)
+ )
+ if len(const_iter_vars) == len(block_stmt.iter_vars):
+ return None
+ ret = [
+ read.buffer
+ for read in block_stmt.reads
+ if len(
+ collect_block_iter_vars_used_in_access_region(block_stmt,
read.region) & const_iter_vars
+ )
+ < len(const_iter_vars)
+ and len(
+ collect_block_iter_vars_used_in_access_region(block_stmt,
read.region) & const_iter_vars
+ )
+ > 0
+ ]
+ return ret if 0 < len(ret) < len(block_stmt.reads) else None
+
+
+def detect_dominant_read(block: tir.Block, const_iter_vars: Set[tir.Var]) ->
tir.PrimExpr:
+ """Detect the dominant read indices in the block."""
+ dominant_read = None
+ num_read_iters = -1
+ for buffer_region in block.reads:
+ tir_vars = (
+ collect_block_iter_vars_used_in_access_region(block,
buffer_region.region)
+ & const_iter_vars
+ )
+ if num_read_iters < len(tir_vars):
+ num_read_iters = len(tir_vars)
+ dominant_read = buffer_region
+ assert dominant_read is not None
+ (result,) = dominant_read.buffer.offset_of([e.min for e in
dominant_read.region])
+ return result
+
+
+def normalize(
+ sch: tir.Schedule,
+ block_info: BlockInfo,
+) -> Optional[bool]:
+ """Normalize the main block."""
+ block_stmt: tir.Block = sch.get(block_info.block_rv)
+ const_iter_vars = set(
+ iter_var.var
+ for iter_var in block_stmt.iter_vars
+ if isinstance(iter_var.dom.extent, tir.IntImm)
+ )
+ dynamic_iter_vars = set(
+ iter_var.var for iter_var in block_stmt.iter_vars if iter_var.var not
in const_iter_vars
+ )
+ access = arith.normalize_to_iter_sum(
+ detect_dominant_read(block_stmt, const_iter_vars),
+ input_iters={i.var: i.dom for i in block_stmt.iter_vars},
+ )
+ buffers_use_vars = [
+ collect_block_iter_vars_used_in_access_region(block_stmt, buf.region)
+ for buf in block_stmt.writes
+ ]
+ buffers_use_vars.extend(
+ [
+ collect_block_iter_vars_used_in_access_region(block_stmt,
buf.region)
+ for buf in block_stmt.reads
+ ]
+ )
+ if collect_vars_used_in_prim_expr(access.base) & set(
+ iter_var.var for iter_var in block_stmt.iter_vars
+ ):
+ return None
+ iter_to_info = {i.var: i for i in block_info.iters}
+ batch_loops, s_loops, r_loops, c_loops = [], [], [], []
+ inner_axis = access.args[-1].source.source
+ is_inner_reduction = iter_to_info[inner_axis].kind == "R"
+
+ for split_expr in access.args:
+ var = split_expr.source.source
+ info = iter_to_info.get(var)
+ loop = info.loop_rv
+ is_reduction = info.kind == "R"
+ if split_expr.lower_factor > 1:
+ if c_loops:
+ return None
+ loop, c_loop = sch.split(loop, factors=[None,
split_expr.lower_factor])
+ # we only support the reduction dim being grouped atm
+ if not is_reduction:
+ return None
+ c_loops.append(c_loop)
+ if is_reduction:
+ r_loops.append(loop)
+ elif all([var in buf_vars for buf_vars in buffers_use_vars]):
+ batch_loops.append(loop)
+ else:
+ s_loops.append(loop)
+
+ assert s_loops
+ assert r_loops
+ if not c_loops:
+ c_loops = [sch.add_unit_loop(block_info.block_rv)]
+ dynamic_loops = [iter_to_info[var].loop_rv for var in dynamic_iter_vars]
+ assert len(dynamic_loops) == 1
+ if not batch_loops:
+ batch_loops = [sch.add_unit_loop(block_info.block_rv)]
+ sch.reorder(*dynamic_loops, *batch_loops, *s_loops, *r_loops, *c_loops)
+ sch.fuse(*batch_loops)
+ sch.fuse(*s_loops)
+ sch.fuse(*r_loops)
+ return is_inner_reduction
+
+
+class LowBatchGEMV(GPUScheduleRule):
+ """A rule for low batch GEMM / decode-GEMM."""
+
+ def __init__(self, bucket=4):
+ self.bucket = bucket
+
+ def apply( # pylint:
disable=too-many-locals,too-many-branches,too-many-return-statements
+ self,
+ func: tir.PrimFunc,
+ target: Target,
+ _: bool,
+ ) -> Union[None, tir.Schedule, List[tir.Schedule]]:
+ if not isinstance(func, tir.PrimFunc) or not
self.is_target_available(target):
+ return None
+ sch = tir.Schedule(func)
+ block_infos = normalize_prim_func(sch)
+
+ reduction_block_infos = [
+ block_info for block_info in block_infos if
block_info.is_reduction()
+ ]
+ if len(reduction_block_infos) != 1:
+ return None
+ reduction_block_info = reduction_block_infos[0]
+ vector_input_buffers = is_gemv(sch, reduction_block_info)
+ if vector_input_buffers is None:
+ return None
+ batch_pad = self.bucket
+ pad_value = [
+ iter.dom if isinstance(iter.dom, int) else batch_pad
+ for iter in reduction_block_info.iters
+ ]
+ sch.pad_einsum(reduction_block_info.block_rv, pad_value)
+ block_infos = normalize_prim_func(sch)
+ dequantize_block = None
+ pad_input_block = None
+ for block_info in block_infos:
+ if "dequantize" in block_info.name:
+ dequantize_block = block_info.block_rv
+ elif "pad" in block_info.name and
len(sch.get_producers(block_info.block_rv)) == 0:
+ pad_input_block = block_info.block_rv
+ block_infos = [
+ block_info
+ for block_info in block_infos
+ if "pad" not in block_info.name and "dequantize" not in
block_info.name
+ ]
+ block_infos = try_inline_contiguous_spatial(sch, block_infos)
+ if len(block_infos) == 1:
+ epilogue = None
+ elif len(block_infos) == 2:
+ epilogue = block_infos[1]
+ if not epilogue.is_injective():
+ return None
+ else:
+ return None
+
+ block_info = block_infos[0]
+ if len(block_info.iters) not in [2, 3]:
+ # either [B, S, R] = [B, S, R] * [B, R]
+ # or [S, R] = [S, R] * [R]
+ return None
+ block = block_info.block_rv
+ vector_input_buffers = is_gemv(sch, block_info)
+ if vector_input_buffers is None:
+ return None
+
+ # Step 1. Normalize the block, merge spatial and reduction iters
+ is_inner_reduction = normalize(sch, block_info)
+ # Step 2. Do the scheduling
+ if is_inner_reduction is None:
+ return None
+ elif is_inner_reduction:
+ self.sch_inner_reduction(
+ sch,
+ target,
+ block,
+ dequantize_block,
+ pad_input_block,
+ vector_input_buffers,
+ epilogue,
+ batch_pad,
+ )
+ return sch
+ else:
+ raise NotImplementedError("Outer reduction is not supported yet")
+
+ def sch_inner_reduction( # pylint: disable=too-many-arguments,
invalid-name, unused-argument
+ self,
+ sch: tir.Schedule,
+ target: Target,
+ block: tir.schedule.BlockRV,
+ dequantize_block: Optional[tir.schedule.BlockRV],
+ pad_input_block: Optional[tir.schedule.BlockRV],
+ vector_input_buffers: List[tir.Buffer],
+ epilogue_info: Optional[BlockInfo],
+ batch_pad: int,
+ ):
+ """Schedule the inner reduction block."""
+
+ def get_max_factor(n, factors):
+ factors = sorted(factors, reverse=True)
+ for factor in factors:
+ if n % factor == 0:
+ return factor
+ return 1
+
+ def apply(
+ sch: tir.Schedule,
+ gemv,
+ TAG_S,
+ TAG_R,
+ TS,
+ TR,
+ TILE_S,
+ TILE_R,
+ VEC_LOAD,
+ VEC_C,
+ LOAD_V_SHARED,
+ LOAD_V_VEC,
+ UNROLL,
+ ):
+ # rfactor: reduce to tx * vec_c
+
+ _, b, s, r, c = sch.get_loops(block=gemv)
+ s = sch.fuse(b, s)
+ r = sch.fuse(r, c)
+ bx, ts, tile_s = sch.split(s, factors=[None, TS, TILE_S],
preserve_unit_iters=True)
+ r, tr, tile_r_vec_n, vec_c = sch.split(
+ r, factors=[None, TR, TILE_R // VEC_C, VEC_C],
preserve_unit_iters=True
+ )
+ sch.reorder(r, tile_r_vec_n, tr, vec_c)
+ tr_vec_c = sch.fuse(tr, vec_c)
+ rf = sch.rfactor(tr_vec_c, 0)
+
+ # rfactor: reduce to tx
+ _, bx, ts, tile_s, tr_vec_c = sch.get_loops(block=gemv)
+ tr, vec_c = sch.split(tr_vec_c, factors=[TR, None],
preserve_unit_iters=True)
+ rf2 = sch.rfactor(tr, 0)
+ # bind, vectorize compute
+ batch_loop, bx, ts, tile_s, r, tile_r_vec_n, tr_vec_c =
sch.get_loops(block=rf)
+ tr, vec_c = sch.split(tr_vec_c, factors=[TR, None],
preserve_unit_iters=True)
+ sch.reorder(bx, ts, tr, r, tile_s, tile_r_vec_n, vec_c)
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(ts, TAG_S)
+ sch.bind(tr, TAG_R)
+ sch.vectorize(vec_c)
+ by, batch = sch.split(batch_loop, factors=[None, batch_pad])
+ sch.bind(by, "blockIdx.y")
+ sch.reorder(bx, ts, tr, r, batch)
+
+ shared_mem_usage = 0
+ for buf in vector_input_buffers:
+ buf_size = reduce(
+ lambda x, y: x * y, buf.shape,
tir.IntImm(buf.shape[0].dtype, 1)
+ ) * get_bytes(buf.dtype)
+ shared_mem_usage += buf_size
+ LOAD_V_SHARED = (
+ LOAD_V_SHARED
+ and isinstance(shared_mem_usage, tir.IntImm)
+ and shared_mem_usage.value <=
target.max_shared_memory_per_block
+ )
+
+ # vectorize load A
+ # (TODO) this is now actually problematic since the number of
loops is dependent on the
+ # number of dimensions of A_q
+ if dequantize_block is not None:
+ sch.compute_at(dequantize_block, r, preserve_unit_loops=True)
+ sch.set_scope(dequantize_block, 0, "local")
+
+ s_local, r_local = sch.get_loops(block=dequantize_block)[-2:]
+ s_local, vec_load = sch.split(
+ s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True
+ )
+ sch.reorder(s_local, r_local, vec_load) # either s_local or
r_local should be 1
+ sch.vectorize(vec_load)
+
+ # load vector into shared memory, shape should be the whole vector
+ if LOAD_V_SHARED:
+ assert len(vector_input_buffers) == 1
+ V_shared = sch.cache_read(rf, read_buffer_index=0,
storage_scope="shared")
+ sch.compute_at(V_shared, tr, preserve_unit_loops=True)
+ l = sch.get_loops(block=V_shared)[-1]
+ loop: tir.For = sch.get(l)
+ if isinstance(loop.extent, tir.IntImm):
+ # avoid introducing predicates when vector length is too
large
+ vec_length = max(
+ min(
+ get_max_factor(
+ (int)(loop.extent),
+ [TS * TR * 1, TS * TR * 2, TS * TR * 4, TS *
TR * 8],
+ )
+ // TS
+ // TR,
+ LOAD_V_VEC,
+ ),
+ 1,
+ )
+ else:
+ vec_length = LOAD_V_VEC
+ if TAG_R == "threadIdx.x":
+ _, ty, tx, vec = sch.split(
+ l, factors=[None, TS, TR, vec_length],
preserve_unit_iters=True
+ )
+ else:
+ _, ty, tx, vec = sch.split(
+ l, factors=[None, TR, TS, vec_length],
preserve_unit_iters=True
+ )
+ sch.bind(ty, "threadIdx.y")
+ sch.bind(tx, "threadIdx.x")
+ sch.vectorize(vec)
+ if pad_input_block is not None:
+ sch.compute_inline(pad_input_block)
+
+ # reduce tile_s * tr * vec to tile_s * tr
+ sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True)
+ tr, vec_c, batch_loop, *ts_tile_s = sch.get_loops(block=rf2)[2:]
+ ts_tile_s = sch.fuse(*ts_tile_s)
+ ts, tile_s = sch.split(ts_tile_s, factors=[TS, None],
preserve_unit_iters=True)
+ tile_s, vec_s = sch.split(
+ tile_s,
+ factors=[None, get_max_factor(TILE_S, [1, 2, 4, 8])],
+ preserve_unit_iters=True,
+ )
+ sch.reorder(ts, tr, tile_s, batch_loop, vec_s, vec_c)
+ sch.bind(ts, TAG_S)
+ sch.bind(tr, TAG_R)
+ sch.vectorize(vec_s)
+
+ # reduce tile_s * tr to tile_s
+ sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True)
+
+ tr, batch_loop, *ts_tile_s = sch.get_loops(block=gemv)[2:]
+ ts_tile_s = sch.fuse(*ts_tile_s)
+ ts, tile_s = sch.split(ts_tile_s, factors=[TS, None],
preserve_unit_iters=True)
+ sch.reorder(tile_s, batch_loop, ts, tr)
+ sch.bind(ts, TAG_S)
+ sch.bind(tr, TAG_R)
+
+ sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[4])
+ sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1])
+
+ sch.set_scope(rf, buffer_index=0, storage_scope="local")
+ sch.set_scope(rf2, buffer_index=0, storage_scope="local")
+
+ unroll_factor = UNROLL
+
+ sch.annotate(
+ block_or_loop=sch.get_loops(rf)[4],
+ ann_key="pragma_auto_unroll_max_step",
+ ann_val=unroll_factor,
+ )
+ sch.annotate(
+ block_or_loop=sch.get_loops(rf)[4],
ann_key="pragma_unroll_explicit", ann_val=1
+ )
+
+ sch.annotate(
+ block_or_loop=sch.get_loops(rf2)[4],
+ ann_key="pragma_auto_unroll_max_step",
+ ann_val=unroll_factor,
+ )
+ sch.annotate(
+ block_or_loop=sch.get_loops(rf2)[4],
ann_key="pragma_unroll_explicit", ann_val=1
+ )
+
+ if LOAD_V_SHARED:
+ sch.annotate(
+ block_or_loop=sch.get_loops(V_shared)[-4],
+ ann_key="pragma_unroll_explicit",
+ ann_val=unroll_factor,
+ )
+ sch.annotate(
+ block_or_loop=sch.get_loops(V_shared)[-4],
ann_key="pragma_vectorize", ann_val=1
+ )
+
+ epilogue = sch.get_consumers(gemv)
+ # Schedule epilogue
+ if epilogue:
+ epilogue = epilogue[0]
+ if is_broadcast_epilogue(sch, block, epilogue):
+ sch.reverse_compute_at(epilogue, bx)
+ sch.set_scope(block, 0, "shared")
+ _, _, _, *s = sch.get_loops(epilogue) # pylint:
disable=invalid-name
+ _, tx = sch.split(sch.fuse(*s), factors=[None, TX])
+ sch.bind(tx, "threadIdx.x")
+ else:
+ sch.reverse_compute_at(epilogue, bx,
preserve_unit_loops=True)
+ ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[3:])
+ ts_tile_s = sch.get_loops(epilogue)[-1]
+ ts, tile_s = sch.split(ts_tile_s, factors=[TS, None],
preserve_unit_iters=True)
+ sch.bind(ts, TAG_S)
+ sch.set_scope(block, 0, "local")
+
+ return sch
+
+ # Specify the `len_tx` and `len_ty` according to the loop extent
+ _, batch, s, r, c = sch.get_loops(block=block)
+ len_batch, len_s, len_r, len_c = (
+ get_extent(sch, batch),
+ get_extent(sch, s),
+ get_extent(sch, r),
+ get_extent(sch, c),
+ )
+ len_S = len_batch * len_s
+ len_R = len_r * len_c
+
+ TAG_S, TAG_R = "threadIdx.y", "threadIdx.x"
+ if target.kind.name == "cuda":
+ VEC_C = 4
+ LOAD_V_SHARED = True
+ LOAD_V_VEC = 8
+ UNROLL = 256
+ if isinstance(len_S, int):
+ if len_S > len_R:
+ TS, TR = 4, 64
+ else:
+ TS, TR = 16, 32
+ elif target.kind.name == "metal":
+ # Note that the following tile size is tuned on M2 Ultra for 7B
+ TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
+ VEC_C = 1
+ LOAD_V_SHARED = False
+ LOAD_V_VEC = -1
+ UNROLL = 256
+ if isinstance(len_S, int):
+ if len_S > len_R:
+ TS, TR = 2, 32
+ else:
+ TS, TR = 2, 64
+ elif target.kind.name == "rocm":
+ VEC_C = 4
+ LOAD_V_SHARED = True
+ LOAD_V_VEC = 8
+ UNROLL = 256
+ if isinstance(len_S, int):
+ if len_S > len_R:
+ TS, TR = 1, 128
+ else:
+ TS, TR = 8, 64
+ elif target.kind.name == "opencl" and "android" in str(target.host):
+ TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
+ VEC_C = 8
+ LOAD_V_SHARED = False
+ LOAD_V_VEC = -1
+ UNROLL = 8
+ TS, TR = 2, 32
+ elif target.kind.name == "vulkan":
+ VEC_C = 4
+ LOAD_V_SHARED = True
+ LOAD_V_VEC = 4
+ UNROLL = 256
+ if isinstance(len_S, int):
+ if len_S > len_R:
+ TS, TR = 4, 32
+ else:
+ TS, TR = 16, 32
+ elif target.kind.name == "opencl" and "mali" in str(target.attrs):
+ VEC_C = 8
+ LOAD_V_SHARED = False
+ LOAD_V_VEC = -1
+ UNROLL = 64
+ TS, TR = 1, 64
+ else:
+ VEC_C = 1
+ LOAD_V_SHARED = False
+ LOAD_V_VEC = -1
+ UNROLL = 64
+ TS, TR = 1, 64
+
+ if not isinstance(len_S, int):
+ TS, TR = 1, 64
+
+ while TS * TR > target.max_num_threads:
+ if TS > 1:
+ TS //= 2
+ else:
+ TR //= 2
+
+ TILE_S, TILE_R = (
+ 2,
+ len_c
+ if len_c > 1
+ else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8])
// TR, 1),
+ )
+ VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C)
+ VEC_LOAD = 1
+ return apply(
+ sch,
+ gemv=block,
+ TAG_S=TAG_S,
+ TAG_R=TAG_R,
+ TS=TS,
+ TR=TR,
+ TILE_S=TILE_S,
+ TILE_R=TILE_R,
+ VEC_LOAD=VEC_LOAD,
+ VEC_C=VEC_C,
+ LOAD_V_SHARED=LOAD_V_SHARED,
+ LOAD_V_VEC=LOAD_V_VEC,
+ UNROLL=UNROLL,
+ )
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index 4eca8aebd7..bdadb6db0f 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -240,6 +240,10 @@ Array<tvm::transform::Pass> CreatePassList(bool
disable_loop_partition) {
if (use_async_copy) {
pass_list.push_back(tir::transform::LowerAsyncDMA());
}
+ // HoistIfThenElse must be applied before UnrollLoop
+ // because HoistIfThenElse could utilize for loop structure
+ // which might be unrolled in UnrollLoop
+ pass_list.push_back(tir::transform::HoistIfThenElse());
pass_list.push_back(tir::transform::UnrollLoop());
// Add user-defined phase-2 passes
@@ -250,7 +254,6 @@ Array<tvm::transform::Pass> CreatePassList(bool
disable_loop_partition) {
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::RemoveNoOp());
pass_list.push_back(tir::transform::RewriteUnsafeSelect());
- pass_list.push_back(tir::transform::HoistIfThenElse());
// Add user-defined phase-3 passes
pass_list.insert(pass_list.end(), user_lower_phase3.begin(),
user_lower_phase3.end());
@@ -586,7 +589,6 @@ transform::Sequential MixedModulePassManager(IRModule
mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn"));
- mixed_pass_list.push_back(tir::transform::MergeSharedMemoryAllocations());
mixed_pass_list.push_back(tir::transform::ThreadSync("warp"));
mixed_pass_list.push_back(tir::transform::InferFragment());
mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());
@@ -604,6 +606,9 @@ transform::Sequential MixedModulePassManager(IRModule
mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+ // MergeSharedMemoryAllocations must be applied after SplitHostDevice
+ // because the merged allocation site is at the beginning of each device
function
+ mixed_pass_list.push_back(tir::transform::MergeSharedMemoryAllocations());
bool unpacked_api = mixed_mod->GetAttr<relay::Executor>(tvm::attr::kExecutor)
.value_or(relay::Executor::Create("graph", {}))
diff --git a/src/tir/transforms/hoist_expression.cc
b/src/tir/transforms/hoist_expression.cc
index 494fd7184f..f0fc90ee32 100644
--- a/src/tir/transforms/hoist_expression.cc
+++ b/src/tir/transforms/hoist_expression.cc
@@ -558,7 +558,14 @@ Pass HoistIfThenElse() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
auto cfg = ctx->GetConfig<HoistIfThenElseConfig>("tir.HoistIfThenElse");
-
+ auto flag = f->GetAttr<Integer>("tir.HoistIfThenElseExprWithBlock");
+ if (flag && flag.value().IntValue() == 1) {
+ HoistExpressionConfig
config(static_cast<int>(HoistedConditionals::kUsingBlockVar) |
+
static_cast<int>(HoistedConditionals::kIfElseExpr),
+
static_cast<int>(HoistedLetBindings::kNone));
+ n->body = ExpressionHoister::Hoist(std::move(n->body), config);
+ return f;
+ }
if (!cfg.defined()) {
cfg = AttrsWithDefaultValues<HoistIfThenElseConfig>();
}
diff --git a/tests/python/dlight/test_gpu_low_batch_gemv.py
b/tests/python/dlight/test_gpu_low_batch_gemv.py
new file mode 100644
index 0000000000..5827b7b810
--- /dev/null
+++ b/tests/python/dlight/test_gpu_low_batch_gemv.py
@@ -0,0 +1,255 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import pytest
+
+import tvm.testing
+from tvm import dlight as dl
+from tvm.script import tir as T
+from tvm.target import Target
+
+
+def test_batch_decode_gemv():
+ # fmt: off
+
+ @T.prim_func(private=True)
+ def before(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"),
lv430: T.Buffer((T.int64(4096), T.int64(896)), "float16"), p_lv807: T.handle,
p_output0: T.handle):
+ T.func_attr({"tir.noalias": T.bool(True),
"tir.HoistIfThenElseExprWithBlock": 1})
+ batch_size = T.int64()
+ lv807 = T.match_buffer(p_lv807, (batch_size, T.int64(1),
T.int64(28672)), "float16")
+ NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size,
T.int64(1), T.int64(4096)), "float16")
+ # with T.block("root"):
+ compute = T.alloc_buffer((T.int64(4096), T.int64(28672)), "float16")
+ dequantize_intermediate_intermediate = T.alloc_buffer((T.int64(4096),
T.int64(28672)), "float16")
+ for i0, i1 in T.grid(T.int64(4096), T.int64(28672)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(lv429[v_i0, v_i1 // T.int64(8)])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = T.Cast("float16",
T.bitwise_and(T.shift_right(lv429[v_i0, v_i1 // T.int64(8)], T.Cast("uint32",
v_i1 % T.int64(8) * T.int64(4))), T.uint32(15)))
+ for i0, i1 in T.grid(T.int64(4096), T.int64(28672)):
+ with T.block("dequantize"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(compute[v_i0, v_i1], lv430[v_i0, v_i1 // T.int64(32)])
+ T.writes(dequantize_intermediate_intermediate[v_i0, v_i1])
+ dequantize_intermediate_intermediate[v_i0, v_i1] =
(compute[v_i0, v_i1] - T.float16(7)) * lv430[v_i0, v_i1 // T.int64(32)]
+ for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(4096),
T.int64(28672)):
+ with T.block("NT_matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+ T.reads(lv807[v_i0, v_i1, v_k],
dequantize_intermediate_intermediate[v_i2, v_k])
+ T.writes(NT_matmul_intermediate[v_i0, v_i1, v_i2])
+ with T.init():
+ NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
+ NT_matmul_intermediate[v_i0, v_i1, v_i2] =
NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv807[v_i0, v_i1, v_k] *
dequantize_intermediate_intermediate[v_i2, v_k]
+
+ @T.prim_func(private=True)
+ def expected(lv429: T.Buffer((T.int64(4096), T.int64(3584)), "uint32"),
lv430: T.Buffer((T.int64(4096), T.int64(896)), "float16"), p_lv807: T.handle,
p_output0: T.handle):
+ T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1,
"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+ batch_size = T.int64()
+ lv807 = T.match_buffer(p_lv807, (batch_size, T.int64(1),
T.int64(28672)), "float16")
+ NT_matmul_intermediate = T.match_buffer(p_output0, (batch_size,
T.int64(1), T.int64(4096)), "float16")
+ # with T.block("root"):
+ dequantize_intermediate_intermediate_local =
T.alloc_buffer((T.int64(4096), T.int64(28672)), "float16", scope="local")
+ NT_matmul_intermediate_pad_local = T.alloc_buffer(((batch_size +
T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16",
scope="local")
+ NT_matmul_intermediate_pad_rf_local = T.alloc_buffer((T.int64(64),
(batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1),
T.int64(4096)), "float16", scope="local")
+ NT_matmul_intermediate_pad_rf_local_1 = T.alloc_buffer((T.int64(64),
(batch_size + T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1),
T.int64(4096)), "float16", scope="local")
+ for ax0_0 in T.thread_binding((batch_size + T.int64(3)) // T.int64(4),
thread="blockIdx.y"):
+ for u_fused_ax1_fused_fused_0 in T.thread_binding(T.int64(1024),
thread="blockIdx.x"):
+ for u_fused_ax1_fused_fused_1 in T.thread_binding(T.int64(2),
thread="threadIdx.x"):
+ for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in
T.thread_binding(T.int64(64), thread="threadIdx.y"):
+ for ax0_1_init, u_fused_ax1_fused_fused_2_init in
T.grid(T.int64(4), T.int64(2)):
+ for
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in
T.vectorized(T.int64(1)):
+ with T.block("NT_matmul_rf_init"):
+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(64),
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 +
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init)
+ v0 = T.axis.spatial((batch_size +
T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init)
+ v1 = T.axis.spatial(T.int64(4096),
u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2)
+ u_fused_ax1_fused_fused_2_init)
+ T.reads()
+
T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
v0, T.int64(0), v1])
+
NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
v0, T.int64(0), v1] = T.float16(0)
+ for ax2_fused_u_fused_0 in T.serial(T.int64(56),
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax0_0_1, ax1 in T.grid(T.int64(2), T.int64(8)):
+ for ax0_1 in T.vectorized(T.int64(1)):
+ with T.block("dequantize"):
+ v0 = T.axis.spatial(T.int64(4096),
u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2)
+ ax0_0_1 + ax0_1)
+ v1 = T.axis.spatial(T.int64(28672),
ax2_fused_u_fused_0 * T.int64(512) +
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * T.int64(8) + ax1)
+ T.reads(lv429[v0, v1 // T.int64(8)],
lv430[v0, v1 // T.int64(32)])
+
T.writes(dequantize_intermediate_intermediate_local[v0, v1])
+
dequantize_intermediate_intermediate_local[v0, v1] = (T.Cast("float16",
T.bitwise_and(T.shift_right(lv429[v0, v1 // T.int64(8)], T.Cast("uint32", v1 %
T.int64(8) * T.int64(4))), T.uint32(15))) - T.float16(7)) * lv430[v0, v1 //
T.int64(32)]
+ for ax0_1, u_fused_ax1_fused_fused_2,
ax2_fused_u_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(8)):
+ for
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(T.int64(1)):
+ with T.block("NT_matmul_rf_update"):
+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(64),
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 +
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1)
+ v0 = T.axis.spatial((batch_size +
T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1)
+ v1 = T.axis.spatial(T.int64(4096),
u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2)
+ u_fused_ax1_fused_fused_2)
+ vax2_fused_u_fused_0,
vax2_fused_u_fused_2 = T.axis.remap("RR", [ax2_fused_u_fused_0,
ax2_fused_u_fused_2])
+
T.reads(NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
v0, T.int64(0), v1], lv807[v0, T.int64(0), vax2_fused_u_fused_0 * T.int64(512)
+ vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) +
vax2_fused_u_fused_2], dequantize_intermediate_intermediate_local[v1,
vax2_fused_u_fused_0 * T.int64(512) +
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) +
vax2_fused_u_fused_2])
+
T.writes(NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
v0, T.int64(0), v1])
+
NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
v0, T.int64(0), v1] =
NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
v0, T.int64(0), v1] + T.if_then_else(v0 < batch_size, lv807[v0, T.int64(0),
vax2_fused_u_fused_0 * T.int64(512) +
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) +
vax2_fused_u_fused_2], T.float16(0)) *
dequantize_intermediate_intermediate_local[v1, [...]
+ for ax3_fused_0 in T.thread_binding(T.int64(2),
thread="threadIdx.x"):
+ for ax0 in T.thread_binding(T.int64(64),
thread="threadIdx.y"):
+ for ax3_fused_1_0 in T.serial(T.int64(1),
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax2 in range(T.int64(4)):
+ for ax3_fused_1_1 in T.vectorized(T.int64(2)):
+ with T.block("NT_matmul_rf_init"):
+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(64),
ax0)
+ v0 = T.axis.spatial((batch_size +
T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
+ v1 = T.axis.spatial(T.int64(4096),
u_fused_ax1_fused_fused_0 * T.int64(4) + ax3_fused_0 * T.int64(2) +
ax3_fused_1_0 * T.int64(2) + ax3_fused_1_1)
+ T.reads()
+
T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
v0, T.int64(0), v1])
+
NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
v0, T.int64(0), v1] = T.float16(0)
+ for ax1 in range(T.int64(1)):
+ with T.block("NT_matmul_rf_update"):
+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0,
ax1])
+ v0 = T.axis.spatial((batch_size +
T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
+ v1 = T.axis.spatial(T.int64(4096),
u_fused_ax1_fused_fused_0 * T.int64(4) + ax3_fused_0 * T.int64(2) +
ax3_fused_1_0 * T.int64(2) + ax3_fused_1_1)
+
T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
v0, T.int64(0), v1],
NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0
+ vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1])
+
T.writes(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
v0, T.int64(0), v1])
+
NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
v0, T.int64(0), v1] =
NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
v0, T.int64(0), v1] +
NT_matmul_intermediate_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0
+ vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1]
+ for ax2_fused_1, ax1 in T.grid(T.int64(2), T.int64(4)):
+ for ax2_fused_0 in T.thread_binding(T.int64(2),
thread="threadIdx.x"):
+ for ax0 in T.thread_binding(T.int64(64),
thread="threadIdx.y"):
+ with T.block("NT_matmul"):
+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(64),
ax0)
+ v0 = T.axis.spatial((batch_size + T.int64(3))
// T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1)
+ v1 = T.axis.spatial(T.int64(4096),
u_fused_ax1_fused_fused_0 * T.int64(4) + ax2_fused_0 * T.int64(2) + ax2_fused_1)
+
T.reads(NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
v0, T.int64(0), v1])
+ T.writes(NT_matmul_intermediate_pad_local[v0,
T.int64(0), v1])
+ with T.init():
+ NT_matmul_intermediate_pad_local[v0,
T.int64(0), v1] = T.float16(0)
+ NT_matmul_intermediate_pad_local[v0,
T.int64(0), v1] = NT_matmul_intermediate_pad_local[v0, T.int64(0), v1] +
NT_matmul_intermediate_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
v0, T.int64(0), v1]
+ for ax0 in range(T.int64(4)):
+ for ax1_fused_0 in T.thread_binding(T.int64(2),
thread="threadIdx.x"):
+ for ax1_fused_1 in range(T.int64(2)):
+ with T.block("NT_matmul_intermediate_pad"):
+ v0 = T.axis.spatial(batch_size, ax0_0 *
T.int64(4) + ax0)
+ v1 = T.axis.spatial(T.int64(4096),
u_fused_ax1_fused_fused_0 * T.int64(4) + ax1_fused_0 * T.int64(2) + ax1_fused_1)
+ T.where((ax0_0 - (batch_size + T.int64(3)) //
T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 <
batch_size)
+ T.reads(NT_matmul_intermediate_pad_local[v0,
T.int64(0), v1])
+ T.writes(NT_matmul_intermediate[v0,
T.int64(0), v1])
+ NT_matmul_intermediate[v0, T.int64(0), v1] =
NT_matmul_intermediate_pad_local[v0, T.int64(0), v1]
+ # fmt: on
+ mod = tvm.IRModule({"main": before})
+ with Target("metal"):
+ mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod)
+ tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
+def test_batch_gemv():
+ N = 4096
+ K = 4096
+ # fmt: off
+ @T.prim_func(private=True)
+ def before(var_A: T.handle, B: T.Buffer((T.int64(N), T.int64(K)),
"float16"), var_NT_matmul: T.handle):
+ T.func_attr({"tir.noalias": T.bool(True),
"tir.HoistIfThenElseExprWithBlock": 1})
+ batch_size = T.int64()
+ A = T.match_buffer(var_A, (batch_size, T.int64(1), T.int64(K)),
"float16")
+ NT_matmul = T.match_buffer(var_NT_matmul, (batch_size, T.int64(1),
T.int64(N)), "float16")
+ # with T.block("root"):
+ for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(N),
T.int64(K)):
+ with T.block("NT_matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+ T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k])
+ T.writes(NT_matmul[v_i0, v_i1, v_i2])
+ with T.init():
+ NT_matmul[v_i0, v_i1, v_i2] = T.float16(0)
+ NT_matmul[v_i0, v_i1, v_i2] = NT_matmul[v_i0, v_i1, v_i2] +
A[v_i0, v_i1, v_k] * B[v_i2, v_k]
+
+ @T.prim_func(private=True)
+ def expected(var_A: T.handle, B: T.Buffer((T.int64(4096), T.int64(4096)),
"float16"), var_NT_matmul: T.handle):
+ T.func_attr({"tir.HoistIfThenElseExprWithBlock": 1,
"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+ batch_size = T.int64()
+ A = T.match_buffer(var_A, (batch_size, T.int64(1), T.int64(4096)),
"float16")
+ NT_matmul = T.match_buffer(var_NT_matmul, (batch_size, T.int64(1),
T.int64(4096)), "float16")
+ # with T.block("root"):
+ NT_matmul_pad_local = T.alloc_buffer(((batch_size + T.int64(3)) //
T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16", scope="local")
+ NT_matmul_pad_rf_local = T.alloc_buffer((T.int64(64), (batch_size +
T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16",
scope="local")
+ NT_matmul_pad_rf_local_1 = T.alloc_buffer((T.int64(64), (batch_size +
T.int64(3)) // T.int64(4) * T.int64(4), T.int64(1), T.int64(4096)), "float16",
scope="local")
+ for ax0_0 in T.thread_binding((batch_size + T.int64(3)) // T.int64(4),
thread="blockIdx.y"):
+ for u_fused_ax1_fused_fused_0 in T.thread_binding(T.int64(1024),
thread="blockIdx.x"):
+ for u_fused_ax1_fused_fused_1 in T.thread_binding(T.int64(2),
thread="threadIdx.x"):
+ for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in
T.thread_binding(T.int64(64), thread="threadIdx.y"):
+ for ax0_1_init, u_fused_ax1_fused_fused_2_init in
T.grid(T.int64(4), T.int64(2)):
+ for
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in
T.vectorized(T.int64(1)):
+ with T.block("NT_matmul_rf_init"):
+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(64),
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 +
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init)
+ v0 = T.axis.spatial((batch_size +
T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1_init)
+ v1 = T.axis.spatial(T.int64(4096),
u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2)
+ u_fused_ax1_fused_fused_2_init)
+ T.reads()
+
T.writes(NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
v0, T.int64(0), v1])
+
NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0,
T.int64(0), v1] = T.float16(0)
+ for ax2_fused_u_fused_0 in T.serial(T.int64(8),
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax0_1, u_fused_ax1_fused_fused_2,
ax2_fused_u_fused_2 in T.grid(T.int64(4), T.int64(2), T.int64(8)):
+ for
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(T.int64(1)):
+ with T.block("NT_matmul_rf_update"):
+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(T.int64(64),
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 +
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1)
+ v0 = T.axis.spatial((batch_size +
T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax0_1)
+ v1 = T.axis.spatial(T.int64(4096),
u_fused_ax1_fused_fused_0 * T.int64(4) + u_fused_ax1_fused_fused_1 * T.int64(2)
+ u_fused_ax1_fused_fused_2)
+ vax2_fused_u_fused_0,
vax2_fused_u_fused_2 = T.axis.remap("RR", [ax2_fused_u_fused_0,
ax2_fused_u_fused_2])
+
T.reads(NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
v0, T.int64(0), v1], A[v0, T.int64(0), vax2_fused_u_fused_0 * T.int64(512) +
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) +
vax2_fused_u_fused_2], B[v1, vax2_fused_u_fused_0 * T.int64(512) +
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) +
vax2_fused_u_fused_2])
+
T.writes(NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
v0, T.int64(0), v1])
+
NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0,
T.int64(0), v1] =
NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, v0,
T.int64(0), v1] + T.if_then_else(v0 < batch_size, A[v0, T.int64(0),
vax2_fused_u_fused_0 * T.int64(512) +
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused * T.int64(8) +
vax2_fused_u_fused_2], T.float16(0)) * B[v1, vax2_fused_u_fused_0 *
T.int64(512) + vax2_fused_u_fused_1_ax2_fused_u_ [...]
+ for ax3_fused_0 in T.thread_binding(T.int64(2),
thread="threadIdx.x"):
+ for ax0 in T.thread_binding(T.int64(64),
thread="threadIdx.y"):
+ for ax3_fused_1_0 in T.serial(T.int64(1),
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax2 in range(T.int64(4)):
+ for ax3_fused_1_1 in T.vectorized(T.int64(2)):
+ with T.block("NT_matmul_rf_init"):
+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(T.int64(64),
ax0)
+ v0 = T.axis.spatial((batch_size +
T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
+ v1 = T.axis.spatial(T.int64(4096),
u_fused_ax1_fused_fused_0 * T.int64(4) + ax3_fused_0 * T.int64(2) +
ax3_fused_1_0 * T.int64(2) + ax3_fused_1_1)
+ T.reads()
+
T.writes(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
v0, T.int64(0), v1])
+
NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0,
T.int64(0), v1] = T.float16(0)
+ for ax1 in range(T.int64(1)):
+ with T.block("NT_matmul_rf_update"):
+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0,
ax1])
+ v0 = T.axis.spatial((batch_size +
T.int64(3)) // T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax2)
+ v1 = T.axis.spatial(T.int64(4096),
u_fused_ax1_fused_fused_0 * T.int64(4) + ax3_fused_0 * T.int64(2) +
ax3_fused_1_0 * T.int64(2) + ax3_fused_1_1)
+
T.reads(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
v0, T.int64(0), v1],
NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 +
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1])
+
T.writes(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
v0, T.int64(0), v1])
+
NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0,
T.int64(0), v1] =
NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0,
T.int64(0), v1] +
NT_matmul_pad_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 +
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, v0, T.int64(0), v1]
+ for ax2_fused_1, ax1 in T.grid(T.int64(2), T.int64(4)):
+ for ax2_fused_0 in T.thread_binding(T.int64(2),
thread="threadIdx.x"):
+ for ax0 in T.thread_binding(T.int64(64),
thread="threadIdx.y"):
+ with T.block("NT_matmul"):
+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(T.int64(64),
ax0)
+ v0 = T.axis.spatial((batch_size + T.int64(3))
// T.int64(4) * T.int64(4), ax0_0 * T.int64(4) + ax1)
+ v1 = T.axis.spatial(T.int64(4096),
u_fused_ax1_fused_fused_0 * T.int64(4) + ax2_fused_0 * T.int64(2) + ax2_fused_1)
+
T.reads(NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
v0, T.int64(0), v1])
+ T.writes(NT_matmul_pad_local[v0, T.int64(0),
v1])
+ with T.init():
+ NT_matmul_pad_local[v0, T.int64(0), v1] =
T.float16(0)
+ NT_matmul_pad_local[v0, T.int64(0), v1] =
NT_matmul_pad_local[v0, T.int64(0), v1] +
NT_matmul_pad_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, v0,
T.int64(0), v1]
+ for ax0 in range(T.int64(4)):
+ for ax1_fused_0 in T.thread_binding(T.int64(2),
thread="threadIdx.x"):
+ for ax1_fused_1 in range(T.int64(2)):
+ with T.block("NT_matmul_pad"):
+ v0 = T.axis.spatial(batch_size, ax0_0 *
T.int64(4) + ax0)
+ v1 = T.axis.spatial(T.int64(4096),
u_fused_ax1_fused_fused_0 * T.int64(4) + ax1_fused_0 * T.int64(2) + ax1_fused_1)
+ T.where((ax0_0 - (batch_size + T.int64(3)) //
T.int64(4) < T.int64(0) or ax0_0 == T.int64(0)) and ax0_0 * T.int64(4) + ax0 <
batch_size)
+ T.reads(NT_matmul_pad_local[v0, T.int64(0),
v1])
+ T.writes(NT_matmul[v0, T.int64(0), v1])
+ NT_matmul[v0, T.int64(0), v1] =
NT_matmul_pad_local[v0, T.int64(0), v1]
+ # fmt: on
+ mod = tvm.IRModule({"main": before})
+ with Target("metal"):
+ mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod)
+ tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()