This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b7aada1441 [relay][frontend] aten::copy_ support for pytorch (#15502)
b7aada1441 is described below
commit b7aada1441abe79ec63eeae89a71a7033b23bac8
Author: jhlee525 <[email protected]>
AuthorDate: Thu Oct 19 10:53:28 2023 +0900
[relay][frontend] aten::copy_ support for pytorch (#15502)
* add handling logic for aten::copy_
* lint
* add test case
* lint
* remove __init__
* fix logic
* lint
* lint
* lint
* feedback
* lint
---------
Co-authored-by: jhlee525 <[email protected]>
---
python/tvm/relay/frontend/pytorch.py | 124 ++++++++++++++++++++++++++
tests/python/frontend/pytorch/test_forward.py | 26 ++++++
2 files changed, 150 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 89dcad03e6..81392a08ec 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -3747,6 +3747,95 @@ class PyTorchOpConverter:
)
return weight_g * (weight_v / norm_v)
+ def inplace_copy(self, inputs, input_types):
+ source = inputs[0]
+ values = inputs[1]
+ accumulate = inputs[2]
+ if not accumulate:
+ mode = "update"
+ else:
+ mode = "add"
+
+ # Track slice and select calls
+ slice_and_select_calls = []
+ while True:
+ if isinstance(source, _expr.Call) and source.op.name in [
+ "strided_slice",
+ "take",
+ ]:
+ slice_and_select_calls.append(source)
+ source = source.args[0]
+ else:
+ break
+ slice_and_select_calls = slice_and_select_calls[::-1]
+ source_shape = _infer_shape(source)
+
+ # Create index map
+ index_map = {}
+ squeezed_axes = []
+ for call in slice_and_select_calls:
+ if call.op.name == "strided_slice":
+ axes = call.attrs.axes
+ if axes is None:
+ axes = list(range(len(source_shape)))
+ begins = call.attrs.begin
+ ends = call.attrs.end
+ for axis, begin, end in zip(axes, begins, ends):
+ num_squeezed_axis = len([v for v in squeezed_axes if v <=
axis])
+ axis += num_squeezed_axis
+ # Set range
+ if begin < 0:
+ begin = source_shape[axis] + begin
+ if end < 0:
+ end = source_shape[axis] + end
+ if begin == 0 and end == source_shape[axis]:
+ continue
+ index_map[axis] = (begin.value, end.value)
+ elif call.op.name == "take":
+ num_squeezed_axis = len([v for v in squeezed_axes if v <=
axis])
+ axis = call.attrs.axis.value + num_squeezed_axis
+ idx = call.args[1]
+ assert isinstance(idx, _expr.Constant)
+ idx = idx.data.numpy().item()
+ if idx < 0:
+ idx = source_shape[axis] + idx
+ index_map[axis] = (idx, idx + 1)
+ values = _op.expand_dims(values, axis)
+ squeezed_axes.append(axis)
+ else:
+ pass
+ last_index_dim = np.max(list(index_map)).item()
+ for axis in range(last_index_dim + 1):
+ if axis not in index_map:
+ index_map[axis] = 0, source_shape[axis]
+
+ # Create indices
+ nelem = 1
+ for (begin, end) in index_map.values():
+ nelem *= end - begin
+ chunk_sizes = [nelem]
+ for i in range(1, last_index_dim + 1):
+ begin, end = index_map[i - 1]
+ chunk_sizes.append(chunk_sizes[-1] // (end - begin))
+ indices = []
+ for axis in range(last_index_dim + 1):
+ chunk_size = chunk_sizes[axis]
+ repeat = nelem // chunk_size
+ begin, end = index_map[axis]
+ step_size = chunk_size // (end - begin)
+ chunk = np.repeat(np.arange(begin, end), step_size)
+ chunk = np.concatenate([chunk] * repeat)
+ indices.append(chunk)
+ indices = np.stack(indices, axis=0).astype(np.int64)
+ new_shape = [indices.shape[0]] + [
+ index_map[i][1] - index_map[i][0] for i in range(last_index_dim +
1)
+ ]
+ indices = np.resize(indices, new_shape)
+ indices = _expr.const(indices)
+
+ # Return
+ return _op.scatter_nd(source, indices, values, mode)
+
# Operator mappings
def create_convert_map(self):
self.convert_map = {
@@ -4018,6 +4107,7 @@ class PyTorchOpConverter:
"aten::__rshift__": self.make_elemwise("right_shift"),
"aten::multinomial": self.multinomial,
"aten::_weight_norm": self.weight_norm,
+ "aten::copy_": self.inplace_copy,
}
def update_convert_map(self, custom_map):
@@ -4470,6 +4560,39 @@ def _run_jit_passes(graph, enable_lower_all_tuples=True):
torch._C._jit_pass_lower_all_tuples(graph)
+def _redirect_inplace_output(graph):
+ """
+ This pass redirects the output node of the in-place op i.e. aten::copy_.
+ Before:
+ %1: ...
+ %2: ...
+ %3: Float(requires_grad=0, device=cpu) = aten::copy_(%input, %1, %2)
+ return (%input)
+ After:
+ %1: ...
+ %2: ...
+ %3: Float(requires_grad=0, device=cpu) = aten::copy_(%input, %1, %2)
+ return (%3)
+ """
+ for node in graph.nodes():
+ if node.kind() == "aten::copy_":
+ node_inputs = list(node.inputs())
+ src_node = node_inputs[0].node()
+ slice_and_select_nodes = []
+ while True:
+ if src_node.kind() in ["aten::slice", "aten::select",
"aten::unsqueeze"]:
+ src_node = list(src_node.inputs())[0].node()
+ slice_and_select_nodes.append(src_node)
+ else:
+ break
+ if src_node.kind() == "prim::Param":
+ # First one is "self"
+ src_value = list(src_node.outputs())[1]
+ else:
+ src_value = src_node.output()
+ src_value.replaceAllUsesAfterNodeWith(node, node.output())
+
+
def _get_tensor_and_var(torch_tensor, name):
tensor = tvm.nd.array(torch_tensor.cpu().numpy())
var = _expr.var(name, shape=tensor.shape, dtype=tensor.dtype)
@@ -4971,6 +5094,7 @@ def from_pytorch(
break
_run_jit_passes(graph, enable_lower_all_tuples)
+ _redirect_inplace_output(graph)
if custom_convert_map:
converter.update_convert_map(custom_convert_map)
diff --git a/tests/python/frontend/pytorch/test_forward.py
b/tests/python/frontend/pytorch/test_forward.py
index 6bbb9ef5cc..abdbda8e40 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -5355,6 +5355,32 @@ def test_exporting_renamed_c_graph():
assert "%aten::_convolution_0" in graph
+def test_inplace_copy():
+ class SimpleInplaceCopy(torch.nn.Module):
+ def forward(self, x):
+ x[:5, 0, 5:] = x[:5, 0, 5:] + 1
+ return x
+
+ class NegativeSliceInplaceCopy(torch.nn.Module):
+ def forward(self, x):
+ x[5:-1, -1, :] = x[5:-1, -1, :] + 1
+ return x
+
+ class PartialDimensionInplaceCopy(torch.nn.Module):
+ def forward(self, x):
+ x[:5] = x[:5] + 1
+ x[0:5, ...] = x[0:5, ...] + 1
+ x[0:5, ..., -1] = x[0:5, ..., -1] + 1
+ return x
+
+ inputs = torch.randn(10, 10, 10)
+ verify_model(SimpleInplaceCopy(), [inputs])
+ inputs = torch.randn(10, 10, 10)
+ verify_model(NegativeSliceInplaceCopy(), [inputs])
+ inputs = torch.randn(10, 10, 10)
+ verify_model(PartialDimensionInplaceCopy(), [inputs])
+
+
class TestSetSpan:
"""test structural equal between translated / hand-crafted relay IR with
span tagged."""