This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bf071dea54 [Relay][Frontend] Preserve Pytorch Span Names (#16171)
bf071dea54 is described below

commit bf071dea54ceb2705c3819909723d8eb8c048dcb
Author: Navya Mehta <134946169+navya-encha...@users.noreply.github.com>
AuthorDate: Thu Dec 7 19:00:35 2023 -0500

    [Relay][Frontend] Preserve Pytorch Span Names (#16171)
    
    * Preserve Pytorch Span Names
    
    * Update pytorch.py
    
    * Add tests
    
    * WIP
    
    * Changes and tests
    
    * Michael Klaiber feedback
    
    * Linting fix
    
    * Linting Feedback Pt.2
    
    * Test changes
    
    * Modify to Pytorch 2.0
---
 python/tvm/relay/frontend/pytorch.py              | 119 ++++++++++++++++++----
 tests/python/frontend/pytorch/test_span_naming.py | 106 +++++++++++++++++++
 2 files changed, 206 insertions(+), 19 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 1a481e9300..9583575bfc 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -21,6 +21,8 @@
 """PT: PyTorch frontend."""
 import functools
 import itertools
+from abc import ABC
+from typing import Dict
 import math
 import re
 import sys
@@ -137,7 +139,9 @@ def _is_int_seq(seq):
 class PyTorchOpConverter:
     """A helper class for holding PyTorch op converters."""
 
-    def __init__(self, prelude, default_dtype, use_parser_friendly_name=False):
+    def __init__(
+        self, prelude, default_dtype, use_parser_friendly_name=False, 
preserve_pytorch_scopes=False
+    ):
         self.prelude = prelude
         self.default_dtype = default_dtype
         self.create_convert_map()
@@ -146,6 +150,7 @@ class PyTorchOpConverter:
         self.op_type_dict = {}  # map from op type to its presenting order
         self.current_op = []  # stack for recording current processing op
         self.use_parser_friendly_name = use_parser_friendly_name
+        self.preserve_pytorch_scopes = preserve_pytorch_scopes
 
     # this incrementally infers the type, see the comments on the type visitor
     # above.
@@ -4314,7 +4319,11 @@ class PyTorchOpConverter:
     def convert_block(self, block, outputs):
         """Translate Torch "Block", used for prim::If and prim::Loop"""
         ops = _get_operator_nodes(
-            block.nodes(), self.source_map, self.op_type_dict, 
self.use_parser_friendly_name
+            block.nodes(),
+            self.source_map,
+            self.op_type_dict,
+            self.use_parser_friendly_name,
+            self.preserve_pytorch_scopes,
         )
         ret_names = _get_input_names(block.returnNode())
         return self.convert_operators(ops, outputs, ret_names)
@@ -4881,25 +4890,76 @@ def _get_constant(node):
         return None
 
 
-def _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name):
-    """Rewrite debug name of node outputs with its operator type"""
+class NodeNamer(ABC):
+    """Name each node and output edge in the relay graph"""
 
-    def _get_source_name(op_type):
+    def __init__(self, op_counter_dict: Dict[str, int]):
+        self.op_counter_dict = op_counter_dict
+
+    def increment_counter(self, identifier: str) -> int:
         op_idx = 0
-        if op_type in op_type_dict:
-            op_idx = op_type_dict[op_type] + 1
-        op_type_dict[op_type] = op_idx
-        return "_".join([op_type, str(op_idx)])
+        if identifier in self.op_counter_dict:
+            op_idx = self.op_counter_dict[identifier] + 1
+        self.op_counter_dict[identifier] = op_idx
+        return op_idx
 
-    # get source name of operator and rename all of its outputs
+    def get_node_source_name(self, node) -> str:
+        raise NotImplementedError()
+
+    def get_node_output_name(self, node_src_name: str, index: int) -> str:
+        raise NotImplementedError()
+
+
+class DefaultNodeKindNamer(NodeNamer):
+    """
+    Namer that uses a default naming based on the "type"/kind of node
     # e.g. node.kind(): aten::adaptive_max_pool2d
     # node_src_name -> aten::adaptive_max_pool2d_x
     # output_1 -> aten::adaptive_max_pool2d_x_0
     # output_2 -> aten::adaptive_max_pool2d_x_1
+    """
+
+    def get_node_source_name(self, node) -> str:
+        op_idx = self.increment_counter(node.kind())
+        return "_".join([node.kind(), str(op_idx)])
+
+    def get_node_output_name(self, node_src_name: str, index: int) -> str:
+        return "_".join([node_src_name, str(index)])
+
+
+class PytorchScopePreservingNamer(NodeNamer):
+    """
+    Namer that uses the Pytorch scope to name nodes.
+    eg. node could be called "bert.encoder.layer.11.output.dense"
+    """
+
+    def get_node_source_name(self, node) -> str:
+        # This works per the scope naming in Pytorch 2.0 and beyond.
+        scope_name_parts = node.scopeName().split("/")
+        imp_parts = [part.split("::")[-1] for part in scope_name_parts]
+        node_src_name = ".".join([part for part in imp_parts if part])
+        return node_src_name
+
+    def get_node_output_name(self, node_src_name: str, index: int) -> str:
+        op_idx = self.increment_counter(node_src_name)
+        return "_".join([node_src_name, str(op_idx), str(index)])
+
+
+def _rename_outputs(
+    node, source_map, op_type_dict, use_parser_friendly_name, 
preserve_pytorch_scopes
+):
+    """Rewrite debug name of node outputs with its operator type"""
+    namer = (
+        PytorchScopePreservingNamer(op_type_dict)
+        if preserve_pytorch_scopes
+        else DefaultNodeKindNamer(op_type_dict)
+    )
+    # get source name of operator and rename all of its outputs
     if node.kind() != "prim::GetAttr":
-        node_src_name = _get_source_name(node.kind())
+        node_src_name = namer.get_node_source_name(node)
         for index, output in enumerate(node.outputs()):
-            output.setDebugName("_".join([node_src_name, str(index)]))
+            name = namer.get_node_output_name(node_src_name, index)
+            output.setDebugName(name)
         # update source map
         # if use_parser_friendly_name is True: e.g. prim::Constant_0 -> 
prim__Constant_0
         if use_parser_friendly_name:
@@ -4907,7 +4967,7 @@ def _rename_outputs(node, source_map, op_type_dict, 
use_parser_friendly_name):
         source_map[node] = node_src_name
 
 
-def _debug_rename(graph, use_parser_friendly_name):
+def _debug_rename(graph, use_parser_friendly_name, preserve_pytorch_scopes):
     """Returns map between node and source name"""
     source_map, op_type_dict = {}, {}
     prim_with_blocks = ["prim::If", "prim::Loop"]
@@ -4919,13 +4979,21 @@ def _debug_rename(graph, use_parser_friendly_name):
             if node.kind() in prim_with_blocks:
                 for block in node.blocks():
                     _traverse_graph(block.nodes())
-            _rename_outputs(node, source_map, op_type_dict, 
use_parser_friendly_name)
+            _rename_outputs(
+                node, source_map, op_type_dict, use_parser_friendly_name, 
preserve_pytorch_scopes
+            )
 
     _traverse_graph(graph.nodes())
     return source_map
 
 
-def _get_operator_nodes(nodes, source_map=None, op_type_dict=None, 
use_parser_friendly_name=False):
+def _get_operator_nodes(
+    nodes,
+    source_map=None,
+    op_type_dict=None,
+    use_parser_friendly_name=False,
+    preserve_pytorch_scopes=False,
+):
     """Returns torch IR nodes that need conversion to Relay"""
     ops, should_rename_graph = [], all([source_map, op_type_dict]) is not None
 
@@ -4935,7 +5003,9 @@ def _get_operator_nodes(nodes, source_map=None, 
op_type_dict=None, use_parser_fr
             continue
 
         if should_rename_graph:
-            _rename_outputs(node, source_map, op_type_dict, 
use_parser_friendly_name)
+            _rename_outputs(
+                node, source_map, op_type_dict, use_parser_friendly_name, 
preserve_pytorch_scopes
+            )
 
         if node.outputsSize() > 1:
             node_name = "_".join(_get_output_names(node))
@@ -5197,6 +5267,7 @@ def from_pytorch(
     use_parser_friendly_name=False,
     keep_quantized_weight=False,
     export_renamed_c_graph_path=None,
+    preserve_pytorch_scopes=False,
 ):
     """Load PyTorch model in the form of a scripted PyTorch model and convert 
into relay.
     The companion parameters will be handled automatically.
@@ -5244,6 +5315,10 @@ def from_pytorch(
         During the conversion, variable names in torch._C.Graph will be 
assigned based on their op
         types. The exported text file can be the reference to spans.
 
+    preserve_pytorch_scopes : bool
+        When naming the nodes in the Relay graph, use the "scope name" from 
the Pytorch model.
+        If false, a default namer is used that does not preserve the Pytorch 
scope names.
+
     Returns
     -------
     mod : tvm.IRModule
@@ -5258,7 +5333,9 @@ def from_pytorch(
     prelude = Prelude(mod)
     enable_lower_all_tuples = True
 
-    converter = PyTorchOpConverter(prelude, default_dtype, 
use_parser_friendly_name)
+    converter = PyTorchOpConverter(
+        prelude, default_dtype, use_parser_friendly_name, 
preserve_pytorch_scopes
+    )
 
     graph = script_module.graph.copy()
 
@@ -5290,7 +5367,7 @@ def from_pytorch(
 
     # rename _C.Graph here for constructing meaningful source name of graph 
nodes
     # by doing so, we could Use source_map as the reference to rename model 
parameters
-    source_map = _debug_rename(graph, use_parser_friendly_name)
+    source_map = _debug_rename(graph, use_parser_friendly_name, 
preserve_pytorch_scopes)
     param_vars, tensors, packed_param_map, param_debug_name_map = 
convert_params(
         graph, params, source_map, use_parser_friendly_name
     )
@@ -5318,7 +5395,11 @@ def from_pytorch(
         converter.update_convert_map(qnn_torch.convert_map)
 
     operator_nodes = _get_operator_nodes(
-        graph.nodes(), converter.source_map, converter.op_type_dict, 
use_parser_friendly_name
+        graph.nodes(),
+        converter.source_map,
+        converter.op_type_dict,
+        use_parser_friendly_name,
+        preserve_pytorch_scopes,
     )
     ret_name = _get_input_names(graph.return_node())
     outputs = converter.convert_operators(operator_nodes, outputs, ret_name)
diff --git a/tests/python/frontend/pytorch/test_span_naming.py 
b/tests/python/frontend/pytorch/test_span_naming.py
new file mode 100644
index 0000000000..fb39ddf4f0
--- /dev/null
+++ b/tests/python/frontend/pytorch/test_span_naming.py
@@ -0,0 +1,106 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=import-self, too-many-lines, len-as-condition, 
no-else-return, unused-variable, too-many-nested-blocks
+# pylint: disable=consider-iterating-dictionary, invalid-name, 
unused-argument, unused-variable, broad-except
+# pylint: disable=import-outside-toplevel, simplifiable-if-expression, 
cell-var-from-loop, unnecessary-lambda
+# pylint: disable=missing-function-docstring, redefined-builtin, 
use-implicit-booleaness-not-comparison
+"""Tests to ensure span names are correctly populated when importing Pytorch"""
+from torch import nn
+import torch
+import tvm
+
+
+class NestedConvModule(nn.Module):
+    """Module that performs Conv2d and relu activation"""
+
+    def __init__(self, in_channels, out_channels):
+        super().__init__()
+        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
stride=1, padding=1)
+        self.relu = nn.ReLU()
+
+    def forward(self, x):
+        x = self.relu(self.conv(x))
+        return x
+
+
+class NestedFinalModule(nn.Module):
+    """Simple module that adds 2 inputs"""
+
+    def forward(self, x, y):
+        return x + y
+
+
+class SimpleTwoConvModule(nn.Module):
+    """
+    ML model that performs 2 convolutions and adds them together.
+    All operations are inside nested modules to make scope names interesting.
+    """
+
+    def __init__(self):
+        super().__init__()
+        # First convolutional module
+        self.image_block1 = NestedConvModule(in_channels=3, out_channels=64)
+        # Second convolutional module
+        self.image_block2 = NestedConvModule(in_channels=64, out_channels=64)
+        self.final_block = NestedFinalModule()
+
+    def forward(self, x):
+        # Forward pass through the first convolutional module
+        x1 = self.image_block1(x)
+        # Forward pass through the second convolutional module
+        x2 = self.image_block2(x1)
+        # Add the outputs of the two convolutional modules
+        return self.final_block(x1, x2)
+
+
+def test_pytorch_scope_based_span_names():
+    model = SimpleTwoConvModule()
+    sample_input = torch.zeros((1, 3, 64, 64), dtype=torch.float32)
+    with torch.no_grad():
+        traced_torch_model = torch.jit.trace(model, sample_input)
+    import_input = [("model_input", (1, 3, 64, 64))]
+    relay_model_ir, relay_model_params = tvm.relay.frontend.from_pytorch(
+        traced_torch_model, import_input, preserve_pytorch_scopes=True
+    )
+    # If specified, we are preserving the pytorch named spans
+    for block in [1, 2]:
+        for key in ["weight", "bias"]:
+            assert f"image_block{block}.conv.{key}" in 
relay_model_params.keys()
+    # Manually check all span names since asserting structural equality is not 
sufficient
+    current_call = relay_model_ir["main"].body
+    assert current_call.op.name == "add"
+    assert current_call.span is not None and 
current_call.span.source_name.name == "final_block"
+    current_call = current_call.args[1]
+    for block in [2, 1]:
+        assert current_call.op.name == "nn.relu"
+        assert (
+            current_call.span is not None
+            and current_call.span.source_name.name == 
f"image_block{block}.relu"
+        )
+        current_call = current_call.args[0]
+        assert current_call.op.name == "nn.bias_add"
+        assert (
+            current_call.span is not None
+            and current_call.span.source_name.name == 
f"image_block{block}.conv"
+        )
+        current_call = current_call.args[0]
+        assert current_call.op.name == "nn.conv2d"
+        assert (
+            current_call.span is not None
+            and current_call.span.source_name.name == 
f"image_block{block}.conv"
+        )
+        current_call = current_call.args[0]

Reply via email to