This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4d95f2c9c5 TVMC: Add new text/relay frontend (#10941)
4d95f2c9c5 is described below

commit 4d95f2c9c5614d55cf0f7cb306164e463c37c7d9
Author: Philipp van Kempen <[email protected]>
AuthorDate: Tue Jul 19 17:21:19 2022 +0200

    TVMC: Add new text/relay frontend (#10941)
    
    * TVMC: Add new text/relay frontend
    
    This feature enables passing a textural representation of a relay module to 
the tvmc command line.
    
    Example: `tvmc compile relay.txt --target c --runtime=crt --executor=aot 
--executor-aot-unpacked-api=1 --pass-config tir.disable_vectorize=1 -f mlf`
    
    Currently it is not possible to supply parameters as it is mainly intended 
to be used for testing certain relay functions or operators. In the future 
(with minor changes to the tvmc frontend api) params could be passed via an 
additional i.e. `params.bin` file
    
    This commit also adds minimal unit testing of the added feature.
    
    Resolve PR comments
    
    TVMC: add warning if relay frontend is used
    
    * [TVMC] populate parameters with random values instead of ones
    
    * [TVMC] Relay frontend: do not populate input tensor buffers if 
--input-shapes is provided
    
    This prevents that the constants inputs are used for Constant folding,
    thus changing the complexity of the model.
    
    If there would be a way, to distinguish between model inputs and parameter 
this
    workaround would not be required.
    
    * [TVMC] Relay frontend: check provided file contents before calling 
tvm.parser.fromtext()
---
 python/tvm/driver/tvmc/frontends.py        | 70 ++++++++++++++++++++++++++++++
 tests/python/driver/tvmc/conftest.py       | 39 +++++++++++++++++
 tests/python/driver/tvmc/test_frontends.py | 13 ++++++
 3 files changed, 122 insertions(+)

diff --git a/python/tvm/driver/tvmc/frontends.py 
b/python/tvm/driver/tvmc/frontends.py
index cfe5a4ac7b..2da5483564 100644
--- a/python/tvm/driver/tvmc/frontends.py
+++ b/python/tvm/driver/tvmc/frontends.py
@@ -23,6 +23,7 @@ loading the tool.
 import logging
 import os
 import sys
+import re
 import importlib
 from abc import ABC
 from abc import abstractmethod
@@ -32,6 +33,7 @@ from pathlib import Path
 import numpy as np
 
 from tvm import relay
+from tvm import parser
 from tvm.driver.tvmc import TVMCException, TVMCImportError
 from tvm.driver.tvmc.model import TVMCModel
 
@@ -294,6 +296,73 @@ class PaddleFrontend(Frontend):
         return relay.frontend.from_paddle(prog, shape_dict=shape_dict, 
**kwargs)
 
 
+class RelayFrontend(Frontend):
+    """Relay frontend for TVMC"""
+
+    @staticmethod
+    def name():
+        return "relay"
+
+    @staticmethod
+    def suffixes():
+        return ["relay"]
+
+    def load(self, path, shape_dict=None, **kwargs):
+        with open(path, "r", encoding="utf-8") as relay_text:
+            text = relay_text.read()
+        if shape_dict is None:
+            logger.warning(
+                "Specify --input-shapes to ensure that model inputs "
+                "will not be considered as constants."
+            )
+
+        def _validate_text(text):
+            """Check the provided file contents.
+            The relay.txt artifact contained in the MLF is missing the version 
header and
+            the metadata which is required to use meta[relay.Constant]."""
+
+            if re.compile(r".*\#\[version\.*").match(text) is None:
+                raise TVMCException(
+                    "The relay model does not include the required version 
information."
+                )
+            if re.compile(r".*meta\[.+\].*", re.DOTALL).match(text):
+                if "#[metadata]" not in text:
+                    raise TVMCException(
+                        "The relay model does not include the required 
#[metadata] section. "
+                        "Use ir_mod.astext(show_meta_data=True) to export 
compatible code."
+                    )
+
+        _validate_text(text)
+
+        ir_mod = parser.fromtext(text)
+
+        if shape_dict:
+            input_names = shape_dict.keys()
+        else:
+            input_names = []
+
+        def _gen_params(ir_mod, skip_names=None):
+            """Populate the all the params in the mode with ones."""
+            main_func = ir_mod["main"]
+            shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in 
main_func.params}
+            type_dict = {p.name_hint: p.checked_type.dtype for p in 
main_func.params}
+            params = {}
+            for name, shape in shape_dict.items():
+                if skip_names and name in skip_names:
+                    continue
+
+                if "int" in type_dict[name]:
+                    data = np.random.randint(128, size=shape, 
dtype=type_dict[name])
+                else:
+                    data = np.random.uniform(-1, 1, 
size=shape).astype(type_dict[name])
+                params[name] = data
+            return params
+
+        params = _gen_params(ir_mod, skip_names=input_names)
+
+        return ir_mod, params
+
+
 ALL_FRONTENDS = [
     KerasFrontend,
     OnnxFrontend,
@@ -301,6 +370,7 @@ ALL_FRONTENDS = [
     TFLiteFrontend,
     PyTorchFrontend,
     PaddleFrontend,
+    RelayFrontend,
 ]
 
 
diff --git a/tests/python/driver/tvmc/conftest.py 
b/tests/python/driver/tvmc/conftest.py
index fcf079620e..48b465e507 100644
--- a/tests/python/driver/tvmc/conftest.py
+++ b/tests/python/driver/tvmc/conftest.py
@@ -17,6 +17,7 @@
 import os
 import pytest
 import tarfile
+import textwrap
 
 import numpy as np
 
@@ -229,3 +230,41 @@ def tflite_cnn_s_quantized(tmpdir_factory):
         "{}/{}".format(base_url, file_to_download), file_to_download, 
module=["tvmc"]
     )
     return model_file
+
+
[email protected](scope="session")
+def relay_text_conv2d(tmpdir_factory):
+    file_path = os.path.join(tmpdir_factory.mktemp("model"), "relay.txt")
+
+    RELAY_MODEL = textwrap.dedent(
+        """\
+        #[version = "0.0.5"]
+        def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 
3, 5, 5), int8]) {
+            %1 = nn.conv2d(
+                 %data,
+                 %weight,
+                 padding=[2, 2],
+                 channels=3,
+                 kernel_size=[5, 5],
+                 data_layout="NCHW",
+                 kernel_layout="OIHW",
+                 out_dtype="int32");
+            %2 = cast(nn.max_pool2d(%1, pool_size=[3, 3]), dtype="int8");
+            %3 = nn.conv2d(
+                 %2,
+                 %weight,
+                 padding=[2, 2],
+                 channels=3,
+                 kernel_size=[5, 5],
+                 data_layout="NCHW",
+                 kernel_layout="OIHW",
+                 out_dtype="int32");
+            %4 = nn.max_pool2d(%3, pool_size=[3, 3]);
+            %4
+        }
+    """
+    )
+
+    with open(file_path, "w") as relay_text:
+        relay_text.write(RELAY_MODEL)
+    return file_path
diff --git a/tests/python/driver/tvmc/test_frontends.py 
b/tests/python/driver/tvmc/test_frontends.py
index b76066994c..1e6efb4a3b 100644
--- a/tests/python/driver/tvmc/test_frontends.py
+++ b/tests/python/driver/tvmc/test_frontends.py
@@ -106,6 +106,12 @@ def test_guess_frontend_paddle():
     assert type(sut) is tvmc.frontends.PaddleFrontend
 
 
+def test_guess_frontend_relay():
+
+    sut = tvmc.frontends.guess_frontend("relay.relay")
+    assert type(sut) is tvmc.frontends.RelayFrontend
+
+
 def test_guess_frontend_invalid():
     with pytest.raises(TVMCException):
         tvmc.frontends.guess_frontend("not/a/file.txt")
@@ -193,6 +199,13 @@ def test_load_model__paddle(paddle_resnet50):
     assert type(tvmc_model.params) is dict
 
 
+def test_load_model__relay(relay_text_conv2d):
+    tvmc_model = tvmc.load(relay_text_conv2d, model_format="relay")
+    assert type(tvmc_model) is TVMCModel
+    assert type(tvmc_model.mod) is IRModule
+    assert type(tvmc_model.params) is dict
+
+
 def test_load_model___wrong_language__to_keras(tflite_mobilenet_v1_1_quant):
     # some CI environments wont offer TensorFlow/Keras, so skip in case it is 
not present
     pytest.importorskip("tensorflow")

Reply via email to