junrushao1994 commented on a change in pull request #7809:
URL: https://github.com/apache/tvm/pull/7809#discussion_r616979231
##########
File path: src/tir/transforms/lower_intrin.cc
##########
@@ -42,28 +42,36 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string
mtriple = "")
: IRMutatorWithAnalyzer(analyzer) {
- patterns_.push_back("tvm.intrin.rule." + target + ".");
+ patterns_.push_back(target + ".FLowerIntrinsic");
bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos);
if (is_llvm_aarch64) {
- patterns_.push_back("tvm.intrin.rule." + target + "." + "aarch64.");
+ patterns_.push_back(target + ".aarch64.FLowerIntrinsic");
}
- patterns_.push_back("tvm.intrin.rule.default.");
- fma_ = runtime::Registry::Get(patterns_[0] + "fma");
+ patterns_.push_back("default.FLowerIntrinsic");
+ fma_ = runtime::Registry::Get("tvm.intrin.rule." + target + ".fma");
if (target == "stackvm") {
support_bitwise_op_ = false;
}
}
PrimExpr VisitExpr_(const CallNode* op) final {
if (auto* ptr_op = op->op.as<OpNode>()) {
- // Still use legacy string based rewriting
- // TODO(tvm-team): migrate the pattern application from global function
look up
- // to an OpAttrMap<PackedFunc>
- std::string name = ptr_op->name;
- PrimExpr r = ApplyPattern(name, GetRef<PrimExpr>(op));
- if (r.defined()) return r;
+ for (size_t i = 0; i < patterns_.size(); ++i)
Review comment:
```suggestion
for (const std::string& pattern : patterns_)
```
##########
File path: src/tir/transforms/lower_intrin.cc
##########
@@ -42,28 +42,36 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string
mtriple = "")
: IRMutatorWithAnalyzer(analyzer) {
- patterns_.push_back("tvm.intrin.rule." + target + ".");
+ patterns_.push_back(target + ".FLowerIntrinsic");
bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos);
if (is_llvm_aarch64) {
- patterns_.push_back("tvm.intrin.rule." + target + "." + "aarch64.");
+ patterns_.push_back(target + ".aarch64.FLowerIntrinsic");
}
- patterns_.push_back("tvm.intrin.rule.default.");
- fma_ = runtime::Registry::Get(patterns_[0] + "fma");
+ patterns_.push_back("default.FLowerIntrinsic");
+ fma_ = runtime::Registry::Get("tvm.intrin.rule." + target + ".fma");
if (target == "stackvm") {
support_bitwise_op_ = false;
}
}
PrimExpr VisitExpr_(const CallNode* op) final {
if (auto* ptr_op = op->op.as<OpNode>()) {
- // Still use legacy string based rewriting
- // TODO(tvm-team): migrate the pattern application from global function
look up
- // to an OpAttrMap<PackedFunc>
- std::string name = ptr_op->name;
- PrimExpr r = ApplyPattern(name, GetRef<PrimExpr>(op));
- if (r.defined()) return r;
+ for (size_t i = 0; i < patterns_.size(); ++i)
+ if (Op::HasAttrMap(patterns_[i])) {
+ auto default_intrin = Op::GetAttrMap<FLowerIntrinsic>(patterns_[i]);
Review comment:
it is not necessarily default intrin, so let's find a more accurate name
here? like, `f_lower_intrin_map`?
##########
File path: src/tir/transforms/lower_intrin.cc
##########
@@ -42,28 +42,36 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string
mtriple = "")
: IRMutatorWithAnalyzer(analyzer) {
- patterns_.push_back("tvm.intrin.rule." + target + ".");
+ patterns_.push_back(target + ".FLowerIntrinsic");
bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos);
if (is_llvm_aarch64) {
- patterns_.push_back("tvm.intrin.rule." + target + "." + "aarch64.");
+ patterns_.push_back(target + ".aarch64.FLowerIntrinsic");
}
- patterns_.push_back("tvm.intrin.rule.default.");
- fma_ = runtime::Registry::Get(patterns_[0] + "fma");
+ patterns_.push_back("default.FLowerIntrinsic");
+ fma_ = runtime::Registry::Get("tvm.intrin.rule." + target + ".fma");
if (target == "stackvm") {
support_bitwise_op_ = false;
}
}
PrimExpr VisitExpr_(const CallNode* op) final {
if (auto* ptr_op = op->op.as<OpNode>()) {
- // Still use legacy string based rewriting
- // TODO(tvm-team): migrate the pattern application from global function
look up
- // to an OpAttrMap<PackedFunc>
- std::string name = ptr_op->name;
- PrimExpr r = ApplyPattern(name, GetRef<PrimExpr>(op));
- if (r.defined()) return r;
+ for (size_t i = 0; i < patterns_.size(); ++i)
+ if (Op::HasAttrMap(patterns_[i])) {
+ auto default_intrin = Op::GetAttrMap<FLowerIntrinsic>(patterns_[i]);
+ FLowerIntrinsic f = default_intrin.get(GetRef<Op>(ptr_op), nullptr);
+ PrimExpr e = GetRef<PrimExpr>(op);
+ if (f != nullptr) {
+ PrimExpr r = f(e);
Review comment:
```suggestion
if (f != nullptr) {
PrimExpr e = GetRef<PrimExpr>(op);
PrimExpr r = f(e);
```
##########
File path: src/tir/transforms/lower_intrin.cc
##########
@@ -42,28 +42,36 @@ class IntrinInjecter : public
tvm::arith::IRMutatorWithAnalyzer {
IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string
mtriple = "")
: IRMutatorWithAnalyzer(analyzer) {
- patterns_.push_back("tvm.intrin.rule." + target + ".");
+ patterns_.push_back(target + ".FLowerIntrinsic");
bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos);
if (is_llvm_aarch64) {
- patterns_.push_back("tvm.intrin.rule." + target + "." + "aarch64.");
+ patterns_.push_back(target + ".aarch64.FLowerIntrinsic");
}
- patterns_.push_back("tvm.intrin.rule.default.");
- fma_ = runtime::Registry::Get(patterns_[0] + "fma");
+ patterns_.push_back("default.FLowerIntrinsic");
+ fma_ = runtime::Registry::Get("tvm.intrin.rule." + target + ".fma");
if (target == "stackvm") {
support_bitwise_op_ = false;
}
}
PrimExpr VisitExpr_(const CallNode* op) final {
if (auto* ptr_op = op->op.as<OpNode>()) {
- // Still use legacy string based rewriting
- // TODO(tvm-team): migrate the pattern application from global function
look up
- // to an OpAttrMap<PackedFunc>
- std::string name = ptr_op->name;
- PrimExpr r = ApplyPattern(name, GetRef<PrimExpr>(op));
- if (r.defined()) return r;
+ for (size_t i = 0; i < patterns_.size(); ++i)
+ if (Op::HasAttrMap(patterns_[i])) {
+ auto default_intrin = Op::GetAttrMap<FLowerIntrinsic>(patterns_[i]);
+ FLowerIntrinsic f = default_intrin.get(GetRef<Op>(ptr_op), nullptr);
+ PrimExpr e = GetRef<PrimExpr>(op);
+ if (f != nullptr) {
+ PrimExpr r = f(e);
+ ICHECK(r.defined()) << "intrinsic rule must always return valid
Expr";
+ if (!r.same_as(e)) {
+ r = this->VisitExpr(r);
+ if (r.defined()) return r;
Review comment:
```suggestion
if (r.defined()) {
return r;
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]