This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 3c0639eee9 [Unity][Pass] Add a pass to alter the TIR implementation of 
an operator (#14215)
3c0639eee9 is described below

commit 3c0639eee9a332bd405864321a7874185f76dcd7
Author: Prakalp Srivastava <[email protected]>
AuthorDate: Wed Mar 15 15:02:38 2023 -0400

    [Unity][Pass] Add a pass to alter the TIR implementation of an operator 
(#14215)
    
    * [Unity][Pass] Add a pass to alter the TIR
    implementation of an operator (identified
    by operator_kind attribute on PrimFunc).
    It also inserts layout changes to i/o
    buffers at Relax level.
    
    * deep copy index map to avoid structural_equality fail
    
    * do not mark layouts as frozen
    
    * address comments
    
    * fix call_tir global symbol in tests
---
 include/tvm/relax/struct_info.h                    |   7 +
 include/tvm/relax/transform.h                      |  17 +-
 python/tvm/relax/transform/transform.py            |  32 ++
 src/relax/transform/alter_op_impl.cc               | 310 +++++++++++++++++++
 tests/python/relax/test_transform_alter_op_impl.py | 342 +++++++++++++++++++++
 5 files changed, 707 insertions(+), 1 deletion(-)

diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h
index b9aebc5494..0c1973bcea 100644
--- a/include/tvm/relax/struct_info.h
+++ b/include/tvm/relax/struct_info.h
@@ -170,6 +170,13 @@ class TensorStructInfoNode : public StructInfoNode {
   /*! \return Whether the struct info contains unknown dtype. */
   bool IsUnknownDtype() const { return dtype.is_void(); }
 
+  /*! \return Shape if it is known. */
+  Optional<Array<PrimExpr>> GetShape() const {
+    if (!shape.defined()) return {};
+    ShapeStructInfo shape_sinfo = 
Downcast<ShapeStructInfo>(this->shape.value()->struct_info_);
+    return shape_sinfo->values;
+  }
+
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("shape", &shape);
     v->Visit("dtype", &dtype);
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 446b75da9f..3ff863dd09 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -27,7 +27,8 @@
 #include <tvm/ir/transform.h>
 #include <tvm/relax/dataflow_pattern.h>
 #include <tvm/relax/expr.h>
-
+#include <tvm/tir/function.h>
+#include <tvm/tir/index_map.h>
 namespace tvm {
 namespace relax {
 namespace transform {
@@ -279,6 +280,20 @@ TVM_DLL Pass RunCodegen(Optional<Map<String, Map<String, 
ObjectRef>>> target_opt
  * \return The Pass.
  */
 TVM_DLL Pass SimplifyNormInference();
+/*!
+ * \brief Returns a pass which replaces PrimFuncs which have matching 
kOperatorName attribute in \p
+ * op_impl_map, with replacement PrimFunc that could possibly have different 
layouts on i/o
+ * buffers. The layout transformations on i/o buffers is present in the \p 
op_buffer_transforms. The
+ * pass inserts the layout transformations in the call sites of PrimFuncs 
being replaced to
+ * transform i/o buffers into expected layout.
+ *
+ * \param op_impl_map Map from from kOperatorName attr (e.g., relax.conv2d) to 
replacement PrimFunc
+ * \param op_buffer_transforms Map from kOperatorName attr to layout 
transformations on each of the
+ * PrimFunc i/o buffers.
+ * \return The Pass.
+ */
+TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
+                         const Map<String, Array<tir::IndexMap>>& 
op_buffer_transforms);
 }  // namespace transform
 }  // namespace relax
 }  // namespace tvm
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 97c8772b3b..c59104ca58 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -20,6 +20,7 @@ import functools
 import inspect
 import types
 from typing import Callable, Dict, Union, Optional, List, Tuple
+from tvm.tir import PrimFunc, IndexMap
 import numpy as np  # type: ignore
 import tvm.ir
 from tvm.runtime import NDArray
@@ -542,6 +543,37 @@ def SimplifyNormInference() -> tvm.ir.transform.Pass:
     return _ffi_api.SimplifyNormInference()  # type: ignore
 
 
+def AlterOpImpl(
+    op_impl_map: Dict[str, PrimFunc],
+    op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]],
+):
+    """Replace all PrimFunc's which have matching 'operator_name' attribute, 
with replacement
+    PrimFunc that could possibly have different layouts on i/o buffers. The 
layout
+    transformations on i/o buffers is present in the op_buffer_transforms map. 
Inserts the layout
+    transformations in the call sites of PrimFuncs being replaced to transform 
i/o
+    tensors into expected layout by new PrimFunc.
+
+    Parameters
+    ----------
+    op_impl_map: Dict[str, PrimFunc]
+        op_kind to PrimFunc map
+    op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]
+        op_kind to layout transformation map for each of the buffers
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+    for operator_name, transform_list in op_buffer_transforms.items():
+        l = []
+        for transform in transform_list:
+            if isinstance(transform, Callable):
+                transform = IndexMap.from_func(transform)
+            l.append(transform)
+        op_buffer_transforms[operator_name] = l
+
+    return _ffi_api.AlterOpImpl(op_impl_map, op_buffer_transforms)  # type: 
ignore
+
+
 def _wrap_class_function_pass(pass_cls, pass_info):
     """Wrap a python class as function pass."""
 
diff --git a/src/relax/transform/alter_op_impl.cc 
b/src/relax/transform/alter_op_impl.cc
new file mode 100644
index 0000000000..6a740b4f55
--- /dev/null
+++ b/src/relax/transform/alter_op_impl.cc
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/alter_op_impl.cc
+ * \brief Change the layout of PrimFunc in the graph. It uses the 
kOperatorName attribute to
+ * identify PrimFuncs to be replaced. Marks the new PrimFuncs with 
kFrozenLayout attribute set to
+ * true.
+ */
+#include <tvm/ir/attrs.h>
+#include <tvm/node/serialization.h>
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/attrs/manipulate.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/transform.h>
+namespace tvm {
+namespace relax {
+
+using namespace tir;
+static constexpr const char* kOperatorName = "operator_name";
+
+/*! \brief Construct ranges from shape dimensions */
+static Array<Range> ConstructRangeFromShape(const Array<PrimExpr>& shape) {
+  return shape.Map([](const PrimExpr& dim) { return 
Range(tir::make_zero(dim.dtype()), dim); });
+}
+
+static Array<PrimExpr> GetShapeFromTensorStructInfo(const TensorStructInfo& 
tensor_sinfo) {
+  auto shape = tensor_sinfo->GetShape();
+  ICHECK(shape.defined());
+  return shape.value();
+}
+
+static Array<PrimExpr> GetShapeFromTensor(const Expr& expr) {
+  const auto& tensor_sinfo = Downcast<TensorStructInfo>(expr->struct_info_);
+  return GetShapeFromTensorStructInfo(tensor_sinfo);
+}
+
+static IndexMap DeepCopyIndexMap(const IndexMap& index_map) {
+  return Downcast<IndexMap>(LoadJSON(SaveJSON(index_map)));
+}
+
+/*! \brief Checks if the \p transform is bijective on the shape of \p expr */
+bool IsTransformBijective(const Expr& expr, const IndexMap& transform) {
+  Array<PrimExpr> input_shape = GetShapeFromTensor(expr);
+  Array<Range> initial_ranges = ConstructRangeFromShape(input_shape);
+  auto [inverse, padding_predicate] = 
transform.NonSurjectiveInverse(initial_ranges);
+  (void)inverse;  // to avoid unused variable warning;
+  arith::Analyzer analyzer;
+  if (!analyzer.CanProve(!padding_predicate)) return false;
+  return true;
+}
+
+/*!
+ * \brief Replace each call_tir to PrimFunc which matches the kOperatorName 
attribute with the
+ * provided replacement PrimFunc and mark it with kFrozenLayout attribute. 
Insert layout
+ * transformations on i/o buffers as necessary for correctness.
+ */
+class AlterOpImplMutator : public ExprMutator {
+ public:
+  AlterOpImplMutator(const IRModule& mod, const Map<String, tir::PrimFunc>& 
op_impl_map,
+                     const Map<String, Array<IndexMap>>& op_buffer_transforms_)
+      : ExprMutator(mod),
+        mod_(mod),
+        op_impl_map_(op_impl_map),
+        op_buffer_transforms__(op_buffer_transforms_) {}
+
+  IRModule Run() {
+    for (const auto& [gv, func] : mod_->functions) {
+      if (func->IsInstance<relax::FunctionNode>()) {
+        relax::Function update_func = Downcast<Function>(VisitExpr(func));
+        builder_->UpdateFunction(gv, update_func);
+      }
+    }
+    return builder_->GetContextIRModule();
+  }
+
+ private:
+  Expr VisitExpr_(const CallNode* op) final {
+    auto call = Downcast<Call>(ExprMutator::VisitExpr_(op));
+
+    // TODO(@tvm-team): When we differentiate the call for tir function and 
packed function,
+    // this logic should be changed accordingly.
+    if (!call->op.same_as(call_tir_op_)) return call;
+
+    // Do not do anything for external function
+    if (call->args[0].as<ExternFuncNode>()) return call;
+
+    // Get operator name from callee
+    ICHECK(call->args[0]->IsInstance<GlobalVarNode>());
+    const tir::PrimFunc& old_func =
+        
Downcast<tir::PrimFunc>(mod_->Lookup(Downcast<GlobalVar>(call->args[0])));
+    Optional<String> maybe_op_kind = 
old_func->attrs.GetAttr<String>(kOperatorName);
+
+    // If the callee does not have kOperatorName attribute or no replacement 
is requested for
+    // it, nothing to do here.
+    if (!maybe_op_kind.defined() || op_impl_map_.count(maybe_op_kind.value()) 
== 0) return call;
+    auto op_kind = maybe_op_kind.value();
+
+    const auto& replacement_func = op_impl_map_[op_kind];
+
+    Array<IndexMap> buffer_transforms;
+    if (op_buffer_transforms__.count(op_kind)) buffer_transforms = 
op_buffer_transforms__[op_kind];
+
+    ICHECK(buffer_transforms.empty() || buffer_transforms.size() == 
replacement_func->params.size())
+        << "Either the i/o buffers do not require any transformations or 
transformations for each "
+           "buffer is provided.";
+    ICHECK_EQ(old_func->params.size(), replacement_func->params.size())
+        << "Number of parameters of old and replacement PrimFunc must match";
+
+    GlobalVar replacement_gv = GetOrCreateGlobalVarForFunc(replacement_func, 
op_kind);
+
+    auto call_tir_inputs_tuple = GetRef<Tuple>(call->args[1].as<TupleNode>());
+    Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, 
buffer_transforms);
+
+    ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir sinfo_args.size() is 
expected to be 1";
+    StructInfo updated_ret_sinfo = UpdateStructInfo(call->sinfo_args[0], 
buffer_transforms);
+    auto updated_call = builder_->Normalize(
+        Call(call_tir_op_, {replacement_gv, updated_inputs}, call->attrs, 
{updated_ret_sinfo}));
+
+    // Now transform each of the outputs to previous layout.
+    return TransformOutputs(updated_call, buffer_transforms, 
call->sinfo_args[0]);
+  }
+
+  Array<TensorStructInfo> GetTensorStructInfoPerOutput(const StructInfo& 
output_sinfo) {
+    if (const auto* tensor_sinfo = output_sinfo.as<TensorStructInfoNode>())
+      return {GetRef<TensorStructInfo>(tensor_sinfo)};
+    const auto* tuple_sinfo = output_sinfo.as<TupleStructInfoNode>();
+    ICHECK(tuple_sinfo);
+
+    Array<TensorStructInfo> arr_tensor_sinfo;
+    arr_tensor_sinfo.reserve(tuple_sinfo->fields.size());
+    for (const auto& sinfo : tuple_sinfo->fields) {
+      const auto* tensor_sinfo = sinfo.as<TensorStructInfoNode>();
+      ICHECK(tensor_sinfo) << "Nested tuples in output of call_tir is not 
supported yet";
+      arr_tensor_sinfo.push_back(GetRef<TensorStructInfo>(tensor_sinfo));
+    }
+    return arr_tensor_sinfo;
+  }
+
+  Expr TransformLayout(const Expr& expr, const IndexMap& index_map) {
+    ObjectPtr<LayoutTransformAttrs> attrs = 
make_object<LayoutTransformAttrs>();
+    // We want to avoid two layout_transform ops to share the same index map 
even if they are
+    // identical. The scope of vars used in index map initial indices is local 
to the op. Not doing
+    // so would confuse the structural equality check.
+    attrs->index_map = std::move(DeepCopyIndexMap(index_map));
+    return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {});
+  }
+
+  Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map,
+                              const TensorStructInfo& old_tensor_sinfo) {
+    Array<PrimExpr> old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo);
+    Array<Range> initial_ranges = ConstructRangeFromShape(old_shape);
+    auto [inverse_index_map, padding_predicate] = 
index_map.NonSurjectiveInverse(initial_ranges);
+    ICHECK(tir::is_zero(padding_predicate))
+        << "Only bijective transformations on input/output buffers are 
supported, but found "
+           "padding predicate "
+        << padding_predicate << " on initial range " << initial_ranges;
+    return TransformLayout(expr, inverse_index_map);
+  }
+
+  /*!
+   * \brief Adds the \p replacement_func to the module if it has not already 
been added before.
+   * \returns The global var associated with the PrimFunc.
+   */
+  GlobalVar GetOrCreateGlobalVarForFunc(const PrimFunc& replacement_func, 
const String& op_kind) {
+    if (cache_.count(replacement_func) != 0) {
+      return cache_[replacement_func];
+    }
+    // Retain the operator name attribute on the replacement PrimFunc. This 
can help any future
+    // passes that use kOperatorName attribute to identify operator 
represented by a PrimFunc.
+    PrimFunc replacement_func_with_frozen_layout =
+        WithAttr(replacement_func, kOperatorName, op_kind);
+
+    GlobalVar gv_replacement =
+        builder_->AddFunction(replacement_func_with_frozen_layout, op_kind + 
"_replacement");
+    cache_.Set(replacement_func, gv_replacement);
+    return gv_replacement;
+  }
+
+  /*!
+   * \brief Updates call inputs with layout transformed inputs
+   */
+  Tuple UpdateInputs(const Tuple& inputs, const Array<IndexMap>& transforms) {
+    if (transforms.empty()) return inputs;
+
+    Array<Expr> updated_inputs;
+    int index = 0;
+    for (const auto& input : inputs->fields) {
+      auto transform = transforms[index++];
+      ICHECK(IsTransformBijective(input, transform))
+          << "Non bijective transforms on input and output buffers are not 
supported.";
+      updated_inputs.push_back(TransformLayout(input, transform));
+    }
+    return Tuple(updated_inputs);
+  }
+
+  /*! \brief Updates output struct info */
+  StructInfo UpdateStructInfo(const StructInfo& out_sinfo,
+                              const Array<IndexMap>& buffer_transforms) {
+    if (buffer_transforms.empty()) return out_sinfo;
+
+    if (out_sinfo->IsInstance<TensorStructInfoNode>())
+      return UpdateStructInfo(Downcast<TensorStructInfo>(out_sinfo),
+                              buffer_transforms[buffer_transforms.size() - 1]);
+
+    ICHECK(out_sinfo->IsInstance<TupleStructInfoNode>())
+        << "Expect output struct info of call_tir to be either TupleStructInfo 
or "
+           "TensorStructInfo, but got "
+        << out_sinfo;
+
+    const auto& tuple_sinfo = Downcast<TupleStructInfo>(out_sinfo);
+    Array<StructInfo> sinfo_fields;
+    size_t first_output_index = buffer_transforms.size() - 
tuple_sinfo->fields.size();
+    size_t i = 0;
+    for (const auto& si : tuple_sinfo->fields) {
+      ICHECK(si->IsInstance<TensorStructInfoNode>())
+          << "Fields of TupleStructInfo must be TensorStructInfo for call_tir "
+             "output structinfo, but got "
+          << si;
+      sinfo_fields.push_back(UpdateStructInfo(Downcast<TensorStructInfo>(si),
+                                              
buffer_transforms[first_output_index + i++]));
+    }
+    return TupleStructInfo(sinfo_fields);
+  }
+
+  /*! \brief Returns the TensorStructInfo after applying the \p transform on 
its shape */
+  StructInfo UpdateStructInfo(const TensorStructInfo& tensor_sinfo, const 
IndexMap& transform) {
+    auto shape = GetShapeFromTensorStructInfo(tensor_sinfo);
+    auto new_shape = transform->MapShape(shape);
+    return TensorStructInfo(ShapeExpr(new_shape), tensor_sinfo->dtype);
+  }
+
+  Expr TransformOutputs(const Expr& expr, const Array<IndexMap>& 
buffer_transforms,
+                        const StructInfo& old_struct_info) {
+    if (buffer_transforms.empty()) return expr;
+
+    Array<TensorStructInfo> old_output_sinfo = 
GetTensorStructInfoPerOutput(old_struct_info);
+
+    size_t num_outputs = old_output_sinfo.size();
+    if (num_outputs == 0) return expr;
+
+    size_t first_output_index = buffer_transforms.size() - num_outputs;
+    // If there is a single output, return the transformed output.
+    if (num_outputs == 1) {
+      IndexMap output_map = buffer_transforms[first_output_index];
+      return TransformLayoutInverse(expr, output_map, old_output_sinfo[0]);
+    }
+
+    // In case of more than one output, we would have to get each item of the 
output tuple,
+    // transform it and return a tuple of all transformed outputs.
+    Array<Expr> transformed_outputs;
+    for (size_t i = 0; i + first_output_index < buffer_transforms.size(); ++i) 
{
+      const auto& output_map = buffer_transforms[i + first_output_index];
+      auto output = builder_->Normalize(TupleGetItem(expr, 
static_cast<int>(i)));
+      transformed_outputs.push_back(
+          TransformLayoutInverse(output, output_map, old_output_sinfo[i]));
+    }
+    return Tuple(transformed_outputs);
+  }
+
+ private:
+  /*! \brief Cache to keep track of the GlobalVar associated with the new 
PrimFunc added */
+  Map<PrimFunc, GlobalVar> cache_;
+  /*! \brief Input IRModule */
+  const IRModule& mod_;
+  /*! \brief Map from kOperatorName attribute to the replacement PrimFunc */
+  const Map<String, PrimFunc>& op_impl_map_;
+  /*! \brief Map from kOperatorName attribute to the layout transforms on i/o 
buffers */
+  const Map<String, Array<IndexMap>>& op_buffer_transforms__;
+
+  const Op& call_tir_op_ = Op::Get("relax.call_tir");
+  const Op& layout_transform_op_ = Op::Get("relax.layout_transform");
+};
+
+namespace transform {
+
+Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
+                 const Map<String, Array<IndexMap>>& op_buffer_transforms_) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule mod,
+                                                                            
PassContext pc) {
+    return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_).Run();
+  };
+  return CreateModulePass(/*pass_function=*/pass_func,  //
+                          /*opt_level=*/0,              //
+                          /*pass_name=*/"AlterOpImpl",  //
+                          /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.AlterOpImpl").set_body_typed(AlterOpImpl);
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_transform_alter_op_impl.py 
b/tests/python/relax/test_transform_alter_op_impl.py
new file mode 100644
index 0000000000..e8fa29a074
--- /dev/null
+++ b/tests/python/relax/test_transform_alter_op_impl.py
@@ -0,0 +1,342 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import tvm.testing
+
+from tvm import relax
+from tvm.script import tir as T, ir as I, relax as R
+
+kOperatorName = "operator_name"
+
+
+def _check(before, expected, operator_name, replacement_primfunc, 
layout_changes):
+    after = relax.transform.AlterOpImpl(
+        {operator_name: replacement_primfunc}, {operator_name: layout_changes}
+    )(before)
+    after = relax.transform.RemoveUnusedFunctions()(after)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_single_output():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def add(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), 
"float32"), output: T.Buffer((16,), "float32")):
+            T.func_attr({"operator_name": "relax.add"})
+            for ax0 in range(16):
+                with T.block("T_add"):
+                    v_ax0 = T.axis.spatial(16, ax0)
+                    T.reads(arg0[v_ax0], arg1[v_ax0])
+                    T.writes(output[v_ax0])
+                    output[v_ax0] = arg0[v_ax0] + arg1[v_ax0]
+
+        @R.function
+        def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), 
dtype="float32")) -> R.Tensor((16,), dtype="float32"):
+            with R.dataflow():
+                lv = R.call_tir(Before.add, (x, y), out_sinfo=R.Tensor((16,), 
dtype="float32"))
+                gv: R.Tensor((16,), dtype="float32") = lv
+                R.output(gv)
+            return gv
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: 
T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")):
+            T.func_attr({"operator_name": "relax.add"})
+            for ax0, ax1 in T.grid(4, 4):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+                    T.writes(output[v_ax0, v_ax1])
+                    output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, 
v_ax1]
+
+        @R.function
+        def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), 
dtype="float32")) -> R.Tensor((16,), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, 
index_map=lambda i: (i // 4, i % 4), pad_value=None)
+                lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, 
index_map=lambda i: (i // 4, i % 4), pad_value=None)
+                lv2 = R.call_tir(Expected.relax_add_replacement, (lv, lv1), 
out_sinfo=R.Tensor((4, 4), dtype="float32"))
+                lv_1: R.Tensor((16,), dtype="float32") = 
R.layout_transform(lv2, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), 
pad_value=None)
+                gv: R.Tensor((16,), dtype="float32") = lv_1
+                R.output(gv)
+            return gv
+
+    @T.prim_func
+    def add_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), 
"float32"), output: T.Buffer((4, 4), "float32")):
+        for ax0, ax1 in T.grid(4, 4):
+            with T.block("T_add"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+                T.writes(output[v_ax0, v_ax1])
+                output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+    # fmt: on
+    index_map = lambda i: (i // 4, i % 4)
+    _check(
+        Before,
+        Expected,
+        operator_name="relax.add",
+        replacement_primfunc=add_2d,
+        layout_changes=[index_map, index_map, index_map],
+    )
+
+
+def test_empty_layout_changes():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def mul_by_2(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), 
"float32")):
+            T.func_attr({"operator_name": "relax.mul_by_2"})
+            for ax0 in range(16):
+                with T.block("T_add"):
+                    v_ax0 = T.axis.spatial(16, ax0)
+                    T.reads(arg0[v_ax0])
+                    T.writes(output[v_ax0])
+                    output[v_ax0] = arg0[v_ax0] * T.float32(2)
+
+        @R.function
+        def main(x: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), 
dtype="float32"):
+            with R.dataflow():
+                lv = R.call_tir(Before.mul_by_2, (x,), 
out_sinfo=R.Tensor((16,), dtype="float32"))
+                gv: R.Tensor((16,), dtype="float32") = lv
+                R.output(gv)
+            return gv
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def relax_mul_by_2_replacement(arg0: T.Buffer((16,), "float32"), 
output: T.Buffer((16,), "float32")):
+            T.func_attr({"operator_name": "relax.mul_by_2"})
+            for ax0 in range(16):
+                with T.block("T_add"):
+                    v_ax0 = T.axis.spatial(16, ax0)
+                    T.reads(arg0[v_ax0])
+                    T.writes(output[v_ax0])
+                    output[v_ax0] = arg0[v_ax0] + arg0[v_ax0]
+
+        @R.function
+        def main(x: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), 
dtype="float32"):
+            with R.dataflow():
+                lv = R.call_tir(Expected.relax_mul_by_2_replacement, (x,), 
out_sinfo=R.Tensor((16,), dtype="float32"))
+                gv: R.Tensor((16,), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    @T.prim_func
+    def add_x_x(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), 
"float32")):
+        T.func_attr({"operator_name": "relax.mul_by_2"})
+        for ax0 in range(16):
+            with T.block("T_add"):
+                v_ax0 = T.axis.spatial(16, ax0)
+                T.reads(arg0[v_ax0])
+                T.writes(output[v_ax0])
+                output[v_ax0] = arg0[v_ax0] + arg0[v_ax0]
+    # fmt: on
+    _check(
+        Before,
+        Expected,
+        operator_name="relax.mul_by_2",
+        replacement_primfunc=add_x_x,
+        layout_changes=[],
+    )
+
+
+def test_multiple_outputs():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def some_op(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), 
"float32"), output0: T.Buffer((16,), "float32"), output1: T.Buffer((16,), 
"float32")):
+            T.func_attr({"operator_name": "relax.some_op"})
+            for ax0 in range(16):
+                with T.block("T_add"):
+                    v_ax0 = T.axis.spatial(16, ax0)
+                    T.reads(arg0[v_ax0], arg1[v_ax0])
+                    T.writes(output0[v_ax0], output1[v_ax0])
+                    output0[v_ax0] = arg0[v_ax0] + arg1[v_ax0]
+                    output1[v_ax0] = arg0[v_ax0] - arg1[v_ax0]
+
+        @R.function
+        def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), 
dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), 
dtype="float32")):
+            with R.dataflow():
+                gv = R.call_tir(Before.some_op, (x, y), 
out_sinfo=[R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")])
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def relax_some_op_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: 
T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1: 
T.Buffer((4, 4), "float32")):
+            T.func_attr({"operator_name": "relax.some_op"})
+            for ax0, ax1 in T.grid(4, 4):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+                    T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1])
+                    output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, 
v_ax1]
+                    output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, 
v_ax1]
+
+        @R.function
+        def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), 
dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), 
dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, 
index_map=lambda i: (i // 4, i % 4), pad_value=None)
+                lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, 
index_map=lambda i: (i // 4, i % 4), pad_value=None)
+                lv2 = R.call_tir(Expected.relax_some_op_replacement, (lv, 
lv1), out_sinfo=[R.Tensor((4, 4), dtype="float32"), R.Tensor((4, 4), 
dtype="float32")])
+                lv3: R.Tensor((4, 4), dtype="float32") = lv2[0]
+                lv4: R.Tensor((16,), dtype="float32") = 
R.layout_transform(lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), 
pad_value=None)
+                lv5: R.Tensor((4, 4), dtype="float32") = lv2[1]
+                lv6: R.Tensor((16,), dtype="float32") = 
R.layout_transform(lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), 
pad_value=None)
+                gv: R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), 
dtype="float32")) = (lv4, lv6)
+                R.output(gv)
+            return gv
+
+    @T.prim_func
+    def some_op_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), 
"float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4), 
"float32")):
+        for ax0, ax1 in T.grid(4, 4):
+            with T.block("T_add"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+                T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1])
+                output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+                output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1]
+    # fmt: on
+
+    index_map = lambda i: (i // 4, i % 4)
+    _check(
+        Before,
+        Expected,
+        operator_name="relax.some_op",
+        replacement_primfunc=some_op_2d,
+        layout_changes=[index_map, index_map, index_map, index_map],
+    )
+
+
+def test_unsupported_implicit_padding():
+    @I.ir_module
+    class InputModule:
+        @R.function
+        def foo(x: R.Tensor((14,), dtype="float32")) -> R.Tensor((14,), 
dtype="float32"):
+            with R.dataflow():
+                lv = R.call_tir(InputModule.relu, (x,), 
out_sinfo=R.Tensor((14,), dtype="float32"))
+                gv: R.Tensor((14,), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+        @T.prim_func
+        def relu(arg0: T.Buffer((14,), "float32"), output: T.Buffer((14,), 
"float32")):
+            T.func_attr({"operator_name": "relax.relu"})
+            for ax0 in T.grid(14):
+                with T.block("T_add"):
+                    v_ax0 = T.axis.remap("S", [ax0])
+                    T.reads(arg0[v_ax0])
+                    T.writes(output[v_ax0])
+                    output[v_ax0] = T.max(arg0[v_ax0], T.float32(0))
+
+    before = InputModule
+
+    @T.prim_func
+    def relu_pad(arg0: T.Buffer((16,), "float32"), output: T.Buffer((16,), 
"float32")):
+        for ax0 in T.grid(16):
+            with T.block("T_add"):
+                v_ax0 = T.axis.remap("S", [ax0])
+                T.reads(arg0[v_ax0])
+                T.writes(output[v_ax0])
+                output[v_ax0] = T.max(arg0[v_ax0], T.float32(0))
+
+    # introduces implicit padding for shape (14,)
+    index_map = lambda i: (i % 16)
+    operator_name = "relax.relu"
+    with pytest.raises(
+        tvm.TVMError, match="Non bijective transforms on input and output 
buffers are not supported"
+    ):
+        _ = relax.transform.AlterOpImpl(
+            {operator_name: relu_pad}, {operator_name: [index_map, index_map]}
+        )(before)
+
+
+def test_multiple_call_sites():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def add(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), 
"float32"), output: T.Buffer((16,), "float32")):
+            T.func_attr({"operator_name": "relax.add"})
+            for ax0 in range(16):
+                with T.block("T_add"):
+                    v_ax0 = T.axis.spatial(16, ax0)
+                    T.reads(arg0[v_ax0], arg1[v_ax0])
+                    T.writes(output[v_ax0])
+                    output[v_ax0] = arg0[v_ax0] + arg1[v_ax0]
+
+        @R.function
+        def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), 
dtype="float32")) -> R.Tensor((16,), dtype="float32"):
+            with R.dataflow():
+                lv0 = R.call_tir(Before.add, (x, y), out_sinfo=R.Tensor((16,), 
dtype="float32"))
+                lv1 = R.nn.relu(lv0)
+                lv2 = R.call_tir(Before.add, (lv0, lv1), 
out_sinfo=R.Tensor((16,), dtype="float32"))
+                gv: R.Tensor((16,), dtype="float32") = lv2
+                R.output(gv)
+            return gv
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def relax_add_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: 
T.Buffer((4, 4), "float32"), output: T.Buffer((4, 4), "float32")):
+            T.func_attr({"operator_name": "relax.add"})
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(4, 4):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+                    T.writes(output[v_ax0, v_ax1])
+                    output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, 
v_ax1]
+
+        @R.function
+        def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), 
dtype="float32")) -> R.Tensor((16,), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, 
index_map=lambda i: (i // 4, i % 4), pad_value=None)
+                lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, 
index_map=lambda i: (i // 4, i % 4), pad_value=None)
+                lv2 = R.call_tir(Expected.relax_add_replacement, (lv, lv1), 
out_sinfo=R.Tensor((4, 4), dtype="float32"))
+                lv0: R.Tensor((16,), dtype="float32") = 
R.layout_transform(lv2, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), 
pad_value=None)
+                lv1_1: R.Tensor((16,), dtype="float32") = R.nn.relu(lv0)
+                lv3: R.Tensor((4, 4), dtype="float32") = 
R.layout_transform(lv0, index_map=lambda i: (i // 4, i % 4), pad_value=None)
+                lv4: R.Tensor((4, 4), dtype="float32") = 
R.layout_transform(lv1_1, index_map=lambda i: (i // 4, i % 4), pad_value=None)
+                lv5 = R.call_tir(Expected.relax_add_replacement, (lv3, lv4), 
out_sinfo=R.Tensor((4, 4), dtype="float32"))
+                lv2_1: R.Tensor((16,), dtype="float32") = 
R.layout_transform(lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), 
pad_value=None)
+                gv: R.Tensor((16,), dtype="float32") = lv2_1
+                R.output(gv)
+            return gv
+    @T.prim_func
+    def add_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), 
"float32"), output: T.Buffer((4, 4), "float32")):
+        for ax0, ax1 in T.grid(4, 4):
+            with T.block("T_add"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
+                T.writes(output[v_ax0, v_ax1])
+                output[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+    # fmt: on
+    index_map = lambda i: (i // 4, i % 4)
+    _check(
+        Before,
+        Expected,
+        operator_name="relax.add",
+        replacement_primfunc=add_2d,
+        layout_changes=[index_map, index_map, index_map],
+    )
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to