This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5a5ccd8df1 [REFACTOR] Phase out LegacyReprPrinter and improve
CommonSubExprElim (#18080)
5a5ccd8df1 is described below
commit 5a5ccd8df13f47fed5d1d8281e887e4adb54c05c
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Jun 19 09:52:36 2025 -0400
[REFACTOR] Phase out LegacyReprPrinter and improve CommonSubExprElim
(#18080)
This PR phases out LegacyReprPrinter. Previously common subexpr elim
relies on sorting on legacy repr for determism, which is hacky.
This PR introduces an ordered_map impl in support to ensure determinism
and migrates the CSE pass to use that instead.
---
include/tvm/node/repr_printer.h | 32 -
src/ir/analysis.cc | 2 +-
src/node/repr_printer.cc | 35 -
src/node/script_printer.cc | 5 +-
src/relax/analysis/computable_at_compile_time.cc | 2 +-
src/relax/analysis/udchain.cc | 4 +-
src/relax/ir/binding_rewrite.cc | 3 +-
src/relax/transform/inline_functions.cc | 2 +-
src/relax/transform/run_codegen.cc | 2 +-
src/script/printer/legacy_repr.cc | 894 ---------------------
src/script/printer/utils.h | 14 +-
src/support/ordered_map.h | 145 ++++
src/support/ordered_set.h | 57 +-
src/tir/transforms/common_subexpr_elim.cc | 42 +-
src/tir/transforms/common_subexpr_elim.h | 3 +-
src/tir/transforms/common_subexpr_elim_tools.cc | 28 +-
src/tir/transforms/common_subexpr_elim_tools.h | 8 +-
.../test_tir_transform_common_subexpr_elim.py | 58 +-
.../test_tir_transform_inject_ptx_async_copy.py | 20 +-
.../test_tir_transform_lower_tvm_builtin.py | 4 +-
tests/python/tvmscript/test_tvmscript_roundtrip.py | 10 +-
21 files changed, 249 insertions(+), 1121 deletions(-)
diff --git a/include/tvm/node/repr_printer.h b/include/tvm/node/repr_printer.h
index 30bfe8e951..e3baf397f2 100644
--- a/include/tvm/node/repr_printer.h
+++ b/include/tvm/node/repr_printer.h
@@ -52,32 +52,6 @@ class ReprPrinter {
TVM_DLL static FType& vtable();
};
-/*! \brief Legacy behavior of ReprPrinter. */
-class ReprLegacyPrinter {
- public:
- /*! \brief The indentation level. */
- int indent{0};
-
- explicit ReprLegacyPrinter(std::ostream& stream) // NOLINT(*)
- : stream(stream) {}
-
- /*! \brief The node to be printed. */
- TVM_DLL void Print(const ObjectRef& node);
- /*! \brief Print indent to the stream */
- TVM_DLL void PrintIndent();
- /*! \brief Could the LegacyPrinter dispatch the node */
- TVM_DLL static bool CanDispatch(const ObjectRef& node);
- /*! \brief Return the ostream it maintains */
- TVM_DLL std::ostream& Stream() const;
- // Allow registration to be printer.
- using FType = NodeFunctor<void(const ObjectRef&, ReprLegacyPrinter*)>;
- TVM_DLL static FType& vtable();
-
- private:
- /*! \brief The output stream */
- std::ostream& stream;
-};
-
/*!
* \brief Dump the node to stderr, used for debug purposes.
* \param node The input node
@@ -113,12 +87,6 @@ inline std::ostream& operator<<(std::ostream& os, const
Variant<V...>& n) { //
return os;
}
-inline std::string AsLegacyRepr(const ObjectRef& n) {
- std::ostringstream os;
- ReprLegacyPrinter(os).Print(n);
- return os.str();
-}
} // namespace ffi
-using ffi::AsLegacyRepr;
} // namespace tvm
#endif // TVM_NODE_REPR_PRINTER_H_
diff --git a/src/ir/analysis.cc b/src/ir/analysis.cc
index 3a54085c22..26a348bcee 100644
--- a/src/ir/analysis.cc
+++ b/src/ir/analysis.cc
@@ -31,7 +31,7 @@ namespace ir {
Map<GlobalVar, Array<GlobalVar>> CollectCallMap(const IRModule& mod) {
struct CalleeCollectorImpl : CalleeCollector {
void Mark(GlobalVar gvar) override { gvars.push_back(gvar); }
- support::OrderedSet<GlobalVar> gvars;
+ support::OrderedSet<GlobalVar, ObjectPtrHash, ObjectPtrEqual> gvars;
};
Map<GlobalVar, Array<GlobalVar>> call_map;
diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc
index aa999655c0..69cb05c121 100644
--- a/src/node/repr_printer.cc
+++ b/src/node/repr_printer.cc
@@ -97,38 +97,6 @@ ReprPrinter::FType& ReprPrinter::vtable() {
return inst;
}
-void ReprLegacyPrinter::Print(const ObjectRef& node) {
- static const FType& f = vtable();
- if (!node.defined()) {
- stream << "(nullptr)";
- } else if (f.can_dispatch(node)) {
- f(node, this);
- } else {
- try {
- stream << node; // Use ReprPrinter
- } catch (const tvm::Error& e) {
- LOG(WARNING) << "ReprPrinter fails";
- stream << node->GetTypeKey() << '(' << node.get() << ')';
- }
- }
-}
-
-bool ReprLegacyPrinter::CanDispatch(const ObjectRef& node) {
- static const FType& f = vtable();
- return !node.defined() || f.can_dispatch(node);
-}
-
-void ReprLegacyPrinter::PrintIndent() {
- for (int i = 0; i < indent; ++i) {
- stream << ' ';
- }
-}
-
-ReprLegacyPrinter::FType& ReprLegacyPrinter::vtable() {
- static FType inst;
- return inst;
-}
-
void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; }
void Dump(const runtime::Object* n) {
Dump(runtime::GetRef<runtime::ObjectRef>(n)); }
@@ -138,7 +106,4 @@
TVM_FFI_REGISTER_GLOBAL("node.AsRepr").set_body_typed([](ffi::Any obj) {
os << obj;
return os.str();
});
-
-TVM_FFI_REGISTER_GLOBAL("node.AsLegacyRepr").set_body_typed(ffi::AsLegacyRepr);
-
} // namespace tvm
diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc
index ee7880f448..c815435796 100644
--- a/src/node/script_printer.cc
+++ b/src/node/script_printer.cc
@@ -32,7 +32,10 @@ TVMScriptPrinter::FType& TVMScriptPrinter::vtable() {
std::string TVMScriptPrinter::Script(const ObjectRef& node, const
Optional<PrinterConfig>& cfg) {
if (!TVMScriptPrinter::vtable().can_dispatch(node)) {
- return AsLegacyRepr(node);
+ std::ostringstream os;
+ ReprPrinter printer(os);
+ printer.Print(node);
+ return os.str();
}
return TVMScriptPrinter::vtable()(node, cfg.value_or(PrinterConfig()));
}
diff --git a/src/relax/analysis/computable_at_compile_time.cc
b/src/relax/analysis/computable_at_compile_time.cc
index ba163b51d6..5825895db7 100644
--- a/src/relax/analysis/computable_at_compile_time.cc
+++ b/src/relax/analysis/computable_at_compile_time.cc
@@ -83,7 +83,7 @@ class CompileTimeCollector : ExprVisitor {
}
}
- support::OrderedSet<Var> known_relax_vars_;
+ support::OrderedSet<Var, ObjectPtrHash, ObjectPtrEqual> known_relax_vars_;
std::unordered_set<tir::Var> known_tir_vars_;
};
} // namespace
diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc
index f62254b695..2f04d86594 100644
--- a/src/relax/analysis/udchain.cc
+++ b/src/relax/analysis/udchain.cc
@@ -56,8 +56,8 @@ class UDChain : relax::ExprVisitor {
private:
Map<Var, Expr> bound_values;
std::unordered_set<Var> forward_declarations;
- std::unordered_map<Var, support::OrderedSet<Var>> usage_map;
- support::OrderedSet<Var> outputs;
+ std::unordered_map<Var, support::OrderedSet<Var, ObjectPtrHash,
ObjectPtrEqual>> usage_map;
+ support::OrderedSet<Var, ObjectPtrHash, ObjectPtrEqual> outputs;
Optional<Var> cur_user_;
diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc
index f35b443b5b..11a0fd29a9 100644
--- a/src/relax/ir/binding_rewrite.cc
+++ b/src/relax/ir/binding_rewrite.cc
@@ -321,7 +321,8 @@ Expr RemoveAllUnused(Expr expr) {
auto var_usage = CollectVarUsage(expr);
// For the purpose of
- support::OrderedSet<Var> externally_exposed(var_usage.outputs.begin(),
var_usage.outputs.end());
+ support::OrderedSet<Var, ObjectPtrHash, ObjectPtrEqual> externally_exposed(
+ var_usage.outputs.begin(), var_usage.outputs.end());
for (const auto& [var, expr] : var_usage.bound_values) {
if (ContainsImpureCall(expr)) {
externally_exposed.insert(var);
diff --git a/src/relax/transform/inline_functions.cc
b/src/relax/transform/inline_functions.cc
index 26b106373f..e295226e9e 100644
--- a/src/relax/transform/inline_functions.cc
+++ b/src/relax/transform/inline_functions.cc
@@ -138,7 +138,7 @@ class FunctionInliner : public ExprMutator {
}
const Map<Variant<String, GlobalVar>, Function>& replacements_;
- support::OrderedSet<GlobalVar> inline_stack_;
+ std::unordered_set<GlobalVar, ObjectPtrHash, ObjectPtrEqual> inline_stack_;
};
} // namespace
diff --git a/src/relax/transform/run_codegen.cc
b/src/relax/transform/run_codegen.cc
index d29bdaacb9..33d3f485a5 100644
--- a/src/relax/transform/run_codegen.cc
+++ b/src/relax/transform/run_codegen.cc
@@ -44,7 +44,7 @@ class CodeGenRunner : ExprMutator {
Array<String> entry_function_names) {
IRModule mod = builder_->GetContextIRModule();
- support::OrderedSet<GlobalVar> entry_functions;
+ support::OrderedSet<GlobalVar, ObjectPtrHash, ObjectPtrEqual>
entry_functions;
// Any user-provided functions are treated as entry functions.
for (const auto& name : entry_function_names) {
entry_functions.insert(mod->GetGlobalVar(name));
diff --git a/src/script/printer/legacy_repr.cc
b/src/script/printer/legacy_repr.cc
deleted file mode 100644
index 57dd691b88..0000000000
--- a/src/script/printer/legacy_repr.cc
+++ /dev/null
@@ -1,894 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#include <tvm/ir/module.h>
-#include <tvm/tir/expr.h>
-#include <tvm/tir/op.h>
-#include <tvm/tir/stmt.h>
-#include <tvm/tir/stmt_functor.h>
-
-#include <sstream>
-
-#include "../../support/str_escape.h"
-
-namespace tvm {
-
-#define TVM_LEGACY_REPR_PRINTER_DEF_OP(Type) \
- ReprLegacyPrinter& operator<<(ReprLegacyPrinter& p, Type value) { \
- p.Stream() << value; \
- return p; \
- }
-
-TVM_LEGACY_REPR_PRINTER_DEF_OP(int);
-TVM_LEGACY_REPR_PRINTER_DEF_OP(int64_t);
-TVM_LEGACY_REPR_PRINTER_DEF_OP(float);
-TVM_LEGACY_REPR_PRINTER_DEF_OP(double);
-TVM_LEGACY_REPR_PRINTER_DEF_OP(char);
-TVM_LEGACY_REPR_PRINTER_DEF_OP(const char*);
-TVM_LEGACY_REPR_PRINTER_DEF_OP(const std::string&);
-TVM_LEGACY_REPR_PRINTER_DEF_OP(runtime::DataType);
-TVM_LEGACY_REPR_PRINTER_DEF_OP(const void*);
-TVM_LEGACY_REPR_PRINTER_DEF_OP(const String&);
-
-std::ostream& ReprLegacyPrinter::Stream() const { return stream; }
-
-ReprLegacyPrinter& operator<<(ReprLegacyPrinter& p, const ObjectRef& value) {
- p.Stream() << AsLegacyRepr(value);
- return p;
-}
-
-ReprLegacyPrinter& operator<<(ReprLegacyPrinter& out, tir::ForKind type) { //
NOLINT(*)
- using tvm::tir::ForKind;
- switch (type) {
- case ForKind::kSerial:
- out << "for";
- break;
- case ForKind::kParallel:
- out << "parallel";
- break;
- case ForKind::kUnrolled:
- out << "unrolled";
- break;
- case ForKind::kVectorized:
- out << "vectorized";
- break;
- case ForKind::kThreadBinding:
- out << "launch_thread";
- break;
- }
- return out;
-}
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<ffi::ArrayObj>([](const ObjectRef& node, ReprLegacyPrinter*
p) {
- auto* op = static_cast<const ffi::ArrayObj*>(node.get());
- (*p) << '[';
- for (size_t i = 0; i < op->size(); ++i) {
- if (i != 0) {
- (*p) << ", ";
- }
- p->Print(op->at(i).cast<ObjectRef>());
- }
- (*p) << ']';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<ffi::MapObj>([](const ObjectRef& node, ReprLegacyPrinter* p)
{
- auto* op = static_cast<const ffi::MapObj*>(node.get());
- (*p) << '{';
- for (auto it = op->begin(); it != op->end(); ++it) {
- if (it != op->begin()) {
- (*p) << ", ";
- }
- if (it->first.as<ffi::StringObj>()) {
- (*p) << '\"' << Downcast<ffi::String>(it->first) << "\": ";
- } else {
- p->Print(it->first.cast<ObjectRef>());
- (*p) << ": ";
- }
- p->Print(it->second.cast<ObjectRef>());
- }
- (*p) << '}';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<ffi::ShapeObj>([](const ObjectRef& node, ReprLegacyPrinter*
p) {
- auto* op = static_cast<const ffi::ShapeObj*>(node.get());
- (*p) << '[';
- for (size_t i = 0; i < op->size; ++i) {
- if (i != 0) {
- (*p) << ", ";
- }
- (*p) << op->data[i];
- }
- (*p) << ']';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<IntImmNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const IntImmNode*>(node.get());
- if (op->dtype == DataType::Int(32)) {
- (*p) << op->value;
- } else {
- (*p) << "(" << op->dtype << ")" << op->value;
- }
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprLegacyPrinter*
p) {
- auto* op = static_cast<const FloatImmNode*>(node.get());
- switch (op->dtype.bits()) {
- case 64:
- (*p) << op->value;
- break;
- case 32:
- (*p) << op->value << 'f';
- break;
- case 16:
- (*p) << op->value << 'h';
- break;
- default:
- LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits();
- }
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<RangeNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const RangeNode*>(node.get());
- (*p) << "range(min=" << op->min << ", ext=" << op->extent << ')';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<PrimTypeNode>([](const ObjectRef& ref, ReprLegacyPrinter* p)
{
- auto* node = static_cast<const PrimTypeNode*>(ref.get());
- (*p) << node->dtype;
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<PointerTypeNode>([](const ObjectRef& ref, ReprLegacyPrinter*
p) {
- auto* node = static_cast<const PointerTypeNode*>(ref.get());
- if (!node->storage_scope.empty()) {
- (*p) << node->storage_scope << " ";
- }
- p->Print(node->element_type);
- (*p) << '*';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<TupleTypeNode>([](const ObjectRef& ref, ReprLegacyPrinter*
p) {
- auto* node = static_cast<const TupleTypeNode*>(ref.get());
- (*p) << "TupleTypeNode(" << node->fields << ")";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<DictAttrsNode>([](const ObjectRef& node, ReprLegacyPrinter*
p) {
- auto* op = static_cast<const DictAttrsNode*>(node.get());
- (*p) << op->dict;
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<GlobalVarNode>([](const ObjectRef& ref, ReprLegacyPrinter*
p) {
- auto* node = static_cast<const GlobalVarNode*>(ref.get());
- (*p) << "GlobalVar(" << node->name_hint << ")";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<IRModuleNode>([](const ObjectRef& ref, ReprLegacyPrinter* p)
{
- auto* node = static_cast<const IRModuleNode*>(ref.get());
- (*p) << "IRModule(" << node->functions << ")";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<FuncTypeNode>([](const ObjectRef& ref, ReprLegacyPrinter* p)
{
- auto* node = static_cast<const FuncTypeNode*>(ref.get());
- (*p) << "FuncType(" << node->arg_types << ", " << node->ret_type << ")";
- });
-
-} // namespace tvm
-
-namespace tvm {
-namespace tir {
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<BufferNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const BufferNode*>(node.get());
- (*p) << "buffer(" << op->name << ", " << op << ")";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<VarNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const VarNode*>(node.get());
- // omit the type
- // stream << op->name << "." << op->type;
- (*p) << op->name_hint;
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<SizeVarNode>([](const ObjectRef& node, ReprLegacyPrinter* p)
{
- auto* op = static_cast<const SizeVarNode*>(node.get());
- (*p) << "{" << op->name_hint << "|" << op->name_hint << ">=0}";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<IterVarNode>([](const ObjectRef& node, ReprLegacyPrinter* p)
{
- auto* op = static_cast<const IterVarNode*>(node.get());
- (*p) << "iter_var(";
- if (op->var->name_hint.length() != 0) {
- (*p) << op->var->name_hint << ", ";
- }
- if (op->dom.defined()) {
- (*p) << op->dom;
- }
- if (op->thread_tag.length() != 0) {
- (*p) << ", " << op->thread_tag;
- }
- (*p) << ")";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<StringImmNode>([](const ObjectRef& node, ReprLegacyPrinter*
p) {
- auto* op = static_cast<const StringImmNode*>(node.get());
- (*p) << '\"' << support::StrEscape(op->value) << '\"';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<CastNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const CastNode*>(node.get());
- (*p) << op->dtype << '(';
- p->Print(op->value);
- (*p) << ')';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<AddNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const AddNode*>(node.get());
- (*p) << '(';
- p->Print(op->a);
- (*p) << " + ";
- p->Print(op->b);
- (*p) << ')';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<SubNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const SubNode*>(node.get());
- (*p) << '(';
- p->Print(op->a);
- (*p) << " - ";
- p->Print(op->b);
- (*p) << ')';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<MulNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const MulNode*>(node.get());
- (*p) << '(';
- p->Print(op->a);
- (*p) << "*";
- p->Print(op->b);
- (*p) << ')';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<DivNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const DivNode*>(node.get());
- (*p) << '(';
- p->Print(op->a);
- (*p) << "/";
- p->Print(op->b);
- (*p) << ')';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<ModNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const ModNode*>(node.get());
- (*p) << '(';
- p->Print(op->a);
- (*p) << " % ";
- p->Print(op->b);
- (*p) << ')';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<FloorDivNode>([](const ObjectRef& node, ReprLegacyPrinter*
p) {
- auto* op = static_cast<const FloorDivNode*>(node.get());
- (*p) << "floordiv(" << op->a << ", " << op->b << ")";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<FloorModNode>([](const ObjectRef& node, ReprLegacyPrinter*
p) {
- auto* op = static_cast<const FloorModNode*>(node.get());
- (*p) << "floormod(" << op->a << ", " << op->b << ")";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<MinNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const MinNode*>(node.get());
- (*p) << "min(";
- p->Print(op->a);
- (*p) << ", ";
- p->Print(op->b);
- (*p) << ")";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<MaxNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const MaxNode*>(node.get());
- (*p) << "max(";
- p->Print(op->a);
- (*p) << ", ";
- p->Print(op->b);
- (*p) << ")";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<EQNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const EQNode*>(node.get());
- (*p) << '(';
- p->Print(op->a);
- (*p) << " == ";
- p->Print(op->b);
- (*p) << ')';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<NENode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const NENode*>(node.get());
- (*p) << '(';
- p->Print(op->a);
- (*p) << " != ";
- p->Print(op->b);
- (*p) << ')';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<LTNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const LTNode*>(node.get());
- (*p) << '(';
- p->Print(op->a);
- (*p) << " < ";
- p->Print(op->b);
- (*p) << ')';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<LENode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const LENode*>(node.get());
- (*p) << '(';
- p->Print(op->a);
- (*p) << " <= ";
- p->Print(op->b);
- (*p) << ')';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<GTNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const GTNode*>(node.get());
- (*p) << '(';
- p->Print(op->a);
- (*p) << " > ";
- p->Print(op->b);
- (*p) << ')';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<GENode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const GENode*>(node.get());
- (*p) << '(';
- p->Print(op->a);
- (*p) << " >= ";
- p->Print(op->b);
- (*p) << ')';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<AndNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const AndNode*>(node.get());
- (*p) << '(';
- p->Print(op->a);
- (*p) << " && ";
- p->Print(op->b);
- (*p) << ')';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<OrNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const OrNode*>(node.get());
- (*p) << '(';
- p->Print(op->a);
- (*p) << " || ";
- p->Print(op->b);
- (*p) << ')';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<NotNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const NotNode*>(node.get());
- (*p) << '!';
- p->Print(op->a);
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<SelectNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const SelectNode*>(node.get());
- (*p) << "select(";
- p->Print(op->condition);
- (*p) << ", ";
- p->Print(op->true_value);
- (*p) << ", ";
- p->Print(op->false_value);
- (*p) << ")";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<RampNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const RampNode*>(node.get());
- (*p) << "ramp(";
- p->Print(op->base);
- (*p) << ", ";
- p->Print(op->stride);
- (*p) << ", " << op->lanes << ")";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<BroadcastNode>([](const ObjectRef& node, ReprLegacyPrinter*
p) {
- auto* op = static_cast<const BroadcastNode*>(node.get());
- (*p) << "x" << op->lanes << "(";
- p->Print(op->value);
- (*p) << ")";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<LetNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const LetNode*>(node.get());
- (*p) << "(let " << op->var << " = ";
- p->Print(op->value);
- (*p) << " in ";
- p->Print(op->body);
- (*p) << ")";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<CallNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const CallNode*>(node.get());
- if (auto* ptr_op = op->op.as<OpNode>()) {
- (*p) << ptr_op->name << "(";
- } else {
- auto* ptr_gvar = op->op.as<GlobalVarNode>();
- ICHECK(ptr_gvar != nullptr);
- (*p) << "@" << ptr_gvar->name_hint << "(";
- }
- for (size_t i = 0; i < op->args.size(); ++i) {
- p->Print(op->args[i]);
- if (i < op->args.size() - 1) {
- (*p) << ", ";
- }
- }
- (*p) << ")";
- });
-
-template <typename T>
-void PrintList(const Array<T>& exprs, ReprLegacyPrinter* p) {
- for (size_t i = 0; i < exprs.size(); ++i) {
- p->Print(exprs[i]);
- if (i < exprs.size() - 1) {
- (*p) << ", ";
- }
- }
-}
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<ShuffleNode>([](const ObjectRef& node, ReprLegacyPrinter* p)
{
- auto* op = static_cast<const ShuffleNode*>(node.get());
- (*p) << "shuffle(";
- PrintList(op->vectors, p);
- (*p) << ", ";
- PrintList(op->indices, p);
- (*p) << ")";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<CommReducerNode>([](const ObjectRef& node,
ReprLegacyPrinter* p) {
- auto* op = static_cast<const CommReducerNode*>(node.get());
- (*p) << "comm_reducer(result=" << op->result << ", lhs=" << op->lhs <<
", rhs=" << op->rhs
- << ", identity_element=" << op->identity_element << ")";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<ReduceNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const ReduceNode*>(node.get());
- (*p) << "reduce(combiner=" << op->combiner;
- (*p) << ", source=" << op->source;
- (*p) << ", init=" << op->init;
- (*p) << ", axis=" << op->axis;
- (*p) << ", where=" << op->condition;
- (*p) << ", value_index=" << op->value_index;
- (*p) << ")";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<BufferLoadNode>([](const ObjectRef& node, ReprLegacyPrinter*
p) {
- auto* op = static_cast<const BufferLoadNode*>(node.get());
- (*p) << op->buffer->name << "[";
- for (size_t i = 0; i < op->indices.size(); ++i) {
- p->Print(op->indices[i]);
- if (i < op->indices.size() - 1) {
- (*p) << ", ";
- }
- }
- (*p) << "]";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<ProducerLoadNode>([](const ObjectRef& node,
ReprLegacyPrinter* p) {
- auto* op = static_cast<const ProducerLoadNode*>(node.get());
- (*p) << op->producer->GetNameHint() << "[";
- for (size_t i = 0; i < op->indices.size(); ++i) {
- p->Print(op->indices[i]);
- if (i < op->indices.size() - 1) {
- (*p) << ", ";
- }
- }
- (*p) << "]";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<PrimFuncNode>([](const ObjectRef& ref, ReprLegacyPrinter* p)
{
- auto* node = static_cast<const PrimFuncNode*>(ref.get());
- (*p) << "PrimFunc(" << node->params << ") ";
- if (node->attrs.defined()) {
- (*p) << "attrs=" << node->attrs;
- }
- (*p) << " {\n";
- p->indent += 2;
- p->Print(node->body);
- p->indent -= 2;
- (*p) << "}\n";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<LetStmtNode>([](const ObjectRef& node, ReprLegacyPrinter* p)
{
- auto* op = static_cast<const LetStmtNode*>(node.get());
- p->PrintIndent();
- (*p) << "let " << op->var << " = ";
- p->Print(op->value);
- (*p) << '\n';
- p->Print(op->body);
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<AttrStmtNode>([](const ObjectRef& node, ReprLegacyPrinter*
p) {
- auto* op = static_cast<const AttrStmtNode*>(node.get());
- p->PrintIndent();
- (*p) << "// attr [";
- p->Print(op->node);
- (*p) << "] " << op->attr_key << " = ";
- p->Print(op->value);
- (*p) << '\n';
- p->Print(op->body);
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<AssertStmtNode>([](const ObjectRef& node, ReprLegacyPrinter*
p) {
- auto* op = static_cast<const AssertStmtNode*>(node.get());
- p->PrintIndent();
- (*p) << "assert(";
- p->Print(op->condition);
- (*p) << ", ";
- p->Print(op->message);
- (*p) << ")\n";
- p->Print(op->body);
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<ForNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const ForNode*>(node.get());
- p->PrintIndent();
- (*p) << op->kind << " (" << op->loop_var << ", ";
- p->Print(op->min);
- (*p) << ", ";
- p->Print(op->extent);
- (*p) << ") {\n";
-
- p->indent += 2;
- p->Print(op->body);
- p->indent -= 2;
-
- p->PrintIndent();
- (*p) << "}\n";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<WhileNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const WhileNode*>(node.get());
- p->PrintIndent();
- (*p) << "while(" << op->condition << ") {\n";
- p->indent += 2;
- p->Print(op->body);
- p->indent -= 2;
- p->PrintIndent();
- (*p) << "}\n";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<AllocateNode>([](const ObjectRef& node, ReprLegacyPrinter*
p) {
- auto* op = static_cast<const AllocateNode*>(node.get());
- const auto* ptr_type =
op->buffer_var->type_annotation.as<PointerTypeNode>();
- ICHECK(ptr_type) << "The provided variable is not of pointer type";
- p->PrintIndent();
- (*p) << "allocate " << op->buffer_var << "[" << op->dtype;
- for (size_t i = 0; i < op->extents.size(); ++i) {
- (*p) << " * ";
- p->Print(op->extents[i]);
- }
- (*p) << "], storage_scope = " << ptr_type->storage_scope;
- if (!is_one(op->condition)) {
- (*p) << " if ";
- p->Print(op->condition);
- }
- (*p) << "\n";
- p->Print(op->body);
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<AllocateConstNode>([](const ObjectRef& node,
ReprLegacyPrinter* p) {
- auto* op = static_cast<const AllocateConstNode*>(node.get());
- p->PrintIndent();
- (*p) << "constant " << op->buffer_var << "[" << op->dtype;
- for (size_t i = 0; i < op->extents.size(); ++i) {
- (*p) << " * ";
- p->Print(op->extents[i]);
- }
- (*p) << "]";
- (*p) << "\n";
- p->Print(op->body);
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<DeclBufferNode>([](const ObjectRef& node, ReprLegacyPrinter*
p) {
- auto* op = static_cast<const DeclBufferNode*>(node.get());
- p->PrintIndent();
- (*p) << "decl_buffer " << op->buffer << "\n";
- (*p) << op->body;
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<SeqStmtNode>([](const ObjectRef& node, ReprLegacyPrinter* p)
{
- auto* op = static_cast<const SeqStmtNode*>(node.get());
- for (Stmt stmt : op->seq) {
- p->Print(stmt);
- }
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<IfThenElseNode>([](const ObjectRef& node, ReprLegacyPrinter*
p) {
- auto* op = static_cast<const IfThenElseNode*>(node.get());
- p->PrintIndent();
- while (true) {
- (*p) << "if (" << op->condition << ") {\n";
- p->indent += 2;
- p->Print(op->then_case);
- p->indent -= 2;
-
- if (!op->else_case) {
- break;
- }
-
- if (const IfThenElseNode* nested_if =
op->else_case.as<IfThenElseNode>()) {
- p->PrintIndent();
- (*p) << "} else ";
- op = nested_if;
- } else {
- p->PrintIndent();
- (*p) << "} else {\n";
- p->indent += 2;
- p->Print(op->else_case);
- p->indent -= 2;
- break;
- }
- }
- p->PrintIndent();
- (*p) << "}\n";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<EvaluateNode>([](const ObjectRef& node, ReprLegacyPrinter*
p) {
- auto* op = static_cast<const EvaluateNode*>(node.get());
- p->PrintIndent();
- p->Print(op->value);
- (*p) << "\n";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<BufferStoreNode>([](const ObjectRef& node,
ReprLegacyPrinter* p) {
- auto* op = static_cast<const BufferStoreNode*>(node.get());
- p->PrintIndent();
- (*p) << op->buffer->name << "[";
- for (size_t i = 0; i < op->indices.size(); ++i) {
- p->Print(op->indices[i]);
- if (i < op->indices.size() - 1) (*p) << ", ";
- }
- (*p) << "]";
- (*p) << " = ";
- p->Print(op->value);
- (*p) << '\n';
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<BufferRealizeNode>([](const ObjectRef& node,
ReprLegacyPrinter* p) {
- auto* op = static_cast<const BufferRealizeNode*>(node.get());
- p->PrintIndent();
- (*p) << "buffer_realize " << op->buffer->name << "(";
- for (size_t i = 0; i < op->bounds.size(); ++i) {
- (*p) << "[";
- p->Print(op->bounds[i]->min);
- (*p) << ", ";
- p->Print(op->bounds[i]->extent);
- (*p) << "]";
- if (i < op->bounds.size() - 1) (*p) << ", ";
- }
- (*p) << ")";
- if (!is_one(op->condition)) {
- (*p) << " if ";
- p->Print(op->condition);
- }
- (*p) << " {\n";
-
- p->indent += 2;
- p->Print(op->body);
- p->indent -= 2;
-
- p->PrintIndent();
- (*p) << "}\n";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<BufferRegionNode>([](const ObjectRef& node,
ReprLegacyPrinter* p) {
- auto* op = static_cast<const BufferRegionNode*>(node.get());
- (*p) << op->buffer->name;
- (*p) << "[";
- for (size_t i = 0; i < op->region.size(); ++i) {
- const auto& range = op->region[i];
- p->Print(range->min);
- if (!is_one(range->extent)) {
- (*p) << ":";
- p->Print(range->min + range->extent);
- }
- if (i != op->region.size() - 1) (*p) << ", ";
- }
- (*p) << "]";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<MatchBufferRegionNode>([](const ObjectRef& node,
ReprLegacyPrinter* p) {
- auto* op = static_cast<const MatchBufferRegionNode*>(node.get());
- p->PrintIndent();
- (*p) << op->buffer->name << " = match_buffer(";
- p->Print(op->source);
- (*p) << ")\n";
- });
-
-void PrintBlockTitle(const BlockNode* op, ReprLegacyPrinter* p) {
- (*p) << "block " << op->name_hint << "(";
- for (size_t i = 0; i < op->iter_vars.size(); i++) {
- p->Print(op->iter_vars[i]);
- if (i < op->iter_vars.size() - 1) (*p) << ", ";
- }
- (*p) << ")";
-}
-
-void PrintBlockSignature(const BlockNode* op, ReprLegacyPrinter* p) {
- // print read/write regions
- p->PrintIndent();
- (*p) << "reads(";
- p->Print(op->reads);
- (*p) << ")\n";
- p->PrintIndent();
- (*p) << "writes(";
- p->Print(op->writes);
- (*p) << ")\n";
- // Print alloc_buffers
- for (const auto& alloc_buf : op->alloc_buffers) {
- p->PrintIndent();
- (*p) << alloc_buf->name << " = alloc_buffer(" << alloc_buf->dtype << "[";
- for (size_t i = 0; i < alloc_buf->shape.size(); ++i) {
- if (i > 0) (*p) << ", ";
- p->Print(alloc_buf->shape[i]);
- }
- (*p) << "])\n";
- }
- // Print match_buffer_regions
- for (const auto& match_buf : op->match_buffers) {
- p->Print(match_buf);
- }
- if (!op->annotations.empty()) {
- p->PrintIndent();
- (*p) << "annotations(" << op->annotations << ")\n";
- }
-}
-
-void PrintBlockBody(const BlockNode* op, ReprLegacyPrinter* p) {
- // Print init
- if (op->init.defined()) {
- p->PrintIndent();
- (*p) << "with init() {\n";
- p->indent += 2;
- p->Print(op->init.value());
- p->indent -= 2;
- p->PrintIndent();
- (*p) << "}\n";
- }
- // Print body
- p->Print(op->body);
-}
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<BlockNode>([](const ObjectRef& node, ReprLegacyPrinter* p) {
- auto* op = static_cast<const BlockNode*>(node.get());
- p->PrintIndent();
- PrintBlockTitle(op, p);
- (*p) << " {\n";
- p->indent += 2;
-
- // Print block elements (e.g. reads/writes, etc)
- PrintBlockSignature(op, p);
- // Print block init and body
- PrintBlockBody(op, p);
-
- p->indent -= 2;
- p->PrintIndent();
- (*p) << "}\n";
- });
-
-TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable)
- .set_dispatch<BlockRealizeNode>([](const ObjectRef& node,
ReprLegacyPrinter* p) {
- auto* op = static_cast<const BlockRealizeNode*>(node.get());
- auto* block_op = op->block.get();
- p->PrintIndent();
- PrintBlockTitle(block_op, p);
- (*p) << " {\n";
- p->indent += 2;
-
- // Print binding iter_values
- for (size_t i = 0; i < block_op->iter_vars.size(); ++i) {
- p->PrintIndent();
- (*p) << "bind(";
- p->Print(block_op->iter_vars[i]->var);
- (*p) << ", ";
- p->Print(op->iter_values[i]);
- (*p) << ")\n";
- }
- // Print predicate
- if (!is_one(op->predicate)) {
- p->PrintIndent();
- (*p) << "where(";
- p->Print(op->predicate);
- (*p) << ")\n";
- }
- // Print block elements (e.g. reads/writes, etc)
- PrintBlockSignature(block_op, p);
- // Print block init and body
- PrintBlockBody(block_op, p);
-
- p->indent -= 2;
- p->PrintIndent();
- (*p) << "}\n";
- });
-
-} // namespace tir
-} // namespace tvm
diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h
index 03341c4cd9..95d24c91c4 100644
--- a/src/script/printer/utils.h
+++ b/src/script/printer/utils.h
@@ -42,18 +42,8 @@ inline void RedirectedReprPrinterMethod(const ObjectRef&
obj, ReprPrinter* p) {
try {
p->stream << TVMScriptPrinter::Script(obj, std::nullopt);
} catch (const tvm::Error& e) {
- if (ReprLegacyPrinter::CanDispatch(obj)) {
- LOG(WARNING) << "TVMScript printer falls back to the legacy ReprPrinter
with the error:\n"
- << e.what();
- try {
- p->stream << AsLegacyRepr(obj);
- } catch (const tvm::Error& e) {
- LOG(WARNING) << "AsLegacyRepr fails. Falling back to the basic address
printer";
- }
- } else {
- LOG(WARNING) << "TVMScript printer falls back to the basic address
printer with the error:\n"
- << e.what();
- }
+ LOG(WARNING) << "TVMScript printer falls back to the basic address printer
with the error:\n"
+ << e.what();
p->stream << obj->GetTypeKey() << '(' << obj.get() << ')';
}
}
diff --git a/src/support/ordered_map.h b/src/support/ordered_map.h
new file mode 100644
index 0000000000..81b0fd38a7
--- /dev/null
+++ b/src/support/ordered_map.h
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file support/ordered_map.h
+ * \brief An STL-like map that preserves insertion order.
+ */
+#ifndef TVM_SUPPORT_ORDERED_MAP_H_
+#define TVM_SUPPORT_ORDERED_MAP_H_
+
+#include <functional>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace support {
+
+/**
+ * \brief An STL-like map that preserves insertion order.
+ *
+ * \tparam K The key type.
+ * \tparam V The value type.
+ * \tparam Hash The hash function.
+ * \tparam KeyEqual The key equality function.
+ * \note we don't support erase since it is less needed and vector backing is
more efficient.
+ */
+template <typename K, typename V, typename Hash = std::hash<K>,
+ typename KeyEqual = std::equal_to<K>>
+class OrderedMap {
+ public:
+ OrderedMap() = default;
+
+ /* \brief Explicit copy constructor
+ *
+ * The default copy constructor would copy both `elements_` and
+ * `elem_to_iter_`. While this is the correct behavior for
+ * `elements_`, the copy of `elem_to_iter_` would contain references
+ * to the original's `element_`, rather than to its own
+ */
+ OrderedMap(const OrderedMap<K, V, Hash, KeyEqual>& other) :
elements_(other.elements_) {
+ InitElementToIter();
+ }
+
+ /* \brief Explicit copy assignment
+ *
+ * Implemented in terms of the copy constructor, and the default
+ * move assignment.
+ */
+ OrderedMap& operator=(const OrderedMap<K, V, Hash, KeyEqual>& other) {
+ return *this = OrderedMap(other);
+ }
+
+ OrderedMap(OrderedMap<K, V, Hash, KeyEqual>&&) = default;
+ OrderedMap& operator=(OrderedMap<K, V, Hash, KeyEqual>&&) = default;
+
+ template <typename Iter>
+ OrderedMap(Iter begin, Iter end) : elements_(begin, end) {
+ InitElementToIter();
+ }
+
+ auto find(const K& k) {
+ auto it = elem_to_index_.find(k);
+ if (it != elem_to_index_.end()) {
+ return elements_.begin() + it->second;
+ }
+ return elements_.end();
+ }
+
+ auto find(const K& k) const {
+ auto it = elem_to_index_.find(k);
+ if (it != elem_to_index_.end()) {
+ return elements_.begin() + it->second;
+ }
+ return elements_.end();
+ }
+
+ V& operator[](const K& k) {
+ auto it = elem_to_index_.find(k);
+ if (it != elem_to_index_.end()) {
+ return elements_[it->second].second;
+ }
+ elements_.emplace_back(k, V());
+ elem_to_index_[k] = elements_.size() - 1;
+ return elements_.back().second;
+ }
+
+ void insert(const K& k, V v) {
+ auto it = elem_to_index_.find(k);
+ if (it != elem_to_index_.end()) {
+ elements_[it->second].second = std::move(v);
+ } else {
+ elements_.emplace_back(k, v);
+ elem_to_index_[k] = elements_.size() - 1;
+ }
+ }
+
+ void clear() {
+ elements_.clear();
+ elem_to_index_.clear();
+ }
+
+ size_t count(const K& k) const { return elem_to_index_.count(k); }
+
+ auto begin() const { return elements_.begin(); }
+ auto end() const { return elements_.end(); }
+ auto begin() { return elements_.begin(); }
+ auto end() { return elements_.end(); }
+
+ size_t size() const { return elements_.size(); }
+ bool empty() const { return elements_.empty(); }
+
+ void reserve(size_t n) { elem_to_index_.reserve(n); }
+
+ private:
+ void InitElementToIter() {
+ for (size_t i = 0; i < elements_.size(); i++) {
+ elem_to_index_[elements_[i].first] = i;
+ }
+ }
+
+ std::vector<std::pair<K, V>> elements_;
+ std::unordered_map<K, size_t, Hash, KeyEqual> elem_to_index_;
+};
+
+} // namespace support
+} // namespace tvm
+
+#endif // TVM_SUPPORT_ORDERED_MAP_H_
diff --git a/src/support/ordered_set.h b/src/support/ordered_set.h
index 11acb8c3fe..169f738e70 100644
--- a/src/support/ordered_set.h
+++ b/src/support/ordered_set.h
@@ -26,30 +26,14 @@
#include <tvm/runtime/object.h>
-#include <list>
+#include <functional>
#include <unordered_map>
+#include <vector>
namespace tvm {
namespace support {
-namespace detail {
-/* \brief Utility to allow use for standard and ObjectRef types
- *
- * \tparam T The type held by the OrderedSet
- */
-template <typename T, typename = void>
-struct OrderedSetLookupType {
- using MapType = std::unordered_map<T, typename std::list<T>::iterator>;
-};
-
-template <typename T>
-struct OrderedSetLookupType<T,
std::enable_if_t<std::is_base_of_v<runtime::ObjectRef, T>>> {
- using MapType = std::unordered_map<T, typename std::list<T>::iterator,
runtime::ObjectPtrHash,
- runtime::ObjectPtrEqual>;
-};
-} // namespace detail
-
-template <typename T>
+template <typename T, typename Hash = std::hash<T>, typename KeyEqual =
std::equal_to<T>>
class OrderedSet {
public:
OrderedSet() = default;
@@ -61,17 +45,21 @@ class OrderedSet {
* `elements_`, the copy of `elem_to_iter_` would contain references
* to the original's `element_`, rather than to its own
*/
- OrderedSet(const OrderedSet<T>& other) : elements_(other.elements_) {
InitElementToIter(); }
+ OrderedSet(const OrderedSet<T, Hash, KeyEqual>& other) :
elements_(other.elements_) {
+ InitElementToIter();
+ }
/* \brief Explicit copy assignment
*
* Implemented in terms of the copy constructor, and the default
* move assignment.
*/
- OrderedSet& operator=(const OrderedSet<T>& other) { return *this =
OrderedSet(other); }
+ OrderedSet& operator=(const OrderedSet<T, Hash, KeyEqual>& other) {
+ return *this = OrderedSet(other);
+ }
- OrderedSet(OrderedSet<T>&&) = default;
- OrderedSet& operator=(OrderedSet<T>&&) = default;
+ OrderedSet(OrderedSet<T, Hash, KeyEqual>&&) = default;
+ OrderedSet& operator=(OrderedSet<T, Hash, KeyEqual>&&) = default;
template <typename Iter>
OrderedSet(Iter begin, Iter end) : elements_(begin, end) {
@@ -79,27 +67,20 @@ class OrderedSet {
}
void push_back(const T& t) {
- if (!elem_to_iter_.count(t)) {
+ if (!elem_to_index_.count(t)) {
elements_.push_back(t);
- elem_to_iter_[t] = std::prev(elements_.end());
+ elem_to_index_[t] = elements_.size() - 1;
}
}
void insert(const T& t) { push_back(t); }
- void erase(const T& t) {
- if (auto it = elem_to_iter_.find(t); it != elem_to_iter_.end()) {
- elements_.erase(it->second);
- elem_to_iter_.erase(it);
- }
- }
-
void clear() {
elements_.clear();
- elem_to_iter_.clear();
+ elem_to_index_.clear();
}
- size_t count(const T& t) const { return elem_to_iter_.count(t); }
+ size_t count(const T& t) const { return elem_to_index_.count(t); }
auto begin() const { return elements_.begin(); }
auto end() const { return elements_.end(); }
@@ -108,13 +89,13 @@ class OrderedSet {
private:
void InitElementToIter() {
- for (auto it = elements_.begin(); it != elements_.end(); it++) {
- elem_to_iter_[*it] = it;
+ for (size_t i = 0; i < elements_.size(); ++i) {
+ elem_to_index_[elements_[i]] = i;
}
}
- std::list<T> elements_;
- typename detail::OrderedSetLookupType<T>::MapType elem_to_iter_;
+ std::vector<T> elements_;
+ std::unordered_map<T, size_t, Hash, KeyEqual> elem_to_index_;
};
} // namespace support
diff --git a/src/tir/transforms/common_subexpr_elim.cc
b/src/tir/transforms/common_subexpr_elim.cc
index 42409efb0b..3fd78a5233 100644
--- a/src/tir/transforms/common_subexpr_elim.cc
+++ b/src/tir/transforms/common_subexpr_elim.cc
@@ -43,8 +43,7 @@
#include <algorithm> // For the algorithm std::find
#include <iostream>
#include <string>
-#include <unordered_map> // For the hashtable datatype
-#include <utility> // For std::pair and std::move
+#include <utility>
#include <vector>
#include "../analysis/check_contains.h" // For the visitor CheckContains
@@ -131,41 +130,24 @@ bool
CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExp
* they appeared in the hashtable was based on some runtime addresses,
so it can potentially
* change with every execution.
*/
-bool
CommonSubexpressionEliminator::OrderOnExprAndFrequency(std::pair<PrimExpr,
size_t> a,
-
std::pair<PrimExpr, size_t> b) {
+bool CommonSubexpressionEliminator::OrderOnExprAndFrequency(const
std::pair<PrimExpr, size_t>& a,
+ const
std::pair<PrimExpr, size_t>& b) {
size_t a_size = CalculateExprComplexity(a.first);
size_t b_size = CalculateExprComplexity(b.first);
-
- // Criteria 1 - Size of the expression comes first
- // `a` comes before `b` if the size of `a` is bigger
- if (a_size > b_size) {
- return true;
- }
- // `a` does NOT come before `b` if the size of `b` is bigger
- if (b_size > a_size) {
- return false;
- }
-
- // Criteria 2 - If they had the same size, use the lexicographic order as a
last resort
- // as we need a deterministic order
- std::stringstream a_stream;
- std::stringstream b_stream;
- a_stream << AsLegacyRepr(a.first);
- b_stream << AsLegacyRepr(b.first);
- return (a_stream.str().compare(b_stream.str()) < 0);
+ return a_size > b_size;
}
/*!
- * \brief Generates a new fresh variable, whose name will be cse_var_i.
+ * \brief Generates a new fresh variable, whose name will be cse_vi.
* \param type_annotation The type of the new variable to generate
- * \return A new variable of type `type_annotation` called cse_var_i where i
is the first available
+ * \return A new variable of type `type_annotation` called cse_vi where i is
the first available
integer.
*/
Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) {
// Increase `num_last_try_` for this new attempt
num_last_try_++;
- // Builds the variable name, which is sce_var_i where i will go up from 1
- std::string prefix = "cse_var_";
+ // Builds the variable name, which is cse_vi where i will go up from 1
+ std::string prefix = "cse_v";
std::string name = prefix.append(std::to_string(num_last_try_));
// Builds a String using the std::string
String string_name(name);
@@ -241,8 +223,8 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const
PrimExpr& expr) {
SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr,
identify_equiv_terms_);
// Sort the vector of semantic entities by decreasing size
- std::sort(semantic_comp_done_by_expr.begin(),
semantic_comp_done_by_expr.end(),
- OrderOnExprAndFrequency);
+ std::stable_sort(semantic_comp_done_by_expr.begin(),
semantic_comp_done_by_expr.end(),
+ OrderOnExprAndFrequency);
// For each computation done (considering them from biggest to smallest)
for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
@@ -421,8 +403,8 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt&
stmt) {
SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt,
identify_equiv_terms_);
// Sort the vector of semantic entities by decreasing size
- std::sort(semantic_comp_done_by_stmt.begin(),
semantic_comp_done_by_stmt.end(),
- OrderOnExprAndFrequency);
+ std::stable_sort(semantic_comp_done_by_stmt.begin(),
semantic_comp_done_by_stmt.end(),
+ OrderOnExprAndFrequency);
// For each computation done (considering them from biggest to smallest)
for (size_t i = 0; i < semantic_comp_done_by_stmt.size(); i++) {
diff --git a/src/tir/transforms/common_subexpr_elim.h
b/src/tir/transforms/common_subexpr_elim.h
index 5c14caf1a6..12a71458e1 100644
--- a/src/tir/transforms/common_subexpr_elim.h
+++ b/src/tir/transforms/common_subexpr_elim.h
@@ -83,7 +83,8 @@ class CommonSubexpressionEliminator : public StmtExprMutator {
static bool ForbiddenComputation(const PrimExpr& expr);
static bool IsEligibleComputation(const PrimExpr& expr);
static bool CanContainEligibleComputations(const PrimExpr& expr);
- static bool OrderOnExprAndFrequency(std::pair<PrimExpr, size_t> a,
std::pair<PrimExpr, size_t> b);
+ static bool OrderOnExprAndFrequency(const std::pair<PrimExpr, size_t>& a,
+ const std::pair<PrimExpr, size_t>& b);
Var GenerateNewVar(DataType type_annotation);
};
diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc
b/src/tir/transforms/common_subexpr_elim_tools.cc
index ce8aef4587..f71d2cf42a 100644
--- a/src/tir/transforms/common_subexpr_elim_tools.cc
+++ b/src/tir/transforms/common_subexpr_elim_tools.cc
@@ -797,7 +797,7 @@ std::vector<std::pair<PrimExpr, size_t>>
SyntacticToSemanticComputations(
// normalized. This normalized table will keep the count for each set of
equivalent terms
// (i.e. each equivalence class), together with a term that did appear in
this equivalence class
// (in practice, the first term of the equivalence class that was
encoutered).
- std::unordered_map<PrimExpr, std::pair<PrimExpr, size_t>, StructuralHash,
ExprDeepEqual>
+ support::OrderedMap<PrimExpr, std::pair<PrimExpr, size_t>, StructuralHash,
ExprDeepEqual>
norm_table;
// In order to avoid frequent rehashing if the norm_table becomes big, we
immediately ask for
@@ -806,23 +806,7 @@ std::vector<std::pair<PrimExpr, size_t>>
SyntacticToSemanticComputations(
// equivalence classes as there are elements)
norm_table.reserve(table.size());
- // Transform the input hashtable to a vector and sort it according to some
order, as we will be
- // iterating through its items soon, and the order of appearance will be
used to determine the
- // individual representant for each class of equivalence, which we want to
be deterministic
- // (otherwise {x+y, y+x} could be both replaced by x+y, and on another run
by y+x).
- std::vector<std::pair<PrimExpr, size_t>>
sorted_items_of_table(table.begin(), table.end());
-
- // We do the ordering by comparing the string repr of each expr to get a
determinstic ordering
- sort(sorted_items_of_table.begin(), sorted_items_of_table.end(),
- [](std::pair<PrimExpr, size_t> a, std::pair<PrimExpr, size_t> b) {
- std::stringstream a_stream;
- std::stringstream b_stream;
- a_stream << AsLegacyRepr(a.first);
- b_stream << AsLegacyRepr(b.first);
- return a_stream.str().compare(b_stream.str()) < 0;
- });
-
- for (const auto& elem : sorted_items_of_table) {
+ for (const auto& elem : table) {
PrimExpr norm_elem = NormalizeTerm(elem.first, identify_equiv_terms);
// If the normalized term is not already a key in the normalized table
auto it_found = norm_table.find(norm_elem);
@@ -831,7 +815,7 @@ std::vector<std::pair<PrimExpr, size_t>>
SyntacticToSemanticComputations(
// (i.e. `norm_elem` has been seen `elem`.second many times so far, and
the chosen element
// to represent the equivalence class will be `elem`.first as it's the
first element of the
// class that we see)
- norm_table[norm_elem] = elem;
+ norm_table.insert(norm_elem, elem);
} else {
// Otherwise, it's not the first time we see a term in this equivalence
class, so we just
// increase the count of this equivalence class as we now have
`elem`.second additional items
@@ -850,10 +834,8 @@ std::vector<std::pair<PrimExpr, size_t>>
SyntacticToSemanticComputations(
// Careful : the pairs will never change (the canonical represantants chosen
will always be the
// same), but the order in which the pairs are produced can vary as we are
iterating through the
// hashtable `norm_table`. It is not an issue as the called will be sorting
the result anyway.
- std::unordered_map<PrimExpr, std::pair<PrimExpr, size_t>, StructuralHash,
- ExprDeepEqual>::const_iterator it_norm_table;
- for (it_norm_table = norm_table.begin(); it_norm_table != norm_table.end();
++it_norm_table) {
- result.push_back(it_norm_table->second);
+ for (const auto& kv : norm_table) {
+ result.push_back(kv.second);
}
return result;
diff --git a/src/tir/transforms/common_subexpr_elim_tools.h
b/src/tir/transforms/common_subexpr_elim_tools.h
index 58014e6a40..31a81dabdb 100644
--- a/src/tir/transforms/common_subexpr_elim_tools.h
+++ b/src/tir/transforms/common_subexpr_elim_tools.h
@@ -34,10 +34,12 @@
#include <tvm/tir/stmt_functor.h> // For the class StmtExprVisitor
#include <optional>
-#include <unordered_map> // For the hashtable datatype
-#include <utility> // For pairs datatype
+#include <unordered_map>
+#include <utility> // For pairs datatype
#include <vector>
+#include "../../support/ordered_map.h"
+
namespace tvm {
namespace tir {
@@ -50,7 +52,7 @@ namespace tir {
not do variables remapping), so it is compatible with StructuralHash
(intended to be used
with StructuralEqual).
*/
-using ComputationTable = std::unordered_map<PrimExpr, size_t, StructuralHash,
ExprDeepEqual>;
+using ComputationTable = support::OrderedMap<PrimExpr, size_t, StructuralHash,
ExprDeepEqual>;
/*!
* \brief A cache of computations is made of a pair of two hashtables, which
respectively associate
diff --git
a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py
b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py
index e7e64d8916..1be5e57ba1 100644
--- a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py
+++ b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py
@@ -93,14 +93,14 @@ def test_cse():
assert body.var.name == "z2"
assert body.value == 2
- # This is the let-in for the first variable generated cse_var_1
+ # This is the let-in for the first variable generated cse_v1
assert isinstance(body.body, tvm.tir.LetStmt)
body = body.body
# And this is the name and value of this variable
- cse_var_1 = body.var # Keep the variable accessible for later checking
the replacements
- assert body.var.name == "cse_var_1"
+ cse_v1 = body.var # Keep the variable accessible for later checking the
replacements
+ assert body.var.name == "cse_v1"
tvm.ir.assert_structural_equal(body.value, z1 + z2)
assert isinstance(body.body, tvm.tir.SeqStmt)
@@ -118,27 +118,27 @@ def test_cse():
assert body.var.name == "y"
assert body.value == 1
- # This is the let-in for the second variable generated cse_var_2
+ # This is the let-in for the second variable generated cse_v2
assert isinstance(body.body, tvm.tir.LetStmt)
body = body.body
# And this is the name and value of this variable
- cse_var_2 = body.var # Keep the variable accessible for later checking
the replacements
- assert body.var.name == "cse_var_2"
+ cse_v2 = body.var # Keep the variable accessible for later checking the
replacements
+ assert body.var.name == "cse_v2"
tvm.ir.assert_structural_equal(body.value, x + y)
body = body.body
body.var.name == "a"
# Check that the replacement has been done correctly!
- tvm.ir.assert_structural_equal(body.value, cse_var_2 + cse_var_1)
+ tvm.ir.assert_structural_equal(body.value, cse_v2 + cse_v1)
body = body.body
body.var.name == "b"
# Check that the replacement has been done correctly!
- tvm.ir.assert_structural_equal(body.value, cse_var_2 + z3)
+ tvm.ir.assert_structural_equal(body.value, cse_v2 + z3)
assert isinstance(body.body, tvm.tir.BufferStore)
@@ -199,7 +199,7 @@ def test_cse_ifNode_1():
body = body.then_case
# The let-in introduced by the CSE should appear now, inside the Then
branch of the If node
- assert body.var.name == "cse_var_1"
+ assert body.var.name == "cse_v1"
# and it should contain the expression (y+z) that was redundant
tvm.ir.assert_structural_equal(body.value, y + z)
@@ -250,7 +250,7 @@ def test_cse_ifNode_2():
assert isinstance(body, tvm.tir.LetStmt)
# The let-in introduced by the CSE should appear now, at the toplevel
(i.e. before the If)
- assert body.var.name == "cse_var_1"
+ assert body.var.name == "cse_v1"
# and it should contain the expression (y+z) that was redundant
tvm.ir.assert_structural_equal(body.value, y + z)
@@ -291,8 +291,8 @@ def test_cse_cascade():
assert isinstance(body, tvm.tir.LetStmt)
# The second let-in (by order introduced) introduced by the CSE should
appear first
- cse_var_2 = body.var # Keep the variable accessible for later checking
the replacements
- assert body.var.name == "cse_var_2"
+ cse_v2 = body.var # Keep the variable accessible for later checking the
replacements
+ assert body.var.name == "cse_v2"
# and it should contain the expression (x+y)
tvm.ir.assert_structural_equal(body.value, (x + y))
@@ -301,10 +301,10 @@ def test_cse_cascade():
assert isinstance(body, tvm.tir.LetStmt)
# The first let-in (by order introduced) introduced by the CSE should
appear now, after the 2nd
- cse_var_1 = body.var # Keep the variable accessible for later checking
the replacements
- assert body.var.name == "cse_var_1"
- # and it should contain the expression cse_var_2+z
- tvm.ir.assert_structural_equal(body.value, cse_var_2 + z)
+ cse_v1 = body.var # Keep the variable accessible for later checking the
replacements
+ assert body.var.name == "cse_v1"
+ # and it should contain the expression cse_v2+z
+ tvm.ir.assert_structural_equal(body.value, cse_v2 + z)
body = body.body
@@ -317,9 +317,9 @@ def test_cse_cascade():
store2 = body[1]
store3 = body[2]
- tvm.ir.assert_structural_equal(store1.value, cse_var_1)
- tvm.ir.assert_structural_equal(store2.value, cse_var_1)
- tvm.ir.assert_structural_equal(store3.value, cse_var_2)
+ tvm.ir.assert_structural_equal(store1.value, cse_v1)
+ tvm.ir.assert_structural_equal(store2.value, cse_v1)
+ tvm.ir.assert_structural_equal(store3.value, cse_v2)
#
-----------------------------------------------------------------------------------------
@@ -360,9 +360,9 @@ def func_distributivity(
def func_distributivity_expected(
B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y:
T.int32, z: T.int32
) -> None:
- with T.LetStmt(x * y + x * z) as cse_var_1:
- B[i1] = cse_var_1
- B[i2] = cse_var_1
+ with T.LetStmt((y + z) * x) as cse_v1:
+ B[i1] = cse_v1
+ B[i2] = cse_v1
@T.prim_func
@@ -377,9 +377,9 @@ def func_associativity(
def func_associativity_expected(
B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y:
T.int32, z: T.int32
) -> None:
- with T.LetStmt((x + y) + z) as cse_var_1:
- B[i1] = cse_var_1
- B[i2] = cse_var_1
+ with T.LetStmt(x + y + z) as cse_v1:
+ B[i1] = cse_v1
+ B[i2] = cse_v1
def _check(original, transformed):
@@ -410,10 +410,10 @@ def test_deterministic_cse():
result = (x + 1) + (x + 2) + (x + 3) + (x + 1) + (x + 2) + (x + 3)
-->
- cse_var_3 = (x + 1)
- cse_var_2 = (x + 2)
- cse_var_1 = (x + 3)
- result = cse_var_3 + cse_var_2 + cse_var_1 + cse_var_3 + cse_var_2 +
cse_var_1
+ cse_v3 = (x + 1)
+ cse_v2 = (x + 2)
+ cse_v1 = (x + 3)
+ result = cse_v3 + cse_v2 + cse_v1 + cse_v3 + cse_v2 + cse_v1
"""
NUM_TERMS = 10
REPEATS = 10
diff --git
a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
index da079f46e3..13487b42f0 100644
--- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py
@@ -329,11 +329,11 @@ __asm__ __volatile__("cp.async.commit_group;");
__asm__ __volatile__("cp.async.commit_group;");
for (int i = 0; i < 13; ++i) {
- bool cse_var_1 = (i < 12);
+ bool cse_v1 = (i < 12);
{
unsigned int addr = cast_smem_ptr_to_int(A_shared + ((((i + 3) & 3) * 16)
+ ((int)threadIdx.x)));
- int pred_guard = (int)cse_var_1;
+ int pred_guard = (int)cse_v1;
__asm__ __volatile__(
"{ .reg .pred p;"
" setp.ne.b32 p, %0, 0;"
@@ -356,7 +356,7 @@ __asm__ __volatile__("cp.async.wait_group 5;");
{
unsigned int addr = cast_smem_ptr_to_int(B_shared + ((((i + 3) & 3) * 16)
+ ((int)threadIdx.x)));
- int pred_guard = (int)cse_var_1;
+ int pred_guard = (int)cse_v1;
__asm__ __volatile__(
"{ .reg .pred p;"
" setp.ne.b32 p, %0, 0;"
@@ -954,10 +954,10 @@ class
TestMultiplicationNodesAreInligned(tvm.testing.CompareBeforeAfter):
T.attr("default", "async_scope", 1)
for i in range(16):
- cse_var_1: T.int64 = T.Cast("int64", i)
- A_shared[
- T.Ramp(tx * T.int64(128) + cse_var_1 * T.int64(8), T.int64(1),
8)
- ] = A_flattened[T.Ramp(tx * T.int64(128) + cse_var_1 * T.int64(8),
T.int64(1), 8)]
+ cse_v1: T.int64 = T.Cast("int64", i)
+ A_shared[T.Ramp(tx * T.int64(128) + cse_v1 * T.int64(8),
T.int64(1), 8)] = A_flattened[
+ T.Ramp(tx * T.int64(128) + cse_v1 * T.int64(8), T.int64(1), 8)
+ ]
T.ptx_commit_group()
T.ptx_wait_group(0)
@@ -965,13 +965,13 @@ class
TestMultiplicationNodesAreInligned(tvm.testing.CompareBeforeAfter):
tx = T.launch_thread("threadIdx.x", T.int64(32))
A_shared = T.decl_buffer((4096,), "float16", scope="shared")
for i in range(16):
- cse_var_1: T.int64 = T.Cast("int64", i)
+ cse_v1: T.int64 = T.Cast("int64", i)
T.ptx_cp_async(
"float16",
A_shared.data,
- tx * T.int64(128) + cse_var_1 * T.int64(8),
+ tx * T.int64(128) + cse_v1 * T.int64(8),
A.data,
- tx * T.int64(128) + cse_var_1 * T.int64(8),
+ tx * T.int64(128) + cse_v1 * T.int64(8),
16,
)
T.ptx_commit_group()
diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
index c63d2f8a41..299c193146 100644
--- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
+++ b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py
@@ -154,8 +154,8 @@ def test_lower_overflow_int32():
rxplaceholder_1 = T.Buffer((T.int64(822083584),),
data=rxplaceholder.data)
T_subtract_1 = T.Buffer((T.int64(822083584),), data=T_subtract)
for ax1, ax2 in T.grid(32, 25690112):
- cse_var_1: T.int32 = ax1 * 25690112 + ax2
- T_subtract_1[cse_var_1] = rxplaceholder_1[cse_var_1] -
rxplaceholder_red_1[ax1]
+ cse_v1: T.int32 = ax1 * 25690112 + ax2
+ T_subtract_1[cse_v1] = rxplaceholder_1[cse_v1] -
rxplaceholder_red_1[ax1]
func = variance4
tvm.compile(func, target="llvm") # should not crash
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index af2db34415..0e1b328844 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -3679,6 +3679,7 @@ def merge_shape_var_def():
# uninitialized vars
@T.prim_func(check_well_formed=False)
def main(A: T.handle, B: T.handle):
+ # fmt: off
T.func_attr({"global_symbol": "main", "tir.noalias": True})
m, n = T.int32(), T.int32()
A_1 = T.match_buffer(A, (m, n), strides=("A_1_s0", "A_1_s1"),
buffer_type="auto")
@@ -3687,8 +3688,8 @@ def merge_shape_var_def():
if T.likely(i_outer * 10 + i_inner < m):
for j_inner in range(5):
if T.likely(j_outer * 5 + j_inner < n):
- cse_var_2: T.int32 = j_outer * 5 + j_inner
- cse_var_1: T.int32 = i_outer * 10 + i_inner
+ cse_v2: T.int32 = j_outer * 5 + j_inner
+ cse_v1: T.int32 = i_outer * 10 + i_inner
B_2 = T.Buffer(
(B_1.strides[0] * m,),
data=B_1.data,
@@ -3701,9 +3702,10 @@ def merge_shape_var_def():
strides=("A_2_s0",),
buffer_type="auto",
)
- B_2[cse_var_1 * B_1.strides[0] + cse_var_2 *
B_1.strides[1]] = A_2[
- cse_var_1 * A_1.strides[0] + cse_var_2 *
A_1.strides[1]
+ B_2[cse_v1 * B_1.strides[0] + cse_v2 * B_1.strides[1]]
= A_2[
+ cse_v1 * A_1.strides[0] + cse_v2 * A_1.strides[1]
]
+ # fmt: on
return main