This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 98aa41e329 [Relay] Flexible shape dispatch transformation (#11199)
98aa41e329 is described below

commit 98aa41e329c20a5b8b34a34387fcc9067db5f22a
Author: Josh Fromm <[email protected]>
AuthorDate: Fri May 6 12:18:16 2022 -0700

    [Relay] Flexible shape dispatch transformation (#11199)
    
    * Added pass that creates a semi-dynamic dispatcher around a relay module.
    
    * Added automatic padding feature.
    
    * Output slicing working.
    
    * Multiple input support working i think.
    
    * Added test file.
    
    * Improve comments.
    
    * Fix lint.
    
    * Allow default values.
    
    * Fix docstring.
    
    * Improved documentation based on feedback.
    
    * Add extra check for record loading.
    
    * Improve variable names.
    
    * Add type inference to make sure things worked.
    
    * Added support for multiple outputs.
---
 python/tvm/auto_scheduler/dispatcher.py            |  29 +-
 python/tvm/autotvm/task/dispatcher.py              |  35 +-
 python/tvm/relay/transform/__init__.py             |   1 +
 python/tvm/relay/transform/flexible_shape.py       | 369 +++++++++++++++++++++
 tests/python/relay/test_auto_scheduler_tuning.py   |   6 +
 .../relay/test_pass_flexible_shape_dispatch.py     | 119 +++++++
 tests/python/unittest/test_autotvm_record.py       |   5 +
 7 files changed, 545 insertions(+), 19 deletions(-)

diff --git a/python/tvm/auto_scheduler/dispatcher.py 
b/python/tvm/auto_scheduler/dispatcher.py
index cc1e76b9fa..eceeba38e0 100644
--- a/python/tvm/auto_scheduler/dispatcher.py
+++ b/python/tvm/auto_scheduler/dispatcher.py
@@ -130,11 +130,13 @@ class ApplyHistoryBest(DispatchContext):
 
     Parameters
     ----------
-    records : str or iterator of (auto_scheduler.measure.MeasureInput,\
-                                  auto_scheduler.measure.MeasureResult)
+    records : str, list of str, or iterator of 
(auto_scheduler.measure.MeasureInput,\
+                                                
auto_scheduler.measure.MeasureResult)
         Collection of tuning records.
         If is str, then it should be the filename of a records log file.
-        Each row of this file is an encoded record pair. Otherwise, it is an 
iterator.
+        Each row of this file is an encoded record pair. If it is an iterator,
+        it can either be a set of str filenames which will be applied jointly,
+        or a set of (input, result) tuples.
     n_lines: Optional[int]
         if it is not None, only load the first `n_lines` lines of log.
     include_compatible: bool
@@ -196,20 +198,29 @@ class ApplyHistoryBest(DispatchContext):
         n_lines: Optional[int]
             if it is not None, only load the first `n_lines` lines of log
         """
-        if isinstance(records, pathlib.Path):
-            records = str(records)
+        joint_records = []
+        if not isinstance(records, (list, tuple)):
+            records = [records]
 
-        if isinstance(records, str):
-            records = load_records(records)
+        for rec in records:
+            if isinstance(rec, pathlib.Path):
+                rec = str(rec)
+
+            if isinstance(rec, str):
+                rec = load_records(rec)
+                joint_records += rec
+            else:
+                if rec is not None:
+                    joint_records.append(rec)
 
-        if not records:
+        if not joint_records:
             return
 
         best_by_targetkey = self.best_by_targetkey
         best_by_model = self.best_by_model
 
         counter = 0
-        for inp, res in records:
+        for inp, res in joint_records:
             if n_lines is not None and counter >= n_lines:
                 break
             counter += 1
diff --git a/python/tvm/autotvm/task/dispatcher.py 
b/python/tvm/autotvm/task/dispatcher.py
index bed0258127..ffff50b9dc 100644
--- a/python/tvm/autotvm/task/dispatcher.py
+++ b/python/tvm/autotvm/task/dispatcher.py
@@ -184,10 +184,12 @@ class ApplyHistoryBest(DispatchContext):
 
     Parameters
     ----------
-    records : str or iterator of (autotvm.measure.MeasureInput, 
autotvm.measure.MeasureResult)
+    records : str, list of str, or iterator of (autotvm.measure.MeasureInput,\
+                                                autotvm.measure.MeasureResult)
         Collection of tuning records.
         If is str, then it should be the filename of a records log file.
-        Each row of this file is an encoded record pair. Otherwise, it is an 
iterator.
+        Each row of this file is an encoded record pair. If it is a list, it 
can either be
+        a list of paths to log files that will be loaded jointly or an 
iterator or records.
     """
 
     def __init__(self, records):
@@ -205,28 +207,41 @@ class ApplyHistoryBest(DispatchContext):
 
         Parameters
         ----------
-        records : str or iterator of (autotvm.measure.MeasureInput, 
autotvm.measure.MeasureResult)
+        records : str, list of str, or iterator of 
(autotvm.measure.MeasureInput,\
+                                                    
autotvm.measure.MeasureResult)
             Collection of tuning records.
             If is str, then it should be the filename of a records log file.
-            Each row of this file is an encoded record pair. Otherwise, it is 
an iterator.
+            Each row of this file is an encoded record pair. If it is a list
+            it can either be a list of paths to logs that will loaded jointly 
or
+            an iterator of measurement results.
         """
         # pylint: disable=import-outside-toplevel
         from pathlib import Path
         from ..record import load_from_file
 
-        if isinstance(records, Path):
-            records = str(records)
+        joint_records = []
+        if not isinstance(records, (list, tuple)):
+            records = [records]
 
-        if isinstance(records, str):
-            records = load_from_file(records)
-        if not records:
+        for rec in records:
+            if isinstance(rec, Path):
+                rec = str(rec)
+
+            if isinstance(rec, str):
+                rec = load_from_file(rec)
+                joint_records += rec
+            else:
+                if rec is not None:
+                    joint_records.append(rec)
+
+        if not joint_records:
             return
 
         best_by_targetkey = self.best_by_targetkey
         best_by_model = self.best_by_model
 
         counter = 0
-        for inp, res in records:
+        for inp, res in joint_records:
             counter += 1
             if res.error_no != 0:
                 continue
diff --git a/python/tvm/relay/transform/__init__.py 
b/python/tvm/relay/transform/__init__.py
index 378b0c38ff..c10b8f8ff3 100644
--- a/python/tvm/relay/transform/__init__.py
+++ b/python/tvm/relay/transform/__init__.py
@@ -20,3 +20,4 @@
 from .transform import *
 from .recast import recast
 from . import fake_quantization_to_integer, mixed_precision
+from .flexible_shape import FlexibleShapeDispatch
diff --git a/python/tvm/relay/transform/flexible_shape.py 
b/python/tvm/relay/transform/flexible_shape.py
new file mode 100644
index 0000000000..c38fde0e70
--- /dev/null
+++ b/python/tvm/relay/transform/flexible_shape.py
@@ -0,0 +1,369 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Relay functions for wrapping a module with flexible shape dispatch."""
+import tvm
+from tvm import relay
+
+
+def override_shape(tensor_type, axis, dim):
+    """Change a dimension in a tensor shape."""
+    # Handle multiple tensors by overriding the shape of each.
+    if isinstance(tensor_type, relay.TupleType):
+        tensor_type = tensor_type.fields
+    else:
+        tensor_type = [tensor_type]
+
+    # Create new tensortypes for each input.
+    new_types = []
+    for t_type in tensor_type:
+        new_dims = list(t_type.shape)
+        new_dims[axis] = dim
+        new_types.append(relay.TensorType(new_dims, t_type.dtype))
+
+    # Dont return a tuple if there is a single tensor.
+    if len(new_types) == 1:
+        return new_types[0]
+    return relay.TupleType(tvm.runtime.convert(new_types))
+
+
+def specialize_body(mod, function, axis, dim, input_indices, 
affects_output=True):
+    """
+    Create a subgraph to handle specific input shapes
+
+    This function takes in a module and one of it's functions and creates a
+    similar function with a specific input shape. It then attaches the new 
function
+    to the module. Calling this function multiple times results in a module 
that
+    contains several similar functions each specialized to a specific input 
shape.
+    This allows a dispatch handler to be built on top of the module to deal 
with
+    flexible shapes.
+
+    There are a few modes to this function. When the specialized function has 
multiple
+    flexible inputs, the index of those inputs must be provided to the 
input_indices argument.
+    In this case, the axis of the flexible dimension for each of those inputs 
must be the same.
+
+    By default, this function assumes that the output shape is dependent on 
the input
+    shape (as is the case in dynamic batching) and will also specialize the 
output type
+    accordingly. If this is not true, the affects_output argument must be set 
to False.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The module that contains specialized functions and the dispatcher.
+    function: Function
+        The original non-specialized function that will be transformed.
+    axis: int
+        Which axis the flexible shape is on.
+    dim: int
+        The shape to specialize the new subgraph for along the axis dim.
+    input_indices: List[int]
+        Which inputs should be dispatched dynamically, provided by index. All 
inputs
+        must share the same dynamic axis.
+    affects_output: Optional[bool]
+        Whether the change in input shape has a corresponding effect on the 
output shape.
+        Batching for example effects both the input and output whereas 
changing sequence
+        length in an NLP model typically does not.
+
+    Returns
+    -------
+    gvar : GlobalVar
+        The new variable for the specialized subgraph.
+    spec_types : List[TensorType]
+        A list of the new specialized types for each input in the graph.
+    """
+    # Iterate through specified inputs and construct specialized shapes for 
each.
+    new_params = list(function.params)
+    data_binding = {}
+    dyn_data_array = []
+    for inp in input_indices:
+        data = function.params[inp]
+        flex_ty = override_shape(data.type_annotation, axis, dim)
+        dyn_data = relay.Var(data.name_hint, type_annotation=flex_ty)
+        new_params[inp] = dyn_data
+        data_binding[data] = dyn_data
+        dyn_data_array.append(dyn_data)
+
+    # Create a new function body for the modified shapes.
+    new_body = relay.expr.bind(function.body, data_binding)
+    # Only change the output shape if the input shape affects it.
+    if affects_output:
+        new_ret_ty = override_shape(function.ret_type, axis, dim)
+    else:
+        new_ret_ty = function.ret_type
+    gvar = relay.GlobalVar("main_" + str(dim))
+    # Add the new function to the main IRModule.
+    mod[gvar] = relay.Function(
+        new_params, new_body, new_ret_ty, function.type_params, function.attrs
+    )
+    return gvar, [d.type_annotation for d in dyn_data_array]
+
+
+def flexible_dispatch(
+    mod, buckets, axis=0, auto_pad=False, pad_value=0, input_indices=None, 
affects_output=True
+):
+    """
+    Enable inference of multiple shaped inputs in one module.
+
+    This transformation adds a handler around a module that
+    checks input shapes and dispatches to a subgraph specialized
+    to handle the specific shapes of that input. If no exactly matching
+    subgraph is available, the input will be run using full dynamism.
+    For best performance, specify all the sizes the module will
+    be likely to see using the buckets argument.
+
+    By default, this function will dispatch shapes that exactly match one
+    of the buckets to a corresponding subgraph. All non-matching shapes
+    use the same fully dynamic fallback. This can be detrimental to performance
+    for those non-matching shapes. Setting auto_pad to True causes this
+    function to round-up the shape of non-matching inputs to the closest
+    bucket. This allows them to use the tuned kernels of bucket shapes
+    which can improve performance.
+
+    Functions that have multiple inputs sharing a dynamic axis, which
+    is common for batch size or sequence length dynamism, are supported
+    through the input_indices argument.
+
+    Many types of dynamism such as batching affect both the input and output
+    shape, however this is not always the case. If the output shape
+    is independent of the input, the affects_output argument of this
+    function must be set to False.
+
+    Parameters
+    ----------
+    buckets: list[int]
+        The sizes of the input dimension that should be explicitly handled.
+        Each value in buckets will have a corresponding subgraph constructed to
+        handle it.
+    axis: int
+        The dimension of the input that should be made flexible. This will
+        most often be used for the batch dimension.
+    auto_pad: Optional[bool]
+        If True, then padding will be inserted to values that don't match one 
of
+        the provided buckets.
+    pad_value: Optional[float]
+        When auto_pad is true, padding will be done with this value.
+    input_indices: Optional[List[int]]
+        Which inputs should be dispatched dynamically, provided by index. All 
inputs
+        must share the same dynamic axis.
+    affects_output: Optional[bool]
+        Whether the change in input shape has a corresponding effect on the 
output shape.
+        Batching for example effects both the input and output whereas 
changing sequence
+        length in an NLP model typically does not.
+
+    Returns
+    -------
+    mod : IRModule
+        The new module wrapped with a flexible shape dispatch handler.
+    """
+    main_fn = mod["main"]
+
+    # Default to single input if not specified.
+    if input_indices is None:
+        input_indices = [0]
+
+    # Extract all input data and create a new dynamic variable for each.
+    data = []
+    dyn_data = []
+    for i in input_indices:
+        data.append(main_fn.params[i])
+        dyn_shape = override_shape(data[i].type_annotation, axis, relay.Any())
+        dyn_data.append(relay.Var(data[i].name_hint, 
type_annotation=dyn_shape))
+
+    # Extract the dynamic shape value from one of the inputs.
+    rt_sh = relay.op.shape_of(dyn_data[0])
+    flex_value = relay.op.take(rt_sh, relay.const(axis))
+
+    if_exprs = []
+
+    for i, bucket in enumerate(buckets):
+        input_data = dyn_data
+        check_dim = flex_value
+
+        # Apply automatic padding if specified.
+        if auto_pad:
+            input_data = []
+            # Construct padding expression for inputs.
+            for j, inp in enumerate(dyn_data):
+                pad_width = relay.const(bucket) - flex_value
+                rank = len(data[j].type_annotation.shape)
+                pads = relay.zeros([rank, 2], "int32")
+                pads = relay.scatter_nd(pads, relay.const([axis, 1]), 
pad_width)
+                padded_value = relay.nn.pad(inp, pads, pad_value)
+
+                # Determine if this is the proper bucket to pad to. Do this by 
checking if the
+                # input shape is between this bucket and the previous.
+                if i == 0:
+                    padded_value = relay.If(
+                        relay.op.less_equal(flex_value, relay.const(bucket)), 
padded_value, inp
+                    )
+                else:
+                    padded_value = relay.If(
+                        relay.op.logical_and(
+                            relay.op.less_equal(flex_value, 
relay.const(bucket)),
+                            relay.op.greater(flex_value, relay.const(buckets[i 
- 1])),
+                        ),
+                        padded_value,
+                        inp,
+                    )
+                # Update input value and test dimension to reflect possible 
padding.
+                input_data.append(padded_value)
+            # Grab the new possibly padded shape for checking bucket size.
+            check_dim = relay.op.take(relay.op.shape_of(input_data[0]), 
relay.const(axis))
+
+        # Create a specialized subgraph for the current bucket.
+        spec_call, spec_ty = specialize_body(
+            mod, main_fn, axis, bucket, input_indices=input_indices, 
affects_output=affects_output
+        )
+        # Apply hard casting to shape to create statically typed graphs.
+        spec_data = []
+        for j, inp in enumerate(input_data):
+            spec_data.append(relay.op.reshape(inp, spec_ty[j].shape))
+
+        # Create a dispatch statement for the current specialized graph.
+        call_args = list(main_fn.params)
+        for j, inp in enumerate(input_indices):
+            call_args[inp] = spec_data[j]
+        new_call = spec_call(*call_args)
+
+        # Remove meaningless padded outputs if applicable.
+        if auto_pad and affects_output:
+            new_call = relay.take(
+                new_call,
+                relay.arange(start=relay.const(0), stop=flex_value, 
dtype="int32"),
+                axis=axis,
+            )
+
+        # Add this new case to the dispatch handler.
+        if_exprs.append((relay.op.equal(check_dim, relay.const(bucket)), 
new_call))
+
+    # Create a subgraph to handle all other shapes.
+    default_dyn_call, _ = specialize_body(
+        mod, main_fn, axis, relay.Any(), input_indices=input_indices, 
affects_output=affects_output
+    )
+    call_args = list(main_fn.params)
+    for j, inp in enumerate(input_indices):
+        call_args[inp] = dyn_data[j]
+    new_body = default_dyn_call(*call_args)
+
+    # Create an If chain to dispatch shapes to the appropriate specialized 
subgraph.
+    for cond, true_branch in if_exprs:
+        new_body = relay.If(cond, true_branch, new_body)
+
+    # Assign new parameters to the function.
+    new_params = list(main_fn.params)
+    for j, inp in enumerate(input_indices):
+        new_params[inp] = dyn_data[j]
+
+    # Update the output shape to be dynamic if needed.
+    if affects_output:
+        dyn_ret_type = override_shape(main_fn.ret_type, axis, relay.Any())
+    else:
+        dyn_ret_type = main_fn.ret_type
+
+    # Assign the handler as the new entrypoint in the module.
+    new_main = relay.Function(
+        new_params, new_body, dyn_ret_type, main_fn.type_params, main_fn.attrs
+    )
+    mod["main"] = new_main
+    # Do type inference to make sure everything worked.
+    mod = relay.transform.InferType()(mod)
+    return mod
+
+
+class FlexibleShapeDispatch(object):
+    """Enable inference of multiple shaped inputs in one module.
+
+    This transformation adds a handler around a module that
+    checks input shapes and dispatches to a subgraph specialized
+    to handle the specific shapes of that input. If no exactly matching
+    subgraph is available, the input will be run using full dynamism.
+    For best performance, specify all the sizes the module will
+    be likely to see using the buckets argument.
+
+    By default, this pass will dispatch shapes that exactly match one
+    of the buckets to a corresponding subgraph. All non-matching shapes
+    use the same fully dynamic fallback. This can be detrimental to performance
+    for those non-matching shapes. Setting auto_pad to True causes this
+    pass to round-up the shape of non-matching inputs to the closest
+    bucket. This allows them to use the tuned kernels of bucket shapes
+    which can improve performance.
+
+    Models that have multiple inputs sharing a dynamic axis, which
+    is common for batch size or sequence length dynamism, are supported
+    through the input_indices argument.
+
+    Many types of dynamism such as batching affect both the input and output
+    shape, however this is not always the case. If the output shape
+    is independent of the input, the affects_output argument of this
+    pass must be set to False.
+
+    Parameters
+    ----------
+    buckets: list[int]
+        The sizes of the input dimension that should be explicitly handled.
+        Each value in buckets will have a corresponding subgraph constructed to
+        handle it.
+    axis: int
+        The dimension of the input that should be made flexible. This will
+        most often be used for the batch dimension.
+    auto_pad: Optional[bool]
+        If True, then padding will be inserted to values that don't match one 
of
+        the provided buckets.
+    pad_value: Optional[float]
+        When auto_pad is true, padding will be done with this value.
+    input_indices: Optional[List[int]]
+        Which inputs should be dispatched dynamically, provided by index. All 
inputs
+        must share the same dynamic axis.
+    affects_output: Optional[bool]
+        Whether the change in input shape has a corresponding effect on the 
output shape.
+        Batching for example effects both the input and output whereas 
changing sequence
+        length in an NLP model typically does not.
+
+    Returns
+    -------
+    ret : FlexibleShapeDispatch
+        A pass that can be applied to a module to add flexible shape handling.
+    """
+
+    def __init__(
+        self,
+        buckets,
+        axis=0,
+        auto_pad=False,
+        pad_value=0,
+        input_indices=None,
+        affects_output=True,
+    ):
+        self.axis = axis
+        self.buckets = buckets
+        self.auto_pad = auto_pad
+        self.pad_value = pad_value
+        self.input_indices = input_indices
+        self.affects_output = affects_output
+        super(FlexibleShapeDispatch, self).__init__()
+
+    def __call__(self, mod):
+        # Shape information is required for this pass.
+        mod = relay.transform.InferType()(mod)
+        return flexible_dispatch(
+            mod,
+            self.buckets,
+            self.axis,
+            self.auto_pad,
+            self.pad_value,
+            self.input_indices,
+            self.affects_output,
+        )
diff --git a/tests/python/relay/test_auto_scheduler_tuning.py 
b/tests/python/relay/test_auto_scheduler_tuning.py
index 1431824899..c9ce5b59ff 100644
--- a/tests/python/relay/test_auto_scheduler_tuning.py
+++ b/tests/python/relay/test_auto_scheduler_tuning.py
@@ -56,6 +56,12 @@ def tune_network(network, target):
             ):
                 lib = relay.build(mod, target=target, params=params)
 
+        # Also test that multiple log files can be loaded.
+        with auto_scheduler.ApplyHistoryBest([log_file, log_file]) as best:
+            assert isinstance(
+                best, auto_scheduler.dispatcher.ApplyHistoryBest
+            ), "Unable to load multiple log files jointly."
+
         # Sample a schedule when missing
         with auto_scheduler.ApplyHistoryBestOrSample(None, num_measure=2):
             with tvm.transform.PassContext(
diff --git a/tests/python/relay/test_pass_flexible_shape_dispatch.py 
b/tests/python/relay/test_pass_flexible_shape_dispatch.py
new file mode 100644
index 0000000000..a6d547f4f5
--- /dev/null
+++ b/tests/python/relay/test_pass_flexible_shape_dispatch.py
@@ -0,0 +1,119 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test flexible shape dispatch pass"""
+import numpy as np
+import pytest
+import tvm
+from tvm import relay
+from tvm.relay.testing.resnet import get_workload
+from tvm.relay import vm
+from tvm import runtime
+
+
+def test_end_to_end():
+    # Load a resnet model.
+    mod, params = get_workload()
+    # Apply flexible dispatch pass.
+    mod = relay.transform.FlexibleShapeDispatch(axis=0, buckets=[1, 4], 
auto_pad=True)(mod)
+    # Compile and confirm result supports multiple shapes.
+    exe = relay.vm.compile(mod, "llvm", params=params)
+    vm = runtime.vm.VirtualMachine(exe, tvm.cpu())
+
+    # Evaluate various batch sizes
+    batch_1 = np.random.normal(size=[1, 3, 224, 224]).astype("float32")
+    assert list(vm.invoke("main", batch_1).shape) == [1, 1000]
+
+    batch_4 = np.random.normal(size=[4, 3, 224, 224]).astype("float32")
+    assert list(vm.invoke("main", batch_4).shape) == [4, 1000]
+
+    # Apply autopadding to an input.
+    batch_3 = np.random.normal(size=[3, 3, 224, 224]).astype("float32")
+    assert list(vm.invoke("main", batch_3).shape) == [3, 1000]
+
+
+def test_multiple_inputs():
+    # Create a small relay module with multiple inputs to dispatch over.
+    x = relay.var("x", shape=[10, 10], dtype="float32")
+    w = relay.var("w", shape=[10, 10], dtype="float32")
+    y = x + w
+    mod = tvm.IRModule.from_expr(y)
+
+    # Apply flexible dispatch to dim 1 for both inputs.
+    mod = relay.transform.FlexibleShapeDispatch(axis=1, buckets=[5, 10], 
input_indices=[0, 1])(mod)
+
+    # Compile and confirm that output shapes are correct.
+    exe = relay.vm.compile(mod, "llvm")
+    vm = runtime.vm.VirtualMachine(exe, tvm.cpu())
+
+    x_w_5 = np.random.normal(size=[10, 5]).astype("float32")
+    assert list(vm.invoke("main", x_w_5, x_w_5).shape) == [10, 5]
+
+    x_w_10 = np.random.normal(size=[10, 10]).astype("float32")
+    assert list(vm.invoke("main", x_w_10, x_w_10).shape) == [10, 10]
+
+
+def test_fixed_output():
+    # Test a graph where the output shape is not based on input dynamism.
+    x = relay.var("x", shape=[10, 10], dtype="float32")
+    w = relay.var("w", shape=[10, 10], dtype="float32")
+    y = relay.nn.dense(x, w)
+    mod = tvm.IRModule.from_expr(y)
+
+    # Apply flexible dispatch to dimension 1 for both inputs.
+    mod = relay.transform.FlexibleShapeDispatch(
+        axis=1, buckets=[5, 7], input_indices=[0, 1], affects_output=False
+    )(mod)
+
+    # Compile and confirm that output shapes are correct.
+    exe = relay.vm.compile(mod, "llvm")
+    vm = runtime.vm.VirtualMachine(exe, tvm.cpu())
+
+    x_w_5 = np.random.normal(size=[10, 5]).astype("float32")
+    assert list(vm.invoke("main", x_w_5, x_w_5).shape) == [10, 10]
+
+    x_w_7 = np.random.normal(size=[10, 7]).astype("float32")
+    assert list(vm.invoke("main", x_w_7, x_w_7).shape) == [10, 10]
+
+    return
+
+
+def test_multiple_outputs():
+    # Create a graph with multiple outputs and test that it works.
+    x = relay.var("x", shape=[10, 10], dtype="float32")
+    y = relay.split(x, 2, axis=1)
+    mod = tvm.IRModule.from_expr(y.astuple())
+
+    # Apply flexible dispatch to batch dimension.
+    mod = relay.transform.FlexibleShapeDispatch(axis=0, buckets=[5, 10])(mod)
+
+    # Compile and confirm that both outputs are correct.
+    exe = relay.vm.compile(mod, "llvm")
+    vm = runtime.vm.VirtualMachine(exe, tvm.cpu())
+
+    x_5 = np.random.normal(size=[5, 10]).astype("float32")
+    result_5 = vm.invoke("main", x_5)
+    assert list(result_5[0].shape) == [5, 5]
+    assert list(result_5[1].shape) == [5, 5]
+
+    x_10 = np.random.normal(size=[10, 10]).astype("float32")
+    result_10 = vm.invoke("main", x_10)
+    assert list(result_10[0].shape) == [10, 5]
+    assert list(result_10[1].shape) == [10, 5]
+
+
+if __name__ == "__main__":
+    pytest.main([__file__])
diff --git a/tests/python/unittest/test_autotvm_record.py 
b/tests/python/unittest/test_autotvm_record.py
index 65739df52c..2ee75cf18c 100644
--- a/tests/python/unittest/test_autotvm_record.py
+++ b/tests/python/unittest/test_autotvm_record.py
@@ -72,6 +72,11 @@ def test_file_io():
     for x, y in zip(ref, autotvm.record.load_from_file(file_path)):
         assert x[1] == y[1]
 
+    # Confirm functionality of multiple file loads
+    hist_best = ApplyHistoryBest([file_path, file_path])
+    x = hist_best.query(target, tsk.workload)
+    assert str(x) == str(inputs[0][2])
+
 
 def test_apply_history_best():
     tsk, target = get_sample_task()

Reply via email to