This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b7aada1441 [relay][frontend] aten::copy_ support for pytorch (#15502)
b7aada1441 is described below

commit b7aada1441abe79ec63eeae89a71a7033b23bac8
Author: jhlee525 <bmrcreativ...@gmail.com>
AuthorDate: Thu Oct 19 10:53:28 2023 +0900

    [relay][frontend] aten::copy_ support for pytorch (#15502)
    
    * add handling logic for aten::copy_
    
    * lint
    
    * add test case
    
    * lint
    
    * remove __init__
    
    * fix logic
    
    * lint
    
    * lint
    
    * lint
    
    * feedback
    
    * lint
    
    ---------
    
    Co-authored-by: jhlee525 <jh...@rebellions.ai>
---
 python/tvm/relay/frontend/pytorch.py          | 124 ++++++++++++++++++++++++++
 tests/python/frontend/pytorch/test_forward.py |  26 ++++++
 2 files changed, 150 insertions(+)

diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 89dcad03e6..81392a08ec 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -3747,6 +3747,95 @@ class PyTorchOpConverter:
         )
         return weight_g * (weight_v / norm_v)
 
+    def inplace_copy(self, inputs, input_types):
+        source = inputs[0]
+        values = inputs[1]
+        accumulate = inputs[2]
+        if not accumulate:
+            mode = "update"
+        else:
+            mode = "add"
+
+        # Track slice and select calls
+        slice_and_select_calls = []
+        while True:
+            if isinstance(source, _expr.Call) and source.op.name in [
+                "strided_slice",
+                "take",
+            ]:
+                slice_and_select_calls.append(source)
+                source = source.args[0]
+            else:
+                break
+        slice_and_select_calls = slice_and_select_calls[::-1]
+        source_shape = _infer_shape(source)
+
+        # Create index map
+        index_map = {}
+        squeezed_axes = []
+        for call in slice_and_select_calls:
+            if call.op.name == "strided_slice":
+                axes = call.attrs.axes
+                if axes is None:
+                    axes = list(range(len(source_shape)))
+                begins = call.attrs.begin
+                ends = call.attrs.end
+                for axis, begin, end in zip(axes, begins, ends):
+                    num_squeezed_axis = len([v for v in squeezed_axes if v <= 
axis])
+                    axis += num_squeezed_axis
+                    # Set range
+                    if begin < 0:
+                        begin = source_shape[axis] + begin
+                    if end < 0:
+                        end = source_shape[axis] + end
+                    if begin == 0 and end == source_shape[axis]:
+                        continue
+                    index_map[axis] = (begin.value, end.value)
+            elif call.op.name == "take":
+                num_squeezed_axis = len([v for v in squeezed_axes if v <= 
axis])
+                axis = call.attrs.axis.value + num_squeezed_axis
+                idx = call.args[1]
+                assert isinstance(idx, _expr.Constant)
+                idx = idx.data.numpy().item()
+                if idx < 0:
+                    idx = source_shape[axis] + idx
+                index_map[axis] = (idx, idx + 1)
+                values = _op.expand_dims(values, axis)
+                squeezed_axes.append(axis)
+            else:
+                pass
+        last_index_dim = np.max(list(index_map)).item()
+        for axis in range(last_index_dim + 1):
+            if axis not in index_map:
+                index_map[axis] = 0, source_shape[axis]
+
+        # Create indices
+        nelem = 1
+        for (begin, end) in index_map.values():
+            nelem *= end - begin
+        chunk_sizes = [nelem]
+        for i in range(1, last_index_dim + 1):
+            begin, end = index_map[i - 1]
+            chunk_sizes.append(chunk_sizes[-1] // (end - begin))
+        indices = []
+        for axis in range(last_index_dim + 1):
+            chunk_size = chunk_sizes[axis]
+            repeat = nelem // chunk_size
+            begin, end = index_map[axis]
+            step_size = chunk_size // (end - begin)
+            chunk = np.repeat(np.arange(begin, end), step_size)
+            chunk = np.concatenate([chunk] * repeat)
+            indices.append(chunk)
+        indices = np.stack(indices, axis=0).astype(np.int64)
+        new_shape = [indices.shape[0]] + [
+            index_map[i][1] - index_map[i][0] for i in range(last_index_dim + 
1)
+        ]
+        indices = np.resize(indices, new_shape)
+        indices = _expr.const(indices)
+
+        # Return
+        return _op.scatter_nd(source, indices, values, mode)
+
     # Operator mappings
     def create_convert_map(self):
         self.convert_map = {
@@ -4018,6 +4107,7 @@ class PyTorchOpConverter:
             "aten::__rshift__": self.make_elemwise("right_shift"),
             "aten::multinomial": self.multinomial,
             "aten::_weight_norm": self.weight_norm,
+            "aten::copy_": self.inplace_copy,
         }
 
     def update_convert_map(self, custom_map):
@@ -4470,6 +4560,39 @@ def _run_jit_passes(graph, enable_lower_all_tuples=True):
         torch._C._jit_pass_lower_all_tuples(graph)
 
 
+def _redirect_inplace_output(graph):
+    """
+    This pass redirects the output node of the in-place op i.e. aten::copy_.
+    Before:
+      %1: ...
+      %2: ...
+      %3: Float(requires_grad=0, device=cpu) = aten::copy_(%input, %1, %2)
+      return (%input)
+    After:
+      %1: ...
+      %2: ...
+      %3: Float(requires_grad=0, device=cpu) = aten::copy_(%input, %1, %2)
+      return (%3)
+    """
+    for node in graph.nodes():
+        if node.kind() == "aten::copy_":
+            node_inputs = list(node.inputs())
+            src_node = node_inputs[0].node()
+            slice_and_select_nodes = []
+            while True:
+                if src_node.kind() in ["aten::slice", "aten::select", 
"aten::unsqueeze"]:
+                    src_node = list(src_node.inputs())[0].node()
+                    slice_and_select_nodes.append(src_node)
+                else:
+                    break
+            if src_node.kind() == "prim::Param":
+                # First one is "self"
+                src_value = list(src_node.outputs())[1]
+            else:
+                src_value = src_node.output()
+            src_value.replaceAllUsesAfterNodeWith(node, node.output())
+
+
 def _get_tensor_and_var(torch_tensor, name):
     tensor = tvm.nd.array(torch_tensor.cpu().numpy())
     var = _expr.var(name, shape=tensor.shape, dtype=tensor.dtype)
@@ -4971,6 +5094,7 @@ def from_pytorch(
             break
 
     _run_jit_passes(graph, enable_lower_all_tuples)
+    _redirect_inplace_output(graph)
 
     if custom_convert_map:
         converter.update_convert_map(custom_convert_map)
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 6bbb9ef5cc..abdbda8e40 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -5355,6 +5355,32 @@ def test_exporting_renamed_c_graph():
         assert "%aten::_convolution_0" in graph
 
 
+def test_inplace_copy():
+    class SimpleInplaceCopy(torch.nn.Module):
+        def forward(self, x):
+            x[:5, 0, 5:] = x[:5, 0, 5:] + 1
+            return x
+
+    class NegativeSliceInplaceCopy(torch.nn.Module):
+        def forward(self, x):
+            x[5:-1, -1, :] = x[5:-1, -1, :] + 1
+            return x
+
+    class PartialDimensionInplaceCopy(torch.nn.Module):
+        def forward(self, x):
+            x[:5] = x[:5] + 1
+            x[0:5, ...] = x[0:5, ...] + 1
+            x[0:5, ..., -1] = x[0:5, ..., -1] + 1
+            return x
+
+    inputs = torch.randn(10, 10, 10)
+    verify_model(SimpleInplaceCopy(), [inputs])
+    inputs = torch.randn(10, 10, 10)
+    verify_model(NegativeSliceInplaceCopy(), [inputs])
+    inputs = torch.randn(10, 10, 10)
+    verify_model(PartialDimensionInplaceCopy(), [inputs])
+
+
 class TestSetSpan:
     """test structural equal between translated / hand-crafted relay IR with 
span tagged."""
 

Reply via email to