This is an automated email from the ASF dual-hosted git repository.

tkonolige pushed a commit to branch tkonolige/relax_pad_etc_new
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit cb458f833253b319122e9618e13e58c109f8b4e2
Author: Tristan Konolige <[email protected]>
AuthorDate: Mon May 15 18:00:50 2023 +0000

    data dependent strided slice
---
 python/tvm/relax/frontend/torch/fx_translator.py | 29 ++++++++++++++++-----
 src/relax/op/tensor/index.cc                     | 32 ------------------------
 2 files changed, 23 insertions(+), 38 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 2af651bb97..bdf2c3375f 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -18,7 +18,7 @@
 # pylint: disable=invalid-name, inconsistent-return-statements, 
unidiomatic-typecheck
 # pylint: disable=import-outside-toplevel
 """PyTorch FX frontend of Relax."""
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union, Sequence
 from functools import reduce
 
 import tvm
@@ -123,6 +123,17 @@ def _convert_data_type(input_type, default_dtype=None):
         raise NotImplementedError("input_type {} is not handled 
yet".format(input_type))
     return "float32"  # Never reached
 
+def _constify_array(ary: Sequence[Union[int, tvm.tir.IntImm, relax.Var]]) -> 
relax.Expr:
+    const = []
+    for i in ary:
+        if isinstance(i, int):
+            const.append(relax.const(tvm.nd.array([i])))
+        elif isinstance(i, tvm.tir.IntImm):
+            const.append(relax.const(tvm.nd.array([i.value])))
+        else:
+            const.append(i)
+    return relax.op.concat(const)
+
 class TorchFXImporter:
     """An importer from PyTorch FX to Relax."""
 
@@ -1141,6 +1152,11 @@ class TorchFXImporter:
             assert isinstance(shape, relax.ShapeExpr)
             size = tuple(int(shape[i].value * scale_factor) for i in range(2, 
len(shape)))
 
+        from torch import fx
+        if any([isinstance(s, fx.node.Node) for s in size]):
+            size = relax.op.concat([self.env[s] for s in size])
+            # import IPython; IPython.embed()
+
         if method.startswith("nearest"):
             method = "nearest_neighbor"
         elif method[0:2] == "bi":
@@ -1346,13 +1362,14 @@ class TorchFXImporter:
                 stride.append(1)
                 axes.append(i)
                 i += 1
-            print([type(x) for x in axes], [type(x) for x in begin], [type(x) 
for x in end], [type(x) for x in stride])
-            if any([isinstance(x, relax.Var) for x in begin]) or 
any([isinstance(x, relax.Var) for x in end]):
-                print(end)
-                sliced = 
self.block_builder.emit(relax.op.data_dependent_strided_slice(x, axes, begin, 
relax.const(np.array(end, dtype=int))))
+            # Handle case where slice is data dependent
+            if any([isinstance(y, relax.Var) for y in begin]) or 
any([isinstance(y, relax.Var) for y in end]) or any([isinstance(y, relax.Var) 
for y in stride]):
+                assert len(axes) == len(x.struct_info.shape), f"Dynamic 
strided slice must be provided with all the axes ({len(x.struct_info.shape)}) 
of the sliced tensor ({len(axes)} provided)"
+                sliced = 
self.block_builder.emit(relax.op.dynamic_strided_slice(x, 
_constify_array(begin), _constify_array(end), _constify_array(stride)))
+                sliced_shape = list(self.shape_of(x))
             else:
                 sliced = self.block_builder.emit(relax.op.strided_slice(x, 
axes, begin, end, stride))
-            sliced_shape = list(self.shape_of(sliced))
+                sliced_shape = list(self.shape_of(sliced))
             for i in expand_dim:
                 sliced_shape.insert(i, 1)
             return self.block_builder.emit(relax.op.reshape(sliced, 
sliced_shape))
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index 6510967992..d3bb34d21a 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -239,7 +239,6 @@ TVM_REGISTER_OP("relax.strided_slice")
     .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice)
     .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow);
 
-<<<<<<< HEAD
 /* relax.dynamic_strided_slice */
 Expr dynamic_strided_slice(Expr x,      //
                            Expr begin,  //
@@ -247,37 +246,6 @@ Expr dynamic_strided_slice(Expr x,      //
                            Expr strides) {
   static const Op& op = Op::Get("relax.dynamic_strided_slice");
   return Call(op, {std::move(x), std::move(begin), std::move(end), 
std::move(strides)}, {});
-||||||| parent of f322af441 (trying data dependent dynamic slice)
-/* relax.data_dependent_strided_slice */
-// TODO: support dynamic number of axes
-TVM_REGISTER_NODE_TYPE(DataDependentStridedSliceAttrs);
-
-Expr data_dependent_strided_slice(Expr x,                 //
-                   Array<Integer> axes,    //
-                   Expr begin,  //
-                   Expr end)    {
-  int n_axis = axes.size();
-
-  ObjectPtr<DataDependentStridedSliceAttrs> attrs = 
make_object<DataDependentStridedSliceAttrs>();
-  attrs->axes = std::move(axes);
-
-  static const Op& op = Op::Get("relax.data_dependent_strided_slice");
-  return Call(op, {std::move(x), std::move(begin), std::move(end)}, 
Attrs(attrs), {});
-=======
-/* relax.data_dependent_strided_slice */
-// TODO: support dynamic number of axes
-TVM_REGISTER_NODE_TYPE(DataDependentStridedSliceAttrs);
-
-Expr data_dependent_strided_slice(Expr x,                 //
-                   Array<Integer> axes,    //
-                   Expr begin,  //
-                   Expr end)    {
-  ObjectPtr<DataDependentStridedSliceAttrs> attrs = 
make_object<DataDependentStridedSliceAttrs>();
-  attrs->axes = std::move(axes);
-
-  static const Op& op = Op::Get("relax.data_dependent_strided_slice");
-  return Call(op, {std::move(x), std::move(begin), std::move(end)}, 
Attrs(attrs), {});
->>>>>>> f322af441 (trying data dependent dynamic slice)
 }
 
 
TVM_REGISTER_GLOBAL("relax.op.dynamic_strided_slice").set_body_typed(dynamic_strided_slice);

Reply via email to