This is an automated email from the ASF dual-hosted git repository.
cbalint13 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 06fb02e3fc [LLVM][METASCHEDULE] Add RISCV V-extension v1.0 kernels to
metaschedule (#18243)
06fb02e3fc is described below
commit 06fb02e3fcab3b2c9e449bb5590bebabeaea0faa
Author: Balint Cristian <[email protected]>
AuthorDate: Mon Sep 8 01:41:17 2025 +0300
[LLVM][METASCHEDULE] Add RISCV V-extension v1.0 kernels to metaschedule
(#18243)
- Enables high performance kernels covering majority of usual ML datatype
inputs
- It is currently compliant with RVV specs version v1.0 (does not work with
older v0.7.1)
- TIR kernels implemented here are using recently added VLA extension
support
---
include/tvm/meta_schedule/postproc.h | 2 +
include/tvm/meta_schedule/schedule_rule.h | 2 +
python/tvm/target/target.py | 8 +
python/tvm/tir/tensor_intrin/__init__.py | 2 +-
python/tvm/tir/tensor_intrin/riscv_cpu.py | 236 +++++++++++++++++++++
src/meta_schedule/postproc/postproc.cc | 8 +
src/meta_schedule/schedule_rule/schedule_rule.cc | 57 +++++
.../space_generator/space_generator.cc | 11 +
8 files changed, 325 insertions(+), 1 deletion(-)
diff --git a/include/tvm/meta_schedule/postproc.h
b/include/tvm/meta_schedule/postproc.h
index c511271d20..6ed7272fe9 100644
--- a/include/tvm/meta_schedule/postproc.h
+++ b/include/tvm/meta_schedule/postproc.h
@@ -166,6 +166,8 @@ class Postproc : public runtime::ObjectRef {
TVM_DLL static Array<Postproc, void> DefaultLLVM();
/*! \brief Create default postprocessors for x86 (AVX512 and VNNI) */
TVM_DLL static Array<Postproc, void> DefaultCPUTensorization();
+ /*! \brief Create default postprocessors for RISCV */
+ TVM_DLL static Array<Postproc, void> DefaultRISCV();
/*! \brief Create default postprocessors for CUDA */
TVM_DLL static Array<Postproc, void> DefaultCUDA();
/*! \brief Create default postprocessors for CUDA with TensorCore */
diff --git a/include/tvm/meta_schedule/schedule_rule.h
b/include/tvm/meta_schedule/schedule_rule.h
index 9011ebe0c1..407914e3d0 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -301,6 +301,8 @@ class ScheduleRule : public runtime::ObjectRef {
TVM_DLL static Array<ScheduleRule, void> DefaultHexagon();
/*! \brief Create default schedule rules for ARM CPU (NEON and DOTPROD) */
TVM_DLL static Array<ScheduleRule, void> DefaultARM(const String& type);
+ /*! \brief Create default schedule rules for RISCV CPU (RVV) */
+ TVM_DLL static Array<ScheduleRule, void> DefaultRISCV(int vlen);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef,
ScheduleRuleNode);
};
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index a9191df773..eb6e25f045 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -637,6 +637,14 @@ def riscv_cpu(model="sifive-u54", options=None):
"-mabi=lp64d",
# cc: riscv64-unknown-linux-gnu-g++ -march=rv64gc -mabi=lp64d
-mcpu=sifive-u74
],
+ "licheepi3a": [
+ "-num-cores=8",
+ "-mtriple=riscv64-unknown-linux-gnu",
+ "-mcpu=spacemit-x60",
+ "-mfloat-abi=hard",
+ "-mabi=lp64d",
+ # cc: riscv64-unknown-linux-gnu-g++ -march=rv64gcv -mabi=lp64d
-mcpu=spacemit-x60
+ ],
}
pre_defined_opt = trans_table.get(model, ["-model=%s" % model])
diff --git a/python/tvm/tir/tensor_intrin/__init__.py
b/python/tvm/tir/tensor_intrin/__init__.py
index 5646554552..0a6cf5310c 100644
--- a/python/tvm/tir/tensor_intrin/__init__.py
+++ b/python/tvm/tir/tensor_intrin/__init__.py
@@ -20,4 +20,4 @@ from tvm.runtime import enabled
from . import cuda
if enabled("llvm"):
- from . import arm_cpu, x86, rocm, hexagon
+ from . import arm_cpu, x86, rocm, hexagon, riscv_cpu
diff --git a/python/tvm/tir/tensor_intrin/riscv_cpu.py
b/python/tvm/tir/tensor_intrin/riscv_cpu.py
new file mode 100644
index 0000000000..febddc2bf3
--- /dev/null
+++ b/python/tvm/tir/tensor_intrin/riscv_cpu.py
@@ -0,0 +1,236 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,line-too-long
+"""Intrinsics for RISCV tensorization"""
+
+import logging
+from tvm.ffi import register_func
+from tvm.runtime import DataType
+from tvm.script import tir as T
+from tvm.target.codegen import llvm_get_vector_width, target_has_features,
Target
+from .. import TensorIntrin
+
+logger = logging.getLogger(__name__)
+
+
+def get_max_elems(vlen: int, lmul: int, sew: int) -> int:
+ """Returns number of elements of a given data type (SEW)
+ that fits multiple (LMUL) of the vector registers (VLEN).
+
+ Args:
+ vlen (int): VLEN vector length in bits
+ lmul (int): LMUL vector lenght multiplier
+ sew (int): SEW standard (single) element width
+
+ Returns:
+ int: Number of elements
+ """
+ return (vlen // sew) * lmul
+
+
+def rvv_vec_dot_product_kernels(
+ n_elems: int,
+ n_lanes: int,
+ data_dtype: str,
+ weight_dtype: str,
+ out_dtype: str,
+ lmul: int,
+):
+ """Dot product of vector and matrix rows using RISC-V vector instructions.
+
+ These kernels takes two arrays A[ELEMS] and B[ELEMS][MACS] and computes
+ dot product of A[ELEMS] with each row of B[LANES], accumulating results
+ with C[LANES].
+
+ The pseudo code is as follows:
+ .. code-block:: c
+ void vec_dot_prod(A[ELEMS], B[LANES][ELEMS], C[LANES]){
+ for (j = 0; j < LANES; j++) {
+ for (k = 0; k < ELEMS; k++) {
+ C[j] += A[k] * B[j][k]
+ }
+ }
+ }
+ """
+
+ @T.prim_func
+ def rvv_vec_dot_prod_desc(
+ A: T.Buffer((n_elems,), data_dtype, offset_factor=1),
+ B: T.Buffer((n_lanes, n_elems), weight_dtype, offset_factor=1),
+ C: T.Buffer((n_lanes,), out_dtype, offset_factor=1),
+ ) -> None:
+ with T.block("root"):
+ T.reads(C[0:n_lanes], A[0:n_elems], B[0:n_lanes, 0:n_elems])
+ T.writes(C[0:n_lanes])
+ for j in T.serial(0, n_lanes):
+ for k in T.serial(0, n_elems):
+ with T.block("update"):
+ vj, vk = T.axis.remap("SR", [j, k])
+ C[vj] = C[vj] + T.cast(A[vk], out_dtype) *
T.cast(B[vj, vk], out_dtype)
+
+ # LLVM only supports ELEN=32 or ELEN=64
+ # https://llvm.org/docs//RISCV/RISCVVectorExtension.html
+ d_dtype_lanes = (64 // DataType(data_dtype).bits) * lmul
+ w_dtype_lanes = (64 // DataType(weight_dtype).bits) * lmul
+ # reduction lanes narrows
+ o_dtype_lanes = (64 // DataType(out_dtype).bits) * lmul // n_lanes
+ # data type widening case
+ o_dtype_lanes = max(o_dtype_lanes, 2)
+
+ mask_args = () if data_dtype[0] in ("i", "u") else (T.uint64(7),)
+
+ wide_dtype = out_dtype
+ if DataType(out_dtype).bits > DataType(data_dtype).bits:
+ wide_dtype = "".join(c for c in data_dtype if not c.isdigit())
+ wide_dtype += str(DataType(data_dtype).bits * 2)
+
+ # fmt: off
+ @T.prim_func
+ def rvv_vec_dot_prod_impl(
+ A: T.Buffer((n_elems,), data_dtype, offset_factor=1),
+ B: T.Buffer((n_lanes, n_elems), weight_dtype, offset_factor=1),
+ C: T.Buffer((n_lanes,), out_dtype, offset_factor=1),
+ ) -> None:
+ with T.block("root"):
+ T.reads(C[0:n_lanes], A[0:n_elems], B[0:n_lanes, 0:n_elems])
+ T.writes(C[0:n_lanes])
+
+ vec_A = T.call_llvm_intrin(
+ f"{data_dtype}xvscalex{d_dtype_lanes}",
+ "llvm.riscv.vle",
+ T.broadcast(T.Cast(data_dtype, 0), T.vscale() * d_dtype_lanes),
+ T.tvm_access_ptr(T.type_annotation(data_dtype), A.data, 0,
n_elems, 1),
+ T.int64(n_elems))
+
+ for i in range(n_lanes):
+ with T.block("update"):
+ T.reads(B[i, 0:n_elems])
+ T.writes(C[i])
+
+ vec_B_row = T.call_llvm_intrin(
+ f"{weight_dtype}xvscalex{w_dtype_lanes}",
+ "llvm.riscv.vle",
+ T.broadcast(T.Cast(data_dtype, 0), T.vscale() *
w_dtype_lanes),
+ T.tvm_access_ptr(T.type_annotation(weight_dtype),
B.data, i * n_elems, n_elems, 1),
+ T.int64(n_elems))
+
+ product = T.call_llvm_intrin(
+ f"{wide_dtype}xvscalex{w_dtype_lanes}",
+ "llvm.riscv.vfmul" if out_dtype[0] == "f" else \
+ "llvm.riscv.vwmulsu" if (data_dtype[0] !=
weight_dtype[0]) else \
+ "llvm.riscv.vwmul",
+ T.broadcast(T.Cast(wide_dtype, 0), T.vscale() *
w_dtype_lanes),
+ vec_B_row,
+ vec_A,
+ *mask_args,
+ T.uint64(n_elems))
+
+ ini_acc = T.call_llvm_intrin(
+ f"{out_dtype}xvscalex{o_dtype_lanes}",
+ "llvm.riscv.vle",
+ T.broadcast(T.Cast(out_dtype, 0), T.vscale() *
o_dtype_lanes),
+ T.tvm_access_ptr(T.type_annotation(out_dtype), C.data,
i, 1, 1),
+ T.int64(1))
+
+ red_sum = T.call_llvm_intrin(
+ f"{out_dtype}xvscalex{o_dtype_lanes}",
+ "llvm.riscv.vfredusum" if out_dtype[0] == "f" else \
+ "llvm.riscv.vwredsum",
+ T.broadcast(T.Cast(out_dtype, 0), T.vscale() *
o_dtype_lanes),
+ product,
+ ini_acc,
+ *mask_args,
+ T.uint64(n_elems))
+
+ C[i] = T.call_llvm_intrin(
+ out_dtype,
+ "llvm.riscv.vfmv.f.s" if out_dtype[0] == "f" else \
+ "llvm.riscv.vmv.x.s",
+ red_sum)
+ # fmt: on
+ return rvv_vec_dot_prod_desc, rvv_vec_dot_prod_impl
+
+
+@register_func("tir.tensor_intrin.register_rvv_isa_intrinsics")
+def register_rvv_isa_intrinsics(target: Target, inventory_only=False) ->
dict():
+ """Register RISCV V (vector) intrinsics
+ [x] Implementation follows version 1.0 vector specifications:
+ https://github.com/riscvarchive/riscv-v-spec/releases/tag/v1.0
+
+ Args:
+ target (Target): TVM target
+ inventory_only (bool): No registration inventory only
+
+ Returns:
+ dict(): A catalog with registered kernel names and properties
+ """
+ if not target_has_features("v", target):
+ raise RuntimeError("Current target does not support `v` extension.")
+
+ vlen = llvm_get_vector_width(target)
+ # get maximum reduction lanes (without grouping)
+ n_lanes = get_max_elems(vlen, lmul=1, sew=32)
+
+ kernels_inventory = {}
+
+ data_dtype = ["uint8", "int8", "float16", "float32"]
+ weight_dtype = ["int8", "int8", "float16", "float32"]
+ output_dtype = ["int32", "int32", "float16", "float32"]
+
+ for d_dtype, w_dtype, o_dtype in zip(data_dtype, weight_dtype,
output_dtype):
+ # max elements to grouped registers
+ max_elems = get_max_elems(vlen, lmul=8, sew=DataType(d_dtype).bits)
+ # data widening halves available vector registers
+ if DataType(o_dtype).bits > DataType(d_dtype).bits:
+ max_elems //= 2
+ # compute optimal LMUL for full load
+ lmul = max_elems // (vlen // DataType(d_dtype).bits)
+
+ n_elems = max_elems
+ while n_elems >= 4:
+
+ dt = DataType(d_dtype)
+ wt = DataType(w_dtype)
+ ot = DataType(o_dtype)
+ kernel_name = "rvv_dot"
+ kernel_name += f"_{n_elems}{dt[0]}{dt.bits}"
+ kernel_name += f"_{n_lanes}x{n_elems}{wt[0]}{wt.bits}"
+ kernel_name += f"_{n_lanes}{ot[0]}{ot.bits}"
+ kernels_inventory[kernel_name] = n_elems
+
+ if not inventory_only:
+ logger.debug(f"Registering kernel {kernel_name}")
+ desc, impl = rvv_vec_dot_product_kernels(
+ n_elems, n_lanes, d_dtype, w_dtype, o_dtype, lmul
+ )
+ TensorIntrin.register(kernel_name, desc, impl, override=True)
+
+ n_elems //= 2
+
+ return kernels_inventory
+
+
+def register_riscv_intrinsics(target: Target):
+ """Register RISCV intrinsics
+
+ Args:
+ target (Target): TVM target
+ """
+
+ # RISCV `v` 1.0 extension templates
+ _ = register_rvv_isa_intrinsics(target)
+ logger.debug("Finished registering riscv intrinsics.")
diff --git a/src/meta_schedule/postproc/postproc.cc
b/src/meta_schedule/postproc/postproc.cc
index ccf280860d..6d11929648 100644
--- a/src/meta_schedule/postproc/postproc.cc
+++ b/src/meta_schedule/postproc/postproc.cc
@@ -69,6 +69,14 @@ Array<Postproc> Postproc::DefaultCPUTensorization() {
};
}
+Array<Postproc> Postproc::DefaultRISCV() {
+ return Array<Postproc>{
+ Postproc::DisallowDynamicLoop(),
Postproc::RewriteParallelVectorizeUnroll(),
+ Postproc::RewriteReductionBlock(),
Postproc::RewriteTensorize(/*vectorize_init_loop=*/false),
+ Postproc::RewriteLayout(),
+ };
+}
+
Array<Postproc> Postproc::DefaultCUDA() {
return Array<Postproc>{
Postproc::DisallowDynamicLoop(),
diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc
b/src/meta_schedule/schedule_rule/schedule_rule.cc
index 9570c0d0f9..e23ca117c6 100644
--- a/src/meta_schedule/schedule_rule/schedule_rule.cc
+++ b/src/meta_schedule/schedule_rule/schedule_rule.cc
@@ -17,6 +17,7 @@
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
+#include <tvm/runtime/data_type.h>
#include "../utils.h"
@@ -304,6 +305,62 @@ Array<ScheduleRule> ScheduleRule::DefaultHexagon() {
};
}
+Array<ScheduleRule> ScheduleRule::DefaultRISCV(const int vlen) {
+ Array<ScheduleRule> rules;
+ rules.push_back(ScheduleRule::ApplyCustomRule());
+ rules.push_back(ScheduleRule::InlineConstantScalars());
+ rules.push_back(ScheduleRule::AutoInline(
+ /*into_producer=*/false,
+ /*into_consumer=*/true,
+ /*inline_const_tensor=*/true,
+ /*disallow_if_then_else=*/true,
+ /*require_injective=*/true,
+ /*require_ordered=*/true,
+ /*disallow_op=*/Array<String>{"tir.exp"}));
+ rules.push_back(ScheduleRule::AddRFactor(
+ /*max_jobs_per_core=*/16,
+ /*max_innermost_factor=*/Integer(64)));
+ auto current_target = tvm::Target::Current();
+ const auto reg_rvv_intrinsics =
+
tvm::ffi::Function::GetGlobalRequired("tir.tensor_intrin.register_rvv_isa_intrinsics");
+ const auto rvv_kernels_inventory =
+ reg_rvv_intrinsics(current_target, /* inventory_only */
true).cast<Map<String, int>>();
+ for (const auto& intrin : rvv_kernels_inventory) {
+ if (!tir::TensorIntrin::Get(intrin.first, /*allow_missing*/ true)) {
+ // on demand intrinsic register
+ reg_rvv_intrinsics(current_target, /* inventory_only */ false);
+ }
+ rules.push_back(ScheduleRule::MultiLevelTilingWithIntrin(
+ /*intrin_name=*/intrin.first,
+ /*structure=*/"SSRSRS",
+ /*tile_binds=*/std::nullopt,
+ /*max_innermost_factor=*/Integer(intrin.second),
+ /*vector_load_lens=*/std::nullopt,
+ /*reuse_read=*/std::nullopt,
+ /*reuse_write=*/
+ Map<String, ffi::Any>{{"req", String("may")},
+ {"levels", Array<Integer>{1, 2}},
+ {"scope", String("global")}}));
+ }
+ rules.push_back(ScheduleRule::MultiLevelTiling(
+ /*structure=*/"SSRSRS",
+ /*tile_binds=*/std::nullopt,
+ /*max_innermost_factor=*/Integer(64),
+ /*vector_load_lens=*/std::nullopt,
+ /*reuse_read=*/std::nullopt,
+ /*reuse_write=*/
+ Map<String, ffi::Any>{
+ {"req", String("may")}, {"levels", Array<Integer>{1, 2}}, {"scope",
String("global")}}));
+ rules.push_back(ScheduleRule::ParallelizeVectorizeUnroll(
+ /*max_jobs_per_core=*/16,
+ /*max_vectorize_extent=*/64,
+ /*unroll_max_steps=*/Array<Integer>{0, 16, 64, 512},
+ /*unroll_explicit=*/true));
+ rules.push_back(ScheduleRule::RandomComputeLocation());
+
+ return rules;
+}
+
Array<ScheduleRule> GetARMNeonSpecificRules() {
return {
ScheduleRule::MultiLevelTilingWithIntrin(
diff --git a/src/meta_schedule/space_generator/space_generator.cc
b/src/meta_schedule/space_generator/space_generator.cc
index 709b36417c..20d2d36268 100644
--- a/src/meta_schedule/space_generator/space_generator.cc
+++ b/src/meta_schedule/space_generator/space_generator.cc
@@ -39,6 +39,10 @@ String GetRuleKindFromTarget(const Target& target) {
return "avx512";
}
}
+ bool have_rvv = target_has_feature_fn_ptr("v", target).cast<bool>();
+ if (have_rvv) {
+ return "rvv";
+ }
TargetJSON target_json =
target::parsers::aprofile::ParseTarget(target->Export());
TargetFeatures afeatures =
Downcast<TargetFeatures>(target_json.at("features"));
@@ -117,6 +121,13 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const
TuneContext& context) {
default_sch_rules = ScheduleRule::DefaultX86("avx512");
default_postprocs = Postproc::DefaultCPUTensorization();
default_mutator_probs = Mutator::DefaultLLVM();
+ } else if (kind == "rvv") {
+ static auto llvm_get_vector_width =
+
tvm::ffi::Function::GetGlobalRequired("target.llvm_get_vector_width");
+ const int vlen =
llvm_get_vector_width(context->target.value()).cast<int>();
+ default_sch_rules = ScheduleRule::DefaultRISCV(vlen);
+ default_postprocs = Postproc::DefaultRISCV();
+ default_mutator_probs = Mutator::DefaultLLVM();
} else if (kind == "asimd") {
default_sch_rules = ScheduleRule::DefaultARM("neon");
default_postprocs = Postproc::DefaultCPUTensorization();