This is an automated email from the ASF dual-hosted git repository. tkonolige pushed a commit to branch tkonolige/relax_pad_etc_new in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 9896c6a39464fcefadfa9b7e99801de89b830db7 Author: Tristan Konolige <[email protected]> AuthorDate: Mon May 15 23:14:36 2023 +0000 fix tuning, comparisons --- python/tvm/relax/frontend/torch/fx_translator.py | 19 +++++++++++++++++-- python/tvm/relax/utils.py | 2 +- src/relax/ir/emit_te.cc | 2 +- src/tir/analysis/estimate_flops.cc | 15 ++++++++++++--- 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index bdf2c3375f..6222b11432 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -376,6 +376,18 @@ class TorchFXImporter: lhs, rhs = self.retrieve_args(node) return self._call_binary_op(relax.op.less, lhs, rhs) + def _le(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + return self._call_binary_op(relax.op.less_equal, lhs, rhs) + + def _gt(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + return self._call_binary_op(relax.op.greater, lhs, rhs) + + def _ge(self, node: fx.node.Node) -> relax.Expr: + lhs, rhs = self.retrieve_args(node) + return self._call_binary_op(relax.op.greater_equal, lhs, rhs) + def _eq(self, node: fx.node.Node) -> relax.Expr: lhs, rhs = self.retrieve_args(node) return self._call_binary_op(relax.op.equal, lhs, rhs) @@ -1154,8 +1166,8 @@ class TorchFXImporter: from torch import fx if any([isinstance(s, fx.node.Node) for s in size]): - size = relax.op.concat([self.env[s] for s in size]) - # import IPython; IPython.embed() + size = relax.op.tensor_to_shape(relax.op.concat([self.env[s] for s in size])) + print(type(size)) if method.startswith("nearest"): method = "nearest_neighbor" @@ -1426,6 +1438,9 @@ class TorchFXImporter: "sqrt": self._sqrt, "round": self._round, "lt": self._lt, + "le": self._le, + "gt": self._gt, + "ge": self._ge, "eq": self._eq, "ne": self._ne, "truediv": self._truediv, diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 02dd941080..8593af7b16 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -368,7 +368,7 @@ def gen_call_tir_inputs( if isinstance(arg.struct_info, ShapeStructInfo): assert isinstance( arg, ShapeExpr - ), "For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr" + ), f"For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr (got {type(arg)}). You may need to run tvm.relax.transform.DecomposeOpsForInference first." return [_convert_te_arg_helper(val) for val in arg.values] if isinstance(arg.struct_info, PrimStructInfo): return arg.value diff --git a/src/relax/ir/emit_te.cc b/src/relax/ir/emit_te.cc index bfb5896c99..23b840b8ca 100644 --- a/src/relax/ir/emit_te.cc +++ b/src/relax/ir/emit_te.cc @@ -60,7 +60,7 @@ te::Tensor TETensor(Expr value, Map<tir::Var, PrimExpr> tir_var_map, std::string } ICHECK(value->struct_info_.defined()) << "value must be normalized and contain StructInfo"; auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(value); - ICHECK(tensor_sinfo) << "Value must be a tensor"; + ICHECK(tensor_sinfo) << "Value must be a tensor but it is a " << value->GetTypeKey(); auto* shape_expr = tensor_sinfo->shape.as<ShapeExprNode>(); CHECK(shape_expr) << "ValueError: Expression does not have an known symbolic shape, please consider use " diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index 1c427e5fd9..59274f9c45 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -69,6 +69,13 @@ struct TResult { return *this; } + TResult SetInvalid() { + for (auto& kv : data_) { + kv.second = -1; + } + return *this; + } + TResult MaxWith(const TResult& rhs) { for (const auto& kv : rhs.data_) { double& v = data_[kv.first]; @@ -140,9 +147,11 @@ class FlopEstimator : private ExprFunctor<TResult(const PrimExpr& n)>, TResult VisitStmt_(const ForNode* loop) override { TResult result = VisitStmt(loop->body); const auto* int_imm = loop->extent.as<IntImmNode>(); - ICHECK(int_imm) << "TypeError: Expect the extent of a loop to be IntImm, but gets: " - << loop->extent->GetTypeKey(); - result *= int_imm->value; + if(int_imm == nullptr) { + result.SetInvalid(); + } else { + result *= int_imm->value; + } return result; }
