masahi commented on a change in pull request #6449:
URL: https://github.com/apache/incubator-tvm/pull/6449#discussion_r489066784
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -2043,6 +2201,151 @@ def _impl(inputs, input_types):
return _impl
+def _roi_align(prelude):
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ boxes = inputs[1]
+
+ output_size = (inputs[3], inputs[4])
+ spatial_scale = inputs[2]
+ sample_ratio = inputs[5]
+ aligned = False if len(inputs) < 7 else inputs[6]
+
+ if aligned:
+ boxes -= _expr.const(0.5 / spatial_scale)
+
+ return _op.vision.roi_align(data, boxes, output_size, spatial_scale,
sample_ratio)
+
+ return _impl
+
+
+def _unbind():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ dim = int(inputs[1])
+ ishapes = _infer_shape(data)
+ if dim >= len(ishapes):
+ msg = "Please check input dim, it shouldn't" "be greater than or
equal to rank."
+ raise AttributeError(msg)
+
+ selections = ishapes[dim]
+ res_split = _op.split(data, selections, dim)
+ # squeeze each split piece to get same shape as aten::unbind
+ # TODO (yongwww): add new op to avoid the squeeze overhead
+ ret = []
+ for i in range(selections):
+ ret.append(_op.transform.squeeze(res_split[i], axis=[dim]))
+ ret = _expr.TupleWrapper(_expr.Tuple(ret), selections)
+ return ret
+
+ return _impl
+
+
+def _shape_as_tensor(prelude):
+ def _impl(inputs, input_types):
+ is_symbolic_shape = False
+ input_shape = _infer_shape(inputs[0], prelude.mod)
+ for axis in input_shape:
+ if not isinstance(axis, (int, tvm.tir.IntImm)):
+ is_symbolic_shape = True
+ break
+
+ if is_symbolic_shape:
+ ret = _op.shape_of(inputs[0], dtype="int64")
+ else:
+ ret = _expr.const(np.array(input_shape), dtype="int64")
+
+ return ret
+
+ return _impl
+
+
+def _logical_and():
+ def _impl(inputs, input_types):
+ lhs = _op.cast(inputs[0], "bool")
+ rhs = _op.cast(inputs[1], "bool")
+
+ return _op.logical_and(lhs, rhs)
+
+ return _impl
+
+
+def _nonzero(is_numpy_style):
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ ret = _op.transform.argwhere(data)
+
+ if is_numpy_style or (len(inputs) > 1 and inputs[1]):
+ # TODO(kevinthesun): Support this by adding unbind op
+ # ret = _unbind()([ret, 0], None)
+ raise RuntimeError("as_tuple is not supported yet for nonzero.")
+ return ret
+
+ return _impl
+
+
+def _scatter():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ axis = int(inputs[1])
+ index = inputs[2]
+ src = inputs[3]
+ return _op.transform.scatter(data, index, src, axis)
+
+ return _impl
+
+
+def _scalar_tensor():
+ def _impl(inputs, input_types):
+ data = inputs[0]
+ cast_map = {
+ 6: "float32",
+ 7: "float64",
+ 3: "int32",
+ 4: "int64",
+ }
+ type_key = inputs[1]
+ if isinstance(data, _expr.Constant):
+ data = data.data.asnumpy().tolist()
+ return _expr.const(data, cast_map[type_key])
+
+ return _impl
+
+
+def _interpolate():
+ def _impl(inputs, input_types):
+ if isinstance(inputs[1], _expr.Expr):
+ out_size = inputs[1]
+ elif isinstance(inputs[1], list):
+ try:
+ infer_res = [_infer_value(size, {}) for size in inputs[1]]
+ out_size = [np.asscalar(res.asnumpy().astype(np.int)) for res
in infer_res]
+ except Exception:
+ h = _op.expand_dims(inputs[1][0], axis=0)
+ w = _op.expand_dims(inputs[1][1], axis=0)
+ out_size = _op.concatenate([h, w], axis=0)
+
+ data = inputs[0]
+ align_corners = inputs[4]
+ method = inputs[3]
+ if method.startswith("nearest"):
+ method = "nearest_neighbor"
+
+ if method == "nearest_neighbor":
+ coord_trans = "asymmetric"
+ elif align_corners:
+ coord_trans = "align_corners"
+ else:
+ coord_trans = "half_pixel"
+
+ def func(x):
Review comment:
remove this `func`
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]