This is an automated email from the ASF dual-hosted git repository.

xiyou pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 675a22e080 [Unity][Transform] Introduce data-dependent operation of 
reshape and its constant folding (#14282)
675a22e080 is described below

commit 675a22e080201c33217f5e35a305dc5a90878cf6
Author: Sunghyun Park <[email protected]>
AuthorDate: Tue Mar 21 13:33:56 2023 -0700

    [Unity][Transform] Introduce data-dependent operation of reshape and its 
constant folding (#14282)
    
    * FEAT: Support data-dependent operation of reshape
    
    * FEAT: Support constant folding with data-dependent reshape
    
    * fix
    
    * remove empty line
    
    * reflect feedback
    
    * Lift the lowering of tensor_to_shape from builtin to 
DecomposeCompositeOps pass
    
    * fix and comment
    
    * fix
    
    * add comments
    
    * reflect feedback
    
    * add comment
    
    * fix
---
 include/tvm/relax/transform.h                      |   7 +-
 python/tvm/relax/op/base.py                        |  14 ++
 .../tvm/relax/transform/legalize_ops/manipulate.py |   6 +-
 python/tvm/relax/transform/transform.py            |  11 +-
 python/tvm/script/ir_builder/relax/ir.py           |   2 +
 src/relax/backend/vm/vm_builtin_lower.cc           |  20 ++-
 src/relax/op/op.cc                                 |  26 +++
 src/relax/op/tensor/manipulate.cc                  |   7 +-
 ...orm_inference.cc => decompose_composite_ops.cc} |  51 +++++-
 src/relax/transform/fold_constant.cc               |  70 +++++++-
 src/runtime/relax_vm/builtin.cc                    |  34 ++++
 tests/python/relax/test_op_manipulate.py           |  25 +--
 ...y => test_transform_decompose_composite_ops.py} |  25 ++-
 tests/python/relax/test_transform_fold_constant.py |  37 ++++
 .../test_transform_legalize_ops_manipulate.py      | 195 +++++++++++++++++++++
 15 files changed, 493 insertions(+), 37 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index a3d0d4a0e9..4f45ba9c25 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -372,12 +372,13 @@ TVM_DLL Pass RunCodegen(Optional<Map<String, Map<String, 
ObjectRef>>> target_opt
                         Array<runtime::String> entry_functions);
 
 /*!
- * \brief Simplify normalization operators during inference. For example, the 
result
+ * \brief Decompose composite operators during inference. For example, the 
result
  * of a batch norm which is indexed at tuple index 0 will be unpacked into a
- * number of simplified operators.
+ * number of simplified operators. Operators like Attention, Erf, etc. can be 
also
+ * simplified into several operators as well.
  * \return The Pass.
  */
-TVM_DLL Pass SimplifyNormInference();
+TVM_DLL Pass DecomposeCompositeOperator();
 
 /*!
  * \brief Returns a pass which replaces PrimFuncs which have matching 
kOperatorName attribute in \p
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index aef0e731db..becd3f2a0f 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -402,3 +402,17 @@ def shape_of(expr: Expr) -> Expr:
         A relax Call, which gets the shape of the input
     """
     return _ffi_api.shape_of(expr)  # type: ignore # pylint: disable=no-member
+
+
+def tensor_to_shape(expr: Expr) -> Expr:
+    """Convert tensor to shape expr.
+    Parameters
+    ----------
+    expr : Expr
+        The input Expr
+    Returns
+    -------
+    result : Expr
+        A relax Call, which transforms the tensor values to the shape
+    """
+    return _ffi_api.tensor_to_shape(expr)  # type: ignore # pylint: 
disable=no-member
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py 
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index e7cae1af34..144ef04748 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -23,7 +23,7 @@ import tvm
 from tvm import topi, tir, relax, te
 from tvm.tir.expr import IntImm
 from ...block_builder import BlockBuilder
-from ...expr import Call, Expr, Var, Tuple, TupleGetItem
+from ...expr import Call, Expr, Var, Tuple, TupleGetItem, ShapeExpr
 from .common import TEFunc, LegalizeFunc, register_legalize
 
 
@@ -32,6 +32,10 @@ def _reshape(
 ) -> LegalizeFunc:
     def reshape_call_te(bb: BlockBuilder, call: Call):
         tgt_shape = call.args[1].struct_info.shape if is_collapse_sum_like 
else call.args[1]
+        # If target shape is Var, pass its bound expr only when it is ShapeExpr
+        if isinstance(tgt_shape, Var):
+            tgt_shape = bb.lookup_binding(tgt_shape)
+            assert isinstance(tgt_shape, ShapeExpr)
         return bb.call_te(te_func, call.args[0], tgt_shape, 
primfunc_name_hint=primfunc_name)
 
     return reshape_call_te
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index e8e3d73113..c03df804ee 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -588,10 +588,11 @@ def MetaScheduleTuneIRMod(
     return _ffi_api.MetaScheduleTuneIRMod(params, work_dir, max_trials_global) 
 # type: ignore
 
 
-def SimplifyNormInference() -> tvm.ir.transform.Pass:
-    """Simplify normalization operators during inference. For example, the 
result
-    of a batch norm which is indexed at tuple index 0 will be unpacked into a
-    number of simplified operators.
+def DecomposeCompositeOps() -> tvm.ir.transform.Pass:
+    """Decompose composite operators that are composed by other operators 
during inference.
+    For example, the result of a batch norm which is indexed at tuple index 0 
will be unpacked
+    into a number of simplified operators. Attention, tensor_to_shape, etc. 
can be also
+    decomposed into a number of simplified operators as well.
 
     Returns
     -------
@@ -599,7 +600,7 @@ def SimplifyNormInference() -> tvm.ir.transform.Pass:
         The registered pass
     """
 
-    return _ffi_api.SimplifyNormInference()  # type: ignore
+    return _ffi_api.DecomposeCompositeOps()  # type: ignore
 
 
 def AlterOpImpl(
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 32d6083e8a..ae0918a082 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -97,6 +97,7 @@ from tvm.relax.op import (
     prod,
     repeat,
     reshape,
+    tensor_to_shape,
     round,
     shape_of,
     std,
@@ -612,6 +613,7 @@ __all__ = [
     "prod",
     "repeat",
     "reshape",
+    "tensor_to_shape",
     "round",
     "shape",
     "shape_of",
diff --git a/src/relax/backend/vm/vm_builtin_lower.cc 
b/src/relax/backend/vm/vm_builtin_lower.cc
index 00d8512dc6..5bf4194997 100644
--- a/src/relax/backend/vm/vm_builtin_lower.cc
+++ b/src/relax/backend/vm/vm_builtin_lower.cc
@@ -129,9 +129,23 @@ class VMBuiltinLowerMutator : public ExprMutator {
   Expr Reshape(const Call& call_node) {
     ICHECK(call_node->args.size() == 2);
     ICHECK(call_node->struct_info_.defined());
-    CHECK(call_node->args[1]->IsInstance<ShapeExprNode>())
-        << "VMBuiltinLower expects the shape arg of reshape op to be a 
ShapeExpr";
-    return Call(builtin_reshape_, call_node->args, Attrs(), 
{GetStructInfo(call_node)});
+    auto arg = call_node->args[1];
+    CHECK(arg->IsInstance<ShapeExprNode>() || arg->IsInstance<VarNode>())
+        << "VMBuiltinLower expects the shape arg of reshape op to be a 
ShapeExpr or VarNode bound "
+           "to a ShapeExpr";
+
+    if (arg->IsInstance<ShapeExprNode>()) {
+      return Call(builtin_reshape_, call_node->args, Attrs(), 
{GetStructInfo(call_node)});
+    } else {
+      // Handling the case when arg is VarNode
+      Optional<Expr> _bound_val = LookupBinding(Downcast<Var>(arg));
+      ICHECK(_bound_val.defined());
+      Expr bound_val = _bound_val.value();
+      CHECK(bound_val->IsInstance<ShapeExprNode>())
+          << "VMBuiltinLower expects bound value to be a ShapeExpr";
+      return Call(builtin_reshape_, {call_node->args[0], bound_val}, Attrs(),
+                  {GetStructInfo(call_node)});
+    }
   }
 
   Expr ShapeOf(const Call& call_node) {
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index a603040394..49df881dcb 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -315,6 +315,32 @@ Expr MakeShapeOf(Expr expr) {
 
 TVM_REGISTER_GLOBAL("relax.op.shape_of").set_body_typed(MakeShapeOf);
 
+// tensor_to_shape
+
+StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder& 
ctx) {
+  ICHECK(call->args.size() == 1);
+  ICHECK(call->args[0]->struct_info_.defined());
+  const auto* tsinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  ICHECK(tsinfo && tsinfo->shape.defined());
+  ShapeExpr shape_expr = Downcast<ShapeExpr>(tsinfo->shape.value());
+  ICHECK(shape_expr->values.size() == 1);
+  const IntImmNode* ndim = shape_expr->values[0].as<IntImmNode>();
+  ICHECK(ndim);
+  return ShapeStructInfo(ndim->value);
+}
+
+RELAY_REGISTER_OP("relax.tensor_to_shape")
+    .set_num_inputs(1)
+    .add_argument("input", "Expr", "The input expression")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
ReturnTensorToShapeStructInfo);
+
+Expr MakeTensorToShape(Expr expr) {
+  static const Op& op = Op::Get("relax.tensor_to_shape");
+  return Call(op, {expr}, {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.tensor_to_shape").set_body_typed(MakeTensorToShape);
+
 // alloc_tensor
 
 StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& 
ctx) {
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index dbeb6f8d5b..faa5ee3bc0 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -751,7 +751,12 @@ StructInfo InferStructInfoReshape(const Call& call, const 
BlockBuilder& ctx) {
                        << new_shape_prod);
     }
   }
-  return TensorStructInfo(call->args[1], data_sinfo->dtype);
+  Expr target_shape = call->args[1];
+  // If shape values are defined, use them
+  if (target_shape->IsInstance<VarNode>() && 
new_shape_sinfo->values.defined()) {
+    return TensorStructInfo(ShapeExpr(new_shape_sinfo->values.value()), 
data_sinfo->dtype);
+  }
+  return TensorStructInfo(target_shape, data_sinfo->dtype);
 }
 
 TVM_REGISTER_OP("relax.reshape")
diff --git a/src/relax/transform/simplify_norm_inference.cc 
b/src/relax/transform/decompose_composite_ops.cc
similarity index 69%
rename from src/relax/transform/simplify_norm_inference.cc
rename to src/relax/transform/decompose_composite_ops.cc
index 545098db28..3681442221 100644
--- a/src/relax/transform/simplify_norm_inference.cc
+++ b/src/relax/transform/decompose_composite_ops.cc
@@ -21,6 +21,7 @@
 
 #include <tvm/relax/analysis.h>
 #include <tvm/relax/attrs/nn.h>
+#include <tvm/relax/struct_info.h>
 #include <tvm/relax/transform.h>
 
 #include "utils.h"
@@ -110,21 +111,63 @@ class NormInferenceSimplifier : public ExprMutator {
   Map<Expr, Expr> batch_norm_map_;
 };
 
+class OpDecomposer : public ExprMutator {
+ public:
+  static Expr Decompose(Expr expr) { return OpDecomposer()(expr); }
+
+ private:
+  using ExprMutator::VisitExpr_;
+  Expr TensorToShape(const Call& call_node) {
+    ICHECK(call_node->struct_info_.defined());
+    Expr expr = call_node->args[0];
+    const ShapeStructInfoNode* sinfo = 
GetStructInfoAs<ShapeStructInfoNode>(call_node);
+    ICHECK(sinfo);
+    // call builtin function that converts tensor to shape tuple
+    // TODO(@sunggg): Register operator for "vm.builtin.tensor_to_shape"
+    Var call = builder_->Emit(Call(ExternFunc("vm.builtin.tensor_to_shape"), 
{expr}, {},
+                                   {GetRef<ShapeStructInfo>(sinfo)}));
+
+    // Operators like reshape take the output of `TensorToShape` as their 
output shape.
+    // Because TOPI expects to have such output shape in symbolic shape at 
least (i.e.,
+    // Array<PrimExpr>), we define symbolic variables and returns them as a 
ShapeExpr.
+    Array<PrimExpr> shape_var;
+    for (int i = 0; i < sinfo->ndim; i++) {
+      shape_var.push_back(tir::Var("x", DataType::Int(64)));
+    }
+    // bind symbolic variables to the shape tuple
+    relax::Var var("y", ShapeStructInfo(shape_var));
+    builder_->EmitNormalized(MatchCast(var, call, ShapeStructInfo(shape_var)));
+    return ShapeExpr(shape_var);
+  }
+
+  Expr VisitExpr_(const CallNode* call_node) final {
+    Call call = Downcast<Call>(VisitExprPostOrder_(call_node));
+    if (call->op == tensor_to_shape_op_) {
+      return TensorToShape(call);
+    } else {
+      return call;
+    }
+  }
+
+  const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape");
+};
+
 namespace transform {
-Pass SimplifyNormInference() {
+Pass DecomposeCompositeOps() {
   runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
       [=](Function f, IRModule m, PassContext pc) {
         f = Downcast<Function>(NormInferenceSimplifier::Simplify(f));
-        // Remove original batch_norm op if it's not used.
+        f = Downcast<Function>(OpDecomposer::Decompose(f));
+        // Remove original ops if it's not used.
         return RemoveAllUnused(f);
       };
   return CreateFunctionPass(/*pass_function=*/pass_func,            //
                             /*opt_level=*/0,                        //
-                            /*pass_name=*/"SimplifyNormInference",  //
+                            /*pass_name=*/"DecomposeCompositeOps",  //
                             /*required=*/{});
 }
 
-TVM_REGISTER_GLOBAL("relax.transform.SimplifyNormInference").set_body_typed(SimplifyNormInference);
+TVM_REGISTER_GLOBAL("relax.transform.DecomposeCompositeOps").set_body_typed(DecomposeCompositeOps);
 
 }  // namespace transform
 }  // namespace relax
diff --git a/src/relax/transform/fold_constant.cc 
b/src/relax/transform/fold_constant.cc
index 622dd9ad09..315b3dc1f2 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -196,6 +196,10 @@ class ConstantFolder : public ExprMutator {
 
   using ExprMutator::VisitExpr_;
 
+  // TODO(@sunggg):
+  // Next PR will support fold with PackedFunc and MatchCast
+  // Until then, DecomposeCompositeOps() should be applied after
+  // this pass to fold `tensor_to_shape` op.
   Expr VisitExpr_(const CallNode* call) final {
     // post-order mutation
     Call post_call = Downcast<Call>(VisitExprPostOrder_(call));
@@ -217,14 +221,64 @@ class ConstantFolder : public ExprMutator {
       return VisitCallTIR(post_call).value_or(post_call);
     }
 
-    // If we are in a dataflow block, we can fold ops by lowering them to 
call_tir.
-    if (builder_->CurrentBlockIsDataFlow() && legalize_map.count(op)) {
-      // Get the legalized expression
-      Expr legalized_expr = builder_->Normalize(legalize_map[op](builder_, 
post_call));
-      // If the legalized expression is call_tir, try to fold it.
-      const CallNode* call = legalized_expr.as<CallNode>();
-      if (call && call->op.same_as(call_tir_op)) {
-        return VisitCallTIR(GetRef<Call>(call)).value_or(post_call);
+    // Special logic to fold ShapeExpr between operators
+    // e.g.,
+    //  <Before>
+    //     lv: R.Shape([16, 16]) = R.shape([16, 16])
+    //     gv: R.Tensor(lv2, dtype="float32") = R.reshape(data, lv)
+    //  <After>
+    //     gv: R.Tensor(lv2, dtype="float32") = R.reshape(data, R.shape([16, 
16]))
+    //
+    Array<Expr> new_args;
+    for (auto arg : post_call->args) {
+      if (arg->IsInstance<VarNode>()) {
+        Optional<Expr> val = LookupBinding(Downcast<Var>(arg));
+        if (val.defined() && val.value()->IsInstance<ShapeExprNode>()) {
+          new_args.push_back(val.value());
+          continue;
+        }
+      }
+      new_args.push_back(arg);
+    }
+    post_call =
+        Call(post_call->op, new_args, post_call->attrs, post_call->sinfo_args, 
post_call->span);
+
+    // If we are in a dataflow block, we can fold ops.
+    if (builder_->CurrentBlockIsDataFlow()) {
+      // Check if we can them to call_tir
+      if (legalize_map.count(op)) {
+        // Get the legalized expression
+        Expr legalized_expr = builder_->Normalize(legalize_map[op](builder_, 
post_call));
+        // If the legalized expression is call_tir, try to fold it.
+        const CallNode* call = legalized_expr.as<CallNode>();
+        if (call && call->op.same_as(call_tir_op)) {
+          return VisitCallTIR(GetRef<Call>(call)).value_or(post_call);
+        }
+      } else if (op->name == "relax.tensor_to_shape") {
+        // Special handling for composite op "relax.tensor_to_shape"
+        // If its input is constant, we can access its value and create 
ShapeExpr
+        // TODO(@sunggg):
+        //   currently, we do not have a info map about decomposition.
+        //   Thus, this is a temporary solution until we have a consensus about
+        //   how to deal with composite ops. One possibility is we register the
+        //   decomposition map for each op in a similar way we do for 
legalization.
+        ICHECK_EQ(post_call->args.size(), 1);
+        Expr arg = post_call->args[0];
+        if (arg->IsInstance<ConstantNode>()) {
+          Constant constant = Downcast<Constant>(arg);
+          runtime::NDArray ndarray = constant->data;
+          ICHECK_EQ(ndarray->device.device_type, kDLCPU);
+          ICHECK(ndarray->strides == nullptr);
+          ICHECK_EQ(ndarray->byte_offset, 0);
+          ICHECK_EQ(ndarray->ndim, 1);
+          const int64_t* data = static_cast<const int64_t*>(ndarray->data);
+          int64_t num_elems = ndarray->shape[0];
+          Array<PrimExpr> shape_values;
+          for (int64_t i = 0; i < num_elems; i++) {
+            shape_values.push_back(IntImm(DataType::Int(64), data[i]));
+          }
+          return ShapeExpr(shape_values);
+        }
       }
     }
 
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index 15a4f8702b..5a7c1d6620 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -380,6 +380,40 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.make_tuple").set_body([](TVMArgs args, TVMRetVal
   *rv = arr;
 });
 
+TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray 
data) {
+  NDArray arr = data;
+  if (data->device.device_type != kDLCPU) {
+    arr = data.CopyTo(DLDevice{kDLCPU, 0});
+  }
+
+  ICHECK_EQ(arr->ndim, 1);
+  ICHECK_EQ(arr->dtype.code, kDLInt);
+
+  std::vector<int64_t> out_shape;
+  for (int i = 0; i < arr.Shape()[0]; ++i) {
+    int64_t result;
+    switch (arr->dtype.bits) {
+      case 16: {
+        result = reinterpret_cast<int16_t*>(arr->data)[i];
+        break;
+      }
+      case 32: {
+        result = reinterpret_cast<int32_t*>(arr->data)[i];
+        break;
+      }
+      case 64: {
+        result = reinterpret_cast<int64_t*>(arr->data)[i];
+        break;
+      }
+      default:
+        LOG(FATAL) << "Unknown scalar int type: " << 
DLDataType2String(arr->dtype);
+        throw;
+    }
+    out_shape.push_back(result);
+  }
+  return ShapeTuple(out_shape);
+});
+
 }  // namespace relax_vm
 }  // namespace runtime
 }  // namespace tvm
diff --git a/tests/python/relax/test_op_manipulate.py 
b/tests/python/relax/test_op_manipulate.py
index 16bbc04d26..3edf63764a 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -88,12 +88,13 @@ def test_reshape_infer_struct_info():
     _check_inference(
         bb, relax.op.reshape(x5, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), 
dtype="")
     )
-    _check_inference(bb, relax.op.reshape(x0, s0), relax.TensorStructInfo(s0, 
"float32"))
-    _check_inference(bb, relax.op.reshape(x1, s0), relax.TensorStructInfo(s0, 
"float32"))
-    _check_inference(bb, relax.op.reshape(x2, s0), relax.TensorStructInfo(s0, 
"float32"))
-    _check_inference(bb, relax.op.reshape(x3, s0), relax.TensorStructInfo(s0, 
dtype=""))
-    _check_inference(bb, relax.op.reshape(x4, s0), relax.TensorStructInfo(s0, 
dtype=""))
-    _check_inference(bb, relax.op.reshape(x5, s0), relax.TensorStructInfo(s0, 
dtype=""))
+    # Remove Var from StructInfo when we can
+    _check_inference(bb, relax.op.reshape(x0, s0), relax.TensorStructInfo((3, 
8, 5), "float32"))
+    _check_inference(bb, relax.op.reshape(x1, s0), relax.TensorStructInfo((3, 
8, 5), "float32"))
+    _check_inference(bb, relax.op.reshape(x2, s0), relax.TensorStructInfo((3, 
8, 5), "float32"))
+    _check_inference(bb, relax.op.reshape(x3, s0), relax.TensorStructInfo((3, 
8, 5), dtype=""))
+    _check_inference(bb, relax.op.reshape(x4, s0), relax.TensorStructInfo((3, 
8, 5), dtype=""))
+    _check_inference(bb, relax.op.reshape(x5, s0), relax.TensorStructInfo((3, 
8, 5), dtype=""))
     _check_inference(bb, relax.op.reshape(x0, s1), relax.TensorStructInfo(s1, 
"float32"))
     _check_inference(bb, relax.op.reshape(x1, s1), relax.TensorStructInfo(s1, 
"float32"))
     _check_inference(bb, relax.op.reshape(x2, s1), relax.TensorStructInfo(s1, 
"float32"))
@@ -160,7 +161,8 @@ def test_reshape_infer_struct_info_shape_symbolic():
             (c, a * b * d, tir.floordiv(a * b * c * d, c * (a * b * d))), 
"float32"
         ),
     )
-    _check_inference(bb, relax.op.reshape(x, s0), relax.TensorStructInfo(s0, 
"float32"))
+    # Remove Var from StructInfo when we can
+    _check_inference(bb, relax.op.reshape(x, s0), relax.TensorStructInfo((c, 
a, d, b), "float32"))
     _check_inference(bb, relax.op.reshape(x, s1), relax.TensorStructInfo(s1, 
"float32"))
     _check_inference(bb, relax.op.reshape(x, s2), relax.TensorStructInfo(s2, 
"float32"))
 
@@ -188,17 +190,20 @@ def test_reshape_infer_struct_info_shape_var():
     _check_inference(
         bb, relax.op.reshape(x0, (3, -1, 5)), relax.TensorStructInfo((3, 8, 
5), "float32")
     )
-    _check_inference(bb, relax.op.reshape(x0, ns0), 
relax.TensorStructInfo(ns0, "float32"))
+    # Remove Var from StructInfo when we can
+    _check_inference(bb, relax.op.reshape(x0, ns0), relax.TensorStructInfo((3, 
8, 5), "float32"))
     _check_inference(bb, relax.op.reshape(x0, ns1), 
relax.TensorStructInfo(ns1, "float32"))
     _check_inference(
         bb, relax.op.reshape(x1, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), 
"float32")
     )
-    _check_inference(bb, relax.op.reshape(x1, ns0), 
relax.TensorStructInfo(ns0, "float32"))
+    # Remove Var from StructInfo when we can
+    _check_inference(bb, relax.op.reshape(x1, ns0), relax.TensorStructInfo((3, 
8, 5), "float32"))
     _check_inference(bb, relax.op.reshape(x1, ns1), 
relax.TensorStructInfo(ns1, "float32"))
     _check_inference(
         bb, relax.op.reshape(x2, (3, 8, 5)), relax.TensorStructInfo((3, 8, 5), 
"float32")
     )
-    _check_inference(bb, relax.op.reshape(x2, ns0), 
relax.TensorStructInfo(ns0, "float32"))
+    # Remove Var from StructInfo when we can
+    _check_inference(bb, relax.op.reshape(x2, ns0), relax.TensorStructInfo((3, 
8, 5), "float32"))
     _check_inference(bb, relax.op.reshape(x2, ns1), 
relax.TensorStructInfo(ns1, "float32"))
 
 
diff --git a/tests/python/relax/test_transform_simpilify_norm_inference.py 
b/tests/python/relax/test_transform_decompose_composite_ops.py
similarity index 86%
rename from tests/python/relax/test_transform_simpilify_norm_inference.py
rename to tests/python/relax/test_transform_decompose_composite_ops.py
index 3c981ba035..08483600a3 100644
--- a/tests/python/relax/test_transform_simpilify_norm_inference.py
+++ b/tests/python/relax/test_transform_decompose_composite_ops.py
@@ -22,7 +22,7 @@ import tvm.script
 import tvm.testing
 from tvm import IRModule, relax
 from tvm.relax import Function
-from tvm.script import relax as R
+from tvm.script import relax as R, tir as T
 
 
 def _check(before: Union[Function, IRModule], expected: Union[Function, 
IRModule]):
@@ -30,7 +30,7 @@ def _check(before: Union[Function, IRModule], expected: 
Union[Function, IRModule
         before = IRModule({"main": before})
     if isinstance(expected, Function):
         expected = IRModule({"main": expected})
-    after = relax.transform.SimplifyNormInference()(before)
+    after = relax.transform.DecomposeCompositeOps()(before)
     tvm.ir.assert_structural_equal(expected, after)
 
 
@@ -149,5 +149,26 @@ def test_batch_norm_complex():
     _check(before, expected)
 
 
+def test_op_tensor_to_shape():
+    @R.function
+    def before(t: R.Tensor(ndim=1, dtype="int64")):
+        gv: R.Shape(ndim=3) = R.tensor_to_shape(t)
+        return gv
+
+    @R.function
+    def expected(t: R.Tensor(dtype="int64", ndim=1)) -> R.Shape(ndim=3):
+        x = T.int64()
+        x_1 = T.int64()
+        x_2 = T.int64()
+        gv: R.Shape(ndim=3) = R.call_packed(
+            "vm.builtin.tensor_to_shape", t, sinfo_args=(R.Shape(ndim=3),)
+        )
+        y: R.Shape([x, x_1, x_2]) = R.match_cast(gv, R.Shape([x, x_1, x_2]))
+        gv_1: R.Shape([x, x_1, x_2]) = R.shape([x, x_1, x_2])
+        return gv_1
+
+    _check(before, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fold_constant.py 
b/tests/python/relax/test_transform_fold_constant.py
index 5bf2d3d9ab..ebd4348b64 100644
--- a/tests/python/relax/test_transform_fold_constant.py
+++ b/tests/python/relax/test_transform_fold_constant.py
@@ -349,6 +349,43 @@ def test_do_not_fold_ops_outside_dataflow():
     tvm.ir.assert_structural_equal(after, before)
 
 
+def test_fold_multiple_relax_ops_with_data_dependent_reshape():
+    @tvm.script.ir_module
+    class Module:
+        @R.function
+        def before(
+            data: R.Tensor((256,), "float32"),
+            c0: R.Tensor((2,), "int64"),
+            c1: R.Tensor((2,), "int64"),
+        ):
+            with R.dataflow():
+                lv0 = R.add(c0, c0)
+                target_shape = R.multiply(lv0, c1)
+                lv2: R.Shape(ndim=2) = R.tensor_to_shape(target_shape)
+                gv: R.Tensor(ndim=2, dtype="float32") = R.reshape(data, lv2)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def expected(data: R.Tensor((256,), "float32")) -> R.Tensor((16, 16), 
dtype="float32"):
+            R.func_attr({"global_symbol": "main"})
+            with R.dataflow():
+                gv: R.Tensor((16, 16), dtype="float32") = R.reshape(data, 
R.shape([16, 16]))
+                R.output(gv)
+            return gv
+
+    c0_np = [8, 8]
+    c1_np = [1, 1]
+    before = gen_mod(Module, "before", {"c0": c0_np, "c1": c1_np})
+    assert relax.analysis.well_formed(before)
+
+    c2_np = np.multiply(np.add(c0_np, c0_np), c1_np)
+    expected = gen_mod(Module, "expected", {"c2": c2_np})
+
+    after = relax.transform.FoldConstant()(before)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
 def test_unsupported_fold_ops_legalized_to_multiple_calls():
     @tvm.script.ir_module
     class Module:
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py 
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index b50ba91089..cce35a9026 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -17,6 +17,7 @@
 
 import pytest
 import tvm
+from tvm import relax
 from tvm.relax.transform import LegalizeOps
 from tvm.script import relax as R, tir as T, ir as I
 import tvm.testing
@@ -498,6 +499,55 @@ def test_reshape():
     mod = LegalizeOps()(Reshape)
     tvm.ir.assert_structural_equal(mod, Expected)
 
+    # fmt: off
+    # ShapeExpr might be produced by shape computation
+    @tvm.script.ir_module
+    class Reshape2:
+        @R.function
+        def main(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((8, 3), 
"float32"):
+            lv: R.Shape((8, 3)) = R.shape((8, 3))
+            gv: R.Tensor((8, 3), "float32") = R.reshape(x, lv)
+            return gv
+
+    # After lowering, redundant var might be removed by later dead code 
elimination
+    @tvm.script.ir_module
+    class Expected2:
+        @T.prim_func
+        def reshape(
+            rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(3), 
T.int64(4)), "float32"),
+            T_reshape: T.Buffer((T.int64(8), T.int64(3)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(T.int64(8), T.int64(3)):
+                with T.block("T_reshape"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(
+                        rxplaceholder[
+                            T.int64(0),
+                            (v_ax0 * T.int64(3) + v_ax1) % T.int64(24) // 
T.int64(12),
+                            (v_ax0 * T.int64(3) + v_ax1) % T.int64(12) // 
T.int64(4),
+                            (v_ax0 * T.int64(3) + v_ax1) % T.int64(4),
+                        ]
+                    )
+                    T.writes(T_reshape[v_ax0, v_ax1])
+                    T_reshape[v_ax0, v_ax1] = rxplaceholder[
+                        T.int64(0),
+                        (v_ax0 * T.int64(3) + v_ax1) % T.int64(24) // 
T.int64(12),
+                        (v_ax0 * T.int64(3) + v_ax1) % T.int64(12) // 
T.int64(4),
+                        (v_ax0 * T.int64(3) + v_ax1) % T.int64(4),
+                    ]
+
+        @R.function
+        def main(x: R.Tensor((1, 2, 3, 4), dtype="float32")) -> R.Tensor((8, 
3), dtype="float32"):
+            lv: R.Shape((8, 3)) = R.shape((8, 3))
+            gv = R.call_tir(Expected2.reshape, (x,), out_sinfo=R.Tensor((8, 
3), dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod2 = LegalizeOps()(Reshape2)
+    tvm.ir.assert_structural_equal(mod2, Expected2)
+
 
 def test_reshape_symbolic():
     # fmt: off
@@ -537,6 +587,151 @@ def test_reshape_symbolic():
     mod = LegalizeOps()(Reshape)
     tvm.ir.assert_structural_equal(mod, Expected)
 
+    # ShapeExpr might be produced by shape computation
+    @tvm.script.ir_module
+    class Reshape2:
+        @R.function
+        def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b 
* 2"), "float32"):
+            a = T.int64()
+            b = T.int64()
+            lv: R.Shape((a // 2, b * 2)) = R.shape((a // 2, b * 2))
+            gv: R.Tensor((a // 2, b * 2), "float32") = R.reshape(x, lv)
+            return gv
+
+    # After lowering, redundant var might be removed by later dead code 
elimination
+    @tvm.script.ir_module
+    class Expected2:
+        @R.function
+        def main(x: R.Tensor(("a", "b"), "float32")) -> R.Tensor(("a // 2", "b 
* 2"), "float32"):
+            a = T.int64()
+            b = T.int64()
+            lv: R.Shape((a // 2, b * 2)) = R.shape((a // 2, b * 2))
+            gv = R.call_tir(Expected2.reshape, (x,), R.Tensor(((a // 2), (b * 
2)), dtype="float32"))
+            return gv
+
+        @T.prim_func
+        def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle):
+            T.func_attr({"tir.noalias": True})
+            a = T.int64()
+            b = T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, [a, b], 
dtype="float32")
+            T_reshape = T.match_buffer(
+                var_T_reshape, [a // T.int64(2), b * T.int64(2)], 
dtype="float32"
+            )
+            for i0, i1 in T.grid(a // T.int64(2), b * T.int64(2)):
+                with T.block("T_reshape"):
+                    ax0, ax1 = T.axis.remap("SS", [i0, i1])
+                    T.reads(
+                        rxplaceholder[
+                            (ax0 * (b * T.int64(2)) + ax1) // b % a,
+                            (ax0 * (b * T.int64(2)) + ax1) % b,
+                        ]
+                    )
+                    T.writes(T_reshape[ax0, ax1])
+                    T_reshape[ax0, ax1] = rxplaceholder[
+                        (ax0 * (b * T.int64(2)) + ax1) // b % a, (ax0 * (b * 
T.int64(2)) + ax1) % b
+                    ]
+
+    mod2 = LegalizeOps()(Reshape2)
+    tvm.ir.assert_structural_equal(mod2, Expected2)
+
+    # ShapeExpr might be produced by shape computation
+    @I.ir_module
+    class Reshape3:
+        @R.function
+        def main(x: R.Tensor((10, "b"), "float32")) -> R.Tensor((5, "b * 2"), 
"float32"):
+            a = T.int64()
+            b = T.int64()
+            lv: R.Shape((5, b * 2)) = R.shape((5, b * 2))
+            gv: R.Tensor((5, b * 2), "float32") = R.reshape(x, lv)
+            return gv
+
+    # After lowering, redundant var might be removed by later dead code 
elimination
+    @I.ir_module
+    class Expected3:
+        @T.prim_func
+        def reshape(var_rxplaceholder: T.handle, var_T_reshape: T.handle):
+            T.func_attr({"tir.noalias": True})
+            b = T.int64()
+            rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(10), b))
+            T_reshape = T.match_buffer(var_T_reshape, (T.int64(5), b * 
T.int64(2)))
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(T.int64(5), b * T.int64(2)):
+                with T.block("T_reshape"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(
+                        rxplaceholder[
+                            (v_ax0 * (b * T.int64(2)) + v_ax1) // b % 
T.int64(10),
+                            (v_ax0 * (b * T.int64(2)) + v_ax1) % b,
+                        ]
+                    )
+                    T.writes(T_reshape[v_ax0, v_ax1])
+                    T_reshape[v_ax0, v_ax1] = rxplaceholder[
+                        (v_ax0 * (b * T.int64(2)) + v_ax1) // b % T.int64(10),
+                        (v_ax0 * (b * T.int64(2)) + v_ax1) % b,
+                    ]
+
+        @R.function
+        def main(
+            x: R.Tensor((10, "b"), dtype="float32")
+        ) -> R.Tensor((5, "b * 2"), dtype="float32"):
+            b = T.int64()
+            lv: R.Shape([5, b * 2]) = R.shape([5, b * 2])
+            gv = R.call_tir(
+                Expected3.reshape, (x,), out_sinfo=R.Tensor((5, b * 2), 
dtype="float32")
+            )
+            return gv
+
+    mod3 = LegalizeOps()(Reshape3)
+    tvm.ir.assert_structural_equal(mod3, Expected3)
+
+
+def test_data_dependent_reshape():
+    # fmt: off
+    @tvm.script.ir_module
+    class DDReshape:
+        @R.function
+        def main(x: R.Tensor((3, ), dtype="int64")):
+            lv: R.Shape([3,]) = R.tensor_to_shape(x)
+            gv = R.reshape(x, lv)
+            return gv
+    # fmt: on
+
+    assert relax.analysis.well_formed(DDReshape)
+    mod = relax.transform.DecomposeCompositeOps()(DDReshape)
+    out_mod = relax.transform.LegalizeOps()(mod)
+
+    # fmt: off
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def reshape(
+            rxplaceholder: T.Buffer((T.int64(3),), "int64"), var_T_reshape: 
T.handle
+        ):
+            T.func_attr({"tir.noalias": True})
+            x = T.int64()
+            T_reshape = T.match_buffer(var_T_reshape, (x,), "int64")
+            # with T.block("root"):
+            for ax0 in range(x):
+                with T.block("T_reshape"):
+                    v_ax0 = T.axis.spatial(x, ax0)
+                    T.reads(rxplaceholder[v_ax0 % T.int64(3)])
+                    T.writes(T_reshape[v_ax0])
+                    T_reshape[v_ax0] = rxplaceholder[v_ax0 % T.int64(3)]
+
+        @R.function
+        def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor((3,), 
dtype="int64"):
+            x_1 = T.int64()
+            gv: R.Shape([3]) = R.call_packed(
+                "vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape([3]),)
+            )
+            y: R.Shape([x_1]) = R.match_cast(gv, R.Shape([x_1]))
+            lv: R.Shape([x_1]) = R.shape([x_1])
+            gv_1 = R.call_tir(Expected.reshape, (x,), 
out_sinfo=R.Tensor((x_1,), dtype="int64"))
+            return gv_1
+    # fmt: on
+    tvm.ir.assert_structural_equal(out_mod, Expected)
+
 
 def test_split_by_indices():
     # fmt: off


Reply via email to