psrivas2 commented on code in PR #14215:
URL: https://github.com/apache/tvm/pull/14215#discussion_r1132544805


##########
src/relax/transform/alter_op_impl.cc:
##########
@@ -0,0 +1,312 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/alter_op_impl.cc
+ * \brief Change the layout of PrimFunc in the graph. It uses the 
kOperatorName attribute to
+ * identify PrimFuncs to be replaced. Marks the new PrimFuncs with 
kFrozenLayout attribute set to
+ * true.
+ */
+#include <tvm/ir/attrs.h>
+#include <tvm/node/serialization.h>
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/attrs/manipulate.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/transform.h>
+namespace tvm {
+namespace relax {
+
+using namespace tir;
+static constexpr const char* kOperatorName = "operator_name";
+
+/*! \brief Construct ranges from shape dimensions */
+static Array<Range> ConstructRangeFromShape(const Array<PrimExpr>& shape) {
+  return shape.Map([](const PrimExpr& dim) { return 
Range(tir::make_zero(dim.dtype()), dim); });
+}
+
+static Array<PrimExpr> GetShapeFromTensorStructInfo(const TensorStructInfo& 
tensor_sinfo) {
+  auto shape = tensor_sinfo->GetShape();
+  ICHECK(shape.defined());
+  return shape.value();
+}
+
+static Array<PrimExpr> GetShapeFromTensor(const Expr& expr) {
+  const auto& tensor_sinfo = Downcast<TensorStructInfo>(expr->struct_info_);
+  return GetShapeFromTensorStructInfo(tensor_sinfo);
+}
+
+static IndexMap DeepCopyIndexMap(const IndexMap& index_map) {
+  return Downcast<IndexMap>(LoadJSON(SaveJSON(index_map)));
+}
+
+/*! \brief Checks if the \p transform is bijective on the shape of \p expr */
+bool IsTransformBijective(const Expr& expr, const IndexMap& transform) {
+  Array<PrimExpr> input_shape = GetShapeFromTensor(expr);
+  Array<Range> initial_ranges = ConstructRangeFromShape(input_shape);
+  auto [inverse, padding_predicate] = 
transform.NonSurjectiveInverse(initial_ranges);
+  (void)inverse;  // to avoid unused variable warning;
+  arith::Analyzer analyzer;
+  if (!analyzer.CanProve(!padding_predicate)) return false;
+  return true;
+}
+
+/*!
+ * \brief Replace each call_tir to PrimFunc which matches the kOperatorName 
attribute with the
+ * provided replacement PrimFunc and mark it with kFrozenLayout attribute. 
Insert layout
+ * transformations on i/o buffers as necessary for correctness.
+ */
+class AlterOpImplMutator : public ExprMutator {
+ public:
+  AlterOpImplMutator(const IRModule& mod, const Map<String, tir::PrimFunc>& 
op_impl_map,
+                     const Map<String, Array<IndexMap>>& op_buffer_transforms_)
+      : ExprMutator(mod),
+        mod_(mod),
+        op_impl_map_(op_impl_map),
+        op_buffer_transforms__(op_buffer_transforms_) {}
+
+  IRModule Run() {
+    for (const auto& [gv, func] : mod_->functions) {
+      if (func->IsInstance<relax::FunctionNode>()) {
+        relax::Function update_func = Downcast<Function>(VisitExpr(func));
+        builder_->UpdateFunction(gv, update_func);
+      }
+    }
+    return builder_->GetContextIRModule();
+  }
+
+ private:
+  Expr VisitExpr_(const CallNode* op) final {
+    auto call = Downcast<Call>(ExprMutator::VisitExpr_(op));
+
+    // TODO(@tvm-team): When we differentiate the call for tir function and 
packed function,
+    // this logic should be changed accordingly.
+    if (!call->op.same_as(call_tir_op_)) return call;
+
+    // Do not do anything for external function
+    if (call->args[0].as<ExternFuncNode>()) return call;
+
+    // Get operator name from callee
+    ICHECK(call->args[0]->IsInstance<GlobalVarNode>());
+    const tir::PrimFunc& old_func =
+        
Downcast<tir::PrimFunc>(mod_->Lookup(Downcast<GlobalVar>(call->args[0])));
+    Optional<String> maybe_op_kind = 
old_func->attrs.GetAttr<String>(kOperatorName);
+
+    // If the callee does not have kOperatorName attribute or no replacement 
is requested for
+    // it, nothing to do here.
+    if (!maybe_op_kind.defined() || op_impl_map_.count(maybe_op_kind.value()) 
== 0) return call;
+    auto op_kind = maybe_op_kind.value();
+
+    const auto& replacement_func = op_impl_map_[op_kind];
+
+    Array<IndexMap> buffer_transforms;
+    if (op_buffer_transforms__.count(op_kind)) buffer_transforms = 
op_buffer_transforms__[op_kind];
+
+    ICHECK(buffer_transforms.empty() || buffer_transforms.size() == 
replacement_func->params.size())
+        << "Either the i/o buffers do not require any transformations or 
transformations for each "
+           "buffer is provided.";
+    ICHECK_EQ(old_func->params.size(), replacement_func->params.size())
+        << "Number of parameters of old and replacement PrimFunc must match";
+
+    GlobalVar replacement_gv = GetOrCreateGlobalVarForFunc(replacement_func, 
op_kind);
+
+    auto call_tir_inputs_tuple = GetRef<Tuple>(call->args[1].as<TupleNode>());
+    Array<Expr> updated_calltir_args =
+        UpdatedArgs(replacement_gv, call_tir_inputs_tuple, buffer_transforms);
+
+    ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir sinfo_args.size() is 
expected to be 1";
+    StructInfo updated_ret_sinfo = UpdateStructInfo(call->sinfo_args[0], 
buffer_transforms);
+    auto updated_call = builder_->Normalize(Call(call_tir_op_, 
std::move(updated_calltir_args),
+                                                 call->attrs, 
{std::move(updated_ret_sinfo)}));
+
+    // Now transform each of the outputs to previous layout.
+    return TransformOutputs(updated_call, buffer_transforms, 
call->sinfo_args[0]);
+  }
+
+  Array<TensorStructInfo> GetTensorStructInfoPerOutput(const StructInfo& 
output_sinfo) {
+    if (const auto* tensor_sinfo = output_sinfo.as<TensorStructInfoNode>())
+      return {GetRef<TensorStructInfo>(tensor_sinfo)};
+    const auto* tuple_sinfo = output_sinfo.as<TupleStructInfoNode>();
+    ICHECK(tuple_sinfo);
+
+    Array<TensorStructInfo> arr_tensor_sinfo;
+    arr_tensor_sinfo.reserve(tuple_sinfo->fields.size());
+    for (const auto& sinfo : tuple_sinfo->fields) {
+      const auto* tensor_sinfo = sinfo.as<TensorStructInfoNode>();
+      ICHECK(tensor_sinfo) << "Nested tuples in output of call_tir is not 
supported yet";
+      arr_tensor_sinfo.push_back(GetRef<TensorStructInfo>(tensor_sinfo));
+    }
+    return arr_tensor_sinfo;
+  }
+
+  Expr TransformLayout(const Expr& expr, const IndexMap& index_map) {
+    ObjectPtr<LayoutTransformAttrs> attrs = 
make_object<LayoutTransformAttrs>();
+    // We want to avoid two layout_transform ops to share the same index map 
even if they are
+    // identical. The scope of vars used in index map initial indices is local 
to the op. Not doing
+    // so would confuse the structural equality check.
+    attrs->index_map = std::move(DeepCopyIndexMap(index_map));
+    return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {});
+  }
+
+  Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map,
+                              const TensorStructInfo& old_tensor_sinfo) {
+    Array<PrimExpr> old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo);
+    Array<Range> initial_ranges = ConstructRangeFromShape(old_shape);
+    auto [inverse_index_map, padding_predicate] = 
index_map.NonSurjectiveInverse(initial_ranges);
+    ICHECK(tir::is_zero(padding_predicate))
+        << "Only bijective transformations on input/output buffers are 
supported, but found "
+           "padding predicate "
+        << padding_predicate << " on initial range " << initial_ranges;
+    return TransformLayout(expr, inverse_index_map);
+  }
+
+  /*!
+   * \brief Adds the \p replacement_func to the module if it has not already 
been added before.
+   * \returns The global var associated with the PrimFunc.
+   */
+  GlobalVar GetOrCreateGlobalVarForFunc(const PrimFunc& replacement_func, 
const String& op_kind) {
+    if (cache_.count(replacement_func) != 0) {
+      return cache_[replacement_func];
+    }
+    // Retain the operator name attribute on the replacement PrimFunc. This 
can help any future
+    // passes that use kOperatorName attribute to identify operator 
represented by a PrimFunc.
+    PrimFunc replacement_func_with_frozen_layout =
+        WithAttr(replacement_func, kOperatorName, op_kind);
+
+    GlobalVar gv_replacement =
+        builder_->AddFunction(replacement_func_with_frozen_layout, op_kind + 
"_replacement");
+    cache_.Set(replacement_func, gv_replacement);
+    return gv_replacement;
+  }
+
+  /*!
+   * \brief Updates call_tir args with global var of replacement func and 
layout transformed inputs
+   */
+  Array<Expr> UpdatedArgs(const GlobalVar& replacement_gv, const Tuple& inputs,
+                          const Array<IndexMap>& transforms) {
+    if (transforms.empty()) return {replacement_gv, inputs};
+
+    Array<Expr> updated_inputs;
+    int index = 0;
+    for (const auto& input : inputs->fields) {
+      auto transform = transforms[index++];
+      ICHECK(IsTransformBijective(input, transform))
+          << "Non bijective transforms on input and output buffers are not 
supported.";
+      updated_inputs.push_back(TransformLayout(input, transform));
+    }
+    return {replacement_gv, Tuple(updated_inputs)};

Review Comment:
   right, thanks! updated the function



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to