This is an automated email from the ASF dual-hosted git repository.
jwfromm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8775a80 [TVMC] Support compiling and running with VM (#10722)
8775a80 is described below
commit 8775a805d17d9e46b1cdcadd00f22d3471270ef1
Author: Margaret Qian <[email protected]>
AuthorDate: Thu Mar 31 12:30:09 2022 -0400
[TVMC] Support compiling and running with VM (#10722)
* introduce vm compile path
* support vm in tvmc
* cleanup + lint
* add profiler + simplify vm case in tvmcpackage
* address comments + parametrize tests
Co-authored-by: Margaret Qian <[email protected]>
---
python/tvm/driver/tvmc/compiler.py | 77 +++++++++++++++---
python/tvm/driver/tvmc/model.py | 94 ++++++++++++++++++----
python/tvm/driver/tvmc/runner.py | 127 +++++++++++++++++++-----------
tests/python/driver/tvmc/test_compiler.py | 50 ++++++------
tests/python/driver/tvmc/test_model.py | 21 +++--
tests/python/driver/tvmc/test_runner.py | 8 +-
6 files changed, 274 insertions(+), 103 deletions(-)
diff --git a/python/tvm/driver/tvmc/compiler.py
b/python/tvm/driver/tvmc/compiler.py
index 66c7b35..8f24dd4 100644
--- a/python/tvm/driver/tvmc/compiler.py
+++ b/python/tvm/driver/tvmc/compiler.py
@@ -201,6 +201,7 @@ def compile_model(
disabled_pass: Optional[str] = None,
pass_context_configs: Optional[List[str]] = None,
additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None,
+ use_vm: bool = False,
):
"""Compile a model from a supported framework into a TVM module.
@@ -248,7 +249,8 @@ def compile_model(
PassContext.
additional_target_options: Optional[Dict[str, Dict[str, Any]]]
Additional target options in a dictionary to combine with initial
Target arguments
-
+ use_vm: bool
+ Whether to use the VM to compile the model as opposed to the graph
executor
Returns
-------
@@ -291,8 +293,13 @@ def compile_model(
opt_level=opt_level, config=config,
disabled_pass=disabled_pass
):
logger.debug("building relay graph with autoscheduler")
- graph_module = relay.build(
- mod, target=tvm_target, executor=executor,
runtime=runtime, params=params
+ graph_module = build(
+ mod,
+ tvm_target=tvm_target,
+ executor=executor,
+ runtime=runtime,
+ params=params,
+ use_vm=use_vm,
)
else:
with autotvm.apply_history_best(tuning_records):
@@ -300,16 +307,26 @@ def compile_model(
opt_level=opt_level, config=config,
disabled_pass=disabled_pass
):
logger.debug("building relay graph with tuning records")
- graph_module = relay.build(
- mod, target=tvm_target, executor=executor,
runtime=runtime, params=params
+ graph_module = build(
+ mod,
+ tvm_target=tvm_target,
+ executor=executor,
+ runtime=runtime,
+ params=params,
+ use_vm=use_vm,
)
else:
with tvm.transform.PassContext(
opt_level=opt_level, config=config, disabled_pass=disabled_pass
):
logger.debug("building relay graph (no tuning records provided)")
- graph_module = relay.build(
- mod, target=tvm_target, executor=executor, runtime=runtime,
params=params
+ graph_module = build(
+ mod,
+ tvm_target=tvm_target,
+ executor=executor,
+ runtime=runtime,
+ params=params,
+ use_vm=use_vm,
)
# Generate output dump files with sources
@@ -319,7 +336,10 @@ def compile_model(
dump_code = [dump_code]
dumps = {}
for source_type in dump_code:
- lib = graph_module.get_lib()
+ if use_vm:
+ lib = graph_module.lib
+ else:
+ lib = graph_module.get_lib()
# TODO lib.get_source call have inconsistent behavior for unsupported
# formats (@leandron).
source = str(mod) if source_type == "relay" else
lib.get_source(source_type)
@@ -327,11 +347,7 @@ def compile_model(
# Create a new tvmc model package object from the graph definition.
package_path = tvmc_model.export_package(
- graph_module,
- package_path,
- cross,
- cross_options,
- output_format,
+ graph_module, package_path, cross, cross_options, output_format
)
# Write dumps to file.
@@ -341,6 +357,41 @@ def compile_model(
return TVMCPackage(package_path)
+def build(
+ mod: tvm.IRModule,
+ tvm_target: str,
+ executor: Executor,
+ runtime: Runtime,
+ params: Dict[str, tvm.nd.NDArray],
+ use_vm: bool,
+):
+ """
+ Builds the model with the provided executor.
+
+ Parameters
+ ----------
+ mod : tvm.IRModule
+ The relay module corresponding to this model.
+ tvm_target : str
+ The target for which to compile. Can be a plain string or
+ a path.
+ executor : Executor
+ The graph executor to build the model if use_vm is not True
+ runtime : Runtime
+ The runtime configuration.
+ params : dict
+ A parameter dictionary for the model.
+ use_vm: bool
+ Whether to use the VM to compile the model as opposed to the graph
executor
+
+ """
+ if use_vm:
+ logger.debug("building with vm compile")
+ return relay.vm.compile(mod, target=tvm_target, params=params)
+ logger.debug("building with relay build")
+ return relay.build(mod, target=tvm_target, executor=executor,
runtime=runtime, params=params)
+
+
def save_dumps(module_name: str, dumps: Dict[str, str], dump_root: str = "."):
"""
Serialize dump files to the disk.
diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py
index 9a2617f..93ca27c 100644
--- a/python/tvm/driver/tvmc/model.py
+++ b/python/tvm/driver/tvmc/model.py
@@ -57,6 +57,8 @@ from tvm.contrib import utils
from tvm.driver.tvmc import TVMCException
from tvm.relay.backend.executor_factory import GraphExecutorFactoryModule
from tvm.runtime.module import BenchmarkResult
+from tvm.runtime.vm import Executable
+
try:
from tvm.micro import export_model_library_format
@@ -182,6 +184,42 @@ class TVMCModel(object):
"""
return self._tmp_dir.relpath("model_package.tar")
+ def export_vm_format(
+ self,
+ vm_exec: Executable,
+ package_path: Optional[str] = None,
+ lib_format: str = "so",
+ ):
+ """Save this TVMCModel compiled via vm to file.
+ Parameters
+ ----------
+ vm_exec : vm.Executable
+ The VM Executable containing compiled the compiled artifacts
needed to run this model.
+ package_path : str, None
+ Where the model should be saved. Note that it will be packaged as
a .tar file.
+ If not provided, the package will be saved to a generically named
file in tmp.
+ lib_format : str
+ How to export the modules function library. Must be one of "so" or
"tar".
+
+ Returns
+ -------
+ package_path : str
+ The path that the package was saved to.
+ """
+ lib_name = "lib." + lib_format
+ temp = self._tmp_dir
+ if package_path is None:
+ package_path = self.default_package_path()
+
+ path_lib = temp.relpath(lib_name)
+ vm_exec.mod.export_library(path_lib)
+ self.lib_path = path_lib
+ # Package up all the temp files into a tar file.
+ with tarfile.open(package_path, "w") as tar:
+ tar.add(path_lib, lib_name)
+
+ return package_path
+
def export_classic_format(
self,
executor_factory: GraphExecutorFactoryModule,
@@ -248,7 +286,7 @@ class TVMCModel(object):
def export_package(
self,
- executor_factory: GraphExecutorFactoryModule,
+ executor_factory: Union[GraphExecutorFactoryModule, Executable],
package_path: Optional[str] = None,
cross: Optional[Union[str, Callable]] = None,
cross_options: Optional[str] = None,
@@ -281,7 +319,9 @@ class TVMCModel(object):
if output_format == "mlf" and cross:
raise TVMCException("Specifying the MLF output and a cross
compiler is not supported.")
- if output_format in ["so", "tar"]:
+ if isinstance(executor_factory, Executable):
+ package_path = self.export_vm_format(executor_factory,
package_path, output_format)
+ elif output_format in ["so", "tar"]:
package_path = self.export_classic_format(
executor_factory, package_path, cross, cross_options,
output_format
)
@@ -314,9 +354,16 @@ class TVMCPackage(object):
project_dir : Path, str
If given and loading a MLF file, the path to the project directory
that contains the file.
+
+ use_vm : bool
+ Whether the graph module was compiled with vm or not.
"""
- def __init__(self, package_path: str, project_dir: Optional[Union[Path,
str]] = None):
+ def __init__(
+ self,
+ package_path: str,
+ project_dir: Optional[Union[Path, str]] = None,
+ ):
self._tmp_dir = utils.tempdir()
self.package_path = package_path
self.import_package(self.package_path)
@@ -351,23 +398,40 @@ class TVMCPackage(object):
self.type = "mlf"
else:
# Classic format
- lib_name_so = "mod.so"
- lib_name_tar = "mod.tar"
- if os.path.exists(temp.relpath(lib_name_so)):
- self.lib_name = lib_name_so
- elif os.path.exists(temp.relpath(lib_name_tar)):
- self.lib_name = lib_name_tar
+ classic_lib_name_so = "mod.so"
+ classic_lib_name_tar = "mod.tar"
+
+ # VM format
+ vm_lib_name_so = "lib.so"
+ vm_lib_name_tar = "lib.tar"
+
+ if os.path.exists(temp.relpath(classic_lib_name_so)):
+ self.lib_name = classic_lib_name_so
+ self.type = "classic"
+ elif os.path.exists(temp.relpath(classic_lib_name_tar)):
+ self.lib_name = classic_lib_name_tar
+ self.type = "classic"
+ elif os.path.exists(temp.relpath(vm_lib_name_so)):
+ self.lib_name = vm_lib_name_so
+ self.type = "vm"
+ elif os.path.exists(temp.relpath(vm_lib_name_tar)):
+ self.lib_name = vm_lib_name_tar
+ self.type = "vm"
else:
raise TVMCException("Couldn't find exported library in the
package.")
- self.lib_path = temp.relpath(self.lib_name)
- graph = temp.relpath("mod.json")
- params = temp.relpath("mod.params")
+ self.lib_path = temp.relpath(self.lib_name)
- self.type = "classic"
+ graph, params = None, None
+ if self.type == "classic":
+ graph = temp.relpath("mod.json")
+ params = temp.relpath("mod.params")
- with open(params, "rb") as param_file:
- self.params = bytearray(param_file.read())
+ if params is not None:
+ with open(params, "rb") as param_file:
+ self.params = bytearray(param_file.read())
+ else:
+ self.params = None
if graph is not None:
with open(graph) as graph_file:
diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py
index 8db1272..1b6d823 100644
--- a/python/tvm/driver/tvmc/runner.py
+++ b/python/tvm/driver/tvmc/runner.py
@@ -28,9 +28,11 @@ import numpy as np
import tvm
from tvm import rpc
+from tvm.runtime import vm
from tvm.autotvm.measure import request_remote
from tvm.contrib import graph_executor as executor
from tvm.contrib.debugger import debug_executor
+from tvm.runtime import profiler_vm
from . import TVMCException
from .arguments import TVMCSuppressedArgumentParser
from .project import (
@@ -530,58 +532,93 @@ def run_module(
assert device == "cpu"
dev = session.cpu()
- # TODO(gromero): Adjust for micro targets.
- if profile:
- logger.debug("Creating executor with profiling enabled.")
- module = debug_executor.create(tvmc_package.graph, lib, dev,
dump_root="./prof")
+ if tvmc_package.type == "vm":
+ assert inputs is not None, "vm runner requires inputs to be
provided as a dict"
+
+ input_tensor = {}
+ for e, i in inputs.items():
+ input_tensor[e] = tvm.nd.array(i, dev)
+
+ if profile:
+ logger.debug("Creating vm with profile enabled.")
+ exe = profiler_vm.VirtualMachineProfiler(lib, dev)
+ res = exe.profile(**input_tensor, func_name="main")
+ # This print is intentional
+ print(res)
+ else:
+ exe = vm.VirtualMachine(lib, dev)
+
+ exe_outputs = exe.invoke("main", **input_tensor)
+ times = exe.benchmark(
+ dev,
+ **input_tensor,
+ func_name="main",
+ repeat=repeat,
+ number=number,
+ end_to_end=end_to_end,
+ )
+
+ # Special handling if the output only has a single value
+ if not isinstance(exe_outputs, list):
+ exe_outputs = [exe_outputs]
+
+ outputs = {}
+ for i, val in enumerate(exe_outputs):
+ output_name = "output_{}".format(i)
+ outputs[output_name] = val.numpy()
else:
- if device == "micro":
- logger.debug("Creating executor (micro) with profiling
disabled.")
- module =
tvm.micro.create_local_graph_executor(tvmc_package.graph, lib, dev)
+ # TODO(gromero): Adjust for micro targets.
+ if profile:
+ logger.debug("Creating runtime with profiling enabled.")
+ module = debug_executor.create(tvmc_package.graph, lib, dev,
dump_root="./prof")
else:
- logger.debug("Creating executor with profiling disabled.")
- module = executor.create(tvmc_package.graph, lib, dev)
+ if device == "micro":
+ logger.debug("Creating runtime (micro) with profiling
disabled.")
+ module =
tvm.micro.create_local_graph_executor(tvmc_package.graph, lib, dev)
+ else:
+ logger.debug("Creating runtime with profiling disabled.")
+ module = executor.create(tvmc_package.graph, lib, dev)
- logger.debug("Loading params into the runtime module.")
- module.load_params(tvmc_package.params)
+ logger.debug("Loading params into the runtime module.")
+ module.load_params(tvmc_package.params)
- logger.debug("Collecting graph input shape and type:")
- shape_dict, dtype_dict = module.get_input_info()
- logger.debug("Graph input shape: %s", shape_dict)
- logger.debug("Graph input type: %s", dtype_dict)
+ logger.debug("Collecting graph input shape and type:")
+ shape_dict, dtype_dict = module.get_input_info()
+ logger.debug("Graph input shape: %s", shape_dict)
+ logger.debug("Graph input type: %s", dtype_dict)
- inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs,
fill_mode)
+ inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs,
fill_mode)
- logger.debug("Setting inputs to the module.")
- module.set_input(**inputs_dict)
+ logger.debug("Setting inputs to the module.")
+ module.set_input(**inputs_dict)
- # Run must be called explicitly if profiling
- if profile:
- logger.info("Running the module with profiling enabled.")
- report = module.profile()
- # This print is intentional
- print(report)
+ # Run must be called explicitly if profiling
+ if profile:
+ logger.info("Running the module with profiling enabled.")
+ report = module.profile()
+ # This print is intentional
+ print(report)
- if device == "micro":
- # TODO(gromero): Fix time_evaluator() for micro targets. Once it's
- # fixed module.benchmark() can be used instead and this if/else can
- # be removed.
- module.run()
- times = []
- else:
- # Call the benchmarking function of the executor.
- # Optionally measure e2e data transfers from the
- # CPU to device memory overheads (e.g. PCIE
- # overheads if the device is a discrete GPU).
- if end_to_end:
- dev = session.cpu()
- times = module.benchmark(dev, number=number, repeat=repeat,
end_to_end=end_to_end)
-
- logger.debug("Collecting the output tensors.")
- num_outputs = module.get_num_outputs()
- outputs = {}
- for i in range(num_outputs):
- output_name = "output_{}".format(i)
- outputs[output_name] = module.get_output(i).numpy()
+ if device == "micro":
+ # TODO(gromero): Fix time_evaluator() for micro targets. Once
it's
+ # fixed module.benchmark() can be used instead and this
if/else can
+ # be removed.
+ module.run()
+ times = []
+ else:
+ # Call the benchmarking function of the executor.
+ # Optionally measure e2e data transfers from the
+ # CPU to device memory overheads (e.g. PCIE
+ # overheads if the device is a discrete GPU).
+ if end_to_end:
+ dev = session.cpu()
+ times = module.benchmark(dev, number=number, repeat=repeat,
end_to_end=end_to_end)
+
+ logger.debug("Collecting the output tensors.")
+ num_outputs = module.get_num_outputs()
+ outputs = {}
+ for i in range(num_outputs):
+ output_name = "output_{}".format(i)
+ outputs[output_name] = module.get_output(i).numpy()
return TVMCResult(outputs, times)
diff --git a/tests/python/driver/tvmc/test_compiler.py
b/tests/python/driver/tvmc/test_compiler.py
index 4b21f4e..bc836de 100644
--- a/tests/python/driver/tvmc/test_compiler.py
+++ b/tests/python/driver/tvmc/test_compiler.py
@@ -49,21 +49,32 @@ def test_save_dumps(tmpdir_factory):
# End to end tests for compilation
-def verify_compile_tflite_module(model, shape_dict=None):
- pytest.importorskip("tflite")
- tvmc_model = tvmc.load(model, shape_dict=shape_dict)
- tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll",
desired_layout="NCHW")
- dumps_path = tvmc_package.package_path + ".ll"
-
+def verify_tvmc_package(tvmc_package, dumps_path, use_vm=False):
# check for output types
assert type(tvmc_package) is TVMCPackage
- assert type(tvmc_package.graph) is str
- assert type(tvmc_package.lib_path) is str
- assert type(tvmc_package.params) is bytearray
assert os.path.exists(dumps_path)
+ assert type(tvmc_package.lib_path) is str
+
+ if use_vm:
+ assert tvmc_package.graph is None
+ assert tvmc_package.params is None
+ else:
+ assert type(tvmc_package.graph) is str
+ assert type(tvmc_package.params) is bytearray
+
+
+def verify_compile_tflite_module(model, shape_dict=None, use_vm=False):
+ pytest.importorskip("tflite")
+ tvmc_model = tvmc.load(model, shape_dict=shape_dict)
+ tvmc_package = tvmc.compile(
+ tvmc_model, target="llvm", dump_code="ll", desired_layout="NCHW",
use_vm=use_vm
+ )
+ dumps_path = tvmc_package.package_path + ".ll"
+ verify_tvmc_package(tvmc_package, dumps_path, use_vm=use_vm)
-def test_compile_tflite_module(tflite_mobilenet_v1_1_quant):
[email protected]("use_vm", [True, False])
+def test_compile_tflite_module(use_vm, tflite_mobilenet_v1_1_quant):
# some CI environments wont offer tflite, so skip in case it is not present
pytest.importorskip("tflite")
# Check default compilation.
@@ -71,7 +82,7 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant):
# Check with manual shape override
shape_string = "input:[1,224,224,3]"
shape_dict = tvmc.shape_parser.parse_shape_string(shape_string)
- verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict)
+ verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict,
use_vm=use_vm)
# This test will be skipped if the AArch64 cross-compilation toolchain is not
installed.
@@ -198,28 +209,23 @@ def
test_cross_compile_options_aarch64_keras_module(keras_resnet50):
assert os.path.exists(dumps_path)
-def verify_compile_onnx_module(model, shape_dict=None):
+def verify_compile_onnx_module(model, shape_dict=None, use_vm=False):
# some CI environments wont offer onnx, so skip in case it is not present
pytest.importorskip("onnx")
tvmc_model = tvmc.load(model, shape_dict=shape_dict)
- tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll")
+ tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll",
use_vm=use_vm)
dumps_path = tvmc_package.package_path + ".ll"
-
- # check for output types
- assert type(tvmc_package) is TVMCPackage
- assert type(tvmc_package.graph) is str
- assert type(tvmc_package.lib_path) is str
- assert type(tvmc_package.params) is bytearray
- assert os.path.exists(dumps_path)
+ verify_tvmc_package(tvmc_package, dumps_path, use_vm=use_vm)
-def test_compile_onnx_module(onnx_resnet50):
[email protected]("use_vm", [True, False])
+def test_compile_onnx_module(use_vm, onnx_resnet50):
# Test default compilation
verify_compile_onnx_module(onnx_resnet50)
# Test with manual shape dict
shape_string = "data:[1,3,200,200]"
shape_dict = tvmc.shape_parser.parse_shape_string(shape_string)
- verify_compile_onnx_module(onnx_resnet50, shape_dict)
+ verify_compile_onnx_module(onnx_resnet50, shape_dict, use_vm=use_vm)
# This test will be skipped if the AArch64 cross-compilation toolchain is not
installed.
diff --git a/tests/python/driver/tvmc/test_model.py
b/tests/python/driver/tvmc/test_model.py
index 5fccfea..74c1c4d 100644
--- a/tests/python/driver/tvmc/test_model.py
+++ b/tests/python/driver/tvmc/test_model.py
@@ -17,6 +17,7 @@
import platform
import pytest
import os
+import numpy as np
from os import path
@@ -29,13 +30,22 @@ from tvm.runtime.module import BenchmarkResult
platform.machine() == "aarch64",
reason="Currently failing on AArch64 - see
https://github.com/apache/tvm/issues/10673",
)
-def test_tvmc_workflow(keras_simple):
[email protected]("use_vm", [True, False])
+def test_tvmc_workflow(use_vm, keras_simple):
pytest.importorskip("tensorflow")
+ import tensorflow as tf
+
+ # Reset so the input name remains consistent across unit test runs
+ tf.keras.backend.clear_session()
tvmc_model = tvmc.load(keras_simple)
tuning_records = tvmc.tune(tvmc_model, target="llvm",
enable_autoscheduler=True, trials=2)
- tvmc_package = tvmc.compile(tvmc_model, tuning_records=tuning_records,
target="llvm")
- result = tvmc.run(tvmc_package, device="cpu", end_to_end=True)
+ tvmc_package = tvmc.compile(
+ tvmc_model, tuning_records=tuning_records, target="llvm", use_vm=use_vm
+ )
+ input_dict = {"input_1": np.random.uniform(size=(1, 32, 32,
3)).astype("float32")}
+
+ result = tvmc.run(tvmc_package, device="cpu", end_to_end=True,
inputs=input_dict)
assert type(tvmc_model) is TVMCModel
assert type(tvmc_package) is TVMCPackage
assert type(result) is TVMCResult
@@ -45,7 +55,8 @@ def test_tvmc_workflow(keras_simple):
assert "output_0" in result.outputs.keys()
-def test_save_load_model(keras_simple, tmpdir_factory):
[email protected]("use_vm", [True, False])
+def test_save_load_model(use_vm, keras_simple, tmpdir_factory):
pytest.importorskip("onnx")
tmpdir = tmpdir_factory.mktemp("data")
@@ -55,7 +66,7 @@ def test_save_load_model(keras_simple, tmpdir_factory):
tvmc.tune(tvmc_model, target="llvm", trials=2)
# Create package artifacts
- tvmc.compile(tvmc_model, target="llvm")
+ tvmc.compile(tvmc_model, target="llvm", use_vm=use_vm)
# Save the model to disk
model_path = os.path.join(tmpdir, "saved_model.tar")
diff --git a/tests/python/driver/tvmc/test_runner.py
b/tests/python/driver/tvmc/test_runner.py
index 30ce2c6..3f4ab11 100644
--- a/tests/python/driver/tvmc/test_runner.py
+++ b/tests/python/driver/tvmc/test_runner.py
@@ -72,18 +72,20 @@ def test_get_top_results_keep_results():
assert len(sut[1]) == expected_number_of_results_per_line
[email protected]("use_vm", [True, False])
def test_run_tflite_module__with_profile__valid_input(
- tflite_mobilenet_v1_1_quant, tflite_compile_model, imagenet_cat
+ use_vm, tflite_mobilenet_v1_1_quant, tflite_compile_model, imagenet_cat
):
# some CI environments wont offer TFLite, so skip in case it is not present
pytest.importorskip("tflite")
inputs = np.load(imagenet_cat)
+ input_dict = {"input": inputs["input"].astype("uint8")}
- tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant)
+ tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant,
use_vm=use_vm)
result = tvmc.run(
tflite_compiled_model,
- inputs=inputs,
+ inputs=input_dict,
hostname=None,
device="cpu",
profile=True,