This is an automated email from the ASF dual-hosted git repository. tkonolige pushed a commit to branch tkonolige/relax_pad_etc_new in repository https://gitbox.apache.org/repos/asf/tvm.git
commit eb4b473ba27bcfcc5c6a62345a20bb65462054b8 Author: Tristan Konolige <[email protected]> AuthorDate: Mon May 15 16:30:21 2023 +0000 trying data dependent dynamic slice --- python/tvm/relax/frontend/torch/fx_translator.py | 6 ++++- src/relax/op/tensor/index.cc | 32 ++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 2d94b246de..2af651bb97 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1347,7 +1347,11 @@ class TorchFXImporter: axes.append(i) i += 1 print([type(x) for x in axes], [type(x) for x in begin], [type(x) for x in end], [type(x) for x in stride]) - sliced = self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) + if any([isinstance(x, relax.Var) for x in begin]) or any([isinstance(x, relax.Var) for x in end]): + print(end) + sliced = self.block_builder.emit(relax.op.data_dependent_strided_slice(x, axes, begin, relax.const(np.array(end, dtype=int)))) + else: + sliced = self.block_builder.emit(relax.op.strided_slice(x, axes, begin, end, stride)) sliced_shape = list(self.shape_of(sliced)) for i in expand_dim: sliced_shape.insert(i, 1) diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc index d3bb34d21a..6510967992 100644 --- a/src/relax/op/tensor/index.cc +++ b/src/relax/op/tensor/index.cc @@ -239,6 +239,7 @@ TVM_REGISTER_OP("relax.strided_slice") .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice) .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow); +<<<<<<< HEAD /* relax.dynamic_strided_slice */ Expr dynamic_strided_slice(Expr x, // Expr begin, // @@ -246,6 +247,37 @@ Expr dynamic_strided_slice(Expr x, // Expr strides) { static const Op& op = Op::Get("relax.dynamic_strided_slice"); return Call(op, {std::move(x), std::move(begin), std::move(end), std::move(strides)}, {}); +||||||| parent of f322af441 (trying data dependent dynamic slice) +/* relax.data_dependent_strided_slice */ +// TODO: support dynamic number of axes +TVM_REGISTER_NODE_TYPE(DataDependentStridedSliceAttrs); + +Expr data_dependent_strided_slice(Expr x, // + Array<Integer> axes, // + Expr begin, // + Expr end) { + int n_axis = axes.size(); + + ObjectPtr<DataDependentStridedSliceAttrs> attrs = make_object<DataDependentStridedSliceAttrs>(); + attrs->axes = std::move(axes); + + static const Op& op = Op::Get("relax.data_dependent_strided_slice"); + return Call(op, {std::move(x), std::move(begin), std::move(end)}, Attrs(attrs), {}); +======= +/* relax.data_dependent_strided_slice */ +// TODO: support dynamic number of axes +TVM_REGISTER_NODE_TYPE(DataDependentStridedSliceAttrs); + +Expr data_dependent_strided_slice(Expr x, // + Array<Integer> axes, // + Expr begin, // + Expr end) { + ObjectPtr<DataDependentStridedSliceAttrs> attrs = make_object<DataDependentStridedSliceAttrs>(); + attrs->axes = std::move(axes); + + static const Op& op = Op::Get("relax.data_dependent_strided_slice"); + return Call(op, {std::move(x), std::move(begin), std::move(end)}, Attrs(attrs), {}); +>>>>>>> f322af441 (trying data dependent dynamic slice) } TVM_REGISTER_GLOBAL("relax.op.dynamic_strided_slice").set_body_typed(dynamic_strided_slice);
