This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new b7aada1441 [relay][frontend] aten::copy_ support for pytorch (#15502) b7aada1441 is described below commit b7aada1441abe79ec63eeae89a71a7033b23bac8 Author: jhlee525 <bmrcreativ...@gmail.com> AuthorDate: Thu Oct 19 10:53:28 2023 +0900 [relay][frontend] aten::copy_ support for pytorch (#15502) * add handling logic for aten::copy_ * lint * add test case * lint * remove __init__ * fix logic * lint * lint * lint * feedback * lint --------- Co-authored-by: jhlee525 <jh...@rebellions.ai> --- python/tvm/relay/frontend/pytorch.py | 124 ++++++++++++++++++++++++++ tests/python/frontend/pytorch/test_forward.py | 26 ++++++ 2 files changed, 150 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 89dcad03e6..81392a08ec 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3747,6 +3747,95 @@ class PyTorchOpConverter: ) return weight_g * (weight_v / norm_v) + def inplace_copy(self, inputs, input_types): + source = inputs[0] + values = inputs[1] + accumulate = inputs[2] + if not accumulate: + mode = "update" + else: + mode = "add" + + # Track slice and select calls + slice_and_select_calls = [] + while True: + if isinstance(source, _expr.Call) and source.op.name in [ + "strided_slice", + "take", + ]: + slice_and_select_calls.append(source) + source = source.args[0] + else: + break + slice_and_select_calls = slice_and_select_calls[::-1] + source_shape = _infer_shape(source) + + # Create index map + index_map = {} + squeezed_axes = [] + for call in slice_and_select_calls: + if call.op.name == "strided_slice": + axes = call.attrs.axes + if axes is None: + axes = list(range(len(source_shape))) + begins = call.attrs.begin + ends = call.attrs.end + for axis, begin, end in zip(axes, begins, ends): + num_squeezed_axis = len([v for v in squeezed_axes if v <= axis]) + axis += num_squeezed_axis + # Set range + if begin < 0: + begin = source_shape[axis] + begin + if end < 0: + end = source_shape[axis] + end + if begin == 0 and end == source_shape[axis]: + continue + index_map[axis] = (begin.value, end.value) + elif call.op.name == "take": + num_squeezed_axis = len([v for v in squeezed_axes if v <= axis]) + axis = call.attrs.axis.value + num_squeezed_axis + idx = call.args[1] + assert isinstance(idx, _expr.Constant) + idx = idx.data.numpy().item() + if idx < 0: + idx = source_shape[axis] + idx + index_map[axis] = (idx, idx + 1) + values = _op.expand_dims(values, axis) + squeezed_axes.append(axis) + else: + pass + last_index_dim = np.max(list(index_map)).item() + for axis in range(last_index_dim + 1): + if axis not in index_map: + index_map[axis] = 0, source_shape[axis] + + # Create indices + nelem = 1 + for (begin, end) in index_map.values(): + nelem *= end - begin + chunk_sizes = [nelem] + for i in range(1, last_index_dim + 1): + begin, end = index_map[i - 1] + chunk_sizes.append(chunk_sizes[-1] // (end - begin)) + indices = [] + for axis in range(last_index_dim + 1): + chunk_size = chunk_sizes[axis] + repeat = nelem // chunk_size + begin, end = index_map[axis] + step_size = chunk_size // (end - begin) + chunk = np.repeat(np.arange(begin, end), step_size) + chunk = np.concatenate([chunk] * repeat) + indices.append(chunk) + indices = np.stack(indices, axis=0).astype(np.int64) + new_shape = [indices.shape[0]] + [ + index_map[i][1] - index_map[i][0] for i in range(last_index_dim + 1) + ] + indices = np.resize(indices, new_shape) + indices = _expr.const(indices) + + # Return + return _op.scatter_nd(source, indices, values, mode) + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -4018,6 +4107,7 @@ class PyTorchOpConverter: "aten::__rshift__": self.make_elemwise("right_shift"), "aten::multinomial": self.multinomial, "aten::_weight_norm": self.weight_norm, + "aten::copy_": self.inplace_copy, } def update_convert_map(self, custom_map): @@ -4470,6 +4560,39 @@ def _run_jit_passes(graph, enable_lower_all_tuples=True): torch._C._jit_pass_lower_all_tuples(graph) +def _redirect_inplace_output(graph): + """ + This pass redirects the output node of the in-place op i.e. aten::copy_. + Before: + %1: ... + %2: ... + %3: Float(requires_grad=0, device=cpu) = aten::copy_(%input, %1, %2) + return (%input) + After: + %1: ... + %2: ... + %3: Float(requires_grad=0, device=cpu) = aten::copy_(%input, %1, %2) + return (%3) + """ + for node in graph.nodes(): + if node.kind() == "aten::copy_": + node_inputs = list(node.inputs()) + src_node = node_inputs[0].node() + slice_and_select_nodes = [] + while True: + if src_node.kind() in ["aten::slice", "aten::select", "aten::unsqueeze"]: + src_node = list(src_node.inputs())[0].node() + slice_and_select_nodes.append(src_node) + else: + break + if src_node.kind() == "prim::Param": + # First one is "self" + src_value = list(src_node.outputs())[1] + else: + src_value = src_node.output() + src_value.replaceAllUsesAfterNodeWith(node, node.output()) + + def _get_tensor_and_var(torch_tensor, name): tensor = tvm.nd.array(torch_tensor.cpu().numpy()) var = _expr.var(name, shape=tensor.shape, dtype=tensor.dtype) @@ -4971,6 +5094,7 @@ def from_pytorch( break _run_jit_passes(graph, enable_lower_all_tuples) + _redirect_inplace_output(graph) if custom_convert_map: converter.update_convert_map(custom_convert_map) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 6bbb9ef5cc..abdbda8e40 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5355,6 +5355,32 @@ def test_exporting_renamed_c_graph(): assert "%aten::_convolution_0" in graph +def test_inplace_copy(): + class SimpleInplaceCopy(torch.nn.Module): + def forward(self, x): + x[:5, 0, 5:] = x[:5, 0, 5:] + 1 + return x + + class NegativeSliceInplaceCopy(torch.nn.Module): + def forward(self, x): + x[5:-1, -1, :] = x[5:-1, -1, :] + 1 + return x + + class PartialDimensionInplaceCopy(torch.nn.Module): + def forward(self, x): + x[:5] = x[:5] + 1 + x[0:5, ...] = x[0:5, ...] + 1 + x[0:5, ..., -1] = x[0:5, ..., -1] + 1 + return x + + inputs = torch.randn(10, 10, 10) + verify_model(SimpleInplaceCopy(), [inputs]) + inputs = torch.randn(10, 10, 10) + verify_model(NegativeSliceInplaceCopy(), [inputs]) + inputs = torch.randn(10, 10, 10) + verify_model(PartialDimensionInplaceCopy(), [inputs]) + + class TestSetSpan: """test structural equal between translated / hand-crafted relay IR with span tagged."""