FrozenGene commented on a change in pull request #5915:
URL: https://github.com/apache/incubator-tvm/pull/5915#discussion_r456186767



##########
File path: tests/python/contrib/test_arm_compute_lib/infrastructure.py
##########
@@ -0,0 +1,197 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from itertools import zip_longest, combinations
+import json
+
+import tvm
+from tvm import relay
+from tvm import rpc
+from tvm.contrib import graph_runtime
+from tvm.relay.op.contrib import arm_compute_lib
+from tvm.contrib import util
+
+
+class Device:
+    """Adjust the following settings to connect to and use a remote device for 
tests."""
+    use_remote = False
+    target = "llvm -mtriple=aarch64-linux-gnu -mattr=+neon"
+    # Enable cross compilation when connecting a remote device from a non-arm 
platform.
+    cross_compile = None
+    # cross_compile = "aarch64-linux-gnu-g++"
+
+    def __init__(self):
+        """Keep remote device for lifetime of object."""
+        self.device = self._get_remote()
+
+    @classmethod
+    def _get_remote(cls):
+        """Get a remote (or local) device to use for testing."""
+        if cls.use_remote:
+            # Here you may adjust settings to run the ACL unit tests via a 
remote
+            # device using the RPC mechanism. Use this in the case you want to 
compile
+            # an ACL module on a different machine to what you run the module 
on i.e.
+            # x86 -> AArch64.
+            #
+            # Use the following to connect directly to a remote device:
+            # device = rpc.connect(
+            #     hostname="0.0.0.0",
+            #     port=9090)
+            #
+            # Or connect via a tracker:
+            # device = tvm.autotvm.measure.request_remote(
+            #     host="0.0.0.0",
+            #     port=9090,
+            #     device_key="device_key",
+            #     timeout=1000)
+            #
+            # return device
+            raise NotImplementedError(
+                "Please adjust these settings to connect to your remote 
device.")
+        else:
+            device = rpc.LocalSession()
+            return device
+
+
+def get_cpu_op_count(mod):
+    """Traverse graph counting ops offloaded to TVM."""
+    class Counter(tvm.relay.ExprVisitor):
+        def __init__(self):
+            super().__init__()
+            self.count = 0
+
+        def visit_call(self, call):
+            if isinstance(call.op, tvm.ir.Op):
+                self.count += 1
+
+            super().visit_call(call)
+
+    c = Counter()
+    c.visit(mod["main"])
+    return c.count
+
+
+def skip_runtime_test():
+    """Skip test if it requires the runtime and it's not present."""
+    # ACL codegen not present.
+    if not tvm.get_global_func("relay.ext.arm_compute_lib", True):
+        print("Skip because Arm Compute Library codegen is not available.")
+        return True
+
+    # Remote device is in use or ACL runtime not present
+    if not Device.use_remote and not 
arm_compute_lib.is_arm_compute_runtime_enabled():
+        print("Skip because runtime isn't present or a remote device isn't 
being used.")
+        return True
+
+
+def skip_codegen_test():
+    """Skip test if it requires the ACL codegen and it's not present."""
+    if not tvm.get_global_func("relay.ext.arm_compute_lib", True):
+        print("Skip because Arm Compute Library codegen is not available.")
+        return True
+
+
+def build_module(mod, target, params=None, enable_acl=True, tvm_ops=0, 
acl_partitions=1):
+    """Build module with option to build for ACL."""
+    if isinstance(mod, tvm.relay.expr.Call):
+        mod = tvm.IRModule.from_expr(mod)
+    with tvm.transform.PassContext(opt_level=3, 
disabled_pass=["AlterOpLayout"]):
+        if enable_acl:
+            mod = arm_compute_lib.partition_for_arm_compute_lib(mod, params)
+            tvm_op_count = get_cpu_op_count(mod)
+            assert tvm_op_count == tvm_ops, \
+                "Got {} TVM operators, expected {}".format(tvm_op_count, 
tvm_ops)
+            partition_count = 0
+            for global_var in mod.get_global_vars():
+                if "arm_compute_lib" in global_var.name_hint:
+                    partition_count += 1
+
+            assert acl_partitions == partition_count, \
+                "Got {} Arm Compute Library partitions, expected {}".format(
+                    partition_count, acl_partitions)
+        relay.backend.compile_engine.get().clear()
+        return relay.build(mod, target=target, params=params)
+
+
+def build_and_run(mod, inputs, outputs, params, device, enable_acl=True, 
no_runs=1,
+                  tvm_ops=0, acl_partitions=1):
+    """Build and run the relay module."""
+    lib = build_module(mod, device.target, params, enable_acl, tvm_ops, 
acl_partitions)
+    lib = update_lib(lib, device.device, device.cross_compile)
+    gen_module = 
graph_runtime.GraphModule(lib['default'](device.device.cpu(0)))
+    gen_module.set_input(**inputs)
+    out = []
+    for _ in range(no_runs):
+        gen_module.run()
+        out.append([gen_module.get_output(i) for i in range(outputs)])
+    return out
+
+
+def update_lib(lib, device, cross_compile):
+    """Export the library to the remote/local device."""
+    lib_name = "mod.so"
+    temp = util.tempdir()
+    lib_path = temp.relpath(lib_name)
+    if cross_compile:
+        lib.export_library(lib_path, cc=cross_compile)
+    else:
+        lib.export_library(lib_path)
+    device.upload(lib_path)
+    lib = device.load_module(lib_name)
+    return lib
+
+
+def verify(answers, atol, rtol):
+    """Compare the array of answers. Each entry is a list of outputs."""
+    if len(answers) < 2:
+        raise RuntimeError(
+            f"No results to compare: expected at least two, found 
{len(answers)}")
+    for answer in zip_longest(*answers):
+        for outs in combinations(answer, 2):
+            tvm.testing.assert_allclose(
+               outs[0].asnumpy(), outs[1].asnumpy(), rtol=rtol, atol=atol)
+
+
+def extract_acl_modules(module):
+    """Get the ACL module(s) from llvm module."""
+    return list(filter(lambda mod: mod.type_key == "arm_compute_lib",
+                       module.lib.imported_modules))

Review comment:
       Let us add one function `get_lib` function instead of using attribute 
directly inside `GraphRuntimeFactoryModule` if it is a must.

##########
File path: tests/python/contrib/test_arm_compute_lib/infrastructure.py
##########
@@ -0,0 +1,197 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from itertools import zip_longest, combinations
+import json
+
+import tvm
+from tvm import relay
+from tvm import rpc
+from tvm.contrib import graph_runtime
+from tvm.relay.op.contrib import arm_compute_lib
+from tvm.contrib import util
+
+
+class Device:
+    """Adjust the following settings to connect to and use a remote device for 
tests."""
+    use_remote = False
+    target = "llvm -mtriple=aarch64-linux-gnu -mattr=+neon"
+    # Enable cross compilation when connecting a remote device from a non-arm 
platform.
+    cross_compile = None
+    # cross_compile = "aarch64-linux-gnu-g++"
+
+    def __init__(self):
+        """Keep remote device for lifetime of object."""
+        self.device = self._get_remote()
+
+    @classmethod
+    def _get_remote(cls):
+        """Get a remote (or local) device to use for testing."""
+        if cls.use_remote:
+            # Here you may adjust settings to run the ACL unit tests via a 
remote
+            # device using the RPC mechanism. Use this in the case you want to 
compile
+            # an ACL module on a different machine to what you run the module 
on i.e.
+            # x86 -> AArch64.
+            #
+            # Use the following to connect directly to a remote device:
+            # device = rpc.connect(
+            #     hostname="0.0.0.0",
+            #     port=9090)
+            #
+            # Or connect via a tracker:
+            # device = tvm.autotvm.measure.request_remote(
+            #     host="0.0.0.0",
+            #     port=9090,
+            #     device_key="device_key",
+            #     timeout=1000)
+            #
+            # return device
+            raise NotImplementedError(
+                "Please adjust these settings to connect to your remote 
device.")
+        else:
+            device = rpc.LocalSession()
+            return device
+
+
+def get_cpu_op_count(mod):
+    """Traverse graph counting ops offloaded to TVM."""
+    class Counter(tvm.relay.ExprVisitor):
+        def __init__(self):
+            super().__init__()
+            self.count = 0
+
+        def visit_call(self, call):
+            if isinstance(call.op, tvm.ir.Op):
+                self.count += 1
+
+            super().visit_call(call)
+
+    c = Counter()
+    c.visit(mod["main"])
+    return c.count
+
+
+def skip_runtime_test():
+    """Skip test if it requires the runtime and it's not present."""
+    # ACL codegen not present.
+    if not tvm.get_global_func("relay.ext.arm_compute_lib", True):
+        print("Skip because Arm Compute Library codegen is not available.")
+        return True
+
+    # Remote device is in use or ACL runtime not present
+    if not Device.use_remote and not 
arm_compute_lib.is_arm_compute_runtime_enabled():
+        print("Skip because runtime isn't present or a remote device isn't 
being used.")
+        return True
+
+
+def skip_codegen_test():
+    """Skip test if it requires the ACL codegen and it's not present."""
+    if not tvm.get_global_func("relay.ext.arm_compute_lib", True):
+        print("Skip because Arm Compute Library codegen is not available.")
+        return True
+
+
+def build_module(mod, target, params=None, enable_acl=True, tvm_ops=0, 
acl_partitions=1):
+    """Build module with option to build for ACL."""
+    if isinstance(mod, tvm.relay.expr.Call):
+        mod = tvm.IRModule.from_expr(mod)
+    with tvm.transform.PassContext(opt_level=3, 
disabled_pass=["AlterOpLayout"]):
+        if enable_acl:
+            mod = arm_compute_lib.partition_for_arm_compute_lib(mod, params)
+            tvm_op_count = get_cpu_op_count(mod)
+            assert tvm_op_count == tvm_ops, \
+                "Got {} TVM operators, expected {}".format(tvm_op_count, 
tvm_ops)
+            partition_count = 0
+            for global_var in mod.get_global_vars():
+                if "arm_compute_lib" in global_var.name_hint:
+                    partition_count += 1
+
+            assert acl_partitions == partition_count, \
+                "Got {} Arm Compute Library partitions, expected {}".format(
+                    partition_count, acl_partitions)
+        relay.backend.compile_engine.get().clear()
+        return relay.build(mod, target=target, params=params)
+
+
+def build_and_run(mod, inputs, outputs, params, device, enable_acl=True, 
no_runs=1,
+                  tvm_ops=0, acl_partitions=1):
+    """Build and run the relay module."""
+    lib = build_module(mod, device.target, params, enable_acl, tvm_ops, 
acl_partitions)
+    lib = update_lib(lib, device.device, device.cross_compile)
+    gen_module = 
graph_runtime.GraphModule(lib['default'](device.device.cpu(0)))
+    gen_module.set_input(**inputs)
+    out = []
+    for _ in range(no_runs):
+        gen_module.run()
+        out.append([gen_module.get_output(i) for i in range(outputs)])
+    return out
+
+
+def update_lib(lib, device, cross_compile):
+    """Export the library to the remote/local device."""
+    lib_name = "mod.so"
+    temp = util.tempdir()
+    lib_path = temp.relpath(lib_name)
+    if cross_compile:
+        lib.export_library(lib_path, cc=cross_compile)
+    else:
+        lib.export_library(lib_path)
+    device.upload(lib_path)
+    lib = device.load_module(lib_name)
+    return lib
+
+
+def verify(answers, atol, rtol):
+    """Compare the array of answers. Each entry is a list of outputs."""
+    if len(answers) < 2:
+        raise RuntimeError(
+            f"No results to compare: expected at least two, found 
{len(answers)}")
+    for answer in zip_longest(*answers):
+        for outs in combinations(answer, 2):
+            tvm.testing.assert_allclose(
+               outs[0].asnumpy(), outs[1].asnumpy(), rtol=rtol, atol=atol)
+
+
+def extract_acl_modules(module):
+    """Get the ACL module(s) from llvm module."""
+    return list(filter(lambda mod: mod.type_key == "arm_compute_lib",
+                       module.lib.imported_modules))

Review comment:
       Let us add one function `get_lib` instead of using attribute directly 
inside `GraphRuntimeFactoryModule` if it is a must.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to