This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new b54f9ff194 [Unity] Add support for AXIS_SEPARATOR in AlterOpImpl Pass
(#15315)
b54f9ff194 is described below
commit b54f9ff194e65988b57be1d46fd833824b7c548f
Author: Abhikrant Sharma <[email protected]>
AuthorDate: Tue Jul 25 01:30:17 2023 +0530
[Unity] Add support for AXIS_SEPARATOR in AlterOpImpl Pass (#15315)
* [Unity] Add support for AXIS_SEPARATOR in AlterOpImpl Pass
Enable support of AXIS_SEPARATOR to handle non-flat buffers
Modified pass to handle AXIS_SEPARATOR
* Fix LINT errors.
---
include/tvm/relax/attrs/manipulate.h | 9 +++
include/tvm/relax/transform.h | 4 +-
python/tvm/relax/op/manipulate.py | 10 ++-
python/tvm/relax/transform/transform.py | 8 +-
src/relax/op/tensor/manipulate.cc | 4 +-
src/relax/op/tensor/manipulate.h | 5 +-
src/relax/transform/alter_op_impl.cc | 55 ++++++++++----
tests/python/relax/test_transform_alter_op_impl.py | 88 +++++++++++++++++++++-
8 files changed, 161 insertions(+), 22 deletions(-)
diff --git a/include/tvm/relax/attrs/manipulate.h
b/include/tvm/relax/attrs/manipulate.h
index 550515c032..b9d0b9f53b 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -59,12 +59,21 @@ struct LayoutTransformAttrs : public
tvm::AttrsNode<LayoutTransformAttrs> {
// pad_value is chosen to be of PrimValue type, as it represents constant
TIR POD expression. This
// needs to be revisited in case PrimValue is evolved to represent symbolic
expression in future.
Optional<PrimValue> pad_value;
+ /*!
+ * axis_separators between input axes when generating flattened output axes.
For buffers
+ * representing flat 1-d memory (e.g. any buffer in RAM), this should be an
empty array.
+ * For buffers representing non-flat memory, each entry in axis_separators
should be the
+ * first input axis that is part of a new flattened axis.
+ */
+ Optional<Array<IntImm>> axis_separators;
TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") {
TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply.");
TVM_ATTR_FIELD(pad_value).describe(
"The specific value to be used to pad if the layout transform would
result in implicit "
"padding. If not specified, the compiler is free to choose any
value.");
+ TVM_ATTR_FIELD(axis_separators)
+ .describe("The separators between input axes when generating flat
output axes");
}
}; // struct LayoutTransformAttrs
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index b618672292..dc2476f383 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -464,10 +464,12 @@ TVM_DLL Pass DecomposeOpsForTraining(Optional<String>
func_name);
* \param op_impl_map Map from from kOperatorName attr (e.g., relax.conv2d) to
replacement PrimFunc
* \param op_buffer_transforms Map from kOperatorName attr to layout
transformations on each of the
* PrimFunc i/o buffers.
+ * \param axis_separators Map from kOperatorName attr to axis_separators of
each buffer_transforms
* \return The Pass.
*/
TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
- const Map<String, Array<tir::IndexMap>>&
op_buffer_transforms);
+ const Map<String, Array<tir::IndexMap>>&
op_buffer_transforms,
+ const Map<String, Array<Array<IntImm>>>&
axis_separators);
/*!
* \brief Layout conversion pass.
diff --git a/python/tvm/relax/op/manipulate.py
b/python/tvm/relax/op/manipulate.py
index a51b516b2e..a6f1c580a4 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -114,6 +114,7 @@ def layout_transform(
x: Expr,
index_map: Union[Callable, IndexMap],
pad_value: Optional[Union[int, float, PrimValue]] = None,
+ axis_separators: Optional[Union[int, IndexMap.AXIS_SEPARATOR]] = None,
):
"""Modifies the layout of a tensor.
@@ -129,6 +130,9 @@ def layout_transform(
The value used for padding if the transformation results in implicit
padding.
If not specified, any value can be used.
+ axis_separators : Optional[Union[int, IndexMap.AXIS_SEPARATOR]]
+ The axis_separators for index_map to create non flat buffers.
+
Returns
-------
result : relax.Expr
@@ -150,7 +154,11 @@ def layout_transform(
elif "float" in x_dtype and (isinstance(pad_value, (int, float))):
pad_value = FloatImm(x_dtype, float(pad_value))
pad_value = PrimValue(pad_value)
- return _ffi_api.layout_transform(x, index_map, pad_value) # type: ignore
+
+ if axis_separators is None:
+ axis_separators = []
+
+ return _ffi_api.layout_transform(x, index_map, pad_value, axis_separators)
# type: ignore
def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr:
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index d1d337a8c0..73267c43ae 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -903,6 +903,7 @@ def DecomposeOpsForTraining(func_name: Optional[str] =
None) -> tvm.ir.transform
def AlterOpImpl(
op_impl_map: Dict[str, PrimFunc],
op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]],
+ op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR,
Callable]]],
):
"""Replace all PrimFunc's which have matching 'operator_name' attribute,
with replacement
PrimFunc that could possibly have different layouts on i/o buffers. The
layout
@@ -916,6 +917,9 @@ def AlterOpImpl(
op_kind to PrimFunc map
op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]
op_kind to layout transformation map for each of the buffers
+ op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR,
Callable]]]
+ op_kind to axis_separator for each index_map
+
Returns
-------
ret: tvm.ir.transform.Pass
@@ -928,7 +932,9 @@ def AlterOpImpl(
l.append(transform)
op_buffer_transforms[operator_name] = l
- return _ffi_api.AlterOpImpl(op_impl_map, op_buffer_transforms) # type:
ignore
+ return _ffi_api.AlterOpImpl(
+ op_impl_map, op_buffer_transforms, op_buffer_axis_separators
+ ) # type: ignore
def ConvertLayout(desired_layouts: Dict[str, List[str]]) ->
tvm.ir.transform.Pass:
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index a55d199822..2d7e60c4f0 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -425,10 +425,12 @@ TVM_REGISTER_OP("relax.flatten")
/* relax.layout_transform */
TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);
-Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue>
pad_value) {
+Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue>
pad_value,
+ Optional<Array<IntImm>> axis_separators) {
ObjectPtr<LayoutTransformAttrs> attrs = make_object<LayoutTransformAttrs>();
attrs->index_map = std::move(index_map);
attrs->pad_value = std::move(pad_value);
+ attrs->axis_separators = std::move(axis_separators);
static const Op& op = Op::Get("relax.layout_transform");
return Call(op, {std::move(x)}, Attrs{attrs}, {});
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index 1593b19c17..b19e3b8507 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -65,9 +65,12 @@ Expr flatten(Expr x);
* \param index_map The transformation to apply.
* \param pad_value The value used for padding if the transformation results
in implicit padding. If
* not specified, any value can be used.
+ * \param axis_separators Array of values to differentiate between input axes
+ * when generating flattened output axes.
* \return The transformed result.
*/
-Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue>
pad_value);
+Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue>
pad_value,
+ Optional<Array<IntImm>> axis_separators);
/*!
* \brief Permutes the dimensions of an array.
diff --git a/src/relax/transform/alter_op_impl.cc
b/src/relax/transform/alter_op_impl.cc
index f40ee3b3bf..c303a2c8f0 100644
--- a/src/relax/transform/alter_op_impl.cc
+++ b/src/relax/transform/alter_op_impl.cc
@@ -76,11 +76,13 @@ bool IsTransformBijective(const Expr& expr, const IndexMap&
transform) {
class AlterOpImplMutator : public ExprMutator {
public:
AlterOpImplMutator(const IRModule& mod, const Map<String, tir::PrimFunc>&
op_impl_map,
- const Map<String, Array<IndexMap>>& op_buffer_transforms_)
+ const Map<String, Array<IndexMap>>& op_buffer_transforms_,
+ const Map<String, Array<Array<IntImm>>>& axis_separators_)
: ExprMutator(mod),
mod_(mod),
op_impl_map_(op_impl_map),
- op_buffer_transforms__(op_buffer_transforms_) {}
+ op_buffer_transforms__(op_buffer_transforms_),
+ op_buffer_axis_separators__(axis_separators_) {}
IRModule Run() {
for (const auto& [gv, func] : mod_->functions) {
@@ -119,7 +121,10 @@ class AlterOpImplMutator : public ExprMutator {
const auto& replacement_func = op_impl_map_[op_kind];
Array<IndexMap> buffer_transforms;
+ Optional<Array<Array<IntImm>>> axis_separators;
if (op_buffer_transforms__.count(op_kind)) buffer_transforms =
op_buffer_transforms__[op_kind];
+ if (op_buffer_axis_separators__.count(op_kind))
+ axis_separators = op_buffer_axis_separators__[op_kind];
ICHECK(buffer_transforms.empty() || buffer_transforms.size() ==
replacement_func->params.size())
<< "Either the i/o buffers do not require any transformations or
transformations for each "
@@ -130,7 +135,7 @@ class AlterOpImplMutator : public ExprMutator {
GlobalVar replacement_gv = GetOrCreateGlobalVarForFunc(replacement_func,
op_kind);
auto call_tir_inputs_tuple = GetRef<Tuple>(call->args[1].as<TupleNode>());
- Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple,
buffer_transforms);
+ Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple,
buffer_transforms, axis_separators);
ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir sinfo_args.size() is
expected to be 1";
StructInfo updated_ret_sinfo = UpdateStructInfo(call->sinfo_args[0],
buffer_transforms);
@@ -138,7 +143,7 @@ class AlterOpImplMutator : public ExprMutator {
Call(call_tir_op_, {replacement_gv, updated_inputs}, call->attrs,
{updated_ret_sinfo}));
// Now transform each of the outputs to previous layout.
- return TransformOutputs(updated_call, buffer_transforms,
call->sinfo_args[0]);
+ return TransformOutputs(updated_call, buffer_transforms,
call->sinfo_args[0], axis_separators);
}
Array<TensorStructInfo> GetTensorStructInfoPerOutput(const StructInfo&
output_sinfo) {
@@ -157,17 +162,20 @@ class AlterOpImplMutator : public ExprMutator {
return arr_tensor_sinfo;
}
- Expr TransformLayout(const Expr& expr, const IndexMap& index_map) {
+ Expr TransformLayout(const Expr& expr, const IndexMap& index_map,
+ const Array<IntImm> axis_separators) {
ObjectPtr<LayoutTransformAttrs> attrs =
make_object<LayoutTransformAttrs>();
// We want to avoid two layout_transform ops to share the same index map
even if they are
// identical. The scope of vars used in index map initial indices is local
to the op. Not doing
// so would confuse the structural equality check.
attrs->index_map = std::move(DeepCopyIndexMap(index_map));
+ attrs->axis_separators = std::move(axis_separators);
return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {});
}
Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map,
- const TensorStructInfo& old_tensor_sinfo) {
+ const TensorStructInfo& old_tensor_sinfo,
+ const Array<IntImm>& axis_separator) {
Array<PrimExpr> old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo);
Array<Range> initial_ranges = ConstructRangeFromShape(old_shape);
arith::Analyzer analyzer;
@@ -177,7 +185,7 @@ class AlterOpImplMutator : public ExprMutator {
<< "Only bijective transformations on input/output buffers are
supported, but found "
"padding predicate "
<< padding_predicate << " on initial range " << initial_ranges;
- return TransformLayout(expr, inverse_index_map);
+ return TransformLayout(expr, inverse_index_map, axis_separator);
}
/*!
@@ -202,16 +210,22 @@ class AlterOpImplMutator : public ExprMutator {
/*!
* \brief Updates call inputs with layout transformed inputs
*/
- Tuple UpdateInputs(const Tuple& inputs, const Array<IndexMap>& transforms) {
+ Tuple UpdateInputs(const Tuple& inputs, const Array<IndexMap>& transforms,
+ const Optional<Array<Array<IntImm>>>& axis_separators) {
if (transforms.empty()) return inputs;
Array<Expr> updated_inputs;
int index = 0;
for (const auto& input : inputs->fields) {
+ Array<IntImm> axis_separator;
+ if (axis_separators.defined()) {
+ Array<Array<IntImm>> axis_separators_value = axis_separators.value();
+ axis_separator = axis_separators_value[index];
+ }
auto transform = transforms[index++];
ICHECK(IsTransformBijective(input, transform))
<< "Non bijective transforms on input and output buffers are not
supported.";
- updated_inputs.push_back(TransformLayout(input, transform));
+ updated_inputs.push_back(TransformLayout(input, transform,
axis_separator));
}
return Tuple(updated_inputs);
}
@@ -254,11 +268,13 @@ class AlterOpImplMutator : public ExprMutator {
}
Expr TransformOutputs(const Expr& expr, const Array<IndexMap>&
buffer_transforms,
- const StructInfo& old_struct_info) {
+ const StructInfo& old_struct_info,
+ const Optional<Array<Array<IntImm>>>& axis_separators)
{
if (buffer_transforms.empty()) return expr;
Array<TensorStructInfo> old_output_sinfo =
GetTensorStructInfoPerOutput(old_struct_info);
+ Array<IntImm> axis_sep;
size_t num_outputs = old_output_sinfo.size();
if (num_outputs == 0) return expr;
@@ -266,7 +282,11 @@ class AlterOpImplMutator : public ExprMutator {
// If there is a single output, return the transformed output.
if (num_outputs == 1) {
IndexMap output_map = buffer_transforms[first_output_index];
- return TransformLayoutInverse(expr, output_map, old_output_sinfo[0]);
+ if (axis_separators.defined()) {
+ Array<Array<IntImm>> axis_separators_value = axis_separators.value();
+ axis_sep = axis_separators_value[first_output_index];
+ }
+ return TransformLayoutInverse(expr, output_map, old_output_sinfo[0],
axis_sep);
}
// In case of more than one output, we would have to get each item of the
output tuple,
@@ -274,9 +294,13 @@ class AlterOpImplMutator : public ExprMutator {
Array<Expr> transformed_outputs;
for (size_t i = 0; i + first_output_index < buffer_transforms.size(); ++i)
{
const auto& output_map = buffer_transforms[i + first_output_index];
+ if (axis_separators.defined()) {
+ Array<Array<IntImm>> axis_separators_value = axis_separators.value();
+ axis_sep = axis_separators_value[i + first_output_index];
+ }
auto output = builder_->Normalize(TupleGetItem(expr,
static_cast<int>(i)));
transformed_outputs.push_back(
- TransformLayoutInverse(output, output_map, old_output_sinfo[i]));
+ TransformLayoutInverse(output, output_map, old_output_sinfo[i],
axis_sep));
}
return Tuple(transformed_outputs);
}
@@ -290,6 +314,8 @@ class AlterOpImplMutator : public ExprMutator {
const Map<String, PrimFunc>& op_impl_map_;
/*! \brief Map from kOperatorName attribute to the layout transforms on i/o
buffers */
const Map<String, Array<IndexMap>>& op_buffer_transforms__;
+ /*! \brief Map from kOperatorName attribute to the axis separatos on i/o
buffers */
+ const Map<String, Array<Array<IntImm>>>& op_buffer_axis_separators__;
const Op& call_tir_op_ = Op::Get("relax.call_tir");
const Op& layout_transform_op_ = Op::Get("relax.layout_transform");
@@ -298,10 +324,11 @@ class AlterOpImplMutator : public ExprMutator {
namespace transform {
Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
- const Map<String, Array<IndexMap>>& op_buffer_transforms_) {
+ const Map<String, Array<IndexMap>>& op_buffer_transforms_,
+ const Map<String, Array<Array<IntImm>>>& axis_separators_) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod,
PassContext pc) {
- return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_).Run();
+ return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_,
axis_separators_).Run();
};
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
diff --git a/tests/python/relax/test_transform_alter_op_impl.py
b/tests/python/relax/test_transform_alter_op_impl.py
index 77e2d4e359..aa35067d58 100644
--- a/tests/python/relax/test_transform_alter_op_impl.py
+++ b/tests/python/relax/test_transform_alter_op_impl.py
@@ -20,13 +20,18 @@ import tvm.testing
from tvm import relax
from tvm.script import tir as T, ir as I, relax as R
+from tvm.tir import IndexMap
kOperatorName = "operator_name"
-def _check(before, expected, operator_name, replacement_primfunc,
layout_changes):
+def _check(
+ before, expected, operator_name, replacement_primfunc, layout_changes,
axis_separator=None
+):
after = relax.transform.AlterOpImpl(
- {operator_name: replacement_primfunc}, {operator_name: layout_changes}
+ {operator_name: replacement_primfunc},
+ {operator_name: layout_changes},
+ {operator_name: axis_separator},
)(before)
after = relax.transform.DeadCodeElimination()(after)
tvm.ir.assert_structural_equal(after, expected)
@@ -225,6 +230,79 @@ def test_multiple_outputs():
)
+def test_multiple_outputs_with_axis_sep():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def some_op(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,),
"float32"), output0: T.Buffer((16,), "float32"), output1: T.Buffer((16,),
"float32")):
+ T.func_attr({"operator_name": "relax.some_op"})
+ for ax0 in range(16):
+ with T.block("T_add"):
+ v_ax0 = T.axis.spatial(16, ax0)
+ T.reads(arg0[v_ax0], arg1[v_ax0])
+ T.writes(output0[v_ax0], output1[v_ax0])
+ output0[v_ax0] = arg0[v_ax0] + arg1[v_ax0]
+ output1[v_ax0] = arg0[v_ax0] - arg1[v_ax0]
+
+ @R.function
+ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,),
dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,),
dtype="float32")):
+ with R.dataflow():
+ gv = R.call_tir(Before.some_op, (x, y),
out_sinfo=[R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")])
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def relax_some_op_replacement(arg0: T.Buffer((4, 4), "float32"), arg1:
T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1:
T.Buffer((4, 4), "float32")):
+ T.func_attr({"operator_name": "relax.some_op"})
+ for ax0, ax1 in T.grid(4, 4):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+ T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1])
+ output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0,
v_ax1]
+ output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0,
v_ax1]
+
+ @R.function
+ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,),
dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,),
dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x,
index_map=lambda i: (i // 4, i % 4), pad_value=None, axis_separators=[1])
+ lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y,
index_map=lambda i: (i // 4, i % 4), pad_value=None, axis_separators=[1])
+ lv2 = R.call_tir(Expected.relax_some_op_replacement, (lv,
lv1), out_sinfo=[R.Tensor((4, 4), dtype="float32"), R.Tensor((4, 4),
dtype="float32")])
+ lv3: R.Tensor((4, 4), dtype="float32") = lv2[0]
+ lv4: R.Tensor((16,), dtype="float32") =
R.layout_transform(lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,),
pad_value=None, axis_separators=[1])
+ lv5: R.Tensor((4, 4), dtype="float32") = lv2[1]
+ lv6: R.Tensor((16,), dtype="float32") =
R.layout_transform(lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,),
pad_value=None, axis_separators=[1])
+ gv: R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,),
dtype="float32")) = (lv4, lv6)
+ R.output(gv)
+ return gv
+
+ @T.prim_func
+ def some_op_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4),
"float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4),
"float32")):
+ for ax0, ax1 in T.grid(4, 4):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+ T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1])
+ output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+ output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1]
+ # fmt: on
+
+ index_map, axis_sep = IndexMap.from_func_with_separators(
+ lambda i: (i // 4, IndexMap.AXIS_SEPARATOR, i % 4)
+ )
+ _check(
+ Before,
+ Expected,
+ operator_name="relax.some_op",
+ replacement_primfunc=some_op_2d,
+ layout_changes=[index_map, index_map, index_map, index_map],
+ axis_separator=[axis_sep, axis_sep, axis_sep, axis_sep],
+ )
+
+
def test_unsupported_implicit_padding():
@I.ir_module
class InputModule:
@@ -264,7 +342,11 @@ def test_unsupported_implicit_padding():
tvm.TVMError, match="Non bijective transforms on input and output
buffers are not supported"
):
_ = relax.transform.AlterOpImpl(
- {operator_name: relu_pad}, {operator_name: [index_map, index_map]}
+ {operator_name: relu_pad},
+ {
+ operator_name: [index_map, index_map],
+ },
+ {operator_name: None},
)(before)