junrushao1994 commented on a change in pull request #7809:
URL: https://github.com/apache/tvm/pull/7809#discussion_r616979231



##########
File path: src/tir/transforms/lower_intrin.cc
##########
@@ -42,28 +42,36 @@ class IntrinInjecter : public 
tvm::arith::IRMutatorWithAnalyzer {
 
   IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string 
mtriple = "")
       : IRMutatorWithAnalyzer(analyzer) {
-    patterns_.push_back("tvm.intrin.rule." + target + ".");
+    patterns_.push_back(target + ".FLowerIntrinsic");
 
     bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos);
     if (is_llvm_aarch64) {
-      patterns_.push_back("tvm.intrin.rule." + target + "." + "aarch64.");
+      patterns_.push_back(target + ".aarch64.FLowerIntrinsic");
     }
 
-    patterns_.push_back("tvm.intrin.rule.default.");
-    fma_ = runtime::Registry::Get(patterns_[0] + "fma");
+    patterns_.push_back("default.FLowerIntrinsic");
+    fma_ = runtime::Registry::Get("tvm.intrin.rule." + target + ".fma");
     if (target == "stackvm") {
       support_bitwise_op_ = false;
     }
   }
 
   PrimExpr VisitExpr_(const CallNode* op) final {
     if (auto* ptr_op = op->op.as<OpNode>()) {
-      // Still use legacy string based rewriting
-      // TODO(tvm-team): migrate the pattern application from global function 
look up
-      // to an OpAttrMap<PackedFunc>
-      std::string name = ptr_op->name;
-      PrimExpr r = ApplyPattern(name, GetRef<PrimExpr>(op));
-      if (r.defined()) return r;
+      for (size_t i = 0; i < patterns_.size(); ++i)

Review comment:
       ```suggestion
         for (const std::string& pattern : patterns_)
   ```

##########
File path: src/tir/transforms/lower_intrin.cc
##########
@@ -42,28 +42,36 @@ class IntrinInjecter : public 
tvm::arith::IRMutatorWithAnalyzer {
 
   IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string 
mtriple = "")
       : IRMutatorWithAnalyzer(analyzer) {
-    patterns_.push_back("tvm.intrin.rule." + target + ".");
+    patterns_.push_back(target + ".FLowerIntrinsic");
 
     bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos);
     if (is_llvm_aarch64) {
-      patterns_.push_back("tvm.intrin.rule." + target + "." + "aarch64.");
+      patterns_.push_back(target + ".aarch64.FLowerIntrinsic");
     }
 
-    patterns_.push_back("tvm.intrin.rule.default.");
-    fma_ = runtime::Registry::Get(patterns_[0] + "fma");
+    patterns_.push_back("default.FLowerIntrinsic");
+    fma_ = runtime::Registry::Get("tvm.intrin.rule." + target + ".fma");
     if (target == "stackvm") {
       support_bitwise_op_ = false;
     }
   }
 
   PrimExpr VisitExpr_(const CallNode* op) final {
     if (auto* ptr_op = op->op.as<OpNode>()) {
-      // Still use legacy string based rewriting
-      // TODO(tvm-team): migrate the pattern application from global function 
look up
-      // to an OpAttrMap<PackedFunc>
-      std::string name = ptr_op->name;
-      PrimExpr r = ApplyPattern(name, GetRef<PrimExpr>(op));
-      if (r.defined()) return r;
+      for (size_t i = 0; i < patterns_.size(); ++i)
+        if (Op::HasAttrMap(patterns_[i])) {
+          auto default_intrin = Op::GetAttrMap<FLowerIntrinsic>(patterns_[i]);

Review comment:
       it is not necessarily default intrin, so let's find a more accurate name 
here? like, `f_lower_intrin_map`?

##########
File path: src/tir/transforms/lower_intrin.cc
##########
@@ -42,28 +42,36 @@ class IntrinInjecter : public 
tvm::arith::IRMutatorWithAnalyzer {
 
   IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string 
mtriple = "")
       : IRMutatorWithAnalyzer(analyzer) {
-    patterns_.push_back("tvm.intrin.rule." + target + ".");
+    patterns_.push_back(target + ".FLowerIntrinsic");
 
     bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos);
     if (is_llvm_aarch64) {
-      patterns_.push_back("tvm.intrin.rule." + target + "." + "aarch64.");
+      patterns_.push_back(target + ".aarch64.FLowerIntrinsic");
     }
 
-    patterns_.push_back("tvm.intrin.rule.default.");
-    fma_ = runtime::Registry::Get(patterns_[0] + "fma");
+    patterns_.push_back("default.FLowerIntrinsic");
+    fma_ = runtime::Registry::Get("tvm.intrin.rule." + target + ".fma");
     if (target == "stackvm") {
       support_bitwise_op_ = false;
     }
   }
 
   PrimExpr VisitExpr_(const CallNode* op) final {
     if (auto* ptr_op = op->op.as<OpNode>()) {
-      // Still use legacy string based rewriting
-      // TODO(tvm-team): migrate the pattern application from global function 
look up
-      // to an OpAttrMap<PackedFunc>
-      std::string name = ptr_op->name;
-      PrimExpr r = ApplyPattern(name, GetRef<PrimExpr>(op));
-      if (r.defined()) return r;
+      for (size_t i = 0; i < patterns_.size(); ++i)
+        if (Op::HasAttrMap(patterns_[i])) {
+          auto default_intrin = Op::GetAttrMap<FLowerIntrinsic>(patterns_[i]);
+          FLowerIntrinsic f = default_intrin.get(GetRef<Op>(ptr_op), nullptr);
+          PrimExpr e = GetRef<PrimExpr>(op);
+          if (f != nullptr) {
+            PrimExpr r = f(e);

Review comment:
       ```suggestion
             if (f != nullptr) {
               PrimExpr e = GetRef<PrimExpr>(op);
               PrimExpr r = f(e);
   ```

##########
File path: src/tir/transforms/lower_intrin.cc
##########
@@ -42,28 +42,36 @@ class IntrinInjecter : public 
tvm::arith::IRMutatorWithAnalyzer {
 
   IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string 
mtriple = "")
       : IRMutatorWithAnalyzer(analyzer) {
-    patterns_.push_back("tvm.intrin.rule." + target + ".");
+    patterns_.push_back(target + ".FLowerIntrinsic");
 
     bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos);
     if (is_llvm_aarch64) {
-      patterns_.push_back("tvm.intrin.rule." + target + "." + "aarch64.");
+      patterns_.push_back(target + ".aarch64.FLowerIntrinsic");
     }
 
-    patterns_.push_back("tvm.intrin.rule.default.");
-    fma_ = runtime::Registry::Get(patterns_[0] + "fma");
+    patterns_.push_back("default.FLowerIntrinsic");
+    fma_ = runtime::Registry::Get("tvm.intrin.rule." + target + ".fma");
     if (target == "stackvm") {
       support_bitwise_op_ = false;
     }
   }
 
   PrimExpr VisitExpr_(const CallNode* op) final {
     if (auto* ptr_op = op->op.as<OpNode>()) {
-      // Still use legacy string based rewriting
-      // TODO(tvm-team): migrate the pattern application from global function 
look up
-      // to an OpAttrMap<PackedFunc>
-      std::string name = ptr_op->name;
-      PrimExpr r = ApplyPattern(name, GetRef<PrimExpr>(op));
-      if (r.defined()) return r;
+      for (size_t i = 0; i < patterns_.size(); ++i)
+        if (Op::HasAttrMap(patterns_[i])) {
+          auto default_intrin = Op::GetAttrMap<FLowerIntrinsic>(patterns_[i]);
+          FLowerIntrinsic f = default_intrin.get(GetRef<Op>(ptr_op), nullptr);
+          PrimExpr e = GetRef<PrimExpr>(op);
+          if (f != nullptr) {
+            PrimExpr r = f(e);
+            ICHECK(r.defined()) << "intrinsic rule must always return valid 
Expr";
+            if (!r.same_as(e)) {
+              r = this->VisitExpr(r);
+              if (r.defined()) return r;

Review comment:
       ```suggestion
                 if (r.defined()) {
                   return r;
                 }
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to