masahi commented on code in PR #13642:
URL: https://github.com/apache/tvm/pull/13642#discussion_r1057990584


##########
python/tvm/topi/x86/dense.py:
##########
@@ -373,6 +375,153 @@ def _callback(op):
     return s
 
 
+def dense_amx_int8_compute(cfg, data, packed_w, bias=None):

Review Comment:
   This can be shared with the vnni compute.



##########
python/tvm/topi/x86/dense.py:
##########
@@ -373,6 +375,153 @@ def _callback(op):
     return s
 
 
+def dense_amx_int8_compute(cfg, data, packed_w, bias=None):
+    """Compute for uint8 x int8 -> int32 dense"""
+    m, k = data.shape
+    n_o, _, n_i, _ = packed_w.shape
+    ak = te.reduce_axis((0, k), name="k")
+
+    C = te.compute(
+        (m, n_o * n_i),
+        lambda i, j: te.sum(
+            data[i, ak].astype("int32")
+            * packed_w[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 
16, ak % 4].astype(
+                "int32"
+            ),
+            axis=ak,
+        ),
+        tag="dense_amx_int8",
+        attrs={"schedule_rule": "meta_schedule.dense_amx_int8"},
+    )
+
+    if bias is not None:
+        C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], 
tag=tag.BROADCAST)
+
+    return C
+
+
+def dense_amx_int8_schedule(cfg, s, C, O, do_parallel=True):
+    """Schedule dense compute using AMX TMUL instruction"""
+    # C: The output of GEMM
+    # O: The output of the fused op
+    def split_x(out):
+        default_x_split_factor1 = 32
+        default_x_split_factor2 = 2
+        default_x_split_factor3 = 2
+        default_x_split_factor4 = 2
+        a_x = s[out].op.axis[-2]
+
+        if cfg.is_fallback:
+            a_xo, a_xi = s[out].split(a_x, factor=default_x_split_factor1)
+            a_xo2, a_xo1 = s[out].split(a_xo, factor=default_x_split_factor2)
+            a_xo3, a_xo2 = s[out].split(a_xo2, factor=default_x_split_factor3)
+            a_xo4, a_xo3 = s[out].split(a_xo3, factor=default_x_split_factor4)
+            return [a_xo4, a_xo3, a_xo2, a_xo1, a_xi]
+
+        cfg.define_split("tile_x", a_x, num_outputs=5, filter=lambda x: 
x.size[-1] == 32)
+        return cfg["tile_x"].apply(s, out, a_x)
+
+    def split_y(out):
+        default_y_split_factor1 = 32
+        default_y_split_factor2 = 4
+        default_y_split_factor3 = 4
+        default_y_split_factor4 = 4
+        a_y = s[out].op.axis[-1]
+
+        if cfg.is_fallback:
+            a_yo1, a_yo = s[out].split(a_y, factor=default_y_split_factor1)
+            a_yo2, a_yo1 = s[out].split(a_yo1, factor=default_y_split_factor2)
+            a_yo3, a_yo2 = s[out].split(a_yo2, factor=default_y_split_factor3)
+            a_yo4, a_yo3 = s[out].split(a_yo3, factor=default_y_split_factor4)
+            return [a_yo4, a_yo3, a_yo2, a_yo1, a_yo]
+
+        cfg.define_split("tile_y", a_y, num_outputs=5, filter=lambda y: 
y.size[-1] == 32)
+        return cfg["tile_y"].apply(s, out, a_y)
+
+    def split_k(out, rd_axis):
+        default_k_split_factor1 = 128
+        default_k_split_factor2 = 2
+        default_k_split_factor3 = 2
+        default_k_split_factor4 = 2
+
+        if cfg.is_fallback:
+            a_ko, a_ki = s[out].split(rd_axis, factor=default_k_split_factor1)
+            a_ko2, a_ko1 = s[out].split(a_ko, factor=default_k_split_factor2)
+            a_ko3, a_ko2 = s[out].split(a_ko2, factor=default_k_split_factor3)
+            a_ko4, a_ko3 = s[out].split(a_ko3, factor=default_k_split_factor4)
+            return [a_ko4, a_ko3, a_ko2, a_ko1, a_ki]
+
+        cfg.define_split("tile_k", rd_axis, num_outputs=5, filter=lambda y: 
y.size[-1] == 128)
+        return cfg["tile_k"].apply(s, out, rd_axis)
+
+    a_x, a_y = C.op.axis
+    (a_k,) = C.op.reduce_axis
+    CF = s.cache_write(C, "amx.tmm")
+
+    a_x3, a_x2, a_x1, a_xo, a_xi = split_x(C)
+    a_y3, a_y2, a_y1, a_yo, a_yi = split_y(C)
+    s[C].reorder(a_x3, a_y3, a_x2, a_y2, a_x1, a_y1, a_xo, a_yo, a_xi, a_yi)
+
+    s[CF].compute_at(s[C], a_yo)
+
+    (a_k_f,) = CF.op.reduce_axis
+    a_x_f, a_y_f = CF.op.axis
+
+    a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32)
+
+    a_yo_f, a_yi_f = s[CF].split(a_y_f, factor=32)
+    a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_ki_f = split_k(CF, a_k_f)
+    s[CF].reorder(a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_xo_f, a_yo_f, a_ki_f, 
a_xi_f, a_yi_f)
+
+    (m, k) = CF.op.input_tensors[0].shape
+    (n, c, n_i, c_i) = CF.op.input_tensors[1].shape
+    n = n * n_i
+
+    s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=int(k)))
+    s[C].tensorize(a_xi, acc_32x32_int32_sapphirerapids(LDC=int(n)))
+
+    if C == O:
+        fused = s[O].fuse(a_x3, a_y3)
+    else:
+        a_y3, a_y2, a_y1, a_yr, a_yi = split_y(O)
+        a_x3, a_x2, a_x1, a_xr, a_xi = split_x(O)
+
+        s[O].reorder(a_y3, a_x3, a_y2, a_x2, a_y1, a_x1, a_yr, a_xr, a_yi, 
a_xi)
+        s[O].vectorize(a_xi)
+
+        fused = s[O].fuse(a_x3, a_y3)
+
+    if do_parallel:
+        s[O].parallel(fused)
+
+    return s, fused
+
+
[email protected]_topi_compute("dense_amx_int8.x86")
+def dense_amx_int8(cfg, data, weight, bias=None, out_dtype=None):
+    """Compute for uint8 x int8 -> int32 dense"""
+    if out_dtype is None:
+        out_dtype = data.dtype
+    assert len(weight.shape) == 4
+    assert data.dtype == "uint8" and weight.dtype == "int8"
+    _, _, _, n_inner = get_const_tuple(weight.shape)  # out_dim
+    assert n_inner == 4
+    return dense_amx_int8_compute(cfg, data, weight, bias)
+
+
[email protected]_topi_schedule("dense_amx_int8.x86")
+def schedule_dense_amx_int8(cfg, outs):
+    """Create a schedule for dense_amx_int8"""
+    s = te.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if "dense_amx_int8" in op.tag:
+            dense_amx_int8_schedule(cfg, s, op.output(0), outs[0])
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s

Review Comment:
   For the two functions above, try removing dup with VNNI ones.



##########
python/tvm/topi/x86/tensor_intrin.py:
##########
@@ -348,3 +348,227 @@ def _instr(index):
         binds={data: a_buffer, kernel: b_buffer},
         default_buffer_params=buffer_params,
     )
+
+
+def dot_32x128x32_u8s8s32_sapphirerapids(LDA):
+    """
+    Int8 dot product by every 16x64 elements using AMX-TMUL Sapphire Rapids 
instructions.
+    The tdpxxd instruction takes two tile of uint8 and int8 datatype -- 
data[16][64] and
+    kernel[1][16][16][4] -- and computes a dot product of data[16][16] in 
int32 datatype.
+
+    (Physically, to efficiently leveraging the tile register, we constructing 
a 2x2 tiles
+    matmul which performs 32x128x32 in total)
+
+    The pseudo code is as follows:
+        for(k=0; k<2; k++){
+            for(n=0; n<2; n++){
+                tileload64(tmm_b, B)
+                for(m=0; m<2; m++){
+                    if(n==0)
+                        tileload64(tmm_a, A)

Review Comment:
   Have you considered tensorizing at finer granularity? Meaning, instead of 
tensoring the whole load + compute as one tensor intrin, tensorize each load 
and compute separately. That could make the hard coded outer-loop factors (2 in 
this intrin) tunable.  



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to