This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 2658ebe  Dynamic ONNX Importer (#6351)
2658ebe is described below

commit 2658ebe737d38b441dee6121c01ba3f9f83ce518
Author: Matthew Brookhart <[email protected]>
AuthorDate: Fri Oct 2 20:32:39 2020 -0600

    Dynamic ONNX Importer (#6351)
    
    * Change onnx importer to use dynamic upsampling3d (#3)
    
    fix pylint
    
    * Refactor ONNX frontend to be dynamic
    
    Make OneHot dynamic
    
    Support BatchMatMul with dynamically shaped inputs
    
    fix dynamic broadcast
    
    Add null checks to broadcast_to rel functions
    
    fail more isolated broadcast_to test
    
    use StructuralEqual instead of pointer comparisions in dynamic_to_static 
pass
    
    add an optional weight freeze argument to onnx importer
    
    convert onnx resize to dynamic op
    
    add dynamic expand to onnx importer
    
    add a shape_func for power
    
    fix BERTSquad, lint
    
    handle onnx graph initializer parameters more intelligently
    
    * Dynamic ONNX importer: Upsampling and Pad (#2)
    
    fix lint
    
    fix Call reference
    
    fix a type issue with expand
    
    fix a bad test refactor
    
    respond to review comments, fix batch matmul tests
    
    * black format
    
    * fix batch matmul test
    
    * add dynamic strided slice to the onnx importer
    
    * fix clip importer
    
    * fix qnn tutorial
    
    * fix bad merge, respond to review comments
    
    * add a simple dynamic model test
    
    * Add dynamic-shaped autopadding to convolution and pooling ops
    
    * fix dynamic issues in a few ops
    
    * fix pylint
    
    * disable tests onnxrt doesn't support
    
    * fix pytorch test
    
    * respond to review comments
    
    * add documentation about partially supporting dynamic shapes
    
    Co-authored-by: Lily Orth-Smith <[email protected]>
---
 include/tvm/relay/transform.h                     |  11 +
 include/tvm/topi/broadcast.h                      |  11 +-
 python/tvm/relay/frontend/onnx.py                 | 607 +++++++++++-----------
 python/tvm/relay/op/_tensor.py                    |   1 +
 python/tvm/relay/op/nn/_nn.py                     |  51 +-
 python/tvm/relay/op/strategy/generic.py           |   2 +-
 python/tvm/relay/op/strategy/x86.py               |  21 +-
 python/tvm/topi/cuda/batch_matmul.py              |   2 +-
 python/tvm/topi/nn/batch_matmul.py                |  25 +-
 python/tvm/topi/x86/batch_matmul.py               |   6 +-
 src/relay/backend/build_module.cc                 |   3 +
 src/relay/op/dyn/tensor/transform.cc              |  18 +-
 src/relay/op/nn/convolution.h                     |  77 ++-
 src/relay/op/nn/nn.cc                             |  53 +-
 src/relay/op/nn/nn.h                              |   8 +-
 src/relay/op/tensor/transform.cc                  |  10 +-
 src/relay/transforms/dynamic_to_static.cc         |   9 +-
 tests/python/frontend/onnx/test_forward.py        | 413 ++++++++++++---
 tests/python/relay/dyn/test_dynamic_op_level10.py |  82 ++-
 tests/python/relay/test_op_level10.py             |  27 +
 tutorials/frontend/from_onnx.py                   |   9 +
 21 files changed, 957 insertions(+), 489 deletions(-)

diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index de2bcc4..faa2698 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -209,6 +209,17 @@ TVM_DLL Pass SimplifyInference();
 TVM_DLL Pass FastMath();
 
 /*!
+ * \brief Find Dynamic ops and make them static
+ *
+ * Searches the graph for dynamic ops. If the dynamic inputs to those ops are 
constants, it replaces
+ * them with static ops and re-performs type inference and constant folding. 
The pass repeats
+ * itself until the graph stops changing or we run too many iterations.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass DynamicToStatic();
+
+/*!
  * \brief Infer the type of an expression.
  *
  * The result of type checking is a new expression with unambigous
diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h
index 8fabaae..d03ddc9 100644
--- a/include/tvm/topi/broadcast.h
+++ b/include/tvm/topi/broadcast.h
@@ -54,14 +54,19 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& 
t,
       << "\nvs\ninput: " << t;
   auto bh = detail::BroadcastShape(output_shape, t->shape);
   CHECK_EQ(output_shape.size(), bh.common_shape.size());
+  Array<PrimExpr> oshape;
   for (size_t i = 0; i < output_shape.size(); ++i) {
-    CHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i]));
+    if (output_shape[i].as<tir::IntImmNode>() == nullptr) {
+      oshape.push_back(output_shape[i]);
+    } else {
+      CHECK(topi::detail::EqualCheck(output_shape[i], bh.common_shape[i]));
+      oshape.push_back(bh.common_shape[i]);
+    }
   }
   auto l = [&](tvm::Array<tvm::tir::Var> ovars) {
     return t(detail::InputIndexFromBroadcast(ovars, t, bh.vars2, bh.all_vars));
   };
-  return tvm::te::compute(tvm::Array<tvm::PrimExpr>(bh.common_shape.begin(), 
bh.common_shape.end()),
-                          l, name, tag);
+  return tvm::te::compute(oshape, l, name, tag);
 }
 
 #define TOPI_DEFINE_BCAST_OP(Name, ComputeRule)                                
                   \
diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index 841ff77..59fdb32 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -28,31 +28,12 @@ from .. import function as _function
 from .. import op as _op
 from .. import vision as _vision
 
-from ..function import Function
-from ..expr import Call, Let
-from ..expr import If, Tuple, TupleGetItem
-from ..expr import RefCreate, RefRead, RefWrite
-from ..expr_functor import ExprFunctor
-from ..adt import Match, Clause
-from ..op.tensor import minimum as _minimum, maximum as _maximum
-
 from .common import AttrCvt, Renamer
 from .common import get_relay_op, new_var, infer_shape, infer_channels
 from .common import infer_type, get_name
-from .common import infer_value as _infer_value
-from .common import infer_value_simulated as _infer_value_simulated
-
-__all__ = ["from_onnx"]
-
-g = None
 
 
-def infer_value(input_val, params, mod=None):
-    return g.infer_value(input_val, params, mod)
-
-
-def infer_value_simulated(input_val, params):
-    return g.infer_value_simulated(input_val, params)
+__all__ = ["from_onnx"]
 
 
 class onnx_input:
@@ -256,21 +237,28 @@ class Pool(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        input_shape = infer_shape(inputs[0])
+        data = inputs[0]
+        input_shape = infer_shape(data)
+        ndim = len(input_shape)
         if "auto_pad" in attr:
             attr["auto_pad"] = attr["auto_pad"].decode("utf-8")
             if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"):
-                pad_tuple = []
-                for axis in range(len(input_shape) - 2):
-                    axis_shape = input_shape[2 + axis]
-                    stride = attr["strides"][axis]
-                    kernel = attr["kernel_shape"][axis]
-                    pad = get_pad_pair(axis_shape, kernel, stride)
-                    pad_tuple.append(pad)
-                pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in 
pair])
-                attr["pads"] = pad_tuple
+                if cls.name == "avg_pool":
+                    pad_tuple = []
+                    for axis in range(len(input_shape) - 2):
+                        axis_shape = input_shape[2 + axis]
+                        stride = attr["strides"][axis]
+                        kernel = attr["kernel_shape"][axis]
+                        pad = get_pad_pair(axis_shape, kernel, stride)
+                        pad_tuple.append(pad)
+                    pad_tuple = tuple([val for pair in zip(*pad_tuple) for val 
in pair])
+                    attr["pads"] = pad_tuple
+                else:
+                    # Warning: Pool does not yet support dynamic shapes,
+                    # one will need to run dynamic_to_static on this model 
after import
+                    data = autopad(data, attr["strides"], 
attr["kernel_shape"], [1] * ndim, ndim)
             elif attr["auto_pad"] == "VALID":
-                attr["pads"] = 0
+                attr["pads"] = tuple([0 for i in range(ndim - 2)])
             elif attr["auto_pad"] == "NOTSET":
                 pass
             else:
@@ -290,7 +278,7 @@ class Pool(OnnxOpConverter):
             transforms={"kernel_shape": "pool_size", "pads": ("padding", 0)},
             ignores=["dilations", "storage_order"],
             custom_check=dimension_constraint(),
-        )(inputs, attr, params)
+        )([data], attr, params)
 
 
 class Absolute(Unary):
@@ -331,29 +319,68 @@ class InstanceNorm(OnnxOpConverter):
         return AttrCvt(op_name="instance_norm")(inputs, attr, params)
 
 
+def autopad(data, strides, kernel_shape, dilations, ndim, pad_type="constant", 
deconv=False):
+    """
+    Perform autopadding with dynamic input shapes
+    """
+    # get attributes as constants
+    strides = _op.const(np.array(strides), dtype="int64")
+    dilated_kernel_shape = _op.const(
+        np.array(
+            [(kernel - 1) * dilation + 1 for kernel, dilation in 
zip(kernel_shape, dilations)]
+        ),
+        dtype="int64",
+    )
+    shape = _op.strided_slice(_op.shape_of(data, dtype="int64"), [2], [ndim])
+    # get input shape
+
+    # set up integer constants
+    zero = _op.const(0, dtype="int64")
+    one = _op.const(1, dtype="int64")
+    two = _op.const(2, dtype="int64")
+
+    # Calculate total padding
+    mod = _op.mod(shape, strides)
+
+    left = _op.maximum(dilated_kernel_shape - strides, zero)
+    right = _op.maximum(dilated_kernel_shape - mod, zero)
+
+    total_pad = _op.where(_op.equal(mod, zero), left, right)
+    if deconv:
+        total_pad = _op.const(np.array(kernel_shape), dtype="int64") - one - 
total_pad
+
+    # split total padding into before and after
+    pad_before = _op.floor_divide(total_pad, two)
+    pad_after = total_pad - pad_before
+
+    # combine
+    pad = _op.concatenate(
+        [_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], 
axis=1
+    )
+
+    # pad N and C with zeros
+    pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), 
dtype="int64"), pad], axis=0)
+
+    return _op.nn.pad(data, pad, _op.const(0.0), pad_type)
+
+
 class Conv(OnnxOpConverter):
     """Operator converter for Conv."""
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
         # Use shape of input to determine convolution type.
-        input_shape = infer_shape(inputs[0])
+        data = inputs[0]
+        input_shape = infer_shape(data)
+        ndim = len(input_shape)
         if "auto_pad" in attr:
             attr["auto_pad"] = attr["auto_pad"].decode("utf-8")
             if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"):
-                pad_tuple = []
-                for axis in range(len(input_shape) - 2):
-                    axis_shape = input_shape[2 + axis]
-                    stride = attr["strides"][axis]
-                    kernel = attr["kernel_shape"][axis]
-                    dilation = attr["dilations"][axis]
-                    dilated_kernel = (kernel - 1) * dilation + 1
-                    pad = get_pad_pair(axis_shape, dilated_kernel, stride)
-                    pad_tuple.append(pad)
-                pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in 
pair])
-                attr["pads"] = pad_tuple
+                # Warning: Convolution does not yet support dynamic shapes,
+                # one will need to run dynamic_to_static on this model after 
import
+                data = autopad(data, attr["strides"], attr["kernel_shape"], 
attr["dilations"], ndim)
             elif attr["auto_pad"] == "VALID":
-                attr["pads"] = tuple([0 for i in range(len(input_shape) - 2)])
+                attr["pads"] = tuple([0 for i in range(ndim - 2)])
             elif attr["auto_pad"] == "NOTSET":
                 pass
             else:
@@ -381,7 +408,7 @@ class Conv(OnnxOpConverter):
                 "group": ("groups", 1),
             },
             custom_check=dimension_constraint(),
-        )(inputs[:2], attr, params)
+        )([data, inputs[1]], attr, params)
 
         use_bias = len(inputs) == 3
         if use_bias:
@@ -400,21 +427,24 @@ class ConvTranspose(OnnxOpConverter):
         groups = attr.pop("group")
         attr["groups"] = groups
         # infer pads for auto_pad
+        data = inputs[0]
+        input_shape = infer_shape(data)
+        ndim = len(input_shape)
         if "auto_pad" in attr:
             attr["auto_pad"] = attr["auto_pad"].decode("utf-8")
             if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"):
-                input_shape = infer_shape(inputs[0])
-                in_h, in_w = input_shape[2], input_shape[3]
-                stride_h, stride_w = attr["strides"]
-                kernel_h, kernel_w = attr["kernel_shape"]
-                dilation_h, dilation_w = attr["dilations"]
-                dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
-                dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
-                pad_v = get_pad_pair(in_h, dilated_kernel_h, stride_h)
-                pad_h = get_pad_pair(in_w, dilated_kernel_w, stride_w)
-                attr["pads"] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1])
+                # Warning: Convolution does not yet support dynamic shapes,
+                # one will need to run dynamic_to_static on this model after 
import
+                data = autopad(
+                    data,
+                    attr["strides"],
+                    attr["kernel_shape"],
+                    attr["dilations"],
+                    ndim,
+                    deconv=True,
+                )
             elif attr["auto_pad"] == "VALID":
-                attr["pads"] = (0, 0)
+                attr["pads"] = tuple([0 for i in range(ndim - 2)])
             elif attr["auto_pad"] == "NOTSET":
                 pass
             else:
@@ -426,12 +456,13 @@ class ConvTranspose(OnnxOpConverter):
             op_name=dimension_picker("conv", "_transpose"),
             transforms={
                 "kernel_shape": "kernel_size",
-                "dilations": ("dilation", (0, 0)),
-                "pads": ("padding", (0, 0), revert_caffe2_pad),
+                "dilations": ("dilation", 1),
+                "pads": ("padding", 0),
+                "group": ("groups", 1),
             },
             disables=["output_shape"],
             custom_check=dimension_constraint(),
-        )(inputs[:2], attr, params)
+        )([data, inputs[1]], attr, params)
         use_bias = len(inputs) == 3
         if use_bias:
             out = _op.nn.bias_add(out, inputs[2])
@@ -492,25 +523,46 @@ class MatMul(OnnxOpConverter):
     def _impl_v1(cls, inputs, attr, params):
         assert len(inputs) == 2, "MatMul op take 2 inputs, {} 
given".format(len(inputs))
         # Need to check input shape as batch matmul must be supported.
-        a_shape = infer_shape(inputs[0])
+        a_shape = _op.shape_of(inputs[0])
         # When performing a batch matmul, we need to properly handle N-dim 
shapes.
-        if len(a_shape) > 2:
-            b_shape = infer_shape(inputs[1])
+        if infer_shape(a_shape)[0] > 2:
+            b_shape = _op.shape_of(inputs[1])
+
+            def flatten_to_3d(x, x_shape):
+                ndims = infer_shape(x_shape)[0]
+                newshape = _op.concatenate(
+                    [_expr.const([-1]), _op.strided_slice(x_shape, [ndims - 
2], [ndims])], 0
+                )
+                out = _op.reshape(x, newshape)
+                return out
+
             # Convert a and b into 3 dimensional tensors.
-            a = _op.reshape(inputs[0], [-1, a_shape[-2], a_shape[-1]])
-            b = _op.reshape(inputs[1], [-1, b_shape[-2], b_shape[-1]])
+            a = flatten_to_3d(inputs[0], a_shape)
+            b = flatten_to_3d(inputs[1], b_shape)
             # Broadcast b to match batch size of a
-            new_b_shape = list(infer_shape(b))
-            new_a_shape = infer_shape(a)
-            if new_a_shape[0] > new_b_shape[0]:
-                new_b_shape[0] = new_a_shape[0]
-                b = _op.broadcast_to(b, new_b_shape)
+            new_b_shape = _op.concatenate(
+                [
+                    _op.strided_slice(_op.shape_of(a), [0], [1]),
+                    _op.strided_slice(_op.shape_of(b), [1], [3]),
+                ],
+                0,
+            )
+            b = _op.broadcast_to(b, new_b_shape)
             # Transpose matrix dimensions of b.
             b = _op.transpose(b, [0, 2, 1])
             # Perform a batch matmul.
             output = _op.nn.batch_matmul(a, b)
             # Reshape output to original dimensions.
-            return _op.reshape(output, [*a_shape[:-2], a_shape[-2], 
b_shape[-1]])
+            final_shape = _op.concatenate(
+                [
+                    _op.strided_slice(a_shape, [0], [infer_shape(a_shape)[0] - 
1]),
+                    _op.strided_slice(
+                        b_shape, [infer_shape(b_shape)[0] - 1], 
[infer_shape(b_shape)[0]]
+                    ),
+                ],
+                0,
+            )
+            return _op.reshape(output, final_shape)
         # Otherwise a simple dense op will get the job done.
         input_1_t = _op.transpose(inputs[1], axes=(1, 0))
         return _op.nn.dense(inputs[0], input_1_t)
@@ -545,23 +597,18 @@ class LpPool(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        input_shape = infer_shape(inputs[0])
         dtype = infer_type(inputs[0]).checked_type.dtype
-
+        data = inputs[0]
+        input_shape = infer_shape(data)
+        ndim = len(input_shape)
         if "auto_pad" in attr:
             attr["auto_pad"] = attr["auto_pad"].decode("utf-8")
             if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"):
-                pad_tuple = []
-                for axis in range(len(input_shape) - 2):
-                    axis_shape = input_shape[2 + axis]
-                    stride = attr["strides"][axis]
-                    kernel = attr["kernel_shape"][axis]
-                    pad = get_pad_pair(axis_shape, kernel, stride)
-                    pad_tuple.append(pad)
-                pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in 
pair])
-                attr["pads"] = pad_tuple
+                # Warning: LpPool does not yet support dynamic shapes,
+                # one will need to run dynamic_to_static on this model after 
import
+                data = autopad(data, attr["strides"], attr["kernel_shape"], 
[1] * ndim, ndim)
             elif attr["auto_pad"] == "VALID":
-                attr["pads"] = 0
+                attr["pads"] = tuple([0 for i in range(ndim - 2)])
             elif attr["auto_pad"] == "NOTSET":
                 pass
             else:
@@ -578,7 +625,7 @@ class LpPool(OnnxOpConverter):
 
         p = _expr.const(attr["p"], dtype)
         reci_p = _expr.const(1.0 / attr["p"], dtype)
-        inputs[0] = _op.power(inputs[0], p)
+        data = _op.power(data, p)
 
         out = AttrCvt(
             op_name=dimension_picker("avg_pool"),
@@ -586,7 +633,7 @@ class LpPool(OnnxOpConverter):
             extras={"count_include_pad": True},
             ignores=["p"],
             custom_check=dimension_constraint(),
-        )(inputs, attr, params)
+        )([data], attr, params)
         kernels = attr["kernel_shape"]
         out = _op.abs(out) * _expr.const(np.prod(kernels).astype(dtype))
         return _op.power(out, reci_p)
@@ -651,27 +698,23 @@ class Pad(OnnxOpConverter):
 
     @classmethod
     def _impl_v11(cls, inputs, attr, params):
-        pad_width = []
-        pads = infer_value_simulated(inputs[1], params).asnumpy()
+        pads = inputs[1]
         if len(inputs) == 3:
-            value = infer_value_simulated(inputs[2], params).asnumpy().item()
+            value = _op.take(inputs[2], _op.const(0))
         else:
             value = 0
-        attr["pad_value"] = value
-        dims = int(len(pads) / 2)
-        for i in range(dims):
-            pad_width.append((pads[i], pads[i + dims]))
-        attr["pad_width"] = pad_width
+
+        pads_shape = infer_shape(pads)
+        dims = int(pads_shape[0] / 2)
+        pad_width_expr = _op.transpose(_op.reshape(pads, (2, dims)))
         pad_mode = attr.get("mode", b"constant").decode("utf-8")
-        if pad_mode in ["constant", "edge", "reflect"]:
-            attr["pad_mode"] = pad_mode
-            attr.pop("mode", None)
-        else:
+
+        if not pad_mode in ["constant", "edge", "reflect"]:
             raise tvm.error.OpAttributeInvalid(
                 "Value " + pad_mode + ' in attribute "mode" is invalid for 
operator Pad.'
             )
 
-        return AttrCvt("pad")(inputs[:1], attr, params)
+        return _op.nn.pad(inputs[0], pad_width_expr, value, pad_mode=pad_mode)
 
 
 class ParametricSoftPlus(OnnxOpConverter):
@@ -736,9 +779,7 @@ class Reshape(OnnxOpConverter):
             shape = 
tuple(params.pop(inputs[1].name_hint).asnumpy().astype("int32"))
             out = _op.reshape(inputs[0], shape)
         else:
-            data, shape = inputs
-            static_shape = infer_value_simulated(shape, params)
-            out = _op.reshape(data, 
newshape=tuple(static_shape.asnumpy().astype("int32")))
+            out = _op.reshape(*inputs)
         return out
 
 
@@ -883,17 +924,22 @@ class Upsample(OnnxOpConverter):
     @classmethod
     def _impl_v9(cls, inputs, attr, params):
         scales = attr.get("scales")
+
+        input_shape = infer_shape(inputs[0])
+        dims = len(input_shape)
+
         if not scales:
             # Here we are going to higher OPSET version.
-            assert len(inputs) == 2, "Upsample op take 2 inputs, {} 
given".format(len(inputs))
+            assert len(inputs) == 2, "Upsample op takes 2 inputs, {} 
given".format(len(inputs))
+
             if get_name(inputs[1]) in params:
                 scales = params[inputs[1].name_hint].asnumpy()
             else:
-                scales = infer_value_simulated(inputs[1], params).asnumpy()
-            inputs = inputs[:1]
-        assert scales[0] == 1.0 and scales[1] == 1.0
-        input_shape = infer_shape(inputs[0])
-        dims = len(input_shape)
+                scales = inputs[1]
+
+        if not isinstance(scales, _expr.Call):
+            assert scales[0] == 1.0 and scales[1] == 1.0
+
         mode = attr.get("mode")
         if mode == b"nearest":
             method = "nearest_neighbor"
@@ -903,21 +949,47 @@ class Upsample(OnnxOpConverter):
             raise tvm.error.OpAttributeInvalid(
                 'Value {} in attribute "mode" of operator Upsample is not 
valid.'.format(mode)
             )
-        attr = {"scale_h": scales[-2], "scale_w": scales[-1], "method": method}
+
+        if method == "nearest_neighbor":
+            align_corners = False
+        else:
+            align_corners = True
+        # in 3d case, we use the purely static op
         if dims == 5:
-            assert len(scales) == 5
-            attr["scale_d"] = scales[-3]
-            attr["layout"] = "NCDHW"
-            op_name = "upsampling3d"
+            if isinstance(scales, _expr.Call):
+                scale_h = _op.take(scales, _op.const(3))
+                scale_w = _op.take(scales, _op.const(4))
+                scale_d = _op.take(scales, _op.const(1))
+            else:
+                assert len(scales) == 5
+                scale_h = scales[-2]
+                scale_w = scales[-1]
+                scale_d = scales[-3]
+
+            layout = "NCDHW"
+            out = _op.nn.upsampling3d(
+                inputs[0], scale_d, scale_h, scale_w, layout=layout, 
method=method
+            )
+        # in 2d case, use dynamic op
         else:
-            assert len(scales) == 4
-            attr["layout"] = "NCHW"
-            if method == "nearest_neighbor":
-                attr["align_corners"] = False
+            if isinstance(scales, _expr.Call):
+                scale_h = _op.take(scales, _op.const(3))
+                scale_w = _op.take(scales, _op.const(4))
             else:
-                attr["align_corners"] = True
-            op_name = "upsampling"
-        return AttrCvt(op_name)(inputs, attr)
+                assert len(scales) == 4
+                scale_h = scales[-2]
+                scale_w = scales[-1]
+            layout = "NCHW"
+
+            out = _op.nn.upsampling(
+                inputs[0],
+                scale_h,
+                scale_w,
+                layout=layout,
+                method=method,
+                align_corners=align_corners,
+            )
+        return out
 
 
 class Shape(OnnxOpConverter):
@@ -970,8 +1042,7 @@ class Split(OnnxOpConverter):
                 attr["indices_or_sections"].append(index)
         # When splits isnt specified divide evenly over axis.
         else:
-            in_shape = infer_shape(inputs[0])
-            attr["indices_or_sections"] = in_shape[attr["axis"]]
+            attr["indices_or_sections"] = attr["tvm_custom"]["num_outputs"]
         return AttrCvt("split", ignores=["split"])(inputs, attr, params)
 
 
@@ -1022,38 +1093,35 @@ class Slice(OnnxOpConverter):
 
     @classmethod
     def _impl_v10(cls, inputs, attr, params):
-        attrs = {"starts": inputs[1], "ends": inputs[2]}
-        if len(inputs) >= 4:
-            attrs["axes"] = inputs[3]
-        if len(inputs) >= 5:
-            attrs["steps"] = inputs[4]
-
-        attrs = {k: (v, get_name(v)) for (k, v) in attrs.items()}
-        attrs = {
-            k: params[v[1]].asnumpy()
-            if v[1] in params
-            else infer_value_simulated(v[0], params).asnumpy()
-            for (k, v) in attrs.items()
-        }
+        starts = inputs[1]
+        ends = inputs[2]
+        axes = inputs[3]
+        steps = inputs[4]
 
-        # Update the starts and ends according to axes if required.
-        if "axes" in attrs and max(attrs["axes"] + 1) != len(attrs["axes"]):
-            new_starts, new_ends, _ = cls._common(attrs["starts"], 
attrs["ends"], attrs["axes"])
-            attrs["starts"] = new_starts
-            attrs["ends"] = new_ends
+        data_rank = len(infer_shape(inputs[0]))
 
-        begins = list(attrs["starts"])
-        ends = list(attrs["ends"])
-        strides = [1] * len(begins)
+        # Update the starts and ends according to axes if required.
+        if axes is not None:
+            data_shape = _op.shape_of(inputs[0], 
dtype=infer_type(ends).checked_type.dtype)
+            starts = _op.scatter(
+                _op.const([0] * data_rank, 
dtype=infer_type(starts).checked_type.dtype),
+                axes,
+                starts,
+                axis=0,
+            )
+            ends = _op.scatter(data_shape, axes, ends, axis=0)
+            if steps is not None:
+                steps = _op.scatter(
+                    _op.const([1] * data_rank, 
dtype=infer_type(steps).checked_type.dtype),
+                    axes,
+                    steps,
+                    axis=0,
+                )
 
-        if "steps" in attrs:
-            steps = list(attrs["steps"])
-            axes = attrs["axes"]
-            assert len(steps) == len(axes)
-            for axis, step in zip(axes, steps):
-                strides[axis] = step
+        if steps is None:
+            steps = _op.const([1] * data_rank, 
dtype=infer_type(starts).checked_type.dtype)
 
-        return _op.strided_slice(inputs[0], begin=begins, end=ends, 
strides=strides)
+        return _op.strided_slice(inputs[0], starts, ends, steps)
 
 
 class Gather(OnnxOpConverter):
@@ -1337,8 +1405,6 @@ class OneHot(OnnxOpConverter):
         off_value, on_value = _op.take(values, _op.const(0)), _op.take(values, 
_op.const(1))
         # Extract the datatype of the output from on_value.
         dtype = infer_type(on_value).checked_type.dtype
-        # Convert depth into an integer.
-        depth = int(infer_value(depth, params).asnumpy()[0])
         # set default value when axis is not set in the model
         if "axis" not in attr:
             attr["axis"] = -1
@@ -1357,8 +1423,7 @@ class ConstantOfShape(OnnxOpConverter):
         else:
             value = _expr.const(0)
             dtype = "float32"
-        static_shape = infer_value_simulated(inputs[0], params)
-        output = _op.full(value, 
shape=tuple(static_shape.asnumpy().astype("int32")), dtype=dtype)
+        output = _op.full(value, inputs[0], dtype=dtype)
         return output
 
 
@@ -1406,8 +1471,7 @@ class Tile(Elemwise):
 
     @classmethod
     def _impl_v6(cls, inputs, attr, params):
-        reps = tuple(infer_value_simulated(inputs[1], 
params).asnumpy().astype("int32"))
-        return _op.tile(inputs[0], reps)
+        return _op.tile(inputs[0], inputs[1])
 
 
 class Erf(OnnxOpConverter):
@@ -1466,11 +1530,9 @@ class Expand(OnnxOpConverter):
 
     @classmethod
     def _impl_v8(cls, inputs, attr, params):
-        in_shape = np.array(infer_shape(inputs[0])).astype("int32")
-        if get_name(inputs[1]) in params:
-            shape = params[inputs[1].name_hint].asnumpy().astype("int32")
-        else:
-            shape = infer_value_simulated(inputs[1], 
params).asnumpy().astype("int32")
+        dtype = infer_type(inputs[1]).checked_type.dtype
+        in_shape = _op.shape_of(inputs[0], dtype=dtype)
+        shape = inputs[1]
 
         # Currently 'op.broadcast_to' expect the rank of the given 'shape'
         # (the 2nd input) is always higher than that of the given 'input' (the 
1st input)
@@ -1485,28 +1547,41 @@ class Expand(OnnxOpConverter):
             intput. Also it replaces the extent of the shape with the 
corresponding extent
             of the intput when it is 1.
             """
-
-            # here we flip the shapes because this can be more simply written
-            # when the innermost dimension is located at the index 0.
-            in_shape = np.flip(in_shape, axis=0)
-            shape = np.flip(shape, axis=0)
-
-            if in_shape.size < shape.size:
-                for i in range(shape.size):
-                    if i < in_shape.size and in_shape[i] > shape[i]:
-                        shape[i] = in_shape[i]
-            else:
-                for i in range(in_shape.size):
-                    if i >= shape.size:
-                        np.append(shape, in_shape[i])
-                    elif shape[i] == 1:
-                        shape[i] = in_shape[i]
-
-            new_shape = np.flip(shape, axis=0)
+            in_dims = infer_shape(in_shape)[0]
+            new_dims = infer_shape(shape)[0]
+            if in_dims < new_dims:
+                in_shape = _op.concatenate(
+                    [
+                        _expr.const(
+                            [
+                                1,
+                            ]
+                            * (new_dims - in_dims),
+                            dtype=dtype,
+                        ),
+                        in_shape,
+                    ],
+                    axis=0,
+                )
+            elif new_dims > in_dims:
+                shape = _op.concatenate(
+                    [
+                        _expr.const(
+                            [
+                                1,
+                            ]
+                            * (in_dims - new_dims),
+                            dtype=dtype,
+                        ),
+                        shape,
+                    ],
+                    axis=0,
+                )
+            new_shape = _op.maximum(in_shape, shape)
             return new_shape
 
         shape = expand_shape(in_shape, shape)
-        return _op.broadcast_to(inputs[0], shape=tuple(shape))
+        return _op.broadcast_to(inputs[0], shape=shape)
 
 
 class RNN(OnnxOpConverter):
@@ -1779,14 +1854,18 @@ class Resize(OnnxOpConverter):
                 'Value {} in attribute "mode" of operator Resize is not 
valid.'.format(mode)
             )
 
-        in_size = np.array(infer_shape(inputs[0]))
-        scale = infer_value_simulated(inputs[2], params).asnumpy()
+        scale = inputs[2]
+        scale_shape = infer_shape(scale)
         if len(inputs) == 4:
-            assert len(scale) == 0, "One of scale or size should be passed, 
not both."
-            size = infer_value_simulated(inputs[3], 
params).asnumpy().astype(np.int32)
+            assert (
+                len(scale_shape) == 0 or scale_shape[0] == 0
+            ), "One of scale or size should be passed, not both."
+            size = inputs[3]
         else:
-            assert len(scale) != 0, "One of scale or size should be passed."
-            size = (in_size * scale).astype(np.int32)
+            assert len(scale_shape) != 0, "One of scale or size should be 
passed."
+            size = (
+                _op.cast(_op.shape_of(inputs[0]), 
infer_type(scale).type_annotation.dtype) * scale
+            )
 
         coord_trans = attr.get("coordinate_transformation_mode")
         if coord_trans in [b"pytorch_half_pixel", b"half_pixel"]:
@@ -1800,7 +1879,7 @@ class Resize(OnnxOpConverter):
                 "Unsupported coordinate_transformation_mode: 
{}".format(coord_trans)
             )
         layout = "NCHW"  # ONNX assumes NCHW layout
-        out_size = (size[2], size[3])
+        out_size = _op.strided_slice(size, [2], [4])
         return _op.image.resize(inputs[0], out_size, layout, method, 
coord_trans)
 
 
@@ -1831,9 +1910,7 @@ class TopK(OnnxOpConverter):
         if largest == 0:
             raise ValueError("TVM only supports finding TopK largest elements")
 
-        K = int(infer_value(inputs[1], params).asnumpy()[0])
-
-        return _op.topk(inputs[0], k=K, axis=axis)
+        return _op.topk(inputs[0], inputs[1], axis=axis)
 
 
 class MaxRoiPool(OnnxOpConverter):
@@ -1898,7 +1975,7 @@ class Clip(OnnxOpConverter):
 
         assert len(inputs) <= 3, "Clip-11 takes up to 3 inputs, input, min, 
max"
         result = inputs[0]
-        for i, op in enumerate([_maximum, _minimum]):
+        for i, op in enumerate([_op.tensor.maximum, _op.tensor.minimum]):
             if i < len(inputs) - 1:
                 result = op(result, inputs[i + 1])
         return result
@@ -2061,7 +2138,7 @@ def _get_convert_map(opset):
     }
 
 
-class GraphProto(ExprFunctor):
+class GraphProto:
     """A helper class for handling Relay expression copying from 
pb2.GraphProto.
     Definition: https://github.com/onnx/onnx/blob/master/onnx/onnx.proto
 
@@ -2077,108 +2154,22 @@ class GraphProto(ExprFunctor):
     def __init__(self, shape, dtype):
         self._nodes = {}
         self._params = {}
+        self._inputs = {}
         self._renames = {}
         self._num_input = 0
         self._num_param = 0
         self._shape = shape if shape else {}
         self._dtype = dtype
 
-        # For infering Values
-        self._tmp_params = {}
-        self._infer_simulated = True
-        self._mod = None
-        super(GraphProto, self).__init__()
-
-    def infer_value(self, input_val, params, mod=None):
-        self._tmp_params = params
-        self._infer_simulated = False
-        self._mod = mod
-        return self.visit(input_val).data
-
-    def infer_value_simulated(self, input_val, params):
-        self._tmp_params = params
-        self._infer_simulated = True
-        return self.visit(input_val).data
-
-    def infer(self, expr):
-        if self._infer_simulated:
-            out = _infer_value_simulated(expr, self._tmp_params)
-        else:
-            out = _infer_value(expr, self._tmp_params)
-        return _expr.const(out.asnumpy())
-
-    def visit_function(self, fn):
-        new_params = [self.visit(x) for x in fn.params]
-        new_body = self.visit(fn.body)
-        return self.infer(
-            Function(list(new_params), new_body, fn.ret_type, fn.type_params, 
fn.attrs)
-        )
-
-    def visit_let(self, let):
-        newvar = self.visit(let.var)
-        newval = self.visit(let.value)
-        newbody = self.visit(let.body)
-        return self.infer(Let(newvar, newval, newbody))
-
-    def visit_call(self, call):
-        new_fn = self.visit(call.op)
-        new_args = [self.visit(arg) for arg in call.args]
-        call = Call(new_fn, new_args, call.attrs)
-        if new_fn == _op.get("nn.batch_norm"):
-            return call
-        return self.infer(call)
-
-    def visit_var(self, var):
-        return self.infer(var)
-
-    def visit_global_id(self, global_var):
-        return self.infer(global_var)
-
-    def visit_if(self, ite):
-        return self.infer(
-            If(self.visit(ite.cond), self.visit(ite.true_branch), 
self.visit(ite.false_branch))
-        )
-
-    def visit_tuple(self, tup):
-        return Tuple([self.visit(field) for field in tup.fields])
-
-    def visit_tuple_getitem(self, op):
-        tuple_value = self.visit(op.tuple_value)
-        if not tuple_value.same_as(op.tuple_value):
-            return self.infer(TupleGetItem(tuple_value, op.index))
-        return self.infer(op)
-
-    def visit_global_var(self, gvar):
-        return self.infer(gvar)
-
-    def visit_op(self, op):
-        return op
-
-    def visit_constant(self, const):
-        return const
+    def freeze(self, func, params):
+        bind_map = {}
+        for name in params.keys():
+            bind_map[self._nodes[name]] = _expr.const(params[name])
+        body = _expr.bind(func.body, bind_map)
+        fn = _function.Function(analysis.free_vars(body), body)
+        return fn, {}
 
-    def visit_constructor(self, con):
-        return con
-
-    def visit_match(self, m):
-        return self.infer(
-            Match(
-                self.visit(m.data),
-                [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses],
-                complete=m.complete,
-            )
-        )
-
-    def visit_ref_create(self, r):
-        return RefCreate(self.visit(r.value))
-
-    def visit_ref_write(self, r):
-        return RefWrite(self.visit(r.ref), self.visit(r.value))
-
-    def visit_ref_read(self, r):
-        return RefRead(self.visit(r.ref))
-
-    def from_onnx(self, graph, opset):
+    def from_onnx(self, graph, opset, freeze_params=False):
         """Construct Relay expression from ONNX graph.
 
         Onnx graph is a python protobuf object.
@@ -2195,6 +2186,13 @@ class GraphProto(ExprFunctor):
 
         opset : opset version
 
+        freeze_params: bool
+            If this parameter is true, the importer will take any provided
+            onnx input values (weights, shapes, etc) and embed them into the 
relay model
+            as Constants instead of variables. This allows more aggressive 
optimizations
+            at compile time and helps in making models static if certain 
inputs represent
+            attributes relay would traditionally consider compile-time 
constants.
+
         Returns
         -------
         mod : tvm.IRModule
@@ -2236,6 +2234,7 @@ class GraphProto(ExprFunctor):
                 else:
                     dtype = d_type
                 self._nodes[i_name] = new_var(i_name, shape=tshape, 
dtype=dtype)
+            self._inputs[i_name] = self._nodes[i_name]
         # get list of unsupported ops
         convert_map = _get_convert_map(opset)
         unsupported_ops = set()
@@ -2271,11 +2270,12 @@ class GraphProto(ExprFunctor):
                 )
             else:
                 i_name = self._parse_value_proto(node)
+                node_output = self._fix_outputs(op_name, node.output)
                 attr["tvm_custom"] = {}
                 attr["tvm_custom"]["name"] = i_name
+                attr["tvm_custom"]["num_outputs"] = len(node_output)
 
                 op = self._convert_operator(op_name, inputs, attr, opset)
-                node_output = self._fix_outputs(op_name, node.output)
                 if not isinstance(op, _expr.TupleWrapper):
                     outputs_num = 1
                 else:
@@ -2294,7 +2294,18 @@ class GraphProto(ExprFunctor):
         # now return the outputs
         outputs = [self._nodes[self._parse_value_proto(i)] for i in 
graph.output]
         outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
-        func = _function.Function(analysis.free_vars(outputs), outputs)
+        ## Maintain the order of inputs and parameters from the ONNX graph, 
but only include
+        ## those parameters that are needed to execute the relay graph
+        free_vars = analysis.free_vars(outputs)
+        nodes = {v: k for k, v in self._nodes.items()}
+        free_vars = [nodes[var] for var in free_vars]
+        for i_name in self._params:
+            if i_name in free_vars and i_name not in self._inputs:
+                self._inputs[i_name] = self._nodes[i_name]
+        func = _function.Function([v for k, v in self._inputs.items()], 
outputs)
+        if freeze_params:
+            func, params = self.freeze(func, self._params)
+            return IRModule.from_expr(func), params
         return IRModule.from_expr(func), self._params
 
     def _parse_value_proto(self, value_proto):
@@ -2388,7 +2399,7 @@ class GraphProto(ExprFunctor):
         return outputs
 
 
-def from_onnx(model, shape=None, dtype="float32", opset=None):
+def from_onnx(model, shape=None, dtype="float32", opset=None, 
freeze_params=False):
     """Convert a ONNX model into an equivalent Relay Function.
 
     ONNX graphs are represented as Python Protobuf objects.
@@ -2398,6 +2409,13 @@ def from_onnx(model, shape=None, dtype="float32", 
opset=None):
     For convenience, we rename the `real` input names to "input_0",
     "input_1"... And renaming parameters to "param_0", "param_1"...
 
+    By default, ONNX defines models in terms of dynamic shapes. The ONNX 
importer
+    retains that dynamism upon import, and the compiler attempts to convert the
+    model into a static shapes at compile time. If this fails, there may still
+    be dynamic operations in the model. Not all TVM kernels currently support
+    dynamic shapes, please file an issue on discuss.tvm.ai
+    if you hit an error with dynamic kernels.
+
     Parameters
     ----------
     model : protobuf object
@@ -2413,6 +2431,13 @@ def from_onnx(model, shape=None, dtype="float32", 
opset=None):
         Override to autodetected opset.
         This can be helpful for some testing.
 
+    freeze_params: bool
+        If this parameter is true, the importer will take any provided
+        onnx input values (weights, shapes, etc) and embed them into the relay 
model
+        as Constants instead of variables. This allows more aggressive 
optimizations
+        at compile time and helps in making models static if certain inputs 
represent
+        attributes relay would traditionally consider compile-time constants.
+
     Returns
     -------
     mod : tvm.IRModule
@@ -2435,7 +2460,6 @@ def from_onnx(model, shape=None, dtype="float32", 
opset=None):
                 warnings.warn(str(e))
     except ImportError:
         pass
-    global g
     g = GraphProto(shape, dtype)
     graph = model.graph
     if opset is None:
@@ -2443,6 +2467,5 @@ def from_onnx(model, shape=None, dtype="float32", 
opset=None):
             opset = model.opset_import[0].version if model.opset_import else 1
         except AttributeError:
             opset = 1
-    mod, params = g.from_onnx(graph, opset)
-    g = None
+    mod, params = g.from_onnx(graph, opset, freeze_params)
     return mod, params
diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py
index 907c512..e9e608b 100644
--- a/python/tvm/relay/op/_tensor.py
+++ b/python/tvm/relay/op/_tensor.py
@@ -241,6 +241,7 @@ register_shape_func("subtract", False, broadcast_shape_func)
 register_shape_func("multiply", False, broadcast_shape_func)
 register_shape_func("divide", False, broadcast_shape_func)
 register_shape_func("floor_divide", False, broadcast_shape_func)
+register_shape_func("power", False, broadcast_shape_func)
 register_shape_func("mod", False, broadcast_shape_func)
 register_shape_func("floor_mod", False, broadcast_shape_func)
 register_shape_func("logical_and", False, broadcast_shape_func)
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 6694b5a..c83f6a9 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -722,29 +722,18 @@ reg.register_pattern("nn.correlation", 
OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
 @script
-def _conv2d_shape_func(dshape, kshape, strides, padding, dilation):
+def _conv_shape_func(dshape, kshape, strides, padding, dilation):
     out = output_tensor((dshape.shape[0],), "int64")
-    height = dshape[2]
-    width = dshape[3]
-    kheight = kshape[2]
-    kwidth = kshape[3]
-    dilated_kh = (kheight - 1) * dilation[0] + 1
-    dilated_kw = (kwidth - 1) * dilation[1] + 1
-
-    oc = kshape[0]
-
-    out_height = (height + 2 * padding[0] - dilated_kh) // strides[0] + 1
-    out_width = (width + 2 * padding[1] - dilated_kw) // strides[1] + 1
-
     out[0] = dshape[0]
-    out[1] = oc
-    out[2] = out_height
-    out[3] = out_width
+    out[1] = kshape[0]
+
+    for i in const_range(dshape.shape[0] - 2):
+        dilated_k = (kshape[i + 2] - 1) * dilation[i] + 1
+        out[i + 2] = (dshape[i + 2] + 2 * padding[i] - dilated_k) // 
strides[i] + 1
     return out
 
 
[email protected]_shape_func("nn.conv2d", False)
-def conv2d_shape_func(attrs, inputs, _):
+def conv_shape_func(attrs, inputs, _):
     """
     Shape function for contrib_conv2d_NCHWc op.
     """
@@ -753,7 +742,7 @@ def conv2d_shape_func(attrs, inputs, _):
     dilation = get_const_tuple(attrs.dilation)
 
     return [
-        _conv2d_shape_func(
+        _conv_shape_func(
             inputs[0],
             inputs[1],
             convert(strides),
@@ -763,6 +752,11 @@ def conv2d_shape_func(attrs, inputs, _):
     ]
 
 
+reg.register_shape_func("nn.conv1d", False, conv_shape_func)
+reg.register_shape_func("nn.conv2d", False, conv_shape_func)
+reg.register_shape_func("nn.conv3d", False, conv_shape_func)
+
+
 @script
 def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, 
oc_bn):
     out = output_tensor((dshape.shape[0],), "int64")
@@ -969,6 +963,25 @@ def dense_shape_func(attrs, inputs, _):
 
 
 @script
+def _batch_matmul_shape_func(data_shape, weight_shape):
+    out = output_tensor((data_shape.shape[0],), "int64")
+    for i in const_range(out.shape[0] - 1):
+        out[i] = data_shape[i]
+    out[out.shape[0] - 1] = weight_shape[weight_shape.shape[0] - 2]
+
+    return out
+
+
[email protected]_shape_func("nn.batch_matmul", False)
+def batch_matmul_shape_func(attrs, inputs, _):
+    """
+    Shape function for dense op.
+    """
+    ret = [_batch_matmul_shape_func(inputs[0], inputs[1])]
+    return ret
+
+
+@script
 def _pad_shape_func(data_shape, pad_width):
     out = output_tensor((data_shape.shape[0],), "int64")
     for i in const_range(out.shape[0]):
diff --git a/python/tvm/relay/op/strategy/generic.py 
b/python/tvm/relay/op/strategy/generic.py
index 68889f3..56ae976 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -683,7 +683,7 @@ def wrap_compute_batch_matmul(topi_compute):
     """wrap batch_matmul topi compute"""
 
     def _compute_batch_matmul(attrs, inputs, out_type):
-        return [topi_compute(inputs[0], inputs[1])]
+        return [topi_compute(inputs[0], inputs[1], out_type.shape)]
 
     return _compute_batch_matmul
 
diff --git a/python/tvm/relay/op/strategy/x86.py 
b/python/tvm/relay/op/strategy/x86.py
index 8925723..e2a82d3 100644
--- a/python/tvm/relay/op/strategy/x86.py
+++ b/python/tvm/relay/op/strategy/x86.py
@@ -21,6 +21,7 @@ import logging
 import re
 from tvm import topi
 from tvm.te import SpecializedCondition
+from tvm.relay.ty import is_dynamic
 from .generic import *
 from .. import op as _op
 
@@ -355,12 +356,20 @@ def dense_strategy_cpu(attrs, inputs, out_type, target):
 def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
     """batch_matmul x86 strategy"""
     strategy = _op.OpStrategy()
-    strategy.add_implementation(
-        wrap_compute_batch_matmul(topi.x86.batch_matmul),
-        wrap_topi_schedule(topi.x86.schedule_batch_matmul),
-        name="batch_matmul.x86",
-        plevel=10,
-    )
+    if is_dynamic(out_type):
+        strategy.add_implementation(
+            wrap_compute_batch_matmul(topi.nn.batch_matmul),
+            wrap_topi_schedule(topi.generic.nn.schedule_batch_matmul),
+            name="batch_matmul.generic",
+            plevel=10,
+        )
+    else:
+        strategy.add_implementation(
+            wrap_compute_batch_matmul(topi.x86.batch_matmul),
+            wrap_topi_schedule(topi.x86.schedule_batch_matmul),
+            name="batch_matmul.x86",
+            plevel=10,
+        )
     if "cblas" in target.libs:
         strategy.add_implementation(
             wrap_compute_batch_matmul(topi.x86.batch_matmul_cblas),
diff --git a/python/tvm/topi/cuda/batch_matmul.py 
b/python/tvm/topi/cuda/batch_matmul.py
index 26647dd..bb060b3 100644
--- a/python/tvm/topi/cuda/batch_matmul.py
+++ b/python/tvm/topi/cuda/batch_matmul.py
@@ -26,7 +26,7 @@ from ..util import traverse_inline, get_const_tuple, 
get_max_power2_factor
 
 
 @autotvm.register_topi_compute("batch_matmul.cuda")
-def batch_matmul(cfg, x, y):
+def batch_matmul(cfg, x, y, out_shape=None):
     """Compute conv2d with NCHW layout"""
     return nn.batch_matmul(x, y)
 
diff --git a/python/tvm/topi/nn/batch_matmul.py 
b/python/tvm/topi/nn/batch_matmul.py
index 7c8fead..34a8c6d 100644
--- a/python/tvm/topi/nn/batch_matmul.py
+++ b/python/tvm/topi/nn/batch_matmul.py
@@ -20,7 +20,7 @@ from tvm import te
 from ..util import get_const_tuple
 
 
-def batch_matmul(x, y):
+def batch_matmul(x, y, oshape=None):
     """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
     data in batch.
 
@@ -37,14 +37,19 @@ def batch_matmul(x, y):
     output : tvm.te.Tensor
         3-D with shape [batch, M, N]
     """
-    assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim 
batch_matmul"
-    x_shape = get_const_tuple(x.shape)
-    y_shape = get_const_tuple(y.shape)
-    assert x_shape[0] == y_shape[0], "batch dimension doesn't match"
-    assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant"
-    batch, M, K = x.shape
-    N = y.shape[1]
-    k = te.reduce_axis((0, K), name="k")
+    if oshape is None:
+        assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim 
batch_matmul"
+        x_shape = get_const_tuple(x.shape)
+        y_shape = get_const_tuple(y.shape)
+        assert x_shape[0] == y_shape[0], "batch dimension doesn't match"
+        assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant"
+        batch, M, K = x.shape
+        N = y.shape[1]
+        k = te.reduce_axis((0, K), name="k")
+        oshape = (batch, M, N)
+    else:
+        _, _, K = x.shape
+        k = te.reduce_axis((0, K), name="k")
     return te.compute(
-        (batch, M, N), lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], 
axis=k), tag="batch_matmul"
+        oshape, lambda b, i, j: te.sum(x[b, i, k] * y[b, j, k], axis=k), 
tag="batch_matmul"
     )
diff --git a/python/tvm/topi/x86/batch_matmul.py 
b/python/tvm/topi/x86/batch_matmul.py
index 333d3be..c095dcb 100644
--- a/python/tvm/topi/x86/batch_matmul.py
+++ b/python/tvm/topi/x86/batch_matmul.py
@@ -25,7 +25,7 @@ from ..util import traverse_inline, get_const_tuple, 
get_max_power2_factor
 
 
 @autotvm.register_topi_compute("batch_matmul.x86")
-def batch_matmul(cfg, x, y):
+def batch_matmul(cfg, x, y, out_shape=None):
     """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
     data in batch.
 
@@ -49,6 +49,10 @@ def batch_matmul(cfg, x, y):
     assert XK == YK, "shapes of x and y is inconsistant"
     B = XB
     K = XK
+    if out_shape is not None:
+        assert out_shape[0] == B, "got invalid output shape"
+        assert out_shape[1] == M, "got invalid output shape"
+        assert out_shape[2] == N, "got invalid output shape"
     if cfg.is_fallback:
         _default_batch_matmul_config(cfg, M, N, K)
 
diff --git a/src/relay/backend/build_module.cc 
b/src/relay/backend/build_module.cc
index 21fd591..b95e096 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -263,6 +263,9 @@ class RelayBuildModule : public runtime::ModuleNode {
       pass_seqs.push_back(transform::Legalize());
     }
 
+    // Convert Dynamic ops to static versions
+    pass_seqs.push_back(transform::DynamicToStatic());
+
     pass_seqs.push_back(transform::SimplifyInference());
     PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
       Expr expr = args[0];
diff --git a/src/relay/op/dyn/tensor/transform.cc 
b/src/relay/op/dyn/tensor/transform.cc
index de1cc5a..4b594ff 100644
--- a/src/relay/op/dyn/tensor/transform.cc
+++ b/src/relay/op/dyn/tensor/transform.cc
@@ -58,6 +58,11 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
 
   Array<IndexExpr> oshape;
   const auto* newshape = types[1].as<TensorTypeNode>();
+  if (newshape == nullptr) {
+    CHECK(types[1].as<IncompleteTypeNode>())
+        << "reshape: expect input type to be TensorType but get " << types[1];
+    return false;
+  }
 
   // Doesn't support dynamic output rank
   for (int i = 0; i < newshape->shape[0].as<IntImmNode>()->value; i++) {
@@ -209,10 +214,17 @@ bool BroadCastToRel(const Array<Type>& types, int 
num_inputs, const Attrs& attrs
   // types = [data_type, broadcast_shape_type, ret_type]
   CHECK_EQ(types.size(), 3);
 
-  const auto* target_shape = types[1].as<TensorTypeNode>();
-  DataType out_dtype = types[0].as<TensorTypeNode>()->dtype;
+  const auto* input_type = types[0].as<TensorTypeNode>();
+  const auto* target_type = types[1].as<TensorTypeNode>();
+  if (target_type == nullptr) {
+    return false;
+  }
+  if (input_type == nullptr) {
+    return false;
+  }
+  auto out_dtype = input_type->dtype;
   // rank must be static
-  const IntImmNode* rank = target_shape->shape[0].as<IntImmNode>();
+  const IntImmNode* rank = target_type->shape[0].as<IntImmNode>();
   CHECK(rank) << "Target shape must have static rank";  // rank must be static 
even in dyn pass
                                                         // could add support 
for dyn rank in futures
 
diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h
index f53f4e0..2311585 100644
--- a/src/relay/op/nn/convolution.h
+++ b/src/relay/op/nn/convolution.h
@@ -100,7 +100,9 @@ bool Conv1DRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
           << "Conv1D: shape of weight is inconsistent with channels, "
           << " channels=" << param->channels << " wshape=" << wshape;
     }
-    CHECK(reporter->AssertEQ(dshape_ncw[1], wshape[1]));
+    if (!dshape_ncw[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) {
+      CHECK(reporter->AssertEQ(dshape_ncw[1], wshape[1]));
+    }
     channels = wshape[0];
     dilated_ksize = 1 + (wshape[2] - 1) * param->dilation[0];
   }
@@ -211,7 +213,9 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
           << "Conv2D: shape of weight is inconsistent with channels, "
           << " channels=" << param->channels << " wshape=" << wshape;
     }
-    CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), 
wshape[1]));
+    if (!dshape_nchw[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) {
+      CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), 
wshape[1]));
+    }
     channels = wshape[0];
     dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
     dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
@@ -322,7 +326,9 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
           << "Conv3D: shape of weight is inconsistent with channels, "
           << " channels=" << param->channels << " wshape=" << wshape;
     }
-    CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), 
wshape[1]));
+    if (!dshape_ncdhw[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) {
+      CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), 
wshape[1]));
+    }
     channels = wshape[0];
     dilated_ksize_z = 1 + (wshape[2] - 1) * param->dilation[0];
     dilated_ksize_y = 1 + (wshape[3] - 1) * param->dilation[1];
@@ -800,7 +806,9 @@ bool Conv1DTransposeRel(const Array<Type>& types, int 
num_inputs, const Attrs& a
           << "Conv1D: shape of weight is inconsistent with channels, "
           << " channels=" << param->channels << " wshape=" << 
Array<IndexExpr>(wshape);
     }
-    CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), 
wshape[0]));
+    if (!dshape_ncw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
+      CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), 
wshape[0]));
+    }
     channels = wshape[1];
     dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0];
   }
@@ -808,8 +816,12 @@ bool Conv1DTransposeRel(const Array<Type>& types, int 
num_inputs, const Attrs& a
   IndexExpr pad_w;
   GetPaddingWidth(param->padding, &pad_w);
   Array<IndexExpr> oshape({dshape_ncw[0], channels, 0});
-  oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - 
pad_w +
-                 param->output_padding[0]));
+  if (!dshape_ncw[2].as<tir::AnyNode>()) {
+    oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x - 
pad_w +
+                   param->output_padding[0]));
+  } else {
+    oshape.Set(2, dshape_ncw[2]);
+  }
 
   DataType out_dtype = param->out_dtype;
   if (out_dtype.bits() == 0) {
@@ -890,7 +902,9 @@ bool Conv3DTransposeRel(const Array<Type>& types, int 
num_inputs, const Attrs& a
           << "Conv3D: shape of weight is inconsistent with channels, "
           << " channels=" << param->channels << " wshape=" << 
Array<IndexExpr>(wshape);
     }
-    CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), 
wshape[0]));
+    if (!dshape_ncdhw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
+      CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), 
wshape[0]));
+    }
     channels = wshape[1];
     dilated_ksize_d = 1 + (wshape[2] - 1) * param->dilation[0];
     dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
@@ -901,12 +915,25 @@ bool Conv3DTransposeRel(const Array<Type>& types, int 
num_inputs, const Attrs& a
   Array<IndexExpr> oshape({dshape_ncdhw[0], channels, 0, 0, 0});
   IndexExpr pad_d, pad_h, pad_w;
   GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w);
-  oshape.Set(2, (param->strides[0] * (dshape_ncdhw[2] - 1) + dilated_ksize_d - 
pad_d +
-                 param->output_padding[0]));
-  oshape.Set(3, (param->strides[1] * (dshape_ncdhw[3] - 1) + dilated_ksize_y - 
pad_h +
-                 param->output_padding[1]));
-  oshape.Set(4, (param->strides[2] * (dshape_ncdhw[4] - 1) + dilated_ksize_x - 
pad_w +
-                 param->output_padding[2]));
+
+  if (!dshape_ncdhw[2].as<tir::AnyNode>()) {
+    oshape.Set(2, (param->strides[0] * (dshape_ncdhw[2] - 1) + dilated_ksize_d 
- pad_d +
+                   param->output_padding[0]));
+  } else {
+    oshape.Set(2, dshape_ncdhw[2]);
+  }
+  if (!dshape_ncdhw[3].as<tir::AnyNode>()) {
+    oshape.Set(3, (param->strides[1] * (dshape_ncdhw[3] - 1) + dilated_ksize_y 
- pad_h +
+                   param->output_padding[1]));
+  } else {
+    oshape.Set(3, dshape_ncdhw[3]);
+  }
+  if (!dshape_ncdhw[4].as<tir::AnyNode>()) {
+    oshape.Set(4, (param->strides[2] * (dshape_ncdhw[4] - 1) + dilated_ksize_x 
- pad_w +
+                   param->output_padding[2]));
+  } else {
+    oshape.Set(4, dshape_ncdhw[4]);
+  }
 
   DataType out_dtype = param->out_dtype;
   if (out_dtype.bits() == 0) {
@@ -985,7 +1012,9 @@ bool Conv2DTransposeRel(const Array<Type>& types, int 
num_inputs, const Attrs& a
           << "Conv2D: shape of weight is inconsistent with channels, "
           << " channels=" << param->channels << " wshape=" << 
Array<IndexExpr>(wshape);
     }
-    CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), 
wshape[0]));
+    if (!dshape_nchw[1].as<tir::AnyNode>() && !wshape[0].as<tir::AnyNode>()) {
+      CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), 
wshape[0]));
+    }
     channels = wshape[1];
     dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
     dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
@@ -994,10 +1023,18 @@ bool Conv2DTransposeRel(const Array<Type>& types, int 
num_inputs, const Attrs& a
   Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
   IndexExpr pad_h, pad_w;
   GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
-  oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - 
pad_h +
-                 param->output_padding[0]));
-  oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - 
pad_w +
-                 param->output_padding[1]));
+  if (!dshape_nchw[2].as<tir::AnyNode>()) {
+    oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y 
- pad_h +
+                   param->output_padding[0]));
+  } else {
+    oshape.Set(2, dshape_nchw[2]);
+  }
+  if (!dshape_nchw[3].as<tir::AnyNode>()) {
+    oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x 
- pad_w +
+                   param->output_padding[1]));
+  } else {
+    oshape.Set(3, dshape_nchw[3]);
+  }
 
   DataType out_dtype = param->out_dtype;
   if (out_dtype.bits() == 0) {
@@ -1053,7 +1090,9 @@ bool DeformableConv2DRel(const Array<Type>& types, int 
num_inputs, const Attrs&
           << "DeformableConv2D: shape of weight is inconsistent with channels, 
"
           << " channels=" << param->channels << " wshape=" << wshape;
     }
-    CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), 
wshape[1]));
+    if (!data->shape[1].as<tir::AnyNode>() && !wshape[1].as<tir::AnyNode>()) {
+      CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), 
wshape[1]));
+    }
     channels = wshape[0];
     ksize_y = wshape[2];
     ksize_x = wshape[3];
diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc
index 619b86d..38ebe42 100644
--- a/src/relay/op/nn/nn.cc
+++ b/src/relay/op/nn/nn.cc
@@ -851,15 +851,26 @@ bool BatchMatmulRel(const Array<Type>& types, int 
num_inputs, const Attrs& attrs
   const auto* y = types[1].as<TensorTypeNode>();
   if (x == nullptr || y == nullptr) return false;
   CHECK(x->shape.size() == 3 && y->shape.size() == 3);
-  CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
-      << "BatchDot: batch dimension doesn't match, "
-      << " x shape=" << x->shape << ", y shape=" << y->shape;
-  CHECK(reporter->AssertEQ(x->shape[2], y->shape[2]))
-      << "BatchDot: shapes of x and y is inconsistent, "
-      << " x shape=" << x->shape << ", y shape=" << y->shape;
-
-  Array<tvm::PrimExpr> oshape = x->shape;
-  oshape.Set(2, y->shape[1]);
+  bool is_dyn = false;
+  Array<tvm::PrimExpr> oshape;
+  for (size_t i = 0; i < 3; ++i) {
+    if (x->shape[i].as<tir::AnyNode>() != nullptr || 
y->shape[i].as<tir::AnyNode>() != nullptr) {
+      is_dyn = true;
+      oshape.push_back(Any());
+    } else {
+      oshape.push_back(x->shape[i]);
+    }
+  }
+  if (!is_dyn) {
+    CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
+        << "BatchDot: batch dimension doesn't match, "
+        << " x shape=" << x->shape << ", y shape=" << y->shape;
+    CHECK(reporter->AssertEQ(x->shape[2], y->shape[2]))
+        << "BatchDot: shapes of x and y is inconsistent, "
+        << " x shape=" << x->shape << ", y shape=" << y->shape;
+
+    oshape.Set(2, y->shape[1]);
+  }
 
   // assign output type
   reporter->Assign(types[2], TensorType(oshape, x->dtype));
@@ -1021,9 +1032,15 @@ bool DepthToSpaceRel(const Array<Type>& types, int 
num_inputs, const Attrs& attr
       << " But got " << in_layout;
 
   auto oshape = layout_converter.ForwardShape(data->shape);
-  oshape.Set(1, indexdiv(oshape[1], (block_size * block_size)));
-  oshape.Set(2, oshape[2] * block_size);
-  oshape.Set(3, oshape[3] * block_size);
+  if (!oshape[1].as<tir::AnyNode>()) {
+    oshape.Set(1, indexdiv(oshape[1], (block_size * block_size)));
+  }
+  if (!oshape[2].as<tir::AnyNode>()) {
+    oshape.Set(2, oshape[2] * block_size);
+  }
+  if (!oshape[3].as<tir::AnyNode>()) {
+    oshape.Set(3, oshape[3] * block_size);
+  }
 
   // Assign output type
   reporter->Assign(types[1], 
TensorType(layout_converter.BackwardShape(oshape), data->dtype));
@@ -1078,9 +1095,15 @@ bool SpaceToDepthRel(const Array<Type>& types, int 
num_inputs, const Attrs& attr
       << " But got " << in_layout;
 
   auto oshape = layout_converter.ForwardShape(data->shape);
-  oshape.Set(1, oshape[1] * (block_size * block_size));
-  oshape.Set(2, indexdiv(oshape[2], block_size));
-  oshape.Set(3, indexdiv(oshape[3], block_size));
+  if (!oshape[1].as<tir::AnyNode>()) {
+    oshape.Set(1, oshape[1] * (block_size * block_size));
+  }
+  if (!oshape[2].as<tir::AnyNode>()) {
+    oshape.Set(2, indexdiv(oshape[2], block_size));
+  }
+  if (!oshape[3].as<tir::AnyNode>()) {
+    oshape.Set(3, indexdiv(oshape[3], block_size));
+  }
 
   // Assign output type
   reporter->Assign(types[1], 
TensorType(layout_converter.BackwardShape(oshape), data->dtype));
diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h
index 0fb0263..e7f5a4b 100644
--- a/src/relay/op/nn/nn.h
+++ b/src/relay/op/nn/nn.h
@@ -63,9 +63,11 @@ bool DenseRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
     if (weight == nullptr) return false;
     Array<tvm::PrimExpr> wshape = weight->shape;
     CHECK(static_cast<int>(weight->shape.size()) == 2);
-    CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], 
weight->shape[1]))
-        << "DenseRel: input dimension doesn't match,"
-        << " data shape=" << data->shape << ", weight shape=" << weight->shape;
+    if (!data->shape.back().as<tir::AnyNode>()) {
+      CHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], 
weight->shape[1]))
+          << "DenseRel: input dimension doesn't match,"
+          << " data shape=" << data->shape << ", weight shape=" << 
weight->shape;
+    }
     oshape.Set((oshape.size() - 1), wshape[0]);
   }
 
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 8d2d391..1649586 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -1822,9 +1822,9 @@ bool SqueezeRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
       if (p.second) {
         result_shape.push_back(p.first);
       } else {
-        const int64_t* axis_ptr = tir::as_const_int(p.first);
-        CHECK(axis_ptr != nullptr) << "cannot get concrete shape of input 
tensor";
-        CHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not 
equal to 1";
+        if (const int64_t* axis_ptr = tir::as_const_int(p.first)) {
+          CHECK_EQ(*axis_ptr, 1) << "cannot squeeze axis with dimension not 
equal to 1";
+        }
       }
     }
   }
@@ -2028,7 +2028,9 @@ bool StridedSliceRel(const Array<Type>& types, int 
num_inputs, const Attrs& attr
                      const TypeReporter& reporter) {
   CHECK_EQ(types.size(), 2);
   const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>();
-  CHECK(param != nullptr);
+  if (param == nullptr) {
+    return false;
+  }
   const auto* data = types[0].as<TensorTypeNode>();
 
   if (data == nullptr) {
diff --git a/src/relay/transforms/dynamic_to_static.cc 
b/src/relay/transforms/dynamic_to_static.cc
index 113b599..edcb839 100644
--- a/src/relay/transforms/dynamic_to_static.cc
+++ b/src/relay/transforms/dynamic_to_static.cc
@@ -227,6 +227,9 @@ Expr DynamicToStatic(Function f, IRModule m) {
     vars.Set(kv.second, kv.first);
   }
   const auto gv = vars[f];
+  // Put a limit on the while loop
+  // Primarily used to prevent accidental infinite lops in development
+  const int loop_limit = 1000;
   int i = 0;
   do {
     pre = expr;
@@ -236,13 +239,13 @@ Expr DynamicToStatic(Function f, IRModule m) {
     expr = mutator.Mutate(m->functions[gv]);
     m->Update(gv, Downcast<BaseFunc>(expr));
     i += 1;
-  } while (pre != expr && i < 1000);
+  } while (!StructuralEqual()(pre, expr) && i < loop_limit);
   return expr;
 }
 
 namespace transform {
 
-Pass ConvertDynamicToStatic() {
+Pass DynamicToStatic() {
   runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
       [=](Function f, IRModule m, PassContext pc) {
         return Downcast<Function>(DynamicToStatic(f, m));
@@ -251,7 +254,7 @@ Pass ConvertDynamicToStatic() {
 }
 
 TVM_REGISTER_GLOBAL("relay._transform.DynamicToStatic").set_body_typed([]() {
-  return ConvertDynamicToStatic();
+  return DynamicToStatic();
 });
 
 }  // namespace transform
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index 1c0fced..1aeb430 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -42,12 +42,21 @@ def get_input_data_shape_dict(graph_def, input_data):
     return input_names, shape_dict
 
 
-def get_tvm_output_with_vm(graph_def, input_data, target, ctx, opset=None):
+def get_tvm_output_with_vm(
+    graph_def, input_data, target, ctx, opset=None, freeze_params=False, 
convert_to_static=False
+):
     """ Generic function to execute and get tvm output with vm executor"""
-
+    if not isinstance(input_data, list):
+        input_data = [input_data]
     _, shape_dict = get_input_data_shape_dict(graph_def, input_data)
 
-    mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)
+    mod, params = relay.frontend.from_onnx(
+        graph_def, shape_dict, opset=opset, freeze_params=freeze_params
+    )
+    if convert_to_static:
+        from tvm.relay import transform
+
+        mod = transform.DynamicToStatic()(mod)
 
     ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target)
     result = ex.evaluate()(*input_data)
@@ -118,6 +127,8 @@ def verify_with_ort_with_inputs(
     targets=None,
     use_vm=False,
     opset=None,
+    freeze_params=False,
+    convert_to_static=False,
     dtype="float32",
     rtol=1e-5,
     atol=1e-5,
@@ -136,9 +147,16 @@ def verify_with_ort_with_inputs(
 
     for target in targets:
         ctx = tvm.context(target, 0)
-
         if use_vm:
-            tvm_out = get_tvm_output_with_vm(model, inputs, target, ctx, 
opset=opset)
+            tvm_out = get_tvm_output_with_vm(
+                model,
+                inputs,
+                target,
+                ctx,
+                opset=opset,
+                freeze_params=freeze_params,
+                convert_to_static=convert_to_static,
+            )
         else:
             tvm_out = get_tvm_output(model, inputs, target, ctx, out_shape, 
dtype, opset=opset)
 
@@ -152,6 +170,8 @@ def verify_with_ort(
     targets=None,
     use_vm=False,
     opset=None,
+    freeze_params=False,
+    convert_to_static=False,
     dtype="float32",
     rtol=1e-5,
     atol=1e-5,
@@ -164,6 +184,8 @@ def verify_with_ort(
         targets=targets,
         use_vm=use_vm,
         opset=opset,
+        freeze_params=freeze_params,
+        convert_to_static=convert_to_static,
         dtype=dtype,
         rtol=rtol,
         atol=atol,
@@ -213,21 +235,37 @@ def test_reshape():
         tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
 
 
[email protected]_gpu
+# TODO(mbrookhart): enable once VM supports heterogenous execution
+# @tvm.testing.uses_gpu
 def test_expand():
-    def _test_expand(name, data, shape, ref_data):
+    def _test_expand(name, data, shape, ref_data, dtype="int32"):
         shape_array = np.array(shape)
-        shape_node = onnx.helper.make_node(
-            "Constant",
-            inputs=[],
-            outputs=["shape"],
-            value=onnx.helper.make_tensor(
-                name="const_tensor",
-                data_type=onnx.TensorProto.INT32,
-                dims=shape_array.shape,
-                vals=shape_array.flatten().astype("int32"),
-            ),
-        )
+        if dtype == "int32":
+            shape_node = onnx.helper.make_node(
+                "Constant",
+                inputs=[],
+                outputs=["shape"],
+                value=onnx.helper.make_tensor(
+                    name="const_tensor",
+                    data_type=onnx.TensorProto.INT32,
+                    dims=shape_array.shape,
+                    vals=shape_array.flatten().astype("int32"),
+                ),
+            )
+        elif dtype == "int64":
+            shape_node = onnx.helper.make_node(
+                "Constant",
+                inputs=[],
+                outputs=["shape"],
+                value=onnx.helper.make_tensor(
+                    name="const_tensor",
+                    data_type=onnx.TensorProto.INT64,
+                    dims=shape_array.shape,
+                    vals=shape_array.flatten().astype("int64"),
+                ),
+            )
+        else:
+            raise "Invalid dtype"
         expand_node = helper.make_node("Expand", ["in", "shape"], ["out"])
 
         graph = helper.make_graph(
@@ -240,20 +278,22 @@ def test_expand():
         model = helper.make_model(graph, producer_name=name)
 
         for target, ctx in tvm.testing.enabled_targets():
-            tvm_out = get_tvm_output(model, data, target, ctx, ref_data.shape, 
"float32")
+            tvm_out = get_tvm_output_with_vm(model, data, target, ctx, 
freeze_params=True)
             tvm.testing.assert_allclose(ref_data, tvm_out)
 
     in_shape = (3, 1)
     shape = (3, 4)
     data = np.random.uniform(size=in_shape).astype(np.float32)
     ref_data = np.tile(data, 4)
-    _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data)
+    _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data, 
"int32")
+    _test_expand("expand_with_dim_unchanged_test", data, shape, ref_data, 
"int64")
 
     in_shape = (3, 1)
     shape = (2, 1, 6)
     data = np.random.uniform(size=in_shape).astype(np.float32)
     ref_data = data * np.ones(shape, dtype=np.float32)
-    _test_expand("expand_with_dim_changed_test", data, shape, ref_data)
+    _test_expand("expand_with_dim_changed_test", data, shape, ref_data, 
"int32")
+    _test_expand("expand_with_dim_changed_test", data, shape, ref_data, 
"int64")
 
 
 def verify_depth_to_space(inshape, outshape, mode, blockSize):
@@ -650,11 +690,12 @@ def _test_slice_iteration_v10(indata, outdata, **attrs):
     model = helper.make_model(graph, producer_name="slice_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 
"float32", opset=10)
+        tvm_out = get_tvm_output_with_vm(model, indata, target, ctx, opset=10, 
freeze_params=True)
         tvm.testing.assert_allclose(outdata, tvm_out)
 
 
[email protected]_gpu
+# TODO(mbrookhart): enable once VM supports heterogenous execution
+# @tvm.testing.uses_gpu
 def test_slice():
     x = np.random.randn(20, 10, 5).astype(np.float32)
     _test_slice_iteration_v1(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), 
axes=(0, 1))
@@ -856,12 +897,13 @@ def test_gather_nd():
     verify_gather_nd((4, 3, 5, 6), [[2, 1, 0, 0]], "float32")
 
 
[email protected]_gpu
+# TODO(mbrookhart): enable once VM supports heterogenous execution
+# @tvm.testing.uses_gpu
 def test_onehot():
     indices_shape = [10]
     indices_array = np.random.randint(low=0, high=9, size=indices_shape, 
dtype="int32")
     depth = 10
-    values = np.asarray([0, 1])
+    values = np.asarray([0, 1]).astype("int32")
     out_np = np.eye(depth)[indices_array.reshape(-1)]
 
     onehot_node = helper.make_node("OneHot", ["indices", "depth", "values"], 
["out"])
@@ -874,17 +916,15 @@ def test_onehot():
             helper.make_tensor_value_info("depth", TensorProto.INT32, [1]),
             helper.make_tensor_value_info("values", TensorProto.INT32, 
values.shape),
         ],
-        initializer=[
-            helper.make_tensor("depth", TensorProto.INT32, [1], [depth]),
-            helper.make_tensor("values", TensorProto.INT32, values.shape, 
values),
-        ],
         outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, 
out_np.shape)],
     )
 
     model = helper.make_model(graph, producer_name="onehot_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [indices_array], target, ctx, 
out_np.shape)
+        tvm_out = get_tvm_output_with_vm(
+            model, [indices_array, np.array([depth]).astype("int32"), values], 
target, ctx
+        )
         tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
 
 
@@ -916,7 +956,7 @@ def test_matmul():
         tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
 
 
-def verify_batch_matmul(a_shape, b_shape):
+def verify_batch_matmul(a_shape, b_shape, target, ctx):
     a_array = np.random.uniform(size=a_shape).astype("float32")
     b_array = np.random.uniform(size=b_shape).astype("float32")
     out_np = np.matmul(a_array, b_array)
@@ -935,16 +975,67 @@ def verify_batch_matmul(a_shape, b_shape):
 
     model = helper.make_model(graph, producer_name="matmul_test")
 
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [a_array, b_array], target, ctx, 
out_np.shape)
+    tvm_out = get_tvm_output_with_vm(model, [a_array, b_array], target, ctx)
+    tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
+
+
+# TODO(mbrookhart): enable cuda once VM supports heterogenous execution
[email protected]_targets("llvm")
+def test_batch_matmul(target, ctx):
+    verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4), target, ctx)
+    verify_batch_matmul((2, 4, 3), (3, 4), target, ctx)
+    verify_batch_matmul((2, 3, 4, 3), (3, 4), target, ctx)
+
+
+def verify_simple_dynamic_model(a_shape, b_shape, target, ctx):
+    def verify_model(ex, a_shape, b_shape):
+        a_array = np.random.uniform(size=a_shape).astype("float32")
+        b_array = np.random.uniform(size=b_shape).astype("float32")
+        # matmul
+        out_np = np.matmul(a_array, b_array)
+        # relu
+        out_np[out_np < 0] = 0
+
+        tvm_out = ex.evaluate()(a_array, b_array).asnumpy()
         tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
 
+    mul_node = helper.make_node("MatMul", ["a", "b"], ["out"])
+    relu_node = helper.make_node("Relu", ["out"], ["relu"])
 
[email protected]_gpu
-def test_batch_matmul():
-    verify_batch_matmul((2, 3, 4, 3), (2, 3, 3, 4))
-    verify_batch_matmul((2, 4, 3), (3, 4))
-    verify_batch_matmul((2, 3, 4, 3), (3, 4))
+    a_array = np.random.uniform(size=a_shape).astype("float32")
+    b_array = np.random.uniform(size=b_shape).astype("float32")
+    # matmul
+    out_np = np.matmul(a_array, b_array)
+
+    graph = helper.make_graph(
+        [mul_node, relu_node],
+        "matmul_test",
+        inputs=[
+            helper.make_tensor_value_info("a", TensorProto.FLOAT, 
list(a_shape)),
+            helper.make_tensor_value_info("b", TensorProto.FLOAT, 
list(b_shape)),
+        ],
+        outputs=[helper.make_tensor_value_info("relu", TensorProto.FLOAT, 
list(out_np.shape))],
+    )
+
+    model = helper.make_model(graph, producer_name="matmul_test")
+
+    a_anys = [relay.Any()] * len(a_shape)
+    b_anys = [relay.Any()] * len(b_shape)
+
+    mod, params = relay.frontend.from_onnx(model, {"a": a_anys, "b": b_anys})
+
+    ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target)
+    verify_model(ex, a_shape, b_shape)
+    verify_model(ex, [a * 2 for a in a_shape], [b * 2 for b in b_shape])
+    verify_model(ex, [a * 3 for a in a_shape], [b * 3 for b in b_shape])
+
+
+# TODO(mbrookhart): enable cuda once VM supports heterogenous execution
[email protected]_targets("llvm")
+def test_batch_matmul_dynamic_model(target, ctx):
+    verify_simple_dynamic_model((2, 3, 4, 3), (2, 3, 3, 4), target, ctx)
+    verify_simple_dynamic_model((2, 4, 3), (3, 4), target, ctx)
+    verify_simple_dynamic_model((2, 3, 4, 3), (3, 4), target, ctx)
 
 
 def verify_lrn(shape, nsize, dtype, alpha=None, beta=None, bias=None):
@@ -1149,8 +1240,9 @@ def _test_upsample_bilinear_opset9():
     model = helper.make_model(graph, 
producer_name="upsample_bilinear_opset9_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 
"float32")
-        tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
+        tvm_out = get_tvm_output_with_vm(
+            model, [in_array], target, ctx, opset=9, freeze_params=True
+        )
 
 
 def _test_upsample3d_trilinear():
@@ -1194,7 +1286,8 @@ def _test_upsample3d_trilinear():
         tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
 
 
[email protected]_gpu
+# TODO(mbrookhart): enable once VM supports heterogenous execution
+# @tvm.testing.uses_gpu
 def test_upsample():
     _test_upsample_nearest()
     _test_upsample_bilinear()
@@ -1475,18 +1568,19 @@ def verify_constantofshape(input_dim, value, dtype):
         "fill_test",
         inputs,
         outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
list(out.shape))],
-        initializer=[helper.make_tensor("input", TensorProto.INT32, 
(len(input_dim),), input_dim)],
     )
 
     model = helper.make_model(graph, producer_name="fill_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [], target, ctx, out.shape)
+        input_np = np.array(input_dim).astype("float32")
+        tvm_out = get_tvm_output_with_vm(model, [input_np], target, ctx)
 
         tvm.testing.assert_allclose(out, tvm_out, rtol=1e-5, atol=1e-5)
 
 
[email protected]_gpu
+# TODO(mbrookhart): enable once VM supports heterogenous execution
+# @tvm.testing.uses_gpu
 def test_constantofshape():
     verify_constantofshape((2, 3, 4, 5), 10, "float32")
     verify_constantofshape((3, 3), 0, "int32")
@@ -1550,7 +1644,7 @@ def verify_pad_v11(indata, pads, mode="constant", 
value=0.0):
             ],
         )
     else:
-        inputs = [indata, pads, np.array([value])]
+        inputs = [indata, pads, np.array([value]).astype("float32")]
         outdata = np.pad(indata, pad_width=np_pads, mode="constant", 
constant_values=value)
         node = helper.make_node(
             "Pad", inputs=["input", "pads", "constant_value"], 
outputs=["output"], mode="constant"
@@ -1561,7 +1655,7 @@ def verify_pad_v11(indata, pads, mode="constant", 
value=0.0):
             inputs=[
                 helper.make_tensor_value_info("input", TensorProto.FLOAT, 
list(indata.shape)),
                 helper.make_tensor_value_info("pads", TensorProto.INT64, 
(len(pads),)),
-                helper.make_tensor_value_info("constant_value", 
TensorProto.INT64, (1,)),
+                helper.make_tensor_value_info("constant_value", 
TensorProto.FLOAT, (1,)),
             ],
             initializer=[
                 helper.make_tensor("pads", TensorProto.INT64, (len(pads),), 
pads),
@@ -1574,11 +1668,12 @@ def verify_pad_v11(indata, pads, mode="constant", 
value=0.0):
     model = helper.make_model(graph, producer_name="pad_test")
     #  tvm result
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, inputs, target, ctx, outdata.shape, 
"float32", opset=11)
+        tvm_out = get_tvm_output_with_vm(model, inputs, target, ctx, opset=11, 
freeze_params=False)
     tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
 
 
[email protected]_gpu
+# TODO(mbrookhart): enable once VM supports heterogenous execution
+# @tvm.testing.uses_gpu
 def test_pad():
     verify_pad(np.random.randn(2, 2).astype(np.float32), [0, 1, 0, 0], 
"constant", 0.0)
     verify_pad(np.random.randn(2, 3).astype(np.float32), [1, 0, 0, 1], 
"constant", 0.0)
@@ -1660,20 +1755,28 @@ def test_all_reduce_funcs():
             )
 
 
-def verify_split(indata, outdatas, split, axis=0):
+def verify_split(indata, outdatas, split, axis=0, pass_split=True):
     indata = np.array(indata).astype(np.float32)
     outdatas = [np.array(o).astype(np.float32) for o in outdatas]
     if split:
         split_index = range(len(split))
     else:
         split_index = range(len(outdatas))
-    node = helper.make_node(
-        "Split",
-        inputs=["input"],
-        outputs=["output_{}".format(i) for i in range(len(split_index))],
-        axis=axis,
-        split=split,
-    )
+    if pass_split:
+        node = helper.make_node(
+            "Split",
+            inputs=["input"],
+            outputs=["output_{}".format(i) for i in range(len(split_index))],
+            axis=axis,
+            split=split,
+        )
+    else:
+        node = helper.make_node(
+            "Split",
+            inputs=["input"],
+            outputs=["output_{}".format(i) for i in range(len(split_index))],
+            axis=axis,
+        )
     graph = helper.make_graph(
         [node],
         "split_test",
@@ -1687,18 +1790,26 @@ def verify_split(indata, outdatas, split, axis=0):
     )
     model = helper.make_model(graph, producer_name="split_test")
 
+    import onnxruntime.backend
+
+    rep = onnxruntime.backend.prepare(model, "CPU")
+    onnx_out = rep.run(indata)
+
     for target, ctx in tvm.testing.enabled_targets():
         output_shape = [o.shape for o in outdatas]
         output_type = ["float32", "float32", "float32"]
         tvm_out = get_tvm_output(model, indata, target, ctx, output_shape, 
output_type)
-    for o, t in zip(outdatas, tvm_out):
-        tvm.testing.assert_allclose(o, t)
+        for o, t in zip(onnx_out, tvm_out):
+            tvm.testing.assert_allclose(o, t)
 
 
 @tvm.testing.uses_gpu
 def test_split():
     # 1D
     verify_split([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0, 4.0], 
[5.0, 6.0]], [2, 2, 2], 0)
+    verify_split(
+        [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], 
[2, 2, 2], 0, False
+    )
     verify_split([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0], [4.0, 
5.0, 6.0]], [2, 1, 3], 0)
     # 2D
     verify_split(
@@ -1708,7 +1819,7 @@ def test_split():
         1,
     )
     # Split evenly (unstack)
-    verify_split([1, 2, 3], [[1], [2], [3]], False)
+    verify_split([1, 2, 3], [[1], [2], [3]], False, 0, False)
 
 
 @tvm.testing.uses_gpu
@@ -2098,19 +2209,17 @@ def verify_tile_v6(indata, repeats, outdata):
             helper.make_tensor_value_info("repeats", TensorProto.INT64, 
list(repeats.shape)),
         ],
         outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(outdata.shape))],
-        initializer=[
-            helper.make_tensor("repeats", TensorProto.INT64, 
list(repeats.shape), repeats)
-        ],
     )
 
     model = helper.make_model(graph, producer_name="tile_test")
 
     for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape, 
opset=6)
+        tvm_out = get_tvm_output_with_vm(model, [indata, repeats], target, 
ctx, opset=6)
         tvm.testing.assert_allclose(outdata, tvm_out)
 
 
[email protected]_gpu
+# TODO(mbrookhart): enable once VM supports heterogenous execution
+# @tvm.testing.uses_gpu
 def test_tile():
     x = np.random.rand(2, 3, 4, 5).astype(np.float32)
     repeats = np.random.randint(low=1, high=10, 
size=(np.ndim(x),)).astype(np.int64)
@@ -2283,9 +2392,11 @@ def test_batch_norm():
     verify_batch_norm([16, 16, 10, 10])
 
 
[email protected]_gpu
+# TODO(mbrookhart): enable once VM supports heterogenous execution
+# @tvm.testing.uses_gpu
 def test_batch_norm_dynamic_subgraph():
     def verify_batch_norm_dynamic_subgraph(in_shape, o_shape):
+
         batchnorm = onnx.helper.make_node(
             "BatchNormalization", inputs=["x", "scale", "B", "mean", "var"], 
outputs=["Y"]
         )
@@ -2307,9 +2418,10 @@ def test_batch_norm_dynamic_subgraph():
         )
 
         model = helper.make_model(graph, producer_name="batchnorm_test")
+
         # X, inp, scale, b, mean, var
         inshapes = [in_shape, o_shape, in_shape[1], in_shape[1], in_shape[1], 
in_shape[1]]
-        verify_with_ort(model, inshapes, in_shape, use_vm=False)
+        verify_with_ort(model, inshapes, in_shape, use_vm=True)
 
     verify_batch_norm_dynamic_subgraph([16, 16, 10, 10], [160, 160])
 
@@ -2373,7 +2485,7 @@ def verify_conv(
 
     model = helper.make_model(graph, producer_name="conv_test")
 
-    verify_with_ort(model, [x_shape, w_shape], y_shape)
+    verify_with_ort(model, [x_shape, w_shape], y_shape, use_vm=True, 
convert_to_static=True)
 
 
 @tvm.testing.uses_gpu
@@ -2458,6 +2570,68 @@ def test_conv():
         )
 
 
+def verify_convtranspose_with_padding(
+    x_shape,
+    w_shape,
+    y_shape,
+    padding,
+    kernel_shape,
+    strides,
+    dilations,
+    auto_pad="NOTSET",
+    unset_pad=False,
+):
+    if unset_pad:
+        node = helper.make_node(
+            "ConvTranspose",
+            inputs=["x", "W"],
+            outputs=["y"],
+            kernel_shape=kernel_shape,
+            # Default values for other attributes:
+            strides=strides,
+            dilations=dilations,
+            group=1,
+        )
+    elif padding is None:
+        node = helper.make_node(
+            "ConvTranspose",
+            inputs=["x", "W"],
+            outputs=["y"],
+            kernel_shape=kernel_shape,
+            # Default values for other attributes:
+            strides=strides,
+            dilations=dilations,
+            group=1,
+            auto_pad=auto_pad,
+        )
+    else:
+        node = helper.make_node(
+            "ConvTranspose",
+            inputs=["x", "W"],
+            outputs=["y"],
+            kernel_shape=kernel_shape,
+            # Default values for other attributes:
+            strides=strides,
+            dilations=dilations,
+            group=1,
+            pads=padding,
+        )
+
+    graph = helper.make_graph(
+        [node],
+        "convtranspose_test",
+        inputs=[
+            helper.make_tensor_value_info("x", TensorProto.FLOAT, 
list(x_shape)),
+            helper.make_tensor_value_info("W", TensorProto.FLOAT, 
list(w_shape)),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, 
list(y_shape))],
+    )
+
+    model = helper.make_model(graph, producer_name="conv_test")
+
+    verify_with_ort(model, [x_shape, w_shape], y_shape, use_vm=True, 
convert_to_static=True)
+
+
 def verify_convtranspose(x_shape, w_shape, y_shape, p):
     node = onnx.helper.make_node(
         "ConvTranspose",
@@ -2492,6 +2666,88 @@ def test_convtranspose():
     # [1, 2, 1, 2] list for pads
     verify_convtranspose((1, 1, 3, 3), (1, 2, 3, 3), (1, 2, 7, 3), [1, 2, 1, 
2])
 
+    def repeat(N, D):
+        return tuple([N for _ in range(D)])
+
+    # TODO(mbrookhart): onnxruntime in CI only supports 2D,
+    # find something else to test 1D and 3D against
+    for D in [2]:
+        # Convolution with padding
+        verify_convtranspose_with_padding(
+            (1, 1) + repeat(5, D),
+            (1, 1) + repeat(3, D),
+            (1, 1) + repeat(5, D),
+            2 * repeat(1, D),
+            repeat(3, D),
+            repeat(1, D),
+            repeat(1, D),
+        )
+        # Convolution without padding
+        verify_convtranspose_with_padding(
+            (1, 1) + repeat(5, D),
+            (1, 1) + repeat(3, D),
+            (1, 1) + repeat(7, D),
+            2 * repeat(0, D),
+            repeat(3, D),
+            repeat(1, D),
+            repeat(1, D),
+        )
+        # Convolution with autopadding
+        verify_convtranspose_with_padding(
+            (1, 1) + repeat(5, D),
+            (1, 1) + repeat(3, D),
+            (1, 1) + repeat(5, D),
+            None,
+            repeat(3, D),
+            repeat(1, D),
+            repeat(1, D),
+            auto_pad="SAME_UPPER",
+        )
+        # Convolution with valid autopadding
+        verify_convtranspose_with_padding(
+            (1, 1) + repeat(5, D),
+            (1, 1) + repeat(3, D),
+            (1, 1) + repeat(7, D),
+            None,
+            repeat(3, D),
+            repeat(1, D),
+            repeat(1, D),
+            auto_pad="VALID",
+        )
+        # Convolution with unset padding
+        verify_convtranspose_with_padding(
+            (1, 1) + repeat(5, D),
+            (1, 1) + repeat(3, D),
+            (1, 1) + repeat(7, D),
+            2 * repeat(0, D),
+            repeat(3, D),
+            repeat(1, D),
+            repeat(1, D),
+            True,
+        )
+        # Convolution with non uniform stride
+        verify_convtranspose_with_padding(
+            (1, 1) + repeat(5, D),
+            (1, 1) + repeat(3, D),
+            (1, 1) + repeat(9, D),
+            None,
+            repeat(3, D),
+            repeat(2, D),
+            repeat(1, D),
+            auto_pad="SAME_UPPER",
+        )
+        # Convolution with dilation
+        # TODO(mbrookhart): Relay doesn't currently support convtranspose with 
dilation
+        # verify_convtranspose_with_padding(
+        #     (1, 1) + repeat(5, D),
+        #     (1, 1) + repeat(3, D),
+        #     (1, 1) + repeat(5, D),
+        #     2 * repeat(2, D),
+        #     repeat(3, D),
+        #     repeat(1, D),
+        #     repeat(2, D),
+        # )
+
 
 @tvm.testing.uses_gpu
 def test_unsqueeze_constant():
@@ -2515,6 +2771,7 @@ def test_unsqueeze_constant():
 
 
 def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, 
auto_pad="NOTSET"):
+    print(x_shape, kernel_shape, strides, mode, pads, auto_pad)
     x_np = np.random.uniform(size=x_shape).astype("float32")
 
     if mode == "max":
@@ -2546,7 +2803,7 @@ def verify_pooling(x_shape, kernel_shape, strides, pads, 
out_shape, mode, auto_p
     )
 
     model = helper.make_model(graph, producer_name="pooling_test")
-    verify_with_ort(model, [x_shape], out_shape)
+    verify_with_ort(model, [x_shape], out_shape, use_vm=True, 
convert_to_static=True)
 
 
 @tvm.testing.uses_gpu
@@ -2796,7 +3053,7 @@ def verify_lppool(x_shape, kernel_shape, p, strides, 
pads, out_shape, auto_pad="
     )
 
     model = helper.make_model(graph, producer_name="lppool_test")
-    verify_with_ort(model, [x_shape], out_shape)
+    verify_with_ort(model, [x_shape], out_shape, use_vm=True, 
convert_to_static=True)
 
 
 @tvm.testing.uses_gpu
@@ -3169,7 +3426,8 @@ def test_gru():
     )
 
 
[email protected]_gpu
+# TODO(mbrookhart): enable once VM supports heterogenous execution
+# @tvm.testing.uses_gpu
 def test_resize():
     def verify(ishape, oshape, scales, mode, coord_trans):
         nodes = [
@@ -3194,7 +3452,6 @@ def test_resize():
 
         if oshape == []:
             oshape = [round(dim * scale) for (dim, scale) in zip(ishape, 
scales)]
-
         graph = helper.make_graph(
             nodes,
             "resize_test",
@@ -3204,7 +3461,7 @@ def test_resize():
 
         model = helper.make_model(graph, producer_name="resize_test")
 
-        verify_with_ort(model, [ishape], oshape, use_vm=False, opset=11)
+        verify_with_ort(model, [ishape], oshape, use_vm=True, opset=11, 
freeze_params=True)
 
     # upsampling
     verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "asymmetric")
@@ -3273,7 +3530,6 @@ def test_topk():
                     ],
                 ),
             ],
-            initializer=[helper.make_tensor("K", TensorProto.INT64, [1], [K])],
             outputs=[
                 helper.make_tensor_value_info("Values", TensorProto.FLOAT, 
output_dims),
                 helper.make_tensor_value_info("Indicies", TensorProto.INT64, 
output_dims),
@@ -3283,17 +3539,10 @@ def test_topk():
         model = helper.make_model(graph, producer_name="topk_test")
 
         indata = np.random.uniform(-10, 10, input_dims).astype(np.float32)
-        onnx_out = get_onnxruntime_output(model, [indata, k])
+        onnx_out = get_onnxruntime_output(model, [indata, np.array([K])])
 
         for target, ctx in [("llvm", tvm.cpu())]:
-            tvm_out = get_tvm_output(
-                model,
-                indata,
-                target,
-                ctx,
-                [output_dims, output_dims],
-                output_dtype=["float32", "int64"],
-            )
+            tvm_out = get_tvm_output_with_vm(model, [indata, np.array(K)], 
target, ctx)
             tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, 
atol=1e-05)
 
     for n in [12, 32]:
diff --git a/tests/python/relay/dyn/test_dynamic_op_level10.py 
b/tests/python/relay/dyn/test_dynamic_op_level10.py
index 622e291..18e1dd5 100644
--- a/tests/python/relay/dyn/test_dynamic_op_level10.py
+++ b/tests/python/relay/dyn/test_dynamic_op_level10.py
@@ -27,34 +27,62 @@ import tvm.topi.testing
 import random
 import tvm.testing
 
-# TODO(mbrookhart): Enable when VM supports heterogenus execution
+# TODO(mbrookhart): Enable when the VM supports heterogenus execution
 # @tvm.testing.uses_gpu
-def test_dyn_broadcast_to():
-    dtype = "uint8"
-    rank = 3
-    shape_type = "int64"
-    dyn_shape = relay.Var("shape", relay.ty.TensorType((rank,), shape_type))
-    x_shape = (1,)
-    x = relay.Var("x", relay.ty.TensorType(x_shape, dtype))
-    z = relay.broadcast_to(x, dyn_shape)
-    zz = run_infer_type(z)
-
-    assert zz.checked_type == relay.ty.TensorType((relay.Any(),) * rank, dtype)
-
-    func = relay.Function([x, dyn_shape], z)
-
-    x = np.random.uniform(size=x_shape).astype(dtype)
-    dyn_shape = (1,) * rank
-    ref_res = np.broadcast_to(x, dyn_shape)
-    for target, ctx in tvm.testing.enabled_targets():
-        for kind in ["vm", "debug"]:
-            mod = tvm.ir.IRModule.from_expr(func)
-            intrp = relay.create_executor(kind, mod=mod, ctx=ctx, 
target=target)
-            op_res = intrp.evaluate(func)(x, 
np.array(dyn_shape).astype(shape_type))
-            tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
-
-
-# TODO(mbrookhart): Enable when VM supports heterogenus execution
+def test_broadcast_to():
+    def verify_more_dynamic_broadcast_to(x_shape, out_shape):
+        rank = len(out_shape)
+        dtype = "float32"
+        shape_type = "int64"
+        reshape_shape = relay.Var("shape", 
relay.ty.TensorType((len(x_shape),), shape_type))
+        broadcast_shape = relay.Var("shape", relay.ty.TensorType((rank,), 
shape_type))
+        x = relay.Var("x", relay.ty.TensorType((np.prod(x_shape),), dtype))
+        r = relay.reshape(x, reshape_shape)
+        z = relay.broadcast_to(r, broadcast_shape)
+
+        func = relay.Function([x, reshape_shape, broadcast_shape], z)
+
+        x = np.random.uniform(size=np.prod(x_shape)).astype(dtype)
+        ref_res = np.broadcast_to(np.reshape(x, x_shape), out_shape)
+        for target, ctx in tvm.testing.enabled_targets():
+            for kind in ["vm", "debug"]:
+                mod = tvm.ir.IRModule.from_expr(func)
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, 
target=target)
+                op_res = intrp.evaluate(func)(
+                    x, np.array(x_shape).astype(shape_type), 
np.array(out_shape).astype(shape_type)
+                )
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, 
rtol=1e-5)
+
+    verify_more_dynamic_broadcast_to((4, 3), (3, 4, 3))
+
+    def verify_broadcast_to(x_shape, out_shape):
+        rank = len(out_shape)
+        dtype = "float32"
+        shape_type = "int64"
+        dyn_shape = relay.Var("shape", relay.ty.TensorType((rank,), 
shape_type))
+        x = relay.Var("x", relay.ty.TensorType(x_shape, dtype))
+        z = relay.broadcast_to(x, dyn_shape)
+        zz = run_infer_type(z)
+
+        assert zz.checked_type == relay.ty.TensorType((relay.Any(),) * rank, 
dtype)
+
+        func = relay.Function([x, dyn_shape], z)
+
+        x = np.random.uniform(size=x_shape).astype(dtype)
+        ref_res = np.broadcast_to(x, out_shape)
+        for target, ctx in tvm.testing.enabled_targets():
+            for kind in ["vm", "debug"]:
+                mod = tvm.ir.IRModule.from_expr(func)
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, 
target=target)
+                op_res = intrp.evaluate(func)(x, 
np.array(out_shape).astype(shape_type))
+                tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, 
rtol=1e-5)
+
+    verify_broadcast_to((1,), (1, 1, 1))
+    verify_broadcast_to((1, 1), (4, 1, 1))
+    verify_broadcast_to((4, 1), (1, 4, 3))
+
+
+# TODO(mbrookhart): Enable when the VM supports heterogenus execution
 # @tvm.testing.uses_gpu
 def test_dyn_one_hot():
     def _get_oshape(indices_shape, depth, axis):
diff --git a/tests/python/relay/test_op_level10.py 
b/tests/python/relay/test_op_level10.py
index edb7c46..bc56568 100644
--- a/tests/python/relay/test_op_level10.py
+++ b/tests/python/relay/test_op_level10.py
@@ -362,6 +362,33 @@ def test_batch_matmul():
     verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20))
 
 
+def verify_dynamic_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"):
+    x = relay.var("x", relay.TensorType(x_shape, dtype))
+    y = relay.var("y", relay.TensorType((relay.Any(),) * len(y_shape), dtype))
+    z = relay.nn.batch_matmul(x, y)
+
+    func = relay.Function([x, y], z)
+    x_np = np.random.uniform(size=x_shape).astype(dtype)
+    y_np = np.random.uniform(size=y_shape).astype(dtype)
+    z_np = tvm.topi.testing.batch_matmul(x_np, y_np)
+
+    for target, ctx in tvm.testing.enabled_targets():
+        for kind in ["vm", "debug"]:
+            mod = tvm.ir.IRModule.from_expr(func)
+            intrp = relay.create_executor(kind, mod=mod, ctx=ctx, 
target=target)
+            z = intrp.evaluate()(x_np, y_np)
+            tvm.testing.assert_allclose(z.asnumpy(), z_np, rtol=1e-5)
+
+
+# TODO(mbrookhart): enable once VM supports heterogenous execution
+# @tvm.testing.uses_gpu
+def test_dynamic_batch_matmul():
+    verify_dynamic_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16))
+    verify_dynamic_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16))
+    verify_dynamic_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20))
+    verify_dynamic_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20))
+
+
 @tvm.testing.uses_gpu
 def test_shape_of():
     shape = (10, 5, 12)
diff --git a/tutorials/frontend/from_onnx.py b/tutorials/frontend/from_onnx.py
index e68a398..22c839c 100644
--- a/tutorials/frontend/from_onnx.py
+++ b/tutorials/frontend/from_onnx.py
@@ -103,3 +103,12 @@ canvas[0:224, 0:224, :] = np.asarray(img)
 canvas[:, 672:, :] = np.asarray(result)
 plt.imshow(canvas.astype(np.uint8))
 plt.show()
+
+######################################################################
+# Notes
+# ---------------------------------------------
+# By default, ONNX defines models in terms of dynamic shapes. The ONNX importer
+# retains that dynamism upon import, and the compiler attemps to convert the 
model
+# into a static shapes at compile time. If this fails, there may still be 
dynamic
+# operations in the model. Not all TVM kernels currently support dynamic 
shapes,
+# please file an issue on discuss.tvm.ai if you hit an error with dynamic 
kernels.

Reply via email to