This is an automated email from the ASF dual-hosted git repository.
sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ab7c1a91d8 [Relax] Support `input_axis_separator` to allow 2D to 1D
conversion (#17115)
ab7c1a91d8 is described below
commit ab7c1a91d81ae91ad806c2f97c11f6b104ab2ec5
Author: Abhikrant Sharma <[email protected]>
AuthorDate: Mon Jul 1 12:31:07 2024 +0530
[Relax] Support `input_axis_separator` to allow 2D to 1D conversion (#17115)
* [Relax] Support input axis_separator to allow 2D to 1D conversion
Introduce input_axis_separator in relax.transform_layout op to allow
conversion of 2D buffers to 1D buffers.
The conversion from 2D->1D is handled while lowering of transform_layout
operator.
Also introducing support for input_axis_separator in AlterOpImpl pass.
* Fix LINT errors
* Fix review comments
---
include/tvm/relax/attrs/manipulate.h | 8 ++
include/tvm/relax/transform.h | 4 +-
python/tvm/relax/op/manipulate.py | 8 +-
.../tvm/relax/transform/legalize_ops/manipulate.py | 13 +++-
python/tvm/relax/transform/transform.py | 12 ++-
src/relax/op/tensor/manipulate.cc | 4 +-
src/relax/op/tensor/manipulate.h | 4 +-
src/relax/transform/alter_op_impl.cc | 68 ++++++++++++-----
tests/python/relax/test_transform_alter_op_impl.py | 85 +++++++++++++++++++++-
9 files changed, 179 insertions(+), 27 deletions(-)
diff --git a/include/tvm/relax/attrs/manipulate.h
b/include/tvm/relax/attrs/manipulate.h
index b9d0b9f53b..ef4265d73b 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -66,6 +66,12 @@ struct LayoutTransformAttrs : public
tvm::AttrsNode<LayoutTransformAttrs> {
* first input axis that is part of a new flattened axis.
*/
Optional<Array<IntImm>> axis_separators;
+ /*!
+ * axis_separators for input buffers.
+ * Needed to identify if the input buffer to layout_transform
+ * contains axis separator.
+ */
+ Optional<Array<IntImm>> input_axis_separators;
TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") {
TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply.");
@@ -74,6 +80,8 @@ struct LayoutTransformAttrs : public
tvm::AttrsNode<LayoutTransformAttrs> {
"padding. If not specified, the compiler is free to choose any
value.");
TVM_ATTR_FIELD(axis_separators)
.describe("The separators between input axes when generating flat
output axes");
+ TVM_ATTR_FIELD(input_axis_separators)
+ .describe("The separators between axes to regenerate output");
}
}; // struct LayoutTransformAttrs
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index d8f36e4786..5a7b85ac13 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -559,11 +559,13 @@ TVM_DLL Pass DecomposeOpsForTraining(Optional<String>
func_name);
* \param op_buffer_transforms Map from kOperatorName attr to layout
transformations on each of the
* PrimFunc i/o buffers.
* \param axis_separators Map from kOperatorName attr to axis_separators of
each buffer_transforms
+ * \param input_axis_separators Map from kOperatorName attr to axis_separator
for input buffer
* \return The Pass.
*/
TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
const Map<String, Array<tir::IndexMap>>&
op_buffer_transforms,
- const Map<String, Array<Array<IntImm>>>&
axis_separators);
+ const Map<String, Array<Array<IntImm>>>&
axis_separators,
+ const Map<String, Array<Array<IntImm>>>&
input_axis_separators);
/*!
* \brief Layout conversion pass.
diff --git a/python/tvm/relax/op/manipulate.py
b/python/tvm/relax/op/manipulate.py
index 9bd99020e9..da0a09cc7b 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -115,6 +115,7 @@ def layout_transform(
index_map: Union[Callable, IndexMap],
pad_value: Optional[Union[int, float, PrimValue]] = None,
axis_separators: Optional[Union[int, IndexMap.AXIS_SEPARATOR]] = None,
+ input_axis_separators: Optional[Union[int, IndexMap.AXIS_SEPARATOR]] =
None,
):
"""Modifies the layout of a tensor.
@@ -158,7 +159,12 @@ def layout_transform(
if axis_separators is None:
axis_separators = []
- return _ffi_api.layout_transform(x, index_map, pad_value, axis_separators)
# type: ignore
+ if input_axis_separators is None:
+ input_axis_separators = []
+
+ return _ffi_api.layout_transform(
+ x, index_map, pad_value, axis_separators, input_axis_separators
+ )
def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr:
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index e56240dc0d..4d30b97f64 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -181,6 +181,9 @@ def _layout_transform(bb: BlockBuilder, call: Call) -> Expr:
name=name,
)
+ def set_axis_sep(axis_sep: list, sch: tir.schedule, buffer_type: str):
+ sch.set_axis_separator(primfunc_name, (buffer_type, 0),
axis_separators=axis_sep)
+
index_map: tvm.tir.IndexMap = call.attrs.index_map
pad_value = call.attrs.pad_value
if pad_value is not None:
@@ -192,8 +195,10 @@ def _layout_transform(bb: BlockBuilder, call: Call) ->
Expr:
pad_value = float(0.0)
axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR =
call.attrs.axis_separators
+ input_axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR =
call.attrs.input_axis_separators
+
# Convert to list from array
- axis_separators = list(map(lambda x: x.value, axis_separators))
+ axis_separators = [int(sep) for sep in axis_separators]
primfunc_name = "te_layout_transform"
_, padding_predicate =
index_map.non_surjective_inverse(call.args[0].struct_info.shape)
if not isinstance(padding_predicate, tvm.tir.expr.IntImm):
@@ -206,8 +211,10 @@ def _layout_transform(bb: BlockBuilder, call: Call) ->
Expr:
# Create TIR schedule to apply layout changes with axis separators
sch = tir.Schedule(tir_func)
sch.transform_layout(primfunc_name, ("write", 0), index_map, pad_value)
- if len(axis_separators) != 0:
- sch.set_axis_separator(primfunc_name, ("write", 0),
axis_separators=axis_separators)
+ set_axis_sep(axis_separators, sch, "write")
+ if input_axis_separators is not None:
+ input_axis_separators = [int(sep) for sep in input_axis_separators]
+ set_axis_sep(input_axis_separators, sch, "read")
gvar = bb.add_func(sch.mod["main"], primfunc_name)
output_shape = index_map.map_shape(list(call_args[0].struct_info.shape))
output_dtype = call_args[0].struct_info.dtype
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 38e7994eb9..3528b4429e 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -24,6 +24,7 @@ from typing import Callable, Dict, List, Mapping, Optional,
Sequence, Tuple, Uni
import numpy as np # type: ignore
import tvm.ir
+from tvm.ir.container import Array
from tvm.relax import Expr, Var, StructInfo
from tvm.relax.dpl import DFPattern
from tvm.runtime import NDArray, Object
@@ -1280,6 +1281,7 @@ def AlterOpImpl(
op_impl_map: Dict[str, PrimFunc],
op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]],
op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR,
Callable]]],
+ op_buffer_input_axis_separators: Dict[str,
List[Union[IndexMap.AXIS_SEPARATOR, Callable]]],
):
"""Replace all PrimFunc's which have matching 'operator_name' attribute,
with replacement
PrimFunc that could possibly have different layouts on i/o buffers. The
layout
@@ -1295,6 +1297,8 @@ def AlterOpImpl(
op_kind to layout transformation map for each of the buffers
op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR,
Callable]]]
op_kind to axis_separator for each index_map
+ op_buffer_input_axis_separators: Dict[str,
List[Union[IndexMap.AXIS_SEPARATOR, Callable]]]
+ op_kind to axis_separator for input index_map
Returns
-------
@@ -1303,13 +1307,19 @@ def AlterOpImpl(
for operator_name, transform_list in op_buffer_transforms.items():
l = []
for transform in transform_list:
+ # Extract the index_map
if isinstance(transform, Callable):
transform = IndexMap.from_func_with_separators(transform)[0]
+ elif isinstance(transform, (Array, tuple)) and
isinstance(transform[0], IndexMap):
+ transform = transform[0]
l.append(transform)
op_buffer_transforms[operator_name] = l
return _ffi_api.AlterOpImpl(
- op_impl_map, op_buffer_transforms, op_buffer_axis_separators
+ op_impl_map,
+ op_buffer_transforms,
+ op_buffer_axis_separators,
+ op_buffer_input_axis_separators,
) # type: ignore
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index ad2a812c82..07c90756bf 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -472,11 +472,13 @@ TVM_REGISTER_OP("relax.flatten")
TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);
Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue>
pad_value,
- Optional<Array<IntImm>> axis_separators) {
+ Optional<Array<IntImm>> axis_separators,
+ Optional<Array<IntImm>> input_axis_separators) {
ObjectPtr<LayoutTransformAttrs> attrs = make_object<LayoutTransformAttrs>();
attrs->index_map = std::move(index_map);
attrs->pad_value = std::move(pad_value);
attrs->axis_separators = std::move(axis_separators);
+ attrs->input_axis_separators = std::move(input_axis_separators);
static const Op& op = Op::Get("relax.layout_transform");
return Call(op, {std::move(x)}, Attrs{attrs}, {});
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index b19e3b8507..32aa107768 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -67,10 +67,12 @@ Expr flatten(Expr x);
* not specified, any value can be used.
* \param axis_separators Array of values to differentiate between input axes
* when generating flattened output axes.
+ * \param input axis_separators Array of values for input buffer.
* \return The transformed result.
*/
Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue>
pad_value,
- Optional<Array<IntImm>> axis_separators);
+ Optional<Array<IntImm>> axis_separators,
+ Optional<Array<IntImm>> input_axis_separators = NullOpt);
/*!
* \brief Permutes the dimensions of an array.
diff --git a/src/relax/transform/alter_op_impl.cc
b/src/relax/transform/alter_op_impl.cc
index 2cb226d56e..aaf643f801 100644
--- a/src/relax/transform/alter_op_impl.cc
+++ b/src/relax/transform/alter_op_impl.cc
@@ -81,12 +81,14 @@ class AlterOpImplMutator : public ExprMutator {
public:
AlterOpImplMutator(const IRModule& mod, const Map<String, tir::PrimFunc>&
op_impl_map,
const Map<String, Array<IndexMap>>& op_buffer_transforms_,
- const Map<String, Array<Array<IntImm>>>& axis_separators_)
+ const Map<String, Array<Array<IntImm>>>& axis_separators_,
+ const Map<String, Array<Array<IntImm>>>&
input_axis_separators_)
: ExprMutator(mod),
mod_(mod),
op_impl_map_(op_impl_map),
op_buffer_transforms__(op_buffer_transforms_),
- op_buffer_axis_separators__(axis_separators_) {}
+ op_buffer_axis_separators__(axis_separators_),
+ op_buffer_input_axis_separators__(input_axis_separators_) {}
IRModule Run() {
for (const auto& gv : mod_->GetGlobalVars()) {
@@ -127,9 +129,12 @@ class AlterOpImplMutator : public ExprMutator {
Array<IndexMap> buffer_transforms;
Optional<Array<Array<IntImm>>> axis_separators;
+ Optional<Array<Array<IntImm>>> input_axis_separators;
if (op_buffer_transforms__.count(op_kind)) buffer_transforms =
op_buffer_transforms__[op_kind];
if (op_buffer_axis_separators__.count(op_kind))
axis_separators = op_buffer_axis_separators__[op_kind];
+ if (op_buffer_input_axis_separators__.count(op_kind))
+ input_axis_separators = op_buffer_input_axis_separators__[op_kind];
ICHECK(buffer_transforms.empty() || buffer_transforms.size() ==
replacement_func->params.size())
<< "Either the i/o buffers do not require any transformations or
transformations for each "
@@ -140,7 +145,8 @@ class AlterOpImplMutator : public ExprMutator {
GlobalVar replacement_gv = GetOrCreateGlobalVarForFunc(replacement_func,
op_kind);
auto call_tir_inputs_tuple = GetRef<Tuple>(call->args[1].as<TupleNode>());
- Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple,
buffer_transforms, axis_separators);
+ Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple,
buffer_transforms, axis_separators,
+ input_axis_separators);
ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir sinfo_args.size() is
expected to be 1";
StructInfo updated_ret_sinfo = UpdateStructInfo(call->sinfo_args[0],
buffer_transforms);
@@ -148,7 +154,8 @@ class AlterOpImplMutator : public ExprMutator {
Call(call_tir_op_, {replacement_gv, updated_inputs}, call->attrs,
{updated_ret_sinfo}));
// Now transform each of the outputs to previous layout.
- return TransformOutputs(updated_call, buffer_transforms,
call->sinfo_args[0], axis_separators);
+ return TransformOutputs(updated_call, buffer_transforms,
call->sinfo_args[0], axis_separators,
+ input_axis_separators);
}
Array<TensorStructInfo> GetTensorStructInfoPerOutput(const StructInfo&
output_sinfo) {
@@ -175,7 +182,8 @@ class AlterOpImplMutator : public ExprMutator {
}
Expr TransformLayout(const Expr& expr, const IndexMap& index_map,
- const Array<IntImm>& axis_separators) {
+ const Array<IntImm>& axis_separators,
+ const Array<IntImm>& input_axis_separators) {
if (IsScalarConstant(expr) || index_map.get() == nullptr) {
return expr;
}
@@ -185,6 +193,7 @@ class AlterOpImplMutator : public ExprMutator {
// so would confuse the structural equality check.
attrs->index_map = std::move(DeepCopyIndexMap(index_map));
attrs->axis_separators = std::move(axis_separators);
+ attrs->input_axis_separators = std::move(input_axis_separators);
return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {});
}
@@ -232,7 +241,8 @@ class AlterOpImplMutator : public ExprMutator {
Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map,
const TensorStructInfo& old_tensor_sinfo,
- const Array<IntImm>& axis_separator) {
+ const Array<IntImm>& axis_separator,
+ const Array<IntImm>& input_axis_separator) {
if (IsScalarConstant(expr) || index_map.get() == nullptr) {
return expr;
}
@@ -243,10 +253,10 @@ class AlterOpImplMutator : public ExprMutator {
index_map.NonSurjectiveInverse(initial_ranges, &analyzer);
if (tir::is_zero(padding_predicate)) {
- return TransformLayout(expr, inverse_index_map, axis_separator);
+ return TransformLayout(expr, inverse_index_map, axis_separator,
input_axis_separator);
} else {
- auto padded_expr =
- builder_->Normalize(TransformLayout(expr, inverse_index_map,
axis_separator));
+ auto padded_expr = builder_->Normalize(
+ TransformLayout(expr, inverse_index_map, axis_separator,
input_axis_separator));
const auto& tensor_sinfo =
Downcast<TensorStructInfo>(padded_expr->struct_info_);
GlobalVar gv_remove_pad = GetOrCreateRemovePadOp(old_shape,
tensor_sinfo->dtype);
@@ -277,19 +287,26 @@ class AlterOpImplMutator : public ExprMutator {
* \brief Updates call inputs with layout transformed inputs
*/
Tuple UpdateInputs(const Tuple& inputs, const Array<IndexMap>& transforms,
- const Optional<Array<Array<IntImm>>>& axis_separators) {
+ const Optional<Array<Array<IntImm>>>& axis_separators,
+ const Optional<Array<Array<IntImm>>>&
input_axis_separators) {
if (transforms.empty()) return inputs;
Array<Expr> updated_inputs;
int index = 0;
for (const auto& input : inputs->fields) {
Array<IntImm> axis_separator;
+ Array<IntImm> input_axis_separator;
if (axis_separators.defined()) {
Array<Array<IntImm>> axis_separators_value = axis_separators.value();
axis_separator = axis_separators_value[index];
}
+ if (input_axis_separators.defined()) {
+ Array<Array<IntImm>> input_axis_separators_value =
input_axis_separators.value();
+ input_axis_separator = input_axis_separators_value[index];
+ }
auto transform = transforms[index++];
- updated_inputs.push_back(TransformLayout(input, transform,
axis_separator));
+ updated_inputs.push_back(
+ TransformLayout(input, transform, axis_separator,
input_axis_separator));
}
return Tuple(updated_inputs);
}
@@ -338,12 +355,13 @@ class AlterOpImplMutator : public ExprMutator {
Expr TransformOutputs(const Expr& expr, const Array<IndexMap>&
buffer_transforms,
const StructInfo& old_struct_info,
- const Optional<Array<Array<IntImm>>>& axis_separators)
{
+ const Optional<Array<Array<IntImm>>>& axis_separators,
+ const Optional<Array<Array<IntImm>>>&
input_axis_separators) {
if (buffer_transforms.empty()) return expr;
Array<TensorStructInfo> old_output_sinfo =
GetTensorStructInfoPerOutput(old_struct_info);
- Array<IntImm> axis_sep;
+ Array<IntImm> axis_sep, input_axis_sep;
size_t num_outputs = old_output_sinfo.size();
if (num_outputs == 0) return expr;
@@ -355,7 +373,12 @@ class AlterOpImplMutator : public ExprMutator {
Array<Array<IntImm>> axis_separators_value = axis_separators.value();
axis_sep = axis_separators_value[first_output_index];
}
- return TransformLayoutInverse(expr, output_map, old_output_sinfo[0],
axis_sep);
+ if (input_axis_separators.defined()) {
+ Array<Array<IntImm>> input_axis_separators_value =
input_axis_separators.value();
+ input_axis_sep = input_axis_separators_value[first_output_index];
+ }
+ return TransformLayoutInverse(expr, output_map, old_output_sinfo[0],
axis_sep,
+ input_axis_sep);
}
// In case of more than one output, we would have to get each item of the
output tuple,
@@ -367,9 +390,13 @@ class AlterOpImplMutator : public ExprMutator {
Array<Array<IntImm>> axis_separators_value = axis_separators.value();
axis_sep = axis_separators_value[i + first_output_index];
}
+ if (input_axis_separators.defined()) {
+ Array<Array<IntImm>> input_axis_separators_value =
input_axis_separators.value();
+ input_axis_sep = input_axis_separators_value[i + first_output_index];
+ }
auto output = builder_->Normalize(TupleGetItem(expr,
static_cast<int>(i)));
- transformed_outputs.push_back(
- TransformLayoutInverse(output, output_map, old_output_sinfo[i],
axis_sep));
+ transformed_outputs.push_back(TransformLayoutInverse(output, output_map,
old_output_sinfo[i],
+ axis_sep,
input_axis_sep));
}
return Tuple(transformed_outputs);
}
@@ -387,6 +414,8 @@ class AlterOpImplMutator : public ExprMutator {
const Map<String, Array<IndexMap>>& op_buffer_transforms__;
/*! \brief Map from kOperatorName attribute to the axis separatos on i/o
buffers */
const Map<String, Array<Array<IntImm>>>& op_buffer_axis_separators__;
+ /*! \brief Map from kOperatorName attribute to the input axis separatos */
+ const Map<String, Array<Array<IntImm>>>& op_buffer_input_axis_separators__;
const Op& call_tir_op_ = Op::Get("relax.call_tir");
const Op& layout_transform_op_ = Op::Get("relax.layout_transform");
@@ -396,10 +425,13 @@ namespace transform {
Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
const Map<String, Array<IndexMap>>& op_buffer_transforms_,
- const Map<String, Array<Array<IntImm>>>& axis_separators_) {
+ const Map<String, Array<Array<IntImm>>>& axis_separators_,
+ const Map<String, Array<Array<IntImm>>>&
input_axis_separators_) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod,
PassContext pc) {
- return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_,
axis_separators_).Run();
+ return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_,
axis_separators_,
+ input_axis_separators_)
+ .Run();
};
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
diff --git a/tests/python/relax/test_transform_alter_op_impl.py
b/tests/python/relax/test_transform_alter_op_impl.py
index f2bad31f21..f1824eba6b 100644
--- a/tests/python/relax/test_transform_alter_op_impl.py
+++ b/tests/python/relax/test_transform_alter_op_impl.py
@@ -26,12 +26,19 @@ kOperatorName = "operator_name"
def _check(
- before, expected, operator_name, replacement_primfunc, layout_changes,
axis_separator=None
+ before,
+ expected,
+ operator_name,
+ replacement_primfunc,
+ layout_changes,
+ axis_separator=None,
+ input_axis_separator=None,
):
after = relax.transform.AlterOpImpl(
{operator_name: replacement_primfunc},
{operator_name: layout_changes},
{operator_name: axis_separator},
+ {operator_name: input_axis_separator},
)(before)
after = relax.transform.DeadCodeElimination()(after)
tvm.ir.assert_structural_equal(after, expected)
@@ -572,5 +579,81 @@ def test_reshape():
)
+def test_input_axis_separator():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def some_op(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,),
"float32"), output0: T.Buffer((16,), "float32"), output1: T.Buffer((16,),
"float32")):
+ T.func_attr({"operator_name": "relax.some_op"})
+ for ax0 in range(16):
+ with T.block("T_add"):
+ v_ax0 = T.axis.spatial(16, ax0)
+ T.reads(arg0[v_ax0], arg1[v_ax0])
+ T.writes(output0[v_ax0], output1[v_ax0])
+ output0[v_ax0] = arg0[v_ax0] + arg1[v_ax0]
+ output1[v_ax0] = arg0[v_ax0] - arg1[v_ax0]
+
+ @R.function
+ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,),
dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,),
dtype="float32")):
+ with R.dataflow():
+ gv = R.call_tir(Before.some_op, (x, y),
out_sinfo=[R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")])
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func(private=True)
+ def relax_some_op_replacement(arg0: T.Buffer((4, 4), "float32"), arg1:
T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1:
T.Buffer((4, 4), "float32")):
+ T.func_attr({"operator_name": "relax.some_op"})
+ for ax0, ax1 in T.grid(4, 4):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0,
v_ax1]
+ output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0,
v_ax1]
+
+ @R.function
+ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,),
dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,),
dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x,
index_map=lambda i: (i // 4, i % 4), pad_value=None, axis_separators=[1])
+ lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y,
index_map=lambda i: (i // 4, i % 4), pad_value=None, axis_separators=[1])
+ lv2 = R.call_tir(Expected.relax_some_op_replacement, (lv,
lv1), out_sinfo=[R.Tensor((4, 4), dtype="float32"), R.Tensor((4, 4),
dtype="float32")])
+ lv3: R.Tensor((4, 4), dtype="float32") = lv2[0]
+ lv4: R.Tensor((16,), dtype="float32") =
R.layout_transform(lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,),
pad_value=None, axis_separators=[], input_axis_separators=[1])
+ lv5: R.Tensor((4, 4), dtype="float32") = lv2[1]
+ lv6: R.Tensor((16,), dtype="float32") =
R.layout_transform(lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,),
pad_value=None, axis_separators=[], input_axis_separators=[1])
+ gv: R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,),
dtype="float32")) = (lv4, lv6)
+ R.output(gv)
+ return gv
+
+ @T.prim_func(private=True)
+ def some_op_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4),
"float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4),
"float32")):
+ for ax0, ax1 in T.grid(4, 4):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+ output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1]
+ # fmt: on
+
+ index_map_axis_sep = IndexMap.from_func_with_separators(
+ lambda i: (i // 4, IndexMap.AXIS_SEPARATOR, i % 4)
+ )
+
+ _check(
+ Before,
+ Expected,
+ operator_name="relax.some_op",
+ replacement_primfunc=some_op_2d,
+ layout_changes=[
+ index_map_axis_sep,
+ index_map_axis_sep,
+ index_map_axis_sep,
+ index_map_axis_sep,
+ ],
+ axis_separator=[index_map_axis_sep[1], index_map_axis_sep[1], [], []],
+ input_axis_separator=[[], [], index_map_axis_sep[1],
index_map_axis_sep[1]],
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()