This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f91b51d  [Relay][Frontend][Onnx] Compare against onnxruntime more 
consistently during testing (#7300)
f91b51d is described below

commit f91b51d638874973a2d9ccbcb4d49cf7c668f516
Author: Josh Fromm <[email protected]>
AuthorDate: Tue Jan 19 06:32:42 2021 -0800

    [Relay][Frontend][Onnx] Compare against onnxruntime more consistently 
during testing (#7300)
    
    
    
    Co-authored-by: Josh Fromm <[email protected]>
---
 python/tvm/relay/frontend/onnx.py          |  73 +--
 tests/python/frontend/onnx/test_forward.py | 801 ++++++++++-------------------
 2 files changed, 302 insertions(+), 572 deletions(-)

diff --git a/python/tvm/relay/frontend/onnx.py 
b/python/tvm/relay/frontend/onnx.py
index 9405bc5..7a3b168 100644
--- a/python/tvm/relay/frontend/onnx.py
+++ b/python/tvm/relay/frontend/onnx.py
@@ -17,6 +17,7 @@
 # pylint: disable=invalid-name, import-self, len-as-condition, 
unused-argument, too-many-lines
 # pylint: disable=import-outside-toplevel
 """ONNX: Open Neural Network Exchange frontend for Relay."""
+import copy
 import warnings
 import numpy as np
 import tvm
@@ -1028,10 +1029,6 @@ class Upsample(OnnxOpConverter):
                 'Value {} in attribute "mode" of operator Upsample is not 
valid.'.format(mode)
             )
 
-        if method == "nearest_neighbor":
-            align_corners = False
-        else:
-            align_corners = True
         # in 3d case, we use the purely static op
         if dims == 5:
             if isinstance(scales, _expr.Call):
@@ -1065,7 +1062,7 @@ class Upsample(OnnxOpConverter):
                 scale_w,
                 layout=layout,
                 method=method,
-                align_corners=align_corners,
+                align_corners=False,
             )
         return out
 
@@ -1111,17 +1108,22 @@ class Split(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        splits = attr.get("split", False)
-        if splits:
+        splits = attr.get("split", None)
+        if splits is not None:
+            indices = []
             attr["indices_or_sections"] = []
             index = 0
             for i in splits[:-1]:
                 index += i
-                attr["indices_or_sections"].append(index)
+                indices.append(index)
         # When splits isnt specified divide evenly over axis.
         else:
-            attr["indices_or_sections"] = attr["tvm_custom"]["num_outputs"]
-        return AttrCvt("split", ignores=["split"])(inputs, attr, params)
+            indices = attr["tvm_custom"]["num_outputs"]
+        output = _op.split(inputs[0], indices, attr.get("axis", 0))
+        # If the output of split is a single value, unpack if from the 
TupleWrapper
+        if len(output) == 1:
+            output = output[0]
+        return output
 
 
 class Slice(OnnxOpConverter):
@@ -1227,7 +1229,9 @@ class GatherND(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, inputs, attr, params):
-        return _op.gather_nd(inputs[0], inputs[1])
+        indices_dims = len(infer_shape(inputs[1]))
+        indices = _op.transpose(inputs[1], axes=[-1] + list(range(indices_dims 
- 1)))
+        return _op.gather_nd(inputs[0], indices)
 
 
 class Scatter(OnnxOpConverter):
@@ -1539,15 +1543,6 @@ class Tile(Elemwise):
     """Operator converter for Tile"""
 
     @classmethod
-    def _impl_v1(cls, inputs, attr, params):
-        if "repeats" not in attr:
-            raise tvm.error.OpAttributeInvalid(
-                'Attribute "repeats" should be set ' "for operator Tile."
-            )
-        reps = attr.pop("repeats")  # The number of times repeating the tensor 
data.
-        return _op.tile(inputs[0], reps)
-
-    @classmethod
     def _impl_v6(cls, inputs, attr, params):
         return _op.tile(inputs[0], inputs[1])
 
@@ -2113,7 +2108,9 @@ class Loop(OnnxOpConverter):
         cond = inputs[1]
         loop_deps = inputs[2:]
         num_deps = len(loop_deps)
-        body = attr["body"]
+        # Create a copy of the body function to prevent the original
+        # from being modified.
+        body = copy.copy(attr["body"])
         iter_dtype = infer_type(max_loop_count).checked_type.dtype
 
         # Determine what condition mode we're in.
@@ -2150,6 +2147,8 @@ class Loop(OnnxOpConverter):
             checked_type = infer_type(val)
             if hasattr(checked_type, "type_annotation"):
                 checked_type = checked_type.type_annotation
+            if hasattr(checked_type, "checked_type"):
+                checked_type = checked_type.checked_type
             shape = get_const_tuple(checked_type.shape)
             actual_shape = []
             for dim in shape:
@@ -2185,8 +2184,14 @@ class Loop(OnnxOpConverter):
         scan_output_init = []
         for i in range(num_scan_outputs):
             name, shape, dtype, _ = get_info(body.output[i + 1 + num_deps])
-            scan_output_vars.append(_expr.var(name, shape=([_ty.Any()] + 
shape), dtype=dtype))
-            scan_output_init.append(_op.reshape(_expr.const([]), [0] + shape))
+            if dtype == "float":
+                dtype = "float32"
+            scan_output_vars.append(
+                _expr.var(name, shape=([_ty.Any()] * (len(shape) + 1)), 
dtype=dtype)
+            )
+            scan_output_init.append(
+                _op.reshape(_expr.const(np.array([]).astype(dtype)), [0] + [1] 
* len(shape))
+            )
 
         # Now we can remove loop iter variables from our inner loop's inputs.
         # This is kind of a hack since we have graph inputs that we don't
@@ -2219,11 +2224,6 @@ class Loop(OnnxOpConverter):
             new_loop_vars = [loop_outputs[i] for i in range(1, 1 + num_deps)]
             new_scan_outputs = [loop_outputs[i] for i in range(1 + num_deps, 
len(loop_outputs))]
 
-            # Increment counter.
-            if max_loop_count is not None:
-                incr = _expr.const(1, dtype=iter_dtype)
-                loop_count = loop_count + incr
-
             # Add new scan outputs to tracking
             combined_scan_outputs = []
             for i, scan in enumerate(scan_outputs):
@@ -2231,6 +2231,11 @@ class Loop(OnnxOpConverter):
                 combined_scan = _op.concatenate([scan, new_scan], axis=0)
                 combined_scan_outputs.append(combined_scan)
 
+            # Increment counter.
+            if max_loop_count is not None:
+                incr = _expr.const(1, dtype=iter_dtype)
+                loop_count = loop_count + incr
+
             # Pack loop outputs for next iteration
             # [iter_count, cond, loop_deps, loop_scans]
             return [loop_count, max_count, new_cond] + new_loop_vars + 
combined_scan_outputs
@@ -2630,12 +2635,12 @@ def _get_convert_map(opset):
         "Greater": Greater.get_converter(opset),
         "Less": Less.get_converter(opset),
         "Log": Renamer("log"),
-        "ACos": Renamer("acos"),
-        "ACosh": Renamer("acosh"),
-        "ASin": Renamer("asin"),
-        "ASinh": Renamer("asinh"),
-        "ATan": Renamer("atan"),
-        "ATanh": Renamer("atanh"),
+        "Acos": Renamer("acos"),
+        "Acosh": Renamer("acosh"),
+        "Asin": Renamer("asin"),
+        "Asinh": Renamer("asinh"),
+        "Atan": Renamer("atan"),
+        "Atanh": Renamer("atanh"),
         "Cos": Renamer("cos"),
         "Cosh": Renamer("cosh"),
         "Sin": Renamer("sin"),
diff --git a/tests/python/frontend/onnx/test_forward.py 
b/tests/python/frontend/onnx/test_forward.py
index 96be6fb..20937d2 100644
--- a/tests/python/frontend/onnx/test_forward.py
+++ b/tests/python/frontend/onnx/test_forward.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 import numpy as np
-import math
 import onnx
 from onnx import helper, TensorProto, mapping, numpy_helper
 import torch
@@ -94,7 +93,7 @@ def get_tvm_output(
     # execute
     m.run()
     # get outputs
-    if isinstance(output_shape, list) and isinstance(output_dtype, list):
+    if isinstance(output_shape, list):
         tvm_output_list = []
         for i, _ in enumerate(output_shape):
             tvm_output = m.get_output(i)
@@ -105,17 +104,19 @@ def get_tvm_output(
         return tvm_output.asnumpy()
 
 
-def get_onnxruntime_output(model, inputs, dtype="float32"):
+def get_onnxruntime_output(model, inputs):
     import onnxruntime.backend
 
     rep = onnxruntime.backend.prepare(model, "CPU")
-    if isinstance(inputs, list) and len(inputs) > 1:
-        return rep.run(inputs)
-    elif isinstance(inputs, list) and len(inputs) == 1:
+    if isinstance(inputs, list) and len(inputs) == 1:
         inp = inputs[0]
     else:
         inp = inputs
-    return rep.run(inp.astype(dtype))[0]
+    output = rep.run(inp)
+    # Unpack output if there's only a single value.
+    if len(output) == 1:
+        output = output[0]
+    return output
 
 
 def verify_with_ort_with_inputs(
@@ -130,15 +131,11 @@ def verify_with_ort_with_inputs(
     dtype="float32",
     rtol=1e-5,
     atol=1e-5,
+    apply_softmax=False,
 ):
-    def flatten(out):
-        if isinstance(out, list) and len(out) == 1:
-            out = out[0]
-        if isinstance(out, np.ndarray):
-            return out.flatten()
-        return out
-
-    ort_out = get_onnxruntime_output(model, inputs, dtype)
+    if opset is not None:
+        model.opset_import[0].version = opset
+    ort_out = get_onnxruntime_output(model, inputs)
 
     if targets is None:
         targets = [tgt for (tgt, _) in tvm.testing.enabled_targets()]
@@ -157,8 +154,15 @@ def verify_with_ort_with_inputs(
             )
         else:
             tvm_out = get_tvm_output(model, inputs, target, ctx, out_shape, 
dtype, opset=opset)
-
-        tvm.testing.assert_allclose(flatten(ort_out), flatten(tvm_out), 
rtol=rtol, atol=atol)
+        if not isinstance(tvm_out, list):
+            tvm_out = [tvm_out]
+        if not isinstance(ort_out, list):
+            ort_out = [ort_out]
+        for tvm_val, ort_val in zip(tvm_out, ort_out):
+            if apply_softmax:
+                ort_val = scipy.special.softmax(ort_val)
+                tvm_val = scipy.special.softmax(tvm_val)
+            tvm.testing.assert_allclose(ort_val, tvm_val, rtol=rtol, atol=atol)
 
 
 def verify_with_ort(
@@ -342,7 +346,7 @@ def verify_depth_to_space(inshape, outshape, mode, 
blockSize):
 
     model = helper.make_model(graph, producer_name="depth_to_space_test")
 
-    verify_with_ort(model, [inshape], outshape)
+    verify_with_ort(model, [inshape], [outshape])
 
 
 @tvm.testing.uses_gpu
@@ -365,7 +369,7 @@ def verify_space_to_depth(inshape, outshape, blockSize):
 
     model = helper.make_model(graph, producer_name="space_to_depth_test")
 
-    verify_with_ort(model, [inshape], outshape)
+    verify_with_ort(model, [inshape], [outshape])
 
 
 @tvm.testing.uses_gpu
@@ -494,11 +498,8 @@ def test_squeeze():
     )
 
     model = helper.make_model(graph, producer_name="squeeze_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        x = np.random.uniform(size=in_shape).astype("float32")
-        tvm_out = get_tvm_output(model, x, target, ctx, out_shape, "float32")
-        tvm.testing.assert_allclose(out_shape, tvm_out.shape)
+    x = np.random.uniform(size=in_shape).astype("float32")
+    verify_with_ort_with_inputs(model, [x], [out_shape])
 
 
 @tvm.testing.uses_gpu
@@ -518,11 +519,7 @@ def test_flatten():
     )
 
     model = helper.make_model(graph, producer_name="flatten_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        x = np.random.uniform(size=in_shape).astype("int32")
-        tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, "float32")
-        tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
+    verify_with_ort(model, [in_shape])
 
 
 @tvm.testing.uses_gpu
@@ -540,16 +537,12 @@ def test_unsqueeze():
     )
 
     model = helper.make_model(graph, producer_name="squeeze_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        x = np.random.uniform(size=in_shape).astype("float32")
-        tvm_out = get_tvm_output(model, x, target, ctx, out_shape, "float32")
-        tvm.testing.assert_allclose(out_shape, tvm_out.shape)
+    verify_with_ort(model, [in_shape])
 
 
 def verify_gather(in_shape, indices, axis, dtype):
     x = np.random.uniform(size=in_shape).astype(dtype)
-    indices = np.array(indices, dtype="int32")
+    indices = np.array(indices, dtype="int64")
     out_np = np.take(x, indices, axis=axis)
 
     y = helper.make_node("Gather", ["in", "indices"], ["out"], axis=axis)
@@ -558,16 +551,19 @@ def verify_gather(in_shape, indices, axis, dtype):
         [y],
         "gather_test",
         inputs=[
-            helper.make_tensor_value_info("in", TensorProto.FLOAT, 
list(in_shape)),
-            helper.make_tensor_value_info("indices", TensorProto.INT32, 
list(indices.shape)),
+            helper.make_tensor_value_info(
+                "in", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], 
list(in_shape)
+            ),
+            helper.make_tensor_value_info("indices", TensorProto.INT64, 
list(indices.shape)),
+        ],
+        outputs=[
+            helper.make_tensor_value_info(
+                "out", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], 
list(out_np.shape)
+            )
         ],
-        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(out_np.shape))],
     )
     model = helper.make_model(graph, producer_name="gather_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [x, indices], target, ctx, 
out_np.shape)
-        tvm.testing.assert_allclose(out_np, tvm_out)
+    verify_with_ort_with_inputs(model, [x, indices], dtype=dtype)
 
 
 @tvm.testing.uses_gpu
@@ -660,10 +656,7 @@ def _test_slice_iteration_v1(indata, outdata, starts, 
ends, axes=None):
     )
 
     model = helper.make_model(graph, producer_name="slice_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 
"float32", opset=1)
-        tvm.testing.assert_allclose(outdata, tvm_out)
+    verify_with_ort_with_inputs(model, [indata], [outdata.shape], opset=1)
 
 
 def _test_slice_iteration_v10(indata, outdata, **attrs):
@@ -738,14 +731,14 @@ def _test_slice_iteration_v10(indata, outdata, **attrs):
 
     if axes:
         axes = np.asarray(axes)
-        inputs.append(helper.make_tensor_value_info("axes", TensorProto.INT32, 
list(axes.shape)))
-        initializer.append(helper.make_tensor("axes", TensorProto.INT32, 
list(axes.shape), axes))
+        inputs.append(helper.make_tensor_value_info("axes", TensorProto.INT64, 
list(axes.shape)))
+        initializer.append(helper.make_tensor("axes", TensorProto.INT64, 
list(axes.shape), axes))
 
     if steps:
         assert axes is not None and len(axes) == len(steps)
         steps = np.asarray(steps)
-        inputs.append(helper.make_tensor_value_info("steps", 
TensorProto.INT32, list(axes.shape)))
-        initializer.append(helper.make_tensor("steps", TensorProto.INT32, 
list(steps.shape), steps))
+        inputs.append(helper.make_tensor_value_info("steps", 
TensorProto.INT64, list(axes.shape)))
+        initializer.append(helper.make_tensor("steps", TensorProto.INT64, 
list(steps.shape), steps))
 
     y = helper.make_node("Slice", ["data", *slice_inputs], ["out"])
 
@@ -758,10 +751,7 @@ def _test_slice_iteration_v10(indata, outdata, **attrs):
         initializer=initializer,
     )
     model = helper.make_model(graph, producer_name="slice_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output_with_vm(model, indata, target, ctx, opset=10, 
freeze_params=True)
-        tvm.testing.assert_allclose(outdata, tvm_out)
+    verify_with_ort_with_inputs(model, [indata], opset=10, freeze_params=True, 
use_vm=True)
 
 
 # TODO(mbrookhart): enable once VM supports heterogenous execution
@@ -854,10 +844,7 @@ def _test_onnx_op_elementwise(inshape, outfunc, npargs, 
dtype, opname, kwargs, o
     )
 
     model = helper.make_model(graph, producer_name=opname + "_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 
dtype, opset=opset)
-        tvm.testing.assert_allclose(outdata, tvm_out)
+    verify_with_ort_with_inputs(model, [indata], [outdata.shape], opset=opset, 
dtype=dtype)
 
 
 @tvm.testing.uses_gpu
@@ -879,6 +866,7 @@ def test_clip():
         "float32",
         "Clip",
         {"min": -1.0, "max": 1.0},
+        opset=6,
     )
 
     _test_onnx_op_elementwise(
@@ -888,7 +876,7 @@ def test_clip():
         "float32",
         "Clip",
         {"max": 1.0},
-        opset=1,
+        opset=6,
     )
 
     _test_onnx_op_elementwise(
@@ -898,7 +886,7 @@ def test_clip():
         "float32",
         "Clip",
         {"min": -1.0},
-        opset=1,
+        opset=6,
     )
 
 
@@ -919,7 +907,7 @@ def test_clip_min_max_as_inputs():
     )
     model = helper.make_model(graph, producer_name="clip_test")
 
-    verify_with_ort(model, [input_shape], input_shape)
+    verify_with_ort(model, [input_shape], out_shape=[input_shape])
 
 
 @tvm.testing.uses_gpu
@@ -941,10 +929,7 @@ def _test_finite_ops(inshape, outfunc, npargs, dtype, 
opname, kwargs):
     )
 
     model = helper.make_model(graph, producer_name=opname + "_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 
dtype)
-        tvm.testing.assert_allclose(outdata, tvm_out)
+    verify_with_ort_with_inputs(model, [indata], [outdata.shape], dtype=dtype)
 
 
 @tvm.testing.uses_gpu
@@ -957,10 +942,9 @@ def test_isnan():
     _test_finite_ops((2, 4, 5, 6), np.isnan, {}, "float32", "IsNaN", {})
 
 
-def verify_gather_nd(in_shape, indices, dtype):
+def verify_gather_nd(in_shape, indices, out_shape, dtype="float32"):
     x = np.random.uniform(size=in_shape).astype(dtype)
-    indices = np.array(indices, dtype="int32")
-    out_np = tvm.topi.testing.gather_nd_python(x, indices)
+    indices = np.array(indices, dtype="int64")
 
     y = helper.make_node("GatherND", ["in", "indices"], ["out"])
 
@@ -968,23 +952,27 @@ def verify_gather_nd(in_shape, indices, dtype):
         [y],
         "gather_test",
         inputs=[
-            helper.make_tensor_value_info("in", TensorProto.FLOAT, 
list(in_shape)),
-            helper.make_tensor_value_info("indices", TensorProto.INT32, 
list(indices.shape)),
+            helper.make_tensor_value_info(
+                "in", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], 
list(in_shape)
+            ),
+            helper.make_tensor_value_info("indices", TensorProto.INT64, 
list(indices.shape)),
+        ],
+        outputs=[
+            helper.make_tensor_value_info(
+                "out", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], 
list(out_shape)
+            )
         ],
-        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(out_np.shape))],
     )
     model = helper.make_model(graph, producer_name="gather_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [x, indices], target, ctx, 
out_np.shape)
-        tvm.testing.assert_allclose(out_np, tvm_out)
+    verify_with_ort_with_inputs(model, [x, indices], [out_shape])
 
 
 @tvm.testing.uses_gpu
 def test_gather_nd():
-    verify_gather_nd((2, 2), [[0, 0], [1, 1]], "int32")
-    verify_gather_nd((3, 3, 3), [[0, 1], [1, 0]], "float32")
-    verify_gather_nd((4, 3, 5, 6), [[2, 1, 0, 0]], "float32")
+    verify_gather_nd([2, 2], [[0, 0], [1, 1]], [2], "int32")
+    verify_gather_nd([2, 2], [[1], [0]], [2, 2])
+    verify_gather_nd([2, 2, 2], [[0, 1], [1, 0]], [2, 2])
+    verify_gather_nd([2, 2, 2], [[[0, 1]], [[1, 0]]], [2, 1, 2])
 
 
 # TODO(mbrookhart): enable once VM supports heterogenous execution
@@ -1011,6 +999,7 @@ def test_onehot():
 
     model = helper.make_model(graph, producer_name="onehot_test")
 
+    # TODO(jwfromm): Replace test against np with test against onnxrt once we 
update versions.
     for target, ctx in tvm.testing.enabled_targets():
         tvm_out = get_tvm_output_with_vm(
             model, [indices_array, np.array([depth]).astype("int32"), values], 
target, ctx
@@ -1022,10 +1011,10 @@ def test_onehot():
 def test_matmul():
     a_shape = (4, 3)
     b_shape = (3, 4)
+    out_shape = [a_shape[0], b_shape[1]]
 
     a_array = np.random.uniform(size=a_shape).astype("float32")
     b_array = np.random.uniform(size=b_shape).astype("float32")
-    out_np = np.matmul(a_array, b_array)
 
     mul_node = helper.make_node("MatMul", ["a", "b"], ["out"])
 
@@ -1036,14 +1025,11 @@ def test_matmul():
             helper.make_tensor_value_info("a", TensorProto.FLOAT, 
list(a_shape)),
             helper.make_tensor_value_info("b", TensorProto.FLOAT, 
list(b_shape)),
         ],
-        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(out_np.shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(out_shape))],
     )
 
     model = helper.make_model(graph, producer_name="matmul_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [a_array, b_array], target, ctx, 
out_np.shape)
-        tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort_with_inputs(model, [a_array, b_array])
 
 
 def verify_batch_matmul(a_shape, b_shape, out_shape, target, ctx):
@@ -1063,10 +1049,7 @@ def verify_batch_matmul(a_shape, b_shape, out_shape, 
target, ctx):
     )
 
     model = helper.make_model(graph, producer_name="matmul_test")
-    onnx_out = get_onnxruntime_output(model, [a_array, b_array], "float32")[0]
-
-    tvm_out = get_tvm_output_with_vm(model, [a_array, b_array], target, ctx)
-    tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort_with_inputs(model, [a_array, b_array], use_vm=True, 
targets=[target])
 
 
 # TODO(mbrookhart): enable cuda once VM supports heterogenous execution
@@ -1152,29 +1135,7 @@ def verify_lrn(shape, nsize, dtype, alpha=None, 
beta=None, bias=None):
         outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(shape))],
     )
     model = helper.make_model(graph, producer_name="lrn_test")
-
-    def _get_python_lrn():
-        square_sum = np.zeros(shape).astype(dtype)
-        for n, c, h, w in np.ndindex(in_array.shape):
-            square_sum[n, c, h, w] = sum(
-                in_array[
-                    n,
-                    max(0, c - int(math.floor((nsize - 1) / 2))) : min(
-                        5, c + int(math.ceil((nsize - 1) / 2)) + 1
-                    ),
-                    h,
-                    w,
-                ]
-                ** 2
-            )
-        py_out = in_array / ((bias + (alpha / nsize) * square_sum) ** beta)
-        return py_out
-
-    for target, ctx in tvm.testing.enabled_targets():
-        input_name = model.graph.input[0].name
-        py_out = _get_python_lrn()
-        tvm_out = get_tvm_output(model, in_array, target, ctx, py_out.shape, 
"float32")
-        tvm.testing.assert_allclose(py_out, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort_with_inputs(model, [in_array])
 
 
 @tvm.testing.uses_gpu
@@ -1184,21 +1145,10 @@ def test_lrn():
 
 
 def verify_instance_norm(shape, axis=1):
-    def _get_python_instance_norm(x, gamma, beta, epsilon=1e-5):
-        dims_x = len(x.shape)
-        axis = tuple(range(2, dims_x))
-        mean = np.mean(x, axis=axis, keepdims=True)
-        var = np.var(x, axis=axis, keepdims=True)
-        dim_ones = (1,) * (dims_x - 2)
-        gamma = gamma.reshape(-1, *dim_ones)
-        beta = beta.reshape(-1, *dim_ones)
-        return gamma * (x - mean) / np.sqrt(var + epsilon) + beta
-
     x = np.random.randn(*shape).astype(np.float32)
     gamma = np.random.randn(shape[1]).astype(np.float32)
     beta = np.random.randn(shape[1]).astype(np.float32)
     epsilon = 1e-5
-    y = _get_python_instance_norm(x, gamma, beta, epsilon).astype(np.float32)
 
     node = onnx.helper.make_node(
         "InstanceNormalization",
@@ -1217,9 +1167,7 @@ def verify_instance_norm(shape, axis=1):
         outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, 
list(shape))],
     )
     model = helper.make_model(graph, producer_name="instance_norm_test")
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [x, gamma, beta], target, ctx, shape, 
"float32")
-        tvm.testing.assert_allclose(y, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort_with_inputs(model, [x, gamma, beta], out_shape=[shape])
 
 
 @tvm.testing.uses_gpu
@@ -1230,14 +1178,13 @@ def test_instance_norm():
     verify_instance_norm((8, 7, 6, 5, 4))
 
 
-def _test_upsample_nearest():
+def verify_upsample_nearest():
     scale = 2
     in_shape = (1, 1, 3, 3)
     out_shape = (1, 1, 3 * scale, 3 * scale)
     y = helper.make_node("Upsample", ["in"], ["out"], mode="nearest", 
scales=[1.0, 1.0, 2.0, 2.0])
 
     in_array = np.random.uniform(size=in_shape).astype(np.float32)
-    out_array = tvm.topi.testing.upsampling_python(in_array, (scale, scale), 
"NCHW")
 
     graph = helper.make_graph(
         [y],
@@ -1247,13 +1194,10 @@ def _test_upsample_nearest():
     )
 
     model = helper.make_model(graph, producer_name="upsample_nearest_test")
+    verify_with_ort_with_inputs(model, [in_array], [out_shape], opset=7)
 
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 
"float32")
-        tvm.testing.assert_allclose(out_array, tvm_out)
 
-
-def _test_upsample3d_nearest():
+def verify_upsample3d_nearest():
     scale = 2
     in_shape = (1, 1, 3, 3, 3)
     out_shape = (1, 1, 3 * scale, 3 * scale, 3 * scale)
@@ -1262,7 +1206,6 @@ def _test_upsample3d_nearest():
     )
 
     in_array = np.random.uniform(size=in_shape).astype(np.float32)
-    out_array = tvm.topi.testing.upsampling3d_python(in_array, (scale, scale, 
scale), "NCDHW")
 
     graph = helper.make_graph(
         [y],
@@ -1272,20 +1215,17 @@ def _test_upsample3d_nearest():
     )
 
     model = helper.make_model(graph, producer_name="upsample_nearest_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 
"float32")
-        tvm.testing.assert_allclose(out_array, tvm_out)
+    # Upsample is deprecated after opset 9
+    verify_with_ort_with_inputs(model, [in_array], [out_shape], opset=7)
 
 
-def _test_upsample_bilinear():
+def verify_upsample_bilinear():
     scale = 2
     in_shape = (1, 1, 3, 3)
     out_shape = (1, 1, 3 * scale, 3 * scale)
     y = helper.make_node("Upsample", ["in"], ["out"], mode="linear", 
scales=[1.0, 1.0, 2.0, 2.0])
 
     in_array = np.random.uniform(size=in_shape).astype(np.float32)
-    out_array = tvm.topi.testing.bilinear_resize_python(in_array, (3 * scale, 
3 * scale), "NCHW")
 
     graph = helper.make_graph(
         [y],
@@ -1295,51 +1235,10 @@ def _test_upsample_bilinear():
     )
 
     model = helper.make_model(graph, producer_name="upsample_bilinear_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 
"float32")
-        tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort_with_inputs(model, [in_array], [out_shape], opset=7)
 
 
-def _test_upsample_bilinear_opset9():
-    scale = 2
-    in_shape = (1, 1, 3, 3)
-    out_shape = (1, 1, 3 * scale, 3 * scale)
-    y = helper.make_node("Upsample", ["in", "scales"], ["out"], mode="linear")
-    scales = [1, 1, 2, 2]
-    in_array = np.random.uniform(size=in_shape).astype(np.float32)
-    out_array = tvm.topi.testing.bilinear_resize_python(in_array, (3 * scale, 
3 * scale), "NCHW")
-
-    ref_node = helper.make_node(
-        "Constant",
-        inputs=[],
-        outputs=["const"],
-        value=onnx.helper.make_tensor(
-            name="const_tensor",
-            data_type=TensorProto.FLOAT,
-            dims=scales,
-            vals=np.random.random(scales).flatten().astype(float),
-        ),
-    )
-
-    shape_node = helper.make_node("Shape", ["const"], ["scales"])
-
-    graph = helper.make_graph(
-        [ref_node, shape_node, y],
-        "upsample_bilinear_opset9_test",
-        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, 
list(in_shape))],
-        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(out_shape))],
-    )
-
-    model = helper.make_model(graph, 
producer_name="upsample_bilinear_opset9_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output_with_vm(
-            model, [in_array], target, ctx, opset=9, freeze_params=True
-        )
-
-
-def _test_upsample3d_trilinear():
+def verify_upsample3d_trilinear():
     scale = 2
     in_shape = (1, 1, 3, 3, 3)
     out_shape = (1, 1, 3 * scale, 3 * scale, 3 * scale)
@@ -1374,7 +1273,8 @@ def _test_upsample3d_trilinear():
     )
 
     model = helper.make_model(graph, producer_name="upsample_trilinear_test")
-
+    # TODO(jwfromm): Trilinear upsampling not supported in 1.0.0 onnxruntime.
+    # Replace topi comparison with verify_with_ort once we update.
     for target, ctx in tvm.testing.enabled_targets():
         tvm_out = get_tvm_output(model, in_array, target, ctx, out_shape, 
"float32")
         tvm.testing.assert_allclose(out_array, tvm_out, rtol=1e-5, atol=1e-5)
@@ -1383,41 +1283,36 @@ def _test_upsample3d_trilinear():
 # TODO(mbrookhart): enable once VM supports heterogenous execution
 # @tvm.testing.uses_gpu
 def test_upsample():
-    _test_upsample_nearest()
-    _test_upsample_bilinear()
-    _test_upsample_bilinear_opset9()
-    _test_upsample3d_nearest()
-    _test_upsample3d_trilinear()
+    verify_upsample_nearest()
+    verify_upsample_bilinear()
+    verify_upsample3d_nearest()
+    verify_upsample3d_trilinear()
 
 
-def _test_softmax(inshape, axis):
+def verify_softmax(inshape, axis):
     opname = "Softmax"
     indata = np.random.uniform(size=inshape).astype(np.float32)
     outshape = inshape
-    outdata = tvm.topi.testing.softmax_python(indata)
-    if isinstance(axis, int):
-        y = helper.make_node(opname, ["in"], ["out"], axis=axis)
-    elif axis is None:
-        y = helper.make_node(opname, ["in"], ["out"])
+    y = helper.make_node(opname, ["in"], ["out"])
+    if axis is not None:
+        axis_attr = helper.make_attribute("axis", axis)
+        y.attribute.append(axis_attr)
 
     graph = helper.make_graph(
         [y],
         opname + "_test",
         inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, 
list(indata.shape))],
-        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(outdata.shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(outshape))],
     )
 
     model = helper.make_model(graph, producer_name=opname + "_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, indata, target, ctx, outshape, 
"float32")
-        tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort_with_inputs(model, [indata])
 
 
 @tvm.testing.uses_gpu
 def test_softmax():
-    _test_softmax((1, 10), None)
-    _test_softmax((1, 10), 1)
+    verify_softmax((1, 10), None)
+    verify_softmax((1, 10), 1)
 
 
 def verify_min(input_dim):
@@ -1427,8 +1322,6 @@ def verify_min(input_dim):
     a_np2 = np.random.uniform(size=input_dim).astype(dtype)
     a_np3 = np.random.uniform(size=input_dim).astype(dtype)
 
-    b_np = np.min((a_np1, a_np2, a_np3), axis=0)
-
     min_node = helper.make_node("Min", ["a_np1", "a_np2", "a_np3"], ["out"])
 
     graph = helper.make_graph(
@@ -1439,14 +1332,11 @@ def verify_min(input_dim):
             helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, 
list(input_dim)),
             helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, 
list(input_dim)),
         ],
-        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(b_np.shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(input_dim))],
     )
 
     model = helper.make_model(graph, producer_name="Min_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, 
b_np.shape)
-        tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort_with_inputs(model, [a_np1, a_np2, a_np3])
 
 
 @tvm.testing.uses_gpu
@@ -1462,8 +1352,6 @@ def verify_max(input_dim):
     a_np2 = np.random.uniform(size=input_dim).astype(dtype)
     a_np3 = np.random.uniform(size=input_dim).astype(dtype)
 
-    b_np = np.max((a_np1, a_np2, a_np3), axis=0)
-
     max_node = helper.make_node("Max", ["a_np1", "a_np2", "a_np3"], ["out"])
 
     graph = helper.make_graph(
@@ -1474,14 +1362,11 @@ def verify_max(input_dim):
             helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, 
list(input_dim)),
             helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, 
list(input_dim)),
         ],
-        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(b_np.shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(input_dim))],
     )
 
     model = helper.make_model(graph, producer_name="Max_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, 
b_np.shape)
-        tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort_with_inputs(model, [a_np1, a_np2, a_np3])
 
 
 @tvm.testing.uses_gpu
@@ -1497,8 +1382,6 @@ def verify_mean(input_dim):
     a_np2 = np.random.uniform(size=input_dim).astype(dtype)
     a_np3 = np.random.uniform(size=input_dim).astype(dtype)
 
-    b_np = np.mean((a_np1, a_np2, a_np3), axis=0)
-
     mean_node = helper.make_node("Mean", ["a_np1", "a_np2", "a_np3"], ["out"])
 
     graph = helper.make_graph(
@@ -1509,14 +1392,11 @@ def verify_mean(input_dim):
             helper.make_tensor_value_info("a_np2", TensorProto.FLOAT, 
list(input_dim)),
             helper.make_tensor_value_info("a_np3", TensorProto.FLOAT, 
list(input_dim)),
         ],
-        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(b_np.shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(input_dim))],
     )
 
     model = helper.make_model(graph, producer_name="Mean_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [a_np1, a_np2, a_np3], target, ctx, 
b_np.shape)
-        tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort_with_inputs(model, [a_np1, a_np2, a_np3])
 
 
 @tvm.testing.uses_gpu
@@ -1530,22 +1410,17 @@ def verify_hardsigmoid(input_dim, alpha, beta):
 
     a_np1 = np.random.uniform(size=input_dim).astype(dtype)
 
-    b_np = np.clip(a_np1 * alpha + beta, 0, 1)
-
     hardsigmoid_node = helper.make_node("HardSigmoid", ["a_np1"], ["out"], 
alpha=alpha, beta=beta)
 
     graph = helper.make_graph(
         [hardsigmoid_node],
         "HardSigmoid_test",
         inputs=[helper.make_tensor_value_info("a_np1", TensorProto.FLOAT, 
list(input_dim))],
-        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(b_np.shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(input_dim))],
     )
 
     model = helper.make_model(graph, producer_name="HardSigmoid_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape)
-        tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort_with_inputs(model, [a_np1])
 
 
 @tvm.testing.uses_gpu
@@ -1554,98 +1429,51 @@ def test_forward_hardsigmoid():
     verify_hardsigmoid((20, 20), 0.3, 0.4)
 
 
-def verify_argmin(input_dim, axis=None, keepdims=None):
-    def _argmin_numpy(data, axis=0, keepdims=True):
-        result = np.argmin(data, axis=axis)
-        if keepdims == 1:
-            result = np.expand_dims(result, axis)
-        return result.astype(data.dtype)
-
+def verify_argreduce(input_dim, op_name, axis=None, keepdims=None):
     a_np1 = np.random.uniform(-10, 10, input_dim).astype(np.int32)
-    if keepdims is None and axis is None:
-        b_np = _argmin_numpy(a_np1)
-        node = onnx.helper.make_node("ArgMin", inputs=["a_np1"], 
outputs=["out"])
-    elif axis is None:
-        b_np = _argmin_numpy(a_np1, keepdims=keepdims)
-        node = onnx.helper.make_node("ArgMin", inputs=["a_np1"], 
outputs=["out"], keepdims=keepdims)
-    elif keepdims is None:
-        b_np = _argmin_numpy(a_np1, axis=axis)
-        node = onnx.helper.make_node("ArgMin", inputs=["a_np1"], 
outputs=["out"], axis=axis)
+    out_shape = list(a_np1.shape)
+    def_axis = axis if axis is not None else 0
+    if keepdims == 1 or keepdims == None:
+        out_shape[def_axis] = 1
     else:
-        b_np = _argmin_numpy(a_np1, axis=axis, keepdims=keepdims)
-        node = onnx.helper.make_node(
-            "ArgMin", inputs=["a_np1"], outputs=["out"], axis=axis, 
keepdims=keepdims
-        )
-    graph = helper.make_graph(
-        [node],
-        "argmin_test",
-        inputs=[helper.make_tensor_value_info("a_np1", TensorProto.INT32, 
list(a_np1.shape))],
-        outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, 
list(b_np.shape))],
-    )
-
-    model = helper.make_model(graph, producer_name="argmin_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, 
b_np.dtype)
-        tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
-
+        out_shape.pop(def_axis)
 
-def verify_argmax(input_dim, axis=None, keepdims=None):
-    def _argmax_numpy(data, axis=0, keepdims=True):
-        result = np.argmax(data, axis=axis)
-        if keepdims == 1:
-            result = np.expand_dims(result, axis)
-        return result.astype(data.dtype)
+    node = onnx.helper.make_node(op_name, inputs=["a_np1"], outputs=["out"])
 
-    a_np1 = np.random.uniform(-10, 10, input_dim).astype(np.int32)
-    if keepdims is None and axis is None:
-        b_np = _argmax_numpy(a_np1)
-        node = onnx.helper.make_node("ArgMax", inputs=["a_np1"], 
outputs=["out"])
-    elif axis is None:
-        b_np = _argmax_numpy(a_np1, keepdims=keepdims)
-        node = onnx.helper.make_node("ArgMax", inputs=["a_np1"], 
outputs=["out"], keepdims=keepdims)
-    elif keepdims is None:
-        b_np = _argmax_numpy(a_np1, axis=axis)
-        node = onnx.helper.make_node("ArgMax", inputs=["a_np1"], 
outputs=["out"], axis=axis)
-    else:
-        b_np = _argmax_numpy(a_np1, axis=axis, keepdims=keepdims)
-        node = onnx.helper.make_node(
-            "ArgMax", inputs=["a_np1"], outputs=["out"], axis=axis, 
keepdims=keepdims
-        )
+    if keepdims is not None:
+        keepdims_attr = helper.make_attribute("keepdims", keepdims)
+        node.attribute.append(keepdims_attr)
+    if axis is not None:
+        axis_attr = helper.make_attribute("axis", axis)
+        node.attribute.append(axis_attr)
 
     graph = helper.make_graph(
         [node],
-        "argmax_test",
+        "argreduce_test",
         inputs=[helper.make_tensor_value_info("a_np1", TensorProto.INT32, 
list(a_np1.shape))],
-        outputs=[helper.make_tensor_value_info("out", TensorProto.INT32, 
list(b_np.shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.INT64, 
list(out_shape))],
     )
 
-    model = helper.make_model(graph, producer_name="argmax_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [a_np1], target, ctx, b_np.shape, 
b_np.dtype)
-        tvm.testing.assert_allclose(b_np, tvm_out, rtol=1e-5, atol=1e-5)
+    model = helper.make_model(graph, producer_name="argreduce_test")
+    verify_with_ort_with_inputs(model, [a_np1])
 
 
 @tvm.testing.uses_gpu
 def test_forward_arg_min_max():
     """Verify argmin and argmax"""
-    verify_argmin([3, 4, 4])
-    verify_argmax([3, 4, 4])
-    verify_argmin([3, 4, 4], axis=1)
-    verify_argmax([3, 4, 4], axis=0)
-    verify_argmin([3, 4, 4], keepdims=0)
-    verify_argmax([3, 4, 4], keepdims=1)
+    verify_argreduce([3, 4, 4], "ArgMin")
+    verify_argreduce([3, 4, 4], "ArgMax")
+    verify_argreduce([3, 4, 4], "ArgMin", axis=1)
+    verify_argreduce([3, 4, 4], "ArgMax", axis=0)
+    verify_argreduce([3, 4, 4], "ArgMin", keepdims=0)
+    verify_argreduce([3, 4, 4], "ArgMax", keepdims=1)
     for axis in [None, 0, 1, 2]:
         for keepdims in [None, True, False]:
-            verify_argmin([3, 4, 4], axis, keepdims)
-            verify_argmax([3, 4, 4], axis, keepdims)
+            verify_argreduce([3, 4, 4], "ArgMin", axis, keepdims)
+            verify_argreduce([3, 4, 4], "ArgMax", axis, keepdims)
 
 
 def verify_constantofshape(input_dim, value, dtype):
-    out = np.empty(shape=input_dim, dtype=dtype)
-    out.fill(value)
-
     fill_node = helper.make_node(
         "ConstantOfShape",
         ["input"],
@@ -1655,22 +1483,22 @@ def verify_constantofshape(input_dim, value, dtype):
         ),
     )
 
-    inputs = [helper.make_tensor_value_info("input", TensorProto.FLOAT, 
input_dim)]
+    inputs = [helper.make_tensor_value_info("input", TensorProto.INT64, 
[len(input_dim)])]
 
     graph = helper.make_graph(
         [fill_node],
         "fill_test",
         inputs,
-        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
list(out.shape))],
+        outputs=[
+            helper.make_tensor_value_info(
+                "output", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], 
input_dim
+            )
+        ],
     )
 
     model = helper.make_model(graph, producer_name="fill_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        input_np = np.array(input_dim).astype("float32")
-        tvm_out = get_tvm_output_with_vm(model, [input_np], target, ctx)
-
-        tvm.testing.assert_allclose(out, tvm_out, rtol=1e-5, atol=1e-5)
+    input_np = np.array(input_dim).astype("int64")
+    verify_with_ort_with_inputs(model, [input_np], use_vm=True)
 
 
 # TODO(mbrookhart): enable once VM supports heterogenous execution
@@ -1708,10 +1536,7 @@ def verify_pad(indata, pads, mode="constant", value=0.0):
         outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
list(outdata.shape))],
     )
     model = helper.make_model(graph, producer_name="pad_test")
-    #  tvm result
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, indata, target, ctx, outdata.shape, 
"float32", opset=2)
-    tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort_with_inputs(model, [indata], [outdata.shape], 
dtype="float32", opset=2)
 
 
 def verify_pad_v11(indata, pads, mode="constant", value=0.0):
@@ -1760,10 +1585,7 @@ def verify_pad_v11(indata, pads, mode="constant", 
value=0.0):
             ],
         )
     model = helper.make_model(graph, producer_name="pad_test")
-    #  tvm result
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output_with_vm(model, inputs, target, ctx, opset=11, 
freeze_params=False)
-    tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort_with_inputs(model, inputs, opset=11, use_vm=True)
 
 
 # TODO(mbrookhart): enable once VM supports heterogenous execution
@@ -1804,7 +1626,7 @@ def verify_reduce_func(func, data, axis, keepdims):
 
     model = helper.make_model(graph, producer_name="reduce_test")
 
-    verify_with_ort_with_inputs(model, [data], outshape)
+    verify_with_ort_with_inputs(model, [data], [outshape])
 
 
 @tvm.testing.uses_gpu
@@ -1849,32 +1671,45 @@ def test_all_reduce_funcs():
             )
 
 
-def verify_split(indata, outdatas, split, axis=0, pass_split=True):
+def verify_split(indata, outdatas, split, axis=0, pass_split=True, opset=11):
     indata = np.array(indata).astype(np.float32)
     outdatas = [np.array(o).astype(np.float32) for o in outdatas]
+    inputs = [helper.make_tensor_value_info("input", TensorProto.FLOAT, 
list(indata.shape))]
+    input_names = ["input"]
+    initializer = []
+
     if split:
         split_index = range(len(split))
     else:
         split_index = range(len(outdatas))
+
     if pass_split:
-        node = helper.make_node(
-            "Split",
-            inputs=["input"],
-            outputs=["output_{}".format(i) for i in range(len(split_index))],
-            axis=axis,
-            split=split,
-        )
-    else:
-        node = helper.make_node(
-            "Split",
-            inputs=["input"],
-            outputs=["output_{}".format(i) for i in range(len(split_index))],
-            axis=axis,
-        )
+        if opset >= 13:
+            input_names.append("split")
+            np_split = np.array(split).astype(np.int64)
+            inputs.append(
+                helper.make_tensor_value_info("split", TensorProto.INT64, 
list(np_split.shape))
+            )
+            indata = [indata, np_split]
+            initializer.append(
+                helper.make_tensor("split", TensorProto.INT64, 
list(np_split.shape), np_split)
+            )
+    node = helper.make_node(
+        "Split",
+        inputs=input_names,
+        outputs=["output_{}".format(i) for i in range(len(split_index))],
+        axis=axis,
+    )
+
+    if pass_split and opset < 13:
+        split_attr = helper.make_attribute("split", split)
+        node.attribute.append(split_attr)
+
     graph = helper.make_graph(
         [node],
         "split_test",
-        inputs=[helper.make_tensor_value_info("input", TensorProto.FLOAT, 
list(indata.shape))],
+        inputs=inputs,
+        initializer=initializer,
         outputs=[
             helper.make_tensor_value_info(
                 "output_{}".format(i), TensorProto.FLOAT, 
list(outdatas[i].shape)
@@ -1883,18 +1718,7 @@ def verify_split(indata, outdatas, split, axis=0, 
pass_split=True):
         ],
     )
     model = helper.make_model(graph, producer_name="split_test")
-
-    import onnxruntime.backend
-
-    rep = onnxruntime.backend.prepare(model, "CPU")
-    onnx_out = rep.run(indata)
-
-    for target, ctx in tvm.testing.enabled_targets():
-        output_shape = [o.shape for o in outdatas]
-        output_type = ["float32", "float32", "float32"]
-        tvm_out = get_tvm_output(model, indata, target, ctx, output_shape, 
output_type)
-        for o, t in zip(onnx_out, tvm_out):
-            tvm.testing.assert_allclose(o, t)
+    verify_with_ort_with_inputs(model, indata, 
out_shape=list(range(len(split_index))), opset=opset)
 
 
 @tvm.testing.uses_gpu
@@ -1914,6 +1738,8 @@ def test_split():
     )
     # Split evenly (unstack)
     verify_split([1, 2, 3], [[1], [2], [3]], False, 0, False)
+    # Split a single value to a single value
+    verify_split([1], [[1]], [1], pass_split=True)
 
 
 @tvm.testing.uses_gpu
@@ -1922,50 +1748,52 @@ def test_binary_ops():
     dtype = "float32"
     out_shape = in_shape
 
-    def verify_binary_ops(op, x, y, out_np, x_name="in1", y_name="in2", 
broadcast=None):
-        if broadcast is None:
-            z = helper.make_node(op, [x_name, y_name], ["out"])
-        else:
-            z = helper.make_node(op, [x_name, y_name], ["out"], broadcast=1)
+    def verify_binary_ops(op, x, y, out_type="float32"):
+        z = helper.make_node(op, ["in1", "in2"], ["out"])
         graph = helper.make_graph(
             [z],
             "_test",
             inputs=[
-                helper.make_tensor_value_info(x_name, TensorProto.FLOAT, 
list(in_shape)),
-                helper.make_tensor_value_info(y_name, TensorProto.FLOAT, 
list(in_shape)),
+                helper.make_tensor_value_info("in1", TensorProto.FLOAT, 
x.shape),
+                helper.make_tensor_value_info("in2", TensorProto.FLOAT, 
y.shape),
+            ],
+            outputs=[
+                helper.make_tensor_value_info(
+                    "out", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(out_type)], 
list(out_shape)
+                )
             ],
-            outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(out_shape))],
         )
         model = helper.make_model(graph, producer_name="_test")
-        for target, ctx in tvm.testing.enabled_targets():
-            tvm_out = get_tvm_output(model, [x, y], target, ctx)
-            tvm.testing.assert_allclose(out_np, tvm_out, rtol=1e-5, atol=1e-5)
+        verify_with_ort_with_inputs(model, [x, y])
 
     x = np.random.uniform(size=in_shape).astype(dtype)
     y = np.random.uniform(size=in_shape).astype(dtype)
     z = np.random.uniform(size=(3,)).astype(dtype)
-    verify_binary_ops("Add", x, y, x + y, broadcast=None)
-    verify_binary_ops("Add", x, z, x + z, broadcast=True)
-    verify_binary_ops("Sub", x, y, x - y, broadcast=None)
-    verify_binary_ops("Sub", x, z, x - z, broadcast=True)
-    verify_binary_ops("Mul", x, y, x * y, broadcast=None)
-    verify_binary_ops("Mul", x, z, x * z, broadcast=True)
-    verify_binary_ops("Mul", x, x, x * x, x_name="in1", y_name="in1", 
broadcast=None)
-    verify_binary_ops("Div", x, y, x / y, broadcast=None)
-    verify_binary_ops("Div", x, z, x / z, broadcast=True)
-    verify_binary_ops("Sum", x, y, x + y, broadcast=None)
-    verify_binary_ops("Greater", x, y, x > y, broadcast=True)
-    verify_binary_ops("Less", x, y, x < y, broadcast=True)
-    verify_binary_ops("Equal", x, y, x == y, broadcast=True)
-
-
[email protected]_gpu
-def test_single_ops():
+    verify_binary_ops("Add", x, y)
+    verify_binary_ops("Add", x, z)
+    verify_binary_ops("Sub", x, y)
+    verify_binary_ops("Sub", x, z)
+    verify_binary_ops("Mul", x, y)
+    verify_binary_ops("Mul", x, z)
+    verify_binary_ops("Div", x, y)
+    verify_binary_ops("Div", x, z)
+    verify_binary_ops("Sum", x, y)
+    verify_binary_ops("Sum", x, z)
+    verify_binary_ops("Greater", x, y, "bool")
+    verify_binary_ops("Greater", x, z, "bool")
+    verify_binary_ops("Less", x, y, "bool")
+    verify_binary_ops("Less", x, z, "bool")
+    verify_binary_ops("Equal", x, y, "bool")
+    verify_binary_ops("Equal", x, z, "bool")
+
+
[email protected]_gpu
+def test_unary_ops():
     in_shape = (1, 2, 3, 3)
     dtype = "float32"
     out_shape = in_shape
 
-    def verify_single_ops(op, x, out_np, rtol=1e-5, atol=1e-5):
+    def verify_unary_ops(op, x, rtol=1e-5, atol=1e-5):
         z = helper.make_node(op, ["in1"], ["out"])
         graph = helper.make_graph(
             [z],
@@ -1976,33 +1804,31 @@ def test_single_ops():
             outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(out_shape))],
         )
         model = helper.make_model(graph, producer_name="_test")
-        for target, ctx in tvm.testing.enabled_targets():
-            tvm_out = get_tvm_output(model, [x], target, ctx)
-            tvm.testing.assert_allclose(out_np, tvm_out, rtol=rtol, atol=atol)
+        verify_with_ort_with_inputs(model, [x], rtol=rtol, atol=atol)
 
     x = np.random.uniform(size=in_shape).astype(dtype)
-    verify_single_ops("Neg", x, -x)
-    verify_single_ops("Abs", x, np.abs(x))
-    verify_single_ops("Reciprocal", x, 1 / x)
-    verify_single_ops("Sqrt", x, np.sqrt(x))
-    verify_single_ops("Relu", x, np.maximum(x, 0))
-    verify_single_ops("Exp", x, np.exp(x))
-    verify_single_ops("Log", x, np.log(x))
-    verify_single_ops("Log", x, np.log(x))
-    verify_single_ops("ACos", x, np.arccos(x))
-    verify_single_ops("ACosh", x, np.arccosh(x))
-    verify_single_ops("ASin", x, np.arcsin(x))
-    verify_single_ops("ASinh", x, np.arcsinh(x))
-    verify_single_ops("ATan", x, np.arctan(x))
-    verify_single_ops("ATanh", x, np.arctanh(x))
-    verify_single_ops("Cos", x, np.cos(x))
-    verify_single_ops("Cosh", x, np.cosh(x))
-    verify_single_ops("Sin", x, np.sin(x))
-    verify_single_ops("Sinh", x, np.sinh(x))
-    verify_single_ops("Tan", x, np.tan(x))
-    verify_single_ops("Tanh", x, np.tanh(x))
-    verify_single_ops("Sigmoid", x, 1 / (1 + np.exp(-x)))
-    verify_single_ops("Softsign", x, x / (1 + np.abs(x)))
+    verify_unary_ops("Neg", x)
+    verify_unary_ops("Abs", x)
+    verify_unary_ops("Reciprocal", x)
+    verify_unary_ops("Sqrt", x)
+    verify_unary_ops("Relu", x)
+    verify_unary_ops("Exp", x)
+    verify_unary_ops("Log", x)
+    verify_unary_ops("Log", x)
+    verify_unary_ops("Acos", x)
+    verify_unary_ops("Acosh", x)
+    verify_unary_ops("Asin", x)
+    verify_unary_ops("Asinh", x)
+    verify_unary_ops("Atan", x)
+    verify_unary_ops("Atanh", x)
+    verify_unary_ops("Cos", x)
+    verify_unary_ops("Cosh", x)
+    verify_unary_ops("Sin", x)
+    verify_unary_ops("Sinh", x)
+    verify_unary_ops("Tan", x)
+    verify_unary_ops("Tanh", x)
+    verify_unary_ops("Sigmoid", x)
+    verify_unary_ops("Softsign", x)
 
 
 @tvm.testing.uses_gpu
@@ -2058,7 +1884,11 @@ def test_prelu():
         model = helper.make_model(graph, producer_name="prelu_test")
 
         verify_with_ort(
-            model, [x_shape, a_shape], list(x_shape), use_vm=True, 
convert_to_static=True
+            model,
+            [x_shape, a_shape],
+            out_shape=[list(x_shape)],
+            use_vm=True,
+            convert_to_static=True,
         )
 
     verify_prelu([3, 4, 5, 6], [1, 4, 1, 1])
@@ -2086,46 +1916,6 @@ def test_ThresholdedRelu():
 
 
 @tvm.testing.uses_gpu
-def test_ScaledTanh():
-    def ScaledTanh_x(x, alpha, beta):
-        return alpha * np.tanh(beta * x)
-
-    _test_onnx_op_elementwise(
-        (2, 4, 5, 6),
-        ScaledTanh_x,
-        {"alpha": 0.25, "beta": 0.3},
-        "float32",
-        "ScaledTanh",
-        {"alpha": 0.25, "beta": 0.3},
-    )
-
-
[email protected]_gpu
-def test_ParametricSoftplus():
-    def ParametricSoftplus_x(x, alpha, beta):
-        return alpha * np.log(np.exp(beta * x) + 1)
-
-    _test_onnx_op_elementwise(
-        (2, 4, 5, 6),
-        ParametricSoftplus_x,
-        {"alpha": 0.25, "beta": 0.3},
-        "float32",
-        "ParametricSoftplus",
-        {"alpha": 0.25, "beta": 0.3},
-    )
-
-
[email protected]_gpu
-def test_Scale():
-    def Scale_x(x, scale):
-        return scale * x
-
-    _test_onnx_op_elementwise(
-        (2, 4, 5, 6), Scale_x, {"scale": 0.25}, "float32", "Scale", {"scale": 
0.25}
-    )
-
-
[email protected]_gpu
 def test_LogSoftmax():
     _test_onnx_op_elementwise(
         (1, 4), tvm.topi.testing.log_softmax_python, {}, "float32", 
"LogSoftmax", {"axis": 1}
@@ -2138,8 +1928,8 @@ def check_torch_conversion(model, input_size):
     # Set verbose=True for more output
     torch.onnx.export(model(), dummy_input, file_name, export_params=True, 
verbose=False)
     onnx_model = onnx.load(file_name)
-    input_data = np.random.uniform(size=input_size).astype("int32")
-    verify_with_ort_with_inputs(onnx_model, [input_data])
+    input_data = np.random.uniform(size=input_size).astype("float32")
+    verify_with_ort_with_inputs(onnx_model, [input_data], apply_softmax=True)
 
 
 @tvm.testing.uses_gpu
@@ -2191,7 +1981,6 @@ def test_sign():
 
 def verify_not(indata, dtype):
     x = indata.astype(dtype)
-    outdata = np.logical_not(x)
 
     node = helper.make_node(
         "Not",
@@ -2203,14 +1992,11 @@ def verify_not(indata, dtype):
         [node],
         "not_test",
         inputs=[helper.make_tensor_value_info("in", TensorProto.BOOL, 
list(x.shape))],
-        outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, 
list(outdata.shape))],
+        outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, 
list(x.shape))],
     )
 
     model = helper.make_model(graph, producer_name="not_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [x], target, ctx, outdata.shape)
-        tvm.testing.assert_allclose(outdata, tvm_out)
+    verify_with_ort_with_inputs(model, [x])
 
 
 @tvm.testing.uses_gpu
@@ -2245,10 +2031,7 @@ def verify_and(indata, dtype):
     )
 
     model = helper.make_model(graph, producer_name="and_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [x, y], target, ctx, outdata.shape)
-        tvm.testing.assert_allclose(outdata, tvm_out)
+    verify_with_ort_with_inputs(model, [x, y], [outdata.shape])
 
 
 @tvm.testing.uses_gpu
@@ -2279,22 +2062,6 @@ def test_and():
     verify_and(indata=[x, y], dtype=bool)
 
 
-def verify_tile_v1(indata, outdata, **kwargs):
-    node = helper.make_node("Tile", inputs=["in"], outputs=["out"], **kwargs)
-    graph = helper.make_graph(
-        [node],
-        "tile_test",
-        inputs=[helper.make_tensor_value_info("in", TensorProto.FLOAT, 
list(indata.shape))],
-        outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(outdata.shape))],
-    )
-
-    model = helper.make_model(graph, producer_name="tile_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape, 
opset=1)
-        tvm.testing.assert_allclose(outdata, tvm_out)
-
-
 def verify_tile_v6(indata, repeats, outdata):
     node = helper.make_node("Tile", inputs=["input", "repeats"], 
outputs=["out"])
     graph = helper.make_graph(
@@ -2308,10 +2075,7 @@ def verify_tile_v6(indata, repeats, outdata):
     )
 
     model = helper.make_model(graph, producer_name="tile_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output_with_vm(model, [indata, repeats], target, 
ctx, opset=6)
-        tvm.testing.assert_allclose(outdata, tvm_out)
+    verify_with_ort_with_inputs(model, [indata, repeats], use_vm=True, opset=6)
 
 
 # TODO(mbrookhart): enable once VM supports heterogenous execution
@@ -2320,7 +2084,6 @@ def test_tile():
     x = np.random.rand(2, 3, 4, 5).astype(np.float32)
     repeats = np.random.randint(low=1, high=10, 
size=(np.ndim(x),)).astype(np.int64)
     z = np.tile(x, repeats)
-    verify_tile_v1(x, z, repeats=repeats)
     verify_tile_v6(x, repeats, z)
 
 
@@ -2333,10 +2096,7 @@ def verify_erf(indata, outdata):
         outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, 
list(outdata.shape))],
     )
     model = helper.make_model(graph, producer_name="erf_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [indata], target, ctx, outdata.shape)
-        tvm.testing.assert_allclose(outdata, tvm_out)
+    verify_with_ort_with_inputs(model, [indata], [outdata.shape])
 
 
 @tvm.testing.uses_gpu
@@ -2359,10 +2119,7 @@ def verify_where(condition, x, y, dtype, outdata):
         outputs=[helper.make_tensor_value_info("out", dtype, 
list(outdata.shape))],
     )
     model = helper.make_model(graph, producer_name="where_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [condition, x, y], target, ctx, 
outdata.shape)
-        tvm.testing.assert_allclose(outdata, tvm_out)
+    verify_with_ort_with_inputs(model, [condition, x, y], [outdata.shape])
 
 
 @tvm.testing.uses_gpu
@@ -2422,10 +2179,7 @@ def verify_or(indata, dtype):
     )
 
     model = helper.make_model(graph, producer_name="or_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [x, y], target, ctx, outdata.shape)
-        tvm.testing.assert_allclose(outdata, tvm_out)
+    verify_with_ort_with_inputs(model, [x, y], [outdata.shape])
 
 
 @tvm.testing.uses_gpu
@@ -2479,7 +2233,7 @@ def test_batch_norm():
         model = helper.make_model(graph, producer_name="batchnorm_test")
         # X, scale, b, mean, var
         inshapes = [in_shape, in_shape[1], in_shape[1], in_shape[1], 
in_shape[1]]
-        verify_with_ort(model, inshapes, in_shape)
+        verify_with_ort(model, inshapes, out_shape=[in_shape])
 
     verify_batch_norm([1, 3, 224, 224])
     verify_batch_norm([1, 3, 24, 24])
@@ -2517,7 +2271,7 @@ def test_batch_norm_dynamic_subgraph():
 
         # X, inp, scale, b, mean, var
         inshapes = [in_shape, o_shape, in_shape[1], in_shape[1], in_shape[1], 
in_shape[1]]
-        verify_with_ort(model, inshapes, in_shape, use_vm=True)
+        verify_with_ort(model, inshapes, out_shape=[in_shape], use_vm=True)
 
     verify_batch_norm_dynamic_subgraph([16, 16, 10, 10], [160, 160])
 
@@ -2581,7 +2335,7 @@ def verify_conv(
 
     model = helper.make_model(graph, producer_name="conv_test")
 
-    verify_with_ort(model, [x_shape, w_shape], y_shape, use_vm=True, 
convert_to_static=True)
+    verify_with_ort(model, [x_shape, w_shape], [y_shape], use_vm=True, 
convert_to_static=True)
 
 
 @tvm.testing.uses_gpu
@@ -2735,7 +2489,7 @@ def verify_convtranspose_with_padding(
 
     model = helper.make_model(graph, producer_name="conv_test")
 
-    verify_with_ort(model, [x_shape, w_shape], y_shape, use_vm=True, 
convert_to_static=True)
+    verify_with_ort(model, [x_shape, w_shape], [y_shape], use_vm=True, 
convert_to_static=True)
 
 
 def verify_convtranspose(x_shape, w_shape, y_shape, p):
@@ -2908,7 +2662,7 @@ def verify_pooling(x_shape, kernel_shape, strides, pads, 
out_shape, mode, auto_p
     )
 
     model = helper.make_model(graph, producer_name="pooling_test")
-    verify_with_ort(model, [x_shape], out_shape, use_vm=True, 
convert_to_static=True)
+    verify_with_ort(model, [x_shape], [out_shape], use_vm=True, 
convert_to_static=True)
 
 
 @tvm.testing.uses_gpu
@@ -3013,7 +2767,7 @@ def verify_mod(x_shape, y_shape, fmod, out_shape, 
dtype="float32"):
         outputs=[helper.make_tensor_value_info("z", onnx_dtype, 
list(out_shape))],
     )
     model = helper.make_model(graph, producer_name="mod_test")
-    verify_with_ort_with_inputs(model, [x_np, y_np], out_shape)
+    verify_with_ort_with_inputs(model, [x_np, y_np], [out_shape])
 
 
 @tvm.testing.uses_gpu
@@ -3066,10 +2820,7 @@ def verify_xor(x_shape, y_shape):
         outputs=[helper.make_tensor_value_info("z", onnx_dtype, 
list(out_shape))],
     )
     model = helper.make_model(graph, producer_name="xor_test")
-
-    for target, ctx in tvm.testing.enabled_targets():
-        tvm_out = get_tvm_output(model, [x_np, y_np], target, ctx, out_shape)
-        tvm.testing.assert_allclose(np_out, tvm_out, rtol=1e-5, atol=1e-5)
+    verify_with_ort_with_inputs(model, [x_np, y_np], [out_shape])
 
 
 @tvm.testing.uses_gpu
@@ -3106,7 +2857,7 @@ def verify_max_roi_pool(x_shape, rois_shape, 
pooled_shape, spatial_scale, out_sh
     )
 
     model = helper.make_model(graph, producer_name="pool_test")
-    verify_with_ort(model, [x_shape, rois_shape], out_shape)
+    verify_with_ort(model, [x_shape, rois_shape], [out_shape])
 
 
 @tvm.testing.uses_gpu
@@ -3158,7 +2909,7 @@ def verify_lppool(x_shape, kernel_shape, p, strides, 
pads, out_shape, auto_pad="
     )
 
     model = helper.make_model(graph, producer_name="lppool_test")
-    verify_with_ort(model, [x_shape], out_shape, use_vm=True, 
convert_to_static=True)
+    verify_with_ort(model, [x_shape], [out_shape], use_vm=True, 
convert_to_static=True)
 
 
 @tvm.testing.uses_gpu
@@ -3350,18 +3101,7 @@ def verify_rnn(
 
     model = helper.make_model(graph, producer_name="rnn_test")
 
-    for target, ctx in tvm.testing.enabled_targets():
-        onnx_out = get_onnxruntime_output(model, input_values, "float32")
-        tvm_out = get_tvm_output(
-            model,
-            input_values,
-            target,
-            ctx,
-            output_shapes,
-            output_dtype=["float32"] * len(output_shapes),
-        )
-        for o_out, t_out in zip(onnx_out, tvm_out):
-            tvm.testing.assert_allclose(o_out, t_out, rtol=5e-3, atol=5e-3)
+    verify_with_ort_with_inputs(model, input_values, output_shapes, atol=1e-2, 
rtol=1e-2)
 
 
 @tvm.testing.uses_gpu
@@ -3566,7 +3306,7 @@ def test_resize():
 
         model = helper.make_model(graph, producer_name="resize_test")
 
-        verify_with_ort(model, [ishape], oshape, use_vm=True, opset=11, 
freeze_params=True)
+        verify_with_ort(model, [ishape], [oshape], use_vm=True, opset=11, 
freeze_params=True)
 
     # upsampling
     verify([1, 16, 32, 32], [1, 16, 64, 64], [], "nearest", "asymmetric")
@@ -3603,9 +3343,7 @@ def test_resize():
         )
 
         model = helper.make_model(graph, producer_name="resize_test")
-        model.opset_import[0].version = 10
-
-        verify_with_ort(model, [ishape], oshape, use_vm=True, 
freeze_params=True)
+        verify_with_ort(model, [ishape], [oshape], use_vm=True, 
freeze_params=True, opset=10)
 
     verify_opset_10([1, 16, 32, 32], [1, 1, 2, 2], "nearest")
     verify_opset_10([1, 16, 32, 32], [1, 1, 0.5, 0.5], "linear")
@@ -3674,11 +3412,7 @@ def test_topk():
         model = helper.make_model(graph, producer_name="topk_test")
 
         indata = np.random.uniform(-10, 10, input_dims).astype(np.float32)
-        onnx_out = get_onnxruntime_output(model, [indata, np.array([K])])
-
-        for target, ctx in [("llvm", tvm.cpu())]:
-            tvm_out = get_tvm_output_with_vm(model, [indata, np.array(K)], 
target, ctx)
-            tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, 
atol=1e-05)
+        verify_with_ort_with_inputs(model, [indata, np.array([K])], 
use_vm=True)
 
     for n in [12, 32]:
         for shape in [[n], [n, n], [n, n, n]]:
@@ -3731,7 +3465,9 @@ def test_roi_align():
         np_rois = np.random.uniform(size=[num_roi, 4]).astype("float32") * 
input_dims[2]
         np_batch_indicies = np.random.randint(low=0, high=input_dims[0], 
size=num_roi)
 
-        verify_with_ort_with_inputs(model, [np_data, np_rois, 
np_batch_indicies], output_dims)
+        verify_with_ort_with_inputs(
+            model, [np_data, np_rois, np_batch_indicies], 
out_shape=[output_dims]
+        )
 
     verify_roi_align((1, 4, 16, 16), 32, 7, 7, sampling_ratio=0, 
spatial_scale=1.0)
     verify_roi_align((4, 4, 16, 32), 32, 7, 7, sampling_ratio=0, 
spatial_scale=1.0)
@@ -3914,12 +3650,7 @@ def verify_cond_loop():
     trip_count = np.array(40).astype(np.int64)
     cond = np.array(1).astype(np.bool)
     input_vals = [trip_count, cond, y]
-    onnx_out = get_onnxruntime_output(loop_model, input_vals)
-
-    for target, ctx in [("llvm", tvm.cpu())]:
-        tvm_out = get_tvm_output_with_vm(loop_model, input_vals, target, ctx, 
freeze_params=True)
-        for i in range(len(tvm_out)):
-            tvm.testing.assert_allclose(onnx_out[i], tvm_out[i], rtol=1e-05, 
atol=1e-05)
+    verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, 
freeze_params=True)
 
 
 def verify_count_loop():
@@ -3974,12 +3705,7 @@ def verify_count_loop():
     trip_count = np.array(5).astype(np.int64)
     cond = np.array(1).astype(np.bool)
     input_vals = [trip_count, cond, y]
-    onnx_out = get_onnxruntime_output(loop_model, input_vals)
-
-    for target, ctx in [("llvm", tvm.cpu())]:
-        tvm_out = get_tvm_output_with_vm(loop_model, input_vals, target, ctx, 
freeze_params=True)
-        for i in range(len(tvm_out)):
-            tvm.testing.assert_allclose(onnx_out[i], tvm_out[i], rtol=1e-05, 
atol=1e-05)
+    verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, 
freeze_params=True)
 
 
 def test_loop():
@@ -3999,11 +3725,11 @@ def verify_if(cond_array):
     y = np.array([5, 4, 3, 2, 1]).astype(np.float32)
 
     then_const_node = onnx.helper.make_node(
-        "Constant", inputs=[], outputs=["then_out"], 
value=onnx.numpy_helper.from_array(x)
+        "Constant", inputs=[], outputs=["then_out"], 
value=numpy_helper.from_array(x)
     )
 
     else_const_node = onnx.helper.make_node(
-        "Constant", inputs=[], outputs=["else_out"], 
value=onnx.numpy_helper.from_array(y)
+        "Constant", inputs=[], outputs=["else_out"], 
value=numpy_helper.from_array(y)
     )
 
     then_body = onnx.helper.make_graph([then_const_node], "then_body", [], 
[then_out])
@@ -4032,6 +3758,8 @@ def verify_if(cond_array):
         cond = np.array(1).astype("bool")
     correct_out = x if cond else y
 
+    # TODO(jwfromm): Onnxruntime 1.0.0 is buggy with If statements. Replace 
this with
+    # verify_with_ort once we update versions.
     for target, ctx in tvm.testing.enabled_targets():
         tvm_out = get_tvm_output_with_vm(if_model, [cond], target, ctx, 
freeze_params=True)
         for i in range(len(tvm_out)):
@@ -4204,15 +3932,12 @@ if __name__ == "__main__":
     test_pad()
     test_split()
     test_binary_ops()
-    test_single_ops()
+    test_unary_ops()
     test_leaky_relu()
     test_elu()
     test_selu()
     test_prelu()
     test_ThresholdedRelu()
-    test_ScaledTanh()
-    test_ParametricSoftplus()
-    test_Scale()
     test_LogSoftmax()
     test_resnet()
     test_inception()

Reply via email to