This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 794e1e3 [Testing] Add model loader for int8 BERT (#10622)
794e1e3 is described below
commit 794e1e36ec57f59cdc7431845f21cba01296b1e0
Author: Masahiro Masuda <[email protected]>
AuthorDate: Wed Mar 16 14:19:12 2022 +0900
[Testing] Add model loader for int8 BERT (#10622)
* add model loader for qat bert-base
* add test
* pylint
* ignore mypy
* Update python/tvm/meta_schedule/testing/tlcbench.py
Co-authored-by: Junru Shao <[email protected]>
* use a dedicated process for converting
* return input info
* encode batch size and seq_len information in cached file path
Co-authored-by: Junru Shao <[email protected]>
---
python/tvm/meta_schedule/testing/tlcbench.py | 124 +++++++++++++++++++++
.../unittest/test_meta_schedule_integration.py | 17 +++
2 files changed, 141 insertions(+)
diff --git a/python/tvm/meta_schedule/testing/tlcbench.py
b/python/tvm/meta_schedule/testing/tlcbench.py
new file mode 100644
index 0000000..c6ab672
--- /dev/null
+++ b/python/tvm/meta_schedule/testing/tlcbench.py
@@ -0,0 +1,124 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,import-outside-toplevel
+# type: ignore
+"""Model loader for TLCBench."""
+import multiprocessing
+import os
+import logging
+import tvm
+from tvm import relay
+from tvm.error import TVMError
+from tvm.contrib.download import download_testdata
+
+
+log = logging.getLogger(__name__)
+
+
+def _convert(args):
+ onnx_model, shape_dict, json_path, params_path = args
+ mod, params = relay.frontend.from_onnx(onnx_model, shape_dict,
freeze_params=True)
+
+ seq = tvm.transform.Sequential(
+ [relay.transform.InferType(),
relay.transform.FakeQuantizationToInteger(use_qat=True)]
+ )
+ mod = seq(mod)
+
+ with open(json_path, "w") as fo:
+ fo.write(tvm.ir.save_json(mod))
+
+ with open(params_path, "wb") as fo:
+ fo.write(relay.save_param_dict(params))
+
+
+def convert_to_qnn(onnx_path, json_path, params_path, input_info):
+ """Run the ONNX frontend and the FQ2I pass. The output is serialized to
disk."""
+ import onnx
+
+ onnx_model = onnx.load(onnx_path)
+
+ shape_dict = dict(input_info)
+
+ log.info("Converting te ONNX model to Relay and running the FQ2I pass, it
may take a while...")
+
+ with multiprocessing.Pool(processes=1) as pool:
+ pool.map(_convert, [(onnx_model, shape_dict, json_path, params_path)])
+
+
+def deserialize_relay(json_path, params_path):
+ with open(json_path, "r") as fi:
+ mod = tvm.ir.load_json(fi.read())
+
+ with open(params_path, "rb") as fi:
+ params = relay.load_param_dict(fi.read())
+
+ return mod, params
+
+
+def load_quantized_bert_base(batch_size=1, seq_len=384):
+ """
+ Load the quantized bert-base model from TLCBench, possibly downloading it
from github
+ and caching the converted int8 QNN module to disk.
+
+ In addition to returing the relay module and its parameters, it also
returns input name
+ and shape information, which can be used at the deployment time as follows:
+
+ ```
+ mod, params, input_info = load_quantized_bert_base()
+
+ ...
+
+ runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+ for name, shape in input_info:
+ arr = np.random.uniform(1, 10, size=shape).astype("int64")
+ runtime.set_input(name, arr)
+
+ runtime.run()
+ ```
+
+ """
+ url =
"https://github.com/tlc-pack/TLCBench/raw/main/models/bert-base-qat.onnx"
+ log.info("Downloading quantized bert-base model.")
+ onnx_path = download_testdata(url, "bert-base-qat.onnx", module="tlcbench")
+ data_dir = os.path.dirname(onnx_path)
+
+ json_path = os.path.join(data_dir, "bert_base_int8_b%d_s%d.json" %
(batch_size, seq_len))
+ params_path = os.path.join(data_dir, "bert_base_int8_b%d_s%d.params" %
(batch_size, seq_len))
+
+ # Input names and order encoded in the ONNX model
+ input_info = [
+ ("input_ids", (batch_size, seq_len)),
+ ("segment_ids", (batch_size, seq_len)),
+ ("input_mask", (batch_size, seq_len)),
+ ]
+
+ if not os.path.exists(json_path) or not os.path.exists(params_path):
+ convert_to_qnn(onnx_path, json_path, params_path, input_info)
+
+ def deserialize():
+ try:
+ return deserialize_relay(json_path, params_path)
+ except TVMError:
+ # A serialized Relay json file may become invalid after TVM bump
+ # Update the serialized model and try loading again
+ convert_to_qnn(onnx_path, json_path, params_path, input_info)
+ return deserialize_relay(json_path, params_path)
+
+ mod, params = deserialize()
+
+ return mod, params, input_info
diff --git a/tests/python/unittest/test_meta_schedule_integration.py
b/tests/python/unittest/test_meta_schedule_integration.py
index 4620e83..5e375a7 100644
--- a/tests/python/unittest/test_meta_schedule_integration.py
+++ b/tests/python/unittest/test_meta_schedule_integration.py
@@ -32,6 +32,8 @@ from tvm.meta_schedule.utils import derived_object
from tvm.script import tir as T
from tvm.target import Target
from tvm.tir import Schedule
+from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base
+from tvm.meta_schedule.tune import extract_task_from_relay
# pylint:
disable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument,missing-docstring,invalid-name
@@ -154,5 +156,20 @@ def test_meta_schedule_integration_apply_history_best():
assert tvm.ir.structural_equal(mod, workload.mod)
[email protected]("Too slow on CI")
+def extract_task_qbert():
+ mod, params, _ = load_quantized_bert_base(batch_size=1, seq_len=128)
+ target = "llvm"
+ extracted_tasks = extract_task_from_relay(mod, target, params)
+ tune_tasks = list(
+ filter(
+ lambda task: "dense" in task.task_name or "batch_matmul" in
task.task_name,
+ extracted_tasks,
+ )
+ )
+ # three int8 dense, two int8 bmm, and one fp32 dense
+ assert len(tune_tasks) == 6
+
+
if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))