masahi commented on code in PR #15679:
URL: https://github.com/apache/tvm/pull/15679#discussion_r1336715785


##########
python/tvm/relax/transform/legalize_ops/manipulate.py:
##########
@@ -182,7 +182,15 @@ def te_layout_transform(data, name):
         )
 
     index_map: tvm.tir.IndexMap = call.attrs.index_map
-    pad_value = call.attrs.pad_value.value
+    pad_value = call.attrs.pad_value
+    if pad_value is not None:
+        pad_value = pad_value.value
+    else:
+        if "int" in call.args[0].struct_info.dtype:
+            pad_value = int(pad_value)

Review Comment:
   pad_value is None here



##########
src/relax/transform/alter_op_impl.cc:
##########
@@ -176,16 +190,41 @@ class AlterOpImplMutator : public ExprMutator {
   Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map,
                               const TensorStructInfo& old_tensor_sinfo,
                               const Array<IntImm>& axis_separator) {
+    if (IsScalarConstant(expr) || index_map.get() == nullptr) {
+      return expr;
+    }
     Array<PrimExpr> old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo);
     Array<Range> initial_ranges = ConstructRangeFromShape(old_shape);
     arith::Analyzer analyzer;
     auto [inverse_index_map, padding_predicate] =
         index_map.NonSurjectiveInverse(initial_ranges, &analyzer);
-    ICHECK(tir::is_zero(padding_predicate))
-        << "Only bijective transformations on input/output buffers are 
supported, but found "
-           "padding predicate "
-        << padding_predicate << " on initial range " << initial_ranges;
-    return TransformLayout(expr, inverse_index_map, axis_separator);
+
+    if (tir::is_zero(padding_predicate)) {
+      return TransformLayout(expr, inverse_index_map, axis_separator);
+    } else {
+      auto padded_expr =
+          builder_->Normalize(TransformLayout(expr, inverse_index_map, 
axis_separator));
+      const auto& tensor_sinfo = 
Downcast<TensorStructInfo>(padded_expr->struct_info_);
+      Array<PrimExpr> padded_shape = 
GetShapeFromTensorStructInfo(tensor_sinfo);
+
+      te::Tensor placeholder_tensor = te::placeholder(padded_shape, 
tensor_sinfo->dtype, "input");
+      te::Tensor output_tensor = te::compute(
+          old_shape,
+          [&placeholder_tensor](const Array<tir::Var>& indices) {
+            return placeholder_tensor(indices);
+          },
+          "output", topi::kElementWise);
+
+      String op_name = "remove_pad";
+      PrimFunc remove_pad_with_frozen_layout =
+          WithAttr(CreatePrimFunc({placeholder_tensor, output_tensor}), 
kOperatorName, op_name);
+
+      GlobalVar gv_remove_pad = 
builder_->AddFunction(remove_pad_with_frozen_layout, op_name);
+      builder_->UpdateFunction(gv_remove_pad,
+                               WithoutAttr(remove_pad_with_frozen_layout, 
"global_symbol"));

Review Comment:
   Code between L210 - L224 is hard to read and it seems it doesn't need to be 
defined here. I think you can define `gv_remove_pad` with dynamic input shape 
once at the constructor and use that repeatedly here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to