This is an automated email from the ASF dual-hosted git repository.
moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dbd076a [BYORTL][Verilator] update ops and add MobileNet (#7972)
dbd076a is described below
commit dbd076a91b0ffad2a23d8ac14f17c69a686a58d1
Author: Luis Vega <[email protected]>
AuthorDate: Tue May 18 10:05:03 2021 -0700
[BYORTL][Verilator] update ops and add MobileNet (#7972)
* update
* update vta submodule
* cpp fmt
* python fmt
* skip if tflite is not available
* fmt
* change assertion
* update comment
---
3rdparty/vta-hw | 2 +-
src/runtime/contrib/verilator/verilator_kernel.h | 5 +-
src/runtime/contrib/verilator/verilator_runtime.cc | 19 +-
src/runtime/contrib/verilator/verilator_runtime.h | 5 +-
.../contrib/test_verilator/infrastructure.py | 128 ++++++++---
.../contrib/test_verilator/test_mobilenet.py | 240 +++++++++++++++++++++
.../test_verilator/test_verilator_codegen.py | 67 ------
.../contrib/test_verilator/test_verilator_ops.py | 191 ++++++++++++++++
8 files changed, 554 insertions(+), 103 deletions(-)
diff --git a/3rdparty/vta-hw b/3rdparty/vta-hw
index 4319417..dfe9f57 160000
--- a/3rdparty/vta-hw
+++ b/3rdparty/vta-hw
@@ -1 +1 @@
-Subproject commit 43194178b4e570a5f1dd4f3f9d37ee16fc1b65be
+Subproject commit dfe9f572a43d41e0c1ecdf036cea97042a0febfe
diff --git a/src/runtime/contrib/verilator/verilator_kernel.h
b/src/runtime/contrib/verilator/verilator_kernel.h
index f62097c..5735329 100644
--- a/src/runtime/contrib/verilator/verilator_kernel.h
+++ b/src/runtime/contrib/verilator/verilator_kernel.h
@@ -33,9 +33,12 @@ namespace tvm {
namespace runtime {
namespace contrib {
-extern "C" TVM_DLL void verilator_add(VerilatorHandle handle, int* data, int*
weight, int* out,
+extern "C" TVM_DLL void verilator_add(VerilatorHandle handle, int* left, int*
right, int* out,
int p_h_, int p_w_);
+extern "C" TVM_DLL void verilator_bias_add(VerilatorHandle handle, int* data,
int* bias, int* out,
+ int p_n_, int p_c_, int p_h_, int
p_w_);
+
} // namespace contrib
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/contrib/verilator/verilator_runtime.cc
b/src/runtime/contrib/verilator/verilator_runtime.cc
index 5dfb844..85172d4 100644
--- a/src/runtime/contrib/verilator/verilator_runtime.cc
+++ b/src/runtime/contrib/verilator/verilator_runtime.cc
@@ -80,7 +80,7 @@ VerilatorRuntime::~VerilatorRuntime() {
auto dealloc =
reinterpret_cast<VerilatorDeallocFunc>(lib_->GetSymbol("VerilatorDealloc"));
ICHECK(dealloc != nullptr);
dealloc(device_);
- delete lib_;
+ lib_->~VerilatorLibrary();
}
void VerilatorRuntime::SetLibrary(const std::string& lib_path) { lib_path_ =
lib_path; }
@@ -100,7 +100,6 @@ void VerilatorRuntime::Init(const Array<NDArray>& consts) {
ICHECK(reset != nullptr);
read_ =
reinterpret_cast<VerilatorReadFunc>(lib_->GetSymbol("VerilatorRead"));
ICHECK(read_ != nullptr);
- add_op_ =
reinterpret_cast<VerilatorAddFunc>(lib_->GetSymbol("verilator_add"));
// alloc verilator device
device_ = alloc();
@@ -108,7 +107,7 @@ void VerilatorRuntime::Init(const Array<NDArray>& consts) {
// enable profiler
if (prof_enable_) prof_ = VerilatorProfiler::ThreadLocal();
- // reset verilator device.
+ // reset verilator device
reset(device_, reset_cycles_);
CHECK_EQ(consts.size(), const_idx_.size())
@@ -136,11 +135,17 @@ void VerilatorRuntime::Run() {
if (node.GetOpType() == "kernel") {
CHECK_EQ(node.GetOpType(), "kernel");
auto op_name = node.GetOpName();
+ auto entry = node.GetInputs()[0];
+ auto shape = node.GetOpShape()[entry.index_];
if ("add" == op_name) {
- auto entry = node.GetInputs()[0];
- auto shape = nodes_[entry.id_].GetOpShape()[entry.index_];
- ICHECK(add_op_ != nullptr);
- add_op_(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]);
+ auto add =
reinterpret_cast<VerilatorAddFunc>(lib_->GetSymbol("verilator_add"));
+ ICHECK(add != nullptr);
+ add(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0], shape[1]);
+ } else if ("nn.bias_add" == op_name) {
+ auto bias_add =
+
reinterpret_cast<VerilatorBiasAddFunc>(lib_->GetSymbol("verilator_bias_add"));
+ ICHECK(bias_add != nullptr);
+ bias_add(device_, in_ptr[0], in_ptr[1], out_ptr[0], shape[0],
shape[3], shape[1], shape[2]);
} else {
LOG(FATAL) << "Unsupported op: " << op_name;
}
diff --git a/src/runtime/contrib/verilator/verilator_runtime.h
b/src/runtime/contrib/verilator/verilator_runtime.h
index acdaa3b..664a041 100644
--- a/src/runtime/contrib/verilator/verilator_runtime.h
+++ b/src/runtime/contrib/verilator/verilator_runtime.h
@@ -50,8 +50,9 @@ using namespace tvm::runtime::json;
typedef VerilatorHandle (*VerilatorAllocFunc)();
typedef void (*VerilatorDeallocFunc)(VerilatorHandle);
typedef void (*VerilatorResetFunc)(VerilatorHandle, int);
-typedef void (*VerilatorAddFunc)(VerilatorHandle, int*, int*, int*, int, int);
typedef int (*VerilatorReadFunc)(VerilatorHandle, int, int);
+typedef void (*VerilatorAddFunc)(VerilatorHandle, int*, int*, int*, int, int);
+typedef void (*VerilatorBiasAddFunc)(VerilatorHandle, int*, int*, int*, int,
int, int, int);
class VerilatorLibrary : public Library {
public:
@@ -122,8 +123,6 @@ class VerilatorRuntime : public JSONRuntimeBase {
VerilatorProfiler* prof_{nullptr};
/*! \brief the verilator read function */
VerilatorReadFunc read_{nullptr};
- /*! \brief the verilator add op function */
- VerilatorAddFunc add_op_{nullptr};
/*! \brief the verilator reset cycles */
int reset_cycles_{1};
/*! \brief the verilator profiler status */
diff --git a/tests/python/contrib/test_verilator/infrastructure.py
b/tests/python/contrib/test_verilator/infrastructure.py
index cf9f8bd..779f787 100644
--- a/tests/python/contrib/test_verilator/infrastructure.py
+++ b/tests/python/contrib/test_verilator/infrastructure.py
@@ -19,6 +19,7 @@
import os
import sys
import subprocess as sp
+import json
import tvm
from tvm import relay
@@ -48,6 +49,10 @@ def _register_verilator_op(op_name, supported=True):
return _func_wrapper
+_register_verilator_op("add")
+_register_verilator_op("nn.bias_add")
+
+
def skip_test():
"""Skip test if it requires the Verilator codegen and it's not present."""
if not tvm.get_global_func("relay.ext.verilator", True):
@@ -59,8 +64,33 @@ def skip_test():
return False
+def clear_stats():
+ """Clear profiler statistics."""
+ f = tvm.get_global_func("verilator.profiler_clear", True)
+ if f:
+ f()
+
+
+def stats():
+ """Get profiler statistics."""
+
+ x = tvm.get_global_func("verilator.profiler_status")()
+ return json.loads(x)
+
+
def offload(mod):
- """Offload ops based on the registered ops"""
+ """Offload ops based on the registered ops
+
+ Paramters
+ ---------
+ mod : Module
+ The input module.
+
+ Returns
+ -------
+ mod : Module
+ The output module with offloaded ops.
+ """
backend = "verilator"
mod = transform.AnnotateTarget([backend])(mod)
@@ -69,7 +99,7 @@ def offload(mod):
def verilator_app_path():
- """Find verilator hardware app path"""
+ """Create verilator hardware app path."""
cur_dir = os.path.dirname(os.path.realpath(__file__))
return os.path.join(
@@ -82,37 +112,87 @@ def verilator_app_path():
"vta-hw",
"apps",
"verilator",
+ "add",
)
-def compile_hardware():
- """Compile hardware into shared library"""
+def compile_hardware(lanes):
+ """Compile hardware into shared library
+
+ Paramters
+ ---------
+ lanes : Int
+ The number of vector lanes.
+
+ Returns
+ -------
+ path : Str
+ The path of the shared library.
+ """
+ lib_name = "libverilator_{}".format(lanes)
+ lib_name_ext = "{}.so".format(lib_name)
+ lib = os.path.join(verilator_app_path(), lib_name_ext)
+ if not os.path.isfile(lib):
+ opt_lib_name = "LIB_NAME={}".format(lib_name)
+ opt_lanes = "LANES={}".format(lanes)
+ cmd = []
+ cmd.append("make")
+ cmd.append("--directory")
+ cmd.append(verilator_app_path())
+ cmd.append(opt_lib_name)
+ cmd.append(opt_lanes)
+ sp.run(cmd, check=True, stdout=sp.DEVNULL)
+ return lib
+
- cmd = []
- cmd.append("make")
- cmd.append("--directory")
- cmd.append(verilator_app_path())
- sp.run(cmd, check=True)
+def compiler_opts(lib):
+ """Create compiler options
+ Paramters
+ ---------
+ lib : Str
+ The path of the hardware shared library.
-def compile_module(mod):
- """Compile Relay module and hardware library"""
+ Returns
+ -------
+ opts : Dict
+ The compiler options.
+ """
+ opts = {
+ "lib_path": lib,
+ "profiler_enable": True,
+ "profiler_cycle_counter_id": 0,
+ }
+ return opts
- lib = os.path.join(verilator_app_path(), "libverilator.so")
- if not os.path.isfile(lib):
- compile_hardware()
- opts = {"lib_path": lib}
+def run_module(inp, mod, params=None, opts=None):
+ """Compile Relay module and hardware library
- with tvm.transform.PassContext(opt_level=3,
config={"relay.ext.verilator.options": opts}):
- exe = relay.vm.compile(mod, target="llvm", params=None)
- code, lib = exe.save()
- return runtime.vm.Executable.load_exec(code, lib)
+ Paramters
+ ---------
+ inp : Data
+ The input data.
+ mod : Module
+ The relay module.
-def run_module(exe, inputs):
- """Run Relay module"""
+ params : Parameters
+ The model Parameters.
- dev = tvm.cpu()
- vm = runtime.vm.VirtualMachine(exe, dev)
- return vm.run(**inputs)
+ opts : Dict
+ The compiler
+
+ Returns
+ -------
+ out : Data
+ The output data.
+ """
+
+ with tvm.transform.PassContext(opt_level=3,
config={"relay.ext.verilator.options": opts}):
+ lib = relay.vm.compile(mod, target="llvm", params=params)
+ code, lib = lib.save()
+ exe = runtime.vm.Executable.load_exec(code, lib)
+ vm = runtime.vm.VirtualMachine(exe, tvm.cpu())
+ out = vm.run(**inp)
+ return out
diff --git a/tests/python/contrib/test_verilator/test_mobilenet.py
b/tests/python/contrib/test_verilator/test_mobilenet.py
new file mode 100644
index 0000000..8447f19
--- /dev/null
+++ b/tests/python/contrib/test_verilator/test_mobilenet.py
@@ -0,0 +1,240 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm import te, relay, transform
+from tvm.contrib.download import download_testdata
+from tvm.contrib import graph_executor as runtime
+
+import os
+from PIL import Image
+import numpy as np
+
+from test_verilator.infrastructure import (
+ compile_hardware,
+ compiler_opts,
+ offload,
+ clear_stats,
+ stats,
+)
+
+
+def extract(path):
+ """Extract a tgz or gz file.
+
+ Paramters
+ ---------
+ path : Str
+ The path of the compressed file.
+ """
+ import tarfile
+
+ if path.endswith("tgz") or path.endswith("gz"):
+ dir_path = os.path.dirname(path)
+ tar = tarfile.open(path)
+ tar.extractall(path=dir_path)
+ tar.close()
+ else:
+ raise RuntimeError("Could not decompress the file: " + path)
+
+
+def get_real_image(im_height, im_width):
+ """Get a real image.
+
+ Paramters
+ ---------
+ im_height : Int
+ The image height.
+
+ im_width : Int
+ The image width.
+
+ Returns
+ -------
+ data: Data
+ The image array.
+ """
+ repo_base =
"https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/"
+ img_name = "elephant-299.jpg"
+ image_url = os.path.join(repo_base, img_name)
+ img_path = download_testdata(image_url, img_name, module="data")
+ image = Image.open(img_path).resize((im_height, im_width))
+ x = np.array(image).astype("uint8")
+ data = np.reshape(x, (1, im_height, im_width, 3))
+ return data
+
+
+def get_mobilenet_model():
+ """Return mobilenet model."""
+ model_url =
"https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz"
+ model_path = download_testdata(
+ model_url, "mobilenet_v1_1.0_224_quant.tgz", module=["tf", "official"]
+ )
+ model_dir = os.path.dirname(model_path)
+ extract(model_path)
+ tflite_model_file = os.path.join(model_dir,
"mobilenet_v1_1.0_224_quant.tflite")
+ tflite_model_buf = open(tflite_model_file, "rb").read()
+ try:
+ import tflite
+
+ return tflite.Model.GetRootAsModel(tflite_model_buf, 0)
+ except AttributeError:
+ import tflite.Model
+
+ return tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
+
+
+def get_input_tensor_name():
+ """Return input name."""
+ return "input"
+
+
+def compile_model_to_relay(model):
+ """Compile model to relay.
+
+ Paramters
+ ---------
+ model : Model
+ The input model.
+
+ Returns
+ -------
+ mod: Module
+ The relay module.
+
+ params: Parameters
+ The model parameters.
+ """
+ input_tensor = get_input_tensor_name()
+ input_shape = (1, 224, 224, 3)
+ input_dtype = "uint8"
+ mod, params = relay.frontend.from_tflite(
+ model,
+ shape_dict={input_tensor: input_shape},
+ dtype_dict={input_tensor: input_dtype},
+ )
+ return mod, params
+
+
+def run_model(mod, params=None, opts=None):
+ """Run model.
+
+ Paramters
+ ---------
+ mod: Module
+ The relay module.
+
+ params: Parameters
+ The model parameters.
+
+ opts: Dict
+ The compiler options.
+
+ Returns
+ -------
+ out: Data
+ The output data.
+ """
+ with transform.PassContext(opt_level=3,
config={"relay.ext.verilator.options": opts}):
+ lib = relay.build(mod, target="llvm", params=params)
+ module = runtime.GraphModule(lib["default"](tvm.cpu()))
+ image_data = get_real_image(224, 224)
+ input_tensor = get_input_tensor_name()
+ module.set_input(input_tensor, image_data)
+ module.run()
+ out = module.get_output(0).asnumpy()
+ return out
+
+
+def get_labels():
+ """Return labels."""
+ label_file_url = "".join(
+ [
+ "https://raw.githubusercontent.com/",
+ "tensorflow/tensorflow/master/tensorflow/lite/java/demo/",
+ "app/src/main/assets/",
+ "labels_mobilenet_quant_v1_224.txt",
+ ]
+ )
+ label_file = "labels_mobilenet_quant_v1_224.txt"
+ label_path = download_testdata(label_file_url, label_file, module="data")
+ # List of 1001 classes
+ with open(label_path) as f:
+ labels = f.readlines()
+ return labels
+
+
+def check_result(res):
+ """Check prediction."""
+ labels = get_labels()
+ predictions = np.squeeze(res)
+ prediction = np.argmax(predictions)
+ # 387 is the elephant
+ assert prediction == 387
+
+
+def print_test_info(lanes, cycles):
+ """Print test info
+
+ Paramters
+ ---------
+ lanes : Int
+ The number of vector lanes.
+
+ cycles : Int
+ The number of cycles.
+ """
+ print(
+ "[mobilenet] vector-lanes:{} number of cycles:{} spent in
nn.bias_add".format(lanes, cycles)
+ )
+
+
+def is_tflite_available():
+ """Skip test if tensorflow-lite is not installed."""
+ try:
+ import tflite
+
+ return True
+ except:
+ return False
+
+
+def tmobilenet(lanes):
+ """Mobilenet test template.
+ Paramters
+ ---------
+ lanes : Int
+ The number of vector lanes.
+ """
+ if not is_tflite_available():
+ return
+ model = get_mobilenet_model()
+ mod, params = compile_model_to_relay(model)
+ mod = offload(mod)
+ lib = compile_hardware(lanes)
+ opts = compiler_opts(lib)
+ clear_stats()
+ res = run_model(mod, params, opts)
+ values = stats()
+ check_result(res)
+ print_test_info(lanes, values["cycle_counter"])
+
+
+def test_mobilenet():
+ """Mobilenet tests."""
+ tmobilenet(4)
+ tmobilenet(32)
diff --git a/tests/python/contrib/test_verilator/test_verilator_codegen.py
b/tests/python/contrib/test_verilator/test_verilator_codegen.py
deleted file mode 100644
index 664e254..0000000
--- a/tests/python/contrib/test_verilator/test_verilator_codegen.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Verilator codegen tests"""
-
-import numpy as np
-
-import tvm
-from tvm import relay
-
-from test_verilator.infrastructure import (
- _register_verilator_op,
- skip_test,
- compile_module,
- run_module,
- offload,
-)
-
-
-_register_verilator_op("add")
-
-
-def create_module_add(shape, dtype):
- x = relay.var("x", shape=shape, dtype=dtype)
- y = relay.var("y", shape=shape, dtype=dtype)
- z = relay.add(x, y)
- f = relay.Function([x, y], z)
- mod = tvm.IRModule()
- mod["main"] = f
- return mod
-
-
-def run_check_add(exe, shape, dtype):
- x_data = np.random.randint(5, size=shape, dtype=dtype)
- y_data = np.random.randint(5, size=shape, dtype=dtype)
- ref = x_data + y_data
- inputs = {"x": x_data, "y": y_data}
- out = run_module(exe, inputs)
- tvm.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5, atol=1e-5)
-
-
-def test_add():
- if skip_test():
- return
- dtype = "int32"
- shape = (8, 4)
- mod = create_module_add(shape, dtype)
- mod = offload(mod)
- exe = compile_module(mod)
- run_check_add(exe, shape, dtype)
-
-
-if __name__ == "__main__":
- test_add()
diff --git a/tests/python/contrib/test_verilator/test_verilator_ops.py
b/tests/python/contrib/test_verilator/test_verilator_ops.py
new file mode 100644
index 0000000..19ed1f0
--- /dev/null
+++ b/tests/python/contrib/test_verilator/test_verilator_ops.py
@@ -0,0 +1,191 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Verilator codegen tests"""
+
+import numpy as np
+
+import tvm
+from tvm import relay
+
+from test_verilator.infrastructure import (
+ skip_test,
+ compile_hardware,
+ compiler_opts,
+ run_module,
+ offload,
+ clear_stats,
+ stats,
+)
+
+
+def create_module_add(shape, dtype):
+ """Create add module.
+
+ Paramters
+ ---------
+ shape : Tuple
+ The shape tuple.
+
+ dtype : Str
+ The data type.
+
+ Returns
+ -------
+ mod: Module
+ The relay module.
+ """
+ x = relay.var("x", shape=shape, dtype=dtype)
+ y = relay.var("y", shape=shape, dtype=dtype)
+ z = relay.add(x, y)
+ f = relay.Function([x, y], z)
+ mod = tvm.IRModule()
+ mod["main"] = f
+ return mod
+
+
+def create_module_bias_add(xshape, yshape, dtype):
+ """Create bias_add module.
+
+ Paramters
+ ---------
+ xshape : Tuple
+ The x shape tuple.
+
+ yshape : Tuple
+ The y shape tuple.
+
+ dtype : Str
+ The data type.
+
+ Returns
+ -------
+ mod: Module
+ The relay module.
+ """
+ x = relay.var("x", shape=xshape, dtype=dtype)
+ y = relay.var("y", shape=yshape, dtype=dtype)
+ z = relay.nn.bias_add(x, y, axis=3)
+ f = relay.Function([x, y], z)
+ mod = tvm.IRModule()
+ mod["main"] = f
+ return mod
+
+
+def run_and_check(xshape, yshape, dtype, mod, opts):
+ """Run and check values.
+
+ Paramters
+ ---------
+ xshape : Tuple
+ The x shape tuple.
+
+ yshape : Tuple
+ The y shape tuple.
+
+ dtype : Str
+ The data type.
+
+ mod: Module
+ The relay module.
+
+ opts: Dict
+ The compiler options.
+
+ Returns
+ -------
+ cycles: Int
+ The number of cycles.
+ """
+ x_data = np.random.randint(5, size=xshape, dtype=dtype)
+ y_data = np.random.randint(5, size=yshape, dtype=dtype)
+ ref = x_data + y_data
+ inp = {"x": x_data, "y": y_data}
+ clear_stats()
+ out = run_module(inp, mod, params=None, opts=opts)
+ values = stats()
+ tvm.testing.assert_allclose(out.asnumpy(), ref, rtol=1e-5, atol=1e-5)
+ return values["cycle_counter"]
+
+
+def print_test_info(test, lanes, cycles):
+ """Print counter
+
+ Paramters
+ ---------
+ test : Str
+ The name of the test.
+
+ lanes : Int
+ The number of vector lanes.
+
+ cycles : Int
+ The number of cycles.
+ """
+ print("test:{} vector-lanes:{} number of cycles:{}".format(test, lanes,
cycles))
+
+
+def tadd(lanes):
+ """Print counter
+
+ Paramters
+ ---------
+ lanes : Int
+ The number of vector lanes.
+ """
+ if skip_test():
+ return
+ dtype = "int32"
+ shape = (8, 4)
+ mod = create_module_add(shape, dtype)
+ mod = offload(mod)
+ lib = compile_hardware(lanes)
+ opts = compiler_opts(lib)
+ cycles = run_and_check(shape, shape, dtype, mod, opts)
+ print_test_info("add", lanes, cycles)
+
+
+def tbias(lanes):
+ """Print counter
+
+ Paramters
+ ---------
+ lanes : Int
+ The number of vector lanes.
+ """
+ if skip_test():
+ return
+ dtype = "int32"
+ xshape = (1, 112, 112, 32)
+ yshape = (32,)
+ mod = create_module_bias_add(xshape, yshape, dtype)
+ mod = offload(mod)
+ lib = compile_hardware(lanes)
+ opts = compiler_opts(lib)
+ cycles = run_and_check(xshape, yshape, dtype, mod, opts)
+ print_test_info("nn.bias_add", lanes, cycles)
+
+
+def test_add():
+ """add tests."""
+ tadd(1)
+ tadd(4)
+
+
+def test_bias_add():
+ """bias_add tests."""
+ tbias(1)
+ tbias(32)