This is an automated email from the ASF dual-hosted git repository.
jwfromm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 98aa41e329 [Relay] Flexible shape dispatch transformation (#11199)
98aa41e329 is described below
commit 98aa41e329c20a5b8b34a34387fcc9067db5f22a
Author: Josh Fromm <[email protected]>
AuthorDate: Fri May 6 12:18:16 2022 -0700
[Relay] Flexible shape dispatch transformation (#11199)
* Added pass that creates a semi-dynamic dispatcher around a relay module.
* Added automatic padding feature.
* Output slicing working.
* Multiple input support working i think.
* Added test file.
* Improve comments.
* Fix lint.
* Allow default values.
* Fix docstring.
* Improved documentation based on feedback.
* Add extra check for record loading.
* Improve variable names.
* Add type inference to make sure things worked.
* Added support for multiple outputs.
---
python/tvm/auto_scheduler/dispatcher.py | 29 +-
python/tvm/autotvm/task/dispatcher.py | 35 +-
python/tvm/relay/transform/__init__.py | 1 +
python/tvm/relay/transform/flexible_shape.py | 369 +++++++++++++++++++++
tests/python/relay/test_auto_scheduler_tuning.py | 6 +
.../relay/test_pass_flexible_shape_dispatch.py | 119 +++++++
tests/python/unittest/test_autotvm_record.py | 5 +
7 files changed, 545 insertions(+), 19 deletions(-)
diff --git a/python/tvm/auto_scheduler/dispatcher.py
b/python/tvm/auto_scheduler/dispatcher.py
index cc1e76b9fa..eceeba38e0 100644
--- a/python/tvm/auto_scheduler/dispatcher.py
+++ b/python/tvm/auto_scheduler/dispatcher.py
@@ -130,11 +130,13 @@ class ApplyHistoryBest(DispatchContext):
Parameters
----------
- records : str or iterator of (auto_scheduler.measure.MeasureInput,\
- auto_scheduler.measure.MeasureResult)
+ records : str, list of str, or iterator of
(auto_scheduler.measure.MeasureInput,\
+
auto_scheduler.measure.MeasureResult)
Collection of tuning records.
If is str, then it should be the filename of a records log file.
- Each row of this file is an encoded record pair. Otherwise, it is an
iterator.
+ Each row of this file is an encoded record pair. If it is an iterator,
+ it can either be a set of str filenames which will be applied jointly,
+ or a set of (input, result) tuples.
n_lines: Optional[int]
if it is not None, only load the first `n_lines` lines of log.
include_compatible: bool
@@ -196,20 +198,29 @@ class ApplyHistoryBest(DispatchContext):
n_lines: Optional[int]
if it is not None, only load the first `n_lines` lines of log
"""
- if isinstance(records, pathlib.Path):
- records = str(records)
+ joint_records = []
+ if not isinstance(records, (list, tuple)):
+ records = [records]
- if isinstance(records, str):
- records = load_records(records)
+ for rec in records:
+ if isinstance(rec, pathlib.Path):
+ rec = str(rec)
+
+ if isinstance(rec, str):
+ rec = load_records(rec)
+ joint_records += rec
+ else:
+ if rec is not None:
+ joint_records.append(rec)
- if not records:
+ if not joint_records:
return
best_by_targetkey = self.best_by_targetkey
best_by_model = self.best_by_model
counter = 0
- for inp, res in records:
+ for inp, res in joint_records:
if n_lines is not None and counter >= n_lines:
break
counter += 1
diff --git a/python/tvm/autotvm/task/dispatcher.py
b/python/tvm/autotvm/task/dispatcher.py
index bed0258127..ffff50b9dc 100644
--- a/python/tvm/autotvm/task/dispatcher.py
+++ b/python/tvm/autotvm/task/dispatcher.py
@@ -184,10 +184,12 @@ class ApplyHistoryBest(DispatchContext):
Parameters
----------
- records : str or iterator of (autotvm.measure.MeasureInput,
autotvm.measure.MeasureResult)
+ records : str, list of str, or iterator of (autotvm.measure.MeasureInput,\
+ autotvm.measure.MeasureResult)
Collection of tuning records.
If is str, then it should be the filename of a records log file.
- Each row of this file is an encoded record pair. Otherwise, it is an
iterator.
+ Each row of this file is an encoded record pair. If it is a list, it
can either be
+ a list of paths to log files that will be loaded jointly or an
iterator or records.
"""
def __init__(self, records):
@@ -205,28 +207,41 @@ class ApplyHistoryBest(DispatchContext):
Parameters
----------
- records : str or iterator of (autotvm.measure.MeasureInput,
autotvm.measure.MeasureResult)
+ records : str, list of str, or iterator of
(autotvm.measure.MeasureInput,\
+
autotvm.measure.MeasureResult)
Collection of tuning records.
If is str, then it should be the filename of a records log file.
- Each row of this file is an encoded record pair. Otherwise, it is
an iterator.
+ Each row of this file is an encoded record pair. If it is a list
+ it can either be a list of paths to logs that will loaded jointly
or
+ an iterator of measurement results.
"""
# pylint: disable=import-outside-toplevel
from pathlib import Path
from ..record import load_from_file
- if isinstance(records, Path):
- records = str(records)
+ joint_records = []
+ if not isinstance(records, (list, tuple)):
+ records = [records]
- if isinstance(records, str):
- records = load_from_file(records)
- if not records:
+ for rec in records:
+ if isinstance(rec, Path):
+ rec = str(rec)
+
+ if isinstance(rec, str):
+ rec = load_from_file(rec)
+ joint_records += rec
+ else:
+ if rec is not None:
+ joint_records.append(rec)
+
+ if not joint_records:
return
best_by_targetkey = self.best_by_targetkey
best_by_model = self.best_by_model
counter = 0
- for inp, res in records:
+ for inp, res in joint_records:
counter += 1
if res.error_no != 0:
continue
diff --git a/python/tvm/relay/transform/__init__.py
b/python/tvm/relay/transform/__init__.py
index 378b0c38ff..c10b8f8ff3 100644
--- a/python/tvm/relay/transform/__init__.py
+++ b/python/tvm/relay/transform/__init__.py
@@ -20,3 +20,4 @@
from .transform import *
from .recast import recast
from . import fake_quantization_to_integer, mixed_precision
+from .flexible_shape import FlexibleShapeDispatch
diff --git a/python/tvm/relay/transform/flexible_shape.py
b/python/tvm/relay/transform/flexible_shape.py
new file mode 100644
index 0000000000..c38fde0e70
--- /dev/null
+++ b/python/tvm/relay/transform/flexible_shape.py
@@ -0,0 +1,369 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Relay functions for wrapping a module with flexible shape dispatch."""
+import tvm
+from tvm import relay
+
+
+def override_shape(tensor_type, axis, dim):
+ """Change a dimension in a tensor shape."""
+ # Handle multiple tensors by overriding the shape of each.
+ if isinstance(tensor_type, relay.TupleType):
+ tensor_type = tensor_type.fields
+ else:
+ tensor_type = [tensor_type]
+
+ # Create new tensortypes for each input.
+ new_types = []
+ for t_type in tensor_type:
+ new_dims = list(t_type.shape)
+ new_dims[axis] = dim
+ new_types.append(relay.TensorType(new_dims, t_type.dtype))
+
+ # Dont return a tuple if there is a single tensor.
+ if len(new_types) == 1:
+ return new_types[0]
+ return relay.TupleType(tvm.runtime.convert(new_types))
+
+
+def specialize_body(mod, function, axis, dim, input_indices,
affects_output=True):
+ """
+ Create a subgraph to handle specific input shapes
+
+ This function takes in a module and one of it's functions and creates a
+ similar function with a specific input shape. It then attaches the new
function
+ to the module. Calling this function multiple times results in a module
that
+ contains several similar functions each specialized to a specific input
shape.
+ This allows a dispatch handler to be built on top of the module to deal
with
+ flexible shapes.
+
+ There are a few modes to this function. When the specialized function has
multiple
+ flexible inputs, the index of those inputs must be provided to the
input_indices argument.
+ In this case, the axis of the flexible dimension for each of those inputs
must be the same.
+
+ By default, this function assumes that the output shape is dependent on
the input
+ shape (as is the case in dynamic batching) and will also specialize the
output type
+ accordingly. If this is not true, the affects_output argument must be set
to False.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The module that contains specialized functions and the dispatcher.
+ function: Function
+ The original non-specialized function that will be transformed.
+ axis: int
+ Which axis the flexible shape is on.
+ dim: int
+ The shape to specialize the new subgraph for along the axis dim.
+ input_indices: List[int]
+ Which inputs should be dispatched dynamically, provided by index. All
inputs
+ must share the same dynamic axis.
+ affects_output: Optional[bool]
+ Whether the change in input shape has a corresponding effect on the
output shape.
+ Batching for example effects both the input and output whereas
changing sequence
+ length in an NLP model typically does not.
+
+ Returns
+ -------
+ gvar : GlobalVar
+ The new variable for the specialized subgraph.
+ spec_types : List[TensorType]
+ A list of the new specialized types for each input in the graph.
+ """
+ # Iterate through specified inputs and construct specialized shapes for
each.
+ new_params = list(function.params)
+ data_binding = {}
+ dyn_data_array = []
+ for inp in input_indices:
+ data = function.params[inp]
+ flex_ty = override_shape(data.type_annotation, axis, dim)
+ dyn_data = relay.Var(data.name_hint, type_annotation=flex_ty)
+ new_params[inp] = dyn_data
+ data_binding[data] = dyn_data
+ dyn_data_array.append(dyn_data)
+
+ # Create a new function body for the modified shapes.
+ new_body = relay.expr.bind(function.body, data_binding)
+ # Only change the output shape if the input shape affects it.
+ if affects_output:
+ new_ret_ty = override_shape(function.ret_type, axis, dim)
+ else:
+ new_ret_ty = function.ret_type
+ gvar = relay.GlobalVar("main_" + str(dim))
+ # Add the new function to the main IRModule.
+ mod[gvar] = relay.Function(
+ new_params, new_body, new_ret_ty, function.type_params, function.attrs
+ )
+ return gvar, [d.type_annotation for d in dyn_data_array]
+
+
+def flexible_dispatch(
+ mod, buckets, axis=0, auto_pad=False, pad_value=0, input_indices=None,
affects_output=True
+):
+ """
+ Enable inference of multiple shaped inputs in one module.
+
+ This transformation adds a handler around a module that
+ checks input shapes and dispatches to a subgraph specialized
+ to handle the specific shapes of that input. If no exactly matching
+ subgraph is available, the input will be run using full dynamism.
+ For best performance, specify all the sizes the module will
+ be likely to see using the buckets argument.
+
+ By default, this function will dispatch shapes that exactly match one
+ of the buckets to a corresponding subgraph. All non-matching shapes
+ use the same fully dynamic fallback. This can be detrimental to performance
+ for those non-matching shapes. Setting auto_pad to True causes this
+ function to round-up the shape of non-matching inputs to the closest
+ bucket. This allows them to use the tuned kernels of bucket shapes
+ which can improve performance.
+
+ Functions that have multiple inputs sharing a dynamic axis, which
+ is common for batch size or sequence length dynamism, are supported
+ through the input_indices argument.
+
+ Many types of dynamism such as batching affect both the input and output
+ shape, however this is not always the case. If the output shape
+ is independent of the input, the affects_output argument of this
+ function must be set to False.
+
+ Parameters
+ ----------
+ buckets: list[int]
+ The sizes of the input dimension that should be explicitly handled.
+ Each value in buckets will have a corresponding subgraph constructed to
+ handle it.
+ axis: int
+ The dimension of the input that should be made flexible. This will
+ most often be used for the batch dimension.
+ auto_pad: Optional[bool]
+ If True, then padding will be inserted to values that don't match one
of
+ the provided buckets.
+ pad_value: Optional[float]
+ When auto_pad is true, padding will be done with this value.
+ input_indices: Optional[List[int]]
+ Which inputs should be dispatched dynamically, provided by index. All
inputs
+ must share the same dynamic axis.
+ affects_output: Optional[bool]
+ Whether the change in input shape has a corresponding effect on the
output shape.
+ Batching for example effects both the input and output whereas
changing sequence
+ length in an NLP model typically does not.
+
+ Returns
+ -------
+ mod : IRModule
+ The new module wrapped with a flexible shape dispatch handler.
+ """
+ main_fn = mod["main"]
+
+ # Default to single input if not specified.
+ if input_indices is None:
+ input_indices = [0]
+
+ # Extract all input data and create a new dynamic variable for each.
+ data = []
+ dyn_data = []
+ for i in input_indices:
+ data.append(main_fn.params[i])
+ dyn_shape = override_shape(data[i].type_annotation, axis, relay.Any())
+ dyn_data.append(relay.Var(data[i].name_hint,
type_annotation=dyn_shape))
+
+ # Extract the dynamic shape value from one of the inputs.
+ rt_sh = relay.op.shape_of(dyn_data[0])
+ flex_value = relay.op.take(rt_sh, relay.const(axis))
+
+ if_exprs = []
+
+ for i, bucket in enumerate(buckets):
+ input_data = dyn_data
+ check_dim = flex_value
+
+ # Apply automatic padding if specified.
+ if auto_pad:
+ input_data = []
+ # Construct padding expression for inputs.
+ for j, inp in enumerate(dyn_data):
+ pad_width = relay.const(bucket) - flex_value
+ rank = len(data[j].type_annotation.shape)
+ pads = relay.zeros([rank, 2], "int32")
+ pads = relay.scatter_nd(pads, relay.const([axis, 1]),
pad_width)
+ padded_value = relay.nn.pad(inp, pads, pad_value)
+
+ # Determine if this is the proper bucket to pad to. Do this by
checking if the
+ # input shape is between this bucket and the previous.
+ if i == 0:
+ padded_value = relay.If(
+ relay.op.less_equal(flex_value, relay.const(bucket)),
padded_value, inp
+ )
+ else:
+ padded_value = relay.If(
+ relay.op.logical_and(
+ relay.op.less_equal(flex_value,
relay.const(bucket)),
+ relay.op.greater(flex_value, relay.const(buckets[i
- 1])),
+ ),
+ padded_value,
+ inp,
+ )
+ # Update input value and test dimension to reflect possible
padding.
+ input_data.append(padded_value)
+ # Grab the new possibly padded shape for checking bucket size.
+ check_dim = relay.op.take(relay.op.shape_of(input_data[0]),
relay.const(axis))
+
+ # Create a specialized subgraph for the current bucket.
+ spec_call, spec_ty = specialize_body(
+ mod, main_fn, axis, bucket, input_indices=input_indices,
affects_output=affects_output
+ )
+ # Apply hard casting to shape to create statically typed graphs.
+ spec_data = []
+ for j, inp in enumerate(input_data):
+ spec_data.append(relay.op.reshape(inp, spec_ty[j].shape))
+
+ # Create a dispatch statement for the current specialized graph.
+ call_args = list(main_fn.params)
+ for j, inp in enumerate(input_indices):
+ call_args[inp] = spec_data[j]
+ new_call = spec_call(*call_args)
+
+ # Remove meaningless padded outputs if applicable.
+ if auto_pad and affects_output:
+ new_call = relay.take(
+ new_call,
+ relay.arange(start=relay.const(0), stop=flex_value,
dtype="int32"),
+ axis=axis,
+ )
+
+ # Add this new case to the dispatch handler.
+ if_exprs.append((relay.op.equal(check_dim, relay.const(bucket)),
new_call))
+
+ # Create a subgraph to handle all other shapes.
+ default_dyn_call, _ = specialize_body(
+ mod, main_fn, axis, relay.Any(), input_indices=input_indices,
affects_output=affects_output
+ )
+ call_args = list(main_fn.params)
+ for j, inp in enumerate(input_indices):
+ call_args[inp] = dyn_data[j]
+ new_body = default_dyn_call(*call_args)
+
+ # Create an If chain to dispatch shapes to the appropriate specialized
subgraph.
+ for cond, true_branch in if_exprs:
+ new_body = relay.If(cond, true_branch, new_body)
+
+ # Assign new parameters to the function.
+ new_params = list(main_fn.params)
+ for j, inp in enumerate(input_indices):
+ new_params[inp] = dyn_data[j]
+
+ # Update the output shape to be dynamic if needed.
+ if affects_output:
+ dyn_ret_type = override_shape(main_fn.ret_type, axis, relay.Any())
+ else:
+ dyn_ret_type = main_fn.ret_type
+
+ # Assign the handler as the new entrypoint in the module.
+ new_main = relay.Function(
+ new_params, new_body, dyn_ret_type, main_fn.type_params, main_fn.attrs
+ )
+ mod["main"] = new_main
+ # Do type inference to make sure everything worked.
+ mod = relay.transform.InferType()(mod)
+ return mod
+
+
+class FlexibleShapeDispatch(object):
+ """Enable inference of multiple shaped inputs in one module.
+
+ This transformation adds a handler around a module that
+ checks input shapes and dispatches to a subgraph specialized
+ to handle the specific shapes of that input. If no exactly matching
+ subgraph is available, the input will be run using full dynamism.
+ For best performance, specify all the sizes the module will
+ be likely to see using the buckets argument.
+
+ By default, this pass will dispatch shapes that exactly match one
+ of the buckets to a corresponding subgraph. All non-matching shapes
+ use the same fully dynamic fallback. This can be detrimental to performance
+ for those non-matching shapes. Setting auto_pad to True causes this
+ pass to round-up the shape of non-matching inputs to the closest
+ bucket. This allows them to use the tuned kernels of bucket shapes
+ which can improve performance.
+
+ Models that have multiple inputs sharing a dynamic axis, which
+ is common for batch size or sequence length dynamism, are supported
+ through the input_indices argument.
+
+ Many types of dynamism such as batching affect both the input and output
+ shape, however this is not always the case. If the output shape
+ is independent of the input, the affects_output argument of this
+ pass must be set to False.
+
+ Parameters
+ ----------
+ buckets: list[int]
+ The sizes of the input dimension that should be explicitly handled.
+ Each value in buckets will have a corresponding subgraph constructed to
+ handle it.
+ axis: int
+ The dimension of the input that should be made flexible. This will
+ most often be used for the batch dimension.
+ auto_pad: Optional[bool]
+ If True, then padding will be inserted to values that don't match one
of
+ the provided buckets.
+ pad_value: Optional[float]
+ When auto_pad is true, padding will be done with this value.
+ input_indices: Optional[List[int]]
+ Which inputs should be dispatched dynamically, provided by index. All
inputs
+ must share the same dynamic axis.
+ affects_output: Optional[bool]
+ Whether the change in input shape has a corresponding effect on the
output shape.
+ Batching for example effects both the input and output whereas
changing sequence
+ length in an NLP model typically does not.
+
+ Returns
+ -------
+ ret : FlexibleShapeDispatch
+ A pass that can be applied to a module to add flexible shape handling.
+ """
+
+ def __init__(
+ self,
+ buckets,
+ axis=0,
+ auto_pad=False,
+ pad_value=0,
+ input_indices=None,
+ affects_output=True,
+ ):
+ self.axis = axis
+ self.buckets = buckets
+ self.auto_pad = auto_pad
+ self.pad_value = pad_value
+ self.input_indices = input_indices
+ self.affects_output = affects_output
+ super(FlexibleShapeDispatch, self).__init__()
+
+ def __call__(self, mod):
+ # Shape information is required for this pass.
+ mod = relay.transform.InferType()(mod)
+ return flexible_dispatch(
+ mod,
+ self.buckets,
+ self.axis,
+ self.auto_pad,
+ self.pad_value,
+ self.input_indices,
+ self.affects_output,
+ )
diff --git a/tests/python/relay/test_auto_scheduler_tuning.py
b/tests/python/relay/test_auto_scheduler_tuning.py
index 1431824899..c9ce5b59ff 100644
--- a/tests/python/relay/test_auto_scheduler_tuning.py
+++ b/tests/python/relay/test_auto_scheduler_tuning.py
@@ -56,6 +56,12 @@ def tune_network(network, target):
):
lib = relay.build(mod, target=target, params=params)
+ # Also test that multiple log files can be loaded.
+ with auto_scheduler.ApplyHistoryBest([log_file, log_file]) as best:
+ assert isinstance(
+ best, auto_scheduler.dispatcher.ApplyHistoryBest
+ ), "Unable to load multiple log files jointly."
+
# Sample a schedule when missing
with auto_scheduler.ApplyHistoryBestOrSample(None, num_measure=2):
with tvm.transform.PassContext(
diff --git a/tests/python/relay/test_pass_flexible_shape_dispatch.py
b/tests/python/relay/test_pass_flexible_shape_dispatch.py
new file mode 100644
index 0000000000..a6d547f4f5
--- /dev/null
+++ b/tests/python/relay/test_pass_flexible_shape_dispatch.py
@@ -0,0 +1,119 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test flexible shape dispatch pass"""
+import numpy as np
+import pytest
+import tvm
+from tvm import relay
+from tvm.relay.testing.resnet import get_workload
+from tvm.relay import vm
+from tvm import runtime
+
+
+def test_end_to_end():
+ # Load a resnet model.
+ mod, params = get_workload()
+ # Apply flexible dispatch pass.
+ mod = relay.transform.FlexibleShapeDispatch(axis=0, buckets=[1, 4],
auto_pad=True)(mod)
+ # Compile and confirm result supports multiple shapes.
+ exe = relay.vm.compile(mod, "llvm", params=params)
+ vm = runtime.vm.VirtualMachine(exe, tvm.cpu())
+
+ # Evaluate various batch sizes
+ batch_1 = np.random.normal(size=[1, 3, 224, 224]).astype("float32")
+ assert list(vm.invoke("main", batch_1).shape) == [1, 1000]
+
+ batch_4 = np.random.normal(size=[4, 3, 224, 224]).astype("float32")
+ assert list(vm.invoke("main", batch_4).shape) == [4, 1000]
+
+ # Apply autopadding to an input.
+ batch_3 = np.random.normal(size=[3, 3, 224, 224]).astype("float32")
+ assert list(vm.invoke("main", batch_3).shape) == [3, 1000]
+
+
+def test_multiple_inputs():
+ # Create a small relay module with multiple inputs to dispatch over.
+ x = relay.var("x", shape=[10, 10], dtype="float32")
+ w = relay.var("w", shape=[10, 10], dtype="float32")
+ y = x + w
+ mod = tvm.IRModule.from_expr(y)
+
+ # Apply flexible dispatch to dim 1 for both inputs.
+ mod = relay.transform.FlexibleShapeDispatch(axis=1, buckets=[5, 10],
input_indices=[0, 1])(mod)
+
+ # Compile and confirm that output shapes are correct.
+ exe = relay.vm.compile(mod, "llvm")
+ vm = runtime.vm.VirtualMachine(exe, tvm.cpu())
+
+ x_w_5 = np.random.normal(size=[10, 5]).astype("float32")
+ assert list(vm.invoke("main", x_w_5, x_w_5).shape) == [10, 5]
+
+ x_w_10 = np.random.normal(size=[10, 10]).astype("float32")
+ assert list(vm.invoke("main", x_w_10, x_w_10).shape) == [10, 10]
+
+
+def test_fixed_output():
+ # Test a graph where the output shape is not based on input dynamism.
+ x = relay.var("x", shape=[10, 10], dtype="float32")
+ w = relay.var("w", shape=[10, 10], dtype="float32")
+ y = relay.nn.dense(x, w)
+ mod = tvm.IRModule.from_expr(y)
+
+ # Apply flexible dispatch to dimension 1 for both inputs.
+ mod = relay.transform.FlexibleShapeDispatch(
+ axis=1, buckets=[5, 7], input_indices=[0, 1], affects_output=False
+ )(mod)
+
+ # Compile and confirm that output shapes are correct.
+ exe = relay.vm.compile(mod, "llvm")
+ vm = runtime.vm.VirtualMachine(exe, tvm.cpu())
+
+ x_w_5 = np.random.normal(size=[10, 5]).astype("float32")
+ assert list(vm.invoke("main", x_w_5, x_w_5).shape) == [10, 10]
+
+ x_w_7 = np.random.normal(size=[10, 7]).astype("float32")
+ assert list(vm.invoke("main", x_w_7, x_w_7).shape) == [10, 10]
+
+ return
+
+
+def test_multiple_outputs():
+ # Create a graph with multiple outputs and test that it works.
+ x = relay.var("x", shape=[10, 10], dtype="float32")
+ y = relay.split(x, 2, axis=1)
+ mod = tvm.IRModule.from_expr(y.astuple())
+
+ # Apply flexible dispatch to batch dimension.
+ mod = relay.transform.FlexibleShapeDispatch(axis=0, buckets=[5, 10])(mod)
+
+ # Compile and confirm that both outputs are correct.
+ exe = relay.vm.compile(mod, "llvm")
+ vm = runtime.vm.VirtualMachine(exe, tvm.cpu())
+
+ x_5 = np.random.normal(size=[5, 10]).astype("float32")
+ result_5 = vm.invoke("main", x_5)
+ assert list(result_5[0].shape) == [5, 5]
+ assert list(result_5[1].shape) == [5, 5]
+
+ x_10 = np.random.normal(size=[10, 10]).astype("float32")
+ result_10 = vm.invoke("main", x_10)
+ assert list(result_10[0].shape) == [10, 5]
+ assert list(result_10[1].shape) == [10, 5]
+
+
+if __name__ == "__main__":
+ pytest.main([__file__])
diff --git a/tests/python/unittest/test_autotvm_record.py
b/tests/python/unittest/test_autotvm_record.py
index 65739df52c..2ee75cf18c 100644
--- a/tests/python/unittest/test_autotvm_record.py
+++ b/tests/python/unittest/test_autotvm_record.py
@@ -72,6 +72,11 @@ def test_file_io():
for x, y in zip(ref, autotvm.record.load_from_file(file_path)):
assert x[1] == y[1]
+ # Confirm functionality of multiple file loads
+ hist_best = ApplyHistoryBest([file_path, file_path])
+ x = hist_best.query(target, tsk.workload)
+ assert str(x) == str(inputs[0][2])
+
def test_apply_history_best():
tsk, target = get_sample_task()