ashutosh-arm commented on a change in pull request #8833: URL: https://github.com/apache/tvm/pull/8833#discussion_r695629688
########## File path: tests/python/contrib/test_cmsisnn/test_softmax.py ########## @@ -0,0 +1,197 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CMSIS-NN integration tests: softmax""" + +import sys +import os +import pathlib +import tvm +from tvm import relay +from tvm.relay.op.contrib import cmsisnn +import numpy as np +import pytest + +# AOT test runner is required for running CMSIS-NN tests +# current file path: tests/python/contrib/test_cmsisnn +# AOT Test runner: tests/python/relay/aot +aot_tests_path = os.path.join(str(pathlib.Path(__file__).parent.resolve()), "../../relay") +sys.path.insert(0, aot_tests_path) +import aot +from aot.aot_test_utils import ( + AOTTestModel, + AOT_CORSTONE300_RUNNER, + generate_ref_data, + convert_to_relay, + compile_and_run, +) + + +def teardown_module(module): + """pytest is going to clean up the additional python paths set required by tests in this file.""" + sys.path.pop(0) + + +def get_range_for_dtype_str(dtype): + """ + Produce the min,max for a give data type. + + Parameters + ---------- + dtype : str + a type string (e.g., int8) + + Returns + ------- + type_info.min : int + the minimum of the range + type_info.max : int + the maximum of the range + """ + + try: + type_info = np.iinfo(dtype) + except ValueError: + type_info = np.finfo(dtype) + return type_info.min, type_info.max + + +def count_num_calls(mod): + """Count number of CallNode in the IRModule""" + + class CallCounter(relay.ExprVisitor): + def __init__(self): + super().__init__() + self.count = 0 + + def visit_call(self, call): + if isinstance(call.op, tvm.ir.Op): + self.count += 1 + + super().visit_call(call) + + counter = CallCounter() + for var in mod.get_global_vars(): + counter.visit(mod[var.name_hint]) + return counter.count + + +def make_module(func): + """Create IRModule from Function""" + func = relay.Function(relay.analysis.free_vars(func), func) + mod = tvm.IRModule.from_expr(func) + return relay.transform.InferType()(mod) + + +def make_model(shape, zero_point, scale, in_dtype, out_dtype): + """Create a Relay Function / network model""" + a = relay.var("in0", shape=shape, dtype=in_dtype) + dequantize = relay.qnn.op.dequantize( + a, + input_scale=relay.const(scale, "float32"), + input_zero_point=relay.const(zero_point, "int32"), + ) + softmax = relay.nn.softmax(dequantize) + model = relay.qnn.op.quantize( + softmax, + output_scale=relay.const(scale, "float32"), + output_zero_point=relay.const(zero_point, "int32"), + out_dtype=out_dtype, + ) + return model + + +def test_softmax_int8(): + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_CORSTONE300_RUNNER + + dtype = "int8" + shape = [1, 16, 16, 3] + zero_point = -128 + scale = 1.0 / 256 + model = make_model(shape, zero_point, scale, dtype, dtype) + orig_mod = make_module(model) + + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + + # validate pattern matching + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert any(attrs), "At least one function with external attributes was expected." + + compilers = [ + key == "Compiler" and value == "cmsisnn" for attr in attrs for key, value in attr.items() + ] + assert any(compilers), "Module does not contain function for cmsisnn target." + + assert count_num_calls(orig_mod) == count_num_calls( + cmsisnn_mod + ), "Number of calls changed during partitioning" + + # validate the output + in_min, in_max = get_range_for_dtype_str(dtype) + np.random.seed(0) + input_data = np.random.randint(in_min, high=in_max, size=shape, dtype=dtype) + inputs = {"in0": input_data} + params = {} + output_list = generate_ref_data(orig_mod["main"], inputs, params) + compile_and_run( + AOTTestModel(module=cmsisnn_mod, inputs=inputs, outputs=output_list, params=params), + test_runner, + interface_api, + use_unpacked_api, + ) + + +def test_softmax_invalid_scale(): Review comment: Done! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
