This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 3c0639eee9 [Unity][Pass] Add a pass to alter the TIR implementation of
an operator (#14215)
3c0639eee9 is described below
commit 3c0639eee9a332bd405864321a7874185f76dcd7
Author: Prakalp Srivastava <[email protected]>
AuthorDate: Wed Mar 15 15:02:38 2023 -0400
[Unity][Pass] Add a pass to alter the TIR implementation of an operator
(#14215)
* [Unity][Pass] Add a pass to alter the TIR
implementation of an operator (identified
by operator_kind attribute on PrimFunc).
It also inserts layout changes to i/o
buffers at Relax level.
* deep copy index map to avoid structural_equality fail
* do not mark layouts as frozen
* address comments
* fix call_tir global symbol in tests
---
include/tvm/relax/struct_info.h | 7 +
include/tvm/relax/transform.h | 17 +-
python/tvm/relax/transform/transform.py | 32 ++
src/relax/transform/alter_op_impl.cc | 310 +++++++++++++++++++
tests/python/relax/test_transform_alter_op_impl.py | 342 +++++++++++++++++++++
5 files changed, 707 insertions(+), 1 deletion(-)
diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h
index b9aebc5494..0c1973bcea 100644
--- a/include/tvm/relax/struct_info.h
+++ b/include/tvm/relax/struct_info.h
@@ -170,6 +170,13 @@ class TensorStructInfoNode : public StructInfoNode {
/*! \return Whether the struct info contains unknown dtype. */
bool IsUnknownDtype() const { return dtype.is_void(); }
+ /*! \return Shape if it is known. */
+ Optional<Array<PrimExpr>> GetShape() const {
+ if (!shape.defined()) return {};
+ ShapeStructInfo shape_sinfo =
Downcast<ShapeStructInfo>(this->shape.value()->struct_info_);
+ return shape_sinfo->values;
+ }
+
void VisitAttrs(AttrVisitor* v) {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 446b75da9f..3ff863dd09 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -27,7 +27,8 @@
#include <tvm/ir/transform.h>
#include <tvm/relax/dataflow_pattern.h>
#include <tvm/relax/expr.h>
-
+#include <tvm/tir/function.h>
+#include <tvm/tir/index_map.h>
namespace tvm {
namespace relax {
namespace transform {
@@ -279,6 +280,20 @@ TVM_DLL Pass RunCodegen(Optional<Map<String, Map<String,
ObjectRef>>> target_opt
* \return The Pass.
*/
TVM_DLL Pass SimplifyNormInference();
+/*!
+ * \brief Returns a pass which replaces PrimFuncs which have matching
kOperatorName attribute in \p
+ * op_impl_map, with replacement PrimFunc that could possibly have different
layouts on i/o
+ * buffers. The layout transformations on i/o buffers is present in the \p
op_buffer_transforms. The
+ * pass inserts the layout transformations in the call sites of PrimFuncs
being replaced to
+ * transform i/o buffers into expected layout.
+ *
+ * \param op_impl_map Map from from kOperatorName attr (e.g., relax.conv2d) to
replacement PrimFunc
+ * \param op_buffer_transforms Map from kOperatorName attr to layout
transformations on each of the
+ * PrimFunc i/o buffers.
+ * \return The Pass.
+ */
+TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
+ const Map<String, Array<tir::IndexMap>>&
op_buffer_transforms);
} // namespace transform
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 97c8772b3b..c59104ca58 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -20,6 +20,7 @@ import functools
import inspect
import types
from typing import Callable, Dict, Union, Optional, List, Tuple
+from tvm.tir import PrimFunc, IndexMap
import numpy as np # type: ignore
import tvm.ir
from tvm.runtime import NDArray
@@ -542,6 +543,37 @@ def SimplifyNormInference() -> tvm.ir.transform.Pass:
return _ffi_api.SimplifyNormInference() # type: ignore
+def AlterOpImpl(
+ op_impl_map: Dict[str, PrimFunc],
+ op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]],
+):
+ """Replace all PrimFunc's which have matching 'operator_name' attribute,
with replacement
+ PrimFunc that could possibly have different layouts on i/o buffers. The
layout
+ transformations on i/o buffers is present in the op_buffer_transforms map.
Inserts the layout
+ transformations in the call sites of PrimFuncs being replaced to transform
i/o
+ tensors into expected layout by new PrimFunc.
+
+ Parameters
+ ----------
+ op_impl_map: Dict[str, PrimFunc]
+ op_kind to PrimFunc map
+ op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]
+ op_kind to layout transformation map for each of the buffers
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ """
+ for operator_name, transform_list in op_buffer_transforms.items():
+ l = []
+ for transform in transform_list:
+ if isinstance(transform, Callable):
+ transform = IndexMap.from_func(transform)
+ l.append(transform)
+ op_buffer_transforms[operator_name] = l
+
+ return _ffi_api.AlterOpImpl(op_impl_map, op_buffer_transforms) # type:
ignore
+
+
def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass."""
diff --git a/src/relax/transform/alter_op_impl.cc
b/src/relax/transform/alter_op_impl.cc
new file mode 100644
index 0000000000..6a740b4f55
--- /dev/null
+++ b/src/relax/transform/alter_op_impl.cc
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/alter_op_impl.cc
+ * \brief Change the layout of PrimFunc in the graph. It uses the
kOperatorName attribute to
+ * identify PrimFuncs to be replaced. Marks the new PrimFuncs with
kFrozenLayout attribute set to
+ * true.
+ */
+#include <tvm/ir/attrs.h>
+#include <tvm/node/serialization.h>
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/attrs/manipulate.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/transform.h>
+namespace tvm {
+namespace relax {
+
+using namespace tir;
+static constexpr const char* kOperatorName = "operator_name";
+
+/*! \brief Construct ranges from shape dimensions */
+static Array<Range> ConstructRangeFromShape(const Array<PrimExpr>& shape) {
+ return shape.Map([](const PrimExpr& dim) { return
Range(tir::make_zero(dim.dtype()), dim); });
+}
+
+static Array<PrimExpr> GetShapeFromTensorStructInfo(const TensorStructInfo&
tensor_sinfo) {
+ auto shape = tensor_sinfo->GetShape();
+ ICHECK(shape.defined());
+ return shape.value();
+}
+
+static Array<PrimExpr> GetShapeFromTensor(const Expr& expr) {
+ const auto& tensor_sinfo = Downcast<TensorStructInfo>(expr->struct_info_);
+ return GetShapeFromTensorStructInfo(tensor_sinfo);
+}
+
+static IndexMap DeepCopyIndexMap(const IndexMap& index_map) {
+ return Downcast<IndexMap>(LoadJSON(SaveJSON(index_map)));
+}
+
+/*! \brief Checks if the \p transform is bijective on the shape of \p expr */
+bool IsTransformBijective(const Expr& expr, const IndexMap& transform) {
+ Array<PrimExpr> input_shape = GetShapeFromTensor(expr);
+ Array<Range> initial_ranges = ConstructRangeFromShape(input_shape);
+ auto [inverse, padding_predicate] =
transform.NonSurjectiveInverse(initial_ranges);
+ (void)inverse; // to avoid unused variable warning;
+ arith::Analyzer analyzer;
+ if (!analyzer.CanProve(!padding_predicate)) return false;
+ return true;
+}
+
+/*!
+ * \brief Replace each call_tir to PrimFunc which matches the kOperatorName
attribute with the
+ * provided replacement PrimFunc and mark it with kFrozenLayout attribute.
Insert layout
+ * transformations on i/o buffers as necessary for correctness.
+ */
+class AlterOpImplMutator : public ExprMutator {
+ public:
+ AlterOpImplMutator(const IRModule& mod, const Map<String, tir::PrimFunc>&
op_impl_map,
+ const Map<String, Array<IndexMap>>& op_buffer_transforms_)
+ : ExprMutator(mod),
+ mod_(mod),
+ op_impl_map_(op_impl_map),
+ op_buffer_transforms__(op_buffer_transforms_) {}
+
+ IRModule Run() {
+ for (const auto& [gv, func] : mod_->functions) {
+ if (func->IsInstance<relax::FunctionNode>()) {
+ relax::Function update_func = Downcast<Function>(VisitExpr(func));
+ builder_->UpdateFunction(gv, update_func);
+ }
+ }
+ return builder_->GetContextIRModule();
+ }
+
+ private:
+ Expr VisitExpr_(const CallNode* op) final {
+ auto call = Downcast<Call>(ExprMutator::VisitExpr_(op));
+
+ // TODO(@tvm-team): When we differentiate the call for tir function and
packed function,
+ // this logic should be changed accordingly.
+ if (!call->op.same_as(call_tir_op_)) return call;
+
+ // Do not do anything for external function
+ if (call->args[0].as<ExternFuncNode>()) return call;
+
+ // Get operator name from callee
+ ICHECK(call->args[0]->IsInstance<GlobalVarNode>());
+ const tir::PrimFunc& old_func =
+
Downcast<tir::PrimFunc>(mod_->Lookup(Downcast<GlobalVar>(call->args[0])));
+ Optional<String> maybe_op_kind =
old_func->attrs.GetAttr<String>(kOperatorName);
+
+ // If the callee does not have kOperatorName attribute or no replacement
is requested for
+ // it, nothing to do here.
+ if (!maybe_op_kind.defined() || op_impl_map_.count(maybe_op_kind.value())
== 0) return call;
+ auto op_kind = maybe_op_kind.value();
+
+ const auto& replacement_func = op_impl_map_[op_kind];
+
+ Array<IndexMap> buffer_transforms;
+ if (op_buffer_transforms__.count(op_kind)) buffer_transforms =
op_buffer_transforms__[op_kind];
+
+ ICHECK(buffer_transforms.empty() || buffer_transforms.size() ==
replacement_func->params.size())
+ << "Either the i/o buffers do not require any transformations or
transformations for each "
+ "buffer is provided.";
+ ICHECK_EQ(old_func->params.size(), replacement_func->params.size())
+ << "Number of parameters of old and replacement PrimFunc must match";
+
+ GlobalVar replacement_gv = GetOrCreateGlobalVarForFunc(replacement_func,
op_kind);
+
+ auto call_tir_inputs_tuple = GetRef<Tuple>(call->args[1].as<TupleNode>());
+ Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple,
buffer_transforms);
+
+ ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir sinfo_args.size() is
expected to be 1";
+ StructInfo updated_ret_sinfo = UpdateStructInfo(call->sinfo_args[0],
buffer_transforms);
+ auto updated_call = builder_->Normalize(
+ Call(call_tir_op_, {replacement_gv, updated_inputs}, call->attrs,
{updated_ret_sinfo}));
+
+ // Now transform each of the outputs to previous layout.
+ return TransformOutputs(updated_call, buffer_transforms,
call->sinfo_args[0]);
+ }
+
+ Array<TensorStructInfo> GetTensorStructInfoPerOutput(const StructInfo&
output_sinfo) {
+ if (const auto* tensor_sinfo = output_sinfo.as<TensorStructInfoNode>())
+ return {GetRef<TensorStructInfo>(tensor_sinfo)};
+ const auto* tuple_sinfo = output_sinfo.as<TupleStructInfoNode>();
+ ICHECK(tuple_sinfo);
+
+ Array<TensorStructInfo> arr_tensor_sinfo;
+ arr_tensor_sinfo.reserve(tuple_sinfo->fields.size());
+ for (const auto& sinfo : tuple_sinfo->fields) {
+ const auto* tensor_sinfo = sinfo.as<TensorStructInfoNode>();
+ ICHECK(tensor_sinfo) << "Nested tuples in output of call_tir is not
supported yet";
+ arr_tensor_sinfo.push_back(GetRef<TensorStructInfo>(tensor_sinfo));
+ }
+ return arr_tensor_sinfo;
+ }
+
+ Expr TransformLayout(const Expr& expr, const IndexMap& index_map) {
+ ObjectPtr<LayoutTransformAttrs> attrs =
make_object<LayoutTransformAttrs>();
+ // We want to avoid two layout_transform ops to share the same index map
even if they are
+ // identical. The scope of vars used in index map initial indices is local
to the op. Not doing
+ // so would confuse the structural equality check.
+ attrs->index_map = std::move(DeepCopyIndexMap(index_map));
+ return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {});
+ }
+
+ Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map,
+ const TensorStructInfo& old_tensor_sinfo) {
+ Array<PrimExpr> old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo);
+ Array<Range> initial_ranges = ConstructRangeFromShape(old_shape);
+ auto [inverse_index_map, padding_predicate] =
index_map.NonSurjectiveInverse(initial_ranges);
+ ICHECK(tir::is_zero(padding_predicate))
+ << "Only bijective transformations on input/output buffers are
supported, but found "
+ "padding predicate "
+ << padding_predicate << " on initial range " << initial_ranges;
+ return TransformLayout(expr, inverse_index_map);
+ }
+
+ /*!
+ * \brief Adds the \p replacement_func to the module if it has not already
been added before.
+ * \returns The global var associated with the PrimFunc.
+ */
+ GlobalVar GetOrCreateGlobalVarForFunc(const PrimFunc& replacement_func,
const String& op_kind) {
+ if (cache_.count(replacement_func) != 0) {
+ return cache_[replacement_func];
+ }
+ // Retain the operator name attribute on the replacement PrimFunc. This
can help any future
+ // passes that use kOperatorName attribute to identify operator
represented by a PrimFunc.
+ PrimFunc replacement_func_with_frozen_layout =
+ WithAttr(replacement_func, kOperatorName, op_kind);
+
+ GlobalVar gv_replacement =
+ builder_->AddFunction(replacement_func_with_frozen_layout, op_kind +
"_replacement");
+ cache_.Set(replacement_func, gv_replacement);
+ return gv_replacement;
+ }
+
+ /*!
+ * \brief Updates call inputs with layout transformed inputs
+ */
+ Tuple UpdateInputs(const Tuple& inputs, const Array<IndexMap>& transforms) {
+ if (transforms.empty()) return inputs;
+
+ Array<Expr> updated_inputs;
+ int index = 0;
+ for (const auto& input : inputs->fields) {
+ auto transform = transforms[index++];
+ ICHECK(IsTransformBijective(input, transform))
+ << "Non bijective transforms on input and output buffers are not
supported.";
+ updated_inputs.push_back(TransformLayout(input, transform));
+ }
+ return Tuple(updated_inputs);
+ }
+
+ /*! \brief Updates output struct info */
+ StructInfo UpdateStructInfo(const StructInfo& out_sinfo,
+ const Array<IndexMap>& buffer_transforms) {
+ if (buffer_transforms.empty()) return out_sinfo;
+
+ if (out_sinfo->IsInstance<TensorStructInfoNode>())
+ return UpdateStructInfo(Downcast<TensorStructInfo>(out_sinfo),
+ buffer_transforms[buffer_transforms.size() - 1]);
+
+ ICHECK(out_sinfo->IsInstance<TupleStructInfoNode>())
+ << "Expect output struct info of call_tir to be either TupleStructInfo
or "
+ "TensorStructInfo, but got "
+ << out_sinfo;
+
+ const auto& tuple_sinfo = Downcast<TupleStructInfo>(out_sinfo);
+ Array<StructInfo> sinfo_fields;
+ size_t first_output_index = buffer_transforms.size() -
tuple_sinfo->fields.size();
+ size_t i = 0;
+ for (const auto& si : tuple_sinfo->fields) {
+ ICHECK(si->IsInstance<TensorStructInfoNode>())
+ << "Fields of TupleStructInfo must be TensorStructInfo for call_tir "
+ "output structinfo, but got "
+ << si;
+ sinfo_fields.push_back(UpdateStructInfo(Downcast<TensorStructInfo>(si),
+
buffer_transforms[first_output_index + i++]));
+ }
+ return TupleStructInfo(sinfo_fields);
+ }
+
+ /*! \brief Returns the TensorStructInfo after applying the \p transform on
its shape */
+ StructInfo UpdateStructInfo(const TensorStructInfo& tensor_sinfo, const
IndexMap& transform) {
+ auto shape = GetShapeFromTensorStructInfo(tensor_sinfo);
+ auto new_shape = transform->MapShape(shape);
+ return TensorStructInfo(ShapeExpr(new_shape), tensor_sinfo->dtype);
+ }
+
+ Expr TransformOutputs(const Expr& expr, const Array<IndexMap>&
buffer_transforms,
+ const StructInfo& old_struct_info) {
+ if (buffer_transforms.empty()) return expr;
+
+ Array<TensorStructInfo> old_output_sinfo =
GetTensorStructInfoPerOutput(old_struct_info);
+
+ size_t num_outputs = old_output_sinfo.size();
+ if (num_outputs == 0) return expr;
+
+ size_t first_output_index = buffer_transforms.size() - num_outputs;
+ // If there is a single output, return the transformed output.
+ if (num_outputs == 1) {
+ IndexMap output_map = buffer_transforms[first_output_index];
+ return TransformLayoutInverse(expr, output_map, old_output_sinfo[0]);
+ }
+
+ // In case of more than one output, we would have to get each item of the
output tuple,
+ // transform it and return a tuple of all transformed outputs.
+ Array<Expr> transformed_outputs;
+ for (size_t i = 0; i + first_output_index < buffer_transforms.size(); ++i)
{
+ const auto& output_map = buffer_transforms[i + first_output_index];
+ auto output = builder_->Normalize(TupleGetItem(expr,
static_cast<int>(i)));
+ transformed_outputs.push_back(
+ TransformLayoutInverse(output, output_map, old_output_sinfo[i]));
+ }
+ return Tuple(transformed_outputs);
+ }
+
+ private:
+ /*! \brief Cache to keep track of the GlobalVar associated with the new
PrimFunc added */
+ Map<PrimFunc, GlobalVar> cache_;
+ /*! \brief Input IRModule */
+ const IRModule& mod_;
+ /*! \brief Map from kOperatorName attribute to the replacement PrimFunc */
+ const Map<String, PrimFunc>& op_impl_map_;
+ /*! \brief Map from kOperatorName attribute to the layout transforms on i/o
buffers */
+ const Map<String, Array<IndexMap>>& op_buffer_transforms__;
+
+ const Op& call_tir_op_ = Op::Get("relax.call_tir");
+ const Op& layout_transform_op_ = Op::Get("relax.layout_transform");
+};
+
+namespace transform {
+
+Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
+ const Map<String, Array<IndexMap>>& op_buffer_transforms_) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod,
+
PassContext pc) {
+ return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_).Run();
+ };
+ return CreateModulePass(/*pass_function=*/pass_func, //
+ /*opt_level=*/0, //
+ /*pass_name=*/"AlterOpImpl", //
+ /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.AlterOpImpl").set_body_typed(AlterOpImpl);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_transform_alter_op_impl.py
b/tests/python/relax/test_transform_alter_op_impl.py
new file mode 100644
index 0000000000..e8fa29a074
--- /dev/null
+++ b/tests/python/relax/test_transform_alter_op_impl.py
@@ -0,0 +1,342 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import tvm.testing
+
+from tvm import relax
+from tvm.script import tir as T, ir as I, relax as R
+
+kOperatorName = "operator_name"
+
+
+def _check(before, expected, operator_name, replacement_primfunc,
layout_changes):
+ after = relax.transform.AlterOpImpl(
+ {operator_name: replacement_primfunc}, {operator_name: layout_changes}
+ )(before)
+ after = relax.transform.RemoveUnusedFunctions()(after)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_single_output():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def add(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,),
"float32"), output: T.Buffer((16,), "float32")):
+ T.func_attr({"operator_name": "relax.add"})
+ for ax0 in range(16):
+ with T.block("T_add"):
+ v_ax0 = T.axis.spatial(16, ax0)
+ T.reads(arg0[v_ax0], arg1[v_ax0])
+ T.writes(output[v_ax0])
+ output[v_ax0] = arg0[v_ax0] + arg1[v_ax0]
+
+ @R.function
+ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,),
dtype="float32")) -> R.Tensor((16,), dtype="float32"):
+ with R.dataflow():
+ lv = R.call_tir(Before.add, (x, y), out_sinfo=R.Tensor((16,),
dtype="float32"))
+ gv: R.Tensor((16,), dtype="float32") = lv
+ R.output(gv)
+ return gv
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1:
T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")):
+ T.func_attr({"operator_name": "relax.add"})
+ for ax0, ax1 in T.grid(4, 4):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+ T.writes(output[v_ax0, v_ax1])
+ output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0,
v_ax1]
+
+ @R.function
+ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,),
dtype="float32")) -> R.Tensor((16,), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x,
index_map=lambda i: (i // 4, i % 4), pad_value=None)
+ lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y,
index_map=lambda i: (i // 4, i % 4), pad_value=None)
+ lv2 = R.call_tir(Expected.relax_add_replacement, (lv, lv1),
out_sinfo=R.Tensor((4, 4), dtype="float32"))
+ lv_1: R.Tensor((16,), dtype="float32") =
R.layout_transform(lv2, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,),
pad_value=None)
+ gv: R.Tensor((16,), dtype="float32") = lv_1
+ R.output(gv)
+ return gv
+
+ @T.prim_func
+ def add_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4),
"float32"), output: T.Buffer((4, 4), "float32")):
+ for ax0, ax1 in T.grid(4, 4):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+ T.writes(output[v_ax0, v_ax1])
+ output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+ # fmt: on
+ index_map = lambda i: (i // 4, i % 4)
+ _check(
+ Before,
+ Expected,
+ operator_name="relax.add",
+ replacement_primfunc=add_2d,
+ layout_changes=[index_map, index_map, index_map],
+ )
+
+
+def test_empty_layout_changes():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def mul_by_2(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,),
"float32")):
+ T.func_attr({"operator_name": "relax.mul_by_2"})
+ for ax0 in range(16):
+ with T.block("T_add"):
+ v_ax0 = T.axis.spatial(16, ax0)
+ T.reads(arg0[v_ax0])
+ T.writes(output[v_ax0])
+ output[v_ax0] = arg0[v_ax0] * T.float32(2)
+
+ @R.function
+ def main(x: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,),
dtype="float32"):
+ with R.dataflow():
+ lv = R.call_tir(Before.mul_by_2, (x,),
out_sinfo=R.Tensor((16,), dtype="float32"))
+ gv: R.Tensor((16,), dtype="float32") = lv
+ R.output(gv)
+ return gv
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def relax_mul_by_2_replacement(arg0: T.Buffer((16,), "float32"),
output: T.Buffer((16,), "float32")):
+ T.func_attr({"operator_name": "relax.mul_by_2"})
+ for ax0 in range(16):
+ with T.block("T_add"):
+ v_ax0 = T.axis.spatial(16, ax0)
+ T.reads(arg0[v_ax0])
+ T.writes(output[v_ax0])
+ output[v_ax0] = arg0[v_ax0] + arg0[v_ax0]
+
+ @R.function
+ def main(x: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,),
dtype="float32"):
+ with R.dataflow():
+ lv = R.call_tir(Expected.relax_mul_by_2_replacement, (x,),
out_sinfo=R.Tensor((16,), dtype="float32"))
+ gv: R.Tensor((16,), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ @T.prim_func
+ def add_x_x(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,),
"float32")):
+ T.func_attr({"operator_name": "relax.mul_by_2"})
+ for ax0 in range(16):
+ with T.block("T_add"):
+ v_ax0 = T.axis.spatial(16, ax0)
+ T.reads(arg0[v_ax0])
+ T.writes(output[v_ax0])
+ output[v_ax0] = arg0[v_ax0] + arg0[v_ax0]
+ # fmt: on
+ _check(
+ Before,
+ Expected,
+ operator_name="relax.mul_by_2",
+ replacement_primfunc=add_x_x,
+ layout_changes=[],
+ )
+
+
+def test_multiple_outputs():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def some_op(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,),
"float32"), output0: T.Buffer((16,), "float32"), output1: T.Buffer((16,),
"float32")):
+ T.func_attr({"operator_name": "relax.some_op"})
+ for ax0 in range(16):
+ with T.block("T_add"):
+ v_ax0 = T.axis.spatial(16, ax0)
+ T.reads(arg0[v_ax0], arg1[v_ax0])
+ T.writes(output0[v_ax0], output1[v_ax0])
+ output0[v_ax0] = arg0[v_ax0] + arg1[v_ax0]
+ output1[v_ax0] = arg0[v_ax0] - arg1[v_ax0]
+
+ @R.function
+ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,),
dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,),
dtype="float32")):
+ with R.dataflow():
+ gv = R.call_tir(Before.some_op, (x, y),
out_sinfo=[R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")])
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def relax_some_op_replacement(arg0: T.Buffer((4, 4), "float32"), arg1:
T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1:
T.Buffer((4, 4), "float32")):
+ T.func_attr({"operator_name": "relax.some_op"})
+ for ax0, ax1 in T.grid(4, 4):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+ T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1])
+ output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0,
v_ax1]
+ output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0,
v_ax1]
+
+ @R.function
+ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,),
dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,),
dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x,
index_map=lambda i: (i // 4, i % 4), pad_value=None)
+ lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y,
index_map=lambda i: (i // 4, i % 4), pad_value=None)
+ lv2 = R.call_tir(Expected.relax_some_op_replacement, (lv,
lv1), out_sinfo=[R.Tensor((4, 4), dtype="float32"), R.Tensor((4, 4),
dtype="float32")])
+ lv3: R.Tensor((4, 4), dtype="float32") = lv2[0]
+ lv4: R.Tensor((16,), dtype="float32") =
R.layout_transform(lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,),
pad_value=None)
+ lv5: R.Tensor((4, 4), dtype="float32") = lv2[1]
+ lv6: R.Tensor((16,), dtype="float32") =
R.layout_transform(lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,),
pad_value=None)
+ gv: R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,),
dtype="float32")) = (lv4, lv6)
+ R.output(gv)
+ return gv
+
+ @T.prim_func
+ def some_op_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4),
"float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4),
"float32")):
+ for ax0, ax1 in T.grid(4, 4):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+ T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1])
+ output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+ output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1]
+ # fmt: on
+
+ index_map = lambda i: (i // 4, i % 4)
+ _check(
+ Before,
+ Expected,
+ operator_name="relax.some_op",
+ replacement_primfunc=some_op_2d,
+ layout_changes=[index_map, index_map, index_map, index_map],
+ )
+
+
+def test_unsupported_implicit_padding():
+ @I.ir_module
+ class InputModule:
+ @R.function
+ def foo(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,),
dtype="float32"):
+ with R.dataflow():
+ lv = R.call_tir(InputModule.relu, (x,),
out_sinfo=R.Tensor((14,), dtype="float32"))
+ gv: R.Tensor((14,), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ @T.prim_func
+ def relu(arg0: T.Buffer((14,), "float32"), output: T.Buffer((14,),
"float32")):
+ T.func_attr({"operator_name": "relax.relu"})
+ for ax0 in T.grid(14):
+ with T.block("T_add"):
+ v_ax0 = T.axis.remap("S", [ax0])
+ T.reads(arg0[v_ax0])
+ T.writes(output[v_ax0])
+ output[v_ax0] = T.max(arg0[v_ax0], T.float32(0))
+
+ before = InputModule
+
+ @T.prim_func
+ def relu_pad(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,),
"float32")):
+ for ax0 in T.grid(16):
+ with T.block("T_add"):
+ v_ax0 = T.axis.remap("S", [ax0])
+ T.reads(arg0[v_ax0])
+ T.writes(output[v_ax0])
+ output[v_ax0] = T.max(arg0[v_ax0], T.float32(0))
+
+ # introduces implicit padding for shape (14,)
+ index_map = lambda i: (i % 16)
+ operator_name = "relax.relu"
+ with pytest.raises(
+ tvm.TVMError, match="Non bijective transforms on input and output
buffers are not supported"
+ ):
+ _ = relax.transform.AlterOpImpl(
+ {operator_name: relu_pad}, {operator_name: [index_map, index_map]}
+ )(before)
+
+
+def test_multiple_call_sites():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def add(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,),
"float32"), output: T.Buffer((16,), "float32")):
+ T.func_attr({"operator_name": "relax.add"})
+ for ax0 in range(16):
+ with T.block("T_add"):
+ v_ax0 = T.axis.spatial(16, ax0)
+ T.reads(arg0[v_ax0], arg1[v_ax0])
+ T.writes(output[v_ax0])
+ output[v_ax0] = arg0[v_ax0] + arg1[v_ax0]
+
+ @R.function
+ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,),
dtype="float32")) -> R.Tensor((16,), dtype="float32"):
+ with R.dataflow():
+ lv0 = R.call_tir(Before.add, (x, y), out_sinfo=R.Tensor((16,),
dtype="float32"))
+ lv1 = R.nn.relu(lv0)
+ lv2 = R.call_tir(Before.add, (lv0, lv1),
out_sinfo=R.Tensor((16,), dtype="float32"))
+ gv: R.Tensor((16,), dtype="float32") = lv2
+ R.output(gv)
+ return gv
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1:
T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")):
+ T.func_attr({"operator_name": "relax.add"})
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(4, 4):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+ T.writes(output[v_ax0, v_ax1])
+ output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0,
v_ax1]
+
+ @R.function
+ def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,),
dtype="float32")) -> R.Tensor((16,), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x,
index_map=lambda i: (i // 4, i % 4), pad_value=None)
+ lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y,
index_map=lambda i: (i // 4, i % 4), pad_value=None)
+ lv2 = R.call_tir(Expected.relax_add_replacement, (lv, lv1),
out_sinfo=R.Tensor((4, 4), dtype="float32"))
+ lv0: R.Tensor((16,), dtype="float32") =
R.layout_transform(lv2, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,),
pad_value=None)
+ lv1_1: R.Tensor((16,), dtype="float32") = R.nn.relu(lv0)
+ lv3: R.Tensor((4, 4), dtype="float32") =
R.layout_transform(lv0, index_map=lambda i: (i // 4, i % 4), pad_value=None)
+ lv4: R.Tensor((4, 4), dtype="float32") =
R.layout_transform(lv1_1, index_map=lambda i: (i // 4, i % 4), pad_value=None)
+ lv5 = R.call_tir(Expected.relax_add_replacement, (lv3, lv4),
out_sinfo=R.Tensor((4, 4), dtype="float32"))
+ lv2_1: R.Tensor((16,), dtype="float32") =
R.layout_transform(lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,),
pad_value=None)
+ gv: R.Tensor((16,), dtype="float32") = lv2_1
+ R.output(gv)
+ return gv
+ @T.prim_func
+ def add_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4),
"float32"), output: T.Buffer((4, 4), "float32")):
+ for ax0, ax1 in T.grid(4, 4):
+ with T.block("T_add"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+ T.writes(output[v_ax0, v_ax1])
+ output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+ # fmt: on
+ index_map = lambda i: (i // 4, i % 4)
+ _check(
+ Before,
+ Expected,
+ operator_name="relax.add",
+ replacement_primfunc=add_2d,
+ layout_changes=[index_map, index_map, index_map],
+ )
+
+
+if __name__ == "__main__":
+ tvm.testing.main()