This is an automated email from the ASF dual-hosted git repository. manupa pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new befa7f3 [microNPU] Fix bug with re-reading in EncodeConstants (#9646) befa7f3 is described below commit befa7f33587068b6bc2ee7f57dc726d7bb1dc365 Author: Matthew Barrett <55580676+mba...@users.noreply.github.com> AuthorDate: Sat Dec 4 07:43:56 2021 +0000 [microNPU] Fix bug with re-reading in EncodeConstants (#9646) When a striping strategy that leads to weights being re-read was deployed, the logic in EncodeConstants failed. This adds a test for that case and fixed the pass so it handles it correctly. Change-Id: I6f54cdb7be69428e49c3b4208271cd3e6c192e5d --- .../tvm/relay/backend/contrib/ethosu/tir/passes.py | 10 +++- .../contrib/test_ethosu/test_encode_constants.py | 66 ++++++++++++++++++++++ 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 41a6832..0a6dcd1 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -331,11 +331,15 @@ def EncodeConstants(const_dict): def _new_buffer(old_buffer, new_value): """Create a new buffer and add the old buffer and its pointer to the rewriting maps.""" - new_buffer = tvm.tir.decl_buffer((len(new_value),), str(new_value.dtype)) - pointer_to_buffer[new_buffer.data] = new_buffer + if old_buffer in rewrite_buffer: + new_buffer = rewrite_buffer[old_buffer] + else: + new_buffer = tvm.tir.decl_buffer((len(new_value),), str(new_value.dtype)) + pointer_to_buffer[new_buffer.data] = new_buffer + buffer_to_const[new_buffer] = new_value + rewrite_buffer[old_buffer] = new_buffer rewrite_pointer[old_buffer.data] = new_buffer.data - buffer_to_const[new_buffer] = new_value def _visit_encode_pre(stmt): if isinstance(stmt, tvm.tir.Call): diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index de8a7f9..7f5eeb1 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -110,6 +110,72 @@ def test_weight_stream_only(): # fmt: off @tvm.script.ir_module +class RereadWeights: + @T.prim_func + def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + placeholder_3 = T.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8") + buffer = T.match_buffer(placeholder_1, [304], dtype="uint8") + buffer_1 = T.match_buffer(placeholder_2, [80], dtype="uint8") + ethosu_write_1 = T.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8") + # body + placeholder_global = T.allocate([304], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer.data, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1.data, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder_3.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write_1.data, 64), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +def test_re_read_weights(): + def _cascader(cached_func, const_dict, sch): + weights = cached_func.inputs[1] + bias = cached_func.inputs[2] + out = cached_func.outputs[0] + conv_compute = Convolution2DCompute.from_output(out) + co = conv_compute.split(sch, 2, 8) + cache_weights = sch.cache_read(weights, "global", [conv_compute.conv2d]) + cache_bias = sch.cache_read(bias, "global", [conv_compute.conv2d]) + sch[cache_weights].compute_at(sch[out], co) + sch[cache_bias].compute_at(sch[out], co) + + def _get_func(): + ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8") + conv = make_ethosu_conv2d( + ifm, + 32, + 8, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + func = relay.Function(relay.analysis.free_vars(conv), conv) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + mod, consts = lower_to_tir(func, cascader=_cascader) + script = mod.script(show_meta=True) + test_mod = tvm.script.from_source(script) + reference_mod = RereadWeights + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + + reference_const_sizes = {1: 304, 2: 80} + test_const_sizes = {} + for key, value in consts.items(): + test_const_sizes[key] = len(value) + + assert reference_const_sizes == test_const_sizes + + +# fmt: off +@tvm.script.ir_module class DirectReadOnly: @T.prim_func def main(placeholder: T.handle, placeholder_1: T.handle, placeholder_2: T.handle, placeholder_3: T.handle, placeholder_4: T.handle, ethosu_write: T.handle) -> None: