This is an automated email from the ASF dual-hosted git repository.

tkonolige pushed a commit to branch tkonolige/relax_pad_etc_new
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 9896c6a39464fcefadfa9b7e99801de89b830db7
Author: Tristan Konolige <[email protected]>
AuthorDate: Mon May 15 23:14:36 2023 +0000

    fix tuning, comparisons
---
 python/tvm/relax/frontend/torch/fx_translator.py | 19 +++++++++++++++++--
 python/tvm/relax/utils.py                        |  2 +-
 src/relax/ir/emit_te.cc                          |  2 +-
 src/tir/analysis/estimate_flops.cc               | 15 ++++++++++++---
 4 files changed, 31 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index bdf2c3375f..6222b11432 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -376,6 +376,18 @@ class TorchFXImporter:
         lhs, rhs = self.retrieve_args(node)
         return self._call_binary_op(relax.op.less, lhs, rhs)
 
+    def _le(self, node: fx.node.Node) -> relax.Expr:
+        lhs, rhs = self.retrieve_args(node)
+        return self._call_binary_op(relax.op.less_equal, lhs, rhs)
+
+    def _gt(self, node: fx.node.Node) -> relax.Expr:
+        lhs, rhs = self.retrieve_args(node)
+        return self._call_binary_op(relax.op.greater, lhs, rhs)
+
+    def _ge(self, node: fx.node.Node) -> relax.Expr:
+        lhs, rhs = self.retrieve_args(node)
+        return self._call_binary_op(relax.op.greater_equal, lhs, rhs)
+
     def _eq(self, node: fx.node.Node) -> relax.Expr:
         lhs, rhs = self.retrieve_args(node)
         return self._call_binary_op(relax.op.equal, lhs, rhs)
@@ -1154,8 +1166,8 @@ class TorchFXImporter:
 
         from torch import fx
         if any([isinstance(s, fx.node.Node) for s in size]):
-            size = relax.op.concat([self.env[s] for s in size])
-            # import IPython; IPython.embed()
+            size = relax.op.tensor_to_shape(relax.op.concat([self.env[s] for s 
in size]))
+            print(type(size))
 
         if method.startswith("nearest"):
             method = "nearest_neighbor"
@@ -1426,6 +1438,9 @@ class TorchFXImporter:
             "sqrt": self._sqrt,
             "round": self._round,
             "lt": self._lt,
+            "le": self._le,
+            "gt": self._gt,
+            "ge": self._ge,
             "eq": self._eq,
             "ne": self._ne,
             "truediv": self._truediv,
diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index 02dd941080..8593af7b16 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -368,7 +368,7 @@ def gen_call_tir_inputs(
                 if isinstance(arg.struct_info, ShapeStructInfo):
                     assert isinstance(
                         arg, ShapeExpr
-                    ), "For Expr having ShapeStructInfo, emit_te now only 
supports ShapeExpr"
+                    ), f"For Expr having ShapeStructInfo, emit_te now only 
supports ShapeExpr (got {type(arg)}). You may need to run 
tvm.relax.transform.DecomposeOpsForInference first."
                     return [_convert_te_arg_helper(val) for val in arg.values]
                 if isinstance(arg.struct_info, PrimStructInfo):
                     return arg.value
diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc
index bfb5896c99..23b840b8ca 100644
--- a/src/relax/ir/emit_te.cc
+++ b/src/relax/ir/emit_te.cc
@@ -60,7 +60,7 @@ te::Tensor TETensor(Expr value, Map<tir::Var, PrimExpr> 
tir_var_map, std::string
   }
   ICHECK(value->struct_info_.defined()) << "value must be normalized and 
contain StructInfo";
   auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(value);
-  ICHECK(tensor_sinfo) << "Value must be a tensor";
+  ICHECK(tensor_sinfo) << "Value must be a tensor but it is a " << 
value->GetTypeKey();
   auto* shape_expr = tensor_sinfo->shape.as<ShapeExprNode>();
   CHECK(shape_expr)
       << "ValueError: Expression does not have an known symbolic shape, please 
consider use "
diff --git a/src/tir/analysis/estimate_flops.cc 
b/src/tir/analysis/estimate_flops.cc
index 1c427e5fd9..59274f9c45 100644
--- a/src/tir/analysis/estimate_flops.cc
+++ b/src/tir/analysis/estimate_flops.cc
@@ -69,6 +69,13 @@ struct TResult {
     return *this;
   }
 
+  TResult SetInvalid() {
+    for (auto& kv : data_) {
+      kv.second = -1;
+    }
+    return *this;
+  }
+
   TResult MaxWith(const TResult& rhs) {
     for (const auto& kv : rhs.data_) {
       double& v = data_[kv.first];
@@ -140,9 +147,11 @@ class FlopEstimator : private ExprFunctor<TResult(const 
PrimExpr& n)>,
   TResult VisitStmt_(const ForNode* loop) override {
     TResult result = VisitStmt(loop->body);
     const auto* int_imm = loop->extent.as<IntImmNode>();
-    ICHECK(int_imm) << "TypeError: Expect the extent of a loop to be IntImm, 
but gets: "
-                    << loop->extent->GetTypeKey();
-    result *= int_imm->value;
+    if(int_imm == nullptr) {
+      result.SetInvalid();
+    } else {
+      result *= int_imm->value;
+    }
     return result;
   }
 

Reply via email to