cbalint13 commented on code in PR #18182: URL: https://github.com/apache/tvm/pull/18182#discussion_r2289517280
########## python/tvm/meta_schedule/tune_context.py: ########## @@ -28,6 +28,7 @@ from tvm.runtime import Object from tvm.target import Target from tvm.tir import PrimFunc, Schedule +from tvm.target.codegen import target_has_features Review Comment: * Not needed here (see below) ########## python/tvm/tir/tensor_intrin/riscv_cpu.py: ########## @@ -0,0 +1,691 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,missing-function-docstring,unused-import +"""Intrinsics for RVV tensorization, both for C and LLVM targets. +===================== +**Author**: `Federico Peccia <https://fPecc.github.io/>`_ +""" +import re +import logging +from tvm.script import tir as T +from tvm.target.codegen import llvm_get_vector_width +from .. import TensorIntrin + +logger = logging.getLogger(__name__) + +##################################################### +# LLVM RISC-V Intrinsic usage: +# https://llvm.org/docs//RISCV/RISCVVectorExtension.html +# +# Vector types are represented using scalable vector +# types, of the form <vscale x n x ty>. n and ty +# control LMUL and SEW respectively (see table in docs). +# TVM represents this with dtype = "tyxvscalexn". +# +# n is calculated as (64/SEW)*LMUL. +# VL is passed to each intrinsic. +# +# Some examples (see table in docs): +# int8 vector type with LMUL = 1 => int8xvscalex8 +# int16 vector type with LMUL = 4 => int16xvscalex16 +# int32 vector type with LMUL = 2 => int32xvscalex4 +# +##################################################### + +##################################################### +# Helper functions +##################################################### + +RISCV_MIN_VL = 4 + + +def get_vlmax(vlen: int, lmul: int, max_sew: int) -> int: + """Return VLMAX + + Args: + vlen (int): Actual VLEN + lmul (int): LMUL + max_sew (int): SEW + + Returns: + int: VLMAX + """ + return (lmul * vlen) // max_sew + + +def get_vlen_from_mattrs(mattrs: list) -> int: + """Extract VLEN from LLVM mattrs list + + Args: + mattrs (list): LLVM list of CPU mattrs + + Returns: + int: VLEN + """ + vlen_regex = r"zvl(\d+)b" + vlen = 0 + for mattr in mattrs: + match = re.search(vlen_regex, mattr) + + if match: + vlen = int(match.group(1)) + break + return vlen + + +def _dtype_to_bits(dtype: str) -> int: + """Get bits from data type + + Args: + dtype (str): Data type + + Returns: + int: bits + """ + bits_per_item = int( + re.match(r"((float)|(int)|(uint))(?P<width_bits>[0-9]+)", dtype).group("width_bits") + ) + assert bits_per_item is not None, f"don't know how to compute size of type {dtype}" + return bits_per_item + + +def _get_dtype_string(dtype: str) -> str: + """Get only type of data type, without bits + + Args: + dtype (str): Data type + + Returns: + str: only string type + """ + return str(re.match(r"[a-z]+", dtype).group(0)) + + +##################################################### +# Parameterized intrinsics +##################################################### + + +def rvv_vmacc(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul: int): + # pylint: disable=unused-argument + input_bits = _dtype_to_bits(input_dtype) + output_bits = _dtype_to_bits(output_dtype) + + output_str_type = _get_dtype_string(output_dtype) + + output_dtype_prefix = output_str_type[0] + + input_lmul = lmul if output_dtype_prefix == "f" else lmul // 2 + + load_llvm_intrinsic = "llvm.riscv.vle" + expand_llvm_intrinsic = "llvm.riscv.vsext" + init_llvm_intrinsic = "llvm.riscv.vle" + macc_llvm_intrinsic = "llvm.riscv.vmacc" if output_dtype_prefix != "f" else "llvm.riscv.vfmacc" + store_llvm_intrinsic = "llvm.riscv.vse" + + # Calculated from https://llvm.org/docs//RISCV/RISCVVectorExtension.html + n_input_dtype = (64 // input_bits) * input_lmul + n_output_dtype = (64 // output_bits) * lmul + + llvm_input_dtype = f"{input_dtype}xvscalex{n_input_dtype}" + llvm_macc_dtype = f"{output_str_type}{output_bits}xvscalex{n_output_dtype}" + + broadcast_input = T.int16(0) if input_dtype == "int16" else T.float32(0) + broadcast_output = T.int32(0) if output_dtype == "int32" else T.float32(0) + + @T.prim_func + def rvv_vmacc_desc( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + C: T.Buffer((int(vlmax),), output_dtype, align=4, offset_factor=1), + ) -> None: + with T.block("root"): + T.reads(C[0 : int(vlmax)], A[0 : int(vlmax)], B[0 : int(vlmax)]) + T.writes(C[0 : int(vlmax)]) + for j in range(0, int(vlmax)): + with T.block("update"): + vj = T.axis.remap("S", [j]) + C[vj] = C[vj] + T.cast(A[vj], output_dtype) * T.cast(B[vj], output_dtype) + + @T.prim_func + def rvv_vmacc_llvm_impl( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + C: T.Buffer((int(vlmax),), output_dtype, align=4, offset_factor=1), + ) -> None: + + with T.block("root"): + + T.reads(A[0 : int(vlmax)], B[0 : int(vlmax)]) + T.writes(C[0 : int(vlmax)]) + + vec_A = ( + T.call_llvm_intrin( + llvm_macc_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_output, n_output_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + vec_B = ( + T.call_llvm_intrin( + llvm_macc_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_output, n_output_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + init = T.call_llvm_intrin( + llvm_macc_dtype, + init_llvm_intrinsic, + T.broadcast(broadcast_output, n_output_dtype * T.vscale()), + C.access_ptr(access_mask=C.READ, ptr_type="handle"), + T.uint64(vlmax), + ) + + product = ( + T.call_llvm_intrin( + llvm_macc_dtype, + macc_llvm_intrinsic, + init, + vec_A, + vec_B, + T.uint64(7), + T.uint64(vlmax), + T.uint64(3), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_macc_dtype, + macc_llvm_intrinsic, + init, + vec_A, + vec_B, + T.uint64(vlmax), + T.uint64(3), + ) + ) + + T.call_llvm_intrin( + "", + store_llvm_intrinsic, + product, + C.access_ptr(access_mask=C.WRITE, ptr_type="handle"), + T.uint64(vlmax), + ) + + return rvv_vmacc_desc, rvv_vmacc_llvm_impl + + +def rvv_multivmul(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul: int): + # pylint: disable=unused-argument + assert J > 1 + + input_bits = _dtype_to_bits(input_dtype) + kernel_bits = _dtype_to_bits(input_dtype) + output_bits = _dtype_to_bits(output_dtype) + + output_str_type = _get_dtype_string(output_dtype) + + output_dtype_prefix = ( + "i" if output_str_type == "int" else ("u" if output_str_type == "uint" else "f") + ) + + intermmediate_bits = output_bits if output_dtype_prefix == "f" else input_bits + kernel_bits + intermmediate_bits = input_bits + + load_llvm_intrinsic = "llvm.riscv.vle" + expand_llvm_intrinsic = "llvm.riscv.vsext" + init_llvm_intrinsic = ( + "llvm.riscv.vmv.v.x" if output_dtype_prefix != "f" else "llvm.riscv.vfmv.v.f" + ) + mult_llvm_intrinsic = "llvm.riscv.vmul" if output_dtype_prefix != "f" else "llvm.riscv.vfmul" + redsum_llvm_intrinsic = ( + "llvm.riscv.vwredsum" if output_dtype_prefix != "f" else "llvm.riscv.vfredusum" + ) + store_llvm_intrinsic = "llvm.riscv.vse" + + # Calculated from https://llvm.org/docs//RISCV/RISCVVectorExtension.html + # vscale = vlen // 64 + n_input_dtype = (64 // input_bits) * lmul + n_kernel_dtype = (64 // kernel_bits) * lmul + n_intermmediate_dtype = (64 // intermmediate_bits) * lmul + + n_redsum_dtype = (64 // output_bits) * 1 + + llvm_input_dtype = f"{input_dtype}xvscalex{n_input_dtype}" + llvm_kernel_dtype = f"{input_dtype}xvscalex{n_kernel_dtype}" + llvm_redsum_dtype = f"{output_dtype}xvscalex{n_redsum_dtype}" + llvm_mult_dtype = f"{output_str_type}{intermmediate_bits}xvscalex{n_intermmediate_dtype}" + + broadcast_input = ( + T.int8(0) + if input_dtype == "int8" + else (T.int16(0) if input_dtype == "int16" else T.float32(0)) + ) + broadcast_kernel = ( + T.int8(0) + if input_dtype == "int8" + else (T.int16(0) if input_dtype == "int16" else T.float32(0)) + ) + broadcast_intermmediate = T.int16(0) if intermmediate_bits == 16 else T.int32(0) + broadcast_output = T.int32(0) if output_dtype == "int32" else T.float32(0) + + @T.prim_func + def rvv_multivmul_desc( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((J, int(vlmax)), input_dtype, align=4, offset_factor=1), + C: T.Buffer((J,), output_dtype, align=4, offset_factor=1), + ) -> None: + with T.block("root"): + T.reads(C[0:J], A[0 : int(vlmax)], B[0:J, 0 : int(vlmax)]) + T.writes(C[0:J]) + for j in range(0, J): + for k in range(0, int(vlmax)): + with T.block("update"): + vj, vk = T.axis.remap("SR", [j, k]) + C[vj] = C[vj] + T.cast(A[vk], output_dtype) * T.cast( + B[vj, vk], output_dtype + ) + + @T.prim_func + def rvv_multivmul_llvm_impl( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer( + (J, int(vlmax)), input_dtype, align=4, offset_factor=1, strides=[T.int32(), T.int32()] + ), + C: T.Buffer((J,), output_dtype, align=4, offset_factor=1), + ) -> None: + + with T.block("root"): + + T.reads(A[0 : int(vlmax)], B[0:J, 0 : int(vlmax)]) + T.writes(C[0:J]) + + vec_A = ( + T.call_llvm_intrin( + llvm_mult_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + vec_B = ( + T.call_llvm_intrin( + llvm_mult_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_kernel_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_kernel, n_kernel_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + redsum = T.call_llvm_intrin( + llvm_redsum_dtype, + init_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + C[0], + T.uint64(1), + ) + + product = ( + T.call_llvm_intrin( + llvm_mult_dtype, + mult_llvm_intrinsic, + T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), + vec_A, + vec_B, + T.uint64(7), + T.uint64(vlmax), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_mult_dtype, + mult_llvm_intrinsic, + T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), + vec_A, + vec_B, + T.uint64(vlmax), + ) + ) + + redsum_result = ( + T.call_llvm_intrin( + llvm_redsum_dtype, + redsum_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + product, + redsum, + T.uint64(7), + T.uint64(vlmax), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_redsum_dtype, + redsum_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + product, + redsum, + T.uint64(vlmax), + ) + ) + + T.call_llvm_intrin( + "", + store_llvm_intrinsic, + redsum_result, + C.access_ptr(access_mask=C.WRITE, ptr_type="handle"), + T.uint64(1), + ) + + return rvv_multivmul_desc, rvv_multivmul_llvm_impl + + +def rvv_vmul(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul: int): + # pylint: disable=unused-argument + input_bits = _dtype_to_bits(input_dtype) + output_bits = _dtype_to_bits(output_dtype) + + output_str_type = _get_dtype_string(output_dtype) + + output_dtype_prefix = ( + "i" if output_str_type == "int" else ("u" if output_str_type == "uint" else "f") + ) + + intermmediate_bits = output_bits if output_dtype_prefix == "f" else input_bits * 2 + intermmediate_bits = input_bits + + load_llvm_intrinsic = "llvm.riscv.vle" + expand_llvm_intrinsic = "llvm.riscv.vsext" + init_llvm_intrinsic = ( + "llvm.riscv.vmv.v.x" if output_dtype_prefix != "f" else "llvm.riscv.vfmv.v.f" + ) + mult_llvm_intrinsic = "llvm.riscv.vmul" if output_dtype_prefix != "f" else "llvm.riscv.vfmul" + redsum_llvm_intrinsic = ( + "llvm.riscv.vwredsum" if output_dtype_prefix != "f" else "llvm.riscv.vfredusum" + ) + store_llvm_intrinsic = "llvm.riscv.vse" + + # Calculated from https://llvm.org/docs//RISCV/RISCVVectorExtension.html + # vscale = vlen // 64 + n_input_dtype = (64 // input_bits) * lmul + n_kernel_dtype = (64 // input_bits) * lmul + n_intermmediate_dtype = (64 // intermmediate_bits) * lmul + + n_redsum_dtype = (64 // output_bits) * 1 + + llvm_input_dtype = f"{input_dtype}xvscalex{n_input_dtype}" + llvm_kernel_dtype = f"{input_dtype}xvscalex{n_kernel_dtype}" + llvm_redsum_dtype = f"{output_dtype}xvscalex{n_redsum_dtype}" + llvm_mult_dtype = f"{output_str_type}{intermmediate_bits}xvscalex{n_intermmediate_dtype}" + + broadcast_input = ( + T.int8(0) + if input_dtype == "int8" + else (T.int16(0) if input_dtype == "int16" else T.float32(0)) + ) + broadcast_kernel = ( + T.int8(0) + if input_dtype == "int8" + else (T.int16(0) if input_dtype == "int16" else T.float32(0)) + ) + broadcast_intermmediate = T.int16(0) if intermmediate_bits == 16 else T.int32(0) + broadcast_output = T.int32(0) if output_dtype == "int32" else T.float32(0) + + @T.prim_func + def rvv_vmul_desc( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + C: T.Buffer((1,), output_dtype, align=4, offset_factor=1), + ) -> None: + with T.block("root"): + T.reads(C[0], A[0 : int(vlmax)], B[0 : int(vlmax)]) + T.writes(C[0]) + for k in range(0, int(vlmax)): + with T.block("update"): + vk = T.axis.remap("R", [k]) + C[0] = C[0] + T.cast(A[vk], output_dtype) * T.cast(B[vk], output_dtype) + + @T.prim_func + def rvv_vmul_llvm_impl( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + C: T.Buffer((1,), output_dtype, align=4, offset_factor=1), + ) -> None: + + with T.block("root"): + + T.reads(A[0 : int(vlmax)], B[0 : int(vlmax)]) + T.writes(C[0]) + + vec_A = ( + T.call_llvm_intrin( + llvm_mult_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + vec_B = ( + T.call_llvm_intrin( + llvm_mult_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_kernel_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_kernel, n_kernel_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + redsum = T.call_llvm_intrin( + llvm_redsum_dtype, + init_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + C[0], + T.uint64(1), + ) + + product = ( + T.call_llvm_intrin( + llvm_mult_dtype, + mult_llvm_intrinsic, + T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), + vec_A, + vec_B, + T.uint64(7), + T.uint64(vlmax), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_mult_dtype, + mult_llvm_intrinsic, + T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), + vec_A, + vec_B, + T.uint64(vlmax), + ) + ) + + redsum_result = ( + T.call_llvm_intrin( + llvm_redsum_dtype, + redsum_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + product, + redsum, + T.uint64(7), + T.uint64(vlmax), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_redsum_dtype, + redsum_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + product, + redsum, + T.uint64(vlmax), + ) + ) + + T.call_llvm_intrin( + "", + store_llvm_intrinsic, + redsum_result, + C.access_ptr(access_mask=C.WRITE, ptr_type="handle"), + T.uint64(1), + ) + + return rvv_vmul_desc, rvv_vmul_llvm_impl + + +##################################################### +# Registering intrinsics +##################################################### + + +def register_intrinsic_combinations( + outer_loops, initial_vlmax, lmul, input_dtype, output_dtype, prefix, generator +): + for J in outer_loops: + current_vlmax = initial_vlmax + while current_vlmax >= RISCV_MIN_VL: + + name = f"{prefix}_{J}_{current_vlmax}_m{lmul}" + + desc, impl = generator(J, current_vlmax, input_dtype, output_dtype, lmul) + + logger.debug(f"Registering intrin {name}...") + + TensorIntrin.register(name, desc, impl, override=True) + + current_vlmax = current_vlmax // 2 + + +def register_riscv_tensor_intrinsics(target): + target_kind = target.kind.name + assert target_kind in ["llvm"] + + vlen = llvm_get_vector_width(target) + + for vmul_type, func, outer_loops in zip( + ["vmacc", "multivmul", "vmul"], + [rvv_vmacc, rvv_multivmul, rvv_vmul], + [[1], [get_vlmax(vlen, lmul=1, max_sew=32)], [1]], + ): + + for idtype, odtype in zip(["int16", "float16", "float32"], ["int32", "float16", "float32"]): + + if idtype == "float32" and vmul_type == "multivmul": + continue + + vlmax = get_vlmax(vlen, lmul=8, max_sew=32) + register_intrinsic_combinations( + outer_loops, vlmax, 8, idtype, odtype, f"rvv_{idtype}_{vmul_type}", func + ) + + print("Finished registering all intrinsics.") Review Comment: Lets use ```logger.debug``` ########## python/tvm/meta_schedule/tune_context.py: ########## @@ -117,6 +118,13 @@ def __init__( if target is not None: if not isinstance(target, Target): target = Target(target) + if "riscv_cpu" in target.keys: + if target_has_features("v", target): + # Because the RVV intrinsics depend on the target, we register them here + # pylint: disable=import-outside-toplevel + from tvm.tir.tensor_intrin.riscv_cpu import register_riscv_tensor_intrinsics + + register_riscv_tensor_intrinsics(target) Review Comment: * Lets move this away to the very bottom of ```tvm/tir/tensor_intrin/riscv_cpu.py``` ########## src/meta_schedule/space_generator/space_generator.cc: ########## @@ -117,6 +128,11 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const TuneContext& context) { default_sch_rules = ScheduleRule::DefaultX86("avx512"); default_postprocs = Postproc::DefaultCPUTensorization(); default_mutator_probs = Mutator::DefaultLLVM(); + } else if (kind == "rvv") { + int vlen = GetRISCVVLENFromLLVMTarget(context->target.value()); + default_sch_rules = ScheduleRule::DefaultRISCV(vlen); + default_postprocs = Postproc::DefaultRISCV(); + default_mutator_probs = Mutator::DefaultLLVM(); } else if (kind == "asimd") { Review Comment: Let's also go experimental mode here too (explained earlier in ```tir/tensor_intrin/riscv_cpu.py```): ``` } else if (kind == "rvv") { if (context->target.value()->GetAttr<String>("model") == "rvv") { // experimental rvv tensorization int vlen = GetRISCVVLENFromLLVMTarget(context->target.value()); default_sch_rules = ScheduleRule::DefaultRISCV(vlen); default_postprocs = Postproc::DefaultRISCV(); } else { default_sch_rules = ScheduleRule::DefaultLLVM(); default_postprocs = Postproc::DefaultLLVM(); } default_mutator_probs = Mutator::DefaultLLVM(); } else if (kind == "asimd") { ``` ########## src/meta_schedule/schedule_rule/schedule_rule.cc: ########## @@ -304,6 +304,122 @@ Array<ScheduleRule> ScheduleRule::DefaultHexagon() { }; } +int GetVLMAX(int vlen, int lmul, int max_sew) { return (lmul * vlen) / max_sew; } + +Array<ScheduleRule> ScheduleRule::DefaultRISCV(int vlen) { + Array<ScheduleRule> rules; + + rules.push_back(ScheduleRule::ApplyCustomRule()); + + rules.push_back(ScheduleRule::InlineConstantScalars()); + + rules.push_back(ScheduleRule::AutoInline( + /*into_producer=*/false, + /*into_consumer=*/true, + /*inline_const_tensor=*/true, + /*disallow_if_then_else=*/true, + /*require_injective=*/true, + /*require_ordered=*/true, + /*disallow_op=*/Array<String>{"tir.exp"})); + + rules.push_back(ScheduleRule::AddRFactor( + /*max_jobs_per_core=*/16, + /*max_innermost_factor=*/Integer(64))); + + int vlmax = 0; + int RISCV_MIN_VL = 4; + std::vector<std::string> vmul_types = {"multivmul", "vmul", "vmacc"}; + String intrin_name = ""; + int j = 1; + + for (const std::string& vmul_type : vmul_types) { + if (vmul_type == "multivmul") + j = GetVLMAX(vlen, 1, 32); + else + j = 1; + + // Registering for int16 + vlmax = GetVLMAX(vlen, 8, 32); + while (vlmax >= RISCV_MIN_VL) { + intrin_name = + "rvv_int16_" + vmul_type + "_" + std::to_string(j) + "_" + std::to_string(vlmax) + "_m8"; + rules.push_back(ScheduleRule::MultiLevelTilingWithIntrin( + /*intrin_name=*/intrin_name, + /*structure=*/"SSRSRS", + /*tile_binds=*/std::nullopt, + /*max_innermost_factor=*/Integer(vlmax), + /*vector_load_lens=*/std::nullopt, + /*reuse_read=*/std::nullopt, + /*reuse_write=*/ + Map<String, ffi::Any>{{"req", String("may")}, + {"levels", Array<Integer>{1, 2}}, + {"scope", String("global")}})); + vlmax /= 2; + } + + // Registering for float16 + vlmax = GetVLMAX(vlen, 8, 16); Review Comment: I think it should be ```GetVLMAX(vlen, 8, 32)```, otherwise we get error: ```ValueError: TensorIntrin 'rvv_float16_multivmul_8_128_m8' is not registered``` 16 is not consistent with what is declared in ```tir/tensor_intrin/riscv_cpu.py``` ########## python/tvm/tir/tensor_intrin/riscv_cpu.py: ########## @@ -0,0 +1,691 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,missing-function-docstring,unused-import +"""Intrinsics for RVV tensorization, both for C and LLVM targets. +===================== +**Author**: `Federico Peccia <https://fPecc.github.io/>`_ Review Comment: Along with the code author, you could also add the arxiv paper like e.g.: https://github.com/apache/tvm/blob/main/python/tvm/topi/nn/winograd_util.py#L18-L23 https://github.com/apache/tvm/blob/main/python/tvm/topi/nn/winograd_util.py#L93-L98 If you wish to also add the arxiv work then this entry becomes ```s/Author/Code Author/``` ########## python/tvm/tir/tensor_intrin/riscv_cpu.py: ########## @@ -0,0 +1,691 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,missing-function-docstring,unused-import +"""Intrinsics for RVV tensorization, both for C and LLVM targets. +===================== +**Author**: `Federico Peccia <https://fPecc.github.io/>`_ +""" +import re +import logging +from tvm.script import tir as T +from tvm.target.codegen import llvm_get_vector_width +from .. import TensorIntrin + +logger = logging.getLogger(__name__) + +##################################################### +# LLVM RISC-V Intrinsic usage: +# https://llvm.org/docs//RISCV/RISCVVectorExtension.html +# +# Vector types are represented using scalable vector +# types, of the form <vscale x n x ty>. n and ty +# control LMUL and SEW respectively (see table in docs). +# TVM represents this with dtype = "tyxvscalexn". +# +# n is calculated as (64/SEW)*LMUL. +# VL is passed to each intrinsic. +# +# Some examples (see table in docs): +# int8 vector type with LMUL = 1 => int8xvscalex8 +# int16 vector type with LMUL = 4 => int16xvscalex16 +# int32 vector type with LMUL = 2 => int32xvscalex4 +# +##################################################### + +##################################################### +# Helper functions +##################################################### + +RISCV_MIN_VL = 4 + + +def get_vlmax(vlen: int, lmul: int, max_sew: int) -> int: + """Return VLMAX + + Args: + vlen (int): Actual VLEN + lmul (int): LMUL + max_sew (int): SEW + + Returns: + int: VLMAX + """ + return (lmul * vlen) // max_sew + + +def get_vlen_from_mattrs(mattrs: list) -> int: + """Extract VLEN from LLVM mattrs list + + Args: + mattrs (list): LLVM list of CPU mattrs + + Returns: + int: VLEN + """ + vlen_regex = r"zvl(\d+)b" + vlen = 0 + for mattr in mattrs: + match = re.search(vlen_regex, mattr) + + if match: + vlen = int(match.group(1)) + break + return vlen + + +def _dtype_to_bits(dtype: str) -> int: + """Get bits from data type + + Args: + dtype (str): Data type + + Returns: + int: bits + """ + bits_per_item = int( + re.match(r"((float)|(int)|(uint))(?P<width_bits>[0-9]+)", dtype).group("width_bits") + ) + assert bits_per_item is not None, f"don't know how to compute size of type {dtype}" + return bits_per_item + + +def _get_dtype_string(dtype: str) -> str: + """Get only type of data type, without bits + + Args: + dtype (str): Data type + + Returns: + str: only string type + """ + return str(re.match(r"[a-z]+", dtype).group(0)) + + +##################################################### +# Parameterized intrinsics +##################################################### + + +def rvv_vmacc(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul: int): + # pylint: disable=unused-argument + input_bits = _dtype_to_bits(input_dtype) + output_bits = _dtype_to_bits(output_dtype) + + output_str_type = _get_dtype_string(output_dtype) + + output_dtype_prefix = output_str_type[0] + + input_lmul = lmul if output_dtype_prefix == "f" else lmul // 2 + + load_llvm_intrinsic = "llvm.riscv.vle" + expand_llvm_intrinsic = "llvm.riscv.vsext" + init_llvm_intrinsic = "llvm.riscv.vle" + macc_llvm_intrinsic = "llvm.riscv.vmacc" if output_dtype_prefix != "f" else "llvm.riscv.vfmacc" + store_llvm_intrinsic = "llvm.riscv.vse" + + # Calculated from https://llvm.org/docs//RISCV/RISCVVectorExtension.html + n_input_dtype = (64 // input_bits) * input_lmul + n_output_dtype = (64 // output_bits) * lmul + + llvm_input_dtype = f"{input_dtype}xvscalex{n_input_dtype}" + llvm_macc_dtype = f"{output_str_type}{output_bits}xvscalex{n_output_dtype}" + + broadcast_input = T.int16(0) if input_dtype == "int16" else T.float32(0) + broadcast_output = T.int32(0) if output_dtype == "int32" else T.float32(0) + + @T.prim_func + def rvv_vmacc_desc( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + C: T.Buffer((int(vlmax),), output_dtype, align=4, offset_factor=1), + ) -> None: + with T.block("root"): + T.reads(C[0 : int(vlmax)], A[0 : int(vlmax)], B[0 : int(vlmax)]) + T.writes(C[0 : int(vlmax)]) + for j in range(0, int(vlmax)): + with T.block("update"): + vj = T.axis.remap("S", [j]) + C[vj] = C[vj] + T.cast(A[vj], output_dtype) * T.cast(B[vj], output_dtype) + + @T.prim_func + def rvv_vmacc_llvm_impl( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + C: T.Buffer((int(vlmax),), output_dtype, align=4, offset_factor=1), + ) -> None: + + with T.block("root"): + + T.reads(A[0 : int(vlmax)], B[0 : int(vlmax)]) + T.writes(C[0 : int(vlmax)]) + + vec_A = ( + T.call_llvm_intrin( + llvm_macc_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_output, n_output_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + vec_B = ( + T.call_llvm_intrin( + llvm_macc_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_output, n_output_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + init = T.call_llvm_intrin( + llvm_macc_dtype, + init_llvm_intrinsic, + T.broadcast(broadcast_output, n_output_dtype * T.vscale()), + C.access_ptr(access_mask=C.READ, ptr_type="handle"), + T.uint64(vlmax), + ) + + product = ( + T.call_llvm_intrin( + llvm_macc_dtype, + macc_llvm_intrinsic, + init, + vec_A, + vec_B, + T.uint64(7), + T.uint64(vlmax), + T.uint64(3), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_macc_dtype, + macc_llvm_intrinsic, + init, + vec_A, + vec_B, + T.uint64(vlmax), + T.uint64(3), + ) + ) + + T.call_llvm_intrin( + "", + store_llvm_intrinsic, + product, + C.access_ptr(access_mask=C.WRITE, ptr_type="handle"), + T.uint64(vlmax), + ) + + return rvv_vmacc_desc, rvv_vmacc_llvm_impl + + +def rvv_multivmul(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul: int): + # pylint: disable=unused-argument + assert J > 1 + + input_bits = _dtype_to_bits(input_dtype) + kernel_bits = _dtype_to_bits(input_dtype) + output_bits = _dtype_to_bits(output_dtype) + + output_str_type = _get_dtype_string(output_dtype) + + output_dtype_prefix = ( + "i" if output_str_type == "int" else ("u" if output_str_type == "uint" else "f") + ) + + intermmediate_bits = output_bits if output_dtype_prefix == "f" else input_bits + kernel_bits + intermmediate_bits = input_bits + + load_llvm_intrinsic = "llvm.riscv.vle" + expand_llvm_intrinsic = "llvm.riscv.vsext" + init_llvm_intrinsic = ( + "llvm.riscv.vmv.v.x" if output_dtype_prefix != "f" else "llvm.riscv.vfmv.v.f" + ) + mult_llvm_intrinsic = "llvm.riscv.vmul" if output_dtype_prefix != "f" else "llvm.riscv.vfmul" + redsum_llvm_intrinsic = ( + "llvm.riscv.vwredsum" if output_dtype_prefix != "f" else "llvm.riscv.vfredusum" + ) + store_llvm_intrinsic = "llvm.riscv.vse" + + # Calculated from https://llvm.org/docs//RISCV/RISCVVectorExtension.html + # vscale = vlen // 64 + n_input_dtype = (64 // input_bits) * lmul + n_kernel_dtype = (64 // kernel_bits) * lmul + n_intermmediate_dtype = (64 // intermmediate_bits) * lmul + + n_redsum_dtype = (64 // output_bits) * 1 + + llvm_input_dtype = f"{input_dtype}xvscalex{n_input_dtype}" + llvm_kernel_dtype = f"{input_dtype}xvscalex{n_kernel_dtype}" + llvm_redsum_dtype = f"{output_dtype}xvscalex{n_redsum_dtype}" + llvm_mult_dtype = f"{output_str_type}{intermmediate_bits}xvscalex{n_intermmediate_dtype}" + + broadcast_input = ( + T.int8(0) + if input_dtype == "int8" + else (T.int16(0) if input_dtype == "int16" else T.float32(0)) + ) + broadcast_kernel = ( + T.int8(0) + if input_dtype == "int8" + else (T.int16(0) if input_dtype == "int16" else T.float32(0)) + ) + broadcast_intermmediate = T.int16(0) if intermmediate_bits == 16 else T.int32(0) + broadcast_output = T.int32(0) if output_dtype == "int32" else T.float32(0) + + @T.prim_func + def rvv_multivmul_desc( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((J, int(vlmax)), input_dtype, align=4, offset_factor=1), + C: T.Buffer((J,), output_dtype, align=4, offset_factor=1), + ) -> None: + with T.block("root"): + T.reads(C[0:J], A[0 : int(vlmax)], B[0:J, 0 : int(vlmax)]) + T.writes(C[0:J]) + for j in range(0, J): + for k in range(0, int(vlmax)): + with T.block("update"): + vj, vk = T.axis.remap("SR", [j, k]) + C[vj] = C[vj] + T.cast(A[vk], output_dtype) * T.cast( + B[vj, vk], output_dtype + ) + + @T.prim_func + def rvv_multivmul_llvm_impl( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer( + (J, int(vlmax)), input_dtype, align=4, offset_factor=1, strides=[T.int32(), T.int32()] + ), + C: T.Buffer((J,), output_dtype, align=4, offset_factor=1), + ) -> None: + + with T.block("root"): + + T.reads(A[0 : int(vlmax)], B[0:J, 0 : int(vlmax)]) + T.writes(C[0:J]) + + vec_A = ( + T.call_llvm_intrin( + llvm_mult_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + vec_B = ( + T.call_llvm_intrin( + llvm_mult_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_kernel_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_kernel, n_kernel_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + redsum = T.call_llvm_intrin( + llvm_redsum_dtype, + init_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + C[0], + T.uint64(1), + ) + + product = ( + T.call_llvm_intrin( + llvm_mult_dtype, + mult_llvm_intrinsic, + T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), + vec_A, + vec_B, + T.uint64(7), + T.uint64(vlmax), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_mult_dtype, + mult_llvm_intrinsic, + T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), + vec_A, + vec_B, + T.uint64(vlmax), + ) + ) + + redsum_result = ( + T.call_llvm_intrin( + llvm_redsum_dtype, + redsum_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + product, + redsum, + T.uint64(7), + T.uint64(vlmax), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_redsum_dtype, + redsum_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + product, + redsum, + T.uint64(vlmax), + ) + ) + + T.call_llvm_intrin( + "", + store_llvm_intrinsic, + redsum_result, + C.access_ptr(access_mask=C.WRITE, ptr_type="handle"), + T.uint64(1), + ) + + return rvv_multivmul_desc, rvv_multivmul_llvm_impl + + +def rvv_vmul(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul: int): + # pylint: disable=unused-argument + input_bits = _dtype_to_bits(input_dtype) + output_bits = _dtype_to_bits(output_dtype) + + output_str_type = _get_dtype_string(output_dtype) + + output_dtype_prefix = ( + "i" if output_str_type == "int" else ("u" if output_str_type == "uint" else "f") + ) + + intermmediate_bits = output_bits if output_dtype_prefix == "f" else input_bits * 2 + intermmediate_bits = input_bits + + load_llvm_intrinsic = "llvm.riscv.vle" + expand_llvm_intrinsic = "llvm.riscv.vsext" + init_llvm_intrinsic = ( + "llvm.riscv.vmv.v.x" if output_dtype_prefix != "f" else "llvm.riscv.vfmv.v.f" + ) + mult_llvm_intrinsic = "llvm.riscv.vmul" if output_dtype_prefix != "f" else "llvm.riscv.vfmul" + redsum_llvm_intrinsic = ( + "llvm.riscv.vwredsum" if output_dtype_prefix != "f" else "llvm.riscv.vfredusum" + ) + store_llvm_intrinsic = "llvm.riscv.vse" + + # Calculated from https://llvm.org/docs//RISCV/RISCVVectorExtension.html + # vscale = vlen // 64 + n_input_dtype = (64 // input_bits) * lmul + n_kernel_dtype = (64 // input_bits) * lmul + n_intermmediate_dtype = (64 // intermmediate_bits) * lmul + + n_redsum_dtype = (64 // output_bits) * 1 + + llvm_input_dtype = f"{input_dtype}xvscalex{n_input_dtype}" + llvm_kernel_dtype = f"{input_dtype}xvscalex{n_kernel_dtype}" + llvm_redsum_dtype = f"{output_dtype}xvscalex{n_redsum_dtype}" + llvm_mult_dtype = f"{output_str_type}{intermmediate_bits}xvscalex{n_intermmediate_dtype}" + + broadcast_input = ( + T.int8(0) + if input_dtype == "int8" + else (T.int16(0) if input_dtype == "int16" else T.float32(0)) + ) + broadcast_kernel = ( + T.int8(0) + if input_dtype == "int8" + else (T.int16(0) if input_dtype == "int16" else T.float32(0)) + ) + broadcast_intermmediate = T.int16(0) if intermmediate_bits == 16 else T.int32(0) + broadcast_output = T.int32(0) if output_dtype == "int32" else T.float32(0) + + @T.prim_func + def rvv_vmul_desc( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + C: T.Buffer((1,), output_dtype, align=4, offset_factor=1), + ) -> None: + with T.block("root"): + T.reads(C[0], A[0 : int(vlmax)], B[0 : int(vlmax)]) + T.writes(C[0]) + for k in range(0, int(vlmax)): + with T.block("update"): + vk = T.axis.remap("R", [k]) + C[0] = C[0] + T.cast(A[vk], output_dtype) * T.cast(B[vk], output_dtype) + + @T.prim_func + def rvv_vmul_llvm_impl( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + C: T.Buffer((1,), output_dtype, align=4, offset_factor=1), + ) -> None: + + with T.block("root"): + + T.reads(A[0 : int(vlmax)], B[0 : int(vlmax)]) + T.writes(C[0]) + + vec_A = ( + T.call_llvm_intrin( + llvm_mult_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + vec_B = ( + T.call_llvm_intrin( + llvm_mult_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_kernel_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_kernel, n_kernel_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + redsum = T.call_llvm_intrin( + llvm_redsum_dtype, + init_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + C[0], + T.uint64(1), + ) + + product = ( + T.call_llvm_intrin( + llvm_mult_dtype, + mult_llvm_intrinsic, + T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), + vec_A, + vec_B, + T.uint64(7), + T.uint64(vlmax), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_mult_dtype, + mult_llvm_intrinsic, + T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), + vec_A, + vec_B, + T.uint64(vlmax), + ) + ) + + redsum_result = ( + T.call_llvm_intrin( + llvm_redsum_dtype, + redsum_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + product, + redsum, + T.uint64(7), + T.uint64(vlmax), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_redsum_dtype, + redsum_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + product, + redsum, + T.uint64(vlmax), + ) + ) + + T.call_llvm_intrin( + "", + store_llvm_intrinsic, + redsum_result, + C.access_ptr(access_mask=C.WRITE, ptr_type="handle"), + T.uint64(1), + ) + + return rvv_vmul_desc, rvv_vmul_llvm_impl + + +##################################################### +# Registering intrinsics +##################################################### + + +def register_intrinsic_combinations( + outer_loops, initial_vlmax, lmul, input_dtype, output_dtype, prefix, generator +): + for J in outer_loops: + current_vlmax = initial_vlmax + while current_vlmax >= RISCV_MIN_VL: + + name = f"{prefix}_{J}_{current_vlmax}_m{lmul}" + + desc, impl = generator(J, current_vlmax, input_dtype, output_dtype, lmul) + + logger.debug(f"Registering intrin {name}...") + + TensorIntrin.register(name, desc, impl, override=True) + + current_vlmax = current_vlmax // 2 + + +def register_riscv_tensor_intrinsics(target): + target_kind = target.kind.name + assert target_kind in ["llvm"] + + vlen = llvm_get_vector_width(target) + + for vmul_type, func, outer_loops in zip( + ["vmacc", "multivmul", "vmul"], + [rvv_vmacc, rvv_multivmul, rvv_vmul], + [[1], [get_vlmax(vlen, lmul=1, max_sew=32)], [1]], + ): + + for idtype, odtype in zip(["int16", "float16", "float32"], ["int32", "float16", "float32"]): + + if idtype == "float32" and vmul_type == "multivmul": + continue + + vlmax = get_vlmax(vlen, lmul=8, max_sew=32) + register_intrinsic_combinations( + outer_loops, vlmax, 8, idtype, odtype, f"rvv_{idtype}_{vmul_type}", func + ) + + print("Finished registering all intrinsics.") Review Comment: Lets move the instantiation right here away from ```tvm/meta_schedule/tune_context.py```: ``` from tvm.target import Target {...} target = Target.current() if "riscv_cpu" in target.keys and "rvv" in target.model and target_has_features("v", target): register_riscv_tensor_intrinsics(target) ``` The instantiation now is inline with x86 options and other ones: https://github.com/apache/tvm/blob/main/python/tvm/tir/tensor_intrin/x86.py#L101-L111 Lets have this as experimental for a while, since it will land as new feature, so user must provide a extra ```llvm {...} -device=riscv_cpu -model=rvv``` argument to his tvm target. ########## python/tvm/tir/tensor_intrin/riscv_cpu.py: ########## @@ -0,0 +1,691 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,missing-function-docstring,unused-import +"""Intrinsics for RVV tensorization, both for C and LLVM targets. +===================== +**Author**: `Federico Peccia <https://fPecc.github.io/>`_ +""" +import re +import logging +from tvm.script import tir as T +from tvm.target.codegen import llvm_get_vector_width +from .. import TensorIntrin + +logger = logging.getLogger(__name__) + +##################################################### +# LLVM RISC-V Intrinsic usage: +# https://llvm.org/docs//RISCV/RISCVVectorExtension.html +# +# Vector types are represented using scalable vector +# types, of the form <vscale x n x ty>. n and ty +# control LMUL and SEW respectively (see table in docs). +# TVM represents this with dtype = "tyxvscalexn". +# +# n is calculated as (64/SEW)*LMUL. +# VL is passed to each intrinsic. +# +# Some examples (see table in docs): +# int8 vector type with LMUL = 1 => int8xvscalex8 +# int16 vector type with LMUL = 4 => int16xvscalex16 +# int32 vector type with LMUL = 2 => int32xvscalex4 +# +##################################################### + +##################################################### +# Helper functions +##################################################### + +RISCV_MIN_VL = 4 + + +def get_vlmax(vlen: int, lmul: int, max_sew: int) -> int: + """Return VLMAX + + Args: + vlen (int): Actual VLEN + lmul (int): LMUL + max_sew (int): SEW + + Returns: + int: VLMAX + """ + return (lmul * vlen) // max_sew + + +def get_vlen_from_mattrs(mattrs: list) -> int: + """Extract VLEN from LLVM mattrs list + + Args: + mattrs (list): LLVM list of CPU mattrs + + Returns: + int: VLEN + """ + vlen_regex = r"zvl(\d+)b" + vlen = 0 + for mattr in mattrs: + match = re.search(vlen_regex, mattr) + + if match: + vlen = int(match.group(1)) + break + return vlen + + +def _dtype_to_bits(dtype: str) -> int: + """Get bits from data type + + Args: + dtype (str): Data type + + Returns: + int: bits + """ + bits_per_item = int( + re.match(r"((float)|(int)|(uint))(?P<width_bits>[0-9]+)", dtype).group("width_bits") + ) + assert bits_per_item is not None, f"don't know how to compute size of type {dtype}" + return bits_per_item + + +def _get_dtype_string(dtype: str) -> str: + """Get only type of data type, without bits + + Args: + dtype (str): Data type + + Returns: + str: only string type + """ + return str(re.match(r"[a-z]+", dtype).group(0)) + + +##################################################### +# Parameterized intrinsics +##################################################### + + +def rvv_vmacc(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul: int): + # pylint: disable=unused-argument + input_bits = _dtype_to_bits(input_dtype) + output_bits = _dtype_to_bits(output_dtype) + + output_str_type = _get_dtype_string(output_dtype) + + output_dtype_prefix = output_str_type[0] + + input_lmul = lmul if output_dtype_prefix == "f" else lmul // 2 + + load_llvm_intrinsic = "llvm.riscv.vle" + expand_llvm_intrinsic = "llvm.riscv.vsext" + init_llvm_intrinsic = "llvm.riscv.vle" + macc_llvm_intrinsic = "llvm.riscv.vmacc" if output_dtype_prefix != "f" else "llvm.riscv.vfmacc" + store_llvm_intrinsic = "llvm.riscv.vse" + + # Calculated from https://llvm.org/docs//RISCV/RISCVVectorExtension.html + n_input_dtype = (64 // input_bits) * input_lmul + n_output_dtype = (64 // output_bits) * lmul + + llvm_input_dtype = f"{input_dtype}xvscalex{n_input_dtype}" + llvm_macc_dtype = f"{output_str_type}{output_bits}xvscalex{n_output_dtype}" + + broadcast_input = T.int16(0) if input_dtype == "int16" else T.float32(0) + broadcast_output = T.int32(0) if output_dtype == "int32" else T.float32(0) + + @T.prim_func + def rvv_vmacc_desc( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + C: T.Buffer((int(vlmax),), output_dtype, align=4, offset_factor=1), + ) -> None: + with T.block("root"): + T.reads(C[0 : int(vlmax)], A[0 : int(vlmax)], B[0 : int(vlmax)]) + T.writes(C[0 : int(vlmax)]) + for j in range(0, int(vlmax)): + with T.block("update"): + vj = T.axis.remap("S", [j]) + C[vj] = C[vj] + T.cast(A[vj], output_dtype) * T.cast(B[vj], output_dtype) + + @T.prim_func + def rvv_vmacc_llvm_impl( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + C: T.Buffer((int(vlmax),), output_dtype, align=4, offset_factor=1), + ) -> None: + + with T.block("root"): + + T.reads(A[0 : int(vlmax)], B[0 : int(vlmax)]) + T.writes(C[0 : int(vlmax)]) + + vec_A = ( + T.call_llvm_intrin( + llvm_macc_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_output, n_output_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + vec_B = ( + T.call_llvm_intrin( + llvm_macc_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_output, n_output_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + init = T.call_llvm_intrin( + llvm_macc_dtype, + init_llvm_intrinsic, + T.broadcast(broadcast_output, n_output_dtype * T.vscale()), + C.access_ptr(access_mask=C.READ, ptr_type="handle"), + T.uint64(vlmax), + ) + + product = ( + T.call_llvm_intrin( + llvm_macc_dtype, + macc_llvm_intrinsic, + init, + vec_A, + vec_B, + T.uint64(7), + T.uint64(vlmax), + T.uint64(3), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_macc_dtype, + macc_llvm_intrinsic, + init, + vec_A, + vec_B, + T.uint64(vlmax), + T.uint64(3), + ) + ) + + T.call_llvm_intrin( + "", + store_llvm_intrinsic, + product, + C.access_ptr(access_mask=C.WRITE, ptr_type="handle"), + T.uint64(vlmax), + ) + + return rvv_vmacc_desc, rvv_vmacc_llvm_impl + + +def rvv_multivmul(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul: int): + # pylint: disable=unused-argument + assert J > 1 + + input_bits = _dtype_to_bits(input_dtype) + kernel_bits = _dtype_to_bits(input_dtype) + output_bits = _dtype_to_bits(output_dtype) + + output_str_type = _get_dtype_string(output_dtype) + + output_dtype_prefix = ( + "i" if output_str_type == "int" else ("u" if output_str_type == "uint" else "f") + ) + + intermmediate_bits = output_bits if output_dtype_prefix == "f" else input_bits + kernel_bits + intermmediate_bits = input_bits + + load_llvm_intrinsic = "llvm.riscv.vle" + expand_llvm_intrinsic = "llvm.riscv.vsext" + init_llvm_intrinsic = ( + "llvm.riscv.vmv.v.x" if output_dtype_prefix != "f" else "llvm.riscv.vfmv.v.f" + ) + mult_llvm_intrinsic = "llvm.riscv.vmul" if output_dtype_prefix != "f" else "llvm.riscv.vfmul" + redsum_llvm_intrinsic = ( + "llvm.riscv.vwredsum" if output_dtype_prefix != "f" else "llvm.riscv.vfredusum" + ) + store_llvm_intrinsic = "llvm.riscv.vse" + + # Calculated from https://llvm.org/docs//RISCV/RISCVVectorExtension.html + # vscale = vlen // 64 + n_input_dtype = (64 // input_bits) * lmul + n_kernel_dtype = (64 // kernel_bits) * lmul + n_intermmediate_dtype = (64 // intermmediate_bits) * lmul + + n_redsum_dtype = (64 // output_bits) * 1 + + llvm_input_dtype = f"{input_dtype}xvscalex{n_input_dtype}" + llvm_kernel_dtype = f"{input_dtype}xvscalex{n_kernel_dtype}" + llvm_redsum_dtype = f"{output_dtype}xvscalex{n_redsum_dtype}" + llvm_mult_dtype = f"{output_str_type}{intermmediate_bits}xvscalex{n_intermmediate_dtype}" + + broadcast_input = ( + T.int8(0) + if input_dtype == "int8" + else (T.int16(0) if input_dtype == "int16" else T.float32(0)) + ) + broadcast_kernel = ( + T.int8(0) + if input_dtype == "int8" + else (T.int16(0) if input_dtype == "int16" else T.float32(0)) + ) + broadcast_intermmediate = T.int16(0) if intermmediate_bits == 16 else T.int32(0) + broadcast_output = T.int32(0) if output_dtype == "int32" else T.float32(0) + + @T.prim_func + def rvv_multivmul_desc( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((J, int(vlmax)), input_dtype, align=4, offset_factor=1), + C: T.Buffer((J,), output_dtype, align=4, offset_factor=1), + ) -> None: + with T.block("root"): + T.reads(C[0:J], A[0 : int(vlmax)], B[0:J, 0 : int(vlmax)]) + T.writes(C[0:J]) + for j in range(0, J): + for k in range(0, int(vlmax)): + with T.block("update"): + vj, vk = T.axis.remap("SR", [j, k]) + C[vj] = C[vj] + T.cast(A[vk], output_dtype) * T.cast( + B[vj, vk], output_dtype + ) + + @T.prim_func + def rvv_multivmul_llvm_impl( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer( + (J, int(vlmax)), input_dtype, align=4, offset_factor=1, strides=[T.int32(), T.int32()] + ), + C: T.Buffer((J,), output_dtype, align=4, offset_factor=1), + ) -> None: + + with T.block("root"): + + T.reads(A[0 : int(vlmax)], B[0:J, 0 : int(vlmax)]) + T.writes(C[0:J]) + + vec_A = ( + T.call_llvm_intrin( + llvm_mult_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + vec_B = ( + T.call_llvm_intrin( + llvm_mult_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_kernel_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_kernel, n_kernel_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + redsum = T.call_llvm_intrin( + llvm_redsum_dtype, + init_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + C[0], + T.uint64(1), + ) + + product = ( + T.call_llvm_intrin( + llvm_mult_dtype, + mult_llvm_intrinsic, + T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), + vec_A, + vec_B, + T.uint64(7), + T.uint64(vlmax), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_mult_dtype, + mult_llvm_intrinsic, + T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), + vec_A, + vec_B, + T.uint64(vlmax), + ) + ) + + redsum_result = ( + T.call_llvm_intrin( + llvm_redsum_dtype, + redsum_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + product, + redsum, + T.uint64(7), + T.uint64(vlmax), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_redsum_dtype, + redsum_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + product, + redsum, + T.uint64(vlmax), + ) + ) + + T.call_llvm_intrin( + "", + store_llvm_intrinsic, + redsum_result, + C.access_ptr(access_mask=C.WRITE, ptr_type="handle"), + T.uint64(1), + ) + + return rvv_multivmul_desc, rvv_multivmul_llvm_impl + + +def rvv_vmul(J: int, vlmax: int, input_dtype: str, output_dtype: str, lmul: int): + # pylint: disable=unused-argument + input_bits = _dtype_to_bits(input_dtype) + output_bits = _dtype_to_bits(output_dtype) + + output_str_type = _get_dtype_string(output_dtype) + + output_dtype_prefix = ( + "i" if output_str_type == "int" else ("u" if output_str_type == "uint" else "f") + ) + + intermmediate_bits = output_bits if output_dtype_prefix == "f" else input_bits * 2 + intermmediate_bits = input_bits + + load_llvm_intrinsic = "llvm.riscv.vle" + expand_llvm_intrinsic = "llvm.riscv.vsext" + init_llvm_intrinsic = ( + "llvm.riscv.vmv.v.x" if output_dtype_prefix != "f" else "llvm.riscv.vfmv.v.f" + ) + mult_llvm_intrinsic = "llvm.riscv.vmul" if output_dtype_prefix != "f" else "llvm.riscv.vfmul" + redsum_llvm_intrinsic = ( + "llvm.riscv.vwredsum" if output_dtype_prefix != "f" else "llvm.riscv.vfredusum" + ) + store_llvm_intrinsic = "llvm.riscv.vse" + + # Calculated from https://llvm.org/docs//RISCV/RISCVVectorExtension.html + # vscale = vlen // 64 + n_input_dtype = (64 // input_bits) * lmul + n_kernel_dtype = (64 // input_bits) * lmul + n_intermmediate_dtype = (64 // intermmediate_bits) * lmul + + n_redsum_dtype = (64 // output_bits) * 1 + + llvm_input_dtype = f"{input_dtype}xvscalex{n_input_dtype}" + llvm_kernel_dtype = f"{input_dtype}xvscalex{n_kernel_dtype}" + llvm_redsum_dtype = f"{output_dtype}xvscalex{n_redsum_dtype}" + llvm_mult_dtype = f"{output_str_type}{intermmediate_bits}xvscalex{n_intermmediate_dtype}" + + broadcast_input = ( + T.int8(0) + if input_dtype == "int8" + else (T.int16(0) if input_dtype == "int16" else T.float32(0)) + ) + broadcast_kernel = ( + T.int8(0) + if input_dtype == "int8" + else (T.int16(0) if input_dtype == "int16" else T.float32(0)) + ) + broadcast_intermmediate = T.int16(0) if intermmediate_bits == 16 else T.int32(0) + broadcast_output = T.int32(0) if output_dtype == "int32" else T.float32(0) + + @T.prim_func + def rvv_vmul_desc( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + C: T.Buffer((1,), output_dtype, align=4, offset_factor=1), + ) -> None: + with T.block("root"): + T.reads(C[0], A[0 : int(vlmax)], B[0 : int(vlmax)]) + T.writes(C[0]) + for k in range(0, int(vlmax)): + with T.block("update"): + vk = T.axis.remap("R", [k]) + C[0] = C[0] + T.cast(A[vk], output_dtype) * T.cast(B[vk], output_dtype) + + @T.prim_func + def rvv_vmul_llvm_impl( + A: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + B: T.Buffer((int(vlmax),), input_dtype, align=4, offset_factor=1), + C: T.Buffer((1,), output_dtype, align=4, offset_factor=1), + ) -> None: + + with T.block("root"): + + T.reads(A[0 : int(vlmax)], B[0 : int(vlmax)]) + T.writes(C[0]) + + vec_A = ( + T.call_llvm_intrin( + llvm_mult_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + A.access_ptr(access_mask=A.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + vec_B = ( + T.call_llvm_intrin( + llvm_mult_dtype, + expand_llvm_intrinsic, + T.broadcast(broadcast_intermmediate, n_intermmediate_dtype * T.vscale()), + T.call_llvm_intrin( + llvm_input_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_input, n_input_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ), + T.int64(vlmax), + ) + if output_dtype_prefix != "f" + else T.call_llvm_intrin( + llvm_kernel_dtype, + load_llvm_intrinsic, + T.broadcast(broadcast_kernel, n_kernel_dtype * T.vscale()), + B.access_ptr(access_mask=B.READ, ptr_type="handle"), + T.int64(vlmax), + ) + ) + + redsum = T.call_llvm_intrin( + llvm_redsum_dtype, + init_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + C[0], + T.uint64(1), + ) + + product = ( + T.call_llvm_intrin( + llvm_mult_dtype, + mult_llvm_intrinsic, + T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), + vec_A, + vec_B, + T.uint64(7), + T.uint64(vlmax), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_mult_dtype, + mult_llvm_intrinsic, + T.broadcast(broadcast_output, n_intermmediate_dtype * T.vscale()), + vec_A, + vec_B, + T.uint64(vlmax), + ) + ) + + redsum_result = ( + T.call_llvm_intrin( + llvm_redsum_dtype, + redsum_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + product, + redsum, + T.uint64(7), + T.uint64(vlmax), + ) + if output_dtype_prefix == "f" + else T.call_llvm_intrin( + llvm_redsum_dtype, + redsum_llvm_intrinsic, + T.broadcast(broadcast_output, n_redsum_dtype * T.vscale()), + product, + redsum, + T.uint64(vlmax), + ) + ) + + T.call_llvm_intrin( + "", + store_llvm_intrinsic, + redsum_result, + C.access_ptr(access_mask=C.WRITE, ptr_type="handle"), + T.uint64(1), + ) + + return rvv_vmul_desc, rvv_vmul_llvm_impl + + +##################################################### +# Registering intrinsics +##################################################### + + +def register_intrinsic_combinations( + outer_loops, initial_vlmax, lmul, input_dtype, output_dtype, prefix, generator +): + for J in outer_loops: + current_vlmax = initial_vlmax + while current_vlmax >= RISCV_MIN_VL: + + name = f"{prefix}_{J}_{current_vlmax}_m{lmul}" + + desc, impl = generator(J, current_vlmax, input_dtype, output_dtype, lmul) + + logger.debug(f"Registering intrin {name}...") + + TensorIntrin.register(name, desc, impl, override=True) + + current_vlmax = current_vlmax // 2 + + +def register_riscv_tensor_intrinsics(target): + target_kind = target.kind.name + assert target_kind in ["llvm"] + + vlen = llvm_get_vector_width(target) + + for vmul_type, func, outer_loops in zip( + ["vmacc", "multivmul", "vmul"], + [rvv_vmacc, rvv_multivmul, rvv_vmul], + [[1], [get_vlmax(vlen, lmul=1, max_sew=32)], [1]], + ): + + for idtype, odtype in zip(["int16", "float16", "float32"], ["int32", "float16", "float32"]): + + if idtype == "float32" and vmul_type == "multivmul": + continue Review Comment: This gives the error: ```ValueError: TensorIntrin 'rvv_float32_multivmul_8_64_m8' is not registered``` I think it should be removed, errors go away and registrations will be even in all places. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
