This is an automated email from the ASF dual-hosted git repository.

tkonolige pushed a commit to branch tkonolige/relax_pad_etc_new
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit a54064659bd6e1239ec612864130428f78503bc3
Author: Tristan Konolige <[email protected]>
AuthorDate: Fri May 12 21:48:50 2023 +0000

    working
---
 include/tvm/relax/attrs/index.h                  |   8 +
 python/tvm/relax/expr.py                         |   2 +-
 python/tvm/relax/frontend/torch/fx_translator.py | 204 +++++++++++++++++++++--
 src/relax/op/tensor/index.cc                     |   1 -
 4 files changed, 203 insertions(+), 12 deletions(-)

diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h
index c95395a803..e31a334bf3 100644
--- a/include/tvm/relax/attrs/index.h
+++ b/include/tvm/relax/attrs/index.h
@@ -56,6 +56,14 @@ struct StridedSliceAttrs : public 
tvm::AttrsNode<StridedSliceAttrs> {
   }
 };  // struct StridedSliceAttrs
 
+/*! \brief Attributes used in strided_slice operator */
+struct DataDependentStridedSliceAttrs : public 
tvm::AttrsNode<DataDependentStridedSliceAttrs> {
+  Array<Integer> axes;
+  TVM_DECLARE_ATTRS(DataDependentStridedSliceAttrs, 
"relax.attrs.DataDependentStridedSliceAttrs") {
+    TVM_ATTR_FIELD(axes).describe("Axes along which slicing is applied.");
+  }
+};
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index fdf98c179b..d3550741f7 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -652,7 +652,7 @@ def const(
         value = _nd.array(value)
 
     if not isinstance(value, _nd.NDArray):
-        raise ValueError("value has to be scalar or NDArray")
+        raise ValueError(f"value has to be scalar or NDArray but it is 
{type(value)}")
 
     return Constant(value)
 
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index 7166b312a0..2d94b246de 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -23,7 +23,105 @@ from functools import reduce
 
 import tvm
 from tvm import relax
-
+import numpy as np
+
+def _pytorch_result_type(dtypes, non_tensor_inputs):
+    """This promotes TVM dtypes like PyTorch would"""
+    import torch
+
+    dtype_map = {
+        "float64": torch.float64,
+        "float32": torch.float32,
+        "float16": torch.float16,
+        "bfloat16": torch.bfloat16,
+        "int64": torch.int64,
+        "int32": torch.int32,
+        "int16": torch.int16,
+        "int8": torch.int8,
+        "uint8": torch.uint8,
+        "bool": torch.bool,
+    }
+    if len(dtypes) > 0:
+        result_type = dtypes[0]
+        for dt in dtypes[1:]:
+            if dt != result_type:  # we don't want to work with same types as 
we
+                # don't do quantized here (which cannot be promoted?)
+                result_type = _convert_data_type(
+                    str(
+                        torch.result_type(
+                            torch.zeros((), dtype=dtype_map[result_type]),
+                            torch.zeros((), dtype=dtype_map[dt]),
+                        )
+                    )
+                )
+    else:
+        result_type = "bool"  # this is the smallest type...
+    for inp in non_tensor_inputs:
+        result_type = _convert_data_type(
+            str(torch.result_type(torch.zeros((), 
dtype=dtype_map[result_type]), inp))
+        )
+    return result_type
+
+
+# Helper functions for operator implementation
+def _convert_dtype_value(val):
+    """converts a PyTorch the PyTorch numeric type id to a torch scalar 
type."""
+    convert_torch_dtype_map = {
+        11: "torch.bool",
+        7: "torch.float64",
+        6: "torch.float32",
+        5: "torch.float16",
+        4: "torch.int64",
+        3: "torch.int32",
+        2: "torch.int16",
+        1: "torch.int8",
+        0: "torch.uint8",
+        None: "torch.int64",
+    }  # Default is torch.int64
+    if val in convert_torch_dtype_map:
+        return _convert_data_type(convert_torch_dtype_map[val])
+    else:
+        msg = "Torch data type value %d is not handled yet." % (val)
+        raise NotImplementedError(msg)
+
+
+def _convert_data_type(input_type, default_dtype=None):
+    """converts the PyTorch scalar type input_type to a TVM dtype.
+    optionally, default_dtype can be a TVM dtype that is used
+    if input_type is None (but not when it is unknown)"""
+    if input_type is None and default_dtype is not None:
+        return default_dtype
+
+    input_type = input_type.lower()
+    if input_type in ["double", "float64", "torch.float64"]:
+        return "float64"
+    elif input_type in ["float", "float32", "torch.float32"]:
+        return "float32"
+    elif input_type in ["half", "float16", "torch.float16"]:
+        return "float16"
+    elif input_type in ["long", "int64", "torch.int64"]:
+        return "int64"
+    elif input_type in ["int", "int32", "torch.int32"]:
+        return "int32"
+    elif input_type in ["short", "int16", "torch.int16"]:
+        return "int16"
+    elif input_type in ["char", "int8", "torch.int8"]:
+        return "int8"
+    elif input_type in ["byte", "uint8", "torch.uint8"]:
+        return "uint8"
+    elif input_type in ["quint8", "torch.quint8"]:
+        return "quint8"
+    elif input_type in ["qint8", "torch.qint8"]:
+        return "qint8"
+    elif input_type in ["qint32", "torch.qint32"]:
+        return "qint32"
+    elif input_type in ["bool", "torch.bool"]:
+        return "bool"
+    elif input_type in ["str"]:
+        return "str"
+    else:
+        raise NotImplementedError("input_type {} is not handled 
yet".format(input_type))
+    return "float32"  # Never reached
 
 class TorchFXImporter:
     """An importer from PyTorch FX to Relax."""
@@ -86,6 +184,23 @@ class TorchFXImporter:
         dtype = TorchFXImporter._convert_data_type(str(tensor.data.dtype))
         return relax.const(tensor.data.numpy(), dtype)
 
+    @staticmethod
+    def _promote_types(inputs):
+        """This promotes TVM inputs with TVM dtypes passed like PyTorch 
would"""
+        dtypes = [inp.struct_info.dtype for inp in inputs]
+        tensor_dtypes = [dt for inp, dt in zip(inputs, dtypes) if not 
np.isscalar(inp)]
+        non_tensor_inputs = [inp for inp in inputs if np.isscalar(inp)]
+        result_type = _pytorch_result_type(tensor_dtypes, non_tensor_inputs)
+        results = []
+        for inp, dt in zip(inputs, dtypes):
+            if np.isscalar(inp):
+                results.append(relax.const(inp, dtype=result_type))
+            elif dt == result_type:
+                results.append(inp)
+            else:
+                results.append(relax.op.astype(inp, result_type))
+        return results
+
     @staticmethod
     def shape_of(tensor):
         """Get the shape of a tensor."""
@@ -119,7 +234,8 @@ class TorchFXImporter:
     @staticmethod
     def _promote_binary_op_args(lhs, rhs):
         if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
-            return lhs, rhs
+            lhs_, rhs_ = TorchFXImporter._promote_types((lhs, rhs))
+            return lhs_, rhs_
         elif isinstance(lhs, relax.Expr):
             assert isinstance(lhs.struct_info, relax.TensorStructInfo)
             return lhs, relax.const(rhs, lhs.struct_info.dtype)
@@ -151,6 +267,8 @@ class TorchFXImporter:
         arg = self.env[node.args[0]]
         if isinstance(arg, (int, float)):
             arg = relax.const(arg, "float32")
+        if isinstance(arg, (tvm.tir.FloatImm, tvm.tir.IntImm)):
+            arg = relax.const(arg.value, "float32")
         return self.block_builder.emit(relax.op.sqrt(arg))
 
     def _rsqrt(self, node: fx.node.Node) -> relax.Expr:
@@ -180,9 +298,8 @@ class TorchFXImporter:
         return lhs + rhs
 
     def _max(self, node: fx.node.Node) -> relax.Expr:
-        lhs, rhs = self.retrieve_args(node)
-        if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
-            return self._call_binary_op(relax.op.maximum, lhs, rhs)
+        lhs, = self.retrieve_args(node)
+        return self.block_builder.emit(relax.op.max(lhs))
 
     def _floordiv(self, node: fx.node.Node) -> relax.Expr:
         lhs, rhs = self.retrieve_args(node)
@@ -190,6 +307,10 @@ class TorchFXImporter:
             return self._call_binary_op(relax.op.floor_divide, lhs, rhs)
         return lhs // rhs
 
+    def _floor(self, node: fx.node.Node) -> relax.Expr:
+        lhs, = self.retrieve_args(node)
+        return self.block_builder.emit(relax.op.floor(lhs))
+
     def _mul(self, node: fx.node.Node) -> relax.Expr:
         lhs, rhs = self.retrieve_args(node)
         if isinstance(lhs, relax.Var) or isinstance(rhs, relax.Var):
@@ -248,6 +369,10 @@ class TorchFXImporter:
         lhs, rhs = self.retrieve_args(node)
         return self._call_binary_op(relax.op.equal, lhs, rhs)
 
+    def _ne(self, node: fx.node.Node) -> relax.Expr:
+        lhs, rhs = self.retrieve_args(node)
+        return self._call_binary_op(relax.op.not_equal, lhs, rhs)
+
     ########## Creation ##########
 
     def _arange(self, node: fx.node.Node) -> relax.Var:
@@ -501,7 +626,13 @@ class TorchFXImporter:
 
     def _expand(self, node: fx.node.Node) -> relax.Var:
         args = self.retrieve_args(node)
-        return self.block_builder.emit(relax.op.broadcast_to(args[0], 
args[1:]))
+        # -1 indicates dimension remains the same
+        fixed_dims = [args[0].struct_info.shape[i] if d == -1 else d for i, d 
in enumerate(args[1:])]
+        return self.block_builder.emit(relax.op.broadcast_to(args[0], 
fixed_dims))
+
+    def _expand_as(self, node: fx.node.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        return self.block_builder.emit(relax.op.broadcast_to(args[0], 
self.shape_of(args[1])))
 
     def _flatten(self, node: fx.node.Node) -> relax.Var:
         x = self.env[node.args[0]]
@@ -574,6 +705,19 @@ class TorchFXImporter:
             dim = None
         return self.block_builder.emit(relax.op.squeeze(x, dim))
 
+    def _repeat_interleave(self, node: fx.node.Node) -> relax.Var:
+        data, repeats = self.retrieve_args(node)
+        if not "dim" in node.kwargs:
+            data = relax.op.reshape(args[0], [-1])
+            axis = 0
+        else:
+            axis = node.kwargs["dim"]
+        return self.block_builder.emit(relax.op.repeat(data, 
repeats=repeats.value, axis=axis))
+
+    def _stack(self, node: fx.node.Node) -> relax.Var:
+        args = [self.env[x] for x in node.args[0]]
+        return self.block_builder.emit(relax.op.concat(args, 
node.kwargs["dim"]))
+
     def _cumsum(self, node: fx.node.Node) -> relax.Var:
         x = self.env[node.args[0]]
 
@@ -711,6 +855,34 @@ class TorchFXImporter:
 
         return self.block_builder.emit(relax.op.add(conv2d, bias))
 
+    def _convtranpose2d(self, node: fx.node.Node) -> relax.Var:
+        x = self.env[node.args[0]]
+        module = self.named_modules[node.target]
+        weight = self.params[module.weight]
+
+        conv2d = self.block_builder.emit(
+            relax.op.nn.conv2d_transpose(
+                x,
+                weight,
+                strides=module.stride,
+                padding=module.padding,
+                dilation=module.dilation,
+                groups=module.groups,
+                data_layout="NCHW",
+                kernel_layout="IOHW",
+                out_dtype="float32",
+            )
+        )
+
+        if module.bias is None:
+            return conv2d
+
+        bias = self.params[module.bias]
+        assert len(self.shape_of(bias)) == 1
+        bias = relax.op.reshape(bias, (1, -1, 1, 1))
+
+        return self.block_builder.emit(relax.op.add(conv2d, bias))
+
     def _max_pool2d(self, node: fx.node.Node) -> relax.Var:
         x = self.env[node.args[0]]
         if node.target in self.named_modules:
@@ -1135,10 +1307,11 @@ class TorchFXImporter:
             i = 0
             shape = self.shape_of(x)
             non_ellipsis_cnt = 0
-            for index in node.args[1]:
+            idxs = [node.args[1]] if isinstance(node.args[1], int) else 
node.args[1]
+            for index in idxs:
                 if isinstance(index, (int, slice)):
                     non_ellipsis_cnt += 1
-            for index in node.args[1]:
+            for index in idxs:
                 if isinstance(index, int):
                     begin.append(index)
                     end.append(index + 1)
@@ -1147,7 +1320,12 @@ class TorchFXImporter:
                     i = i + 1
                 elif isinstance(index, slice):
                     begin.append(0 if index.start is None else index.start)
-                    end.append(shape[i] if index.stop is None else index.stop)
+                    if index.stop is None:
+                        end.append(shape[i])
+                    elif isinstance(index.stop, fx.node.Node):
+                        end.append(self.env[index.stop])
+                    else:
+                        end.append(index.stop)
                     stride.append(1 if index.step is None else index.step)
                     axes.append(i)
                     i = i + 1
@@ -1168,6 +1346,7 @@ class TorchFXImporter:
                 stride.append(1)
                 axes.append(i)
                 i += 1
+            print([type(x) for x in axes], [type(x) for x in begin], [type(x) 
for x in end], [type(x) for x in stride])
             sliced = self.block_builder.emit(relax.op.strided_slice(x, axes, 
begin, end, stride))
             sliced_shape = list(self.shape_of(sliced))
             for i in expand_dim:
@@ -1181,7 +1360,6 @@ class TorchFXImporter:
                 dtype = x.struct_info.dtype
                 return relax.const(x.data.numpy()[node.args[1]], dtype)
         else:
-            import IPython;IPython.embed()
             raise ValueError(f"Unsupported type {type(x)} for _getitem, should 
be list, tuple, ShapeExpr, Tuple, Var, or Constant")
 
     def create_convert_map(self):
@@ -1193,6 +1371,7 @@ class TorchFXImporter:
             nn.Linear: self._linear,
             nn.Conv1d: self._conv1d,
             nn.Conv2d: self._conv2d,
+            nn.ConvTranspose2d: self._convtranpose2d,
             nn.MaxPool2d: self._max_pool2d,
             nn.AvgPool2d: self._avg_pool2d,
             nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d(is_module=True),
@@ -1218,6 +1397,7 @@ class TorchFXImporter:
             "iadd": self._add,
             "add": self._add,
             "floordiv": self._floordiv,
+            "floor": self._floor,
             "mul": self._mul,
             "sub": self._sub,
             "pow": self._pow,
@@ -1226,6 +1406,7 @@ class TorchFXImporter:
             "round": self._round,
             "lt": self._lt,
             "eq": self._eq,
+            "ne": self._ne,
             "truediv": self._truediv,
             "fill_": self._inplace_fill,
             "new_ones": self._new_ones,
@@ -1265,6 +1446,8 @@ class TorchFXImporter:
             "unsqueeze": lambda node: self.block_builder.emit(
                 relax.op.expand_dims(self.env[node.args[0]], node.args[1])
             ),
+            "repeat_interleave": self._repeat_interleave,
+            "stack": self._stack,
             "view": self._reshape,
             "argmax": self._argmax_argmin(relax.op.argmax),
             "argmin": self._argmax_argmin(relax.op.argmin),
@@ -1298,6 +1481,7 @@ class TorchFXImporter:
             "pad": self._pad,
             "unbind": self._unbind,
             "einsum": self._einsum,
+            "expand_as": self._expand_as,
         }
 
     def from_fx(
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index c3d38db4e1..d3bb34d21a 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -312,6 +312,5 @@ TVM_REGISTER_OP("relax.dynamic_strided_slice")
     .add_argument("end", "Tensor", "Indices indicating end of the slice.")
     .add_argument("strides", "Tensor", "The stride values.")
     .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoDynStridedSlice);
-
 }  // namespace relax
 }  // namespace tvm

Reply via email to