Qianshui-Jiang commented on code in PR #13642:
URL: https://github.com/apache/tvm/pull/13642#discussion_r1058782556
##########
python/tvm/topi/x86/dense.py:
##########
@@ -373,6 +375,153 @@ def _callback(op):
return s
+def dense_amx_int8_compute(cfg, data, packed_w, bias=None):
+ """Compute for uint8 x int8 -> int32 dense"""
+ m, k = data.shape
+ n_o, _, n_i, _ = packed_w.shape
+ ak = te.reduce_axis((0, k), name="k")
+
+ C = te.compute(
+ (m, n_o * n_i),
+ lambda i, j: te.sum(
+ data[i, ak].astype("int32")
+ * packed_w[tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j %
16, ak % 4].astype(
+ "int32"
+ ),
+ axis=ak,
+ ),
+ tag="dense_amx_int8",
+ attrs={"schedule_rule": "meta_schedule.dense_amx_int8"},
+ )
+
+ if bias is not None:
+ C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j],
tag=tag.BROADCAST)
+
+ return C
+
+
+def dense_amx_int8_schedule(cfg, s, C, O, do_parallel=True):
+ """Schedule dense compute using AMX TMUL instruction"""
+ # C: The output of GEMM
+ # O: The output of the fused op
+ def split_x(out):
+ default_x_split_factor1 = 32
+ default_x_split_factor2 = 2
+ default_x_split_factor3 = 2
+ default_x_split_factor4 = 2
+ a_x = s[out].op.axis[-2]
+
+ if cfg.is_fallback:
+ a_xo, a_xi = s[out].split(a_x, factor=default_x_split_factor1)
+ a_xo2, a_xo1 = s[out].split(a_xo, factor=default_x_split_factor2)
+ a_xo3, a_xo2 = s[out].split(a_xo2, factor=default_x_split_factor3)
+ a_xo4, a_xo3 = s[out].split(a_xo3, factor=default_x_split_factor4)
+ return [a_xo4, a_xo3, a_xo2, a_xo1, a_xi]
+
+ cfg.define_split("tile_x", a_x, num_outputs=5, filter=lambda x:
x.size[-1] == 32)
+ return cfg["tile_x"].apply(s, out, a_x)
+
+ def split_y(out):
+ default_y_split_factor1 = 32
+ default_y_split_factor2 = 4
+ default_y_split_factor3 = 4
+ default_y_split_factor4 = 4
+ a_y = s[out].op.axis[-1]
+
+ if cfg.is_fallback:
+ a_yo1, a_yo = s[out].split(a_y, factor=default_y_split_factor1)
+ a_yo2, a_yo1 = s[out].split(a_yo1, factor=default_y_split_factor2)
+ a_yo3, a_yo2 = s[out].split(a_yo2, factor=default_y_split_factor3)
+ a_yo4, a_yo3 = s[out].split(a_yo3, factor=default_y_split_factor4)
+ return [a_yo4, a_yo3, a_yo2, a_yo1, a_yo]
+
+ cfg.define_split("tile_y", a_y, num_outputs=5, filter=lambda y:
y.size[-1] == 32)
+ return cfg["tile_y"].apply(s, out, a_y)
+
+ def split_k(out, rd_axis):
+ default_k_split_factor1 = 128
+ default_k_split_factor2 = 2
+ default_k_split_factor3 = 2
+ default_k_split_factor4 = 2
+
+ if cfg.is_fallback:
+ a_ko, a_ki = s[out].split(rd_axis, factor=default_k_split_factor1)
+ a_ko2, a_ko1 = s[out].split(a_ko, factor=default_k_split_factor2)
+ a_ko3, a_ko2 = s[out].split(a_ko2, factor=default_k_split_factor3)
+ a_ko4, a_ko3 = s[out].split(a_ko3, factor=default_k_split_factor4)
+ return [a_ko4, a_ko3, a_ko2, a_ko1, a_ki]
+
+ cfg.define_split("tile_k", rd_axis, num_outputs=5, filter=lambda y:
y.size[-1] == 128)
+ return cfg["tile_k"].apply(s, out, rd_axis)
+
+ a_x, a_y = C.op.axis
+ (a_k,) = C.op.reduce_axis
+ CF = s.cache_write(C, "amx.tmm")
+
+ a_x3, a_x2, a_x1, a_xo, a_xi = split_x(C)
+ a_y3, a_y2, a_y1, a_yo, a_yi = split_y(C)
+ s[C].reorder(a_x3, a_y3, a_x2, a_y2, a_x1, a_y1, a_xo, a_yo, a_xi, a_yi)
+
+ s[CF].compute_at(s[C], a_yo)
+
+ (a_k_f,) = CF.op.reduce_axis
+ a_x_f, a_y_f = CF.op.axis
+
+ a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32)
+
+ a_yo_f, a_yi_f = s[CF].split(a_y_f, factor=32)
+ a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_ki_f = split_k(CF, a_k_f)
+ s[CF].reorder(a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_xo_f, a_yo_f, a_ki_f,
a_xi_f, a_yi_f)
+
+ (m, k) = CF.op.input_tensors[0].shape
+ (n, c, n_i, c_i) = CF.op.input_tensors[1].shape
+ n = n * n_i
+
+ s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=int(k)))
+ s[C].tensorize(a_xi, acc_32x32_int32_sapphirerapids(LDC=int(n)))
+
+ if C == O:
+ fused = s[O].fuse(a_x3, a_y3)
+ else:
+ a_y3, a_y2, a_y1, a_yr, a_yi = split_y(O)
+ a_x3, a_x2, a_x1, a_xr, a_xi = split_x(O)
+
+ s[O].reorder(a_y3, a_x3, a_y2, a_x2, a_y1, a_x1, a_yr, a_xr, a_yi,
a_xi)
+ s[O].vectorize(a_xi)
+
+ fused = s[O].fuse(a_x3, a_y3)
+
+ if do_parallel:
+ s[O].parallel(fused)
+
+ return s, fused
+
+
[email protected]_topi_compute("dense_amx_int8.x86")
+def dense_amx_int8(cfg, data, weight, bias=None, out_dtype=None):
+ """Compute for uint8 x int8 -> int32 dense"""
+ if out_dtype is None:
+ out_dtype = data.dtype
+ assert len(weight.shape) == 4
+ assert data.dtype == "uint8" and weight.dtype == "int8"
+ _, _, _, n_inner = get_const_tuple(weight.shape) # out_dim
+ assert n_inner == 4
+ return dense_amx_int8_compute(cfg, data, weight, bias)
+
+
[email protected]_topi_schedule("dense_amx_int8.x86")
+def schedule_dense_amx_int8(cfg, outs):
+ """Create a schedule for dense_amx_int8"""
+ s = te.create_schedule([x.op for x in outs])
+
+ def _callback(op):
+ if "dense_amx_int8" in op.tag:
+ dense_amx_int8_schedule(cfg, s, op.output(0), outs[0])
+
+ traverse_inline(s, outs[0].op, _callback)
+ return s
Review Comment:
yep, already unified the AMX and VNNI interface in x86 strategy.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]