This is an automated email from the ASF dual-hosted git repository.

tkonolige pushed a commit to branch tkonolige/relax_pad_etc_new
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit eb4b473ba27bcfcc5c6a62345a20bb65462054b8
Author: Tristan Konolige <[email protected]>
AuthorDate: Mon May 15 16:30:21 2023 +0000

    trying data dependent dynamic slice
---
 python/tvm/relax/frontend/torch/fx_translator.py |  6 ++++-
 src/relax/op/tensor/index.cc                     | 32 ++++++++++++++++++++++++
 2 files changed, 37 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 2d94b246de..2af651bb97 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -1347,7 +1347,11 @@ class TorchFXImporter:
                 axes.append(i)
                 i += 1
             print([type(x) for x in axes], [type(x) for x in begin], [type(x) 
for x in end], [type(x) for x in stride])
-            sliced = self.block_builder.emit(relax.op.strided_slice(x, axes, 
begin, end, stride))
+            if any([isinstance(x, relax.Var) for x in begin]) or 
any([isinstance(x, relax.Var) for x in end]):
+                print(end)
+                sliced = 
self.block_builder.emit(relax.op.data_dependent_strided_slice(x, axes, begin, 
relax.const(np.array(end, dtype=int))))
+            else:
+                sliced = self.block_builder.emit(relax.op.strided_slice(x, 
axes, begin, end, stride))
             sliced_shape = list(self.shape_of(sliced))
             for i in expand_dim:
                 sliced_shape.insert(i, 1)
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index d3bb34d21a..6510967992 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -239,6 +239,7 @@ TVM_REGISTER_OP("relax.strided_slice")
     .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice)
     .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
+<<<<<<< HEAD
 /* relax.dynamic_strided_slice */
 Expr dynamic_strided_slice(Expr x,      //
                            Expr begin,  //
@@ -246,6 +247,37 @@ Expr dynamic_strided_slice(Expr x,      //
                            Expr strides) {
   static const Op& op = Op::Get("relax.dynamic_strided_slice");
   return Call(op, {std::move(x), std::move(begin), std::move(end), 
std::move(strides)}, {});
+||||||| parent of f322af441 (trying data dependent dynamic slice)
+/* relax.data_dependent_strided_slice */
+// TODO: support dynamic number of axes
+TVM_REGISTER_NODE_TYPE(DataDependentStridedSliceAttrs);
+
+Expr data_dependent_strided_slice(Expr x,                 //
+                   Array<Integer> axes,    //
+                   Expr begin,  //
+                   Expr end)    {
+  int n_axis = axes.size();
+
+  ObjectPtr<DataDependentStridedSliceAttrs> attrs = 
make_object<DataDependentStridedSliceAttrs>();
+  attrs->axes = std::move(axes);
+
+  static const Op& op = Op::Get("relax.data_dependent_strided_slice");
+  return Call(op, {std::move(x), std::move(begin), std::move(end)}, 
Attrs(attrs), {});
+=======
+/* relax.data_dependent_strided_slice */
+// TODO: support dynamic number of axes
+TVM_REGISTER_NODE_TYPE(DataDependentStridedSliceAttrs);
+
+Expr data_dependent_strided_slice(Expr x,                 //
+                   Array<Integer> axes,    //
+                   Expr begin,  //
+                   Expr end)    {
+  ObjectPtr<DataDependentStridedSliceAttrs> attrs = 
make_object<DataDependentStridedSliceAttrs>();
+  attrs->axes = std::move(axes);
+
+  static const Op& op = Op::Get("relax.data_dependent_strided_slice");
+  return Call(op, {std::move(x), std::move(begin), std::move(end)}, 
Attrs(attrs), {});
+>>>>>>> f322af441 (trying data dependent dynamic slice)
 }
 
 
TVM_REGISTER_GLOBAL("relax.op.dynamic_strided_slice").set_body_typed(dynamic_strided_slice);

Reply via email to