This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new bf071dea54 [Relay][Frontend] Preserve Pytorch Span Names (#16171) bf071dea54 is described below commit bf071dea54ceb2705c3819909723d8eb8c048dcb Author: Navya Mehta <134946169+navya-encha...@users.noreply.github.com> AuthorDate: Thu Dec 7 19:00:35 2023 -0500 [Relay][Frontend] Preserve Pytorch Span Names (#16171) * Preserve Pytorch Span Names * Update pytorch.py * Add tests * WIP * Changes and tests * Michael Klaiber feedback * Linting fix * Linting Feedback Pt.2 * Test changes * Modify to Pytorch 2.0 --- python/tvm/relay/frontend/pytorch.py | 119 ++++++++++++++++++---- tests/python/frontend/pytorch/test_span_naming.py | 106 +++++++++++++++++++ 2 files changed, 206 insertions(+), 19 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 1a481e9300..9583575bfc 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -21,6 +21,8 @@ """PT: PyTorch frontend.""" import functools import itertools +from abc import ABC +from typing import Dict import math import re import sys @@ -137,7 +139,9 @@ def _is_int_seq(seq): class PyTorchOpConverter: """A helper class for holding PyTorch op converters.""" - def __init__(self, prelude, default_dtype, use_parser_friendly_name=False): + def __init__( + self, prelude, default_dtype, use_parser_friendly_name=False, preserve_pytorch_scopes=False + ): self.prelude = prelude self.default_dtype = default_dtype self.create_convert_map() @@ -146,6 +150,7 @@ class PyTorchOpConverter: self.op_type_dict = {} # map from op type to its presenting order self.current_op = [] # stack for recording current processing op self.use_parser_friendly_name = use_parser_friendly_name + self.preserve_pytorch_scopes = preserve_pytorch_scopes # this incrementally infers the type, see the comments on the type visitor # above. @@ -4314,7 +4319,11 @@ class PyTorchOpConverter: def convert_block(self, block, outputs): """Translate Torch "Block", used for prim::If and prim::Loop""" ops = _get_operator_nodes( - block.nodes(), self.source_map, self.op_type_dict, self.use_parser_friendly_name + block.nodes(), + self.source_map, + self.op_type_dict, + self.use_parser_friendly_name, + self.preserve_pytorch_scopes, ) ret_names = _get_input_names(block.returnNode()) return self.convert_operators(ops, outputs, ret_names) @@ -4881,25 +4890,76 @@ def _get_constant(node): return None -def _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name): - """Rewrite debug name of node outputs with its operator type""" +class NodeNamer(ABC): + """Name each node and output edge in the relay graph""" - def _get_source_name(op_type): + def __init__(self, op_counter_dict: Dict[str, int]): + self.op_counter_dict = op_counter_dict + + def increment_counter(self, identifier: str) -> int: op_idx = 0 - if op_type in op_type_dict: - op_idx = op_type_dict[op_type] + 1 - op_type_dict[op_type] = op_idx - return "_".join([op_type, str(op_idx)]) + if identifier in self.op_counter_dict: + op_idx = self.op_counter_dict[identifier] + 1 + self.op_counter_dict[identifier] = op_idx + return op_idx - # get source name of operator and rename all of its outputs + def get_node_source_name(self, node) -> str: + raise NotImplementedError() + + def get_node_output_name(self, node_src_name: str, index: int) -> str: + raise NotImplementedError() + + +class DefaultNodeKindNamer(NodeNamer): + """ + Namer that uses a default naming based on the "type"/kind of node # e.g. node.kind(): aten::adaptive_max_pool2d # node_src_name -> aten::adaptive_max_pool2d_x # output_1 -> aten::adaptive_max_pool2d_x_0 # output_2 -> aten::adaptive_max_pool2d_x_1 + """ + + def get_node_source_name(self, node) -> str: + op_idx = self.increment_counter(node.kind()) + return "_".join([node.kind(), str(op_idx)]) + + def get_node_output_name(self, node_src_name: str, index: int) -> str: + return "_".join([node_src_name, str(index)]) + + +class PytorchScopePreservingNamer(NodeNamer): + """ + Namer that uses the Pytorch scope to name nodes. + eg. node could be called "bert.encoder.layer.11.output.dense" + """ + + def get_node_source_name(self, node) -> str: + # This works per the scope naming in Pytorch 2.0 and beyond. + scope_name_parts = node.scopeName().split("/") + imp_parts = [part.split("::")[-1] for part in scope_name_parts] + node_src_name = ".".join([part for part in imp_parts if part]) + return node_src_name + + def get_node_output_name(self, node_src_name: str, index: int) -> str: + op_idx = self.increment_counter(node_src_name) + return "_".join([node_src_name, str(op_idx), str(index)]) + + +def _rename_outputs( + node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes +): + """Rewrite debug name of node outputs with its operator type""" + namer = ( + PytorchScopePreservingNamer(op_type_dict) + if preserve_pytorch_scopes + else DefaultNodeKindNamer(op_type_dict) + ) + # get source name of operator and rename all of its outputs if node.kind() != "prim::GetAttr": - node_src_name = _get_source_name(node.kind()) + node_src_name = namer.get_node_source_name(node) for index, output in enumerate(node.outputs()): - output.setDebugName("_".join([node_src_name, str(index)])) + name = namer.get_node_output_name(node_src_name, index) + output.setDebugName(name) # update source map # if use_parser_friendly_name is True: e.g. prim::Constant_0 -> prim__Constant_0 if use_parser_friendly_name: @@ -4907,7 +4967,7 @@ def _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name): source_map[node] = node_src_name -def _debug_rename(graph, use_parser_friendly_name): +def _debug_rename(graph, use_parser_friendly_name, preserve_pytorch_scopes): """Returns map between node and source name""" source_map, op_type_dict = {}, {} prim_with_blocks = ["prim::If", "prim::Loop"] @@ -4919,13 +4979,21 @@ def _debug_rename(graph, use_parser_friendly_name): if node.kind() in prim_with_blocks: for block in node.blocks(): _traverse_graph(block.nodes()) - _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name) + _rename_outputs( + node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes + ) _traverse_graph(graph.nodes()) return source_map -def _get_operator_nodes(nodes, source_map=None, op_type_dict=None, use_parser_friendly_name=False): +def _get_operator_nodes( + nodes, + source_map=None, + op_type_dict=None, + use_parser_friendly_name=False, + preserve_pytorch_scopes=False, +): """Returns torch IR nodes that need conversion to Relay""" ops, should_rename_graph = [], all([source_map, op_type_dict]) is not None @@ -4935,7 +5003,9 @@ def _get_operator_nodes(nodes, source_map=None, op_type_dict=None, use_parser_fr continue if should_rename_graph: - _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name) + _rename_outputs( + node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes + ) if node.outputsSize() > 1: node_name = "_".join(_get_output_names(node)) @@ -5197,6 +5267,7 @@ def from_pytorch( use_parser_friendly_name=False, keep_quantized_weight=False, export_renamed_c_graph_path=None, + preserve_pytorch_scopes=False, ): """Load PyTorch model in the form of a scripted PyTorch model and convert into relay. The companion parameters will be handled automatically. @@ -5244,6 +5315,10 @@ def from_pytorch( During the conversion, variable names in torch._C.Graph will be assigned based on their op types. The exported text file can be the reference to spans. + preserve_pytorch_scopes : bool + When naming the nodes in the Relay graph, use the "scope name" from the Pytorch model. + If false, a default namer is used that does not preserve the Pytorch scope names. + Returns ------- mod : tvm.IRModule @@ -5258,7 +5333,9 @@ def from_pytorch( prelude = Prelude(mod) enable_lower_all_tuples = True - converter = PyTorchOpConverter(prelude, default_dtype, use_parser_friendly_name) + converter = PyTorchOpConverter( + prelude, default_dtype, use_parser_friendly_name, preserve_pytorch_scopes + ) graph = script_module.graph.copy() @@ -5290,7 +5367,7 @@ def from_pytorch( # rename _C.Graph here for constructing meaningful source name of graph nodes # by doing so, we could Use source_map as the reference to rename model parameters - source_map = _debug_rename(graph, use_parser_friendly_name) + source_map = _debug_rename(graph, use_parser_friendly_name, preserve_pytorch_scopes) param_vars, tensors, packed_param_map, param_debug_name_map = convert_params( graph, params, source_map, use_parser_friendly_name ) @@ -5318,7 +5395,11 @@ def from_pytorch( converter.update_convert_map(qnn_torch.convert_map) operator_nodes = _get_operator_nodes( - graph.nodes(), converter.source_map, converter.op_type_dict, use_parser_friendly_name + graph.nodes(), + converter.source_map, + converter.op_type_dict, + use_parser_friendly_name, + preserve_pytorch_scopes, ) ret_name = _get_input_names(graph.return_node()) outputs = converter.convert_operators(operator_nodes, outputs, ret_name) diff --git a/tests/python/frontend/pytorch/test_span_naming.py b/tests/python/frontend/pytorch/test_span_naming.py new file mode 100644 index 0000000000..fb39ddf4f0 --- /dev/null +++ b/tests/python/frontend/pytorch/test_span_naming.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks +# pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except +# pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda +# pylint: disable=missing-function-docstring, redefined-builtin, use-implicit-booleaness-not-comparison +"""Tests to ensure span names are correctly populated when importing Pytorch""" +from torch import nn +import torch +import tvm + + +class NestedConvModule(nn.Module): + """Module that performs Conv2d and relu activation""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.relu = nn.ReLU() + + def forward(self, x): + x = self.relu(self.conv(x)) + return x + + +class NestedFinalModule(nn.Module): + """Simple module that adds 2 inputs""" + + def forward(self, x, y): + return x + y + + +class SimpleTwoConvModule(nn.Module): + """ + ML model that performs 2 convolutions and adds them together. + All operations are inside nested modules to make scope names interesting. + """ + + def __init__(self): + super().__init__() + # First convolutional module + self.image_block1 = NestedConvModule(in_channels=3, out_channels=64) + # Second convolutional module + self.image_block2 = NestedConvModule(in_channels=64, out_channels=64) + self.final_block = NestedFinalModule() + + def forward(self, x): + # Forward pass through the first convolutional module + x1 = self.image_block1(x) + # Forward pass through the second convolutional module + x2 = self.image_block2(x1) + # Add the outputs of the two convolutional modules + return self.final_block(x1, x2) + + +def test_pytorch_scope_based_span_names(): + model = SimpleTwoConvModule() + sample_input = torch.zeros((1, 3, 64, 64), dtype=torch.float32) + with torch.no_grad(): + traced_torch_model = torch.jit.trace(model, sample_input) + import_input = [("model_input", (1, 3, 64, 64))] + relay_model_ir, relay_model_params = tvm.relay.frontend.from_pytorch( + traced_torch_model, import_input, preserve_pytorch_scopes=True + ) + # If specified, we are preserving the pytorch named spans + for block in [1, 2]: + for key in ["weight", "bias"]: + assert f"image_block{block}.conv.{key}" in relay_model_params.keys() + # Manually check all span names since asserting structural equality is not sufficient + current_call = relay_model_ir["main"].body + assert current_call.op.name == "add" + assert current_call.span is not None and current_call.span.source_name.name == "final_block" + current_call = current_call.args[1] + for block in [2, 1]: + assert current_call.op.name == "nn.relu" + assert ( + current_call.span is not None + and current_call.span.source_name.name == f"image_block{block}.relu" + ) + current_call = current_call.args[0] + assert current_call.op.name == "nn.bias_add" + assert ( + current_call.span is not None + and current_call.span.source_name.name == f"image_block{block}.conv" + ) + current_call = current_call.args[0] + assert current_call.op.name == "nn.conv2d" + assert ( + current_call.span is not None + and current_call.span.source_name.name == f"image_block{block}.conv" + ) + current_call = current_call.args[0]