This is an automated email from the ASF dual-hosted git repository.

cbalint13 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 06fb02e3fc [LLVM][METASCHEDULE] Add RISCV V-extension v1.0 kernels to 
metaschedule (#18243)
06fb02e3fc is described below

commit 06fb02e3fcab3b2c9e449bb5590bebabeaea0faa
Author: Balint Cristian <[email protected]>
AuthorDate: Mon Sep 8 01:41:17 2025 +0300

    [LLVM][METASCHEDULE] Add RISCV V-extension v1.0 kernels to metaschedule 
(#18243)
    
    - Enables high performance kernels covering majority of usual ML datatype 
inputs
    - It is currently compliant with RVV specs version v1.0 (does not work with 
older v0.7.1)
    - TIR kernels implemented here are using recently added VLA extension 
support
---
 include/tvm/meta_schedule/postproc.h               |   2 +
 include/tvm/meta_schedule/schedule_rule.h          |   2 +
 python/tvm/target/target.py                        |   8 +
 python/tvm/tir/tensor_intrin/__init__.py           |   2 +-
 python/tvm/tir/tensor_intrin/riscv_cpu.py          | 236 +++++++++++++++++++++
 src/meta_schedule/postproc/postproc.cc             |   8 +
 src/meta_schedule/schedule_rule/schedule_rule.cc   |  57 +++++
 .../space_generator/space_generator.cc             |  11 +
 8 files changed, 325 insertions(+), 1 deletion(-)

diff --git a/include/tvm/meta_schedule/postproc.h 
b/include/tvm/meta_schedule/postproc.h
index c511271d20..6ed7272fe9 100644
--- a/include/tvm/meta_schedule/postproc.h
+++ b/include/tvm/meta_schedule/postproc.h
@@ -166,6 +166,8 @@ class Postproc : public runtime::ObjectRef {
   TVM_DLL static Array<Postproc, void> DefaultLLVM();
   /*! \brief Create default postprocessors for x86 (AVX512 and VNNI) */
   TVM_DLL static Array<Postproc, void> DefaultCPUTensorization();
+  /*! \brief Create default postprocessors for RISCV */
+  TVM_DLL static Array<Postproc, void> DefaultRISCV();
   /*! \brief Create default postprocessors for CUDA */
   TVM_DLL static Array<Postproc, void> DefaultCUDA();
   /*! \brief Create default postprocessors for CUDA with TensorCore */
diff --git a/include/tvm/meta_schedule/schedule_rule.h 
b/include/tvm/meta_schedule/schedule_rule.h
index 9011ebe0c1..407914e3d0 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -301,6 +301,8 @@ class ScheduleRule : public runtime::ObjectRef {
   TVM_DLL static Array<ScheduleRule, void> DefaultHexagon();
   /*! \brief Create default schedule rules for ARM CPU (NEON and DOTPROD) */
   TVM_DLL static Array<ScheduleRule, void> DefaultARM(const String& type);
+  /*! \brief Create default schedule rules for RISCV CPU (RVV) */
+  TVM_DLL static Array<ScheduleRule, void> DefaultRISCV(int vlen);
 
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, 
ScheduleRuleNode);
 };
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index a9191df773..eb6e25f045 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -637,6 +637,14 @@ def riscv_cpu(model="sifive-u54", options=None):
             "-mabi=lp64d",
             # cc: riscv64-unknown-linux-gnu-g++ -march=rv64gc -mabi=lp64d 
-mcpu=sifive-u74
         ],
+        "licheepi3a": [
+            "-num-cores=8",
+            "-mtriple=riscv64-unknown-linux-gnu",
+            "-mcpu=spacemit-x60",
+            "-mfloat-abi=hard",
+            "-mabi=lp64d",
+            # cc: riscv64-unknown-linux-gnu-g++ -march=rv64gcv -mabi=lp64d 
-mcpu=spacemit-x60
+        ],
     }
     pre_defined_opt = trans_table.get(model, ["-model=%s" % model])
 
diff --git a/python/tvm/tir/tensor_intrin/__init__.py 
b/python/tvm/tir/tensor_intrin/__init__.py
index 5646554552..0a6cf5310c 100644
--- a/python/tvm/tir/tensor_intrin/__init__.py
+++ b/python/tvm/tir/tensor_intrin/__init__.py
@@ -20,4 +20,4 @@ from tvm.runtime import enabled
 from . import cuda
 
 if enabled("llvm"):
-    from . import arm_cpu, x86, rocm, hexagon
+    from . import arm_cpu, x86, rocm, hexagon, riscv_cpu
diff --git a/python/tvm/tir/tensor_intrin/riscv_cpu.py 
b/python/tvm/tir/tensor_intrin/riscv_cpu.py
new file mode 100644
index 0000000000..febddc2bf3
--- /dev/null
+++ b/python/tvm/tir/tensor_intrin/riscv_cpu.py
@@ -0,0 +1,236 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,line-too-long
+"""Intrinsics for RISCV tensorization"""
+
+import logging
+from tvm.ffi import register_func
+from tvm.runtime import DataType
+from tvm.script import tir as T
+from tvm.target.codegen import llvm_get_vector_width, target_has_features, 
Target
+from .. import TensorIntrin
+
+logger = logging.getLogger(__name__)
+
+
+def get_max_elems(vlen: int, lmul: int, sew: int) -> int:
+    """Returns number of elements of a given data type (SEW)
+    that fits multiple (LMUL) of the vector registers (VLEN).
+
+    Args:
+        vlen (int): VLEN vector length in bits
+        lmul (int): LMUL vector lenght multiplier
+        sew (int): SEW standard (single) element width
+
+    Returns:
+        int: Number of elements
+    """
+    return (vlen // sew) * lmul
+
+
+def rvv_vec_dot_product_kernels(
+    n_elems: int,
+    n_lanes: int,
+    data_dtype: str,
+    weight_dtype: str,
+    out_dtype: str,
+    lmul: int,
+):
+    """Dot product of vector and matrix rows using RISC-V vector instructions.
+
+    These kernels takes two arrays A[ELEMS] and B[ELEMS][MACS] and computes
+    dot product of A[ELEMS] with each row of B[LANES], accumulating results
+    with C[LANES].
+
+    The pseudo code is as follows:
+    .. code-block:: c
+        void vec_dot_prod(A[ELEMS], B[LANES][ELEMS], C[LANES]){
+            for (j = 0; j < LANES; j++) {
+                for (k = 0; k < ELEMS; k++) {
+                    C[j] += A[k] * B[j][k]
+                }
+            }
+        }
+    """
+
+    @T.prim_func
+    def rvv_vec_dot_prod_desc(
+        A: T.Buffer((n_elems,), data_dtype, offset_factor=1),
+        B: T.Buffer((n_lanes, n_elems), weight_dtype, offset_factor=1),
+        C: T.Buffer((n_lanes,), out_dtype, offset_factor=1),
+    ) -> None:
+        with T.block("root"):
+            T.reads(C[0:n_lanes], A[0:n_elems], B[0:n_lanes, 0:n_elems])
+            T.writes(C[0:n_lanes])
+            for j in T.serial(0, n_lanes):
+                for k in T.serial(0, n_elems):
+                    with T.block("update"):
+                        vj, vk = T.axis.remap("SR", [j, k])
+                        C[vj] = C[vj] + T.cast(A[vk], out_dtype) * 
T.cast(B[vj, vk], out_dtype)
+
+    # LLVM only supports ELEN=32 or ELEN=64
+    # https://llvm.org/docs//RISCV/RISCVVectorExtension.html
+    d_dtype_lanes = (64 // DataType(data_dtype).bits) * lmul
+    w_dtype_lanes = (64 // DataType(weight_dtype).bits) * lmul
+    # reduction lanes narrows
+    o_dtype_lanes = (64 // DataType(out_dtype).bits) * lmul // n_lanes
+    # data type widening case
+    o_dtype_lanes = max(o_dtype_lanes, 2)
+
+    mask_args = () if data_dtype[0] in ("i", "u") else (T.uint64(7),)
+
+    wide_dtype = out_dtype
+    if DataType(out_dtype).bits > DataType(data_dtype).bits:
+        wide_dtype = "".join(c for c in data_dtype if not c.isdigit())
+        wide_dtype += str(DataType(data_dtype).bits * 2)
+
+    # fmt: off
+    @T.prim_func
+    def rvv_vec_dot_prod_impl(
+        A: T.Buffer((n_elems,), data_dtype, offset_factor=1),
+        B: T.Buffer((n_lanes, n_elems), weight_dtype, offset_factor=1),
+        C: T.Buffer((n_lanes,), out_dtype, offset_factor=1),
+    ) -> None:
+        with T.block("root"):
+            T.reads(C[0:n_lanes], A[0:n_elems], B[0:n_lanes, 0:n_elems])
+            T.writes(C[0:n_lanes])
+
+            vec_A = T.call_llvm_intrin(
+                f"{data_dtype}xvscalex{d_dtype_lanes}",
+                "llvm.riscv.vle",
+                T.broadcast(T.Cast(data_dtype, 0), T.vscale() * d_dtype_lanes),
+                T.tvm_access_ptr(T.type_annotation(data_dtype), A.data, 0, 
n_elems, 1),
+                T.int64(n_elems))
+
+            for i in range(n_lanes):
+                with T.block("update"):
+                    T.reads(B[i, 0:n_elems])
+                    T.writes(C[i])
+
+                    vec_B_row = T.call_llvm_intrin(
+                        f"{weight_dtype}xvscalex{w_dtype_lanes}",
+                        "llvm.riscv.vle",
+                        T.broadcast(T.Cast(data_dtype, 0), T.vscale() * 
w_dtype_lanes),
+                        T.tvm_access_ptr(T.type_annotation(weight_dtype), 
B.data, i * n_elems, n_elems, 1),
+                        T.int64(n_elems))
+
+                    product = T.call_llvm_intrin(
+                        f"{wide_dtype}xvscalex{w_dtype_lanes}",
+                        "llvm.riscv.vfmul" if out_dtype[0] == "f" else \
+                        "llvm.riscv.vwmulsu" if (data_dtype[0] != 
weight_dtype[0]) else \
+                        "llvm.riscv.vwmul",
+                        T.broadcast(T.Cast(wide_dtype, 0), T.vscale() * 
w_dtype_lanes),
+                        vec_B_row,
+                        vec_A,
+                        *mask_args,
+                        T.uint64(n_elems))
+
+                    ini_acc = T.call_llvm_intrin(
+                        f"{out_dtype}xvscalex{o_dtype_lanes}",
+                        "llvm.riscv.vle",
+                        T.broadcast(T.Cast(out_dtype, 0), T.vscale() * 
o_dtype_lanes),
+                        T.tvm_access_ptr(T.type_annotation(out_dtype), C.data, 
i, 1, 1),
+                        T.int64(1))
+
+                    red_sum = T.call_llvm_intrin(
+                        f"{out_dtype}xvscalex{o_dtype_lanes}",
+                        "llvm.riscv.vfredusum" if out_dtype[0] == "f" else \
+                        "llvm.riscv.vwredsum",
+                        T.broadcast(T.Cast(out_dtype, 0), T.vscale() * 
o_dtype_lanes),
+                        product,
+                        ini_acc,
+                        *mask_args,
+                        T.uint64(n_elems))
+
+                    C[i] = T.call_llvm_intrin(
+                        out_dtype,
+                        "llvm.riscv.vfmv.f.s" if out_dtype[0] == "f" else \
+                        "llvm.riscv.vmv.x.s",
+                        red_sum)
+    # fmt: on
+    return rvv_vec_dot_prod_desc, rvv_vec_dot_prod_impl
+
+
+@register_func("tir.tensor_intrin.register_rvv_isa_intrinsics")
+def register_rvv_isa_intrinsics(target: Target, inventory_only=False) -> 
dict():
+    """Register RISCV V (vector) intrinsics
+    [x] Implementation follows version 1.0 vector specifications:
+        https://github.com/riscvarchive/riscv-v-spec/releases/tag/v1.0
+
+    Args:
+        target (Target): TVM target
+        inventory_only (bool): No registration inventory only
+
+    Returns:
+        dict(): A catalog with registered kernel names and properties
+    """
+    if not target_has_features("v", target):
+        raise RuntimeError("Current target does not support `v` extension.")
+
+    vlen = llvm_get_vector_width(target)
+    # get maximum reduction lanes (without grouping)
+    n_lanes = get_max_elems(vlen, lmul=1, sew=32)
+
+    kernels_inventory = {}
+
+    data_dtype = ["uint8", "int8", "float16", "float32"]
+    weight_dtype = ["int8", "int8", "float16", "float32"]
+    output_dtype = ["int32", "int32", "float16", "float32"]
+
+    for d_dtype, w_dtype, o_dtype in zip(data_dtype, weight_dtype, 
output_dtype):
+        # max elements to grouped registers
+        max_elems = get_max_elems(vlen, lmul=8, sew=DataType(d_dtype).bits)
+        # data widening halves available vector registers
+        if DataType(o_dtype).bits > DataType(d_dtype).bits:
+            max_elems //= 2
+        # compute optimal LMUL for full load
+        lmul = max_elems // (vlen // DataType(d_dtype).bits)
+
+        n_elems = max_elems
+        while n_elems >= 4:
+
+            dt = DataType(d_dtype)
+            wt = DataType(w_dtype)
+            ot = DataType(o_dtype)
+            kernel_name = "rvv_dot"
+            kernel_name += f"_{n_elems}{dt[0]}{dt.bits}"
+            kernel_name += f"_{n_lanes}x{n_elems}{wt[0]}{wt.bits}"
+            kernel_name += f"_{n_lanes}{ot[0]}{ot.bits}"
+            kernels_inventory[kernel_name] = n_elems
+
+            if not inventory_only:
+                logger.debug(f"Registering kernel {kernel_name}")
+                desc, impl = rvv_vec_dot_product_kernels(
+                    n_elems, n_lanes, d_dtype, w_dtype, o_dtype, lmul
+                )
+                TensorIntrin.register(kernel_name, desc, impl, override=True)
+
+            n_elems //= 2
+
+    return kernels_inventory
+
+
+def register_riscv_intrinsics(target: Target):
+    """Register RISCV intrinsics
+
+    Args:
+        target (Target): TVM target
+    """
+
+    # RISCV `v` 1.0 extension templates
+    _ = register_rvv_isa_intrinsics(target)
+    logger.debug("Finished registering riscv intrinsics.")
diff --git a/src/meta_schedule/postproc/postproc.cc 
b/src/meta_schedule/postproc/postproc.cc
index ccf280860d..6d11929648 100644
--- a/src/meta_schedule/postproc/postproc.cc
+++ b/src/meta_schedule/postproc/postproc.cc
@@ -69,6 +69,14 @@ Array<Postproc> Postproc::DefaultCPUTensorization() {
   };
 }
 
+Array<Postproc> Postproc::DefaultRISCV() {
+  return Array<Postproc>{
+      Postproc::DisallowDynamicLoop(),   
Postproc::RewriteParallelVectorizeUnroll(),
+      Postproc::RewriteReductionBlock(), 
Postproc::RewriteTensorize(/*vectorize_init_loop=*/false),
+      Postproc::RewriteLayout(),
+  };
+}
+
 Array<Postproc> Postproc::DefaultCUDA() {
   return Array<Postproc>{
       Postproc::DisallowDynamicLoop(),
diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc 
b/src/meta_schedule/schedule_rule/schedule_rule.cc
index 9570c0d0f9..e23ca117c6 100644
--- a/src/meta_schedule/schedule_rule/schedule_rule.cc
+++ b/src/meta_schedule/schedule_rule/schedule_rule.cc
@@ -17,6 +17,7 @@
  * under the License.
  */
 #include <tvm/ffi/reflection/registry.h>
+#include <tvm/runtime/data_type.h>
 
 #include "../utils.h"
 
@@ -304,6 +305,62 @@ Array<ScheduleRule> ScheduleRule::DefaultHexagon() {
   };
 }
 
+Array<ScheduleRule> ScheduleRule::DefaultRISCV(const int vlen) {
+  Array<ScheduleRule> rules;
+  rules.push_back(ScheduleRule::ApplyCustomRule());
+  rules.push_back(ScheduleRule::InlineConstantScalars());
+  rules.push_back(ScheduleRule::AutoInline(
+      /*into_producer=*/false,
+      /*into_consumer=*/true,
+      /*inline_const_tensor=*/true,
+      /*disallow_if_then_else=*/true,
+      /*require_injective=*/true,
+      /*require_ordered=*/true,
+      /*disallow_op=*/Array<String>{"tir.exp"}));
+  rules.push_back(ScheduleRule::AddRFactor(
+      /*max_jobs_per_core=*/16,
+      /*max_innermost_factor=*/Integer(64)));
+  auto current_target = tvm::Target::Current();
+  const auto reg_rvv_intrinsics =
+      
tvm::ffi::Function::GetGlobalRequired("tir.tensor_intrin.register_rvv_isa_intrinsics");
+  const auto rvv_kernels_inventory =
+      reg_rvv_intrinsics(current_target, /* inventory_only */ 
true).cast<Map<String, int>>();
+  for (const auto& intrin : rvv_kernels_inventory) {
+    if (!tir::TensorIntrin::Get(intrin.first, /*allow_missing*/ true)) {
+      // on demand intrinsic register
+      reg_rvv_intrinsics(current_target, /* inventory_only */ false);
+    }
+    rules.push_back(ScheduleRule::MultiLevelTilingWithIntrin(
+        /*intrin_name=*/intrin.first,
+        /*structure=*/"SSRSRS",
+        /*tile_binds=*/std::nullopt,
+        /*max_innermost_factor=*/Integer(intrin.second),
+        /*vector_load_lens=*/std::nullopt,
+        /*reuse_read=*/std::nullopt,
+        /*reuse_write=*/
+        Map<String, ffi::Any>{{"req", String("may")},
+                              {"levels", Array<Integer>{1, 2}},
+                              {"scope", String("global")}}));
+  }
+  rules.push_back(ScheduleRule::MultiLevelTiling(
+      /*structure=*/"SSRSRS",
+      /*tile_binds=*/std::nullopt,
+      /*max_innermost_factor=*/Integer(64),
+      /*vector_load_lens=*/std::nullopt,
+      /*reuse_read=*/std::nullopt,
+      /*reuse_write=*/
+      Map<String, ffi::Any>{
+          {"req", String("may")}, {"levels", Array<Integer>{1, 2}}, {"scope", 
String("global")}}));
+  rules.push_back(ScheduleRule::ParallelizeVectorizeUnroll(
+      /*max_jobs_per_core=*/16,
+      /*max_vectorize_extent=*/64,
+      /*unroll_max_steps=*/Array<Integer>{0, 16, 64, 512},
+      /*unroll_explicit=*/true));
+  rules.push_back(ScheduleRule::RandomComputeLocation());
+
+  return rules;
+}
+
 Array<ScheduleRule> GetARMNeonSpecificRules() {
   return {
       ScheduleRule::MultiLevelTilingWithIntrin(
diff --git a/src/meta_schedule/space_generator/space_generator.cc 
b/src/meta_schedule/space_generator/space_generator.cc
index 709b36417c..20d2d36268 100644
--- a/src/meta_schedule/space_generator/space_generator.cc
+++ b/src/meta_schedule/space_generator/space_generator.cc
@@ -39,6 +39,10 @@ String GetRuleKindFromTarget(const Target& target) {
         return "avx512";
       }
     }
+    bool have_rvv = target_has_feature_fn_ptr("v", target).cast<bool>();
+    if (have_rvv) {
+      return "rvv";
+    }
 
     TargetJSON target_json = 
target::parsers::aprofile::ParseTarget(target->Export());
     TargetFeatures afeatures = 
Downcast<TargetFeatures>(target_json.at("features"));
@@ -117,6 +121,13 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const 
TuneContext& context) {
       default_sch_rules = ScheduleRule::DefaultX86("avx512");
       default_postprocs = Postproc::DefaultCPUTensorization();
       default_mutator_probs = Mutator::DefaultLLVM();
+    } else if (kind == "rvv") {
+      static auto llvm_get_vector_width =
+          
tvm::ffi::Function::GetGlobalRequired("target.llvm_get_vector_width");
+      const int vlen = 
llvm_get_vector_width(context->target.value()).cast<int>();
+      default_sch_rules = ScheduleRule::DefaultRISCV(vlen);
+      default_postprocs = Postproc::DefaultRISCV();
+      default_mutator_probs = Mutator::DefaultLLVM();
     } else if (kind == "asimd") {
       default_sch_rules = ScheduleRule::DefaultARM("neon");
       default_postprocs = Postproc::DefaultCPUTensorization();

Reply via email to