jiangjiajun commented on a change in pull request #9126:
URL: https://github.com/apache/tvm/pull/9126#discussion_r717359304



##########
File path: python/tvm/relay/frontend/paddlepaddle.py
##########
@@ -40,28 +41,108 @@
 __all__ = ["from_paddle"]
 
 
+def _get_pad_size(in_size, dilated_kernel_size, stride_size):
+    """calculate the paddings size"""

Review comment:
       Done

##########
File path: python/tvm/relay/frontend/paddlepaddle.py
##########
@@ -40,28 +41,108 @@
 __all__ = ["from_paddle"]
 
 
+def _get_pad_size(in_size, dilated_kernel_size, stride_size):
+    """calculate the paddings size"""
+
+    if stride_size == 1 or in_size % stride_size == 0:
+        pad = max(dilated_kernel_size - stride_size, 0)
+    else:
+        pad = max(dilated_kernel_size - (in_size % stride_size), 0)
+
+    pad_before = pad // 2
+    pad_after = pad - pad_before
+
+    return [pad_before, pad_after]
+
+
+def _dtype_shape_promotion(inputs):
+    """promote data type and shape for list of tensors."""
+
+    dtype_order = ["bool", "int8", "int16", "int32", "int64", "float32", 
"float64"]
+
+    ranks = [len(infer_shape(x)) for x in inputs]
+    if set(ranks) == set([1, 0]):
+        for i, r in enumerate(ranks):
+            if r == 0:
+                inputs[i] = _op.expand_dims(inputs[i], axis=0)
+
+    dtypes = set(dtype_order.index(infer_type(x).checked_type.dtype) for x in 
inputs)
+    if len(dtypes) == 1:
+        return inputs
+    max_dtype = dtype_order[max(dtypes)]
+    for i, input_op in enumerate(inputs):
+        if infer_type(input_op).checked_type.dtype != max_dtype:
+            inputs[i] = input_op.astype(max_dtype)
+    return inputs
+
+
 def shape_of(x, dtype="int32"):
     """Get shape of a tensor"""
 
     ttype = infer_type(x).checked_type
     if not _ty.is_dynamic(ttype):
         shape = list(ttype.shape)
-        return _expr.const(shape, dtype)
+        return _expr.const(np.array(shape), dtype)
     return _op.shape_of(x, dtype)
 
 
-def _get_pad_size(in_size, dilated_kernel_size, stride_size):
-    """calculate the paddings size"""
+def _infer_value(x, params):
+    """Try running infer_value, and if successful, return the inferred value.
+    Otherwise, return input"""
 
-    if stride_size == 1 or in_size % stride_size == 0:
-        pad = max(dilated_kernel_size - stride_size, 0)
+    try:
+        value = infer_value(x, params)
+        return value.numpy().tolist()
+    except Exception:  # pylint: disable=broad-except
+        return x
+
+
+def _convert_dtype_value(val):
+    """converts a Paddle type id to a string."""
+
+    convert_dtype_map = {
+        21: "int8",
+        20: "uint8",
+        6: "float64",
+        5: "float32",
+        4: "float16",
+        3: "int64",
+        2: "int32",
+        1: "int16",
+        0: "bool",
+    }
+    if val not in convert_dtype_map:
+        msg = "Paddle data type value %d is not handled yet." % (val)
+        raise NotImplementedError(msg)
+    return convert_dtype_map[val]
+
+
+def convert_unary_op(g, op, block):
+    """Operator converter for all the unary operators."""
+
+    # op_map stores mapping relationship between paddlepaddle and relay
+    op_map = {
+        "isinf_v2": _op.isinf,
+        "isfinite_v2": _op.isfinite,
+        "isnan_v2": _op.isnan,
+    }
+    if op.type in op_map:
+        unary_func = op_map[op.type]
     else:
-        pad = max(dilated_kernel_size - (in_size % stride_size), 0)
+        # while paddle operator's name is same with relay
+        unary_func = get_relay_op(op.type)
+    out = unary_func(g.get_node(op.input("X")[0]))
+    g.add_node(op.output("Out")[0], out)
 
-    pad_before = pad // 2
-    pad_after = pad - pad_before
 
-    return [pad_before, pad_after]
+def convert_binary_logical_op(g, op, block):
+    """Operator converter for logical op."""
+
+    ipt0 = g.get_node(op.input("X")[0])
+    ipt1 = g.get_node(op.input("Y")[0])
+    op_func = get_relay_op(op.type)
+    out = op_func(ipt0, ipt1)
+    g.add_node(op.output("Out")[0], out)
 
 
 def convert_arg_max(g, op, block):

Review comment:
       Done

##########
File path: python/tvm/relay/frontend/paddlepaddle.py
##########
@@ -70,20 +151,73 @@ def convert_arg_max(g, op, block):
     axis = op.attr("axis")
     keepdims = op.attr("keepdims")
     flatten = op.attr("flatten")
+    dtype = op.attr("dtype")
+    dtype = _convert_dtype_value(dtype)
 
     x = g.get_node(op.input("X")[0])
     if axis is None or flatten:
         x = _op.reshape(x, [-1])
         out = _op.argmax(x, axis=None, keepdims=True)
     else:
         out = _op.argmax(x, axis=axis, keepdims=keepdims)
+    if dtype != infer_type(out).checked_type.dtype:
+        out = _op.cast(out, dtype)
+    g.add_node(op.output("Out")[0], out)
+
+
+def convert_arg_min(g, op, block):
+    """Operator converter for arg_min."""
+
+    axis = op.attr("axis")
+    keepdims = op.attr("keepdims")
+    flatten = op.attr("flatten")
+    dtype = op.attr("dtype")
+    dtype = _convert_dtype_value(dtype)
+
+    x = g.get_node(op.input("X")[0])
+    if axis is None or flatten:
+        x = _op.reshape(x, [-1])
+        out = _op.argmin(x, axis=None, keepdims=True)
+    else:
+        out = _op.argmin(x, axis=axis, keepdims=keepdims)
+    if dtype != infer_type(out).checked_type.dtype:
+        out = _op.cast(out, dtype)
+    g.add_node(op.output("Out")[0], out)
+
+
+def convert_argsort(g, op, block):
+    """Operator converter for argsort."""
+
+    x = g.get_node(op.input("X")[0])
+    axis = op.attr("axis")
+    descending = op.attr("descending")
+
+    out = _op.sort(x, axis, not descending)

Review comment:
       Done

##########
File path: python/tvm/relay/frontend/paddlepaddle.py
##########
@@ -70,20 +151,73 @@ def convert_arg_max(g, op, block):
     axis = op.attr("axis")
     keepdims = op.attr("keepdims")
     flatten = op.attr("flatten")
+    dtype = op.attr("dtype")
+    dtype = _convert_dtype_value(dtype)
 
     x = g.get_node(op.input("X")[0])
     if axis is None or flatten:
         x = _op.reshape(x, [-1])
         out = _op.argmax(x, axis=None, keepdims=True)
     else:
         out = _op.argmax(x, axis=axis, keepdims=keepdims)
+    if dtype != infer_type(out).checked_type.dtype:
+        out = _op.cast(out, dtype)
+    g.add_node(op.output("Out")[0], out)
+
+
+def convert_arg_min(g, op, block):
+    """Operator converter for arg_min."""
+
+    axis = op.attr("axis")
+    keepdims = op.attr("keepdims")
+    flatten = op.attr("flatten")
+    dtype = op.attr("dtype")
+    dtype = _convert_dtype_value(dtype)
+
+    x = g.get_node(op.input("X")[0])
+    if axis is None or flatten:
+        x = _op.reshape(x, [-1])
+        out = _op.argmin(x, axis=None, keepdims=True)
+    else:
+        out = _op.argmin(x, axis=axis, keepdims=keepdims)
+    if dtype != infer_type(out).checked_type.dtype:
+        out = _op.cast(out, dtype)
+    g.add_node(op.output("Out")[0], out)
+
+
+def convert_argsort(g, op, block):
+    """Operator converter for argsort."""
+
+    x = g.get_node(op.input("X")[0])
+    axis = op.attr("axis")
+    descending = op.attr("descending")
+
+    out = _op.sort(x, axis, not descending)
+    out_indice = _op.argsort(x, axis, not descending, dtype="int64")

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to