This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3635945  Refactor Dynamic to Static (#7368)
3635945 is described below

commit 3635945e48d9d1a7d8e76df418f057a4a3b88dc4
Author: Matthew Brookhart <[email protected]>
AuthorDate: Mon Feb 1 16:55:23 2021 -0700

    Refactor Dynamic to Static (#7368)
    
    * DynamicToStatic Refactor
    
    * fix test
    
    * add regression tests
    
    * cleanup
    
    * skip PrepareInput if the arg is already a constant
    
    * fix an issue with type inference with global functions
---
 src/relay/transforms/dynamic_to_static.cc         | 155 ++++++++++++++--------
 tests/python/relay/test_pass_dynamic_to_static.py |  44 +++++-
 2 files changed, 138 insertions(+), 61 deletions(-)

diff --git a/src/relay/transforms/dynamic_to_static.cc 
b/src/relay/transforms/dynamic_to_static.cc
index c580f60..815e4d2 100644
--- a/src/relay/transforms/dynamic_to_static.cc
+++ b/src/relay/transforms/dynamic_to_static.cc
@@ -34,27 +34,30 @@ namespace relay {
 
 class DynamicToStaticMutator : public MixedModeMutator {
  public:
-  DynamicToStaticMutator() {
+  DynamicToStaticMutator(IRModule mod, Function func) : mod_(mod), func_(func) 
{
     op_map_ = {
         {Op::Get("dyn.reshape"),
-         [](const CallNode* call_node) {
-           if (const ConstantNode* shape = 
call_node->args[1].as<ConstantNode>()) {
+         [this](const CallNode* call_node) {
+           auto args = PrepareArgs(call_node);
+           if (const ConstantNode* shape = args[1].as<ConstantNode>()) {
              ICHECK_EQ(shape->data->ndim, 1);
              return MakeReshape(call_node->args[0], ToVector(shape->data));
            }
            return Expr(nullptr);
          }},
         {Op::Get("dyn.tile"),
-         [](const CallNode* call_node) {
-           if (const ConstantNode* reps = 
call_node->args[1].as<ConstantNode>()) {
+         [this](const CallNode* call_node) {
+           auto args = PrepareArgs(call_node);
+           if (const ConstantNode* reps = args[1].as<ConstantNode>()) {
              ICHECK_EQ(reps->data->ndim, 1);
              return MakeTile(call_node->args[0], ToVector(reps->data));
            }
            return Expr(nullptr);
          }},
         {Op::Get("dyn.topk"),
-         [](const CallNode* call_node) {
-           if (const ConstantNode* k = call_node->args[1].as<ConstantNode>()) {
+         [this](const CallNode* call_node) {
+           auto args = PrepareArgs(call_node);
+           if (const ConstantNode* k = args[1].as<ConstantNode>()) {
              const TopKAttrs* param = call_node->attrs.as<TopKAttrs>();
              ICHECK(param);
              return MakeTopK(call_node->args[0], 
static_cast<int>(ToScalar(k->data, 0)),
@@ -63,16 +66,18 @@ class DynamicToStaticMutator : public MixedModeMutator {
            return Expr(nullptr);
          }},
         {Op::Get("dyn.broadcast_to"),
-         [](const CallNode* call_node) {
-           if (const ConstantNode* shape = 
call_node->args[1].as<ConstantNode>()) {
+         [this](const CallNode* call_node) {
+           auto args = PrepareArgs(call_node);
+           if (const ConstantNode* shape = args[1].as<ConstantNode>()) {
              ICHECK_EQ(shape->data->ndim, 1);
              return MakeBroadCastTo(call_node->args[0], ToVector(shape->data));
            }
            return Expr(nullptr);
          }},
         {Op::Get("dyn.zeros"),
-         [](const CallNode* call_node) {
-           if (const ConstantNode* shape = 
call_node->args[0].as<ConstantNode>()) {
+         [this](const CallNode* call_node) {
+           auto args = PrepareArgs(call_node);
+           if (const ConstantNode* shape = args[0].as<ConstantNode>()) {
              const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
              ICHECK(param);
              return MakeZeros(ToVector(shape->data), param->dtype);
@@ -80,8 +85,9 @@ class DynamicToStaticMutator : public MixedModeMutator {
            return Expr(nullptr);
          }},
         {Op::Get("dyn.ones"),
-         [](const CallNode* call_node) {
-           if (const ConstantNode* shape = 
call_node->args[0].as<ConstantNode>()) {
+         [this](const CallNode* call_node) {
+           auto args = PrepareArgs(call_node);
+           if (const ConstantNode* shape = args[0].as<ConstantNode>()) {
              const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
              ICHECK(param);
              return MakeOnes(ToVector(shape->data), param->dtype);
@@ -89,8 +95,9 @@ class DynamicToStaticMutator : public MixedModeMutator {
            return Expr(nullptr);
          }},
         {Op::Get("dyn.one_hot"),
-         [](const CallNode* call_node) {
-           if (const ConstantNode* depth = 
call_node->args[3].as<ConstantNode>()) {
+         [this](const CallNode* call_node) {
+           auto args = PrepareArgs(call_node);
+           if (const ConstantNode* depth = args[3].as<ConstantNode>()) {
              const OneHotAttrs* param = call_node->attrs.as<OneHotAttrs>();
              ICHECK(param);
              return MakeOneHot(call_node->args[0], call_node->args[1], 
call_node->args[2],
@@ -100,8 +107,9 @@ class DynamicToStaticMutator : public MixedModeMutator {
            return Expr(nullptr);
          }},
         {Op::Get("dyn.image.resize"),
-         [](const CallNode* call_node) {
-           if (const ConstantNode* size = 
call_node->args[1].as<ConstantNode>()) {
+         [this](const CallNode* call_node) {
+           auto args = PrepareArgs(call_node);
+           if (const ConstantNode* size = args[1].as<ConstantNode>()) {
              const ResizeAttrs* param = call_node->attrs.as<ResizeAttrs>();
              ICHECK(param);
              auto size_int = ToVector(size->data);
@@ -115,8 +123,9 @@ class DynamicToStaticMutator : public MixedModeMutator {
            return Expr(nullptr);
          }},
         {Op::Get("dyn.full"),
-         [](const CallNode* call_node) {
-           if (const ConstantNode* shape = 
call_node->args[1].as<ConstantNode>()) {
+         [this](const CallNode* call_node) {
+           auto args = PrepareArgs(call_node);
+           if (const ConstantNode* shape = args[1].as<ConstantNode>()) {
              ICHECK_EQ(shape->data->ndim, 1);
              const InitOpAttrs* param = call_node->attrs.as<InitOpAttrs>();
              ICHECK(param);
@@ -125,9 +134,10 @@ class DynamicToStaticMutator : public MixedModeMutator {
            return Expr(nullptr);
          }},
         {Op::Get("dyn.nn.upsampling"),
-         [](const CallNode* call_node) {
-           const ConstantNode* scale_h = call_node->args[1].as<ConstantNode>();
-           const ConstantNode* scale_w = call_node->args[2].as<ConstantNode>();
+         [this](const CallNode* call_node) {
+           auto args = PrepareArgs(call_node);
+           const ConstantNode* scale_h = args[1].as<ConstantNode>();
+           const ConstantNode* scale_w = args[2].as<ConstantNode>();
            if (scale_h && scale_w) {
              ICHECK_EQ(scale_h->data->ndim, 0);
              ICHECK_EQ(scale_w->data->ndim, 0);
@@ -140,10 +150,11 @@ class DynamicToStaticMutator : public MixedModeMutator {
            return Expr(nullptr);
          }},
         {Op::Get("dyn.nn.upsampling3d"),
-         [](const CallNode* call_node) {
-           const ConstantNode* scale_d = call_node->args[1].as<ConstantNode>();
-           const ConstantNode* scale_h = call_node->args[2].as<ConstantNode>();
-           const ConstantNode* scale_w = call_node->args[3].as<ConstantNode>();
+         [this](const CallNode* call_node) {
+           auto args = PrepareArgs(call_node);
+           const ConstantNode* scale_d = args[1].as<ConstantNode>();
+           const ConstantNode* scale_h = args[2].as<ConstantNode>();
+           const ConstantNode* scale_w = args[3].as<ConstantNode>();
            if (scale_d && scale_h && scale_w) {
              ICHECK_EQ(scale_d->data->ndim, 0);
              ICHECK_EQ(scale_h->data->ndim, 0);
@@ -159,9 +170,10 @@ class DynamicToStaticMutator : public MixedModeMutator {
            return Expr(nullptr);
          }},
         {Op::Get("dyn.nn.pad"),
-         [](const CallNode* call_node) {
-           const ConstantNode* pad_width = 
call_node->args[1].as<ConstantNode>();
-           const ConstantNode* pad_fill = 
call_node->args[2].as<ConstantNode>();
+         [this](const CallNode* call_node) {
+           auto args = PrepareArgs(call_node);
+           const ConstantNode* pad_width = args[1].as<ConstantNode>();
+           const ConstantNode* pad_fill = args[2].as<ConstantNode>();
            if (pad_width && pad_fill) {
              ICHECK_EQ(pad_fill->data->ndim, 0);   // pad_val is 1d
              ICHECK_EQ(pad_width->data->ndim, 2);  // pad_width is 2d
@@ -174,10 +186,11 @@ class DynamicToStaticMutator : public MixedModeMutator {
            return Expr(nullptr);
          }},
         {Op::Get("dyn.strided_slice"),
-         [](const CallNode* call_node) {
-           const ConstantNode* begin = call_node->args[1].as<ConstantNode>();
-           const ConstantNode* end = call_node->args[2].as<ConstantNode>();
-           const ConstantNode* stride = call_node->args[3].as<ConstantNode>();
+         [this](const CallNode* call_node) {
+           auto args = PrepareArgs(call_node);
+           const ConstantNode* begin = args[1].as<ConstantNode>();
+           const ConstantNode* end = args[2].as<ConstantNode>();
+           const ConstantNode* stride = args[3].as<ConstantNode>();
            if (begin && end && stride) {
              ICHECK_EQ(begin->data->ndim, 1);
              ICHECK_EQ(end->data->ndim, 1);
@@ -190,8 +203,9 @@ class DynamicToStaticMutator : public MixedModeMutator {
            return Expr(nullptr);
          }},
         {Op::Get("dyn.sparse_to_dense"),
-         [](const CallNode* call_node) {
-           const ConstantNode* output_shape = 
call_node->args[3].as<ConstantNode>();
+         [this](const CallNode* call_node) {
+           auto args = PrepareArgs(call_node);
+           const ConstantNode* output_shape = args[3].as<ConstantNode>();
            if (output_shape) {
              ICHECK_EQ(output_shape->data->ndim, 1);
              return MakeSparseToDense(call_node->args[0], 
ToVector(output_shape->data),
@@ -200,6 +214,45 @@ class DynamicToStaticMutator : public MixedModeMutator {
            return Expr(nullptr);
          }},
     };
+    Map<BaseFunc, GlobalVar> vars;
+    for (auto kv : mod_->functions) {
+      vars.Set(kv.second, kv.first);
+    }
+    gv_ = vars[func_];
+  }
+
+  Expr PrepareInput(const Expr& expr) {
+    BaseFunc func;
+    if (auto* func_node = expr.as<BaseFuncNode>()) {
+      func = GetRef<BaseFunc>(func_node);
+    } else {
+      func =
+          relay::Function(relay::FreeVars(expr), expr, Type(), 
relay::FreeTypeVars(expr, mod_), {});
+    }
+    mod_->Update(gv_, func);
+    mod_ = transform::FoldConstant()(mod_);
+    mod_ = transform::InferType()(mod_);
+    mod_ = transform::FoldConstant()(mod_);
+    mod_ = transform::InferType()(mod_);
+    Expr out;
+    if (expr.as<FunctionNode>()) {
+      out = mod_->Lookup(gv_);
+    } else {
+      out = mod_->Lookup(gv_).as<FunctionNode>()->body;
+    }
+    return out;
+  }
+
+  std::vector<Expr> PrepareArgs(const CallNode* call_node) {
+    std::vector<Expr> args;
+    for (auto arg : call_node->args) {
+      if (arg.as<ConstantNode>()) {
+        args.emplace_back(arg);
+      } else {
+        args.emplace_back(PrepareInput(arg));
+      }
+    }
+    return args;
   }
 
  private:
@@ -222,35 +275,19 @@ class DynamicToStaticMutator : public MixedModeMutator {
     }
     return post;
   }
+
   std::unordered_map<Expr, std::function<Expr(const CallNode*)>, 
ObjectPtrHash, ObjectPtrEqual>
       op_map_;
+  IRModule mod_;
+  Function func_;
+  GlobalVar gv_;
 };
 
 Expr DynamicToStatic(Function f, IRModule m) {
-  Expr pre = f;
-  Expr expr = f;
-  auto fold_const = transform::FoldConstant();
-  auto infer_type = transform::InferType();
-  DynamicToStaticMutator mutator;
-  Map<BaseFunc, GlobalVar> vars;
-  for (auto kv : m->functions) {
-    vars.Set(kv.second, kv.first);
-  }
-  const auto gv = vars[f];
-  // Put a limit on the while loop
-  // Primarily used to prevent accidental infinite lops in development
-  const int loop_limit = 1000;
-  int i = 0;
-  do {
-    pre = expr;
-    // TODO(mbrookhart): Is it possible to run these passes JUST on the 
current function?
-    m = infer_type(m);
-    m = fold_const(m);
-    expr = mutator.Mutate(m->functions[gv]);
-    m->Update(gv, Downcast<BaseFunc>(expr));
-    i += 1;
-  } while (!StructuralEqual()(pre, expr) && i < loop_limit);
-  return expr;
+  DynamicToStaticMutator mutator(m, f);
+  Expr expr = mutator.Mutate(f);
+  Expr out = mutator.PrepareInput(expr);
+  return out;
 }
 
 namespace transform {
diff --git a/tests/python/relay/test_pass_dynamic_to_static.py 
b/tests/python/relay/test_pass_dynamic_to_static.py
index 141023d..c9e047a 100644
--- a/tests/python/relay/test_pass_dynamic_to_static.py
+++ b/tests/python/relay/test_pass_dynamic_to_static.py
@@ -232,11 +232,11 @@ def test_dynamic_to_static_zeros_ones():
 
             func = run_infer_type(relay.Function([x], y))
             func2 = run_opt_pass(
-                run_opt_pass(func, transform.DynamicToStatic()), 
transform.InferType()
+                run_opt_pass(func, transform.DynamicToStatic()),
+                transform.InferType(),
             )
 
             zz = func2.body
-            assert isinstance(zz, relay.Constant)
             assert zz.checked_type == relay.ty.TensorType(shape, dtype)
 
             x_data = np.random.uniform(low=1, high=1, size=shape)
@@ -518,5 +518,45 @@ def test_dyn_to_static_sparse_to_dense():
     verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0])  # default value 
not specified
 
 
[email protected]_gpu
+def test_dynamic_to_static_dynamic_rank():
+    def verify_full(fill_value, fill_shape, dtype):
+        x = relay.var("x", relay.scalar_type(dtype))
+        y = relay.var("y", relay.TensorType(fill_shape, "int64"))
+        shape = relay.shape_of(y)
+        shape = relay.strided_slice(shape, [0], relay.shape_of(shape))
+        z = relay.full(x, shape, dtype)
+
+        func = relay.Function([x, y], z)
+        func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), 
transform.InferType())
+
+        zz = func2.body
+        assert isinstance(zz, relay.Call)
+        assert zz.op == relay.op.get("full")
+
+        ref_res = np.full(fill_shape, fill_value).astype(dtype)
+        y_data = np.random.uniform(low=-1, high=1, 
size=fill_shape).astype("int64")
+        verify_func(func2, [fill_value, y_data], ref_res)
+
+    verify_full(4, (1, 2, 3, 4), "int32")
+    verify_full(4.0, (1, 2, 8, 10), "float32")
+
+
[email protected]_gpu
+def test_dynamic_to_static_dynamic_if():
+    x = relay.var("x", relay.TensorType((2, 2), "int64"))
+    cond = relay.const(1)
+    iff = relay.If(cond, relay.reshape(x, [1, 4]), relay.reshape(x, (4, 1)))
+
+    func = relay.Function([x], iff)
+    func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), 
transform.InferType())
+
+    zz = func2.body
+    assert isinstance(zz, relay.Call)
+    assert zz.op == relay.op.get("reshape")
+    x_data = np.random.uniform(low=-1, high=1, size=(2, 2)).astype("int64")
+    verify_func(func2, [x_data], x_data.reshape(1, 4))
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to