This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b724c87f76 [MetaSchedule][ARM] Enable ARM CPU intrinsic for
MetaSchedule (#14209)
b724c87f76 is described below
commit b724c87f76071c1c4d39c17aa0638bd94c878d61
Author: dsbarinov1 <[email protected]>
AuthorDate: Fri Mar 31 11:46:38 2023 +0300
[MetaSchedule][ARM] Enable ARM CPU intrinsic for MetaSchedule (#14209)
---
include/tvm/meta_schedule/schedule_rule.h | 2 +
include/tvm/runtime/container/array.h | 43 ++++++++++
python/tvm/tir/tensor_intrin/arm_cpu.py | 99 ++++++++++++++--------
src/meta_schedule/schedule_rule/schedule_rule.cc | 90 ++++++++++++++++++++
.../space_generator/space_generator.cc | 19 +++++
.../test_meta_schedule_post_order_apply.py | 73 ++++++++++++++++
6 files changed, 291 insertions(+), 35 deletions(-)
diff --git a/include/tvm/meta_schedule/schedule_rule.h
b/include/tvm/meta_schedule/schedule_rule.h
index 7995d1fcee..d91812fb55 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -300,6 +300,8 @@ class ScheduleRule : public runtime::ObjectRef {
TVM_DLL static Array<ScheduleRule, void> DefaultHexagon();
/*! \brief Create default schedule rules for Micro */
TVM_DLL static Array<ScheduleRule, void> DefaultMicro();
+ /*! \brief Create default schedule rules for ARM CPU (NEON and DOTPROD) */
+ TVM_DLL static Array<ScheduleRule, void> DefaultARM(const String& type);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ScheduleRule, ObjectRef,
ScheduleRuleNode);
};
diff --git a/include/tvm/runtime/container/array.h
b/include/tvm/runtime/container/array.h
index d1d5422a43..ff0bd03ab9 100644
--- a/include/tvm/runtime/container/array.h
+++ b/include/tvm/runtime/container/array.h
@@ -580,6 +580,36 @@ class Array : public ObjectRef {
}
}
+ template <typename... Args>
+ static size_t CalcCapacityImpl() {
+ return 0;
+ }
+
+ template <typename... Args>
+ static size_t CalcCapacityImpl(Array<T> value, Args... args) {
+ return value.size() + CalcCapacityImpl(args...);
+ }
+
+ template <typename... Args>
+ static size_t CalcCapacityImpl(T value, Args... args) {
+ return 1 + CalcCapacityImpl(args...);
+ }
+
+ template <typename... Args>
+ static void AgregateImpl(Array<T>& dest) {} // NOLINT(*)
+
+ template <typename... Args>
+ static void AgregateImpl(Array<T>& dest, Array<T> value, Args... args) { //
NOLINT(*)
+ dest.insert(dest.end(), value.begin(), value.end());
+ AgregateImpl(dest, args...);
+ }
+
+ template <typename... Args>
+ static void AgregateImpl(Array<T>& dest, T value, Args... args) { //
NOLINT(*)
+ dest.push_back(value);
+ AgregateImpl(dest, args...);
+ }
+
public:
// Array's own methods
@@ -680,6 +710,19 @@ class Array : public ObjectRef {
/*! \brief specify container node */
using ContainerType = ArrayNode;
+ /*!
+ * \brief Agregate arguments into a single Array<T>
+ * \param args sequence of T or Array<T> elements
+ * \return Agregated Array<T>
+ */
+ template <typename... Args>
+ static Array<T> Agregate(Args... args) {
+ Array<T> result;
+ result.reserve(CalcCapacityImpl(args...));
+ AgregateImpl(result, args...);
+ return result;
+ }
+
private:
/*!
* \brief Implement copy-on-write semantics, and ensures capacity is enough
for extra elements.
diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py
b/python/tvm/tir/tensor_intrin/arm_cpu.py
index 9357f0ceb2..521d882e24 100644
--- a/python/tvm/tir/tensor_intrin/arm_cpu.py
+++ b/python/tvm/tir/tensor_intrin/arm_cpu.py
@@ -26,7 +26,7 @@ from .dot_product_common import DP4A_INTRIN # pylint:
disable=unused-import
@T.prim_func
-def dot_product_4x4_i8i8i32_desc(
+def neon_4x4_i8i8i32_desc(
A: T.Buffer((4,), "int8", offset_factor=1),
B: T.Buffer((4, 4), "int8", offset_factor=1),
C: T.Buffer((4,), "int32", offset_factor=1),
@@ -42,7 +42,7 @@ def dot_product_4x4_i8i8i32_desc(
@T.prim_func
-def dot_product_4x4_i8i8i32_neon(
+def neon_4x4_i8i8i32_impl(
A: T.Buffer((4,), "int8", offset_factor=1),
B: T.Buffer((4, 4), "int8", offset_factor=1),
C: T.Buffer((4,), "int32", offset_factor=1),
@@ -102,42 +102,71 @@ def dot_product_4x4_i8i8i32_neon(
)
[email protected]_func
-def dot_product_4x4_i8i8i32_sdot(
- A: T.Buffer((4,), "int8", offset_factor=1),
- B: T.Buffer((4, 4), "int8", offset_factor=1),
- C: T.Buffer((4,), "int32", offset_factor=1),
-) -> None:
- with T.block("root"):
- T.reads(C[0:4], A[0:4], B[0:4, 0:4])
- T.writes(C[0:4])
-
- A_i8x4 = A.vload([0], "int8x4")
- A_i32 = T.reinterpret(A_i8x4, dtype="int32")
- vec_ai32 = T.broadcast(A_i32, 4)
- vec_a = T.reinterpret(vec_ai32, dtype="int8x16")
-
- vec_b = B.vload([0, 0], dtype="int8x16")
-
- vec_c = C.vload([0], dtype="int32x4")
-
- C[T.ramp(T.int32(0), 1, 4)] = T.call_llvm_pure_intrin(
- T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"),
- T.uint32(3),
- vec_c,
- vec_a,
- vec_b,
- dtype="int32x4",
- )
+def get_dotprod_intrin(in_dtype, out_dtype):
+ if in_dtype == "uint8":
+ instr = "udot.v4u32.v16u8"
+ else: # if in_dtype == "int8"
+ instr = "sdot.v4i32.v16i8"
+
+ in_dtype_x4 = "{TYPE}x4".format(TYPE=in_dtype)
+ out_dtype_x4 = "{TYPE}x4".format(TYPE=out_dtype)
+ in_dtype_x16 = "{TYPE}x16".format(TYPE=in_dtype)
+
+ @T.prim_func
+ def dot_prod_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (4,), dtype=in_dtype, offset_factor=1)
+ B = T.match_buffer(b, (4, 4), dtype=in_dtype, offset_factor=1)
+ C = T.match_buffer(c, (4,), dtype=out_dtype, offset_factor=1)
+ with T.block("root"):
+ T.reads(C[0:4], A[0:4], B[0:4, 0:4])
+ T.writes(C[0:4])
+ for i in T.serial(0, 4):
+ for k in T.serial(0, 4):
+ with T.block("update"):
+ vi, vk = T.axis.remap("SR", [i, k])
+ C[vi] = C[vi] + T.cast(A[vk], dtype=out_dtype) *
T.cast(
+ B[vi, vk], dtype=out_dtype
+ )
+
+ @T.prim_func
+ def dot_prod_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
+ A = T.match_buffer(a, (4,), dtype=in_dtype, offset_factor=1)
+ B = T.match_buffer(b, (4, 4), dtype=in_dtype, offset_factor=1)
+ C = T.match_buffer(c, (4,), dtype=out_dtype, offset_factor=1)
+ with T.block("root"):
+ T.reads(C[0:4], A[0:4], B[0:4, 0:4])
+ T.writes(C[0:4])
+
+ A_i8x4 = A.vload([0], in_dtype_x4)
+ A_i32 = T.reinterpret(A_i8x4, dtype=out_dtype)
+ vec_ai32 = T.broadcast(A_i32, 4)
+ vec_a = T.reinterpret(vec_ai32, dtype=in_dtype_x16)
+
+ vec_b = B.vload([0, 0], dtype=in_dtype_x16)
+
+ vec_c = C.vload([0], dtype=out_dtype_x4)
+
+ C[T.ramp(T.int32(0), 1, 4)] = T.call_llvm_pure_intrin(
+
T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.{INSTR}".format(INSTR=instr)),
+ T.uint32(3),
+ vec_c,
+ vec_a,
+ vec_b,
+ dtype=out_dtype_x4,
+ )
+
+ return dot_prod_desc, dot_prod_impl
ARM_DOT_4x4_i8_NEON_INTRIN = "dot_4x4_i8i8s32_neon"
ARM_DOT_4x4_i8_SDOT_INTRIN = "dot_4x4_i8i8s32_sdot"
+ARM_DOT_4x4_u8_UDOT_INTRIN = "dot_4x4_u8u8u32_udot"
+ARM_DOT_4x4_u8_HDOT_INTRIN = "dot_4x4_u8u8i32_hdot"
+
+TensorIntrin.register(ARM_DOT_4x4_i8_NEON_INTRIN, neon_4x4_i8i8i32_desc,
neon_4x4_i8i8i32_impl)
+
+TensorIntrin.register(ARM_DOT_4x4_i8_SDOT_INTRIN, *get_dotprod_intrin("int8",
"int32"))
-TensorIntrin.register(
- ARM_DOT_4x4_i8_NEON_INTRIN, dot_product_4x4_i8i8i32_desc,
dot_product_4x4_i8i8i32_neon
-)
+TensorIntrin.register(ARM_DOT_4x4_u8_UDOT_INTRIN, *get_dotprod_intrin("uint8",
"uint32"))
-TensorIntrin.register(
- ARM_DOT_4x4_i8_SDOT_INTRIN, dot_product_4x4_i8i8i32_desc,
dot_product_4x4_i8i8i32_sdot
-)
+TensorIntrin.register(ARM_DOT_4x4_u8_HDOT_INTRIN, *get_dotprod_intrin("uint8",
"int32"))
diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc
b/src/meta_schedule/schedule_rule/schedule_rule.cc
index 49a7c9911c..35f1151c9c 100644
--- a/src/meta_schedule/schedule_rule/schedule_rule.cc
+++ b/src/meta_schedule/schedule_rule/schedule_rule.cc
@@ -295,6 +295,94 @@ Array<ScheduleRule> ScheduleRule::DefaultMicro() {
};
}
+Array<ScheduleRule> GetNeonSpecificRules() {
+ return {
+ ScheduleRule::MultiLevelTilingWithIntrin(
+ /*intrin_name=*/String("dot_4x4_i8i8s32_neon"),
+ /*structure=*/"SSRSRS",
+ /*tile_binds=*/NullOpt,
+ /*max_innermost_factor=*/Integer(32),
+ /*vector_load_lens=*/NullOpt,
+ /*reuse_read=*/NullOpt,
+ /*reuse_write=*/
+ Map<String, ObjectRef>{{"req", String("may")},
+ {"levels", Array<Integer>{1, 2}},
+ {"scope", String("global")}}),
+ };
+}
+
+Array<ScheduleRule> GetDotprodSpecificRules() {
+ return {
+ ScheduleRule::MultiLevelTilingWithIntrin(
+ /*intrin_name=*/String("dot_4x4_i8i8s32_sdot"),
+ /*structure=*/"SSRSRS",
+ /*tile_binds=*/NullOpt,
+ /*max_innermost_factor=*/Integer(32),
+ /*vector_load_lens=*/NullOpt,
+ /*reuse_read=*/NullOpt,
+ /*reuse_write=*/
+ Map<String, ObjectRef>{{"req", String("may")},
+ {"levels", Array<Integer>{1, 2}},
+ {"scope", String("global")}}),
+ ScheduleRule::MultiLevelTilingWithIntrin(
+ /*intrin_name=*/String("dot_4x4_u8u8u32_udot"),
+ /*structure=*/"SSRSRS",
+ /*tile_binds=*/NullOpt,
+ /*max_innermost_factor=*/Integer(32),
+ /*vector_load_lens=*/NullOpt,
+ /*reuse_read=*/NullOpt,
+ /*reuse_write=*/
+ Map<String, ObjectRef>{{"req", String("may")},
+ {"levels", Array<Integer>{1, 2}},
+ {"scope", String("global")}}),
+ ScheduleRule::MultiLevelTilingWithIntrin(
+ /*intrin_name=*/String("dot_4x4_u8u8i32_hdot"),
+ /*structure=*/"SSRSRS",
+ /*tile_binds=*/NullOpt,
+ /*max_innermost_factor=*/Integer(32),
+ /*vector_load_lens=*/NullOpt,
+ /*reuse_read=*/NullOpt,
+ /*reuse_write=*/
+ Map<String, ObjectRef>{{"req", String("may")},
+ {"levels", Array<Integer>{1, 2}},
+ {"scope", String("global")}}),
+ };
+}
+
+Array<ScheduleRule> ScheduleRule::DefaultARM(const String& type) {
+ return Array<ScheduleRule>::Agregate(
+ ScheduleRule::ApplyCustomRule(), ScheduleRule::InlineConstantScalars(),
+ ScheduleRule::AutoInline(
+ /*into_producer=*/false,
+ /*into_consumer=*/true,
+ /*inline_const_tensor=*/true,
+ /*disallow_if_then_else=*/true,
+ /*require_injective=*/true,
+ /*require_ordered=*/true,
+ /*disallow_op=*/Array<String>{"tir.exp"}),
+ ScheduleRule::AddRFactor(
+ /*max_jobs_per_core=*/8,
+ /*max_innermost_factor=*/Integer(32)),
+ "neon" == type ? GetNeonSpecificRules() : Array<ScheduleRule>{},
+ "dotprod" == type ? GetDotprodSpecificRules() : Array<ScheduleRule>{},
+ ScheduleRule::MultiLevelTiling(
+ /*structure=*/"SSRSRS",
+ /*tile_binds=*/NullOpt,
+ /*max_innermost_factor=*/Integer(32),
+ /*vector_load_lens=*/NullOpt,
+ /*reuse_read=*/NullOpt,
+ /*reuse_write=*/
+ Map<String, ObjectRef>{{"req", String("may")},
+ {"levels", Array<Integer>{1, 2}},
+ {"scope", String("global")}}),
+ ScheduleRule::ParallelizeVectorizeUnroll(
+ /*max_jobs_per_core=*/8,
+ /*max_vectorize_extent=*/32,
+ /*unroll_max_steps=*/Array<Integer>{0, 8, 32, 256},
+ /*unroll_explicit=*/true),
+ ScheduleRule::RandomComputeLocation());
+}
+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PyScheduleRuleNode>([](const ObjectRef& n, ReprPrinter* p) {
const auto* self = n.as<PyScheduleRuleNode>();
@@ -325,6 +413,8 @@
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultHexagon")
.set_body_typed(ScheduleRule::DefaultHexagon);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultMicro")
.set_body_typed(ScheduleRule::DefaultMicro);
+TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleDefaultARM")
+ .set_body_typed(ScheduleRule::DefaultARM);
} // namespace meta_schedule
} // namespace tvm
diff --git a/src/meta_schedule/space_generator/space_generator.cc
b/src/meta_schedule/space_generator/space_generator.cc
index a3669e996f..c96554e6a2 100644
--- a/src/meta_schedule/space_generator/space_generator.cc
+++ b/src/meta_schedule/space_generator/space_generator.cc
@@ -16,6 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
+#include "../../target/parsers/aprofile.h"
#include "../utils.h"
namespace tvm {
@@ -38,6 +39,16 @@ String GetRuleKindFromTarget(const Target& target) {
return "avx512";
}
}
+
+ TargetJSON target_json =
target::parsers::aprofile::ParseTarget(target->Export());
+ TargetFeatures afeatures =
Downcast<TargetFeatures>(target_json.at("features"));
+
+ if (Downcast<Bool>(afeatures.at("has_dotprod"))) {
+ return "dotprod";
+ }
+ if (Downcast<Bool>(afeatures.at("has_asimd"))) {
+ return "asimd";
+ }
return "llvm";
}
if (target->kind->name == "hexagon") {
@@ -110,6 +121,14 @@ void SpaceGeneratorNode::InitializeWithTuneContext(const
TuneContext& context) {
default_sch_rules = ScheduleRule::DefaultMicro();
default_postprocs = Postproc::DefaultMicro();
default_mutator_probs = Mutator::DefaultMicro();
+ } else if (kind == "asimd") {
+ default_sch_rules = ScheduleRule::DefaultARM("neon");
+ default_postprocs = Postproc::DefaultCPUTensorization();
+ default_mutator_probs = Mutator::DefaultLLVM();
+ } else if (kind == "dotprod") {
+ default_sch_rules = ScheduleRule::DefaultARM("dotprod");
+ default_postprocs = Postproc::DefaultCPUTensorization();
+ default_mutator_probs = Mutator::DefaultLLVM();
} else {
LOG(FATAL) << "Unsupported kind: " << kind;
throw;
diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py
b/tests/python/unittest/test_meta_schedule_post_order_apply.py
index 716f829653..6c069dc6bf 100644
--- a/tests/python/unittest/test_meta_schedule_post_order_apply.py
+++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py
@@ -23,6 +23,8 @@ from typing import List
import pytest
import tvm
import tvm.testing
+from tvm import te
+from tvm.ir.module import IRModule
from tvm._ffi import register_func
from tvm.error import TVMError
from tvm.meta_schedule import TuneContext
@@ -36,6 +38,23 @@ from tvm.tir.schedule import BlockRV, Schedule
# pylint:
disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,
# fmt: off
+
+def get_matmul_packed(m, n, k, lhs_type="int8", rhs_dtype="int8",
acc_dtype="int32"):
+ X = te.placeholder((m, k), name="X", dtype=lhs_type)
+ W = te.placeholder((n, k), name="W", dtype=rhs_dtype)
+
+ ak = te.reduce_axis((0, k), name="k")
+ matmul = te.compute(
+ (m, n),
+ lambda i, j: te.sum(
+ X[i, ak].astype(acc_dtype) * W[j, ak].astype(acc_dtype),
+ axis=ak,
+ ),
+ name="compute",
+ )
+ return te.create_prim_func([X, W, matmul])
+
+
@tvm.script.ir_module
class Matmul:
@T.prim_func
@@ -404,6 +423,60 @@ def test_target_blocks_search_space():
assert len(schs) == 8
[email protected](
+ "target,mod,expected_intr",
+ [
+ (
+ Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu
-mattr=+neon -num-cores 2"),
+ IRModule({"main": get_matmul_packed(128, 128, 128, "int8", "int8",
"int32")}),
+ "dot_4x4_i8i8s32_neon",
+ ),
+ (
+ Target(
+ "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu
-mattr=+neon,+v8.2a,+dotprod -num-cores 2"
+ ),
+ IRModule({"main": get_matmul_packed(128, 128, 128, "int8", "int8",
"int32")}),
+ "dot_4x4_i8i8s32_sdot",
+ ),
+ (
+ Target(
+ "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu
-mattr=+neon,+v8.2a,+dotprod -num-cores 2"
+ ),
+ IRModule({"main": get_matmul_packed(128, 128, 128, "uint8",
"uint8", "uint32")}),
+ "dot_4x4_u8u8u32_udot",
+ ),
+ (
+ Target(
+ "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu
-mattr=+neon,+v8.2a,+dotprod -num-cores 2"
+ ),
+ IRModule({"main": get_matmul_packed(128, 128, 128, "uint8",
"uint8", "int32")}),
+ "dot_4x4_u8u8i32_hdot",
+ ),
+ ],
+)
+def test_meta_schedule_post_order_apply_arm_intrin(target, mod, expected_intr):
+ context = TuneContext(
+ mod=mod,
+ target=target,
+ task_name="Arm Intrinsic Task",
+ space_generator=PostOrderApply(), # Triggers default generator
+ rand_state=1, # Change it while all tests are not passing
+ )
+ post_order_apply = context.space_generator
+ schs = post_order_apply.generate_design_space(mod)
+
+ assert len(schs) != 0
+
+ for sch in schs:
+ sch.enter_postproc()
+
+ for proc in context.space_generator.postprocs:
+ proc.apply(sch)
+
+ assert any(["call_llvm_pure_intrin" in sch.mod.script() for sch in schs])
+ assert any([expected_intr in str(sch.trace) for sch in schs])
+
+
def test_meta_schedule_derived_object():
@derived_object
class RemoveBlock(PyScheduleRule):