This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b724c87f76 [MetaSchedule][ARM] Enable ARM CPU intrinsic for 
MetaSchedule (#14209)
b724c87f76 is described below

commit b724c87f76071c1c4d39c17aa0638bd94c878d61
Author: dsbarinov1 <[email protected]>
AuthorDate: Fri Mar 31 11:46:38 2023 +0300

    [MetaSchedule][ARM] Enable ARM CPU intrinsic for MetaSchedule (#14209)
---
 include/tvm/meta_schedule/schedule_rule.h          |  2 +
 include/tvm/runtime/container/array.h              | 43 ++++++++++
 python/tvm/tir/tensor_intrin/arm_cpu.py            | 99 ++++++++++++++--------
 src/meta_schedule/schedule_rule/schedule_rule.cc   | 90 ++++++++++++++++++++
 .../space_generator/space_generator.cc             | 19 +++++
 .../test_meta_schedule_post_order_apply.py         | 73 ++++++++++++++++
 6 files changed, 291 insertions(+), 35 deletions(-)

diff --git a/include/tvm/meta_schedule/schedule_rule.h 
b/include/tvm/meta_schedule/schedule_rule.h
index 7995d1fcee..d91812fb55 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -300,6 +300,8 @@ class ScheduleRule : public runtime::ObjectRef {
   TVM_DLL static Array<ScheduleRule, void> DefaultHexagon();
   /*! \brief Create default schedule rules for Micro */
   TVM_DLL static Array<ScheduleRule, void> DefaultMicro();
+  /*! \brief Create default schedule rules for ARM CPU (NEON and DOTPROD) */
+  TVM_DLL static Array<ScheduleRule, void> DefaultARM(const String& type);
 
   TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef, 
ScheduleRuleNode);
 };
diff --git a/include/tvm/runtime/container/array.h 
b/include/tvm/runtime/container/array.h
index d1d5422a43..ff0bd03ab9 100644
--- a/include/tvm/runtime/container/array.h
+++ b/include/tvm/runtime/container/array.h
@@ -580,6 +580,36 @@ class Array : public ObjectRef {
     }
   }
 
+  template <typename... Args>
+  static size_t CalcCapacityImpl() {
+    return 0;
+  }
+
+  template <typename... Args>
+  static size_t CalcCapacityImpl(Array<T> value, Args... args) {
+    return value.size() + CalcCapacityImpl(args...);
+  }
+
+  template <typename... Args>
+  static size_t CalcCapacityImpl(T value, Args... args) {
+    return 1 + CalcCapacityImpl(args...);
+  }
+
+  template <typename... Args>
+  static void AgregateImpl(Array<T>& dest) {}  // NOLINT(*)
+
+  template <typename... Args>
+  static void AgregateImpl(Array<T>& dest, Array<T> value, Args... args) {  // 
NOLINT(*)
+    dest.insert(dest.end(), value.begin(), value.end());
+    AgregateImpl(dest, args...);
+  }
+
+  template <typename... Args>
+  static void AgregateImpl(Array<T>& dest, T value, Args... args) {  // 
NOLINT(*)
+    dest.push_back(value);
+    AgregateImpl(dest, args...);
+  }
+
  public:
   // Array's own methods
 
@@ -680,6 +710,19 @@ class Array : public ObjectRef {
   /*! \brief specify container node */
   using ContainerType = ArrayNode;
 
+  /*!
+   * \brief Agregate arguments into a single Array<T>
+   * \param args sequence of T or Array<T> elements
+   * \return Agregated Array<T>
+   */
+  template <typename... Args>
+  static Array<T> Agregate(Args... args) {
+    Array<T> result;
+    result.reserve(CalcCapacityImpl(args...));
+    AgregateImpl(result, args...);
+    return result;
+  }
+
  private:
   /*!
    * \brief Implement copy-on-write semantics, and ensures capacity is enough 
for extra elements.
diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py 
b/python/tvm/tir/tensor_intrin/arm_cpu.py
index 9357f0ceb2..521d882e24 100644
--- a/python/tvm/tir/tensor_intrin/arm_cpu.py
+++ b/python/tvm/tir/tensor_intrin/arm_cpu.py
@@ -26,7 +26,7 @@ from .dot_product_common import DP4A_INTRIN  # pylint: 
disable=unused-import
 
 
 @T.prim_func
-def dot_product_4x4_i8i8i32_desc(
+def neon_4x4_i8i8i32_desc(
     A: T.Buffer((4,), "int8", offset_factor=1),
     B: T.Buffer((4, 4), "int8", offset_factor=1),
     C: T.Buffer((4,), "int32", offset_factor=1),
@@ -42,7 +42,7 @@ def dot_product_4x4_i8i8i32_desc(
 
 
 @T.prim_func
-def dot_product_4x4_i8i8i32_neon(
+def neon_4x4_i8i8i32_impl(
     A: T.Buffer((4,), "int8", offset_factor=1),
     B: T.Buffer((4, 4), "int8", offset_factor=1),
     C: T.Buffer((4,), "int32", offset_factor=1),
@@ -102,42 +102,71 @@ def dot_product_4x4_i8i8i32_neon(
         )
 
 
[email protected]_func
-def dot_product_4x4_i8i8i32_sdot(
-    A: T.Buffer((4,), "int8", offset_factor=1),
-    B: T.Buffer((4, 4), "int8", offset_factor=1),
-    C: T.Buffer((4,), "int32", offset_factor=1),
-) -> None:
-    with T.block("root"):
-        T.reads(C[0:4], A[0:4], B[0:4, 0:4])
-        T.writes(C[0:4])
-
-        A_i8x4 = A.vload([0], "int8x4")
-        A_i32 = T.reinterpret(A_i8x4, dtype="int32")
-        vec_ai32 = T.broadcast(A_i32, 4)
-        vec_a = T.reinterpret(vec_ai32, dtype="int8x16")
-
-        vec_b = B.vload([0, 0], dtype="int8x16")
-
-        vec_c = C.vload([0], dtype="int32x4")
-
-        C[T.ramp(T.int32(0), 1, 4)] = T.call_llvm_pure_intrin(
-            T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"),
-            T.uint32(3),
-            vec_c,
-            vec_a,
-            vec_b,
-            dtype="int32x4",
-        )
+def get_dotprod_intrin(in_dtype, out_dtype):
+    if in_dtype == "uint8":
+        instr = "udot.v4u32.v16u8"
+    else:  # if in_dtype == "int8"
+        instr = "sdot.v4i32.v16i8"
+
+    in_dtype_x4 = "{TYPE}x4".format(TYPE=in_dtype)
+    out_dtype_x4 = "{TYPE}x4".format(TYPE=out_dtype)
+    in_dtype_x16 = "{TYPE}x16".format(TYPE=in_dtype)
+
+    @T.prim_func
+    def dot_prod_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (4,), dtype=in_dtype, offset_factor=1)
+        B = T.match_buffer(b, (4, 4), dtype=in_dtype, offset_factor=1)
+        C = T.match_buffer(c, (4,), dtype=out_dtype, offset_factor=1)
+        with T.block("root"):
+            T.reads(C[0:4], A[0:4], B[0:4, 0:4])
+            T.writes(C[0:4])
+            for i in T.serial(0, 4):
+                for k in T.serial(0, 4):
+                    with T.block("update"):
+                        vi, vk = T.axis.remap("SR", [i, k])
+                        C[vi] = C[vi] + T.cast(A[vk], dtype=out_dtype) * 
T.cast(
+                            B[vi, vk], dtype=out_dtype
+                        )
+
+    @T.prim_func
+    def dot_prod_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
+        A = T.match_buffer(a, (4,), dtype=in_dtype, offset_factor=1)
+        B = T.match_buffer(b, (4, 4), dtype=in_dtype, offset_factor=1)
+        C = T.match_buffer(c, (4,), dtype=out_dtype, offset_factor=1)
+        with T.block("root"):
+            T.reads(C[0:4], A[0:4], B[0:4, 0:4])
+            T.writes(C[0:4])
+
+            A_i8x4 = A.vload([0], in_dtype_x4)
+            A_i32 = T.reinterpret(A_i8x4, dtype=out_dtype)
+            vec_ai32 = T.broadcast(A_i32, 4)
+            vec_a = T.reinterpret(vec_ai32, dtype=in_dtype_x16)
+
+            vec_b = B.vload([0, 0], dtype=in_dtype_x16)
+
+            vec_c = C.vload([0], dtype=out_dtype_x4)
+
+            C[T.ramp(T.int32(0), 1, 4)] = T.call_llvm_pure_intrin(
+                
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.{INSTR}".format(INSTR=instr)),
+                T.uint32(3),
+                vec_c,
+                vec_a,
+                vec_b,
+                dtype=out_dtype_x4,
+            )
+
+    return dot_prod_desc, dot_prod_impl
 
 
 ARM_DOT_4x4_i8_NEON_INTRIN = "dot_4x4_i8i8s32_neon"
 ARM_DOT_4x4_i8_SDOT_INTRIN = "dot_4x4_i8i8s32_sdot"
+ARM_DOT_4x4_u8_UDOT_INTRIN = "dot_4x4_u8u8u32_udot"
+ARM_DOT_4x4_u8_HDOT_INTRIN = "dot_4x4_u8u8i32_hdot"
+
+TensorIntrin.register(ARM_DOT_4x4_i8_NEON_INTRIN, neon_4x4_i8i8i32_desc, 
neon_4x4_i8i8i32_impl)
+
+TensorIntrin.register(ARM_DOT_4x4_i8_SDOT_INTRIN, *get_dotprod_intrin("int8", 
"int32"))
 
-TensorIntrin.register(
-    ARM_DOT_4x4_i8_NEON_INTRIN, dot_product_4x4_i8i8i32_desc, 
dot_product_4x4_i8i8i32_neon
-)
+TensorIntrin.register(ARM_DOT_4x4_u8_UDOT_INTRIN, *get_dotprod_intrin("uint8", 
"uint32"))
 
-TensorIntrin.register(
-    ARM_DOT_4x4_i8_SDOT_INTRIN, dot_product_4x4_i8i8i32_desc, 
dot_product_4x4_i8i8i32_sdot
-)
+TensorIntrin.register(ARM_DOT_4x4_u8_HDOT_INTRIN, *get_dotprod_intrin("uint8", 
"int32"))
diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc 
b/src/meta_schedule/schedule_rule/schedule_rule.cc
index 49a7c9911c..35f1151c9c 100644
--- a/src/meta_schedule/schedule_rule/schedule_rule.cc
+++ b/src/meta_schedule/schedule_rule/schedule_rule.cc
@@ -295,6 +295,94 @@ Array<ScheduleRule> ScheduleRule::DefaultMicro() {
   };
 }
 
+Array<ScheduleRule> GetNeonSpecificRules() {
+  return {
+      ScheduleRule::MultiLevelTilingWithIntrin(
+          /*intrin_name=*/String("dot_4x4_i8i8s32_neon"),
+          /*structure=*/"SSRSRS",
+          /*tile_binds=*/NullOpt,
+          /*max_innermost_factor=*/Integer(32),
+          /*vector_load_lens=*/NullOpt,
+          /*reuse_read=*/NullOpt,
+          /*reuse_write=*/
+          Map<String, ObjectRef>{{"req", String("may")},
+                                 {"levels", Array<Integer>{1, 2}},
+                                 {"scope", String("global")}}),
+  };
+}
+
+Array<ScheduleRule> GetDotprodSpecificRules() {
+  return {
+      ScheduleRule::MultiLevelTilingWithIntrin(
+          /*intrin_name=*/String("dot_4x4_i8i8s32_sdot"),
+          /*structure=*/"SSRSRS",
+          /*tile_binds=*/NullOpt,
+          /*max_innermost_factor=*/Integer(32),
+          /*vector_load_lens=*/NullOpt,
+          /*reuse_read=*/NullOpt,
+          /*reuse_write=*/
+          Map<String, ObjectRef>{{"req", String("may")},
+                                 {"levels", Array<Integer>{1, 2}},
+                                 {"scope", String("global")}}),
+      ScheduleRule::MultiLevelTilingWithIntrin(
+          /*intrin_name=*/String("dot_4x4_u8u8u32_udot"),
+          /*structure=*/"SSRSRS",
+          /*tile_binds=*/NullOpt,
+          /*max_innermost_factor=*/Integer(32),
+          /*vector_load_lens=*/NullOpt,
+          /*reuse_read=*/NullOpt,
+          /*reuse_write=*/
+          Map<String, ObjectRef>{{"req", String("may")},
+                                 {"levels", Array<Integer>{1, 2}},
+                                 {"scope", String("global")}}),
+      ScheduleRule::MultiLevelTilingWithIntrin(
+          /*intrin_name=*/String("dot_4x4_u8u8i32_hdot"),
+          /*structure=*/"SSRSRS",
+          /*tile_binds=*/NullOpt,
+          /*max_innermost_factor=*/Integer(32),
+          /*vector_load_lens=*/NullOpt,
+          /*reuse_read=*/NullOpt,
+          /*reuse_write=*/
+          Map<String, ObjectRef>{{"req", String("may")},
+                                 {"levels", Array<Integer>{1, 2}},
+                                 {"scope", String("global")}}),
+  };
+}
+
+Array<ScheduleRule> ScheduleRule::DefaultARM(const String& type) {
+  return Array<ScheduleRule>::Agregate(
+      ScheduleRule::ApplyCustomRule(), ScheduleRule::InlineConstantScalars(),
+      ScheduleRule::AutoInline(
+          /*into_producer=*/false,
+          /*into_consumer=*/true,
+          /*inline_const_tensor=*/true,
+          /*disallow_if_then_else=*/true,
+          /*require_injective=*/true,
+          /*require_ordered=*/true,
+          /*disallow_op=*/Array<String>{"tir.exp"}),
+      ScheduleRule::AddRFactor(
+          /*max_jobs_per_core=*/8,
+          /*max_innermost_factor=*/Integer(32)),
+      "neon" == type ? GetNeonSpecificRules() : Array<ScheduleRule>{},
+      "dotprod" == type ? GetDotprodSpecificRules() : Array<ScheduleRule>{},
+      ScheduleRule::MultiLevelTiling(
+          /*structure=*/"SSRSRS",
+          /*tile_binds=*/NullOpt,
+          /*max_innermost_factor=*/Integer(32),
+          /*vector_load_lens=*/NullOpt,
+          /*reuse_read=*/NullOpt,
+          /*reuse_write=*/
+          Map<String, ObjectRef>{{"req", String("may")},
+                                 {"levels", Array<Integer>{1, 2}},
+                                 {"scope", String("global")}}),
+      ScheduleRule::ParallelizeVectorizeUnroll(
+          /*max_jobs_per_core=*/8,
+          /*max_vectorize_extent=*/32,
+          /*unroll_max_steps=*/Array<Integer>{0, 8, 32, 256},
+          /*unroll_explicit=*/true),
+      ScheduleRule::RandomComputeLocation());
+}
+
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<PyScheduleRuleNode>([](const ObjectRef& n, ReprPrinter* p) {
       const auto* self = n.as<PyScheduleRuleNode>();
@@ -325,6 +413,8 @@ 
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultHexagon")
     .set_body_typed(ScheduleRule::DefaultHexagon);
 TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultMicro")
     .set_body_typed(ScheduleRule::DefaultMicro);
+TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultARM")
+    .set_body_typed(ScheduleRule::DefaultARM);
 
 }  // namespace meta_schedule
 }  // namespace tvm
diff --git a/src/meta_schedule/space_generator/space_generator.cc 
b/src/meta_schedule/space_generator/space_generator.cc
index a3669e996f..c96554e6a2 100644
--- a/src/meta_schedule/space_generator/space_generator.cc
+++ b/src/meta_schedule/space_generator/space_generator.cc
@@ -16,6 +16,7 @@
  * specific language governing permissions and limitations
  * under the License.
  */
+#include "../../target/parsers/aprofile.h"
 #include "../utils.h"
 
 namespace tvm {
@@ -38,6 +39,16 @@ String GetRuleKindFromTarget(const Target& target) {
         return "avx512";
       }
     }
+
+    TargetJSON target_json = 
target::parsers::aprofile::ParseTarget(target->Export());
+    TargetFeatures afeatures = 
Downcast<TargetFeatures>(target_json.at("features"));
+
+    if (Downcast<Bool>(afeatures.at("has_dotprod"))) {
+      return "dotprod";
+    }
+    if (Downcast<Bool>(afeatures.at("has_asimd"))) {
+      return "asimd";
+    }
     return "llvm";
   }
   if (target->kind->name == "hexagon") {
@@ -110,6 +121,14 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const 
TuneContext& context) {
       default_sch_rules = ScheduleRule::DefaultMicro();
       default_postprocs = Postproc::DefaultMicro();
       default_mutator_probs = Mutator::DefaultMicro();
+    } else if (kind == "asimd") {
+      default_sch_rules = ScheduleRule::DefaultARM("neon");
+      default_postprocs = Postproc::DefaultCPUTensorization();
+      default_mutator_probs = Mutator::DefaultLLVM();
+    } else if (kind == "dotprod") {
+      default_sch_rules = ScheduleRule::DefaultARM("dotprod");
+      default_postprocs = Postproc::DefaultCPUTensorization();
+      default_mutator_probs = Mutator::DefaultLLVM();
     } else {
       LOG(FATAL) << "Unsupported kind: " << kind;
       throw;
diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py 
b/tests/python/unittest/test_meta_schedule_post_order_apply.py
index 716f829653..6c069dc6bf 100644
--- a/tests/python/unittest/test_meta_schedule_post_order_apply.py
+++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py
@@ -23,6 +23,8 @@ from typing import List
 import pytest
 import tvm
 import tvm.testing
+from tvm import te
+from tvm.ir.module import IRModule
 from tvm._ffi import register_func
 from tvm.error import TVMError
 from tvm.meta_schedule import TuneContext
@@ -36,6 +38,23 @@ from tvm.tir.schedule import BlockRV, Schedule
 # pylint: 
disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,
 # fmt: off
 
+
+def get_matmul_packed(m, n, k, lhs_type="int8", rhs_dtype="int8", 
acc_dtype="int32"):
+    X = te.placeholder((m, k), name="X", dtype=lhs_type)
+    W = te.placeholder((n, k), name="W", dtype=rhs_dtype)
+
+    ak = te.reduce_axis((0, k), name="k")
+    matmul = te.compute(
+        (m, n),
+        lambda i, j: te.sum(
+            X[i, ak].astype(acc_dtype) * W[j, ak].astype(acc_dtype),
+            axis=ak,
+        ),
+        name="compute",
+    )
+    return te.create_prim_func([X, W, matmul])
+
+
 @tvm.script.ir_module
 class Matmul:
     @T.prim_func
@@ -404,6 +423,60 @@ def test_target_blocks_search_space():
     assert len(schs) == 8
 
 
[email protected](
+    "target,mod,expected_intr",
+    [
+        (
+            Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu 
-mattr=+neon -num-cores 2"),
+            IRModule({"main": get_matmul_packed(128, 128, 128, "int8", "int8", 
"int32")}),
+            "dot_4x4_i8i8s32_neon",
+        ),
+        (
+            Target(
+                "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu 
-mattr=+neon,+v8.2a,+dotprod -num-cores 2"
+            ),
+            IRModule({"main": get_matmul_packed(128, 128, 128, "int8", "int8", 
"int32")}),
+            "dot_4x4_i8i8s32_sdot",
+        ),
+        (
+            Target(
+                "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu 
-mattr=+neon,+v8.2a,+dotprod -num-cores 2"
+            ),
+            IRModule({"main": get_matmul_packed(128, 128, 128, "uint8", 
"uint8", "uint32")}),
+            "dot_4x4_u8u8u32_udot",
+        ),
+        (
+            Target(
+                "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu 
-mattr=+neon,+v8.2a,+dotprod -num-cores 2"
+            ),
+            IRModule({"main": get_matmul_packed(128, 128, 128, "uint8", 
"uint8", "int32")}),
+            "dot_4x4_u8u8i32_hdot",
+        ),
+    ],
+)
+def test_meta_schedule_post_order_apply_arm_intrin(target, mod, expected_intr):
+    context = TuneContext(
+        mod=mod,
+        target=target,
+        task_name="Arm Intrinsic Task",
+        space_generator=PostOrderApply(),  # Triggers default generator
+        rand_state=1,  # Change it while all tests are not passing
+    )
+    post_order_apply = context.space_generator
+    schs = post_order_apply.generate_design_space(mod)
+
+    assert len(schs) != 0
+
+    for sch in schs:
+        sch.enter_postproc()
+
+        for proc in context.space_generator.postprocs:
+            proc.apply(sch)
+
+    assert any(["call_llvm_pure_intrin" in sch.mod.script() for sch in schs])
+    assert any([expected_intr in str(sch.trace) for sch in schs])
+
+
 def test_meta_schedule_derived_object():
     @derived_object
     class RemoveBlock(PyScheduleRule):

Reply via email to