This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 794e1e3  [Testing] Add model loader for int8 BERT (#10622)
794e1e3 is described below

commit 794e1e36ec57f59cdc7431845f21cba01296b1e0
Author: Masahiro Masuda <[email protected]>
AuthorDate: Wed Mar 16 14:19:12 2022 +0900

    [Testing] Add model loader for int8 BERT (#10622)
    
    * add model loader for qat bert-base
    
    * add test
    
    * pylint
    
    * ignore mypy
    
    * Update python/tvm/meta_schedule/testing/tlcbench.py
    
    Co-authored-by: Junru Shao <[email protected]>
    
    * use a dedicated process for converting
    
    * return input info
    
    * encode batch size and seq_len information in cached file path
    
    Co-authored-by: Junru Shao <[email protected]>
---
 python/tvm/meta_schedule/testing/tlcbench.py       | 124 +++++++++++++++++++++
 .../unittest/test_meta_schedule_integration.py     |  17 +++
 2 files changed, 141 insertions(+)

diff --git a/python/tvm/meta_schedule/testing/tlcbench.py 
b/python/tvm/meta_schedule/testing/tlcbench.py
new file mode 100644
index 0000000..c6ab672
--- /dev/null
+++ b/python/tvm/meta_schedule/testing/tlcbench.py
@@ -0,0 +1,124 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,import-outside-toplevel
+# type: ignore
+"""Model loader for TLCBench."""
+import multiprocessing
+import os
+import logging
+import tvm
+from tvm import relay
+from tvm.error import TVMError
+from tvm.contrib.download import download_testdata
+
+
+log = logging.getLogger(__name__)
+
+
+def _convert(args):
+    onnx_model, shape_dict, json_path, params_path = args
+    mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, 
freeze_params=True)
+
+    seq = tvm.transform.Sequential(
+        [relay.transform.InferType(), 
relay.transform.FakeQuantizationToInteger(use_qat=True)]
+    )
+    mod = seq(mod)
+
+    with open(json_path, "w") as fo:
+        fo.write(tvm.ir.save_json(mod))
+
+    with open(params_path, "wb") as fo:
+        fo.write(relay.save_param_dict(params))
+
+
+def convert_to_qnn(onnx_path, json_path, params_path, input_info):
+    """Run the ONNX frontend and the FQ2I pass. The output is serialized to 
disk."""
+    import onnx
+
+    onnx_model = onnx.load(onnx_path)
+
+    shape_dict = dict(input_info)
+
+    log.info("Converting te ONNX model to Relay and running the FQ2I pass, it 
may take a while...")
+
+    with multiprocessing.Pool(processes=1) as pool:
+        pool.map(_convert, [(onnx_model, shape_dict, json_path, params_path)])
+
+
+def deserialize_relay(json_path, params_path):
+    with open(json_path, "r") as fi:
+        mod = tvm.ir.load_json(fi.read())
+
+    with open(params_path, "rb") as fi:
+        params = relay.load_param_dict(fi.read())
+
+    return mod, params
+
+
+def load_quantized_bert_base(batch_size=1, seq_len=384):
+    """
+    Load the quantized bert-base model from TLCBench, possibly downloading it 
from github
+    and caching the converted int8 QNN module to disk.
+
+    In addition to returing the relay module and its parameters, it also 
returns input name
+    and shape information, which can be used at the deployment time as follows:
+
+    ```
+    mod, params, input_info = load_quantized_bert_base()
+
+    ...
+
+    runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+    for name, shape in input_info:
+        arr = np.random.uniform(1, 10, size=shape).astype("int64")
+        runtime.set_input(name, arr)
+
+    runtime.run()
+    ```
+
+    """
+    url = 
"https://github.com/tlc-pack/TLCBench/raw/main/models/bert-base-qat.onnx";
+    log.info("Downloading quantized bert-base model.")
+    onnx_path = download_testdata(url, "bert-base-qat.onnx", module="tlcbench")
+    data_dir = os.path.dirname(onnx_path)
+
+    json_path = os.path.join(data_dir, "bert_base_int8_b%d_s%d.json" % 
(batch_size, seq_len))
+    params_path = os.path.join(data_dir, "bert_base_int8_b%d_s%d.params" % 
(batch_size, seq_len))
+
+    # Input names and order encoded in the ONNX model
+    input_info = [
+        ("input_ids", (batch_size, seq_len)),
+        ("segment_ids", (batch_size, seq_len)),
+        ("input_mask", (batch_size, seq_len)),
+    ]
+
+    if not os.path.exists(json_path) or not os.path.exists(params_path):
+        convert_to_qnn(onnx_path, json_path, params_path, input_info)
+
+    def deserialize():
+        try:
+            return deserialize_relay(json_path, params_path)
+        except TVMError:
+            # A serialized Relay json file may become invalid after TVM bump
+            # Update the serialized model and try loading again
+            convert_to_qnn(onnx_path, json_path, params_path, input_info)
+            return deserialize_relay(json_path, params_path)
+
+    mod, params = deserialize()
+
+    return mod, params, input_info
diff --git a/tests/python/unittest/test_meta_schedule_integration.py 
b/tests/python/unittest/test_meta_schedule_integration.py
index 4620e83..5e375a7 100644
--- a/tests/python/unittest/test_meta_schedule_integration.py
+++ b/tests/python/unittest/test_meta_schedule_integration.py
@@ -32,6 +32,8 @@ from tvm.meta_schedule.utils import derived_object
 from tvm.script import tir as T
 from tvm.target import Target
 from tvm.tir import Schedule
+from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base
+from tvm.meta_schedule.tune import extract_task_from_relay
 
 # pylint: 
disable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument,missing-docstring,invalid-name
 
@@ -154,5 +156,20 @@ def test_meta_schedule_integration_apply_history_best():
     assert tvm.ir.structural_equal(mod, workload.mod)
 
 
[email protected]("Too slow on CI")
+def extract_task_qbert():
+    mod, params, _ = load_quantized_bert_base(batch_size=1, seq_len=128)
+    target = "llvm"
+    extracted_tasks = extract_task_from_relay(mod, target, params)
+    tune_tasks = list(
+        filter(
+            lambda task: "dense" in task.task_name or "batch_matmul" in 
task.task_name,
+            extracted_tasks,
+        )
+    )
+    # three int8 dense, two int8 bmm, and one fp32 dense
+    assert len(tune_tasks) == 6
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))

Reply via email to