This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 4404748 [TIR][REFACTOR] Add tir prefix to type keys (#5802)
4404748 is described below
commit 4404748e15fad281e0c51ad77ae267db07a5bdd0
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Jun 14 09:45:46 2020 -0700
[TIR][REFACTOR] Add tir prefix to type keys (#5802)
---
include/tvm/tir/buffer.h | 4 +-
include/tvm/tir/data_layout.h | 4 +-
include/tvm/tir/expr.h | 64 ++++++++++-----------
include/tvm/tir/stmt.h | 32 +++++------
include/tvm/tir/var.h | 2 +-
python/tvm/ir/json_compact.py | 47 +++++++++++++--
python/tvm/te/hybrid/util.py | 2 +-
python/tvm/tir/buffer.py | 4 +-
python/tvm/tir/data_layout.py | 4 +-
python/tvm/tir/expr.py | 66 +++++++++++-----------
python/tvm/tir/stmt.py | 30 +++++-----
src/tir/pass/hoist_if_then_else.cc | 8 +--
tests/python/unittest/test_target_codegen_cuda.py | 2 +-
tests/python/unittest/test_target_codegen_llvm.py | 2 +-
tests/python/unittest/test_tir_pass_hoist_if.py | 40 ++++++-------
.../unittest/test_tir_stmt_functor_ir_transform.py | 2 +-
tutorials/dev/low_level_custom_pass.py | 4 +-
vta/python/vta/transform.py | 18 +++---
18 files changed, 186 insertions(+), 149 deletions(-)
diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h
index 34b0155..e150ff3 100644
--- a/include/tvm/tir/buffer.h
+++ b/include/tvm/tir/buffer.h
@@ -118,7 +118,7 @@ class BufferNode : public Object {
return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
}
- static constexpr const char* _type_key = "Buffer";
+ static constexpr const char* _type_key = "tir.Buffer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(BufferNode, Object);
@@ -228,7 +228,7 @@ class DataProducerNode : public Object {
void SHashReduce(SHashReducer hash_reduce) const {}
- static constexpr const char* _type_key = "DataProducer";
+ static constexpr const char* _type_key = "tir.DataProducer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_BASE_OBJECT_INFO(DataProducerNode, Object);
diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h
index b7cb686..d3a77cc 100644
--- a/include/tvm/tir/data_layout.h
+++ b/include/tvm/tir/data_layout.h
@@ -112,7 +112,7 @@ class LayoutNode : public Object {
v->Visit("axes", &axes);
}
- static constexpr const char* _type_key = "Layout";
+ static constexpr const char* _type_key = "tir.Layout";
TVM_DECLARE_FINAL_OBJECT_INFO(LayoutNode, Object);
};
@@ -308,7 +308,7 @@ class BijectiveLayoutNode : public Object {
v->Visit("backward_rule", &backward_rule);
}
- static constexpr const char* _type_key = "BijectiveLayout";
+ static constexpr const char* _type_key = "tir.BijectiveLayout";
TVM_DECLARE_FINAL_OBJECT_INFO(BijectiveLayoutNode, Object);
};
diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index cfb7f1e..1518d1f 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -64,7 +64,7 @@ class StringImmNode : public PrimExprNode {
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
- static constexpr const char* _type_key = "StringImm";
+ static constexpr const char* _type_key = "tir.StringImm";
TVM_DECLARE_FINAL_OBJECT_INFO(StringImmNode, PrimExprNode);
};
@@ -101,7 +101,7 @@ class CastNode : public PrimExprNode {
hash_reduce(value);
}
- static constexpr const char* _type_key = "Cast";
+ static constexpr const char* _type_key = "tir.Cast";
TVM_DECLARE_FINAL_OBJECT_INFO(CastNode, PrimExprNode);
};
@@ -149,7 +149,7 @@ class BinaryOpNode : public PrimExprNode {
/*! \brief a + b */
class AddNode : public BinaryOpNode<AddNode> {
public:
- static constexpr const char* _type_key = "Add";
+ static constexpr const char* _type_key = "tir.Add";
};
/*!
@@ -165,7 +165,7 @@ class Add : public PrimExpr {
/*! \brief a - b */
class SubNode : public BinaryOpNode<SubNode> {
public:
- static constexpr const char* _type_key = "Sub";
+ static constexpr const char* _type_key = "tir.Sub";
};
/*!
@@ -181,7 +181,7 @@ class Sub : public PrimExpr {
/*! \brief a * b */
class MulNode : public BinaryOpNode<MulNode> {
public:
- static constexpr const char* _type_key = "Mul";
+ static constexpr const char* _type_key = "tir.Mul";
};
/*!
@@ -200,7 +200,7 @@ class Mul : public PrimExpr {
*/
class DivNode : public BinaryOpNode<DivNode> {
public:
- static constexpr const char* _type_key = "Div";
+ static constexpr const char* _type_key = "tir.Div";
};
/*!
@@ -219,7 +219,7 @@ class Div : public PrimExpr {
*/
class ModNode : public BinaryOpNode<ModNode> {
public:
- static constexpr const char* _type_key = "Mod";
+ static constexpr const char* _type_key = "tir.Mod";
};
/*!
@@ -235,7 +235,7 @@ class Mod : public PrimExpr {
/*! \brief Floor division, floor(a/b) */
class FloorDivNode : public BinaryOpNode<FloorDivNode> {
public:
- static constexpr const char* _type_key = "FloorDiv";
+ static constexpr const char* _type_key = "tir.FloorDiv";
};
/*!
@@ -251,7 +251,7 @@ class FloorDiv : public PrimExpr {
/*! \brief The remainder of the floordiv */
class FloorModNode : public BinaryOpNode<FloorModNode> {
public:
- static constexpr const char* _type_key = "FloorMod";
+ static constexpr const char* _type_key = "tir.FloorMod";
};
/*!
@@ -267,7 +267,7 @@ class FloorMod : public PrimExpr {
/*! \brief min(a, b) */
class MinNode : public BinaryOpNode<MinNode> {
public:
- static constexpr const char* _type_key = "Min";
+ static constexpr const char* _type_key = "tir.Min";
};
/*!
@@ -283,7 +283,7 @@ class Min : public PrimExpr {
/*! \brief max(a, b) */
class MaxNode : public BinaryOpNode<MaxNode> {
public:
- static constexpr const char* _type_key = "Max";
+ static constexpr const char* _type_key = "tir.Max";
};
/*!
@@ -330,7 +330,7 @@ class CmpOpNode : public PrimExprNode {
/*! \brief a == b */
class EQNode : public CmpOpNode<EQNode> {
public:
- static constexpr const char* _type_key = "EQ";
+ static constexpr const char* _type_key = "tir.EQ";
};
/*!
@@ -346,7 +346,7 @@ class EQ : public PrimExpr {
/*! \brief a != b */
class NENode : public CmpOpNode<NENode> {
public:
- static constexpr const char* _type_key = "NE";
+ static constexpr const char* _type_key = "tir.NE";
};
/*!
@@ -362,7 +362,7 @@ class NE : public PrimExpr {
/*! \brief a < b */
class LTNode : public CmpOpNode<LTNode> {
public:
- static constexpr const char* _type_key = "LT";
+ static constexpr const char* _type_key = "tir.LT";
};
/*!
@@ -378,7 +378,7 @@ class LT : public PrimExpr {
/*! \brief a <= b */
struct LENode : public CmpOpNode<LENode> {
public:
- static constexpr const char* _type_key = "LE";
+ static constexpr const char* _type_key = "tir.LE";
};
/*!
@@ -394,7 +394,7 @@ class LE : public PrimExpr {
/*! \brief a > b */
class GTNode : public CmpOpNode<GTNode> {
public:
- static constexpr const char* _type_key = "GT";
+ static constexpr const char* _type_key = "tir.GT";
};
/*!
@@ -410,7 +410,7 @@ class GT : public PrimExpr {
/*! \brief a >= b */
class GENode : public CmpOpNode<GENode> {
public:
- static constexpr const char* _type_key = "GE";
+ static constexpr const char* _type_key = "tir.GE";
};
/*!
@@ -447,7 +447,7 @@ class AndNode : public PrimExprNode {
hash_reduce(b);
}
- static constexpr const char* _type_key = "And";
+ static constexpr const char* _type_key = "tir.And";
TVM_DECLARE_FINAL_OBJECT_INFO(AndNode, PrimExprNode);
};
@@ -485,7 +485,7 @@ class OrNode : public PrimExprNode {
hash_reduce(b);
}
- static constexpr const char* _type_key = "Or";
+ static constexpr const char* _type_key = "tir.Or";
TVM_DECLARE_FINAL_OBJECT_INFO(OrNode, PrimExprNode);
};
@@ -519,7 +519,7 @@ class NotNode : public PrimExprNode {
hash_reduce(a);
}
- static constexpr const char* _type_key = "Not";
+ static constexpr const char* _type_key = "tir.Not";
TVM_DECLARE_FINAL_OBJECT_INFO(NotNode, PrimExprNode);
};
@@ -568,7 +568,7 @@ class SelectNode : public PrimExprNode {
hash_reduce(false_value);
}
- static constexpr const char* _type_key = "Select";
+ static constexpr const char* _type_key = "tir.Select";
TVM_DECLARE_FINAL_OBJECT_INFO(SelectNode, PrimExprNode);
};
@@ -617,7 +617,7 @@ class BufferLoadNode : public PrimExprNode {
hash_reduce(indices);
}
- static constexpr const char* _type_key = "BufferLoad";
+ static constexpr const char* _type_key = "tir.BufferLoad";
TVM_DECLARE_FINAL_OBJECT_INFO(BufferLoadNode, PrimExprNode);
};
@@ -664,7 +664,7 @@ class ProducerLoadNode : public PrimExprNode {
hash_reduce(indices);
}
- static constexpr const char* _type_key = "ProducerLoad";
+ static constexpr const char* _type_key = "tir.ProducerLoad";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerLoadNode, PrimExprNode);
};
@@ -722,7 +722,7 @@ class LoadNode : public PrimExprNode {
hash_reduce(predicate);
}
- static constexpr const char* _type_key = "Load";
+ static constexpr const char* _type_key = "tir.Load";
TVM_DECLARE_FINAL_OBJECT_INFO(LoadNode, PrimExprNode);
};
@@ -773,7 +773,7 @@ class RampNode : public PrimExprNode {
hash_reduce(lanes);
}
- static constexpr const char* _type_key = "Ramp";
+ static constexpr const char* _type_key = "tir.Ramp";
TVM_DECLARE_FINAL_OBJECT_INFO(RampNode, PrimExprNode);
};
@@ -811,7 +811,7 @@ class BroadcastNode : public PrimExprNode {
hash_reduce(lanes);
}
- static constexpr const char* _type_key = "Broadcast";
+ static constexpr const char* _type_key = "tir.Broadcast";
TVM_DECLARE_FINAL_OBJECT_INFO(BroadcastNode, PrimExprNode);
};
@@ -856,7 +856,7 @@ class LetNode : public PrimExprNode {
hash_reduce(body);
}
- static constexpr const char* _type_key = "Let";
+ static constexpr const char* _type_key = "tir.Let";
TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, PrimExprNode);
};
@@ -928,7 +928,7 @@ class CallNode : public PrimExprNode {
/*! \return Whether call node can be vectorized. */
bool is_vectorizable() const;
- static constexpr const char* _type_key = "Call";
+ static constexpr const char* _type_key = "tir.Call";
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, PrimExprNode);
// Build-in intrinsics
@@ -990,7 +990,7 @@ class ShuffleNode : public PrimExprNode {
hash_reduce(indices);
}
- static constexpr const char* _type_key = "Shuffle";
+ static constexpr const char* _type_key = "tir.Shuffle";
TVM_DECLARE_FINAL_OBJECT_INFO(ShuffleNode, PrimExprNode);
};
@@ -1048,7 +1048,7 @@ class CommReducerNode : public Object {
hash_reduce(identity_element);
}
- static constexpr const char* _type_key = "CommReducer";
+ static constexpr const char* _type_key = "tir.CommReducer";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(CommReducerNode, Object);
@@ -1108,7 +1108,7 @@ class ReduceNode : public PrimExprNode {
hash_reduce(value_index);
}
- static constexpr const char* _type_key = "Reduce";
+ static constexpr const char* _type_key = "tir.Reduce";
TVM_DECLARE_FINAL_OBJECT_INFO(ReduceNode, PrimExprNode);
};
@@ -1136,7 +1136,7 @@ class AnyNode : public PrimExprNode {
/*! \brief Convert to var. */
Var ToVar() const { return Var("any_dim", DataType::Int(32)); }
- static constexpr const char* _type_key = "Any";
+ static constexpr const char* _type_key = "tir.Any";
TVM_DECLARE_FINAL_OBJECT_INFO(AnyNode, PrimExprNode);
};
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index ee8e1eb..be1c567 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -37,7 +37,7 @@ namespace tir {
/*! \brief Base node of all statements. */
class StmtNode : public Object {
public:
- static constexpr const char* _type_key = "Stmt";
+ static constexpr const char* _type_key = "tir.Stmt";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 15;
@@ -79,7 +79,7 @@ class LetStmtNode : public StmtNode {
hash_reduce(body);
}
- static constexpr const char* _type_key = "LetStmt";
+ static constexpr const char* _type_key = "tir.LetStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(LetStmtNode, StmtNode);
};
@@ -134,7 +134,7 @@ class AttrStmtNode : public StmtNode {
hash_reduce(body);
}
- static constexpr const char* _type_key = "AttrStmt";
+ static constexpr const char* _type_key = "tir.AttrStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(AttrStmtNode, StmtNode);
};
@@ -181,7 +181,7 @@ class AssertStmtNode : public StmtNode {
hash_reduce(body);
}
- static constexpr const char* _type_key = "AssertStmt";
+ static constexpr const char* _type_key = "tir.AssertStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode);
};
@@ -244,7 +244,7 @@ class StoreNode : public StmtNode {
hash_reduce(predicate);
}
- static constexpr const char* _type_key = "Store";
+ static constexpr const char* _type_key = "tir.Store";
TVM_DECLARE_FINAL_OBJECT_INFO(StoreNode, StmtNode);
};
@@ -295,7 +295,7 @@ class BufferStoreNode : public StmtNode {
hash_reduce(indices);
}
- static constexpr const char* _type_key = "BufferStore";
+ static constexpr const char* _type_key = "tir.BufferStore";
TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode);
};
@@ -355,7 +355,7 @@ class BufferRealizeNode : public StmtNode {
BufferRealizeNode(Buffer buffer, Array<Range> bounds, PrimExpr condition,
Stmt body)
: buffer(buffer), bounds(bounds), condition(condition), body(body) {}
- static constexpr const char* _type_key = "BufferRealize";
+ static constexpr const char* _type_key = "tir.BufferRealize";
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRealizeNode, StmtNode);
};
@@ -406,7 +406,7 @@ class ProducerStoreNode : public StmtNode {
hash_reduce(indices);
}
- static constexpr const char* _type_key = "ProducerStore";
+ static constexpr const char* _type_key = "tir.ProducerStore";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode);
};
@@ -462,7 +462,7 @@ class ProducerRealizeNode : public StmtNode {
hash_reduce(body);
}
- static constexpr const char* _type_key = "ProducerRealize";
+ static constexpr const char* _type_key = "tir.ProducerRealize";
TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode);
};
@@ -529,7 +529,7 @@ class AllocateNode : public StmtNode {
*/
TVM_DLL static int32_t constant_allocation_size(const Array<PrimExpr>&
extents);
- static constexpr const char* _type_key = "Allocate";
+ static constexpr const char* _type_key = "tir.Allocate";
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateNode, StmtNode);
};
@@ -559,7 +559,7 @@ class FreeNode : public StmtNode {
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(buffer_var); }
- static constexpr const char* _type_key = "Free";
+ static constexpr const char* _type_key = "tir.Free";
TVM_DECLARE_FINAL_OBJECT_INFO(FreeNode, StmtNode);
};
@@ -598,7 +598,7 @@ class SeqStmtNode : public StmtNode {
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(seq); }
- static constexpr const char* _type_key = "SeqStmt";
+ static constexpr const char* _type_key = "tir.SeqStmt";
TVM_DECLARE_FINAL_OBJECT_INFO(SeqStmtNode, StmtNode);
};
@@ -697,7 +697,7 @@ class IfThenElseNode : public StmtNode {
hash_reduce(else_case);
}
- static constexpr const char* _type_key = "IfThenElse";
+ static constexpr const char* _type_key = "tir.IfThenElse";
TVM_DECLARE_FINAL_OBJECT_INFO(IfThenElseNode, StmtNode);
};
@@ -731,7 +731,7 @@ class EvaluateNode : public StmtNode {
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(value); }
- static constexpr const char* _type_key = "Evaluate";
+ static constexpr const char* _type_key = "tir.Evaluate";
TVM_DECLARE_FINAL_OBJECT_INFO(EvaluateNode, StmtNode);
};
@@ -817,7 +817,7 @@ class ForNode : public StmtNode {
hash_reduce(body);
}
- static constexpr const char* _type_key = "For";
+ static constexpr const char* _type_key = "tir.For";
TVM_DECLARE_FINAL_OBJECT_INFO(ForNode, StmtNode);
};
@@ -860,7 +860,7 @@ class PrefetchNode : public StmtNode {
PrefetchNode() = default;
PrefetchNode(Buffer buffer, Array<Range> bounds) : buffer(buffer),
bounds(bounds) {}
- static constexpr const char* _type_key = "Prefetch";
+ static constexpr const char* _type_key = "tir.Prefetch";
TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode);
};
diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h
index 2a44909..f1651c1 100644
--- a/include/tvm/tir/var.h
+++ b/include/tvm/tir/var.h
@@ -266,7 +266,7 @@ class IterVarNode : public Object {
hash_reduce(thread_tag);
}
- static constexpr const char* _type_key = "IterVar";
+ static constexpr const char* _type_key = "tir.IterVar";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object);
diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py
index 94e9cf3..8b75685 100644
--- a/python/tvm/ir/json_compact.py
+++ b/python/tvm/ir/json_compact.py
@@ -138,11 +138,48 @@ def create_updater_06_to_07():
# TIR
"Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")],
"SizeVar": [_update_tir_var("tir.SizeVar"),
_update_from_std_str("name")],
- "StringImm": [_update_from_std_str("value")],
- "Call": [_update_from_std_str("name")],
- "AttrStmt": [_update_from_std_str("attr_key")],
- "Layout": [_update_from_std_str("name")],
- "Buffer": [_update_from_std_str("name"),
_update_from_std_str("scope")],
+ "StringImm": [_rename("tir.StringImm"), _update_from_std_str("value")],
+ "Cast": [_rename("tir.Cast")],
+ "Add": [_rename("tir.Add")],
+ "Sub": [_rename("tir.Sub")],
+ "Mul": [_rename("tir.Mul")],
+ "Div": [_rename("tir.Div")],
+ "Mod": [_rename("tir.Mod")],
+ "FloorDiv": [_rename("tir.FloorDiv")],
+ "FloorMod": [_rename("tir.FloorMod")],
+ "Min": [_rename("tir.Min")],
+ "Max": [_rename("tir.Max")],
+ "EQ": [_rename("tir.EQ")],
+ "NE": [_rename("tir.NE")],
+ "LT": [_rename("tir.LT")],
+ "LE": [_rename("tir.LE")],
+ "GT": [_rename("tir.GT")],
+ "GE": [_rename("tir.GE")],
+ "And": [_rename("tir.And")],
+ "Or": [_rename("tir.Or")],
+ "Not": [_rename("tir.Not")],
+ "Select": [_rename("tir.Select")],
+ "Load": [_rename("tir.Load")],
+ "BufferLoad": [_rename("tir.BufferLoad")],
+ "Ramp": [_rename("tir.Ramp")],
+ "Broadcast": [_rename("tir.Broadcast")],
+ "Shuffle": [_rename("tir.Shuffle")],
+ "Call": [_rename("tir.Call"), _update_from_std_str("name")],
+ "Let": [_rename("tir.Let")],
+ "Any": [_rename("tir.Any")],
+ "LetStmt": [_rename("tir.LetStmt")],
+ "AssertStmt": [_rename("tir.AssertStmt")],
+ "Store": [_rename("tir.Store")],
+ "BufferStore": [_rename("tir.BufferStore")],
+ "BufferRealize": [_rename("tir.BufferRealize")],
+ "Allocate": [_rename("tir.Allocate")],
+ "IfThenElse": [_rename("tir.IfThenElse")],
+ "Evaluate": [_rename("tir.Evaluate")],
+ "Prefetch": [_rename("tir.Prefetch")],
+ "AttrStmt": [_rename("tir.AttrStmt"),
_update_from_std_str("attr_key")],
+ "Layout": [_rename("tir.Layout"), _update_from_std_str("name")],
+ "Buffer": [
+ _rename("tir.Buffer"), _update_from_std_str("name"),
_update_from_std_str("scope")],
}
return create_updater(node_map, "0.6", "0.7")
diff --git a/python/tvm/te/hybrid/util.py b/python/tvm/te/hybrid/util.py
index 810509b..891d7ba 100644
--- a/python/tvm/te/hybrid/util.py
+++ b/python/tvm/te/hybrid/util.py
@@ -83,7 +83,7 @@ def replace_io(body, rmap):
return _expr.ProducerLoad(buf, op.indices)
return None
- return stmt_functor.ir_transform(body, None, replace, ['ProducerStore',
'ProducerLoad'])
+ return stmt_functor.ir_transform(body, None, replace,
['tir.ProducerStore', 'tir.ProducerLoad'])
def _is_tvm_arg_types(args):
diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py
index e4dec5f..11bfb4c 100644
--- a/python/tvm/tir/buffer.py
+++ b/python/tvm/tir/buffer.py
@@ -24,7 +24,7 @@ from tvm.ir import PrimExpr
from . import _ffi_api
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Buffer")
class Buffer(Object):
"""Symbolic data buffer in TVM.
@@ -247,6 +247,6 @@ def decl_buffer(shape,
data_alignment, offset_factor, buffer_type)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.DataProducer")
class DataProducer(Object):
pass
diff --git a/python/tvm/tir/data_layout.py b/python/tvm/tir/data_layout.py
index fd8c7a9..1616473 100644
--- a/python/tvm/tir/data_layout.py
+++ b/python/tvm/tir/data_layout.py
@@ -20,7 +20,7 @@ import tvm._ffi
from tvm.runtime import Object
from . import _ffi_api
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Layout")
class Layout(Object):
"""Layout is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
@@ -77,7 +77,7 @@ class Layout(Object):
return _ffi_api.LayoutFactorOf(self, axis)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.BijectiveLayout")
class BijectiveLayout(Object):
"""Bijective mapping for two layouts (src-layout and dst-layout).
It provides shape and index conversion between each other.
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index d55370e..f8cb054 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -321,7 +321,7 @@ class SizeVar(Var):
_ffi_api.SizeVar, name, dtype)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.IterVar")
class IterVar(Object, ExprOp):
"""Represent iteration variable.
@@ -373,7 +373,7 @@ class IterVar(Object, ExprOp):
_ffi_api.IterVar, dom, var, iter_type, thread_tag)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.CommReducer")
class CommReducer(Object):
"""Communicative reduce operator
@@ -396,7 +396,7 @@ class CommReducer(Object):
_ffi_api.CommReducer, lhs, rhs, result, identity_element)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Reduce")
class Reduce(PrimExprWithOp):
"""Reduce node.
@@ -475,7 +475,7 @@ class IntImm(ConstExpr):
return self.__nonzero__()
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.StringImm")
class StringImm(ConstExpr):
"""String constant.
@@ -499,7 +499,7 @@ class StringImm(ConstExpr):
return self.value != other
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Cast")
class Cast(PrimExprWithOp):
"""Cast expression.
@@ -516,7 +516,7 @@ class Cast(PrimExprWithOp):
_ffi_api.Cast, dtype, value)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Add")
class Add(BinaryOpExpr):
"""Add node.
@@ -533,7 +533,7 @@ class Add(BinaryOpExpr):
_ffi_api.Add, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Sub")
class Sub(BinaryOpExpr):
"""Sub node.
@@ -550,7 +550,7 @@ class Sub(BinaryOpExpr):
_ffi_api.Sub, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Mul")
class Mul(BinaryOpExpr):
"""Mul node.
@@ -567,7 +567,7 @@ class Mul(BinaryOpExpr):
_ffi_api.Mul, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Div")
class Div(BinaryOpExpr):
"""Div node.
@@ -584,7 +584,7 @@ class Div(BinaryOpExpr):
_ffi_api.Div, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Mod")
class Mod(BinaryOpExpr):
"""Mod node.
@@ -601,7 +601,7 @@ class Mod(BinaryOpExpr):
_ffi_api.Mod, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.FloorDiv")
class FloorDiv(BinaryOpExpr):
"""FloorDiv node.
@@ -618,7 +618,7 @@ class FloorDiv(BinaryOpExpr):
_ffi_api.FloorDiv, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.FloorMod")
class FloorMod(BinaryOpExpr):
"""FloorMod node.
@@ -635,7 +635,7 @@ class FloorMod(BinaryOpExpr):
_ffi_api.FloorMod, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Min")
class Min(BinaryOpExpr):
"""Min node.
@@ -652,7 +652,7 @@ class Min(BinaryOpExpr):
_ffi_api.Min, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Max")
class Max(BinaryOpExpr):
"""Max node.
@@ -669,7 +669,7 @@ class Max(BinaryOpExpr):
_ffi_api.Max, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.EQ")
class EQ(CmpExpr):
"""EQ node.
@@ -686,7 +686,7 @@ class EQ(CmpExpr):
_ffi_api.EQ, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.NE")
class NE(CmpExpr):
"""NE node.
@@ -703,7 +703,7 @@ class NE(CmpExpr):
_ffi_api.NE, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.LT")
class LT(CmpExpr):
"""LT node.
@@ -720,7 +720,7 @@ class LT(CmpExpr):
_ffi_api.LT, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.LE")
class LE(CmpExpr):
"""LE node.
@@ -737,7 +737,7 @@ class LE(CmpExpr):
_ffi_api.LE, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.GT")
class GT(CmpExpr):
"""GT node.
@@ -754,7 +754,7 @@ class GT(CmpExpr):
_ffi_api.GT, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.GE")
class GE(CmpExpr):
"""GE node.
@@ -771,7 +771,7 @@ class GE(CmpExpr):
_ffi_api.GE, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.And")
class And(LogicalExpr):
"""And node.
@@ -788,7 +788,7 @@ class And(LogicalExpr):
_ffi_api.And, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Or")
class Or(LogicalExpr):
"""Or node.
@@ -805,7 +805,7 @@ class Or(LogicalExpr):
_ffi_api.Or, a, b)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Not")
class Not(LogicalExpr):
"""Not node.
@@ -819,7 +819,7 @@ class Not(LogicalExpr):
_ffi_api.Not, a)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Select")
class Select(PrimExprWithOp):
"""Select node.
@@ -847,7 +847,7 @@ class Select(PrimExprWithOp):
_ffi_api.Select, condition, true_value, false_value)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Load")
class Load(PrimExprWithOp):
"""Load node.
@@ -871,7 +871,7 @@ class Load(PrimExprWithOp):
_ffi_api.Load, dtype, buffer_var, index, *args)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.BufferLoad")
class BufferLoad(PrimExprWithOp):
"""Buffer load node.
@@ -888,7 +888,7 @@ class BufferLoad(PrimExprWithOp):
_ffi_api.BufferLoad, buffer, indices)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.ProducerLoad")
class ProducerLoad(PrimExprWithOp):
"""Producer load node.
@@ -905,7 +905,7 @@ class ProducerLoad(PrimExprWithOp):
_ffi_api.ProducerLoad, producer, indices)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Ramp")
class Ramp(PrimExprWithOp):
"""Ramp node.
@@ -925,7 +925,7 @@ class Ramp(PrimExprWithOp):
_ffi_api.Ramp, base, stride, lanes)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Broadcast")
class Broadcast(PrimExprWithOp):
"""Broadcast node.
@@ -942,7 +942,7 @@ class Broadcast(PrimExprWithOp):
_ffi_api.Broadcast, value, lanes)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Shuffle")
class Shuffle(PrimExprWithOp):
"""Shuffle node.
@@ -959,7 +959,7 @@ class Shuffle(PrimExprWithOp):
_ffi_api.Shuffle, vectors, indices)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Call")
class Call(PrimExprWithOp):
"""Call node.
@@ -987,7 +987,7 @@ class Call(PrimExprWithOp):
_ffi_api.Call, dtype, name, args, call_type)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Let")
class Let(PrimExprWithOp):
"""Let node.
@@ -1007,7 +1007,7 @@ class Let(PrimExprWithOp):
_ffi_api.Let, var, value, body)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Any")
class Any(PrimExpr):
"""Any node.
"""
diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py
index f4d8471..4536580 100644
--- a/python/tvm/tir/stmt.py
+++ b/python/tvm/tir/stmt.py
@@ -36,7 +36,7 @@ class Stmt(Object):
"""Base class of all the statements."""
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.LetStmt")
class LetStmt(Stmt):
"""LetStmt node.
@@ -56,7 +56,7 @@ class LetStmt(Stmt):
_ffi_api.LetStmt, var, value, body)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.AssertStmt")
class AssertStmt(Stmt):
"""AssertStmt node.
@@ -76,7 +76,7 @@ class AssertStmt(Stmt):
_ffi_api.AssertStmt, condition, message, body)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.For")
class For(Stmt):
"""For node.
@@ -116,7 +116,7 @@ class For(Stmt):
for_type, device_api, body)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Store")
class Store(Stmt):
"""Store node.
@@ -140,7 +140,7 @@ class Store(Stmt):
_ffi_api.Store, buffer_var, value, index, *args)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.BufferStore")
class BufferStore(Stmt):
"""Buffer store node.
@@ -160,7 +160,7 @@ class BufferStore(Stmt):
_ffi_api.BufferStore, buffer, value, indices)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.BufferRealize")
class BufferRealize(Stmt):
"""Buffer realize node.
@@ -183,7 +183,7 @@ class BufferRealize(Stmt):
_ffi_api.BufferRealize, buffer, bounds, condition, body)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.ProducerStore")
class ProducerStore(Stmt):
"""ProducerStore node.
@@ -203,7 +203,7 @@ class ProducerStore(Stmt):
_ffi_api.ProducerStore, producer, value, indices)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Allocate")
class Allocate(Stmt):
"""Allocate node.
@@ -235,7 +235,7 @@ class Allocate(Stmt):
extents, condition, body)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.AttrStmt")
class AttrStmt(Stmt):
"""AttrStmt node.
@@ -258,7 +258,7 @@ class AttrStmt(Stmt):
_ffi_api.AttrStmt, node, attr_key, value, body)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Free")
class Free(Stmt):
"""Free node.
@@ -272,7 +272,7 @@ class Free(Stmt):
_ffi_api.Free, buffer_var)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.ProducerRealize")
class ProducerRealize(Stmt):
"""ProducerRealize node.
@@ -299,7 +299,7 @@ class ProducerRealize(Stmt):
_ffi_api.ProducerRealize, producer, bounds, condition, body)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.SeqStmt")
class SeqStmt(Stmt):
"""Sequence of statements.
@@ -319,7 +319,7 @@ class SeqStmt(Stmt):
return len(self.seq)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.IfThenElse")
class IfThenElse(Stmt):
"""IfThenElse node.
@@ -339,7 +339,7 @@ class IfThenElse(Stmt):
_ffi_api.IfThenElse, condition, then_case, else_case)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Evaluate")
class Evaluate(Stmt):
"""Evaluate node.
@@ -353,7 +353,7 @@ class Evaluate(Stmt):
_ffi_api.Evaluate, value)
-@tvm._ffi.register_object
+@tvm._ffi.register_object("tir.Prefetch")
class Prefetch(Stmt):
"""Prefetch node.
diff --git a/src/tir/pass/hoist_if_then_else.cc
b/src/tir/pass/hoist_if_then_else.cc
index 868845f..d1e24b9 100644
--- a/src/tir/pass/hoist_if_then_else.cc
+++ b/src/tir/pass/hoist_if_then_else.cc
@@ -159,7 +159,7 @@ Stmt update_for(const Stmt& parent_for_stmt, const Stmt&
new_if_stmt) {
}
});
- return IRTransform(parent_for_stmt, nullptr, replace_target_for,
Array<String>{"For"});
+ return IRTransform(parent_for_stmt, nullptr, replace_target_for,
Array<String>{"tir.For"});
}
// Remove IfThenElse node from a For node.
@@ -183,9 +183,9 @@ std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const
Stmt& if_stmt) {
}
});
- then_for = IRTransform(for_stmt, nullptr, replace_then_case,
Array<String>{"IfThenElse"});
+ then_for = IRTransform(for_stmt, nullptr, replace_then_case,
Array<String>{"tir.IfThenElse"});
if (if_stmt.as<IfThenElseNode>()->else_case.defined()) {
- else_for = IRTransform(for_stmt, nullptr, replace_else_case,
Array<String>{"IfThenElse"});
+ else_for = IRTransform(for_stmt, nullptr, replace_else_case,
Array<String>{"tir.IfThenElse"});
}
return std::make_pair(then_for, else_for);
@@ -393,7 +393,7 @@ Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) {
*ret = new_for;
}
});
- return IRTransform(stmt, nullptr, replace_top_for, Array<String>{"For"});
+ return IRTransform(stmt, nullptr, replace_top_for, Array<String>{"tir.For"});
}
Stmt HoistIfThenElse(Stmt stmt) { return
IfThenElseHoist().VisitAndMutate(stmt); }
diff --git a/tests/python/unittest/test_target_codegen_cuda.py
b/tests/python/unittest/test_target_codegen_cuda.py
index bafa957..1a7163f 100644
--- a/tests/python/unittest/test_target_codegen_cuda.py
+++ b/tests/python/unittest/test_target_codegen_cuda.py
@@ -214,7 +214,7 @@ def test_cuda_shuffle():
def _transform(f, *_):
return f.with_body(
- tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer,
['For']))
+ tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer,
['tir.For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0,
name="MyVectorize")
with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1,
MyVectorize())]}):
diff --git a/tests/python/unittest/test_target_codegen_llvm.py
b/tests/python/unittest/test_target_codegen_llvm.py
index 34db08f..1173b71 100644
--- a/tests/python/unittest/test_target_codegen_llvm.py
+++ b/tests/python/unittest/test_target_codegen_llvm.py
@@ -724,7 +724,7 @@ def test_llvm_shuffle():
def _transform(f, *_):
return f.with_body(
- tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer,
['For']))
+ tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer,
['tir.For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0,
name="my_vectorize")
diff --git a/tests/python/unittest/test_tir_pass_hoist_if.py
b/tests/python/unittest/test_tir_pass_hoist_if.py
index 346239d..80e93a7 100644
--- a/tests/python/unittest/test_tir_pass_hoist_if.py
+++ b/tests/python/unittest/test_tir_pass_hoist_if.py
@@ -33,12 +33,12 @@ def verify_structure(stmt, expected_struct):
if isinstance(op, tvm.tir.IfThenElse):
global var_list
tvm.tir.stmt_functor.post_order_visit(op.condition, _extract_vars)
- val = [(op.then_case, op.else_case), ("IfThenElse",
tuple(var_list))]
+ val = [(op.then_case, op.else_case), ("tir.IfThenElse",
tuple(var_list))]
var_list.clear()
elif isinstance(op, tvm.tir.For):
- val = [(op.body,), ("For", op.loop_var.name)]
+ val = [(op.body,), ("tir.For", op.loop_var.name)]
elif isinstance(op, tvm.tir.AttrStmt):
- val = [(op.body,), ("AttrStmt", op.attr_key, int(op.value))]
+ val = [(op.body,), ("tir.AttrStmt", op.attr_key, int(op.value))]
else:
return
node_dict[key] = val
@@ -68,9 +68,9 @@ def test_basic():
stmt = ib.get()
new_stmt = tvm.testing.HoistIfThenElse(stmt)
- expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),),
- ('IfThenElse', ('i',)): (('For', 'j'), ('For', 'j')),
- ('For', 'i'): (('IfThenElse', ('i',)),)}
+ expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'):
(('tir.For', 'k'),),
+ ('tir.IfThenElse', ('i',)): (('tir.For', 'j'),
('tir.For', 'j')),
+ ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
verify_structure(new_stmt, expected_struct)
def test_no_else():
@@ -87,9 +87,9 @@ def test_no_else():
stmt = ib.get()
new_stmt = tvm.testing.HoistIfThenElse(stmt)
- expected_struct = {('For', 'k'): (None,), ('For', 'j'): (('For', 'k'),),
- ('IfThenElse', ('i',)): (('For', 'j'), None),
- ('For', 'i'): (('IfThenElse', ('i',)),)}
+ expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'):
(('tir.For', 'k'),),
+ ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None),
+ ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
verify_structure(new_stmt, expected_struct)
def test_attr_stmt():
@@ -114,10 +114,10 @@ def test_attr_stmt():
stmt = ib.get()
new_stmt = tvm.testing.HoistIfThenElse(stmt)
- expected_struct = {('For', 'k'): (None,), ('IfThenElse', ('i', 'j')):
(('For', 'k'), ('For', 'k')),
- ('For', 'j'): (('IfThenElse', ('i', 'j')),), ('For',
'i'): (('For', 'j'),),
- ('AttrStmt', 'thread_extent', 64): (('For', 'i'),),
- ('AttrStmt', 'thread_extent', 32): (('AttrStmt',
'thread_extent', 64),)}
+ expected_struct = {('tir.For', 'k'): (None,), ('tir.IfThenElse', ('i',
'j')): (('tir.For', 'k'), ('tir.For', 'k')),
+ ('tir.For', 'j'): (('tir.IfThenElse', ('i', 'j')),),
('tir.For', 'i'): (('tir.For', 'j'),),
+ ('tir.AttrStmt', 'thread_extent', 64): (('tir.For',
'i'),),
+ ('tir.AttrStmt', 'thread_extent', 32):
(('tir.AttrStmt', 'thread_extent', 64),)}
verify_structure(new_stmt, expected_struct)
def test_nested_for():
@@ -138,9 +138,9 @@ def test_nested_for():
stmt = ib.get()
new_stmt = tvm.testing.HoistIfThenElse(stmt)
- expected_struct = {('IfThenElse', ('i', 'j')): (None, None), ('For', 'l'):
(('IfThenElse', ('i', 'j')),),
- ('For', 'k'): (('For', 'l'),), ('For', 'j'): (None,),
('IfThenElse', ('i',)): (('For', 'j'), None),
- ('For', 'i'): (('IfThenElse', ('i',)),)}
+ expected_struct = {('tir.IfThenElse', ('i', 'j')): (None, None),
('tir.For', 'l'): (('tir.IfThenElse', ('i', 'j')),),
+ ('tir.For', 'k'): (('tir.For', 'l'),), ('tir.For',
'j'): (None,), ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None),
+ ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
verify_structure(new_stmt, expected_struct)
def test_if_block():
@@ -171,10 +171,10 @@ def test_if_block():
stmt = ib.get()
new_stmt = tvm.testing.HoistIfThenElse(stmt)
- expected_struct = {('IfThenElse', ('i', 'j')): (None, None),
('IfThenElse', ('j',)): (None, None),
- ('For', 'l'): (None,), ('For', 'k'): (None,), ('For',
'j'): (('For', 'j'),),
- ('IfThenElse', ('i',)): (('For', 'j'), None), ('For',
'i'): (('IfThenElse', ('i',)),),
- ('IfThenElse', ('n',)): (('For', 'j'), None)}
+ expected_struct = {('tir.IfThenElse', ('i', 'j')): (None, None),
('tir.IfThenElse', ('j',)): (None, None),
+ ('tir.For', 'l'): (None,), ('tir.For', 'k'): (None,),
('tir.For', 'j'): (('tir.For', 'j'),),
+ ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None),
('tir.For', 'i'): (('tir.IfThenElse', ('i',)),),
+ ('tir.IfThenElse', ('n',)): (('tir.For', 'j'), None)}
verify_structure(new_stmt, expected_struct)
diff --git a/tests/python/unittest/test_tir_stmt_functor_ir_transform.py
b/tests/python/unittest/test_tir_stmt_functor_ir_transform.py
index 7bf7011..38529e9 100644
--- a/tests/python/unittest/test_tir_stmt_functor_ir_transform.py
+++ b/tests/python/unittest/test_tir_stmt_functor_ir_transform.py
@@ -37,7 +37,7 @@ def test_ir_transform():
if op.name == "TestA":
return tvm.tir.call_extern("int32", "TestB", op.args[0] + 1)
return op
- body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder,
["Call"])
+ body = tvm.tir.stmt_functor.ir_transform(body, preorder, postorder,
["tir.Call"])
stmt_list = tvm.tir.stmt_list(body.body.body)
assert stmt_list[0].value.args[0].name == "TestB"
assert stmt_list[1].value.value == 0
diff --git a/tutorials/dev/low_level_custom_pass.py
b/tutorials/dev/low_level_custom_pass.py
index db50572..17f864f 100644
--- a/tutorials/dev/low_level_custom_pass.py
+++ b/tutorials/dev/low_level_custom_pass.py
@@ -84,7 +84,7 @@ print(ir)
loops = []
def find_width8(op):
- """ Find all the 'For' nodes whose extent can be divided by 8. """
+ """ Find all the 'tir.For' nodes whose extent can be divided by 8. """
if isinstance(op, tvm.tir.For):
if isinstance(op.extent, tvm.tir.IntImm):
if op.extent.value % 8 == 0:
@@ -129,7 +129,7 @@ def vectorize(f, mod, ctx):
# The last list arugment indicates what kinds of nodes will be transformed.
# Thus, in this case only `For` nodes will call `vectorize8`
return f.with_body(
- tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ['For']))
+ tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8,
['tir.For']))
#####################################################################
diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py
index 37b4e0e..207f784 100644
--- a/vta/python/vta/transform.py
+++ b/vta/python/vta/transform.py
@@ -87,7 +87,7 @@ def FoldUopLoop():
return op
ret = tvm.tir.stmt_functor.ir_transform(
- stmt.body, None, _post_order, ["Call"])
+ stmt.body, None, _post_order, ["tir.Call"])
if not fail[0] and all(x is not None for x in gemm_offsets):
def _visit(op):
@@ -132,7 +132,7 @@ def FoldUopLoop():
def _ftransform(f, mod, ctx):
return f.with_body(tvm.tir.stmt_functor.ir_transform(
- f.body, _do_fold, None, ["AttrStmt"]))
+ f.body, _do_fold, None, ["tir.AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.FoldUopLoop")
@@ -188,7 +188,7 @@ def CPUAccessRewrite():
stmt_in = f.body
stmt = tvm.tir.stmt_functor.ir_transform(
- stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
+ stmt_in, None, _post_order, ["tir.Allocate", "tir.Load",
"tir.Store"])
for buffer_var, new_var in rw_info.items():
stmt = tvm.tir.LetStmt(
@@ -254,7 +254,7 @@ def LiftAllocToScopeBegin():
raise RuntimeError("not reached")
stmt_in = f.body
stmt = tvm.tir.stmt_functor.ir_transform(
- stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
+ stmt_in, _pre_order, _post_order, ["tir.Allocate", "tir.AttrStmt",
"tir.For"])
assert len(lift_stmt) == 1
return f.with_body(_merge_block(lift_stmt[0], stmt))
@@ -277,7 +277,7 @@ def InjectSkipCopy():
def _ftransform(f, mod, ctx):
return f.with_body(tvm.tir.stmt_functor.ir_transform(
- f.body, _do_fold, None, ["AttrStmt"]))
+ f.body, _do_fold, None, ["tir.AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectSkipCopy")
@@ -307,7 +307,7 @@ def InjectCoProcSync():
op.device_api, op.body)
return None
return f.with_body(tvm.tir.stmt_functor.ir_transform(
- f.body, None, _do_fold, ["AttrStmt"]))
+ f.body, None, _do_fold, ["tir.AttrStmt"]))
return tvm.transform.Sequential(
[tvm.tir.transform.prim_func_pass(_ftransform, 0,
"tir.vta.InjectCoProcSync"),
tvm.tir.transform.CoProcSync()],
@@ -708,7 +708,7 @@ def InjectConv2DTransposeSkip():
return None
return func.with_body(tvm.tir.stmt_functor.ir_transform(
- func.body, _do_fold, None, ["AttrStmt"]))
+ func.body, _do_fold, None, ["tir.AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip")
@@ -737,7 +737,7 @@ def AnnotateALUCoProcScope():
return stmt
return func.with_body(tvm.tir.stmt_functor.ir_transform(
- func.body, None, _do_fold, ["AttrStmt"]))
+ func.body, None, _do_fold, ["tir.AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope")
@@ -956,7 +956,7 @@ def InjectALUIntrin():
return stmt
return func.with_body(tvm.tir.stmt_functor.ir_transform(
- func.body, None, _do_fold, ["AttrStmt"]))
+ func.body, None, _do_fold, ["tir.AttrStmt"]))
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectALUIntrin")