Mousius commented on a change in pull request #8833:
URL: https://github.com/apache/tvm/pull/8833#discussion_r694965395



##########
File path: python/tvm/relay/backend/contrib/cmsisnn/codegen.py
##########
@@ -0,0 +1,140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Codegen for CMSIS-NN"""
+import tvm
+from tvm import relay
+from tvm.relay.expr_functor import ExprVisitor
+
+
+def generate_tir(name, func):
+    """Generates TIR"""
+
+    class GenerateTIR(ExprVisitor):
+        """Generates TIR module containing TIR primfuncs corresponding to the 
Relay operators.
+        Note: Relay operator to primfunc mapping may not be 1:1.
+        """
+
+        def __init__(self, name):
+            super().__init__()
+            self.name = name
+            self.tir_mod = None
+            self.scale = 1.0 / 256
+
+        def call_contains_op(self, call, op_name):
+            if not isinstance(call.op, tvm.ir.op.Op):
+                return False
+            if call.op.name != op_name:
+                return False
+            return True
+
+        def is_quantized_softmax(self, call):
+            """Checks for the following relay sequence
+            a = qnn.dequantize(in, scale, zero_point)
+            b = nn.softmax(a)
+            c = qnn.quantize(c, scale, zero_point)
+            """
+            if not self.call_contains_op(call, "qnn.quantize"):
+                return False
+            softmax_call = call.args[0]
+            if not self.call_contains_op(softmax_call, "nn.softmax"):
+                return False
+            dequantize_call = softmax_call.args[0]
+            if not self.call_contains_op(dequantize_call, "qnn.dequantize"):
+                return False
+            if not call.attrs.out_dtype == "int8":
+                return False
+            self.scale = dequantize_call.args[1].data.numpy().item(0)
+            return True
+
+        def emit_softmax_tir(self, call):
+            """Generates TIR extern_call for softmax"""
+            shape = call.checked_type.shape  # NHWC
+            dtype = call.checked_type.dtype
+            ir_builder = tvm.tir.ir_builder.create()
+            in_buf = tvm.tir.decl_buffer(shape=shape, dtype=dtype)
+            out_buf = tvm.tir.decl_buffer(shape=shape, dtype=dtype)
+            num_rows = shape[0] * shape[1] * shape[2]
+            row_size = shape[3]
+            ir_builder.emit(
+                tvm.tir.call_extern(
+                    dtype,
+                    "arm_softmax_s8",
+                    in_buf.data,
+                    num_rows,
+                    row_size,
+                    self.scale,
+                    out_buf.data,
+                )
+            )
+            prim_func = tvm.tir.PrimFunc([in_buf, out_buf], ir_builder.get())
+            prim_func = prim_func.with_attr("global_symbol", self.name)
+            prim_func = prim_func.with_attr("tir.noalias", True)
+            self.tir_mod = tvm.IRModule({self.name: prim_func})
+
+        def visit_call(self, call):
+            """Iterates over the relay operators within relay external 
function"""
+            super().visit_call(call)
+            if self.is_quantized_softmax(call):
+                self.emit_softmax_tir(call)
+
+        def generate_tir(self, func):
+            self.visit(func)
+            return self.tir_mod
+
+    tir_mod = GenerateTIR(name).generate_tir(func)
+    return tir_mod
+
+
+def relay_to_tir(name, func):
+    """Lower a Relay function to TIR for the CMSIS-NN target.
+
+    The Relay function should only contain operations supported
+    by the CMSIS-NN target. This is enforced by the graph partitioner
+    for CMSIS-NN.
+
+    Parameters
+    ----------
+    name: str
+        Name of the external relay function
+    func : tvm.relay.Function
+        The Relay function to lower.
+
+    Returns
+    -------
+    mod : tvm.IRModule
+        The lowered TIR module.
+
+    """
+    tir_mod = generate_tir(name, func)
+    return tir_mod
+
+
+@tvm._ffi.register_func("relay.ext.cmsisnn")

Review comment:
       `@tvm.register_func` 

##########
File path: python/tvm/relay/backend/contrib/cmsisnn/codegen.py
##########
@@ -0,0 +1,140 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Codegen for CMSIS-NN"""
+import tvm
+from tvm import relay
+from tvm.relay.expr_functor import ExprVisitor
+
+
+def generate_tir(name, func):
+    """Generates TIR"""
+
+    class GenerateTIR(ExprVisitor):
+        """Generates TIR module containing TIR primfuncs corresponding to the 
Relay operators.
+        Note: Relay operator to primfunc mapping may not be 1:1.
+        """
+
+        def __init__(self, name):
+            super().__init__()
+            self.name = name
+            self.tir_mod = None
+            self.scale = 1.0 / 256
+
+        def call_contains_op(self, call, op_name):
+            if not isinstance(call.op, tvm.ir.op.Op):
+                return False
+            if call.op.name != op_name:
+                return False
+            return True
+
+        def is_quantized_softmax(self, call):
+            """Checks for the following relay sequence
+            a = qnn.dequantize(in, scale, zero_point)
+            b = nn.softmax(a)
+            c = qnn.quantize(c, scale, zero_point)
+            """
+            if not self.call_contains_op(call, "qnn.quantize"):
+                return False
+            softmax_call = call.args[0]
+            if not self.call_contains_op(softmax_call, "nn.softmax"):
+                return False
+            dequantize_call = softmax_call.args[0]
+            if not self.call_contains_op(dequantize_call, "qnn.dequantize"):
+                return False
+            if not call.attrs.out_dtype == "int8":
+                return False
+            self.scale = dequantize_call.args[1].data.numpy().item(0)
+            return True
+
+        def emit_softmax_tir(self, call):
+            """Generates TIR extern_call for softmax"""
+            shape = call.checked_type.shape  # NHWC
+            dtype = call.checked_type.dtype
+            ir_builder = tvm.tir.ir_builder.create()
+            in_buf = tvm.tir.decl_buffer(shape=shape, dtype=dtype)
+            out_buf = tvm.tir.decl_buffer(shape=shape, dtype=dtype)
+            num_rows = shape[0] * shape[1] * shape[2]
+            row_size = shape[3]
+            ir_builder.emit(
+                tvm.tir.call_extern(
+                    dtype,
+                    "arm_softmax_s8",
+                    in_buf.data,
+                    num_rows,
+                    row_size,
+                    self.scale,
+                    out_buf.data,
+                )
+            )
+            prim_func = tvm.tir.PrimFunc([in_buf, out_buf], ir_builder.get())
+            prim_func = prim_func.with_attr("global_symbol", self.name)
+            prim_func = prim_func.with_attr("tir.noalias", True)
+            self.tir_mod = tvm.IRModule({self.name: prim_func})
+
+        def visit_call(self, call):
+            """Iterates over the relay operators within relay external 
function"""
+            super().visit_call(call)
+            if self.is_quantized_softmax(call):
+                self.emit_softmax_tir(call)
+
+        def generate_tir(self, func):
+            self.visit(func)
+            return self.tir_mod
+
+    tir_mod = GenerateTIR(name).generate_tir(func)
+    return tir_mod
+
+
+def relay_to_tir(name, func):
+    """Lower a Relay function to TIR for the CMSIS-NN target.
+
+    The Relay function should only contain operations supported
+    by the CMSIS-NN target. This is enforced by the graph partitioner
+    for CMSIS-NN.
+
+    Parameters
+    ----------
+    name: str
+        Name of the external relay function
+    func : tvm.relay.Function
+        The Relay function to lower.
+
+    Returns
+    -------
+    mod : tvm.IRModule
+        The lowered TIR module.
+
+    """
+    tir_mod = generate_tir(name, func)
+    return tir_mod
+
+
+@tvm._ffi.register_func("relay.ext.cmsisnn")
+def cmsisnn_compiler(relay_func):
+    """It compiles Relay's external function into equivalent TIR
+    and subsequently converts that into 'c' code. During the 'c'
+    code generation, it embeds CMSIS-NN APIs for the corresponding
+    operators.
+    """
+    assert isinstance(relay_func, tvm.ir.function.BaseFunc)

Review comment:
       What's the purpose of this assertion?

##########
File path: src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc
##########
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <dmlc/filesystem.h>
+#include <dmlc/logging.h>
+#include <dmlc/memory_io.h>
+#include <tvm/ir/expr.h>
+#include <tvm/ir/attrs.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/memory.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/builtin.h>
+
+#include <cmath>
+#include <fstream>
+#include <map>
+#include <sstream>
+#include <string>
+#include <vector>
+
+#include "../../../../target/source/codegen_c.h"
+#include "../../../../runtime/file_utils.h"
+#include "../../../qnn/utils.h"
+
+namespace tvm {
+namespace runtime {
+
+using namespace tir;
+
+class CodeGenCMSISNN : public tvm::codegen::CodeGenC {
+ public:
+
+  void Init(bool output_ssa) {
+    decl_stream << "#include <stdio.h>\n";
+    decl_stream << "#include <stdlib.h>\n";
+    decl_stream << "#include <dlpack/dlpack.h>\n";
+    decl_stream << "#include <tvm/runtime/crt/module.h>\n";
+    decl_stream << "#include <arm_nnfunctions.h>\n";
+    CodeGenC::Init(output_ssa);
+  }
+
+  
+  /*!
+   * \brief Emit code that offloads a subgraph to the Cortex-M
+   *
+   * \return string of code that offloads a subgraph to the Cortex-M
+   */
+  void AddFunction(PrimFunc& prim_func) {
+    PrintExternCPrefix(stream);
+    CodeGenC::AddFunction(prim_func);
+    PrintExternCPostfix(stream);
+  }
+
+private:
+  void VisitExpr_(const CallNode* op, std::ostream& os) {  // NOLINT(*)
+      if (not op->op.same_as(builtin::call_extern())) {
+        return;
+      }
+      std::string cmsis_func_name = op->args[0].as<StringImmNode>()->value;
+      if(cmsis_func_name.find("softmax") != std::string::npos) {

Review comment:
       Is there a more concrete name we can provide here? Like 
`cmsisnn_softmax` ?

##########
File path: tests/python/contrib/test_cmsisnn/test_softmax.py
##########
@@ -0,0 +1,197 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""CMSIS-NN integration tests: softmax"""
+
+import sys
+import os
+import pathlib
+import tvm
+from tvm import relay
+from tvm.relay.op.contrib import cmsisnn
+import numpy as np
+import pytest
+
+# AOT test runner is required for running CMSIS-NN tests
+# current file path: tests/python/contrib/test_cmsisnn
+# AOT Test runner: tests/python/relay/aot
+aot_tests_path = os.path.join(str(pathlib.Path(__file__).parent.resolve()), 
"../../relay")
+sys.path.insert(0, aot_tests_path)
+import aot

Review comment:
       Can you not use `from tests.python.relay.aot.aot_test_utils import ...` ?

##########
File path: tests/python/contrib/test_cmsisnn/test_softmax.py
##########
@@ -0,0 +1,197 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""CMSIS-NN integration tests: softmax"""
+
+import sys
+import os
+import pathlib
+import tvm
+from tvm import relay
+from tvm.relay.op.contrib import cmsisnn
+import numpy as np
+import pytest
+
+# AOT test runner is required for running CMSIS-NN tests
+# current file path: tests/python/contrib/test_cmsisnn
+# AOT Test runner: tests/python/relay/aot
+aot_tests_path = os.path.join(str(pathlib.Path(__file__).parent.resolve()), 
"../../relay")
+sys.path.insert(0, aot_tests_path)
+import aot
+from aot.aot_test_utils import (
+    AOTTestModel,
+    AOT_CORSTONE300_RUNNER,
+    generate_ref_data,
+    convert_to_relay,
+    compile_and_run,
+)
+
+
+def teardown_module(module):
+    """pytest is going to clean up the additional python paths set required by 
tests in this file."""
+    sys.path.pop(0)
+
+
+def get_range_for_dtype_str(dtype):
+    """
+    Produce the min,max for a give data type.
+
+    Parameters
+    ----------
+    dtype : str
+        a type string (e.g., int8)
+
+    Returns
+    -------
+    type_info.min : int
+        the minimum of the range
+    type_info.max : int
+        the maximum of the range
+    """
+
+    try:
+        type_info = np.iinfo(dtype)
+    except ValueError:
+        type_info = np.finfo(dtype)
+    return type_info.min, type_info.max
+
+
+def count_num_calls(mod):
+    """Count number of CallNode in the IRModule"""
+
+    class CallCounter(relay.ExprVisitor):
+        def __init__(self):
+            super().__init__()
+            self.count = 0
+
+        def visit_call(self, call):
+            if isinstance(call.op, tvm.ir.Op):
+                self.count += 1
+
+            super().visit_call(call)
+
+    counter = CallCounter()
+    for var in mod.get_global_vars():
+        counter.visit(mod[var.name_hint])
+    return counter.count
+
+
+def make_module(func):
+    """Create IRModule from Function"""
+    func = relay.Function(relay.analysis.free_vars(func), func)
+    mod = tvm.IRModule.from_expr(func)
+    return relay.transform.InferType()(mod)
+
+
+def make_model(shape, zero_point, scale, in_dtype, out_dtype):
+    """Create a Relay Function / network model"""
+    a = relay.var("in0", shape=shape, dtype=in_dtype)
+    dequantize = relay.qnn.op.dequantize(
+        a,
+        input_scale=relay.const(scale, "float32"),
+        input_zero_point=relay.const(zero_point, "int32"),
+    )
+    softmax = relay.nn.softmax(dequantize)
+    model = relay.qnn.op.quantize(
+        softmax,
+        output_scale=relay.const(scale, "float32"),
+        output_zero_point=relay.const(zero_point, "int32"),
+        out_dtype=out_dtype,
+    )
+    return model
+
+
+def test_softmax_int8():
+    interface_api = "c"
+    use_unpacked_api = True
+    test_runner = AOT_CORSTONE300_RUNNER
+
+    dtype = "int8"
+    shape = [1, 16, 16, 3]
+    zero_point = -128
+    scale = 1.0 / 256
+    model = make_model(shape, zero_point, scale, dtype, dtype)
+    orig_mod = make_module(model)
+
+    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)
+
+    # validate pattern matching
+    attrs = [
+        cmsisnn_mod[var.name_hint].attrs
+        for var in cmsisnn_mod.get_global_vars()
+        if cmsisnn_mod[var.name_hint].attrs
+    ]
+    assert any(attrs), "At least one function with external attributes was 
expected."
+
+    compilers = [
+        key == "Compiler" and value == "cmsisnn" for attr in attrs for key, 
value in attr.items()
+    ]
+    assert any(compilers), "Module does not contain function for cmsisnn 
target."
+
+    assert count_num_calls(orig_mod) == count_num_calls(
+        cmsisnn_mod
+    ), "Number of calls changed during partitioning"
+
+    # validate the output
+    in_min, in_max = get_range_for_dtype_str(dtype)
+    np.random.seed(0)
+    input_data = np.random.randint(in_min, high=in_max, size=shape, 
dtype=dtype)
+    inputs = {"in0": input_data}
+    params = {}
+    output_list = generate_ref_data(orig_mod["main"], inputs, params)
+    compile_and_run(
+        AOTTestModel(module=cmsisnn_mod, inputs=inputs, outputs=output_list, 
params=params),
+        test_runner,
+        interface_api,
+        use_unpacked_api,
+    )
+
+
+def test_softmax_invalid_scale():

Review comment:
       Can we test `scale` and `zero_point` independently and in combination 
with data types? This test could fail to match based on one or many of these.

##########
File path: src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc
##########
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <dmlc/filesystem.h>
+#include <dmlc/logging.h>
+#include <dmlc/memory_io.h>
+#include <tvm/ir/expr.h>
+#include <tvm/ir/attrs.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/memory.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/object.h>
+#include <tvm/runtime/packed_func.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/expr.h>
+#include <tvm/tir/function.h>

Review comment:
       I'm not sure all of these headers are needed in this file? For example, 
`tvm/runtime/ndarray.h` is also included from `tvm/tir/function.h`. Can you 
limit these to specifically what's being used in this file?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to