This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a60cd0fecf [TIR] Allow symbolic bounds in IndexMap analysis (#15264)
a60cd0fecf is described below
commit a60cd0fecfdbc56187c7632469c5f444f835e99f
Author: Junru Shao <[email protected]>
AuthorDate: Sat Jul 8 21:50:05 2023 -0700
[TIR] Allow symbolic bounds in IndexMap analysis (#15264)
This PR adds the bounds of shape variables to the arithmetic analyzer
so that it is possible to simplify certain expressions.
---
include/tvm/tir/index_map.h | 12 +--
include/tvm/topi/transform.h | 9 +-
pyproject.toml | 3 +
python/tvm/te/schedule.py | 11 +-
python/tvm/tir/function.py | 20 ++--
python/tvm/tir/schedule/schedule.py | 40 ++++++--
python/tvm/tir/schedule/testing.py | 6 +-
src/arith/ir_mutator_with_analyzer.cc | 20 ++++
src/arith/ir_mutator_with_analyzer.h | 6 ++
src/arith/iter_affine_map.cc | 6 ++
src/relay/backend/te_compiler_cache.cc | 4 +-
src/relay/op/tensor/transform.cc | 4 +-
src/runtime/logging.cc | 3 +
src/target/source/codegen_cuda.cc | 3 +-
src/te/schedule/message_passing.cc | 14 ++-
src/te/schedule/schedule_lang.cc | 6 +-
src/tir/ir/index_map.cc | 64 ++++++------
src/tir/schedule/analysis.h | 10 ++
src/tir/schedule/analysis/analysis.cc | 14 +++
src/tir/schedule/primitive/cache_read_write.cc | 7 +-
src/tir/schedule/primitive/compute_at.cc | 14 ---
.../schedule/primitive/layout_transformation.cc | 81 ++++++++-------
src/tir/schedule/transform.cc | 22 ++--
src/tir/transforms/flatten_buffer.cc | 13 +--
src/tir/transforms/storage_flatten.cc | 7 +-
src/tir/transforms/transform_mma_buffer_layout.cc | 6 +-
.../python/unittest/test_arith_iter_affine_map.py | 4 +-
.../test_meta_schedule_relay_integration.py | 20 +++-
...meta_schedule_schedule_cuda_layout_transform.py | 6 +-
.../test_meta_schedule_schedule_rule_mlt_tc.py | 19 ++--
.../unittest/test_meta_schedule_trace_apply.py | 10 +-
.../python/unittest/test_meta_schedule_tune_tir.py | 5 +
...sform_layout.py => test_te_transform_layout.py} | 0
.../{test_index_map.py => test_tir_index_map.py} | 11 +-
.../unittest/test_tir_schedule_transform_layout.py | 112 +++++++++++++++++++--
35 files changed, 411 insertions(+), 181 deletions(-)
diff --git a/include/tvm/tir/index_map.h b/include/tvm/tir/index_map.h
index 796e04e74a..340d953ccf 100644
--- a/include/tvm/tir/index_map.h
+++ b/include/tvm/tir/index_map.h
@@ -102,8 +102,7 @@ class IndexMapNode : public Object {
* \returns The indices in the output space. Contains one value for
* each expression in `final_indices`.
*/
- Array<PrimExpr> MapIndices(const Array<PrimExpr>& indices,
- arith::Analyzer* analyzer = nullptr) const;
+ Array<PrimExpr> MapIndices(const Array<PrimExpr>& indices, arith::Analyzer*
analyzer) const;
/*! \brief Map a memory range to the output space
*
@@ -121,7 +120,7 @@ class IndexMapNode : public Object {
* \returns The ranges in the output space. Contains one value for
* each expression in `final_indices`.
*/
- Array<Range> MapRanges(const Array<Range>& ranges, arith::Analyzer* analyzer
= nullptr) const;
+ Array<Range> MapRanges(const Array<Range>& ranges, arith::Analyzer*
analyzer) const;
/*! \brief Map a buffer shape to the output space
*
@@ -134,7 +133,7 @@ class IndexMapNode : public Object {
* \returns The buffer shape in the output space. Contains one
* value for each expression in `final_indices`.
*/
- Array<PrimExpr> MapShape(const Array<PrimExpr>& shape, arith::Analyzer*
analyzer = nullptr) const;
+ Array<PrimExpr> MapShape(const Array<PrimExpr>& shape, arith::Analyzer*
analyzer) const;
/* \brief Map an NDArray according to this index map
*
@@ -203,7 +202,7 @@ class IndexMap : public ObjectRef {
* If the user has supplied an `inverse_index_map`, that map is
* assumed to be correct and bijective, and is returned.
*/
- IndexMap Inverse(Array<Range> initial_ranges) const;
+ IndexMap Inverse(Array<Range> initial_ranges, arith::Analyzer* analyzer)
const;
/*! \brief Rename the variables in the index map and ensure the names are
unique.
*
@@ -225,7 +224,8 @@ class IndexMap : public ObjectRef {
* \return The inverted index map, along with the predicate for
* which the inverse maps to a valid range.
*/
- std::pair<IndexMap, PrimExpr> NonSurjectiveInverse(Array<Range>
initial_ranges) const;
+ std::pair<IndexMap, PrimExpr> NonSurjectiveInverse(Array<Range>
initial_ranges,
+ arith::Analyzer*
analyzer) const;
TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode);
};
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index cab3466765..d881c4f423 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -24,6 +24,7 @@
#ifndef TVM_TOPI_TRANSFORM_H_
#define TVM_TOPI_TRANSFORM_H_
+#include <tvm/arith/analyzer.h>
#include <tvm/te/operation.h>
#include <tvm/tir/data_layout.h>
#include <tvm/tir/index_map.h>
@@ -1738,16 +1739,18 @@ inline Tensor auto_scheduler_layout_transform(const
Tensor& src, const String& s
inline Tensor meta_schedule_layout_transform(const Tensor& src, const
tir::IndexMap& index_map,
const String name =
"T_meta_schedule_layout_trans",
const String tag = kInjective) {
+ arith::Analyzer analyzer;
Array<Range> iter_domain;
iter_domain.reserve(src->shape.size());
for (const PrimExpr& e : src->shape) {
iter_domain.push_back(Range::FromMinExtent(make_zero(e->dtype), e));
}
- Array<PrimExpr> post_transform_shape = index_map->MapShape(src->shape);
+ Array<PrimExpr> post_transform_shape = index_map->MapShape(src->shape,
&analyzer);
return compute(
post_transform_shape,
- [src, inv = index_map.Inverse(iter_domain)](const Array<Var>& indices)
-> PrimExpr {
- return src(inv->MapIndices(Array<PrimExpr>{indices.begin(),
indices.end()}));
+ [src, inv = index_map.Inverse(iter_domain, &analyzer),
+ &analyzer](const Array<Var>& indices) -> PrimExpr {
+ return src(inv->MapIndices(Array<PrimExpr>{indices.begin(),
indices.end()}, &analyzer));
},
name, tag);
}
diff --git a/pyproject.toml b/pyproject.toml
index 5cca711ddb..91740f2b4b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,6 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+[tool.isort]
+profile = "black"
+src_paths = ["python", "tests/python"]
[tool.black]
line-length = 100
diff --git a/python/tvm/te/schedule.py b/python/tvm/te/schedule.py
index 3dbf2cefe4..0b6abc2566 100644
--- a/python/tvm/te/schedule.py
+++ b/python/tvm/te/schedule.py
@@ -22,13 +22,12 @@ from typing import Callable, List
import tvm._ffi
from tvm._ffi.base import string_types
-
-from tvm.runtime import Object, convert
from tvm.ir import container as _container
-from tvm.tir import IterVar, Buffer, Var, IndexMap
+from tvm.runtime import Object, convert
+from tvm.tir import Buffer, IndexMap, IterVar, Var
-from . import tensor as _tensor
from . import _ffi_api
+from . import tensor as _tensor
@tvm._ffi.register_object
@@ -600,7 +599,9 @@ class Stage(Object):
"""
ndim = len(self.op.output(0).shape)
- index_map, axis_separators =
IndexMap.from_func_with_separators(mapping_function, ndim=ndim)
+ index_map, axis_separators = IndexMap.from_func_with_separators(
+ mapping_function, ndim=ndim, index_dtype="int32"
+ )
new_iter_vars = _ffi_api.StageTransformLayout(
self, index_map.initial_indices, index_map.final_indices
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index 32ec347039..bd44e3f7c3 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -67,7 +67,6 @@ class PrimFunc(BaseFunc, Scriptable):
attrs=None,
span=None,
):
-
param_list = []
buffer_map = {} if buffer_map is None else buffer_map
for x in params:
@@ -266,6 +265,8 @@ class IndexMap(Object):
mapping_function: Callable,
ndim: Optional[int] = None,
inverse_index_map: Union[Callable, Optional["IndexMap"]] = None,
+ *,
+ index_dtype: str = "int64",
):
"""Create an index map from a function
@@ -302,7 +303,10 @@ class IndexMap(Object):
"""
index_map, axis_separators = IndexMap.from_func_with_separators(
- mapping_function, ndim, inverse_index_map
+ mapping_function,
+ ndim,
+ inverse_index_map,
+ index_dtype=index_dtype,
)
assert not axis_separators, (
"The mapping_function provided to IndexMap.from_func "
@@ -316,6 +320,8 @@ class IndexMap(Object):
mapping_function: Callable,
ndim: Optional[int] = None,
inverse_index_map: Union[Callable, Optional["IndexMap"]] = None,
+ *,
+ index_dtype: str = "int64",
):
"""Create an index map from a function
@@ -346,6 +352,9 @@ class IndexMap(Object):
It is the user's responsibility to ensure the correctness of the
pre-defined inverse
index map.
+ index_dtype : str
+ The default index dtype to use for input iters in the mapping
function.
+
Returns
-------
ret: Tuple[IndexMap, List[int]]
@@ -361,20 +370,19 @@ class IndexMap(Object):
args = []
var_arg_name = None
kwargs = collections.OrderedDict()
- default_index_dtype = "int32"
for name, param in params.items():
if param.kind in [
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
]:
- args.append(tvm.tir.Var(name, default_index_dtype))
+ args.append(tvm.tir.Var(name, index_dtype))
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
var_arg_name = name
elif param.kind == inspect.Parameter.KEYWORD_ONLY:
- kwargs[name] = tvm.tir.Var(name, default_index_dtype)
+ kwargs[name] = tvm.tir.Var(name, index_dtype)
else:
raise ValueError("transform_layout mapping may not have *args")
@@ -386,7 +394,7 @@ class IndexMap(Object):
assert ndim is not None, "ndim must be specified when *args is
used"
num_var_args = ndim - len(args) - len(kwargs)
for i in range(num_var_args):
- args.append(tvm.tir.Var(f"{var_arg_name}_{i}",
default_index_dtype))
+ args.append(tvm.tir.Var(f"{var_arg_name}_{i}", index_dtype))
mapping = mapping_function(*args, **kwargs)
diff --git a/python/tvm/tir/schedule/schedule.py
b/python/tvm/tir/schedule/schedule.py
index 4f717474e0..6c42f15a2f 100644
--- a/python/tvm/tir/schedule/schedule.py
+++ b/python/tvm/tir/schedule/schedule.py
@@ -93,6 +93,15 @@ def _parse_seed(seed: Optional[int]) -> int:
return seed
+def _get_block_default_dtype(block: Block) -> str:
+ for i in block.iter_vars:
+ return i.var.dtype
+ for buffer_region in list(block.reads) + list(block.writes):
+ for dom in buffer_region.region:
+ return dom.min.dtype
+ return "int64"
+
+
@_register_object("tir.Schedule")
class Schedule(Object):
"""The user-facing schedule class
@@ -1492,7 +1501,10 @@ class Schedule(Object):
block = self._normalize_block_arg(block)
if callable(index_map):
- index_map = IndexMap.from_func(index_map)
+ index_map = IndexMap.from_func(
+ index_map,
+ index_dtype=_get_block_default_dtype(self.get(block)),
+ )
return _ffi_api.ScheduleReindexCacheRead( # type: ignore # pylint:
disable=no-member
self, block, read_buffer_index, storage_scope, index_map
)
@@ -1589,7 +1601,10 @@ class Schedule(Object):
block = self._normalize_block_arg(block)
if callable(index_map):
- index_map = IndexMap.from_func(index_map)
+ index_map = IndexMap.from_func(
+ index_map,
+ index_dtype=_get_block_default_dtype(self.get(block)),
+ )
return _ffi_api.ScheduleReindexCacheWrite( # type: ignore # pylint:
disable=no-member
self, block, write_buffer_index, storage_scope, index_map
)
@@ -3246,14 +3261,22 @@ class Schedule(Object):
ndim = len(buffer_obj.shape)
if callable(index_map):
- index_map, axis_separators =
IndexMap.from_func_with_separators(index_map, ndim=ndim)
+ index_map, axis_separators = IndexMap.from_func_with_separators(
+ index_map,
+ ndim=ndim,
+ index_dtype=_get_block_default_dtype(self.get(block)),
+ )
else:
axis_separators = []
if pad_value is None:
pass
elif callable(pad_value):
- pad_value = IndexMap.from_func(pad_value,
ndim=len(index_map.final_indices))
+ pad_value = IndexMap.from_func(
+ pad_value,
+ ndim=len(index_map.final_indices),
+ index_dtype=_get_block_default_dtype(self.get(block)),
+ )
elif not isinstance(pad_value, IndexMap):
# Explicitly convert python int/float arguments to the
# buffer's type. If the default `tvm.runtime.convert`
@@ -3264,7 +3287,9 @@ class Schedule(Object):
elif "float" in buffer_obj.dtype and isinstance(pad_value, float):
pad_value = FloatImm(buffer_obj.dtype, pad_value)
pad_value = IndexMap.from_func(
- lambda *indices: pad_value, ndim=len(index_map.final_indices)
+ lambda *indices: pad_value,
+ ndim=len(index_map.final_indices),
+ index_dtype=_get_block_default_dtype(self.get(block)),
)
buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
@@ -3337,7 +3362,10 @@ class Schedule(Object):
"""
block = self._normalize_block_arg(block)
if callable(index_map):
- index_map = IndexMap.from_func(index_map)
+ index_map = IndexMap.from_func(
+ index_map,
+ index_dtype=_get_block_default_dtype(self.get(block)),
+ )
_ffi_api.ScheduleTransformBlockLayout( # type: ignore # pylint:
disable=no-member
self, block, index_map
)
diff --git a/python/tvm/tir/schedule/testing.py
b/python/tvm/tir/schedule/testing.py
index f38a657123..a293b54b46 100644
--- a/python/tvm/tir/schedule/testing.py
+++ b/python/tvm/tir/schedule/testing.py
@@ -51,6 +51,8 @@ def verify_trace_roundtrip(
The text format or formats whose round-trip behavior should be
validated. If a single string, validate round-trips through
"""
+ from tvm.script import tir as T # pylint: disable=import-outside-toplevel
+
if not isinstance(text_format, str):
for opt in text_format:
new_sch = verify_trace_roundtrip(sch, mod, debug_mask=debug_mask,
text_format=opt)
@@ -66,7 +68,9 @@ def verify_trace_roundtrip(
Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch)
elif text_format == "python":
py_trace = "\n".join(trace.as_python())
- exec(py_trace, tvm.tir.__dict__, {"sch": new_sch}) # pylint:
disable=exec-used
+ vars_dict = {"T": T}
+ vars_dict.update(tvm.tir.__dict__)
+ exec(py_trace, vars_dict, {"sch": new_sch}) # pylint:
disable=exec-used
else:
assert text_format in ("json", "python"), f"Unknown text format:
{text_format}"
diff --git a/src/arith/ir_mutator_with_analyzer.cc
b/src/arith/ir_mutator_with_analyzer.cc
index 1f087d9934..2ee427beb8 100644
--- a/src/arith/ir_mutator_with_analyzer.cc
+++ b/src/arith/ir_mutator_with_analyzer.cc
@@ -22,6 +22,7 @@
*/
#include "ir_mutator_with_analyzer.h"
+#include <tvm/arith/iter_affine_map.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
@@ -39,6 +40,25 @@ void IRMutatorWithAnalyzer::MarkBufferMapShapes(const
tir::PrimFunc& func) {
}
}
+Array<PrimExpr> IRMutatorWithAnalyzer::IterMapSimplifyWithContext(const
Array<PrimExpr>& indices,
+ bool
non_trivial_only) {
+ PrimExpr pred = const_true();
+ for (PrimExpr val : iter_predicates_) {
+ pred = pred && val;
+ }
+ int n = indices.size();
+ Array<PrimExpr> simplified = arith::IterMapSimplify(
+ indices, this->iter_vars_, pred, arith::IterMapLevel::Surjective,
this->analyzer_);
+ if (non_trivial_only) {
+ for (int i = 0; i < n; ++i) {
+ if (simplified[i]->IsInstance<IntImmNode>() &&
indices[i]->IsInstance<VarNode>()) {
+ simplified.Set(i, indices[i]);
+ }
+ }
+ }
+ return simplified;
+}
+
Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
// record the loop variable as iterators
Range dom = Range::FromMinExtent(op->min, op->extent);
diff --git a/src/arith/ir_mutator_with_analyzer.h
b/src/arith/ir_mutator_with_analyzer.h
index f04b40e7ae..fb01fd19ce 100644
--- a/src/arith/ir_mutator_with_analyzer.h
+++ b/src/arith/ir_mutator_with_analyzer.h
@@ -70,6 +70,12 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
*/
void MarkBufferMapShapes(const tir::PrimFunc& func);
+ /*!
+ * \brief Use internal bound information to perform inter map simplification
of indices.
+ * \note Only do this during layout remapping
+ */
+ Array<PrimExpr> IterMapSimplifyWithContext(const Array<PrimExpr>& indices,
bool non_trivial_only);
+
/*! \brief internal analyzer field. */
Analyzer* analyzer_;
// the following two fields are useful in case we want
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index f8a36daf53..af1128aa27 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -2061,6 +2061,12 @@ class IterMapToExprNormalizer : public ExprMutator {
if (analyzer_->CanProve(expr->extent == expr->source->extent) &&
is_one(expr->lower_factor)) {
return source * expr->scale;
} else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor
* expr->extent)) {
+ // Simplify if `expr` is always 0. The 2nd condition guarantess that we
do not aggressively
+ // simplify trivial iters like `vi \in [0, 1)`, which can be useful for
subsequent analysis
+ // like tensorization.
+ if (is_one(expr->extent) && !is_one(expr->source->extent)) {
+ return make_const(expr->extent->dtype, 0);
+ }
return floordiv(source, expr->lower_factor) * expr->scale;
} else {
return floordiv(floormod(source, expr->lower_factor * expr->extent),
expr->lower_factor) *
diff --git a/src/relay/backend/te_compiler_cache.cc
b/src/relay/backend/te_compiler_cache.cc
index 275d1b6bf7..b747855bff 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -19,6 +19,7 @@
#include "./te_compiler_cache.h"
+#include <tvm/arith/analyzer.h>
#include <tvm/driver/driver_api.h>
#include <tvm/ir/name_supply.h>
#include <tvm/ir/type_functor.h>
@@ -594,7 +595,8 @@ class ScheduleBuilder : public ExprVisitor {
src_size_1d *= c->shape[i];
orig_shape.push_back(PrimExpr(static_cast<int>((c->shape[i]))));
}
- auto dst_shape = index_map->MapShape(orig_shape);
+ arith::Analyzer analyzer;
+ auto dst_shape = index_map->MapShape(orig_shape,
&analyzer);
std::vector<int64_t> dst_shape_int;
size_t dst_size_1d = 1;
for (size_t i = 0; i < dst_shape.size(); ++i) {
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index a0111ff7cd..fde6daa4d8 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -23,6 +23,7 @@
*/
#include "transform.h"
+#include <tvm/arith/analyzer.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
@@ -3434,9 +3435,10 @@ Array<te::Tensor>
MetaScheduleLayoutTransformCompute(const Attrs& attrs,
bool MetaScheduleLayoutTransformRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
TensorType data_type = Downcast<TensorType>(types[0]);
+ arith::Analyzer analyzer;
const MetaScheduleLayoutTransformAttrs* params =
attrs.as<MetaScheduleLayoutTransformAttrs>();
ICHECK(params);
- Array<PrimExpr> new_shape = params->index_map->MapShape(data_type->shape);
+ Array<PrimExpr> new_shape = params->index_map->MapShape(data_type->shape,
&analyzer);
reporter->Assign(types[1], TensorType(new_shape, data_type->dtype));
return true;
}
diff --git a/src/runtime/logging.cc b/src/runtime/logging.cc
index 5e7431e510..04b25f764c 100644
--- a/src/runtime/logging.cc
+++ b/src/runtime/logging.cc
@@ -130,6 +130,9 @@ int BacktraceFullCallback(void* data, uintptr_t pc, const
char* filename, int li
backtrace_syminfo(_bt_state, pc, BacktraceSyminfoCallback,
BacktraceErrorCallback,
symbol_str.get());
}
+ if (filename == nullptr && strstr(symbol_str.get()->data(), "ffi_call_")) {
+ return 0;
+ }
s << *symbol_str;
if (filename != nullptr) {
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index cd0ec0e34f..22103f7b0f 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -898,8 +898,9 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op,
std::ostream& os) {
runtime::Registry::Get("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout");
ICHECK(index_map_func);
+ arith::Analyzer analyzer;
auto inverse_index_map =
- IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0,
n)});
+ IndexMap::FromFunc(2, *index_map_func).Inverse({Range(0, m), Range(0,
n)}, &analyzer);
auto indices_16x16 = inverse_index_map->final_indices;
// "//" and "%" in the index map are translated to FloorDiv/Mod, but the
plain Div/Mod are fine.
diff --git a/src/te/schedule/message_passing.cc
b/src/te/schedule/message_passing.cc
index 7041b751c5..233663feac 100644
--- a/src/te/schedule/message_passing.cc
+++ b/src/te/schedule/message_passing.cc
@@ -193,7 +193,7 @@ void PassDownDomain(const Stage& stage,
std::unordered_map<IterVar, Range>* p_st
for (const auto& iter_var : s->original_variables) {
original_ranges.push_back(state[iter_var]);
}
- Array<Range> updated_ranges =
s->forward_transformation->MapRanges(original_ranges);
+ Array<Range> updated_ranges =
s->forward_transformation->MapRanges(original_ranges, actx);
ICHECK_EQ(updated_ranges.size(), s->transformed_variables.size());
for (size_t i = 0; i < updated_ranges.size(); i++) {
@@ -269,6 +269,7 @@ void PassUpIndex(const Stage& stage, const Map<IterVar,
Range>& dom_map,
}
} else if (rel.as<SingletonNode>()) {
} else if (const TransformNode* s = rel.as<TransformNode>()) {
+ arith::Analyzer analyzer;
bool missing_transformed = false;
for (const auto& iter_var : s->transformed_variables) {
if (!state.count(iter_var)) {
@@ -284,7 +285,8 @@ void PassUpIndex(const Stage& stage, const Map<IterVar,
Range>& dom_map,
for (const auto& iter_var : s->transformed_variables) {
transformed_indices.push_back(state[iter_var]);
}
- Array<PrimExpr> original_indices =
s->inverse_transformation->MapIndices(transformed_indices);
+ Array<PrimExpr> original_indices =
+ s->inverse_transformation->MapIndices(transformed_indices,
&analyzer);
ICHECK_EQ(original_indices.size(), s->original_variables.size());
for (size_t i = 0; i < original_indices.size(); i++) {
@@ -352,7 +354,9 @@ void PassDownIndex(const Stage& stage, const Map<IterVar,
Range>& dom_map,
for (const auto& iter_var : s->original_variables) {
original_indices.push_back(state[iter_var]);
}
- Array<PrimExpr> transformed_indices =
s->forward_transformation->MapIndices(original_indices);
+ arith::Analyzer analyzer;
+ Array<PrimExpr> transformed_indices =
+ s->forward_transformation->MapIndices(original_indices, &analyzer);
ICHECK_EQ(transformed_indices.size(), s->transformed_variables.size());
for (size_t i = 0; i < transformed_indices.size(); i++) {
@@ -449,7 +453,9 @@ Array<IntSet> PassUpDomain(const TransformNode* s,
transformed_indices.push_back(iter_var->var);
}
- Array<PrimExpr> transformed_exprs =
s->inverse_transformation->MapIndices(transformed_indices);
+ arith::Analyzer analyzer;
+ Array<PrimExpr> transformed_exprs =
+ s->inverse_transformation->MapIndices(transformed_indices, &analyzer);
ICHECK_EQ(transformed_exprs.size(), s->original_variables.size());
for (size_t i = 0; i < transformed_exprs.size(); i++) {
diff --git a/src/te/schedule/schedule_lang.cc b/src/te/schedule/schedule_lang.cc
index 56fe0cfc65..44e742eee4 100644
--- a/src/te/schedule/schedule_lang.cc
+++ b/src/te/schedule/schedule_lang.cc
@@ -21,6 +21,7 @@
* \file schedule_lang.cc
*/
#include <dmlc/thread_local.h>
+#include <tvm/arith/analyzer.h>
#include <tvm/ir/transform.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
@@ -491,10 +492,11 @@ Stage& Stage::transform_layout(const Array<Var>&
initial_indices,
for (const auto& iter_var : compute->axis) {
initial_ranges.push_back(iter_var->dom);
}
- Array<Range> final_ranges = map->MapRanges(initial_ranges);
+ arith::Analyzer analyzer;
+ Array<Range> final_ranges = map->MapRanges(initial_ranges, &analyzer);
// Make IterVar objects to represent the new iterations.
- auto inverse = map.Inverse(initial_ranges);
+ auto inverse = map.Inverse(initial_ranges, &analyzer);
Array<IterVar> final_indices_iter;
ICHECK_EQ(inverse->initial_indices.size(), final_ranges.size());
for (size_t i = 0; i < inverse->initial_indices.size(); i++) {
diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc
index a39149ceba..149e4cecd4 100644
--- a/src/tir/ir/index_map.cc
+++ b/src/tir/ir/index_map.cc
@@ -55,7 +55,9 @@ IndexMap IndexMap::FromFunc(int ndim,
runtime::TypedPackedFunc<Array<PrimExpr>(A
std::pair<IndexMap, PrimExpr> IndexMapInverseImpl(const IndexMap& self,
const Array<Range>&
initial_ranges,
- arith::IterMapLevel
check_level) {
+ arith::IterMapLevel
check_level,
+ arith::Analyzer* analyzer) {
+ ICHECK(analyzer != nullptr);
if (self->inverse_index_map.defined()) {
// return the pre-defined inverse index map if exists. In this
// case, the user-defined inverse is assumed to be correct and
@@ -88,9 +90,8 @@ std::pair<IndexMap, PrimExpr> IndexMapInverseImpl(const
IndexMap& self,
// Unpack the output indices into linear combinations of the initial
// indices.
- arith::Analyzer analyzer;
- auto padded_iter_map = DetectIterMap(self->final_indices, input_iters, /*
predicate = */ 1,
- /*check_level=*/check_level, &analyzer,
+ auto padded_iter_map = DetectIterMap(self->final_indices, input_iters,
/*predicate=*/1,
+ /*check_level=*/check_level, analyzer,
/*simplify_trivial_iterators=*/false);
CHECK(padded_iter_map->errors.empty()) << "Could not parse mapping as sum of
iterators. "
<< "Error: " <<
padded_iter_map->errors[0];
@@ -110,15 +111,15 @@ std::pair<IndexMap, PrimExpr> IndexMapInverseImpl(const
IndexMap& self,
} else {
expr = inverse_exprs_map.at(index);
}
- inverse_exprs.push_back(analyzer.Simplify(expr));
+ inverse_exprs.push_back(analyzer->Simplify(expr));
}
PrimExpr padding_predicate = padded_iter_map->padding_predicate;
padding_predicate = arith::NormalizeIterMapToExpr(padding_predicate);
padding_predicate = Substitute(padding_predicate, inverse_exprs_map);
+ auto output_ranges = self->MapRanges(initial_ranges, analyzer);
{
- auto output_ranges = self->MapRanges(initial_ranges);
ICHECK_EQ(output_ranges.size(), output_vars.size());
arith::Analyzer analyzer;
@@ -133,15 +134,17 @@ std::pair<IndexMap, PrimExpr> IndexMapInverseImpl(const
IndexMap& self,
return {IndexMap(output_vars, inverse_exprs), padding_predicate};
}
-std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range>
initial_ranges) const {
- return IndexMapInverseImpl(*this, initial_ranges,
arith::IterMapLevel::NoCheck);
+std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range>
initial_ranges,
+ arith::Analyzer*
analyzer) const {
+ ICHECK(analyzer != nullptr);
+ return IndexMapInverseImpl(*this, initial_ranges,
arith::IterMapLevel::NoCheck, analyzer);
}
-IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const {
+IndexMap IndexMap::Inverse(Array<Range> initial_ranges, arith::Analyzer*
analyzer) const {
+ ICHECK(analyzer != nullptr);
auto [inverse, padding_predicate] =
- IndexMapInverseImpl(*this, initial_ranges,
arith::IterMapLevel::Bijective);
- arith::Analyzer analyzer;
- CHECK(analyzer.CanProve(!padding_predicate))
+ IndexMapInverseImpl(*this, initial_ranges,
arith::IterMapLevel::Bijective, analyzer);
+ CHECK(analyzer->CanProve(!padding_predicate))
<< "Bijective inverse should not contain padding, but inverse of " <<
*this << " over range "
<< initial_ranges << " resulted in a padding predicate of " <<
padding_predicate;
return inverse;
@@ -149,6 +152,7 @@ IndexMap IndexMap::Inverse(Array<Range> initial_ranges)
const {
Array<PrimExpr> IndexMapNode::MapIndices(const Array<PrimExpr>& indices,
arith::Analyzer* analyzer) const {
+ ICHECK(analyzer != nullptr);
ICHECK_EQ(indices.size(), initial_indices.size());
Map<Var, PrimExpr> vmap;
@@ -157,11 +161,6 @@ Array<PrimExpr> IndexMapNode::MapIndices(const
Array<PrimExpr>& indices,
vmap.Set(initial_indices[i], indices[i]);
}
- arith::Analyzer local_analyzer;
- if (!analyzer) {
- analyzer = &local_analyzer;
- }
-
Array<PrimExpr> output = final_indices.Map([&](PrimExpr index) {
PrimExpr result = SubstituteWithDataTypeLegalization(
std::move(index), [&](const Var& var) { return vmap.Get(var); });
@@ -171,18 +170,13 @@ Array<PrimExpr> IndexMapNode::MapIndices(const
Array<PrimExpr>& indices,
}
Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges,
arith::Analyzer* analyzer) const {
+ ICHECK(analyzer != nullptr);
ICHECK_EQ(ranges.size(), initial_indices.size());
Map<Var, Range> input_iters;
for (size_t i = 0; i < initial_indices.size(); i++) {
input_iters.Set(initial_indices[i], ranges[i]);
}
-
- arith::Analyzer local_analyzer;
- if (!analyzer) {
- analyzer = &local_analyzer;
- }
-
auto iter_map = DetectIterMap(final_indices, input_iters, /* predicate = */
1,
/*check_level=*/arith::IterMapLevel::NoCheck,
analyzer,
/*simplify_trivial_iterators=*/false);
@@ -240,6 +234,7 @@ Array<Range> IndexMapNode::MapRanges(const Array<Range>&
ranges, arith::Analyzer
Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape,
arith::Analyzer* analyzer) const {
+ ICHECK(analyzer != nullptr);
ICHECK_EQ(shape.size(), initial_indices.size());
Array<Range> ranges;
@@ -258,6 +253,7 @@ Array<PrimExpr> IndexMapNode::MapShape(const
Array<PrimExpr>& shape,
}
runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const {
+ arith::Analyzer analyzer;
auto shape = arr_src.Shape();
ICHECK(shape.size() == initial_indices.size())
<< "The rank of the input array should be " << initial_indices.size() <<
" but got "
@@ -268,7 +264,7 @@ runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray
arr_src) const {
size_1d *= shape[i];
orig_shape.push_back(PrimExpr(static_cast<int>((shape[i]))));
}
- auto dst_shape = MapShape(orig_shape);
+ auto dst_shape = MapShape(orig_shape, &analyzer);
std::vector<int64_t> dst_shape_int;
for (size_t i = 0; i < dst_shape.size(); ++i) {
@@ -292,7 +288,7 @@ runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray
arr_src) const {
src_indices.push_back(PrimExpr(static_cast<int>((src_linear_index /
div_factor))));
src_linear_index %= div_factor;
}
- auto dst_indices = MapIndices(src_indices);
+ auto dst_indices = MapIndices(src_indices, &analyzer);
// Convert an N-d coordinate to a linear coordinate
// (z, y, x) -> z * height * width + y * width + x
@@ -430,19 +426,29 @@ TVM_REGISTER_GLOBAL("tir.IndexMap")
});
TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices")
- .set_body_typed([](IndexMap map, Array<PrimExpr> indices) { return
map->MapIndices(indices); });
+ .set_body_typed([](IndexMap map, Array<PrimExpr> indices) {
+ arith::Analyzer analyzer;
+ return map->MapIndices(indices, &analyzer);
+ });
TVM_REGISTER_GLOBAL("tir.IndexMapMapShape").set_body_typed([](IndexMap map,
Array<PrimExpr> shape) {
- return map->MapShape(shape);
+ arith::Analyzer analyzer;
+ return map->MapShape(shape, &analyzer);
});
-TVM_REGISTER_GLOBAL("tir.IndexMapInverse").set_body_method(&IndexMap::Inverse);
+
+TVM_REGISTER_GLOBAL("tir.IndexMapInverse")
+ .set_body_typed([](IndexMap map, Array<Range> initial_ranges) {
+ arith::Analyzer analyzer;
+ return map.Inverse(initial_ranges, &analyzer);
+ });
TVM_REGISTER_GLOBAL("tir.IndexMapMapNDArray")
.set_body_typed([](IndexMap map, runtime::NDArray arr) { return
map->MapNDArray(arr); });
TVM_REGISTER_GLOBAL("tir.IndexMapNonSurjectiveInverse")
.set_body_typed([](IndexMap forward, Array<Range> initial_ranges) {
- auto result = forward.NonSurjectiveInverse(initial_ranges);
+ arith::Analyzer analyzer;
+ auto result = forward.NonSurjectiveInverse(initial_ranges, &analyzer);
return Array<ObjectRef>{result.first, result.second};
});
diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h
index 6e361cbe05..868fbe8563 100644
--- a/src/tir/schedule/analysis.h
+++ b/src/tir/schedule/analysis.h
@@ -72,6 +72,16 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod,
const StmtNode* root_bl
*/
StmtSRef GetSRefTreeRoot(const StmtSRef& sref);
+/*!
+ * \brief Given an arbitrary sref, bind the shape var info of the PrimFunc it
belongs to to the
+ * given analyzer
+ * \param state The schedule state
+ * \param sref The given sref
+ * \param analyzer The analyzer to be bound
+ */
+void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref,
+ arith::Analyzer* analyzer);
+
/******** Scope ********/
/*!
* \brief Checks if scope the specified sref is in is a stage-pipeline and
return it
diff --git a/src/tir/schedule/analysis/analysis.cc
b/src/tir/schedule/analysis/analysis.cc
index a765eb4d9f..17cc39d5cb 100644
--- a/src/tir/schedule/analysis/analysis.cc
+++ b/src/tir/schedule/analysis/analysis.cc
@@ -1269,6 +1269,20 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref) {
return GetRef<StmtSRef>(p);
}
+void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref,
+ arith::Analyzer* analyzer) {
+ while (sref->parent != nullptr) {
+ sref = sref->parent;
+ }
+ const PrimFuncNode* f = GetRootPrimFunc(state->mod, sref->stmt, nullptr);
+ for (const auto& kv : f->buffer_map) {
+ const Buffer& buffer = kv.second;
+ for (const PrimExpr& e : buffer->shape) {
+ analyzer->MarkGlobalNonNegValue(e);
+ }
+ }
+}
+
/******** Misc ********/
bool HasOp(const Stmt& stmt, const Array<Op>& ops) {
diff --git a/src/tir/schedule/primitive/cache_read_write.cc
b/src/tir/schedule/primitive/cache_read_write.cc
index 6f9aa11275..3fbdf856b5 100644
--- a/src/tir/schedule/primitive/cache_read_write.cc
+++ b/src/tir/schedule/primitive/cache_read_write.cc
@@ -1909,13 +1909,14 @@ void CollectReindexCacheStageInfoAndCreateBuffer(
ReindexCacheStageInfo* info, const IRModule& mod, const StmtSRef&
block_sref,
const String& storage_scope, const IndexMap& index_map, const Block& block,
const BlockRealize& realize, const Buffer& old_buffer, const BufferRegion&
cache_region) {
+ arith::Analyzer analyzer;
Array<PrimExpr> block_iter_vars, block_shape;
for (const IterVar& iter_var : block->iter_vars) {
block_iter_vars.push_back(iter_var);
block_shape.push_back(iter_var->dom->extent);
}
- Array<PrimExpr> new_indices = index_map->MapIndices(block_iter_vars);
- Array<PrimExpr> new_shape = index_map->MapShape(block_shape);
+ Array<PrimExpr> new_indices = index_map->MapIndices(block_iter_vars,
&analyzer);
+ Array<PrimExpr> new_shape = index_map->MapShape(block_shape, &analyzer);
info->indices = new_indices;
// Step 5. Update CacheTouchedInfo
@@ -1926,8 +1927,6 @@ void CollectReindexCacheStageInfoAndCreateBuffer(
old_indices.push_back(range->min);
}
- arith::Analyzer analyzer;
-
VarUseDefAnalyzer collector_new(/*defined_vars=*/{});
for (const PrimExpr& idx : new_indices) {
collector_new(idx);
diff --git a/src/tir/schedule/primitive/compute_at.cc
b/src/tir/schedule/primitive/compute_at.cc
index 45d0c81050..fc388b0048 100644
--- a/src/tir/schedule/primitive/compute_at.cc
+++ b/src/tir/schedule/primitive/compute_at.cc
@@ -680,20 +680,6 @@ void CalculateProvidedRequiredRegions(
/******** Main Implementation ********/
-void AddShapeVarBounds(const ScheduleState& state, const StmtSRefNode* sref,
- arith::Analyzer* analyzer) {
- while (sref->parent != nullptr) {
- sref = sref->parent;
- }
- const PrimFuncNode* f = GetRootPrimFunc(state->mod, sref->stmt, nullptr);
- for (const auto& kv : f->buffer_map) {
- const Buffer& buffer = kv.second;
- for (const PrimExpr& e : buffer->shape) {
- analyzer->MarkGlobalNonNegValue(e);
- }
- }
-}
-
template <bool is_compute_at>
void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef&
block_sref,
const StmtSRef& loop_sref, bool
preserve_unit_loops,
diff --git a/src/tir/schedule/primitive/layout_transformation.cc
b/src/tir/schedule/primitive/layout_transformation.cc
index bb2abc559d..d9a9f3cfed 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -17,6 +17,7 @@
* under the License.
*/
+#include <tvm/arith/analyzer.h>
#include <tvm/node/node.h>
#include <optional>
@@ -93,12 +94,12 @@ class TransformLayoutPlanner : private StmtExprVisitor {
static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer,
IndexMap index_map,
IndexMap inverse, PrimExpr padding_predicate,
- Optional<IndexMap> pad_value) {
+ Optional<IndexMap> pad_value, arith::Analyzer*
analyzer) {
ICHECK(!pad_value.defined() || pad_value.value()->final_indices.size() ==
1)
<< "Internal error: Should be caught by ScheduleError checks prior to
this point";
TransformLayoutPlanner visitor(old_buffer);
visitor(block);
- return visitor.Finalize(new_buffer, index_map, inverse, padding_predicate,
pad_value);
+ return visitor.Finalize(new_buffer, index_map, inverse, padding_predicate,
pad_value, analyzer);
}
private:
@@ -220,14 +221,15 @@ class TransformLayoutPlanner : private StmtExprVisitor {
public:
BufferStoreReplacer(const WriteInfo& info, const Buffer& new_buffer,
PrimExpr padding_predicate,
const IndexMap& inverse, const Optional<IndexMap>&
pad_value,
- Map<Block, Block>* new_block_to_old)
+ Map<Block, Block>* new_block_to_old, arith::Analyzer*
analyzer)
: info(info),
new_buffer(new_buffer),
new_indices(inverse->initial_indices),
padding_predicate(padding_predicate),
inverse(inverse),
pad_value(pad_value),
- new_block_to_old(*new_block_to_old) {
+ new_block_to_old(*new_block_to_old),
+ analyzer(analyzer) {
ICHECK_EQ(info.dependent_loopnest.size(), inverse->final_indices.size());
for (size_t i = 0; i < info.dependent_loopnest.size(); i++) {
Var var = info.dependent_loopnest[i]->loop_var;
@@ -353,7 +355,7 @@ class TransformLayoutPlanner : private StmtExprVisitor {
if (can_replace) {
Array<PrimExpr> new_index_exprs =
new_indices.Map([](const auto& var) -> PrimExpr { return var; });
- PrimExpr pad_value_at_index =
pad_value.value()->MapIndices(new_index_exprs)[0];
+ PrimExpr pad_value_at_index =
pad_value.value()->MapIndices(new_index_exprs, analyzer)[0];
store =
BufferStore(new_buffer, if_then_else(padding_predicate,
pad_value_at_index, op->value),
new_index_exprs);
@@ -429,22 +431,24 @@ class TransformLayoutPlanner : private StmtExprVisitor {
const Optional<IndexMap>& pad_value;
Map<Block, Block>& new_block_to_old;
bool all_stores_replaced{true};
+ arith::Analyzer* analyzer;
Map<Var, PrimExpr> var_remap;
};
TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap
inverse,
- PrimExpr padding_predicate, Optional<IndexMap>
pad_value) const {
- if (auto prologue_plan =
- FinalizeProloguePlan(new_buffer, index_map, inverse,
padding_predicate, pad_value);
+ PrimExpr padding_predicate, Optional<IndexMap>
pad_value,
+ arith::Analyzer* analyzer) const {
+ if (auto prologue_plan = FinalizeProloguePlan(new_buffer, index_map,
inverse, padding_predicate,
+ pad_value, analyzer);
prologue_plan.has_value()) {
return prologue_plan.value();
- } else if (auto replacement_plan = FinalizeReplacementPlan(new_buffer,
index_map, inverse,
-
padding_predicate, pad_value);
+ } else if (auto replacement_plan = FinalizeReplacementPlan(
+ new_buffer, index_map, inverse, padding_predicate,
pad_value, analyzer);
replacement_plan.has_value()) {
return replacement_plan.value();
} else if (auto epilogue_plan = FinalizeEpiloguePlan(new_buffer,
index_map, inverse,
- padding_predicate,
pad_value);
+ padding_predicate,
pad_value, analyzer);
epilogue_plan.has_value()) {
return epilogue_plan.value();
} else {
@@ -454,7 +458,8 @@ class TransformLayoutPlanner : private StmtExprVisitor {
std::optional<ProloguePlan> FinalizeProloguePlan(Buffer new_buffer, IndexMap
index_map,
IndexMap inverse, PrimExpr
padding_predicate,
- Optional<IndexMap>
pad_value) const {
+ Optional<IndexMap>
pad_value,
+ arith::Analyzer* analyzer)
const {
if (write_info_.size() || is_zero(padding_predicate) ||
!pad_value.defined()) {
return std::nullopt;
}
@@ -476,7 +481,7 @@ class TransformLayoutPlanner : private StmtExprVisitor {
}
padding_predicate = Substitute(std::move(padding_predicate),
loop_indices_to_block_indices);
- PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices)[0];
+ PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices,
analyzer)[0];
PrimExpr expr = (!padding_predicate) || (BufferLoad(new_buffer, indices)
== pad_value_at_index);
Stmt stmt = Evaluate(Call(DataType::Bool(), builtin::assume(), {expr}));
@@ -498,7 +503,8 @@ class TransformLayoutPlanner : private StmtExprVisitor {
std::optional<ReplacementPlan> FinalizeReplacementPlan(Buffer new_buffer,
IndexMap index_map,
IndexMap inverse,
PrimExpr
padding_predicate,
- Optional<IndexMap>
pad_value) const {
+ Optional<IndexMap>
pad_value,
+ arith::Analyzer*
analyzer) const {
if (write_info_.empty() || is_zero(padding_predicate) ||
!pad_value.defined()) {
return std::nullopt;
}
@@ -511,7 +517,7 @@ class TransformLayoutPlanner : private StmtExprVisitor {
}
BufferStoreReplacer replacer(info, new_buffer, padding_predicate,
inverse, pad_value,
- &new_block_to_old);
+ &new_block_to_old, analyzer);
Stmt stmt = replacer(info.dependent_loopnest.back()->body);
if (!replacer.is_all_stores_replaced()) {
return NullOpt;
@@ -547,7 +553,8 @@ class TransformLayoutPlanner : private StmtExprVisitor {
std::optional<EpiloguePlan> FinalizeEpiloguePlan(Buffer new_buffer, IndexMap
index_map,
IndexMap inverse, PrimExpr
padding_predicate,
- Optional<IndexMap>
pad_value) const {
+ Optional<IndexMap>
pad_value,
+ arith::Analyzer* analyzer)
const {
if (write_info_.empty() || is_zero(padding_predicate) ||
!pad_value.defined()) {
return std::nullopt;
}
@@ -566,7 +573,7 @@ class TransformLayoutPlanner : private StmtExprVisitor {
iter_values.push_back(loop_var);
}
- PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices)[0];
+ PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices,
analyzer)[0];
Stmt stmt = BufferStore(new_buffer, pad_value_at_index, indices);
std::stringstream block_name;
@@ -757,12 +764,13 @@ class TransformLayoutRewriter : private
arith::IRMutatorWithAnalyzer {
const Block& scope_stmt, const Buffer& old_buffer, const Buffer&
new_buffer,
const IndexMap& index_map, const Optional<IndexMap>& opt_inverse,
const PrimExpr& padding_predicate, const Optional<IndexMap>& pad_value) {
- auto plan = pad_value.defined() ? TransformLayoutPlanner::Plan(
- scope_stmt, old_buffer, new_buffer,
index_map,
- opt_inverse.value(),
padding_predicate, pad_value)
- :
TransformLayoutPlanner::NoPaddingRequired();
-
arith::Analyzer analyzer;
+ auto plan = pad_value.defined()
+ ? TransformLayoutPlanner::Plan(scope_stmt, old_buffer,
new_buffer, index_map,
+ opt_inverse.value(),
padding_predicate,
+ pad_value, &analyzer)
+ : TransformLayoutPlanner::NoPaddingRequired();
+
TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan,
&analyzer);
Block result = Downcast<Block>(rewriter(scope_stmt));
if (auto plan_ptr =
std::get_if<TransformLayoutPlanner::ProloguePlan>(&plan)) {
@@ -794,9 +802,8 @@ class TransformLayoutRewriter : private
arith::IRMutatorWithAnalyzer {
void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) {
*buffer = new_buffer_;
- *indices = index_map_->MapIndices(*indices);
- (*indices).MutateByApply(
- [&](const PrimExpr& e) { return SimplifyNonTrivialExpr(e, analyzer_);
});
+ *indices = index_map_->MapIndices(*indices, &index_simplifier_);
+ *indices = this->IterMapSimplifyWithContext(*indices, true);
}
using Parent = arith::IRMutatorWithAnalyzer;
@@ -913,6 +920,7 @@ class TransformLayoutRewriter : private
arith::IRMutatorWithAnalyzer {
const TransformLayoutPlanner::TransformPlan& plan_;
Map<Var, Buffer> buffer_data_to_buffer_;
Map<Block, Block> new_block_to_old_;
+ arith::Analyzer index_simplifier_;
};
class BufferIsSubregionError : public ScheduleError {
@@ -1069,7 +1077,8 @@ class TransformationIntroducesPaddingError : public
ScheduleError {
}
String DetailRenderTemplate() const final {
- auto new_shape = index_map_->MapShape(buffer_->shape);
+ arith::Analyzer analyzer;
+ auto new_shape = index_map_->MapShape(buffer_->shape, &analyzer);
std::ostringstream os;
os << "The transformation " << index_map_ << " applied on buffer " <<
buffer_->name
<< " of shape " << buffer_->shape << " would result in shape " <<
new_shape
@@ -1138,6 +1147,8 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map,
const Array<PrimExpr>&
void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int
buffer_index,
BufferIndexType buffer_index_type, const IndexMap&
index_map_orig,
const Optional<IndexMap>& pad_value, bool
assume_injective_transform) {
+ arith::Analyzer analyzer;
+ AddShapeVarBounds(self, block_sref.get(), &analyzer);
// Step 1: Input handling and error checking
const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
Buffer old_buffer =
@@ -1173,7 +1184,7 @@ void TransformLayout(ScheduleState self, const StmtSRef&
block_sref, int buffer_
for (const auto& dim : old_buffer->shape) {
region.push_back(Range::FromMinExtent(make_zero(dim.dtype()), dim));
}
- return index_map.NonSurjectiveInverse(region);
+ return index_map.NonSurjectiveInverse(region, &analyzer);
}();
}
@@ -1184,7 +1195,7 @@ void TransformLayout(ScheduleState self, const StmtSRef&
block_sref, int buffer_
// Step 2: Infer the shape of the new buffer
Buffer new_buffer = old_buffer;
- new_buffer.CopyOnWrite()->shape = index_map->MapShape(old_buffer->shape);
+ new_buffer.CopyOnWrite()->shape = index_map->MapShape(old_buffer->shape,
&analyzer);
// Step 3: Rewrite BufferLoad/BufferStore access indices, block read/write
regions, and block
// alloc_buffers.
@@ -1336,6 +1347,7 @@ void TransformBlockLayout(ScheduleState self, const
StmtSRef& block_sref,
const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
const Block& block = GetRef<Block>(block_ptr);
arith::Analyzer analyzer;
+ AddShapeVarBounds(self, block_sref.get(), &analyzer);
// Step 1: Collect outer loops and loop vars
Array<StmtSRef> loops = GetLoops(block_sref); // outer loops of the block
@@ -1375,8 +1387,8 @@ void TransformBlockLayout(ScheduleState self, const
StmtSRef& block_sref,
// Step 4: Apply the IndexMap to block iters.
IndexMapNotApplicableToBlockIterError::Check(self->mod, block, index_map);
- Array<PrimExpr> transformed_block_iters = index_map->MapIndices(block_vars);
- Array<PrimExpr> new_block_iter_range =
index_map->MapShape(block_iter_range_array);
+ Array<PrimExpr> transformed_block_iters = index_map->MapIndices(block_vars,
&analyzer);
+ Array<PrimExpr> new_block_iter_range =
index_map->MapShape(block_iter_range_array, &analyzer);
// Step 5: Create the new block after transformation.
@@ -1408,14 +1420,13 @@ void TransformBlockLayout(ScheduleState self, const
StmtSRef& block_sref,
}
IndexMap inverse_index_map{nullptr};
try {
- inverse_index_map = index_map.Inverse(initial_ranges);
+ inverse_index_map = index_map.Inverse(initial_ranges, &analyzer);
} catch (...) {
throw NotBijectiveAffineIndexMapError(self->mod, index_map);
}
-
- Array<PrimExpr> inversed_new_block_vars = inverse_index_map->MapIndices(
- new_block_vars); // old block vars written in terms of new block vars
-
+ // old block vars written in terms of new block vars
+ Array<PrimExpr> inversed_new_block_vars =
+ inverse_index_map->MapIndices(new_block_vars, &analyzer);
for (int i = 0, n = block_vars.size(); i < n; ++i) {
inverse_subst_map.Set(Downcast<Var>(block_vars[i]),
inversed_new_block_vars[i]);
}
diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc
index e4047273b6..62662a4ce1 100644
--- a/src/tir/schedule/transform.cc
+++ b/src/tir/schedule/transform.cc
@@ -388,16 +388,17 @@
TVM_REGISTER_GLOBAL("tir.schedule.TileWithTensorIntrin").set_body_typed(TileWith
/******** BlockBufferAccessSimplifier ********/
void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array<BufferRegion>*
old_access_regions) {
auto fmutate = [this](const BufferRegion& buffer_region) {
- std::vector<Range> new_buffer_region;
+ Array<Range> new_buffer_region;
+ Array<PrimExpr> simplified_min;
for (const auto& range : buffer_region->region) {
- if (is_one(range->extent) && range->min->IsInstance<VarNode>()) {
- new_buffer_region.push_back(Range::FromMinExtent(
- SimplifyNonTrivialExpr(range->min, analyzer_),
make_const(range->min.dtype(), 1)));
- } else {
- new_buffer_region.push_back(
- Range::FromMinExtent(SimplifyNonTrivialExpr(range->min, analyzer_),
- SimplifyNonTrivialExpr(range->extent,
analyzer_)));
- }
+ simplified_min.push_back(range->min);
+ }
+ simplified_min = this->IterMapSimplifyWithContext(simplified_min, true);
+ int n = buffer_region->region.size();
+ for (int i = 0; i < n; ++i) {
+ PrimExpr min = simplified_min[i];
+ PrimExpr extent = analyzer_->Simplify(buffer_region->region[i]->extent);
+ new_buffer_region.push_back(Range::FromMinExtent(min, extent));
}
return BufferRegion(buffer_region->buffer, new_buffer_region);
};
@@ -405,8 +406,7 @@ void
BlockBufferAccessSimplifier::SimplifyAccessRegion(Array<BufferRegion>* old_
}
void BlockBufferAccessSimplifier::SimplifyBufferIndices(Array<PrimExpr>*
indices) {
- (*indices).MutateByApply(
- [this](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr,
analyzer_); });
+ *indices = this->IterMapSimplifyWithContext(*indices, true);
}
Stmt BlockBufferAccessSimplifier::VisitStmt_(const BlockNode* op) {
diff --git a/src/tir/transforms/flatten_buffer.cc
b/src/tir/transforms/flatten_buffer.cc
index f37c21593f..c04e12b839 100644
--- a/src/tir/transforms/flatten_buffer.cc
+++ b/src/tir/transforms/flatten_buffer.cc
@@ -220,18 +220,7 @@ class BufferFlattener : public
arith::IRMutatorWithAnalyzer {
Array<PrimExpr> GetSimplifiedElemOffset(const Buffer& buffer, const
Array<PrimExpr>& indices) {
auto flattened_indices = buffer->ElemOffset(indices);
- // Use IterMapSimplify to enable constant fold of fused indices
- // IterMapSimplify is more powerful and time-consuming than normal
- // simplify as it tries to deal with symbolic fusion
- //
- // Only use to handle indices during layout transformations
- // So we restrict the use to here
- PrimExpr pred = const_true();
- for (PrimExpr val : iter_predicates_) {
- pred = pred && val;
- }
- return arith::IterMapSimplify(flattened_indices, this->iter_vars_, pred,
- arith::IterMapLevel::Surjective,
this->analyzer_);
+ return this->IterMapSimplifyWithContext(flattened_indices, false);
}
template <typename Node>
diff --git a/src/tir/transforms/storage_flatten.cc
b/src/tir/transforms/storage_flatten.cc
index 8c409fba5e..9c12448381 100644
--- a/src/tir/transforms/storage_flatten.cc
+++ b/src/tir/transforms/storage_flatten.cc
@@ -1265,7 +1265,7 @@ class ApplyLayoutTransforms : public StmtExprMutator {
Array<IndexMap> transforms = lookup.value();
for (const auto& transform : transforms) {
- write_ptr->bounds = transform->MapRanges(realize->bounds);
+ write_ptr->bounds = transform->MapRanges(realize->bounds, &analyzer);
}
}
@@ -1292,7 +1292,7 @@ class ApplyLayoutTransforms : public StmtExprMutator {
Array<IndexMap> transforms = lookup.value();
for (const auto& transform : transforms) {
- write_ptr->indices = transform->MapIndices(node->indices);
+ write_ptr->indices = transform->MapIndices(node->indices, &analyzer);
}
}
return node;
@@ -1315,7 +1315,7 @@ class ApplyLayoutTransforms : public StmtExprMutator {
auto write_ptr = buf.CopyOnWrite();
for (const auto& transform : transforms) {
- write_ptr->shape = transform->MapShape(buf->shape);
+ write_ptr->shape = transform->MapShape(buf->shape, &analyzer);
}
}
@@ -1326,6 +1326,7 @@ class ApplyLayoutTransforms : public StmtExprMutator {
std::unordered_map<const BufferNode*, Buffer> buf_map_;
Map<Buffer, Array<IndexMap>> layout_transforms_;
+ arith::Analyzer analyzer;
};
class StorageFlattener : public StmtExprMutator {
diff --git a/src/tir/transforms/transform_mma_buffer_layout.cc
b/src/tir/transforms/transform_mma_buffer_layout.cc
index 82fd6cfa9a..abe0bc3a3d 100644
--- a/src/tir/transforms/transform_mma_buffer_layout.cc
+++ b/src/tir/transforms/transform_mma_buffer_layout.cc
@@ -130,7 +130,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator {
const auto* index_map_func =
runtime::Registry::Get("tir.index_map_m16n8k8.matrixC");
ICHECK(index_map_func);
auto index_map = IndexMap::FromFunc(2, *index_map_func);
- auto new_indices = index_map->MapIndices(store->indices);
+ auto new_indices = index_map->MapIndices(store->indices, &analyzer);
n->buffer = buffer_map_[store->buffer];
n->indices = std::move(new_indices);
} else if (store->buffer.scope() == "m16n8k8.matrixA" ||
@@ -149,7 +149,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator {
const auto* index_map_func =
runtime::Registry::Get("tir.index_map_m16n8k8.matrixC");
ICHECK(index_map_func);
auto index_map = IndexMap::FromFunc(2, *index_map_func);
- auto new_indices = index_map->MapIndices(load->indices);
+ auto new_indices = index_map->MapIndices(load->indices, &analyzer);
n->buffer = buffer_map_[load->buffer];
n->indices = std::move(new_indices);
} else if (load->buffer.scope() == "m16n8k8.matrixA" ||
@@ -179,7 +179,7 @@ Pass TransformMmaBufferLayout() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = MmaBufferLayoutTransformer()(std::move(n->body));
- return std::move(f);
+ return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.TransformMmaBufferLayout", {});
}
diff --git a/tests/python/unittest/test_arith_iter_affine_map.py
b/tests/python/unittest/test_arith_iter_affine_map.py
index 594dec73ea..cee9922e86 100644
--- a/tests/python/unittest/test_arith_iter_affine_map.py
+++ b/tests/python/unittest/test_arith_iter_affine_map.py
@@ -16,7 +16,7 @@
# under the License.
import tvm
import tvm.testing
-from tvm.tir import floormod, floordiv
+from tvm.tir import floordiv, floormod
def ifuse(inputs, pred_extent=None):
@@ -1211,7 +1211,7 @@ def test_iter_map_simplify_unit_loop_order():
# When we have iterators that have same scale but one of them come
# with unit extent, we should prioritize unit extent
assert_iter_map_simplify(
- {x // 128 + y + z: y + x // 128 + z},
+ {x // 128 + y + z: y + z},
var_dom([(x, 128), (y, 128), (z, 1)]),
simplify_trivial_iterators=False,
)
diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py
b/tests/python/unittest/test_meta_schedule_relay_integration.py
index 162dec6271..3a2ca69cba 100644
--- a/tests/python/unittest/test_meta_schedule_relay_integration.py
+++ b/tests/python/unittest/test_meta_schedule_relay_integration.py
@@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Integration test for MetaSchedule"""
-import tempfile
import platform
+import tempfile
from typing import List
import numpy as np
@@ -56,6 +56,7 @@ class MockModule:
# pylint:
enable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument
[email protected]("Integration tests")
def test_meta_schedule_dynamic_loop_extent():
a = relay.var("a", shape=(1, 8, 8, 512), dtype="float32")
b = relay.nn.adaptive_avg_pool2d(a, (7, 7), "NHWC")
@@ -64,6 +65,7 @@ def test_meta_schedule_dynamic_loop_extent():
assert not extracted_tasks
[email protected]("Integration tests")
@pytest.mark.skipif(
platform.machine() == "aarch64",
reason="Currently torch.jit.trace fails on AArch64",
@@ -104,6 +106,7 @@ def test_meta_schedule_integration_extract_from_resnet():
assert t.task_name in expected_task_names, t.task_name
[email protected]("Integration tests")
@pytest.mark.skipif(
platform.machine() == "aarch64",
reason="Currently torch.jit.trace fails on AArch64",
@@ -126,6 +129,7 @@ def test_task_extraction_winograd_tensorcore():
assert len([t for t in extracted_tasks if "winograd" in t.task_name]) == 4
[email protected]("Integration tests")
@pytest.mark.skipif(
platform.machine() == "aarch64",
reason="Currently torch.jit.trace fails on AArch64",
@@ -165,6 +169,7 @@ def test_task_extraction_anchor_block():
assert t.task_name in expected_task_names, t.task_name
[email protected]("Integration tests")
@tvm.testing.requires_package("torch")
def test_meta_schedule_integration_extract_from_bert_base():
pytest.importorskip(
@@ -263,6 +268,7 @@ def test_meta_schedule_integration_extract_from_bert_base():
assert expected_shape == shape, t.task_name
[email protected]("Integration tests")
@pytest.mark.skipif(
platform.machine() == "aarch64",
reason="Currently torch.jit.trace fails on AArch64",
@@ -374,6 +380,7 @@ def extract_task_qbert_avx512():
extract_task_qbert("llvm -mcpu=skylake-avx512", "avx512")
[email protected]("Integration tests")
@tvm.testing.skip_if_32bit(reason="Apparently the LLVM version on i386 image
is too old")
def test_extract_task_arm_conv2d_nchwc():
data_shape = (1, 64, 128, 128)
@@ -419,6 +426,7 @@ def test_extract_task_arm_conv2d_nchwc():
assert list(out_type.shape) == [1, 8, 130, 130, 4]
[email protected]("Integration tests")
def test_meta_schedule_te2primfunc_argument_order_and_lowering():
# pylint:
disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
# fmt: off
@@ -581,7 +589,9 @@ def
test_meta_schedule_te2primfunc_argument_order_and_lowering():
dev,
)
- with target, _create_verification_database(), PassContext( # pylint:
disable=not-context-manager
+ with (
+ target
+ ), _create_verification_database(), PassContext( # pylint:
disable=not-context-manager
opt_level=3,
config={
"relay.backend.use_meta_schedule": True,
@@ -607,6 +617,7 @@ def
test_meta_schedule_te2primfunc_argument_order_and_lowering():
assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4)
[email protected]("Integration tests")
def test_rewrite_layout_link_params():
I, O, H, W = 64, 64, 56, 56
kH = kW = 3
@@ -685,6 +696,7 @@ def test_rewrite_layout_link_params():
np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4)
[email protected]("Integration tests")
def test_module_equality_ignore_ndarray():
target = "llvm --num-cores=4"
@@ -800,6 +812,7 @@ def _test_anchor_tuning(target, space):
np.testing.assert_allclose(ref, out, atol=1e-3)
[email protected]("Integration tests")
@pytest.mark.parametrize(
"space",
[
@@ -811,6 +824,7 @@ def test_anchor_tuning_cpu(space):
_test_anchor_tuning("llvm --num-cores=4", space)
[email protected]("Integration tests")
def test_anchor_tuning_cpu_link_params():
data_shape = (128, 128)
weight_shape1 = (128, 128)
@@ -863,6 +877,7 @@ def test_anchor_tuning_cpu_link_params():
np.testing.assert_allclose(ref, out, atol=1e-3)
[email protected]("Integration tests")
@pytest.mark.xfail(raises=tvm.error.TVMError)
def test_disabled_pass_param():
"""
@@ -908,6 +923,7 @@ def test_disabled_pass_param():
pytest.fail("'disabled_pass' argument does not work")
[email protected]("Integration tests")
def test_rewrite_layout_link_params_1x1_conv2d():
I, O, H, W = 32, 16, 256, 256
kH = kW = 1
diff --git
a/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py
b/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py
index d1ba84d836..437aae9e6b 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_cuda_layout_transform.py
@@ -21,11 +21,14 @@ import tempfile
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
+import pytest
import tvm
import tvm.testing
from tvm import meta_schedule, relay
-from tvm.meta_schedule.schedule.cuda.layout_transform import
cuda_layout_transform_schedule_rule
+from tvm.meta_schedule.schedule.cuda.layout_transform import (
+ cuda_layout_transform_schedule_rule,
+)
from tvm.relay.op import OpPattern
from tvm.script import ir as I
from tvm.script import tir as T
@@ -170,6 +173,7 @@ def run_primfunc(
lib(*input_tensors)
[email protected]("Integration test")
class TestRandomRelayE2ECorrectness:
"""Tests E2E correctness of layout transform schedule.
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
index 7cf06b54ca..d1f4b6bdce 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
@@ -14,9 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint:
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+# pylint:
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring,line-too-long,invalid-name,too-many-locals,too-many-statements,too-many-nested-blocks,too-many-branches,too-many-lines,chained-comparison
import pytest
+
import tvm
import tvm.testing
from tvm import meta_schedule as ms
@@ -413,10 +414,10 @@ def test_conv2d(shared_scope):
with T.block("PadInput_reindex_shared.dyn"):
v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused
* 16 + ax0_ax1_fused // 288)
v1 = T.axis.spatial(288, ax0_ax1_fused % 288)
- T.reads(PadInput[v0 // 256, v1 // 96 + v0 //
16, v1 % 96 // 32 + v0 % 16, v1 % 32])
+ T.reads(PadInput[0, v0 // 16 + v1 // 96, v0 %
16 + v1 % 96 // 32, v1 % 32])
T.writes(PadInput_reindex_shared_dyn[v0, v1])
T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "meta_schedule.cooperative_fetch": 2})
- PadInput_reindex_shared_dyn[v0, v1] =
PadInput[v0 // 256, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32]
+ PadInput_reindex_shared_dyn[v0, v1] =
PadInput[0, v0 // 16 + v1 // 96, v0 % 16 + v1 % 96 // 32, v1 % 32]
for ax0_ax1_fused in range(4608):
with T.block("weight_reindex_shared.dyn"):
v0 = T.axis.spatial(288, ax0_ax1_fused // 16)
@@ -497,9 +498,9 @@ def test_conv2d(shared_scope):
v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
// 16)
v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 16)
T.reads(conv2d_nhwc_reindex_shared_dyn[v0, v1, v2,
v3, v4, v5])
- T.writes(conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 +
v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16])
+ T.writes(conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4
+ v0 * 16) % 16, v5 + v1 * 16])
T.block_attr({"meta_schedule.cooperative_fetch":
3})
- conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16)
// 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared_dyn[v0,
v1, v2, v3, v4, v5]
+ conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4 + v0 *
16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared_dyn[v0, v1, v2, v3, v4, v5]
# fmt: on
decision_0 = [
("SamplePerfectTile", [1, 16, 1, 1, 1]),
@@ -915,10 +916,10 @@ def test_conv_1x1():
with T.block("PadInput_reindex_shared"):
v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused
// 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 64)
v1 = T.axis.spatial(64, ax0_ax1_fused % 64)
- T.reads(inputs[v0 // 256, v0 // 16, v0 % 16,
v1])
+ T.reads(inputs[0, v0 // 16, v0 % 16, v1])
T.writes(PadInput_reindex_shared[v0, v1])
T.block_attr({"buffer_dim_align": [[0, 0, 32,
8]], "meta_schedule.cooperative_fetch": 1})
- PadInput_reindex_shared[v0, v1] = inputs[v0 //
256, v0 // 16, v0 % 16, v1]
+ PadInput_reindex_shared[v0, v1] = inputs[0, v0
// 16, v0 % 16, v1]
for ax0_ax1_ax2_ax3_fused in range(2048):
with T.block("weight_reindex_shared"):
v0 = T.axis.spatial(1, 0)
@@ -1007,9 +1008,9 @@ def test_conv_1x1():
v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 256 // 16)
v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused
% 16)
T.reads(conv2d_nhwc_reindex_shared[v0, v1, v2, v3,
v4, v5])
- T.writes(conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 +
v0 * 16) // 16, (v4 + v0 * 16) % 16, v5 + v1 * 16])
+ T.writes(conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4
+ v0 * 16) % 16, v5 + v1 * 16])
T.block_attr({"meta_schedule.cooperative_fetch":
2})
- conv2d_nhwc[(v4 + v0 * 16) // 256, (v4 + v0 * 16)
// 16, (v4 + v0 * 16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1,
v2, v3, v4, v5]
+ conv2d_nhwc[0, (v4 + v0 * 16) // 16, (v4 + v0 *
16) % 16, v5 + v1 * 16] = conv2d_nhwc_reindex_shared[v0, v1, v2, v3, v4, v5]
# fmt: on
decision_0 = [
diff --git a/tests/python/unittest/test_meta_schedule_trace_apply.py
b/tests/python/unittest/test_meta_schedule_trace_apply.py
index 78b2fdbf3d..8562205753 100644
--- a/tests/python/unittest/test_meta_schedule_trace_apply.py
+++ b/tests/python/unittest/test_meta_schedule_trace_apply.py
@@ -15,16 +15,14 @@
# specific language governing permissions and limitations
# under the License.
import pytest
-
import tvm
-import tvm.testing
import tvm.meta_schedule as ms
+import tvm.testing
from tvm.script import tir as T
-from tvm.tir import Schedule, floormod, floordiv
-from tvm.tir.tensor_intrin.cuda import *
from tvm.target import Target
from tvm.target.codegen import llvm_lookup_intrinsic_id
-
+from tvm.tir import Schedule, floordiv, floormod
+from tvm.tir.tensor_intrin.cuda import *
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN
@@ -1885,6 +1883,7 @@ def test_dense_add_cpu():
((i0 * 64) + i2),
i1,
),
+ index_dtype="int32",
),
pad_value=None,
)
@@ -1950,6 +1949,7 @@ def test_dense_add_cpu_no_write_cache():
((i1 * 32) + i3),
((i0 * 16) + i2),
),
+ index_dtype="int32",
),
pad_value=None,
)
diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py
b/tests/python/unittest/test_meta_schedule_tune_tir.py
index aa45120c23..c8fc4a73f5 100644
--- a/tests/python/unittest/test_meta_schedule_tune_tir.py
+++ b/tests/python/unittest/test_meta_schedule_tune_tir.py
@@ -20,6 +20,7 @@ import tempfile
import numpy as np
import pytest
+
import tvm
import tvm.testing
from tvm import meta_schedule as ms
@@ -61,6 +62,7 @@ def two_step(a: T.handle, c: T.handle) -> None:
C[vi, vj] = B[vi, vj] + 3.0
[email protected]("Integration test")
@tvm.testing.requires_llvm
def test_tune_matmul_cpu():
with tempfile.TemporaryDirectory() as work_dir:
@@ -80,6 +82,7 @@ def test_tune_matmul_cpu():
sch.trace.show()
[email protected]("Integration test")
@tvm.testing.requires_cuda
def test_tune_matmul_cuda():
with tempfile.TemporaryDirectory() as work_dir:
@@ -99,6 +102,7 @@ def test_tune_matmul_cuda():
sch.trace.show()
[email protected]("Integration test")
def test_tune_run_module_via_rpc():
target = tvm.target.Target("llvm")
rt_mod = tvm.build(matmul, target)
@@ -141,6 +145,7 @@ def test_tune_run_module_via_rpc():
tvm.testing.assert_allclose(result.numpy(), c_np, rtol=1e-3)
[email protected]("Integration test")
def test_tune_block_cpu():
@ms.derived_object
class RemoveBlock(ms.schedule_rule.PyScheduleRule):
diff --git a/tests/python/unittest/test_transform_layout.py
b/tests/python/unittest/test_te_transform_layout.py
similarity index 100%
rename from tests/python/unittest/test_transform_layout.py
rename to tests/python/unittest/test_te_transform_layout.py
diff --git a/tests/python/unittest/test_index_map.py
b/tests/python/unittest/test_tir_index_map.py
similarity index 97%
rename from tests/python/unittest/test_index_map.py
rename to tests/python/unittest/test_tir_index_map.py
index 5eb31cd378..e893ed897d 100644
--- a/tests/python/unittest/test_index_map.py
+++ b/tests/python/unittest/test_tir_index_map.py
@@ -15,17 +15,16 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
-
import pytest
+
import tvm
import tvm.testing
from tvm.ir import assert_structural_equal
-from tvm.tir import IndexMap, IntImm, floordiv, floormod
from tvm.runtime import const
+from tvm.tir import IndexMap, IntImm, floordiv, floormod
def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None:
-
iters_1 = map1.map_indices(map2.initial_indices)
iters_2 = map2.final_indices
assert len(iters_1) == len(iters_2)
@@ -36,7 +35,7 @@ def assert_equal_index_map(map1: IndexMap, map2: IndexMap) ->
None:
def test_index_mapping():
- index_map = IndexMap.from_func(lambda i: [i // 4, i % 4])
+ index_map = IndexMap.from_func(lambda i: [i // 4, i % 4],
index_dtype="int32")
assert_structural_equal(index_map.map_indices([0]), [0, 0])
assert_structural_equal(index_map.map_indices([3]), [0, 3])
@@ -48,7 +47,7 @@ def test_index_mapping():
def test_shape_mapping():
- index_map = IndexMap.from_func(lambda i: [i // 4, i % 4])
+ index_map = IndexMap.from_func(lambda i: [i // 4, i % 4],
index_dtype="int32")
assert_structural_equal(index_map.map_shape([4]), [1, 4])
assert_structural_equal(index_map.map_shape([16]), [4, 4])
@@ -184,7 +183,7 @@ padding_test_case = tvm.testing.parameter(
def test_nonsurjective_inverse(padding_test_case):
- index_map = IndexMap.from_func(padding_test_case["forward"])
+ index_map = IndexMap.from_func(padding_test_case["forward"],
index_dtype="int32")
inverse, padding_predicate =
index_map.non_surjective_inverse(padding_test_case["pre_shape"])
expected_inverse = IndexMap.from_func(padding_test_case["inverse"])
diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py
b/tests/python/unittest/test_tir_schedule_transform_layout.py
index 8de11d8bd5..04bd00111e 100644
--- a/tests/python/unittest/test_tir_schedule_transform_layout.py
+++ b/tests/python/unittest/test_tir_schedule_transform_layout.py
@@ -154,11 +154,9 @@ def conv2d_nhwc_transformed(
for ax0, ax1, ax2 in T.grid(12544, 64, 147):
with T.block("conv2d_nhwc"):
v0, v1, v2 = T.axis.remap("SSR", [ax0, ax1, ax2])
- T.reads(PadInput[v0 // 12544, v0 // 112 * 2 + v2 // 21, v0 % 112 *
2 + v2 % 21 // 3, v2 % 3], Weight[v2 // 21, v2 % 21 // 3, v2 % 3, v1])
- T.writes(Conv2d_nhwc[v0 // 12544, v0 // 112, v0 % 112, v1])
with T.init():
- Conv2d_nhwc[v0 // 12544, v0 // 112, v0 % 112, v1] =
T.float32(0)
- Conv2d_nhwc[v0 // 12544, v0 // 112, v0 % 112, v1] = Conv2d_nhwc[v0
// 12544, v0 // 112, v0 % 112, v1] + PadInput[v0 // 12544, v0 // 112 * 2 + v2
// 21, v0 % 112 * 2 + v2 % 21 // 3, v2 % 3] * Weight[v2 // 21, v2 % 21 // 3, v2
% 3, v1]
+ Conv2d_nhwc[0, v0 // 112, v0 % 112, v1] = T.float32(0)
+ Conv2d_nhwc[0, v0 // 112, v0 % 112, v1] = Conv2d_nhwc[0, v0 //
112, v0 % 112, v1] + PadInput[0, v0 // 112 * 2 + v2 // 21, v0 % 112 * 2 + v2 %
21 // 3, v2 % 3] * Weight[v2 // 21, v2 % 21 // 3, v2 % 3, v1]
@T.prim_func
@@ -461,11 +459,6 @@ def
test_transform_block_layout_int64_extent(use_block_name):
sch = tir.Schedule(elementwise_int64_extent, debug_mask="all")
block = "B" if use_block_name else sch.get_block("B")
sch.transform_block_layout(block, lambda i, j: (i * 128 + j,))
- print(
- tvm.ir.base.get_first_structural_mismatch(
- elementwise_int64_extent_transformed, sch.mod["main"]
- )
- )
tvm.ir.assert_structural_equal(elementwise_int64_extent_transformed,
sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=elementwise_int64_extent)
@@ -1085,5 +1078,106 @@ def test_index_map_dtype_legalize_with_constant():
sch.transform_layout(block="block", buffer="A", index_map=func,
pad_value=0)
+def test_transform_layout_with_symbolic_bound():
+ # fmt: off
+ # pylint: disable=invalid-name,line-too-long,too-many-locals
+ @T.prim_func
+ def before(a: T.handle, b: T.handle, c: T.handle):
+ T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+ n = T.int64()
+ A = T.match_buffer(a, (T.int64(1), T.int64(32), T.int64(1),
T.int64(128)), "float16")
+ B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)),
"float16")
+ C = T.match_buffer(c, (T.int64(1), T.int64(32), T.int64(1), n),
"float16")
+ for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1),
n, T.int64(128)):
+ with T.block("NT_matmul"):
+ v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1,
i2, i3, k])
+ T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_i3, v_k])
+ T.writes(C[v_i0, v_i1, v_i2, v_i3])
+ with T.init():
+ C[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
+ C[v_i0, v_i1, v_i2, v_i3] = C[v_i0, v_i1, v_i2, v_i3] +
A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_i3, v_k]
+
+ @T.prim_func
+ def after(a: T.handle, b: T.handle, c: T.handle):
+ T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+ n = T.int64()
+ A = T.match_buffer(a, (T.int64(1), T.int64(32), T.int64(1),
T.int64(128)), "float16")
+ B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)),
"float16")
+ C = T.match_buffer(c, (n * T.int64(32),), "float16")
+ for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1),
n, T.int64(128)):
+ with T.block("NT_matmul"):
+ v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1,
i2, i3, k])
+ T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_i3, v_k])
+ T.writes(C[v_i1 * n + v_i3])
+ with T.init():
+ C[v_i1 * n + v_i3] = T.float16(0)
+ C[v_i1 * n + v_i3] = C[v_i1 * n + v_i3] + A[v_i0, v_i1, v_i2,
v_k] * B[v_i0, v_i1, v_i3, v_k]
+ # pylint: enable=invalid-name,line-too-long,too-many-locals
+ # fmt: on
+ # pylint: disable=invalid-name
+ _, _, n, _ = before.buffer_map[before.params[1]].shape
+ sch = tvm.tir.Schedule(before)
+ block = sch.get_block("NT_matmul")
+ sch.transform_layout(
+ block,
+ ("write", 0),
+ lambda x, y, z, w: x * 32 * n + y * n + z * n + w,
+ assume_injective_transform=True,
+ )
+ # pylint: enable=invalid-name
+ tvm.ir.assert_structural_equal(after, sch.mod["main"])
+
+
+def test_transform_block_layout_with_symbolic_bound():
+ # fmt: off
+ # pylint: disable=invalid-name,line-too-long,too-many-locals
+ @T.prim_func
+ def before(a: T.handle, b: T.handle, c: T.handle):
+ T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+ n = T.int64()
+ A = T.match_buffer(a, (T.int64(1), T.int64(32), T.int64(1),
T.int64(128)), "float16")
+ B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)),
"float16")
+ C = T.match_buffer(c, (n * T.int64(32),), "float16")
+ for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1),
n, T.int64(128)):
+ with T.block("NT_matmul"):
+ v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1,
i2, i3, k])
+ T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_i3, v_k])
+ T.writes(C[v_i1 * n + v_i3])
+ with T.init():
+ C[v_i1 * n + v_i3] = T.float16(0)
+ C[v_i1 * n + v_i3] = C[v_i1 * n + v_i3] + A[v_i0, v_i1, v_i2,
v_k] * B[v_i0, v_i1, v_i3, v_k]
+
+ @T.prim_func
+ def after(a: T.handle, b: T.handle, c: T.handle):
+ T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+ n = T.int64()
+ A = T.match_buffer(a, (T.int64(1), T.int64(32), T.int64(1),
T.int64(128)), "float16")
+ B = T.match_buffer(b, (T.int64(1), T.int64(32), n, T.int64(128)),
"float16")
+ C = T.match_buffer(c, (n * T.int64(32),), "float16")
+ for ax0, ax1 in T.grid(n * T.int64(32), T.int64(128)):
+ with T.block("NT_matmul"):
+ v0, v1 = T.axis.remap("SR", [ax0, ax1])
+ T.reads(A[T.int64(0), v0 // n, T.int64(0), v1], B[T.int64(0),
v0 // n, v0 % n, v1])
+ T.writes(C[v0])
+ with T.init():
+ C[v0] = T.float16(0)
+ C[v0] = C[v0] + A[T.int64(0), v0 // n, T.int64(0), v1] *
B[T.int64(0), v0 // n, v0 % n, v1]
+ # pylint: enable=invalid-name,line-too-long,too-many-locals
+ # fmt: on
+ # pylint: disable=invalid-name
+ _, _, n, _ = before.buffer_map[before.params[1]].shape
+ sch = tvm.tir.Schedule(before)
+ block = sch.get_block("NT_matmul")
+ sch.transform_block_layout(
+ block,
+ lambda x, y, z, w, k: (
+ x * 32 * n + y * n + z * n + w,
+ k,
+ ),
+ )
+ # pylint: enable=invalid-name
+ tvm.ir.assert_structural_equal(after, sch.mod["main"])
+
+
if __name__ == "__main__":
tvm.testing.main()