This is an automated email from the ASF dual-hosted git repository. tkonolige pushed a commit to branch tkonolige/relax_pad_etc_new in repository https://gitbox.apache.org/repos/asf/tvm.git
commit cb458f833253b319122e9618e13e58c109f8b4e2 Author: Tristan Konolige <[email protected]> AuthorDate: Mon May 15 18:00:50 2023 +0000 data dependent strided slice --- python/tvm/relax/frontend/torch/fx_translator.py | 29 ++++++++++++++++----- src/relax/op/tensor/index.cc | 32 ------------------------ 2 files changed, 23 insertions(+), 38 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 2af651bb97..bdf2c3375f 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -18,7 +18,7 @@ # pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck # pylint: disable=import-outside-toplevel """PyTorch FX frontend of Relax.""" -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union, Sequence from functools import reduce import tvm @@ -123,6 +123,17 @@ def _convert_data_type(input_type, default_dtype=None): raise NotImplementedError("input_type {} is not handled yet".format(input_type)) return "float32" # Never reached +def _constify_array(ary: Sequence[Union[int, tvm.tir.IntImm, relax.Var]]) -> relax.Expr: + const = [] + for i in ary: + if isinstance(i, int): + const.append(relax.const(tvm.nd.array([i]))) + elif isinstance(i, tvm.tir.IntImm): + const.append(relax.const(tvm.nd.array([i.value]))) + else: + const.append(i) + return relax.op.concat(const) + class TorchFXImporter: """An importer from PyTorch FX to Relax.""" @@ -1141,6 +1152,11 @@ class TorchFXImporter: assert isinstance(shape, relax.ShapeExpr) size = tuple(int(shape[i].value * scale_factor) for i in range(2, len(shape))) + from torch import fx + if any([isinstance(s, fx.node.Node) for s in size]): + size = relax.op.concat([self.env[s] for s in size]) + # import IPython; IPython.embed() + if method.startswith("nearest"): method = "nearest_neighbor" elif method[0:2] == "bi": @@ -1346,13 +1362,14 @@ class TorchFXImporter: stride.append(1) axes.append(i) i += 1 - print([type(x) for x in axes], [type(x) for x in begin], [type(x) for x in end], [type(x) for x in stride]) - if any([isinstance(x, relax.Var) for x in begin]) or any([isinstance(x, relax.Var) for x in end]): - print(end) - sliced = self.block_builder.emit(relax.op.data_dependent_strided_slice(x, axes, begin, relax.const(np.array(end, dtype=int)))) + # Handle case where slice is data dependent + if any([isinstance(y, relax.Var) for y in begin]) or any([isinstance(y, relax.Var) for y in end]) or any([isinstance(y, relax.Var) for y in stride]): + assert len(axes) == len(x.struct_info.shape), f"Dynamic strided slice must be provided with all the axes ({len(x.struct_info.shape)}) of the sliced tensor ({len(axes)} provided)" + sliced = self.block_builder.emit(relax.op.dynamic_strided_slice(x, _constify_array(begin), _constify_array(end), _constify_array(stride))) + sliced_shape = list(self.shape_of(x)) else: sliced = self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) - sliced_shape = list(self.shape_of(sliced)) + sliced_shape = list(self.shape_of(sliced)) for i in expand_dim: sliced_shape.insert(i, 1) return self.block_builder.emit(relax.op.reshape(sliced, sliced_shape)) diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index 6510967992..d3bb34d21a 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -239,7 +239,6 @@ TVM_REGISTER_OP("relax.strided_slice") .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice) .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); -<<<<<<< HEAD /* relax.dynamic_strided_slice */ Expr dynamic_strided_slice(Expr x, // Expr begin, // @@ -247,37 +246,6 @@ Expr dynamic_strided_slice(Expr x, // Expr strides) { static const Op& op = Op::Get("relax.dynamic_strided_slice"); return Call(op, {std::move(x), std::move(begin), std::move(end), std::move(strides)}, {}); -||||||| parent of f322af441 (trying data dependent dynamic slice) -/* relax.data_dependent_strided_slice */ -// TODO: support dynamic number of axes -TVM_REGISTER_NODE_TYPE(DataDependentStridedSliceAttrs); - -Expr data_dependent_strided_slice(Expr x, // - Array<Integer> axes, // - Expr begin, // - Expr end) { - int n_axis = axes.size(); - - ObjectPtr<DataDependentStridedSliceAttrs> attrs = make_object<DataDependentStridedSliceAttrs>(); - attrs->axes = std::move(axes); - - static const Op& op = Op::Get("relax.data_dependent_strided_slice"); - return Call(op, {std::move(x), std::move(begin), std::move(end)}, Attrs(attrs), {}); -======= -/* relax.data_dependent_strided_slice */ -// TODO: support dynamic number of axes -TVM_REGISTER_NODE_TYPE(DataDependentStridedSliceAttrs); - -Expr data_dependent_strided_slice(Expr x, // - Array<Integer> axes, // - Expr begin, // - Expr end) { - ObjectPtr<DataDependentStridedSliceAttrs> attrs = make_object<DataDependentStridedSliceAttrs>(); - attrs->axes = std::move(axes); - - static const Op& op = Op::Get("relax.data_dependent_strided_slice"); - return Call(op, {std::move(x), std::move(begin), std::move(end)}, Attrs(attrs), {}); ->>>>>>> f322af441 (trying data dependent dynamic slice) } TVM_REGISTER_GLOBAL("relax.op.dynamic_strided_slice").set_body_typed(dynamic_strided_slice);
