This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6c34361369 [Hexagon] Adapt some intrinsics for high vector lanes
(#14345)
6c34361369 is described below
commit 6c3436136926613367ac3e19ece8bed90c4d5efb
Author: apeskov <[email protected]>
AuthorDate: Mon Mar 27 08:08:30 2023 +0400
[Hexagon] Adapt some intrinsics for high vector lanes (#14345)
* [HEX] Enhanced vector lanes for some intrinsics
* fix pylint
Signed-off-by: Alexander Peskov <[email protected]>
* fix lint 2
Signed-off-by: Alexander Peskov <[email protected]>
* Fix typo
Signed-off-by: Alexander Peskov <[email protected]>
---------
Signed-off-by: Alexander Peskov <[email protected]>
---
python/tvm/topi/hexagon/tensor_intrin.py | 309 +++++++++++++++------
.../test_hexagon/test_fixed_point_multiply.py | 138 ++++++++-
2 files changed, 363 insertions(+), 84 deletions(-)
diff --git a/python/tvm/topi/hexagon/tensor_intrin.py
b/python/tvm/topi/hexagon/tensor_intrin.py
index 3e9fd47b0f..24bbacf37c 100644
--- a/python/tvm/topi/hexagon/tensor_intrin.py
+++ b/python/tvm/topi/hexagon/tensor_intrin.py
@@ -22,44 +22,165 @@ from tvm.ir import register_intrin_lowering
from tvm import te
+def get_lanes(dtype: str):
+ if "x" not in dtype:
+ return 1
+
+ _, lanes = dtype.split("x")
+ return int(lanes)
+
+
+def is_vector_type(dtype: str):
+ return get_lanes(dtype) != 1
+
+
+def is_power_of_2(n: int):
+ return (n & (n - 1) == 0) and n != 0
+
+
+def _adapt_to_highest_lanes(*args, intrinsic=None, intrinsic_lanes: int = 0):
+ """Apply provided lowering intrinsic to arguments with longer vector data
type.
+
+ This wrapper will do next actions:
+ * Split each argument into chunks with size equal intrinsic_lanes
+ * Apply provided intrinsic for each argument chunk
+ * Concatenate results
+
+ Parameters
+ ----------
+ args: List[PrimExpr]
+ List of arguments. Each arg expression should have vector type with
lanes
+ equal `intrinsic_lanes * 2**n`.
+
+ intrinsic: callable
+ Intrinsic implementation to apply.
+
+ intrinsic_lanes: int
+ Vector length required by intrinsic implementation.
+
+ Returns
+ -------
+ res : PrimExpr
+ Resulting expression.
+ """
+
+ def split_args(args_set):
+ res_args_set = []
+ for args_chunk in args_set:
+ res_args_chunk_l = []
+ res_args_chunk_h = []
+ for arg_chunk in args_chunk:
+ element, lanes = arg_chunk.dtype.split("x")
+ res_arg_chunk_dtype = f"{element}x{int(lanes) // 2}"
+
+
res_args_chunk_l.append(tvm.tir.op.vectorlow(res_arg_chunk_dtype, arg_chunk))
+
res_args_chunk_h.append(tvm.tir.op.vectorhigh(res_arg_chunk_dtype, arg_chunk))
+ res_args_set += [res_args_chunk_l, res_args_chunk_h]
+
+ return res_args_set
+
+ def concat_args(res_chunks):
+ merged_res_chunks = []
+ for i in range(0, len(res_chunks), 2):
+ arg_chunk_l = res_chunks[i]
+ arg_chunk_h = res_chunks[i + 1]
+ element, lanes = arg_chunk_l.dtype.split("x")
+ res_arg_chunk_dtype = f"{element}x{int(lanes) * 2}"
+
+ merged_res_chunks.append(
+ tvm.tir.op.vectorcombine(res_arg_chunk_dtype, arg_chunk_l,
arg_chunk_h)
+ )
+
+ return merged_res_chunks
+
+ num_chunks = None
+ for arg in args:
+ _, lanes = arg.dtype.split("x")
+ lanes = int(lanes)
+ assert lanes % intrinsic_lanes == 0
+ if num_chunks is None:
+ assert is_power_of_2(lanes // intrinsic_lanes)
+ num_chunks = lanes // intrinsic_lanes
+
+ assert num_chunks == lanes // intrinsic_lanes
+
+ # Split arguments
+ lowered_args = [args]
+ while len(lowered_args) != num_chunks:
+ lowered_args = split_args(lowered_args)
+
+ # Intrinsic application
+ lowered_res = []
+ for l_arg in lowered_args:
+ res = intrinsic(*l_arg)
+ lowered_res.append(res)
+
+ # Result concatenation
+ while len(lowered_res) != 1:
+ lowered_res = concat_args(lowered_res)
+
+ return lowered_res[0]
+
+
def _q_multiply_shift_hexagon(op):
"""
Implementation of q_multiply_shift through hexagon intrinsics vmpyewuh and
vmpyowh when q == 31.
"""
- x = op.args[0]
- y = op.args[1]
- fractional_bits = op.args[2]
- shift = op.args[3]
-
- # Don't use this intrinsic if we don't have a int32x32 vector
- # or if we are not multiplying q31 numbers
- if x.dtype != "int32x32" or fractional_bits.value != 31:
- return op
+ arg_x = op.args[0]
+ arg_fractional_bits = op.args[2]
- # Case 1, shift is negative
- mul_e_1 = tvm.tir.call_llvm_intrin(
- op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"),
x, y
- )
- mul_o_1 = tvm.tir.call_llvm_intrin(
- op.dtype, "llvm.hexagon.V6.vmpyowh.sacc.128B", tvm.tir.const(3,
"uint32"), mul_e_1, x, y
- )
- fixup = 1 << (-shift - 1)
- round_mul = mul_o_1 + fixup
- out_negative_shift = tvm.tir.call_llvm_intrin(
- op.dtype, "llvm.hexagon.V6.vaslwv.128B", tvm.tir.const(2, "uint32"),
round_mul, shift
- )
+ # Don't use this intrinsic if we are not multiplying q31 numbers
+ if arg_fractional_bits.value != 31:
+ return op
- # Case 2, shift is positive
- x = x * (1 << (shift))
- mul_e_2 = tvm.tir.call_llvm_intrin(
- op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"),
x, y
- )
- mul_o_2 = tvm.tir.call_llvm_intrin(
- op.dtype, "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", tvm.tir.const(3,
"uint32"), mul_e_2, x, y
- )
+ x_lanes = get_lanes(arg_x.dtype)
+ if x_lanes % 32 != 0 or not is_power_of_2(x_lanes // 32):
+ return op
- # Select depending on the shift
- return tvm.tir.Select(shift < 0, out_negative_shift, mul_o_2)
+ # pylint: disable=unused-argument
+ def intrinsic_lowering_32(x, y, fractional_bits, shift):
+ lowered_dtype = "int32x32"
+
+ # Case 1, shift is negative
+ mul_e_1 = tvm.tir.call_llvm_intrin(
+ lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2,
"uint32"), x, y
+ )
+ mul_o_1 = tvm.tir.call_llvm_intrin(
+ lowered_dtype,
+ "llvm.hexagon.V6.vmpyowh.sacc.128B",
+ tvm.tir.const(3, "uint32"),
+ mul_e_1,
+ x,
+ y,
+ )
+ fixup = 1 << (-shift - 1)
+ round_mul = mul_o_1 + fixup
+ out_negative_shift = tvm.tir.call_llvm_intrin(
+ lowered_dtype,
+ "llvm.hexagon.V6.vaslwv.128B",
+ tvm.tir.const(2, "uint32"),
+ round_mul,
+ shift,
+ )
+
+ # Case 2, shift is positive
+ x = x * (1 << (shift))
+ mul_e_2 = tvm.tir.call_llvm_intrin(
+ lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2,
"uint32"), x, y
+ )
+ mul_o_2 = tvm.tir.call_llvm_intrin(
+ lowered_dtype,
+ "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B",
+ tvm.tir.const(3, "uint32"),
+ mul_e_2,
+ x,
+ y,
+ )
+
+ # Select depending on the shift
+ return tvm.tir.Select(shift < 0, out_negative_shift, mul_o_2)
+
+ return _adapt_to_highest_lanes(*op.args, intrinsic=intrinsic_lowering_32,
intrinsic_lanes=32)
register_intrin_lowering(
@@ -72,65 +193,87 @@ def _q_multiply_shift_per_axis_hexagon(op):
Implementation of q_multiply_shift_per_axis through hexagon intrinsics
vmpyewuh and vmpyowh when
q == 31.
"""
- x = op.args[0]
- y = op.args[1]
- left_shift = op.args[2]
- right_shift = op.args[3]
- fractional_bits = op.args[4]
- is_lshift_required = op.args[5]
- is_rshift_required = op.args[6]
-
- # Don't use this intrinsic if we don't have a int32x32 vector
- # or if we are not multiplying q31 numbers
- if x.dtype != "int32x32" or fractional_bits.value != 31:
+ arg_x = op.args[0]
+ arg_fractional_bits = op.args[4]
+ arg_is_lshift_required = op.args[5]
+ arg_is_rshift_required = op.args[6]
+
+ # Don't use this intrinsic if we are not multiplying q31 numbers
+ if arg_fractional_bits.value != 31:
+ return op
+
+ x_lanes = get_lanes(arg_x.dtype)
+ if x_lanes % 32 != 0 or not is_power_of_2(x_lanes // 32):
return op
# Don't use this intrinsic when we need do both: left and right shifts.
# For now it is not clear how to implement this case through vector HVX
instructions without
# accuracy drop.
- if is_rshift_required.value and is_lshift_required.value:
+ if arg_is_rshift_required.value and arg_is_lshift_required.value:
return op
- # Case 1: do the left shift
- shifted_x = x << left_shift
- mul_e_1 = tvm.tir.call_llvm_intrin(
- op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"),
shifted_x, y
- )
- left_shift_out = tvm.tir.call_llvm_intrin(
- op.dtype,
- "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B",
- tvm.tir.const(3, "uint32"),
- mul_e_1,
- shifted_x,
- y,
- )
-
- # Case 2: do the right shift
- mul_e_2 = tvm.tir.call_llvm_intrin(
- op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"),
x, y
- )
- mul_o_2 = tvm.tir.call_llvm_intrin(
- op.dtype, "llvm.hexagon.V6.vmpyowh.sacc.128B", tvm.tir.const(3,
"uint32"), mul_e_2, x, y
- )
- fixup = 1 << (right_shift - 1)
- round_mul = mul_o_2 + fixup
- right_shift_out = tvm.tir.call_llvm_intrin(
- op.dtype, "llvm.hexagon.V6.vasrwv.128B", tvm.tir.const(2, "uint32"),
round_mul, right_shift
- )
-
- # Case 3: do neither right nor left shift
- mul_e_3 = tvm.tir.call_llvm_intrin(
- op.dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2, "uint32"),
x, y
- )
- no_shift_out = tvm.tir.call_llvm_intrin(
- op.dtype, "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B", tvm.tir.const(3,
"uint32"), mul_e_3, x, y
- )
-
- return tvm.tir.Select(
- tvm.tir.Not(tvm.tir.Or(is_lshift_required, is_rshift_required)),
- no_shift_out,
- tvm.tir.Select(is_lshift_required, left_shift_out, right_shift_out),
- )
+ # pylint: disable=unused-argument
+ def intrinsic_impl_32(
+ x, y, left_shift, right_shift, fractional_bits, is_lshift_required,
is_rshift_required
+ ):
+ lowered_dtype = "int32x32"
+
+ # Case 1: do the left shift
+ shifted_x = x << left_shift
+ mul_e_1 = tvm.tir.call_llvm_intrin(
+ lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2,
"uint32"), shifted_x, y
+ )
+ left_shift_out = tvm.tir.call_llvm_intrin(
+ lowered_dtype,
+ "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B",
+ tvm.tir.const(3, "uint32"),
+ mul_e_1,
+ shifted_x,
+ y,
+ )
+
+ # Case 2: do the right shift
+ mul_e_2 = tvm.tir.call_llvm_intrin(
+ lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2,
"uint32"), x, y
+ )
+ mul_o_2 = tvm.tir.call_llvm_intrin(
+ lowered_dtype,
+ "llvm.hexagon.V6.vmpyowh.sacc.128B",
+ tvm.tir.const(3, "uint32"),
+ mul_e_2,
+ x,
+ y,
+ )
+ fixup = 1 << (right_shift - 1)
+ round_mul = mul_o_2 + fixup
+ right_shift_out = tvm.tir.call_llvm_intrin(
+ lowered_dtype,
+ "llvm.hexagon.V6.vasrwv.128B",
+ tvm.tir.const(2, "uint32"),
+ round_mul,
+ right_shift,
+ )
+
+ # Case 3: do neither right nor left shift
+ mul_e_3 = tvm.tir.call_llvm_intrin(
+ lowered_dtype, "llvm.hexagon.V6.vmpyewuh.128B", tvm.tir.const(2,
"uint32"), x, y
+ )
+ no_shift_out = tvm.tir.call_llvm_intrin(
+ lowered_dtype,
+ "llvm.hexagon.V6.vmpyowh.rnd.sacc.128B",
+ tvm.tir.const(3, "uint32"),
+ mul_e_3,
+ x,
+ y,
+ )
+
+ return tvm.tir.Select(
+ tvm.tir.Not(tvm.tir.Or(is_lshift_required, is_rshift_required)),
+ no_shift_out,
+ tvm.tir.Select(is_lshift_required, left_shift_out,
right_shift_out),
+ )
+
+ return _adapt_to_highest_lanes(*op.args, intrinsic=intrinsic_impl_32,
intrinsic_lanes=32)
register_intrin_lowering(
diff --git a/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py
b/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py
index 5eac35f2d6..fdfe3ad2b7 100644
--- a/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py
+++ b/tests/python/contrib/test_hexagon/test_fixed_point_multiply.py
@@ -21,6 +21,7 @@ import numpy as np
import tvm.testing
from tvm import relay
+from tvm import te
from tvm.relay.backend import Executor
from tvm.contrib.hexagon.session import Session
from tvm.contrib.hexagon.pytest_plugin import HEXAGON_AOT_LLVM_TARGET
@@ -100,7 +101,7 @@ class TestFixedPointMultiply:
)
@tvm.testing.requires_hexagon
- def test_fixed_point_multiply(self, hexagon_session: Session, multiplier:
int, shift: int):
+ def test_per_tensor(self, hexagon_session: Session, multiplier: int,
shift: int):
"""Fixed point multiply test."""
ishape = (6, 32)
a = relay.var("a", relay.TensorType(ishape, "int32"))
@@ -169,6 +170,141 @@ class TestFixedPointMultiply:
tvm.testing.assert_allclose(hexagon_output, expected_output)
+ vector_size = tvm.testing.parameter(32, 64, 128, 256)
+
+ def test_per_tensor_with_lanes(self, hexagon_session: Session,
vector_size):
+ """Test fixed point multiply with vectorization.
+ Vectorization size is more than hw vector length"""
+ ishape = [2, 256, 16]
+
+ def q_mul_shift(shape):
+ x = te.placeholder(shape, name="X", dtype="int32")
+ out = te.compute(
+ shape,
+ lambda i, j, k: tvm.tir.q_multiply_shift(
+ x[i, j, k],
+ tvm.tir.const(1395864320, "int32"),
+ tvm.tir.const(31, "int32"),
+ tvm.tir.const(1, "int32"),
+ ),
+ name="compute",
+ )
+ return te.create_prim_func([x, out])
+
+ mod = q_mul_shift(ishape)
+
+ # Schedule with vectorization
+ sch = tvm.tir.Schedule(mod)
+ b00 = sch.get_block(name="compute", func_name="main")
+ fused = sch.fuse(*sch.get_loops(block=b00))
+ _, v = sch.split(loop=fused, factors=[None, vector_size])
+ sch.vectorize(v)
+
+ with tvm.transform.PassContext(opt_level=3):
+ hex_lib = tvm.build(sch.mod["main"],
target=get_hexagon_target("v68"))
+ host_lib = tvm.build(mod, target=tvm.target.Target("llvm"))
+
+ asm = hex_lib.get_source("asm")
+
+ # Check that 'vmpye' instruction was generated in asm file.
+ vmpye_regex = re.compile(r"v\d{1,2}.w =
vmpye\(v\d{1,2}.w,v\d{1,2}.uh\)")
+ assert vmpye_regex.search(asm) is not None
+
+ # Check that 'vmpyo' instruction was generated in asm file.
+ vmpyo_regex = re.compile(r"v\d{1,2}.w \+=
vmpyo\(v\d{1,2}.w,v\d{1,2}.h\):<<1:rnd:sat:shift")
+ assert vmpyo_regex.search(asm) is not None
+
+ # Verify accuracy
+ a_np = np.random.randint(-1000, 1000,
size=np.prod(ishape)).reshape(ishape).astype("int32")
+ b_np = np.random.randint(-1000, 1000,
size=np.prod(ishape)).reshape(ishape).astype("int32")
+ hex_args = [
+ tvm.runtime.ndarray.array(arg, device=hexagon_session.device,
mem_scope="global")
+ for arg in [a_np, b_np]
+ ]
+ host_args = [tvm.runtime.ndarray.array(arg) for arg in [a_np, b_np]]
+
+ hex_rt = hexagon_session.load_module(hex_lib)
+ hex_rt(*hex_args)
+ host_lib(*host_args)
+
+ assert np.allclose(hex_args[1].numpy(), host_args[1].numpy())
+
+ def test_per_channel_with_lanes(self, hexagon_session: Session,
vector_size):
+ """Test fixed point multiply with vectorization.
+ Vectorization size is more than hw vector length"""
+ a_shape = [2, 256, 16]
+ b_shape = [256]
+
+ def q_mul_shift(shape):
+ shift_shape = [shape[1]]
+ x = te.placeholder(shape, name="X", dtype="int32")
+ y = te.placeholder(shift_shape, name="X", dtype="int32")
+ l_shift = te.placeholder(shift_shape, name="X", dtype="int32")
+ r_shift = te.placeholder(shift_shape, name="X", dtype="int32")
+
+ out = te.compute(
+ shape,
+ lambda i, j, k: tvm.tir.q_multiply_shift_per_axis(
+ x[i, j, k],
+ y[j],
+ l_shift[j],
+ r_shift[j],
+ tvm.tir.const(31, "int32"),
+ tvm.tir.const(1, "bool"),
+ tvm.tir.const(0, "bool"),
+ ),
+ name="compute",
+ )
+ return te.create_prim_func([x, y, l_shift, r_shift, out])
+
+ mod = q_mul_shift(a_shape)
+
+ # Schedule with vectorization
+ sch = tvm.tir.Schedule(mod)
+ b00 = sch.get_block(name="compute", func_name="main")
+ fused = sch.fuse(*sch.get_loops(block=b00))
+ _, v = sch.split(loop=fused, factors=[None, vector_size])
+ sch.vectorize(v)
+
+ with tvm.transform.PassContext(opt_level=3):
+ hex_lib = tvm.build(sch.mod["main"],
target=get_hexagon_target("v68"))
+ host_lib = tvm.build(mod, target=tvm.target.Target("llvm"))
+
+ asm = hex_lib.get_source("asm")
+
+ # Check that 'vmpye' instruction was generated in asm file.
+ vmpye_regex = re.compile(r"v\d{1,2}.w =
vmpye\(v\d{1,2}.w,v\d{1,2}.uh\)")
+ assert vmpye_regex.search(asm) is not None
+
+ # Check that 'vmpyo' instruction was generated in asm file.
+ vmpyo_regex = re.compile(r"v\d{1,2}.w \+=
vmpyo\(v\d{1,2}.w,v\d{1,2}.h\):<<1:rnd:sat:shift")
+ assert vmpyo_regex.search(asm) is not None
+
+ # Verify accuracy
+ x_np = (
+ np.random.randint(-1000, 1000,
size=np.prod(a_shape)).reshape(a_shape).astype("int32")
+ )
+ y_np = (
+ np.random.randint(-1000, 1000,
size=np.prod(b_shape)).reshape(b_shape).astype("int32")
+ )
+ lsh_np = np.random.randint(0, 10,
size=np.prod(b_shape)).reshape(b_shape).astype("int32")
+ rsh_np = np.random.randint(0, 10,
size=np.prod(b_shape)).reshape(b_shape).astype("int32")
+ b_np = (
+ np.random.randint(-1000, 1000,
size=np.prod(a_shape)).reshape(a_shape).astype("int32")
+ )
+ np_args = [x_np, y_np, lsh_np, rsh_np, b_np]
+ hex_args = [
+ tvm.runtime.ndarray.array(arg, device=hexagon_session.device,
mem_scope="global")
+ for arg in np_args
+ ]
+ host_args = [tvm.runtime.ndarray.array(arg) for arg in np_args]
+
+ hex_rt = hexagon_session.load_module(hex_lib)
+ hex_rt(*hex_args)
+ host_lib(*host_args)
+
+ assert np.allclose(hex_args[4].numpy(), host_args[4].numpy())
+
if __name__ == "__main__":
tvm.testing.main()