Lunderberg commented on code in PR #13463:
URL: https://github.com/apache/tvm/pull/13463#discussion_r1029680983
##########
tests/python/unittest/test_tir_schedule_transform_layout.py:
##########
@@ -173,6 +173,35 @@ def two_elementwise_unit_dim(A: T.Buffer[(1, 128),
"float32"], C: T.Buffer[(1, 1
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0
+
+
[email protected]_module
Review Comment:
Is this entire definition required for the test case? As a reader, it's
hard to tell which parts of this PrimFunc are needed to trigger the bug.
##########
src/tir/schedule/primitive/layout_transformation.cc:
##########
@@ -1055,13 +1055,43 @@ class TransformationIntroducesPaddingError : public
ScheduleError {
PrimExpr padding_predicate_;
};
+// Make the dtypes of indices in IndexMap be the same as the dtype of the
buffer shape, to avoid
+// dtype-mismatch issues later.
+IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Buffer& buf) {
+ auto initial_indices_orig = index_map->initial_indices;
+ ICHECK(buf->shape.size() == initial_indices_orig.size());
+
+ Array<Var> initial_indices;
+ Map<Var, PrimExpr> var_map;
+
+ for (size_t i = 0; i < buf->shape.size(); ++i) {
+ if (buf->shape[i]->dtype != initial_indices_orig[i].dtype()) {
Review Comment:
I think this would have an error if only some of the index dtypes have a
mismatch. In that case, `initial_indices` would only be filled with variables
that have a mismatched dtype, when it should have the same size as
`initial_indices_orig`.
```c++
if (buf->shape[i]->dtype == initial_indices_orig[i].dtype()) {
initial_indices.push_back(initial_indices_orig[i]);
} else {
auto new_idx = ...
}
```
##########
tests/python/unittest/test_tir_schedule_transform_layout.py:
##########
@@ -925,5 +954,24 @@ def expected(a: T.handle):
A[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, 42,
dtype="int32")
+def test_index_map_dtype_legalize():
+ """Test dtype legalization of the index map indices."""
+
+ def index_map_nchw32c_nchw8h8w32c(n_batch, channel, height, width,
channel_32):
+ return [n_batch, channel, height // 8, width // 8, height % 8, width %
8, channel_32]
+
+ sch = tir.Schedule(Conv2dNCHW32c)
+
+ conv2d_block = sch.get_block("conv2d_NCHWc_int8")
+ sch.cache_read(conv2d_block, 0, "global.vtcm")
+
+ # The following error is raised from the IterVar constructor without the
dtype legalization.
+ # TVMError: Check failed: dom->extent.dtype() == var.dtype() (int64 vs.
int32) :
+ # The dtype of the extent of an IterVar (int64) must match its associated
Var's dtype (int32)
+ sch.transform_layout(
+ conv2d_block, ("read", 0), index_map=index_map_nchw32c_nchw8h8w32c,
pad_value=0
Review Comment:
Instead of `("read", 0)`, the buffer can be specified by name, `buffer =
"data_pad_global_vtcm"`.
##########
src/tir/schedule/primitive/layout_transformation.cc:
##########
@@ -1055,13 +1055,43 @@ class TransformationIntroducesPaddingError : public
ScheduleError {
PrimExpr padding_predicate_;
};
+// Make the dtypes of indices in IndexMap be the same as the dtype of the
buffer shape, to avoid
+// dtype-mismatch issues later.
+IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Buffer& buf) {
+ auto initial_indices_orig = index_map->initial_indices;
Review Comment:
Nit: `const auto&` instead of `auto`.
##########
tests/python/unittest/test_tir_schedule_transform_layout.py:
##########
@@ -173,6 +173,35 @@ def two_elementwise_unit_dim(A: T.Buffer[(1, 128),
"float32"], C: T.Buffer[(1, 1
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0
+
+
[email protected]_module
+class Conv2dNCHW32c:
+ @T.prim_func
Review Comment:
This location for `@T.prim_func` would run while the test is being
collected. This can make it difficult to trouble-shoot, since a failure when
parsing/constructing would prevent any unit test from running, not just the
unit tests that make use of it. It would be better to have the `@T.prim_func`
be in the unit test itself, so a failure to construct the primfunc would only
cause a failure in a single test.
##########
tests/python/unittest/test_tir_schedule_transform_layout.py:
##########
@@ -925,5 +954,24 @@ def expected(a: T.handle):
A[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, 42,
dtype="int32")
+def test_index_map_dtype_legalize():
+ """Test dtype legalization of the index map indices."""
+
+ def index_map_nchw32c_nchw8h8w32c(n_batch, channel, height, width,
channel_32):
+ return [n_batch, channel, height // 8, width // 8, height % 8, width %
8, channel_32]
+
+ sch = tir.Schedule(Conv2dNCHW32c)
+
+ conv2d_block = sch.get_block("conv2d_NCHWc_int8")
+ sch.cache_read(conv2d_block, 0, "global.vtcm")
Review Comment:
Why does the test case call `cache_read`, instead of starting with the input
provided to `transform_layout`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]