cbalint13 commented on code in PR #18182:
URL: https://github.com/apache/tvm/pull/18182#discussion_r2289576172
##########
src/meta_schedule/schedule_rule/schedule_rule.cc:
##########
@@ -304,6 +304,122 @@ Array<ScheduleRule> ScheduleRule::DefaultHexagon() {
};
}
+int GetVLMAX(int vlen, int lmul, int max_sew) { return (lmul * vlen) /
max_sew; }
+
+Array<ScheduleRule> ScheduleRule::DefaultRISCV(int vlen) {
+ Array<ScheduleRule> rules;
+
+ rules.push_back(ScheduleRule::ApplyCustomRule());
+
+ rules.push_back(ScheduleRule::InlineConstantScalars());
+
+ rules.push_back(ScheduleRule::AutoInline(
+ /*into_producer=*/false,
+ /*into_consumer=*/true,
+ /*inline_const_tensor=*/true,
+ /*disallow_if_then_else=*/true,
+ /*require_injective=*/true,
+ /*require_ordered=*/true,
+ /*disallow_op=*/Array<String>{"tir.exp"}));
+
+ rules.push_back(ScheduleRule::AddRFactor(
+ /*max_jobs_per_core=*/16,
+ /*max_innermost_factor=*/Integer(64)));
+
+ int vlmax = 0;
+ int RISCV_MIN_VL = 4;
+ std::vector<std::string> vmul_types = {"multivmul", "vmul", "vmacc"};
+ String intrin_name = "";
+ int j = 1;
+
+ for (const std::string& vmul_type : vmul_types) {
+ if (vmul_type == "multivmul")
+ j = GetVLMAX(vlen, 1, 32);
+ else
+ j = 1;
+
+ // Registering for int16
+ vlmax = GetVLMAX(vlen, 8, 32);
+ while (vlmax >= RISCV_MIN_VL) {
+ intrin_name =
+ "rvv_int16_" + vmul_type + "_" + std::to_string(j) + "_" +
std::to_string(vlmax) + "_m8";
+ rules.push_back(ScheduleRule::MultiLevelTilingWithIntrin(
+ /*intrin_name=*/intrin_name,
+ /*structure=*/"SSRSRS",
+ /*tile_binds=*/std::nullopt,
+ /*max_innermost_factor=*/Integer(vlmax),
+ /*vector_load_lens=*/std::nullopt,
+ /*reuse_read=*/std::nullopt,
+ /*reuse_write=*/
+ Map<String, ffi::Any>{{"req", String("may")},
+ {"levels", Array<Integer>{1, 2}},
+ {"scope", String("global")}}));
+ vlmax /= 2;
+ }
+
+ // Registering for float16
+ vlmax = GetVLMAX(vlen, 8, 16);
Review Comment:
I think it should be ```GetVLMAX(vlen, 8, 32)```, otherwise we get error:
```ValueError: TensorIntrin 'rvv_float16_multivmul_8_128_m8' is not
registered```
Value ```16``` is not consistent with what is declared in
```tir/tensor_intrin/riscv_cpu.py``` or vice-versa.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]