This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c8140643d3 [REFACTOR][S-TIR] Move remaining data structures to s_tir
(#18743)
c8140643d3 is described below
commit c8140643d33917ce44ebfdcd379e1b882c5f8f9d
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Feb 10 18:24:46 2026 -0500
[REFACTOR][S-TIR] Move remaining data structures to s_tir (#18743)
This PR moves remaining related data structures to s_tir.
- Moves sblock_dependency_info and sblock_scope.
- Moves related analyssis.
- Hides the data_type_rewriter to private functions.
---
include/tvm/s_tir/analysis.h | 88 +++++++++++++
include/tvm/{tir => s_tir}/data_layout.h | 12 +-
.../sblock_dependence_info.h} | 38 +++---
.../{tir/block_scope.h => s_tir/sblock_scope.h} | 14 +--
include/tvm/s_tir/schedule/state.h | 2 +-
include/tvm/{tir => s_tir}/utils.h | 8 +-
include/tvm/tir/analysis.h | 53 +-------
include/tvm/topi/einsum.h | 1 -
include/tvm/topi/transform.h | 2 +-
python/tvm/relax/transform/legalize_ops/nn.py | 20 +--
python/tvm/s_tir/__init__.py | 4 +-
python/tvm/s_tir/_ffi_api.py | 4 +-
python/tvm/s_tir/analysis/__init__.py | 120 ++++++++++++++++++
python/tvm/s_tir/{ => analysis}/_ffi_api.py | 4 +-
python/tvm/s_tir/block_dependence_info.py | 6 +-
python/tvm/{tir => s_tir}/data_layout.py | 4 +-
...ependence_info.py => sblock_dependence_info.py} | 6 +-
.../tvm/s_tir/{block_scope.py => sblock_scope.py} | 6 +-
python/tvm/s_tir/schedule/__init__.py | 2 +-
python/tvm/s_tir/schedule/state.py | 2 +-
python/tvm/tir/__init__.py | 1 -
python/tvm/tir/analysis/analysis.py | 97 +-------------
python/tvm/topi/utils.py | 3 +-
src/contrib/msc/core/ir/graph.h | 2 +-
src/contrib/msc/core/ir/graph_builder.h | 2 +-
src/contrib/msc/core/ir/plugin.h | 2 +-
src/relax/op/op_common.h | 2 +-
src/relax/transform/infer_amp_utils.h | 1 -
src/relax/transform/infer_layout_utils.h | 2 +-
.../analysis/find_anchor_sblock.cc} | 49 ++------
.../analysis/sblock_access_region_detector.cc} | 10 +-
.../analysis/sblock_buffer_access_lca_detector.cc} | 4 +-
src/{tir/ir => s_tir}/data_layout.cc | 22 ++--
.../sblock_dependence_info.cc} | 40 +++---
src/s_tir/schedule/primitive/blockize_tensorize.cc | 2 +-
src/s_tir/schedule/utils.h | 2 +-
src/te/operation/create_primfunc.cc | 2 +-
src/tir/analysis/stmt_finding.cc | 87 -------------
src/tir/ir/block_scope.cc | 16 +--
src/tir/ir/data_type_rewriter.cc | 6 +-
.../tvm/tir => src/tir/ir}/data_type_rewriter.h | 6 +-
src/tir/ir/stmt_functor.cc | 2 +-
src/tir/transform/force_narrow_index_to_i32.cc | 3 +-
src/tir/transform/narrow_datatype.cc | 2 +-
tests/cpp/data_type_rewriter_test.cc | 140 ---------------------
.../analysis/test_sblock_access_region.py} | 40 +++---
.../analysis/test_sblock_buffer_access_lca.py} | 12 +-
...ence_info.py => test_sblock_dependence_info.py} | 2 +-
.../base}/test_tir_data_layout.py | 34 ++---
.../base}/test_tir_te_extern_primfunc.py | 0
tests/scripts/task_python_unittest.sh | 1 +
51 files changed, 400 insertions(+), 590 deletions(-)
diff --git a/include/tvm/s_tir/analysis.h b/include/tvm/s_tir/analysis.h
new file mode 100644
index 0000000000..c9a00648af
--- /dev/null
+++ b/include/tvm/s_tir/analysis.h
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/s_tir/analysis.h
+ * \brief Analysis utilities for Schedulable TensorIR (S-TIR).
+ */
+#ifndef TVM_S_TIR_ANALYSIS_H_
+#define TVM_S_TIR_ANALYSIS_H_
+
+#include <tvm/ir/module.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt.h>
+
+namespace tvm {
+namespace tir {
+
+/*!
+ * \brief Auto detect the block access region according to its body stmt
+ * It will detect the access region as an array in order of appearance
in AST
+ * \param block The block to be detected
+ * \param buffer_var_map The outside buffers which may be accessed the block.
+ * It is a map from buffer var to the buffer.
+ * \return Array of access regions.
+ * There are three arrays of BufferRegion:
+ * - first: read regions
+ * - second: write regions
+ * - third: opaque regions
+ */
+TVM_DLL ffi::Array<ffi::Array<BufferRegion>> GetSBlockAccessRegion(
+ const SBlock& block, const ffi::Map<Var, Buffer>& buffer_var_map);
+
+/*!
+ * \brief Auto detect the block read/write region according to its body stmt.
An opaque access will
+ * be counted as both a read and a write access
+ * \param block The block to be detected
+ * \param buffer_var_map The outside buffers which may be accessed the block.
+ * It is a map from buffer var to the buffer
+ * \return An array only consisting of the read regions and write regions of
the input block
+ */
+TVM_DLL ffi::Array<ffi::Array<BufferRegion>> GetSBlockReadWriteRegion(
+ const SBlock& block, const ffi::Map<Var, Buffer>& buffer_var_map);
+
+/*!
+ * \brief Detect the lowest common ancestor(LCA) of buffer access, including
both high-level
+ * access(BufferLoad, BufferStore) and low-level access(Load, Store and
opaque access).
+ * The LCA may be a For loop or a Block.
+ * \param func The PrimFunc to be detected.
+ * \return The Map from buffer to the LCA of all access to it. The lca is
function root if the
+ * return stmt is std::nullopt.
+ */
+TVM_DLL ffi::Map<Buffer, ffi::Optional<Stmt>> DetectBufferAccessLCA(const
PrimFunc& func);
+
+/*!
+ * \brief Find the "anchor block" of the given module.
+ * We define the anchor block to be the block with (1) an init statement and
(2) having
+ * the biggest flops count. The latter condition is only used when there are
multiple blocks
+ * with an init statement.
+ * For example, if the input module is conv2d + fused spatial blocks, conv2d
is the anchor block.
+ * The input module may not contain more than one such block. For example, a
module having
+ * two conv2d is not allowed as an input.
+ * However, a module created from winograd convolution has multiple blocks
with an init statement
+ * (input transform, batched GEMM, and output transform). We use the second
condition, the flops
+ * count, to determine that the batched GEMM block is the anchor block.
+ * \param mod The input TIR module.
+ * \return The anchor block if found, nullptr otherwise.
+ */
+const tir::SBlockNode* FindAnchorBlock(const IRModule& mod);
+
+} // namespace tir
+} // namespace tvm
+#endif // TVM_S_TIR_ANALYSIS_H_
diff --git a/include/tvm/tir/data_layout.h b/include/tvm/s_tir/data_layout.h
similarity index 97%
rename from include/tvm/tir/data_layout.h
rename to include/tvm/s_tir/data_layout.h
index 4f2a4452b8..5bdad33ba0 100644
--- a/include/tvm/tir/data_layout.h
+++ b/include/tvm/s_tir/data_layout.h
@@ -18,12 +18,12 @@
*/
/*!
- * \file tvm/tir/data_layout.h
+ * \file tvm/s_tir/data_layout.h
* \brief Layout expression to describe the data organization of a tensor.
* And BijectiveLayout to mapping two data layouts between each other.
*/
-#ifndef TVM_TIR_DATA_LAYOUT_H_
-#define TVM_TIR_DATA_LAYOUT_H_
+#ifndef TVM_S_TIR_DATA_LAYOUT_H_
+#define TVM_S_TIR_DATA_LAYOUT_H_
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/expr.h>
@@ -114,7 +114,7 @@ class LayoutNode : public Object {
.def_ro("name", &LayoutNode::name)
.def_ro("axes", &LayoutNode::axes);
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Layout", LayoutNode, Object);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.Layout", LayoutNode, Object);
};
/*!
@@ -321,7 +321,7 @@ class BijectiveLayoutNode : public Object {
.def_ro("shape_forward_rule", &BijectiveLayoutNode::shape_forward_rule)
.def_ro("shape_backward_rule",
&BijectiveLayoutNode::shape_backward_rule);
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.BijectiveLayout",
BijectiveLayoutNode, Object);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.BijectiveLayout",
BijectiveLayoutNode, Object);
};
/*!
@@ -354,4 +354,4 @@ class BijectiveLayout : public ObjectRef {
} // namespace tir
} // namespace tvm
-#endif // TVM_TIR_DATA_LAYOUT_H_
+#endif // TVM_S_TIR_DATA_LAYOUT_H_
diff --git a/include/tvm/tir/block_dependence_info.h
b/include/tvm/s_tir/sblock_dependence_info.h
similarity index 72%
rename from include/tvm/tir/block_dependence_info.h
rename to include/tvm/s_tir/sblock_dependence_info.h
index 2e56058eff..e1ec8b9588 100644
--- a/include/tvm/tir/block_dependence_info.h
+++ b/include/tvm/s_tir/sblock_dependence_info.h
@@ -17,10 +17,10 @@
* under the License.
*/
/*!
- * \file tvm/tir/block_dependence_info.h
- * \brief Define BlockDependenceInfoNode that uses the SBlockScope and
StmtSRef objects to
+ * \file tvm/s_tir/sblock_dependence_info.h
+ * \brief Define SBlockDependenceInfoNode that uses the SBlockScope and
StmtSRef objects to
* store the block level dependences
- * \sa BlockDependenceInfoNode
+ * \sa SBlockDependenceInfoNode
*/
/**
@@ -28,11 +28,11 @@
* analysis
*/
-#ifndef TVM_TIR_BLOCK_DEPENDENCE_INFO_H_
-#define TVM_TIR_BLOCK_DEPENDENCE_INFO_H_
+#ifndef TVM_S_TIR_SBLOCK_DEPENDENCE_INFO_H_
+#define TVM_S_TIR_SBLOCK_DEPENDENCE_INFO_H_
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/tir/block_scope.h>
+#include <tvm/s_tir/sblock_scope.h>
#include <unordered_map>
@@ -51,7 +51,7 @@ namespace tir {
* dependences. This provides the advantage that the scope block (parent
block) for a given block
* sref can be directly accessed using the sref->parent member
*/
-class BlockDependenceInfoNode : public Object {
+class SBlockDependenceInfoNode : public Object {
public:
/*!
* \brief Mapping from a block sref to its corresponding SBlockScope,
@@ -63,9 +63,9 @@ class BlockDependenceInfoNode : public Object {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
- refl::ObjectDef<BlockDependenceInfoNode>();
+ refl::ObjectDef<SBlockDependenceInfoNode>();
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SBlockDependenceInfo",
BlockDependenceInfoNode, Object);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.SBlockDependenceInfo",
SBlockDependenceInfoNode, Object);
/*!
* \brief Get the SBlockScope corresponding to the sref of scope root block
@@ -82,23 +82,23 @@ class BlockDependenceInfoNode : public Object {
};
/*!
- * \brief Managed reference to BlockDependenceInfoNode
- * \sa BlockDependenceInfo
+ * \brief Managed reference to SBlockDependenceInfoNode
+ * \sa SBlockDependenceInfo
*/
-class BlockDependenceInfo : public ObjectRef {
- /*! \brief Construct an empty BlockDependenceInfo
+class SBlockDependenceInfo : public ObjectRef {
+ /*! \brief Construct an empty SBlockDependenceInfo
*/
- TVM_DLL BlockDependenceInfo();
+ TVM_DLL SBlockDependenceInfo();
public:
- /*! \brief Construct a BlockDependenceInfo from IRModule
+ /*! \brief Construct a SBlockDependenceInfo from IRModule
*/
- TVM_DLL BlockDependenceInfo(IRModule mod);
+ TVM_DLL SBlockDependenceInfo(IRModule mod);
- TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(BlockDependenceInfo, ObjectRef,
- BlockDependenceInfoNode);
+ TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SBlockDependenceInfo,
ObjectRef,
+ SBlockDependenceInfoNode);
};
} // namespace tir
} // namespace tvm
-#endif // TVM_TIR_BLOCK_DEPENDENCE_INFO_H_
+#endif // TVM_S_TIR_SBLOCK_DEPENDENCE_INFO_H_
diff --git a/include/tvm/tir/block_scope.h b/include/tvm/s_tir/sblock_scope.h
similarity index 96%
rename from include/tvm/tir/block_scope.h
rename to include/tvm/s_tir/sblock_scope.h
index d356643cda..a302cab260 100644
--- a/include/tvm/tir/block_scope.h
+++ b/include/tvm/s_tir/sblock_scope.h
@@ -17,13 +17,13 @@
* under the License.
*/
/*!
- * \file tvm/tir/block_scope.h
+ * \file tvm/s_tir/sblock_scope.h
* \brief Definition of two pillar data structure for TensorIR scheduling:
StmtSRef, SBlockScope.
* \sa StmtSRefNode
* \sa SBlockScopeNode
*/
-#ifndef TVM_TIR_BLOCK_SCOPE_H_
-#define TVM_TIR_BLOCK_SCOPE_H_
+#ifndef TVM_S_TIR_SBLOCK_SCOPE_H_
+#define TVM_S_TIR_SBLOCK_SCOPE_H_
#include <tvm/ir/module.h>
#include <tvm/tir/function.h>
@@ -73,7 +73,7 @@ class StmtSRefNode : public Object {
}
static constexpr const bool _type_mutable = true;
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.StmtSRef", StmtSRefNode, Object);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.StmtSRef", StmtSRefNode, Object);
/*! \brief Reset the object inplace to the invalid state */
void Reset() {
@@ -223,7 +223,7 @@ class DependencyNode : public Object {
.def_ro("dst", &DependencyNode::dst)
.def_ro("kind", &DependencyNode::kind);
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.Dependency", DependencyNode, Object);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.Dependency", DependencyNode,
Object);
};
/*!
@@ -267,7 +267,7 @@ class SBlockScopeNode : public Object {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<SBlockScopeNode>();
}
- TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tir.SBlockScope", SBlockScopeNode,
Object);
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("s_tir.SBlockScope", SBlockScopeNode,
Object);
public:
/******** Dependency ********/
@@ -314,4 +314,4 @@ class SBlockScope : public ObjectRef {
} // namespace tir
} // namespace tvm
-#endif // TVM_TIR_BLOCK_SCOPE_H_
+#endif // TVM_S_TIR_SBLOCK_SCOPE_H_
diff --git a/include/tvm/s_tir/schedule/state.h
b/include/tvm/s_tir/schedule/state.h
index 821125037c..03d7ddbdb3 100644
--- a/include/tvm/s_tir/schedule/state.h
+++ b/include/tvm/s_tir/schedule/state.h
@@ -25,7 +25,7 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/module.h>
-#include <tvm/tir/block_scope.h>
+#include <tvm/s_tir/sblock_scope.h>
#include <tvm/tir/function.h>
#include <unordered_map>
diff --git a/include/tvm/tir/utils.h b/include/tvm/s_tir/utils.h
similarity index 98%
rename from include/tvm/tir/utils.h
rename to include/tvm/s_tir/utils.h
index a62b136219..bedcb372d3 100644
--- a/include/tvm/tir/utils.h
+++ b/include/tvm/s_tir/utils.h
@@ -16,10 +16,10 @@
* specific language governing permissions and limitations
* under the License.
*/
-#ifndef TVM_TIR_UTILS_H_
-#define TVM_TIR_UTILS_H_
+#ifndef TVM_S_TIR_UTILS_H_
+#define TVM_S_TIR_UTILS_H_
-#include <tvm/tir/block_scope.h>
+#include <tvm/s_tir/sblock_scope.h>
#include <tvm/tir/stmt.h>
#include <unordered_map>
@@ -138,4 +138,4 @@ inline void SetSeqIndexInChildren(
} // namespace tir
} // namespace tvm
-#endif // TVM_TIR_UTILS_H_
+#endif // TVM_S_TIR_UTILS_H_
diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h
index 8e71d5a1a5..7953396abb 100644
--- a/include/tvm/tir/analysis.h
+++ b/include/tvm/tir/analysis.h
@@ -26,6 +26,7 @@
#include <tvm/ir/module.h>
#include <tvm/ir/transform.h>
+#include <tvm/s_tir/analysis.h>
#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
@@ -221,32 +222,6 @@ TVM_DLL bool VerifyVTCMLimit(const IRModule& mod, Integer
limit);
*/
TVM_DLL bool VerifyVTCMLimit(const PrimFunc& func, Integer limit);
-/*!
- * \brief Auto detect the block access region according to its body stmt
- * It will detect the access region as an array in order of appearance
in AST
- * \param block The block to be detected
- * \param buffer_var_map The outside buffers which may be accessed the block.
- * It is a map from buffer var to the buffer.
- * \return Array of access regions.
- * There are three arrays of BufferRegion:
- * - first: read regions
- * - second: write regions
- * - third: opaque regions
- */
-TVM_DLL ffi::Array<ffi::Array<BufferRegion>> GetSBlockAccessRegion(
- const SBlock& block, const ffi::Map<Var, Buffer>& buffer_var_map);
-
-/*!
- * \brief Auto detect the block read/write region according to its body stmt.
An opaque access will
- * be counted as both a read and a write access
- * \param block The block to be detected
- * \param buffer_var_map The outside buffers which may be accessed the block.
- * It is a map from buffer var to the buffer
- * \return An array only consisting of the read regions and write regions of
the input block
- */
-TVM_DLL ffi::Array<ffi::Array<BufferRegion>> GetSBlockReadWriteRegion(
- const SBlock& block, const ffi::Map<Var, Buffer>& buffer_var_map);
-
/*! \brief Helper struct for return value of IdentifyMemCpy
*
* This helper struct is not strictly necessary, as `IdentifyMemCpy`
@@ -310,16 +285,6 @@ TVM_DLL tvm::ffi::Map<ffi::String,
tvm::ffi::Map<ffi::String, Integer>> Calculat
TVM_DLL tvm::ffi::Map<ffi::String, tvm::ffi::Map<ffi::String, Integer>>
CalculateAllocatedBytes(
const IRModule& mod);
-/*!
- * \brief Detect the lowest common ancestor(LCA) of buffer access, including
both high-level
- * access(BufferLoad, BufferStore) and low-level access(Load, Store and
opaque access).
- * The LCA may be a For loop or a Block.
- * \param func The PrimFunc to be detected.
- * \return The Map from buffer to the LCA of all access to it. The lca is
function root if the
- * return stmt is std::nullopt.
- */
-TVM_DLL ffi::Map<Buffer, ffi::Optional<Stmt>> DetectBufferAccessLCA(const
PrimFunc& func);
-
/*!
* \brief Verify if the given TIR is well-formed. The verification includes:
*
@@ -365,22 +330,6 @@ TVM_DLL bool VerifyWellFormed(const IRModule& mod, bool
assert_mode = true);
*/
const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar*
result_g_var);
-/*!
- * \brief Find the "anchor block" of the given module.
- * We define the anchor block to be the block with (1) an init statement and
(2) having
- * the biggest flops count. The latter condition is only used when there are
multiple blocks
- * with an init statement.
- * For example, if the input module is conv2d + fused spatial blocks, conv2d
is the anchor block.
- * The input module may not contain more than one such block. For example, a
module having
- * two conv2d is not allowed as an input.
- * However, a module created from winograd convolution has multiple blocks
with an init statement
- * (input transform, batched GEMM, and output transform). We use the second
condition, the flops
- * count, to determine that the batched GEMM block is the anchor block.
- * \param mod The input TIR module.
- * \return The anchor block if found, nullptr otherwise.
- */
-const tir::SBlockNode* FindAnchorBlock(const IRModule& mod);
-
// Pass variants of verification analysis
// directly throws RuntimeError when verification fails.
namespace transform {
diff --git a/include/tvm/topi/einsum.h b/include/tvm/topi/einsum.h
index 44f01b0a96..6aaad7b6db 100644
--- a/include/tvm/topi/einsum.h
+++ b/include/tvm/topi/einsum.h
@@ -29,7 +29,6 @@
#define NPY_MAXARGS 16
#include <tvm/te/operation.h>
-#include <tvm/tir/data_layout.h>
#include <tvm/topi/detail/constant_utils.h>
#include <tvm/topi/detail/ravel_unravel.h>
#include <tvm/topi/detail/tensor_utils.h>
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index ef4830a46a..6f395575ce 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -25,8 +25,8 @@
#define TVM_TOPI_TRANSFORM_H_
#include <tvm/arith/analyzer.h>
+#include <tvm/s_tir/data_layout.h>
#include <tvm/te/operation.h>
-#include <tvm/tir/data_layout.h>
#include <tvm/tir/index_map.h>
#include <tvm/topi/broadcast.h>
#include <tvm/topi/detail/broadcast.h>
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index 1a0477af20..c4f9901a03 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -20,7 +20,7 @@ import logging
import math
from typing import Optional
-from tvm import te, tir, topi
+from tvm import s_tir, te, tir, topi
from ...block_builder import BlockBuilder
from ...expr import Call, Expr
@@ -42,8 +42,8 @@ def _nn_conv1d(bb: BlockBuilder, call: Call) -> Expr:
)
return call
if call.attrs.groups != 1:
- data_layout = tir.layout(call.attrs.data_layout)
- kernel_layout = tir.layout(call.attrs.kernel_layout)
+ data_layout = s_tir.layout(call.attrs.data_layout)
+ kernel_layout = s_tir.layout(call.attrs.kernel_layout)
ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")]
oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")]
if not isinstance(ic, tir.IntImm) or not isinstance(oc, tir.IntImm):
@@ -83,8 +83,8 @@ def _nn_conv2d(bb: BlockBuilder, call: Call) -> Expr:
)
return call
if call.attrs.groups != 1:
- data_layout = tir.layout(call.attrs.data_layout)
- kernel_layout = tir.layout(call.attrs.kernel_layout)
+ data_layout = s_tir.layout(call.attrs.data_layout)
+ kernel_layout = s_tir.layout(call.attrs.kernel_layout)
ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")]
oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")]
if not isinstance(ic, tir.IntImm) or not isinstance(oc, tir.IntImm):
@@ -124,8 +124,8 @@ def _nn_conv3d(bb: BlockBuilder, call: Call) -> Expr:
)
return call
if call.attrs.groups != 1:
- data_layout = tir.layout(call.attrs.data_layout)
- kernel_layout = tir.layout(call.attrs.kernel_layout)
+ data_layout = s_tir.layout(call.attrs.data_layout)
+ kernel_layout = s_tir.layout(call.attrs.kernel_layout)
ic = call.args[0].struct_info.shape.values[data_layout.index_of("C")]
oc = call.args[1].struct_info.shape.values[kernel_layout.index_of("O")]
if not isinstance(ic, tir.IntImm) or not isinstance(oc, tir.IntImm):
@@ -407,7 +407,7 @@ def _nn_adaptive_avg_pool1d(bb: BlockBuilder, call: Call)
-> Expr:
def te_adaptive_avg_pool1d(data, output_size, layout_str):
if output_size is None:
- layout = tir.layout(layout_str)
+ layout = s_tir.layout(layout_str)
idx_W = layout.index_of("W")
assert idx_W != -1
output_size = data.shape[idx_W]
@@ -434,7 +434,7 @@ def _nn_adaptive_avg_pool2d(bb: BlockBuilder, call: Call)
-> Expr:
def te_adaptive_avg_pool2d(data, output_size, layout_str):
if output_size is None:
- layout = tir.layout(layout_str)
+ layout = s_tir.layout(layout_str)
idx_H = layout.index_of("H")
idx_W = layout.index_of("W")
assert idx_H != -1 and idx_W != -1
@@ -462,7 +462,7 @@ def _nn_adaptive_avg_pool3d(bb: BlockBuilder, call: Call)
-> Expr:
def te_adaptive_avg_pool3d(data, output_size, layout_str):
if output_size is None:
- layout = tir.layout(layout_str)
+ layout = s_tir.layout(layout_str)
idx_D = layout.index_of("D")
idx_H = layout.index_of("H")
idx_W = layout.index_of("W")
diff --git a/python/tvm/s_tir/__init__.py b/python/tvm/s_tir/__init__.py
index fc2fdab196..ff6ca7347b 100644
--- a/python/tvm/s_tir/__init__.py
+++ b/python/tvm/s_tir/__init__.py
@@ -28,8 +28,10 @@ from . import pipeline
from . import transform
from . import schedule
from .schedule import StmtSRef, SBlockScope, ScheduleState, Schedule,
ScheduleError, Trace
-from .block_dependence_info import SBlockDependenceInfo
+from .sblock_dependence_info import SBlockDependenceInfo
+from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
if not _RUNTIME_ONLY:
+ from . import analysis
from . import meta_schedule
from . import dlight
diff --git a/python/tvm/s_tir/_ffi_api.py b/python/tvm/s_tir/_ffi_api.py
index 4140cda741..49cd4606ab 100644
--- a/python/tvm/s_tir/_ffi_api.py
+++ b/python/tvm/s_tir/_ffi_api.py
@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""FFI APIs for tvm.tir"""
+"""FFI APIs for tvm.s_tir"""
import tvm_ffi
-tvm_ffi.init_ffi_api("tir", __name__)
+tvm_ffi.init_ffi_api("s_tir", __name__)
diff --git a/python/tvm/s_tir/analysis/__init__.py
b/python/tvm/s_tir/analysis/__init__.py
new file mode 100644
index 0000000000..194586a72c
--- /dev/null
+++ b/python/tvm/s_tir/analysis/__init__.py
@@ -0,0 +1,120 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Analysis utilities for Schedulable TensorIR (S-TIR)."""
+# pylint: disable=invalid-name
+from typing import Dict, List, Optional
+
+from tvm.ir import IRModule
+from tvm.tir.expr import Var
+from tvm.tir.stmt import SBlock, BufferRegion
+
+from tvm.tir import Buffer, Stmt
+from tvm.tir.function import PrimFunc
+from . import _ffi_api
+
+
+def get_sblock_access_region(
+ block: SBlock, buffer_var_map: Dict[Var, Buffer]
+) -> List[List[BufferRegion]]:
+ """Detect which regions of tensors in this block are read or written to.
+ Regions are sorted by order of appearance in the AST.
+
+ Parameters
+ ----------
+ block: tvm.tir.SBlock
+ The block in which we are detecting read/write regions.
+
+ buffer_var_map : Dict[Var, Buffer]
+ The outside buffers which may access the block. Mapping from buffer
var to the buffer
+
+ Returns
+ -------
+ result : List[List[BufferRegion]]
+ Array of access regions. There are three arrays of BufferRegion:
+ - first: read regions
+ - second: write regions
+ - third: opaque regions
+ """
+ return _ffi_api.GetSBlockAccessRegion(block, buffer_var_map) # type:
ignore
+
+
+def get_sblock_read_write_region(
+ block: SBlock, buffer_var_map: Dict[Var, Buffer]
+) -> List[List[BufferRegion]]:
+ """Auto detect the block read/write region according to its body stmt.
+ An opaque access will be counted as both a read and a write access
+
+ Parameters
+ ----------
+ block: tvm.tir.SBlock
+ The block in which we are detecting read/write regions.
+
+ buffer_var_map : Dict[Var, Buffer]
+ The outside buffers which may access the block. Mapping from buffer
var to the buffer
+
+ Returns
+ -------
+ result : List[List[BufferRegion]]
+ An array only consisting of the read regions and write regions of the
input block
+ """
+ return _ffi_api.GetSBlockReadWriteRegion(block, buffer_var_map) # type:
ignore
+
+
+def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]:
+ """Detect the lowest common ancestor(LCA) of buffer access, including both
high-level
+ access (BufferLoad, BufferStore) and low-level access (BufferLoad,
BufferStore and opaque
+ access).
+ The LCA may be a For loop or a Block.
+
+ Parameters
+ ----------
+ func: tvm.tir.PrimFunc
+ The function to be detected.
+
+ Returns
+ -------
+ result : Dict[Buffer, Stmt]
+ Map from buffer to the LCA of all access to it.
+ """
+ return _ffi_api.detect_buffer_access_lca(func) # type: ignore # pylint:
disable=no-member
+
+
+def find_anchor_sblock(mod: IRModule) -> Optional[SBlock]:
+ """Find the "anchor block" of the given module.
+
+ We define the anchor block to be the block with (1) an init statement and
(2) having
+ the biggest flops count. The latter condition is only used when there are
multiple blocks
+ with an init statement.
+
+ For example, if the input module is conv2d + fused spatial blocks, conv2d
is the anchor block.
+ The input module may not contain more than one such block. For example, a
module having
+ two conv2d is not allowed as an input.
+
+ However, a module created from winograd convolution has multiple blocks
with an init statement
+ (input transform, batched GEMM, and output transform). We use the second
condition, the flops
+ count, to determine that the batched GEMM block is the anchor block.
+
+ Parameters
+ ----------
+ mod: tvm.ir.IRModule
+ The input TIR module.
+ Returns
+ -------
+ anchor_block: Optional[SBlock]
+ The anchor block if found, None otherwise.
+ """
+ return _ffi_api.find_anchor_sblock(mod) # type: ignore # pylint:
disable=no-member
diff --git a/python/tvm/s_tir/_ffi_api.py
b/python/tvm/s_tir/analysis/_ffi_api.py
similarity index 90%
copy from python/tvm/s_tir/_ffi_api.py
copy to python/tvm/s_tir/analysis/_ffi_api.py
index 4140cda741..0d56a235ec 100644
--- a/python/tvm/s_tir/_ffi_api.py
+++ b/python/tvm/s_tir/analysis/_ffi_api.py
@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""FFI APIs for tvm.tir"""
+"""FFI APIs for tvm.s_tir.analysis"""
import tvm_ffi
-tvm_ffi.init_ffi_api("tir", __name__)
+tvm_ffi.init_ffi_api("s_tir.analysis", __name__)
diff --git a/python/tvm/s_tir/block_dependence_info.py
b/python/tvm/s_tir/block_dependence_info.py
index 8deba7e3a7..a95f018d30 100644
--- a/python/tvm/s_tir/block_dependence_info.py
+++ b/python/tvm/s_tir/block_dependence_info.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Define BlockDependenceInfoNode that uses the SBlockScope and StmtSRef
objects
+"""Define SBlockDependenceInfoNode that uses the SBlockScope and StmtSRef
objects
to store the block level dependences"""
from typing import Union, Optional
@@ -23,11 +23,11 @@ from tvm.ir.module import IRModule
from tvm.runtime import Object
from tvm.tir import SBlock, PrimFunc
-from .block_scope import SBlockScope, StmtSRef
+from .sblock_scope import SBlockScope, StmtSRef
from . import _ffi_api
-@register_object("tir.SBlockDependenceInfo")
+@register_object("s_tir.SBlockDependenceInfo")
class SBlockDependenceInfo(Object):
"""
SBlockDependenceInfo
diff --git a/python/tvm/tir/data_layout.py b/python/tvm/s_tir/data_layout.py
similarity index 98%
rename from python/tvm/tir/data_layout.py
rename to python/tvm/s_tir/data_layout.py
index f9c0e0cdc7..6875eacb85 100644
--- a/python/tvm/tir/data_layout.py
+++ b/python/tvm/s_tir/data_layout.py
@@ -23,7 +23,7 @@ from tvm.runtime import Object
from . import _ffi_api
-@tvm_ffi.register_object("tir.Layout")
+@tvm_ffi.register_object("s_tir.Layout")
class Layout(Object):
"""Layout is composed of upper cases, lower cases and numbers,
where upper case indicates a primal axis and
@@ -81,7 +81,7 @@ class Layout(Object):
return _ffi_api.LayoutFactorOf(self, axis) # type: ignore
-@tvm_ffi.register_object("tir.BijectiveLayout")
+@tvm_ffi.register_object("s_tir.BijectiveLayout")
class BijectiveLayout(Object):
"""Bijective mapping for two layouts (src-layout and dst-layout).
It provides shape and index conversion between each other.
diff --git a/python/tvm/s_tir/block_dependence_info.py
b/python/tvm/s_tir/sblock_dependence_info.py
similarity index 94%
copy from python/tvm/s_tir/block_dependence_info.py
copy to python/tvm/s_tir/sblock_dependence_info.py
index 8deba7e3a7..a95f018d30 100644
--- a/python/tvm/s_tir/block_dependence_info.py
+++ b/python/tvm/s_tir/sblock_dependence_info.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Define BlockDependenceInfoNode that uses the SBlockScope and StmtSRef
objects
+"""Define SBlockDependenceInfoNode that uses the SBlockScope and StmtSRef
objects
to store the block level dependences"""
from typing import Union, Optional
@@ -23,11 +23,11 @@ from tvm.ir.module import IRModule
from tvm.runtime import Object
from tvm.tir import SBlock, PrimFunc
-from .block_scope import SBlockScope, StmtSRef
+from .sblock_scope import SBlockScope, StmtSRef
from . import _ffi_api
-@register_object("tir.SBlockDependenceInfo")
+@register_object("s_tir.SBlockDependenceInfo")
class SBlockDependenceInfo(Object):
"""
SBlockDependenceInfo
diff --git a/python/tvm/s_tir/block_scope.py b/python/tvm/s_tir/sblock_scope.py
similarity index 97%
rename from python/tvm/s_tir/block_scope.py
rename to python/tvm/s_tir/sblock_scope.py
index d8bc9b16e9..4963ec6a71 100644
--- a/python/tvm/s_tir/block_scope.py
+++ b/python/tvm/s_tir/sblock_scope.py
@@ -25,7 +25,7 @@ from tvm.tir import SBlock, For
from . import _ffi_api
-@register_object("tir.StmtSRef")
+@register_object("s_tir.StmtSRef")
class StmtSRef(Object):
"""An object that refers to schedulable elements in the TensorIR, aka
"sref".
@@ -86,7 +86,7 @@ class DepKind(IntEnum):
OPAQUE = 3
-@register_object("tir.Dependency")
+@register_object("s_tir.Dependency")
class Dependency(Object):
"""A tuple (src, dst, kind) representing certain types of dependency.
For example, (A, B, kRAW) means block B depends on block A, and the
dependency kind is
@@ -107,7 +107,7 @@ class Dependency(Object):
kind: DepKind
-@register_object("tir.SBlockScope")
+@register_object("s_tir.SBlockScope")
class SBlockScope(Object):
"""An object corresponds to each block sref in the sref tree, which
tracks the producer-consumer dependency between blocks.
diff --git a/python/tvm/s_tir/schedule/__init__.py
b/python/tvm/s_tir/schedule/__init__.py
index 170d6dd9ab..de55fc3ac6 100644
--- a/python/tvm/s_tir/schedule/__init__.py
+++ b/python/tvm/s_tir/schedule/__init__.py
@@ -17,7 +17,7 @@
# pylint: disable=unused-import
"""Namespace for the TensorIR schedule API."""
-from ..block_scope import SBlockScope, Dependency, DepKind, StmtSRef
+from ..sblock_scope import SBlockScope, Dependency, DepKind, StmtSRef
from .instruction import Instruction, InstructionKind
from .schedule import SBlockRV, ExprRV, LoopRV, Schedule, ScheduleError
from .state import ScheduleDebugMask, ScheduleState
diff --git a/python/tvm/s_tir/schedule/state.py
b/python/tvm/s_tir/schedule/state.py
index e9f0f43090..c98090e6d1 100644
--- a/python/tvm/s_tir/schedule/state.py
+++ b/python/tvm/s_tir/schedule/state.py
@@ -26,7 +26,7 @@ from tvm.runtime import Object
from tvm.tir import SBlock, SBlockRealize, For, PrimFunc
from . import _ffi_api
-from ..block_scope import SBlockScope, StmtSRef
+from ..sblock_scope import SBlockScope, StmtSRef
CachedFlags = namedtuple("CachedFlags", ["affine_binding", "region_cover",
"stage_pipeline"])
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index 8d50cccc01..e48b08bbd7 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -20,7 +20,6 @@ from tvm.ir import PrimExpr
from tvm.runtime import const
from .buffer import Buffer, decl_buffer, DataProducer
-from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
from .expr import convert
from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
diff --git a/python/tvm/tir/analysis/analysis.py
b/python/tvm/tir/analysis/analysis.py
index f99da9dedb..b574c78895 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -21,9 +21,9 @@ from typing import Dict, List, Optional, Union
import tvm
from tvm.ir import IRModule
from tvm.tir.expr import Var
-from tvm.tir.stmt import SBlock, BufferRegion, PrimExpr
+from tvm.tir.stmt import PrimExpr
-from .. import Buffer, Stmt
+from .. import Stmt
from ..function import PrimFunc
from . import _ffi_api
@@ -116,53 +116,6 @@ def verify_gpu_code(func: PrimFunc, constraints: Dict[str,
int]) -> None:
return _ffi_api.verify_gpu_code(func, constraints) # type: ignore
-def get_sblock_access_region(
- block: SBlock, buffer_var_map: Dict[Var, Buffer]
-) -> List[List[BufferRegion]]:
- """Detect which regions of tensors in this block are read or written to.
- Regions are sorted by order of appearance in the AST.
-
- Parameters
- ----------
- block: tvm.tir.SBlock
- The block in which we are detecting read/write regions.
-
- buffer_var_map : Dict[Var, Buffer]
- The outside buffers which may access the block. Mapping from buffer
var to the buffer
-
- Returns
- -------
- result : List[List[BufferRegion]]
- Array of access regions. There are three arrays of BufferRegion:
- - first: read regions
- - second: write regions
- - third: opaque regions
- """
- return _ffi_api.GetSBlockAccessRegion(block, buffer_var_map) # type:
ignore
-
-
-def get_sblock_read_write_region(
- block: SBlock, buffer_var_map: Dict[Var, Buffer]
-) -> List[List[BufferRegion]]:
- """Auto detect the block read/write region according to its body stmt.
- An opaque access will be counted as both a read and a write access
-
- Parameters
- ----------
- block: tvm.tir.SBlock
- The block in which we are detecting read/write regions.
-
- buffer_var_map : Dict[Var, Buffer]
- The outside buffers which may access the block. Mapping from buffer
var to the buffer
-
- Returns
- -------
- result : List[List[BufferRegion]]
- An array only consisting of the read regions and write regions of the
input block
- """
- return _ffi_api.GetSBlockReadWriteRegion(block, buffer_var_map) # type:
ignore
-
-
def calculate_allocated_bytes(
func_or_mod: Union[PrimFunc, IRModule],
) -> Union[Dict[str, int], Dict[str, Dict[str, int]]]:
@@ -188,25 +141,6 @@ def calculate_allocated_bytes(
return _ffi_api.calculate_allocated_bytes(func_or_mod) # type: ignore
-def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]:
- """Detect the lowest common ancestor(LCA) of buffer access, including both
high-level
- access (BufferLoad, BufferStore) and low-level access (BufferLoad,
BufferStore and opaque
- access).
- The LCA may be a For loop or a Block.
-
- Parameters
- ----------
- func: tvm.tir.PrimFunc
- The function to be detected.
-
- Returns
- -------
- result : Dict[Buffer, Stmt]
- Map from buffer to the LCA of all access to it.
- """
- return _ffi_api.detect_buffer_access_lca(func) # type: ignore # pylint:
disable=no-member
-
-
def estimate_tir_flops(stmt_or_mod: Union[Stmt, IRModule]) -> float:
"""Estimate the FLOPs of a TIR fragment.
@@ -274,33 +208,6 @@ def OOBChecker():
return _ffi_api.OOBChecker() # type: ignore
-def find_anchor_sblock(mod: IRModule) -> SBlock:
- """Find the "anchor block" of the given module.
-
- We define the anchor block to be the block with (1) an init statement and
(2) having
- the biggest flops count. The latter condition is only used when there are
multiple blocks
- with an init statement.
-
- For example, if the input module is conv2d + fused spatial blocks, conv2d
is the anchor block.
- The input module may not contain more than one such block. For example, a
module having
- two conv2d is not allowed as an input.
-
- However, a module created from winograd convolution has multiple blocks
with an init statement
- (input transform, batched GEMM, and output transform). We use the second
condition, the flops
- count, to determine that the batched GEMM block is the anchor block.
-
- Parameters
- ----------
- mod: tvm.ir.IRModule
- The input TIR module.
- Returns
- -------
- anchor_block: SBlock
- The anchor block if found, None otherwise.
- """
- return _ffi_api.find_anchor_sblock(mod) # type: ignore # pylint:
disable=no-member
-
-
def has_if_then_else(stmt: Stmt) -> bool:
return tvm.ffi.get_global_func("s_tir.schedule.HasIfThenElse")(stmt)
diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py
index d74d5d2a84..e24c499a0e 100644
--- a/python/tvm/topi/utils.py
+++ b/python/tvm/topi/utils.py
@@ -23,7 +23,8 @@ from numbers import Integral
import numpy as np
import tvm
from tvm import te
-from tvm.tir import SizeVar, bijective_layout, layout
+from tvm.tir import SizeVar
+from tvm.s_tir import bijective_layout, layout
from . import cpp, tag
diff --git a/src/contrib/msc/core/ir/graph.h b/src/contrib/msc/core/ir/graph.h
index d795bea7fa..bb62cb194a 100644
--- a/src/contrib/msc/core/ir/graph.h
+++ b/src/contrib/msc/core/ir/graph.h
@@ -26,7 +26,7 @@
#include <dmlc/json.h>
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/tir/data_layout.h>
+#include <tvm/s_tir/data_layout.h>
#include <string>
#include <unordered_map>
diff --git a/src/contrib/msc/core/ir/graph_builder.h
b/src/contrib/msc/core/ir/graph_builder.h
index 22a4929fe1..c86801352f 100644
--- a/src/contrib/msc/core/ir/graph_builder.h
+++ b/src/contrib/msc/core/ir/graph_builder.h
@@ -29,7 +29,7 @@
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/runtime/tensor.h>
-#include <tvm/tir/data_layout.h>
+#include <tvm/s_tir/data_layout.h>
#include <set>
#include <stack>
diff --git a/src/contrib/msc/core/ir/plugin.h b/src/contrib/msc/core/ir/plugin.h
index eaf3167dcf..838b3948bf 100644
--- a/src/contrib/msc/core/ir/plugin.h
+++ b/src/contrib/msc/core/ir/plugin.h
@@ -26,7 +26,7 @@
#include <dmlc/json.h>
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/tir/data_layout.h>
+#include <tvm/s_tir/data_layout.h>
#include <string>
#include <unordered_map>
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 5c4f563beb..ee82f3eebc 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -27,7 +27,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/relax/op_attr_types.h>
-#include <tvm/tir/data_layout.h>
+#include <tvm/s_tir/data_layout.h>
#include <optional>
#include <tuple>
diff --git a/src/relax/transform/infer_amp_utils.h
b/src/relax/transform/infer_amp_utils.h
index e8ac586036..1b884d5f4b 100644
--- a/src/relax/transform/infer_amp_utils.h
+++ b/src/relax/transform/infer_amp_utils.h
@@ -29,7 +29,6 @@
#include <tvm/relax/expr.h>
#include <tvm/relax/nested_msg.h>
#include <tvm/relax/op_attr_types.h>
-#include <tvm/tir/data_layout.h>
#include <unordered_map>
#include <unordered_set>
diff --git a/src/relax/transform/infer_layout_utils.h
b/src/relax/transform/infer_layout_utils.h
index 973e46b45c..e5524d3435 100644
--- a/src/relax/transform/infer_layout_utils.h
+++ b/src/relax/transform/infer_layout_utils.h
@@ -38,7 +38,7 @@
#include <tvm/relax/expr.h>
#include <tvm/relax/nested_msg.h>
#include <tvm/relax/op_attr_types.h>
-#include <tvm/tir/data_layout.h>
+#include <tvm/s_tir/data_layout.h>
#include <array>
#include <string>
diff --git a/src/tir/analysis/stmt_finding.cc
b/src/s_tir/analysis/find_anchor_sblock.cc
similarity index 71%
copy from src/tir/analysis/stmt_finding.cc
copy to src/s_tir/analysis/find_anchor_sblock.cc
index 58879277e9..7e0be3eef6 100644
--- a/src/tir/analysis/stmt_finding.cc
+++ b/src/s_tir/analysis/find_anchor_sblock.cc
@@ -16,6 +16,12 @@
* specific language governing permissions and limitations
* under the License.
*/
+
+/*!
+ * \file s_tir/analysis/find_anchor_sblock.cc
+ * \brief Find the "anchor block" of a given module
+ */
+
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
@@ -23,47 +29,6 @@
namespace tvm {
namespace tir {
-const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar*
result_g_var) {
- GlobalVar result = NullValue<GlobalVar>();
- // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc`
- int num_prim_func = 0;
- const tir::PrimFuncNode* main_func = nullptr;
- const tir::PrimFuncNode* last_func = nullptr;
- for (const auto& kv : mod->functions) {
- GlobalVar gv = kv.first;
- BaseFunc base_func = kv.second;
- if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
- last_func = func;
- if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
- if (result_g_var != nullptr) {
- *result_g_var = gv;
- }
- return func;
- }
- if (gv->name_hint == "main") {
- main_func = func;
- result = gv;
- }
- ++num_prim_func;
- }
- }
- // Priority 2: PrimFunc whose name is `main`
- if (main_func != nullptr) {
- if (result_g_var != nullptr) {
- *result_g_var = result;
- }
- return main_func;
- }
- // Priority 3: The only PrimFunc in the IRModule
- if (num_prim_func == 1) {
- if (result_g_var != nullptr) {
- *result_g_var = result;
- }
- return last_func;
- }
- return nullptr;
-}
-
Stmt GetEnclosingLoop(const SBlockNode* block, Stmt func_body) {
struct GetRootSeqStmt : public StmtVisitor {
void VisitStmt_(const SeqStmtNode* seq) override { result = seq; }
@@ -142,7 +107,7 @@ const SBlockNode* FindAnchorBlock(const IRModule& mod) {
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.analysis.find_anchor_sblock", [](const IRModule&
mod) {
+ refl::GlobalDef().def("s_tir.analysis.find_anchor_sblock", [](const
IRModule& mod) {
auto ret = FindAnchorBlock(mod);
if (ret) {
return ffi::Optional<SBlock>(ffi::GetRef<SBlock>(ret));
diff --git a/src/tir/analysis/block_access_region_detector.cc
b/src/s_tir/analysis/sblock_access_region_detector.cc
similarity index 98%
rename from src/tir/analysis/block_access_region_detector.cc
rename to src/s_tir/analysis/sblock_access_region_detector.cc
index b79141cd1e..e8dda4d88e 100644
--- a/src/tir/analysis/block_access_region_detector.cc
+++ b/src/s_tir/analysis/sblock_access_region_detector.cc
@@ -18,8 +18,8 @@
*/
/*!
- * \file tir/analysis/block_region_detector.cc
- * \brief Detect block read/write regions by visiting its body
+ * \file s_tir/analysis/sblock_access_region_detector.cc
+ * \brief Detect sblock read/write regions by visiting its body
*/
#include <tvm/arith/analyzer.h>
@@ -29,7 +29,7 @@
#include <unordered_map>
-#include "../transform/ir_utils.h"
+#include "../../tir/transform/ir_utils.h"
namespace tvm {
namespace tir {
@@ -414,8 +414,8 @@ ffi::Array<ffi::Array<BufferRegion>>
GetSBlockReadWriteRegion(
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
- .def("tir.analysis.GetSBlockAccessRegion", GetSBlockAccessRegion)
- .def("tir.analysis.GetSBlockReadWriteRegion", GetSBlockReadWriteRegion);
+ .def("s_tir.analysis.GetSBlockAccessRegion", GetSBlockAccessRegion)
+ .def("s_tir.analysis.GetSBlockReadWriteRegion",
GetSBlockReadWriteRegion);
}
} // namespace tir
diff --git a/src/tir/analysis/buffer_access_lca_detector.cc
b/src/s_tir/analysis/sblock_buffer_access_lca_detector.cc
similarity index 98%
rename from src/tir/analysis/buffer_access_lca_detector.cc
rename to src/s_tir/analysis/sblock_buffer_access_lca_detector.cc
index 467f8123c4..67ee0dbe69 100644
--- a/src/tir/analysis/buffer_access_lca_detector.cc
+++ b/src/s_tir/analysis/sblock_buffer_access_lca_detector.cc
@@ -18,7 +18,7 @@
*/
/*!
- * \file tir/analysis/buffer_access_lca_detector.cc
+ * \file s_tir/analysis/sblock_buffer_access_lca_detector.cc
* \brief Detect the lowest common ancestor(LCA) of buffer access
*/
@@ -349,7 +349,7 @@ ffi::Map<Buffer, ffi::Optional<Stmt>>
DetectBufferAccessLCA(const PrimFunc& func
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.analysis.detect_buffer_access_lca",
DetectBufferAccessLCA);
+ refl::GlobalDef().def("s_tir.analysis.detect_buffer_access_lca",
DetectBufferAccessLCA);
}
} // namespace tir
} // namespace tvm
diff --git a/src/tir/ir/data_layout.cc b/src/s_tir/data_layout.cc
similarity index 96%
rename from src/tir/ir/data_layout.cc
rename to src/s_tir/data_layout.cc
index 75f9bb50d1..9e028d2042 100644
--- a/src/tir/ir/data_layout.cc
+++ b/src/s_tir/data_layout.cc
@@ -24,7 +24,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/tir/data_layout.h>
+#include <tvm/s_tir/data_layout.h>
#include <tvm/tir/stmt_functor.h>
#include <cctype>
@@ -433,29 +433,29 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
- .def("tir.Layout", [](std::string name, DataType dtype) { return
Layout(name, dtype); })
- .def("tir.LayoutIndexOf",
+ .def("s_tir.Layout", [](std::string name, DataType dtype) { return
Layout(name, dtype); })
+ .def("s_tir.LayoutIndexOf",
[](Layout layout, std::string axis) -> int {
return layout.IndexOf(LayoutAxis::Get(axis));
})
- .def("tir.LayoutFactorOf",
+ .def("s_tir.LayoutFactorOf",
[](Layout layout, std::string axis) -> int {
return layout.FactorOf(LayoutAxis::Get(axis));
})
- .def("tir.LayoutNdim", [](Layout layout) -> int { return layout.ndim();
})
- .def("tir.LayoutGetItem",
+ .def("s_tir.LayoutNdim", [](Layout layout) -> int { return
layout.ndim(); })
+ .def("s_tir.LayoutGetItem",
[](Layout layout, int idx) -> std::string {
const LayoutAxis& axis = layout[idx];
return axis.name();
})
- .def("tir.BijectiveLayout",
+ .def("s_tir.BijectiveLayout",
[](Layout src_layout, Layout dst_layout) -> BijectiveLayout {
return BijectiveLayout(src_layout, dst_layout);
})
- .def_method("tir.BijectiveLayoutForwardIndex",
&BijectiveLayout::ForwardIndex)
- .def_method("tir.BijectiveLayoutBackwardIndex",
&BijectiveLayout::BackwardIndex)
- .def_method("tir.BijectiveLayoutForwardShape",
&BijectiveLayout::ForwardShape)
- .def_method("tir.BijectiveLayoutBackwardShape",
&BijectiveLayout::BackwardShape);
+ .def_method("s_tir.BijectiveLayoutForwardIndex",
&BijectiveLayout::ForwardIndex)
+ .def_method("s_tir.BijectiveLayoutBackwardIndex",
&BijectiveLayout::BackwardIndex)
+ .def_method("s_tir.BijectiveLayoutForwardShape",
&BijectiveLayout::ForwardShape)
+ .def_method("s_tir.BijectiveLayoutBackwardShape",
&BijectiveLayout::BackwardShape);
}
} // namespace tir
} // namespace tvm
diff --git a/src/tir/ir/block_dependence_info.cc
b/src/s_tir/sblock_dependence_info.cc
similarity index 68%
rename from src/tir/ir/block_dependence_info.cc
rename to src/s_tir/sblock_dependence_info.cc
index deff65e8cf..bbfa691cee 100644
--- a/src/tir/ir/block_dependence_info.cc
+++ b/src/s_tir/sblock_dependence_info.cc
@@ -18,25 +18,25 @@
*/
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/tir/block_dependence_info.h>
-#include <tvm/tir/utils.h>
+#include <tvm/s_tir/sblock_dependence_info.h>
+#include <tvm/s_tir/utils.h>
namespace tvm {
namespace tir {
-TVM_FFI_STATIC_INIT_BLOCK() { BlockDependenceInfoNode::RegisterReflection(); }
+TVM_FFI_STATIC_INIT_BLOCK() { SBlockDependenceInfoNode::RegisterReflection(); }
/**
* @brief A helper class to collect and build SBlock Dependences using
SBlockScope class
*/
-class BlockDependenceInfoCollector : private StmtVisitor {
+class SBlockDependenceInfoCollector : private StmtVisitor {
public:
- static void Collect(BlockDependenceInfoNode* self, const Stmt& stmt) {
- BlockDependenceInfoCollector collector(self);
+ static void Collect(SBlockDependenceInfoNode* self, const Stmt& stmt) {
+ SBlockDependenceInfoCollector collector(self);
collector.VisitStmt(stmt);
}
- explicit BlockDependenceInfoCollector(BlockDependenceInfoNode* self)
+ explicit SBlockDependenceInfoCollector(SBlockDependenceInfoNode* self)
: self_(self), block_frames_{} {
block_frames_.emplace_back();
}
@@ -65,23 +65,25 @@ class BlockDependenceInfoCollector : private StmtVisitor {
SetSeqIndexInChildren(self_->stmt2ref, seq_stmt, false);
}
- BlockDependenceInfoNode* self_;
+ SBlockDependenceInfoNode* self_;
/*! \brief The stack frames of blocks in the DFS visit. */
std::vector<ffi::Array<StmtSRef>> block_frames_;
};
-BlockDependenceInfo::BlockDependenceInfo() { data_ =
ffi::make_object<BlockDependenceInfoNode>(); }
+SBlockDependenceInfo::SBlockDependenceInfo() {
+ data_ = ffi::make_object<SBlockDependenceInfoNode>();
+}
-BlockDependenceInfo::BlockDependenceInfo(IRModule mod) {
- ObjectPtr<BlockDependenceInfoNode> n =
ffi::make_object<BlockDependenceInfoNode>();
- BlockDependenceInfoNode* self = n.get();
+SBlockDependenceInfo::SBlockDependenceInfo(IRModule mod) {
+ ObjectPtr<SBlockDependenceInfoNode> n =
ffi::make_object<SBlockDependenceInfoNode>();
+ SBlockDependenceInfoNode* self = n.get();
n->stmt2ref = SRefTreeCreator::Create(mod, /* include_loops */ false);
for (const auto& kv : mod->functions) {
const BaseFunc& base_func = kv.second;
if (auto opt = base_func.as<PrimFunc>()) {
auto func = opt.value();
- BlockDependenceInfoCollector::Collect(self, func->body);
+ SBlockDependenceInfoCollector::Collect(self, func->body);
}
}
data_ = std::move(n);
@@ -90,12 +92,12 @@ BlockDependenceInfo::BlockDependenceInfo(IRModule mod) {
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
- .def("tir.SBlockDependenceInfo",
- [](IRModule mod) -> BlockDependenceInfo { return
BlockDependenceInfo(mod); })
- .def_method("tir.SBlockDependenceInfoGetSBlockScope",
- &BlockDependenceInfoNode::GetSBlockScope)
- .def("tir.SBlockDependenceInfoGetSRef",
- [](BlockDependenceInfo self, Stmt stmt) -> ffi::Optional<StmtSRef> {
+ .def("s_tir.SBlockDependenceInfo",
+ [](IRModule mod) -> SBlockDependenceInfo { return
SBlockDependenceInfo(mod); })
+ .def_method("s_tir.SBlockDependenceInfoGetSBlockScope",
+ &SBlockDependenceInfoNode::GetSBlockScope)
+ .def("s_tir.SBlockDependenceInfoGetSRef",
+ [](SBlockDependenceInfo self, Stmt stmt) -> ffi::Optional<StmtSRef>
{
auto it = self->stmt2ref.find(stmt.get());
return it != self->stmt2ref.end() ? it->second :
ffi::Optional<StmtSRef>(std::nullopt);
});
diff --git a/src/s_tir/schedule/primitive/blockize_tensorize.cc
b/src/s_tir/schedule/primitive/blockize_tensorize.cc
index ac3c63159b..f7cfda5479 100644
--- a/src/s_tir/schedule/primitive/blockize_tensorize.cc
+++ b/src/s_tir/schedule/primitive/blockize_tensorize.cc
@@ -16,10 +16,10 @@
* specific language governing permissions and limitations
* under the License.
*/
-#include <tvm/tir/data_type_rewriter.h>
#include <functional>
+#include "../../../tir/ir/data_type_rewriter.h"
#include "../../../tir/transform/simplify.h"
#include "../ir_comparator.h"
#include "../utils.h"
diff --git a/src/s_tir/schedule/utils.h b/src/s_tir/schedule/utils.h
index a82978435d..c3abdae6c9 100644
--- a/src/s_tir/schedule/utils.h
+++ b/src/s_tir/schedule/utils.h
@@ -27,11 +27,11 @@
#include <tvm/s_tir/schedule/schedule.h>
#include <tvm/s_tir/schedule/state.h>
#include <tvm/s_tir/schedule/trace.h>
+#include <tvm/s_tir/utils.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
-#include <tvm/tir/utils.h>
#include <string>
#include <unordered_map>
diff --git a/src/te/operation/create_primfunc.cc
b/src/te/operation/create_primfunc.cc
index b7f85b226b..6499eac91e 100644
--- a/src/te/operation/create_primfunc.cc
+++ b/src/te/operation/create_primfunc.cc
@@ -25,7 +25,6 @@
#include <tvm/ir/name_supply.h>
#include <tvm/te/operation.h>
#include <tvm/tir/analysis.h>
-#include <tvm/tir/data_type_rewriter.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
@@ -37,6 +36,7 @@
#include <vector>
#include "../../support/array.h"
+#include "../../tir/ir/data_type_rewriter.h"
#include "../../tir/ir/functor_common.h"
#include "../../tir/transform/ir_utils.h"
#include "graph.h"
diff --git a/src/tir/analysis/stmt_finding.cc b/src/tir/analysis/stmt_finding.cc
index 58879277e9..6093c5da9f 100644
--- a/src/tir/analysis/stmt_finding.cc
+++ b/src/tir/analysis/stmt_finding.cc
@@ -64,92 +64,5 @@ const PrimFuncNode* FindEntryFunc(const IRModule& mod,
GlobalVar* result_g_var)
return nullptr;
}
-Stmt GetEnclosingLoop(const SBlockNode* block, Stmt func_body) {
- struct GetRootSeqStmt : public StmtVisitor {
- void VisitStmt_(const SeqStmtNode* seq) override { result = seq; }
- const SeqStmtNode* result;
- };
-
- struct BlockFinder : public StmtVisitor {
- explicit BlockFinder(const SBlockNode* tgt) : target(tgt) {}
-
- void VisitStmt_(const SBlockNode* block) override {
- if (block == target) {
- found = true;
- }
- }
-
- const SBlockNode* target;
- bool found = false;
- };
-
- GetRootSeqStmt seq_finder;
- seq_finder(func_body);
-
- ICHECK(seq_finder.result);
-
- for (auto stmt : seq_finder.result->seq) {
- if (stmt->IsInstance<ForNode>()) {
- BlockFinder finder(block);
- finder(stmt);
- if (finder.found) {
- return stmt;
- }
- }
- }
-
- LOG(FATAL) << "Enclosing loop not found for a block " <<
ffi::GetRef<SBlock>(block);
- TVM_FFI_UNREACHABLE();
-}
-
-const SBlockNode* FindAnchorBlock(const IRModule& mod) {
- struct ReductionSBlockCollector : public StmtVisitor {
- void VisitStmt_(const SBlockNode* block) override {
- if (block->init) {
- blocks.push_back(block);
- }
- StmtVisitor::VisitStmt(block->body);
- }
- std::vector<const SBlockNode*> blocks;
- };
-
- if (auto prim_func = FindEntryFunc(mod, nullptr)) {
- ReductionSBlockCollector collector;
- collector(prim_func->body);
-
- const auto& candidates = collector.blocks;
-
- if (candidates.empty()) {
- return nullptr;
- } else if (candidates.size() == 1) {
- return candidates[0];
- }
-
- double best_flops = -1;
- int best_idx = 0;
- for (size_t i = 0; i < candidates.size(); ++i) {
- auto loop = GetEnclosingLoop(candidates[i], prim_func->body);
- auto flops = EstimateTIRFlops(loop);
- if (flops > best_flops) {
- best_flops = flops;
- best_idx = i;
- }
- }
- return candidates[best_idx];
- }
- return nullptr;
-}
-
-TVM_FFI_STATIC_INIT_BLOCK() {
- namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("tir.analysis.find_anchor_sblock", [](const IRModule&
mod) {
- auto ret = FindAnchorBlock(mod);
- if (ret) {
- return ffi::Optional<SBlock>(ffi::GetRef<SBlock>(ret));
- }
- return ffi::Optional<SBlock>(std::nullopt);
- });
-}
-
} // namespace tir
} // namespace tvm
diff --git a/src/tir/ir/block_scope.cc b/src/tir/ir/block_scope.cc
index 8b2675936f..dc8d12d0ac 100644
--- a/src/tir/ir/block_scope.cc
+++ b/src/tir/ir/block_scope.cc
@@ -17,8 +17,8 @@
* under the License.
*/
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/tir/block_scope.h>
-#include <tvm/tir/utils.h>
+#include <tvm/s_tir/sblock_scope.h>
+#include <tvm/s_tir/utils.h>
namespace tvm {
namespace tir {
@@ -196,18 +196,18 @@ void SRefTreeCreator::VisitStmt_(const SeqStmtNode*
seq_stmt) {
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
- .def("tir.StmtSRefStmt",
+ .def("s_tir.StmtSRefStmt",
[](StmtSRef sref) -> ffi::Optional<Stmt> {
return ffi::GetRef<ffi::Optional<Stmt>>(sref->stmt);
})
- .def("tir.StmtSRefParent",
+ .def("s_tir.StmtSRefParent",
[](StmtSRef sref) -> ffi::Optional<StmtSRef> {
return ffi::GetRef<ffi::Optional<StmtSRef>>(sref->parent);
})
- .def("tir.StmtSRefRootMark", StmtSRef::RootMark)
- .def("tir.StmtSRefInlineMark", StmtSRef::InlineMark)
- .def_method("tir.SBlockScopeGetDepsBySrc",
&SBlockScopeNode::GetDepsBySrc)
- .def_method("tir.SBlockScopeGetDepsByDst",
&SBlockScopeNode::GetDepsByDst);
+ .def("s_tir.StmtSRefRootMark", StmtSRef::RootMark)
+ .def("s_tir.StmtSRefInlineMark", StmtSRef::InlineMark)
+ .def_method("s_tir.SBlockScopeGetDepsBySrc",
&SBlockScopeNode::GetDepsBySrc)
+ .def_method("s_tir.SBlockScopeGetDepsByDst",
&SBlockScopeNode::GetDepsByDst);
}
} // namespace tir
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index d6d7c1e609..f7f6f7256a 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -22,10 +22,14 @@
* \brief Rewrite the data type of expressions.
*/
+#include "data_type_rewriter.h"
+
#include <tvm/tir/builtin.h>
-#include <tvm/tir/data_type_rewriter.h>
#include <tvm/tir/op.h>
+#include <algorithm>
+#include <utility>
+
#include "./functor_common.h"
#include "tvm/ir/expr.h"
#include "tvm/tir/expr.h"
diff --git a/include/tvm/tir/data_type_rewriter.h
b/src/tir/ir/data_type_rewriter.h
similarity index 98%
rename from include/tvm/tir/data_type_rewriter.h
rename to src/tir/ir/data_type_rewriter.h
index e100eeb590..0662c60fa1 100644
--- a/include/tvm/tir/data_type_rewriter.h
+++ b/src/tir/ir/data_type_rewriter.h
@@ -21,8 +21,8 @@
* \file data_type_rewriter.h
* \brief Rewrite the data type of expressions.
*/
-#ifndef TVM_TIR_DATA_TYPE_REWRITER_H_
-#define TVM_TIR_DATA_TYPE_REWRITER_H_
+#ifndef TVM_TIR_IR_DATA_TYPE_REWRITER_H_
+#define TVM_TIR_IR_DATA_TYPE_REWRITER_H_
#include <tvm/tir/stmt_functor.h>
@@ -165,4 +165,4 @@ class IndexDataTypeNormalizer : public
IndexDataTypeRewriter {
} // namespace tir
} // namespace tvm
-#endif // TVM_TIR_DATA_TYPE_REWRITER_H_
+#endif // TVM_TIR_IR_DATA_TYPE_REWRITER_H_
diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc
index a2f3c46aba..2d4ec3a1ca 100644
--- a/src/tir/ir/stmt_functor.cc
+++ b/src/tir/ir/stmt_functor.cc
@@ -22,12 +22,12 @@
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/ir/module.h>
-#include <tvm/tir/data_type_rewriter.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
#include <functional>
+#include "data_type_rewriter.h"
#include "functor_common.h"
namespace tvm {
diff --git a/src/tir/transform/force_narrow_index_to_i32.cc
b/src/tir/transform/force_narrow_index_to_i32.cc
index cee6018150..e908d35125 100644
--- a/src/tir/transform/force_narrow_index_to_i32.cc
+++ b/src/tir/transform/force_narrow_index_to_i32.cc
@@ -24,10 +24,11 @@
*/
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/tir/data_type_rewriter.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
+#include "../ir/data_type_rewriter.h"
+
namespace tvm {
namespace tir {
diff --git a/src/tir/transform/narrow_datatype.cc
b/src/tir/transform/narrow_datatype.cc
index 31e5cb348e..8d03f8c157 100644
--- a/src/tir/transform/narrow_datatype.cc
+++ b/src/tir/transform/narrow_datatype.cc
@@ -25,12 +25,12 @@
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h>
-#include <tvm/tir/data_type_rewriter.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
#include "../../arith/ir_mutator_with_analyzer.h"
#include "../../arith/ir_visitor_with_analyzer.h"
+#include "../ir/data_type_rewriter.h"
namespace tvm {
namespace tir {
diff --git a/tests/cpp/data_type_rewriter_test.cc
b/tests/cpp/data_type_rewriter_test.cc
deleted file mode 100644
index b7575812fe..0000000000
--- a/tests/cpp/data_type_rewriter_test.cc
+++ /dev/null
@@ -1,140 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include <gtest/gtest.h>
-#include <tvm/tir/builtin.h>
-#include <tvm/tir/data_type_rewriter.h>
-#include <tvm/tir/op.h>
-
-using namespace tvm;
-using namespace tvm::tir;
-using namespace tvm::runtime;
-
-using BinaryOpTypes =
- ::testing::Types<Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod, Min, Max,
EQ, NE, LT, LE, GT, GE>;
-
-template <typename T>
-class DataTypeLegalizerBinaryOp : public ::testing::Test {};
-
-TYPED_TEST_SUITE(DataTypeLegalizerBinaryOp, BinaryOpTypes);
-
-TYPED_TEST(DataTypeLegalizerBinaryOp, Basic) {
- using RefType = TypeParam;
- using NodeType = typename RefType::ContainerType;
- auto node = ffi::make_object<NodeType>();
- node->a = Var("a", DataType::Int(32));
- node->b = IntImm(DataType::Int(64), 2);
- DataTypeLegalizer legalizer;
- auto new_expr = Downcast<RefType>(legalizer(RefType(node)));
- auto target_dtype = DataType::Int(64);
- ASSERT_EQ(new_expr->a.dtype(), target_dtype);
- ASSERT_EQ(new_expr->b.dtype(), target_dtype);
-}
-
-TEST(DataTypeLegalizer, Select) {
- auto node = ffi::make_object<SelectNode>();
- node->condition = Var("cond", DataType::Bool());
- node->true_value = Var("a", DataType::Int(64));
- node->false_value = IntImm(DataType::Int(32), 2);
- DataTypeLegalizer legalizer;
- Select new_select = Downcast<Select>(legalizer(Select(node)));
- auto target_dtype = DataType::Int(64);
- ASSERT_EQ(new_select->true_value.dtype(), target_dtype);
- ASSERT_EQ(new_select->false_value.dtype(), target_dtype);
- ASSERT_EQ(new_select.dtype(), target_dtype);
- ASSERT_EQ(new_select->condition.dtype(), node->condition.dtype());
-}
-TEST(DataTypeLegalizer, IfThenElse) {
- auto cond = Var("cond", DataType::Bool());
- PrimExpr call = Call(DataType::Int(32), builtin::if_then_else(),
- {cond, Var("a", DataType::Int(64)),
IntImm(DataType::Int(32), 2)});
- DataTypeLegalizer legalizer;
- Call new_call = Downcast<Call>(legalizer(call));
- auto target_dtype = DataType::Int(64);
- ASSERT_EQ(new_call->args[1].dtype(), target_dtype);
- ASSERT_EQ(new_call->args[2].dtype(), target_dtype);
- ASSERT_EQ(new_call->dtype, target_dtype);
-}
-
-TEST(DataTypeLegalizer, Block) {
- auto block_node = ffi::make_object<SBlockNode>();
- auto iter_var_node = ffi::make_object<IterVarNode>();
- iter_var_node->var = Var("i", DataType::Int(32));
- iter_var_node->dom =
- Range::FromMinExtent(IntImm(DataType::Int(64), 0),
IntImm(DataType::Int(64), 10));
- iter_var_node->iter_type = IterVarType::kDataPar;
- block_node->iter_vars = {IterVar(iter_var_node)};
- block_node->reads = {};
- block_node->writes = {};
- block_node->name_hint = "block";
- block_node->body = Evaluate(Integer(0));
- auto block_realize_node = ffi::make_object<SBlockRealizeNode>();
- auto loop_var = Var("i", DataType::Int(32));
- block_realize_node->iter_values = {loop_var};
- block_realize_node->predicate = const_true();
- block_realize_node->block = SBlock(block_node);
- auto for_node = ffi::make_object<ForNode>();
- for_node->loop_var = loop_var;
- for_node->min = IntImm(DataType::Int(64), 0);
- for_node->extent = IntImm(DataType::Int(64), 10);
- for_node->kind = ForKind::kSerial;
- for_node->body = SBlockRealize(block_realize_node);
- Stmt stmt = For(for_node);
-
- DataTypeLegalizer legalizer;
- DataType target_dtype = loop_var->dtype;
- Stmt new_stmt = legalizer(stmt);
- const ForNode* new_for = new_stmt.as<ForNode>();
- ASSERT_EQ(new_for->loop_var.dtype(), target_dtype);
- ASSERT_EQ(new_for->min.dtype(), target_dtype);
- ASSERT_EQ(new_for->extent.dtype(), target_dtype);
- const SBlockRealizeNode* new_block_realize =
new_for->body.as<SBlockRealizeNode>();
- ASSERT_EQ(new_block_realize->iter_values[0].dtype(), target_dtype);
- const SBlockNode* new_block = new_block_realize->block.as<SBlockNode>();
- ASSERT_EQ(new_block->iter_vars[0]->dom->min.dtype(), target_dtype);
- ASSERT_EQ(new_block->iter_vars[0]->dom->extent.dtype(), target_dtype);
- ASSERT_EQ(new_block->iter_vars[0]->var.dtype(), target_dtype);
-}
-
-TEST(DataTypeLegalizer, For) {
- auto node = ffi::make_object<ForNode>();
- node->body = Evaluate(Integer(0));
- node->loop_var = Var("i", DataType::Int(32));
- node->min = IntImm(DataType::Int(64), 0);
- node->extent = IntImm(DataType::Int(64), 10);
- DataTypeLegalizer legalizer;
- For new_for = Downcast<For>(legalizer(For(node)));
- ASSERT_EQ(new_for->min.dtype(), DataType::Int(32));
- ASSERT_EQ(new_for->extent.dtype(), DataType::Int(32));
- ASSERT_EQ(new_for->loop_var.dtype(), DataType::Int(32));
-}
-
-TEST(DataTypeLegalizer, Ramp) {
- auto node = ffi::make_object<RampNode>();
- node->base = IntImm(DataType::Int(64), 0);
- node->stride = IntImm(DataType::Int(32), 1);
- int lanes = 4;
- node->lanes = lanes;
- DataTypeLegalizer legalizer;
- Ramp new_ramp = Downcast<Ramp>(legalizer(Ramp(node)));
- DataType target_dtype = DataType::Int(64);
- ASSERT_EQ(new_ramp->base.dtype(), target_dtype);
- ASSERT_EQ(new_ramp->stride.dtype(), target_dtype);
- ASSERT_EQ(new_ramp->dtype, target_dtype.with_lanes(lanes));
-}
diff --git
a/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py
b/tests/python/s_tir/analysis/test_sblock_access_region.py
similarity index 91%
rename from
tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py
rename to tests/python/s_tir/analysis/test_sblock_access_region.py
index 40c84f1956..2427f4eb6e 100644
--- a/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py
+++ b/tests/python/s_tir/analysis/test_sblock_access_region.py
@@ -18,7 +18,7 @@ import pytest
import tvm
import tvm.testing
-from tvm import tir
+from tvm import s_tir
from tvm.ir import Range
from tvm.script import tir as T
@@ -210,7 +210,7 @@ def test_block_access_region_detector():
block = func.body.block.body.block
alloc_buffers = func.body.block.alloc_buffers
buffer_var_map = {buf.data: buf for buf in alloc_buffers}
- ret = tir.analysis.get_sblock_access_region(block, buffer_var_map)
+ ret = s_tir.analysis.get_sblock_access_region(block, buffer_var_map)
tvm.ir.assert_structural_equal(block.reads, ret[0])
tvm.ir.assert_structural_equal(block.writes, ret[1])
@@ -225,12 +225,12 @@ def test_opaque_block():
buffer_var_map = {buf.data: buf for buf in alloc_buffers}
block0 = opaque_block_func.body.block.body.body.block
- ret = tir.analysis.get_sblock_access_region(block0, buffer_var_map)
+ ret = s_tir.analysis.get_sblock_access_region(block0, buffer_var_map)
tvm.ir.assert_structural_equal(block0.reads, ret[0])
tvm.ir.assert_structural_equal(block0.writes, ret[1])
block1 = block0.body.body.block
- ret = tir.analysis.get_sblock_access_region(block1, buffer_var_map)
+ ret = s_tir.analysis.get_sblock_access_region(block1, buffer_var_map)
tvm.ir.assert_structural_equal(block1.reads, ret[0])
tvm.ir.assert_structural_equal(block1.writes, ret[1])
@@ -240,8 +240,8 @@ def test_opaque_access():
alloc_buffers = opaque_access_func.body.block.alloc_buffers
buffer_var_map = {buf.data: buf for buf in alloc_buffers}
- ret0 = tir.analysis.get_sblock_read_write_region(block, buffer_var_map)
- ret1 = tir.analysis.get_sblock_access_region(block, buffer_var_map)
+ ret0 = s_tir.analysis.get_sblock_read_write_region(block, buffer_var_map)
+ ret1 = s_tir.analysis.get_sblock_access_region(block, buffer_var_map)
with pytest.raises(ValueError):
tvm.ir.assert_structural_equal(ret0[0], ret1[0])
with pytest.raises(ValueError):
@@ -253,8 +253,8 @@ def test_opaque_access_with_tvm_access_ptr():
alloc_buffers =
opaque_access_with_tvm_access_ptr_func.body.block.alloc_buffers
buffer_var_map = {buf.data: buf for buf in alloc_buffers}
- ret0 = tir.analysis.get_sblock_read_write_region(block, buffer_var_map)
- ret1 = tir.analysis.get_sblock_access_region(block, buffer_var_map)
+ ret0 = s_tir.analysis.get_sblock_read_write_region(block, buffer_var_map)
+ ret1 = s_tir.analysis.get_sblock_access_region(block, buffer_var_map)
tvm.ir.assert_structural_equal(block.reads, ret0[0])
tvm.ir.assert_structural_equal(block.writes, ret0[1])
with pytest.raises(ValueError):
@@ -271,13 +271,13 @@ def test_match_buffer():
buffer_var_map = {buf.data: buf for buf in alloc_buffers}
# Check block
- ret = tir.analysis.get_sblock_access_region(block, buffer_var_map)
+ ret = s_tir.analysis.get_sblock_access_region(block, buffer_var_map)
tvm.ir.assert_structural_equal(block.writes, ret[1])
# B is opaque access
tvm.ir.assert_structural_equal(block.reads, ret[2])
# Check inner block AAA without updating buffer_var_map
- ret = tir.analysis.get_sblock_access_region(block_inner, buffer_var_map)
+ ret = s_tir.analysis.get_sblock_access_region(block_inner, buffer_var_map)
# Since AA is not in the buffer_var_map, region of AA will not be
collected.
tvm.ir.assert_structural_equal([], ret[1])
@@ -286,7 +286,7 @@ def test_match_buffer():
target_buffer = match_buffer.buffer
buffer_var_map[target_buffer.data] = target_buffer
- ret = tir.analysis.get_sblock_access_region(block_inner, buffer_var_map)
+ ret = s_tir.analysis.get_sblock_access_region(block_inner, buffer_var_map)
tvm.ir.assert_structural_equal(block_inner.reads, ret[0])
tvm.ir.assert_structural_equal(block_inner.writes, ret[1])
@@ -295,8 +295,8 @@ def test_access_in_if_then_else_func():
block = access_in_if_then_else_func.body.block.body.block
alloc_buffers = access_in_if_then_else_func.body.block.alloc_buffers
buffer_var_map = {buf.data: buf for buf in alloc_buffers}
- ret0 = tir.analysis.get_sblock_read_write_region(block, buffer_var_map)
- ret1 = tir.analysis.get_sblock_access_region(block, buffer_var_map)
+ ret0 = s_tir.analysis.get_sblock_read_write_region(block, buffer_var_map)
+ ret1 = s_tir.analysis.get_sblock_access_region(block, buffer_var_map)
tvm.ir.assert_structural_equal(ret0[0], ret1[0])
tvm.ir.assert_structural_equal(ret0[1], ret1[1])
@@ -305,8 +305,8 @@ def test_access_in_branch_func():
block = access_in_branch_func.body.block.body.block
alloc_buffers = access_in_branch_func.body.block.alloc_buffers
buffer_var_map = {buf.data: buf for buf in alloc_buffers}
- ret0 = tir.analysis.get_sblock_read_write_region(block, buffer_var_map)
- ret1 = tir.analysis.get_sblock_access_region(block, buffer_var_map)
+ ret0 = s_tir.analysis.get_sblock_read_write_region(block, buffer_var_map)
+ ret1 = s_tir.analysis.get_sblock_access_region(block, buffer_var_map)
tvm.ir.assert_structural_equal(ret0[0], ret1[0])
tvm.ir.assert_structural_equal(ret0[1], ret1[1])
@@ -327,7 +327,7 @@ def test_access_of_padding_pattern():
block = s.get_sref(s.get_sblock(block_name)).stmt
expect_reads = block.reads
expect_writes = block.writes
- ret = tir.analysis.get_sblock_access_region(block, buffer_var_map)
+ ret = s_tir.analysis.get_sblock_access_region(block, buffer_var_map)
for i, read in enumerate(ret[0]):
do_compare_buffer_region(read, expect_reads[i])
for i, write in enumerate(ret[1]):
@@ -341,7 +341,7 @@ def test_access_of_reduction():
block = gemm.body.block.body.body.body.body.body.body.block
alloc_buffers = gemm.body.block.alloc_buffers
buffer_var_map = {buf.data: buf for buf in alloc_buffers}
- ret = tir.analysis.get_sblock_access_region(block, buffer_var_map)
+ ret = s_tir.analysis.get_sblock_access_region(block, buffer_var_map)
tvm.ir.assert_structural_equal(block.reads, ret[0])
tvm.ir.assert_structural_equal(block.writes, ret[1])
@@ -352,7 +352,7 @@ def test_access_of_decompose_reduction():
alloc_buffers = decomposed_gemm.body.block.alloc_buffers
buffer_var_map = {buf.data: buf for buf in alloc_buffers}
for block in [init, update]:
- ret = tir.analysis.get_sblock_access_region(block, buffer_var_map)
+ ret = s_tir.analysis.get_sblock_access_region(block, buffer_var_map)
tvm.ir.assert_structural_equal(block.reads, ret[0])
tvm.ir.assert_structural_equal(block.writes, ret[1])
@@ -380,7 +380,7 @@ def test_buffer_access_with_let_binding():
block = func.body.block.body.body.body.block
buffer_var_map = {buf.data: buf for buf in func.buffer_map.values()}
- ret = tir.analysis.get_sblock_access_region(block, buffer_var_map)
+ ret = s_tir.analysis.get_sblock_access_region(block, buffer_var_map)
tvm.ir.assert_structural_equal(block.reads, ret[0])
tvm.ir.assert_structural_equal(block.writes, ret[1])
@@ -406,7 +406,7 @@ def test_buffer_access_with_nested_let_binding():
block = func.body.block.body.body.body.block
buffer_var_map = {buf.data: buf for buf in func.buffer_map.values()}
- ret = tir.analysis.get_sblock_access_region(block, buffer_var_map)
+ ret = s_tir.analysis.get_sblock_access_region(block, buffer_var_map)
tvm.ir.assert_structural_equal(block.reads, ret[0])
tvm.ir.assert_structural_equal(block.writes, ret[1])
diff --git
a/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py
b/tests/python/s_tir/analysis/test_sblock_buffer_access_lca.py
similarity index 95%
rename from
tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py
rename to tests/python/s_tir/analysis/test_sblock_buffer_access_lca.py
index 9a06e610f2..f1fe5ece97 100644
--- a/tests/python/tir-analysis/test_tir_analysis_detect_buffer_access_lca.py
+++ b/tests/python/s_tir/analysis/test_sblock_buffer_access_lca.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
-from tvm import tir
+from tvm import s_tir
from tvm.script import tir as T
@@ -110,7 +110,7 @@ def test_buffer_load_store():
func = buffer_load_store_func
A, B = [func.buffer_map[x] for x in func.params]
C, D = func.body.block.alloc_buffers
- lca = tir.analysis.detect_buffer_access_lca(func)
+ lca = s_tir.analysis.detect_buffer_access_lca(func)
# LCA of Buffer A is root
root_block = func.body.block
@@ -133,7 +133,7 @@ def test_buffer_load_store():
def test_opaque_access():
func = buffer_opaque_access
B, C = [func.buffer_map[x] for x in func.params]
- lca = tir.analysis.detect_buffer_access_lca(func)
+ lca = s_tir.analysis.detect_buffer_access_lca(func)
# Cannot detect buffer A since it is define by low-level Allocate
@@ -148,14 +148,14 @@ def test_opaque_access():
def test_lca_func_root():
func = lca_is_func_root
(A,) = [func.buffer_map[x] for x in func.params]
- lca = tir.analysis.detect_buffer_access_lca(func)
+ lca = s_tir.analysis.detect_buffer_access_lca(func)
assert lca[A] is None
def test_match_buffer():
func = match_buffer_func
A, B = [func.buffer_map[x] for x in func.params]
- lca = tir.analysis.detect_buffer_access_lca(func)
+ lca = s_tir.analysis.detect_buffer_access_lca(func)
root_block = func.body.block
block = root_block.body.body.body.block
@@ -171,7 +171,7 @@ def test_match_buffer():
def test_global_buffer_with_blockidx():
func = global_buffer_with_blockidx
A, B = [func.buffer_map[x] for x in func.params]
- lca = tir.analysis.detect_buffer_access_lca(func)
+ lca = s_tir.analysis.detect_buffer_access_lca(func)
root_block = func.body.block
blockidx_loop = root_block.body
diff --git a/tests/python/s_tir/base/test_tir_block_dependence_info.py
b/tests/python/s_tir/base/test_sblock_dependence_info.py
similarity index 99%
rename from tests/python/s_tir/base/test_tir_block_dependence_info.py
rename to tests/python/s_tir/base/test_sblock_dependence_info.py
index c39be195e1..6b288d80d5 100644
--- a/tests/python/s_tir/base/test_tir_block_dependence_info.py
+++ b/tests/python/s_tir/base/test_sblock_dependence_info.py
@@ -26,7 +26,7 @@ from tvm.ir import IRModule
from tvm.script import tir as T
from tvm.tir import PrimFunc
from tvm.s_tir import SBlockDependenceInfo
-from tvm.s_tir.block_scope import DepKind
+from tvm.s_tir.sblock_scope import DepKind
from tvm.tir.stmt_functor import post_order_visit
# pylint: disable=no-member,invalid-name,unused-variable
diff --git a/tests/python/tir-base/test_tir_data_layout.py
b/tests/python/s_tir/base/test_tir_data_layout.py
similarity index 76%
rename from tests/python/tir-base/test_tir_data_layout.py
rename to tests/python/s_tir/base/test_tir_data_layout.py
index a76cb50da3..397c6d673c 100644
--- a/tests/python/tir-base/test_tir_data_layout.py
+++ b/tests/python/s_tir/base/test_tir_data_layout.py
@@ -23,9 +23,9 @@ from tvm.topi.utils import get_const_tuple
def test_layout():
- layout = tvm.tir.layout("NCHW16c")
+ layout = tvm.s_tir.layout("NCHW16c")
assert layout is not None
- assert isinstance(layout, tvm.tir.Layout)
+ assert isinstance(layout, tvm.s_tir.Layout)
assert layout.factor_of("c") == 16
assert layout.factor_of("C") == 16
@@ -54,7 +54,7 @@ def test_layout():
def test_layout_dtype():
- layout_i32 = tvm.tir.layout("NCHW")
+ layout_i32 = tvm.s_tir.layout("NCHW")
assert layout_i32.axes[0].var.dtype == "int32"
assert layout_i32.axes[0].dom.min.dtype == "int32"
assert layout_i32.axes[0].dom.extent.dtype == "int32"
@@ -62,7 +62,7 @@ def test_layout_dtype():
assert layout_i32.axes[1].dom.min.dtype == "int32"
assert layout_i32.axes[1].dom.extent.dtype == "int32"
- layout_i64 = tvm.tir.layout("NCHW", dtype="int64")
+ layout_i64 = tvm.s_tir.layout("NCHW", dtype="int64")
assert layout_i64.axes[2].var.dtype == "int64"
assert layout_i64.axes[2].dom.min.dtype == "int64"
assert layout_i64.axes[2].dom.extent.dtype == "int64"
@@ -71,27 +71,27 @@ def test_layout_dtype():
assert layout_i64.axes[3].dom.extent.dtype == "int64"
with pytest.raises(TypeError):
- tvm.tir.layout("NCHW", dtype="float32")
+ tvm.s_tir.layout("NCHW", dtype="float32")
with pytest.raises(TypeError):
- tvm.tir.layout("NCHW", dtype=None)
+ tvm.s_tir.layout("NCHW", dtype=None)
def test_bilayout_convertible():
# not convertible
- assert tvm.tir.bijective_layout("NCHW", "ABCD") is None
- assert tvm.tir.bijective_layout("__undef__", "NCHW") is None
- assert tvm.tir.bijective_layout("NCHW", "__undef__") is None
- assert tvm.tir.bijective_layout("__undef__", "__undef__") is None
- assert tvm.tir.bijective_layout("", "NCHW") is None
- assert tvm.tir.bijective_layout("NCHW", "") is None
- assert tvm.tir.bijective_layout("", "") is None
+ assert tvm.s_tir.bijective_layout("NCHW", "ABCD") is None
+ assert tvm.s_tir.bijective_layout("__undef__", "NCHW") is None
+ assert tvm.s_tir.bijective_layout("NCHW", "__undef__") is None
+ assert tvm.s_tir.bijective_layout("__undef__", "__undef__") is None
+ assert tvm.s_tir.bijective_layout("", "NCHW") is None
+ assert tvm.s_tir.bijective_layout("NCHW", "") is None
+ assert tvm.s_tir.bijective_layout("", "") is None
# convertible
- assert tvm.tir.bijective_layout("NCHW", "NCHW16c") is not None
+ assert tvm.s_tir.bijective_layout("NCHW", "NCHW16c") is not None
def test_bilayout_shape():
- bilayout = tvm.tir.bijective_layout("NCHW", "NCHW16c")
- assert isinstance(bilayout, tvm.tir.BijectiveLayout)
+ bilayout = tvm.s_tir.bijective_layout("NCHW", "NCHW16c")
+ assert isinstance(bilayout, tvm.s_tir.BijectiveLayout)
dst_shape = bilayout.forward_shape((1, 32, 7, 7))
assert get_const_tuple(dst_shape) == (1, 2, 7, 7, 16)
@@ -101,7 +101,7 @@ def test_bilayout_shape():
def test_bilayout_index():
- bilayout = tvm.tir.bijective_layout("NCHW", "NCHW16c")
+ bilayout = tvm.s_tir.bijective_layout("NCHW", "NCHW16c")
dst_index = bilayout.forward_index([0, 18, 6, 6])
assert get_const_tuple(dst_index) == (0, 1, 6, 6, 2)
diff --git a/tests/python/tir-base/test_tir_te_extern_primfunc.py
b/tests/python/s_tir/base/test_tir_te_extern_primfunc.py
similarity index 100%
rename from tests/python/tir-base/test_tir_te_extern_primfunc.py
rename to tests/python/s_tir/base/test_tir_te_extern_primfunc.py
diff --git a/tests/scripts/task_python_unittest.sh
b/tests/scripts/task_python_unittest.sh
index 3619b51e79..36d15a3a85 100755
--- a/tests/scripts/task_python_unittest.sh
+++ b/tests/scripts/task_python_unittest.sh
@@ -50,6 +50,7 @@ TEST_FILES=(
"s_tir/base"
"s_tir/schedule"
"s_tir/dlight"
+ "s_tir/analysis"
"tir-analysis"
"tir-base"
"tir-transform"