This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e3c8f2bc2d [Frontend][Paddle]add take_alone_axis and topk converter 
for paddle frontend (#14170)
e3c8f2bc2d is described below

commit e3c8f2bc2d9ab3f5ba1c0db0755e455bf59a39b8
Author: GaoYuYang <[email protected]>
AuthorDate: Sun Mar 12 22:23:13 2023 +0800

    [Frontend][Paddle]add take_alone_axis and topk converter for paddle 
frontend (#14170)
    
    add take_alone_axis and topk converter for paddle frontend
---
 python/tvm/relay/frontend/paddlepaddle.py          | 18 +++++++++++++++--
 tests/python/frontend/paddlepaddle/test_forward.py | 23 ++++++++++++++++++++++
 2 files changed, 39 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/frontend/paddlepaddle.py 
b/python/tvm/relay/frontend/paddlepaddle.py
index 4b849987ed..3c6429246a 100755
--- a/python/tvm/relay/frontend/paddlepaddle.py
+++ b/python/tvm/relay/frontend/paddlepaddle.py
@@ -2071,6 +2071,14 @@ def convert_swish(g, op, block):
     g.add_node(op.output("Out")[0], out)
 
 
+def convert_take_along_axis(g, op, block):
+    x = g.get_node(op.input("Input")[0])
+    idx = g.get_node(op.input("Index")[0])
+    axis = op.attr("Axis")
+    out = _op.gather(x, axis, idx)
+    g.add_node(op.output("Result")[0], out)
+
+
 def convert_tile(g, op, block):
     """Operator converter for tile."""
 
@@ -2111,9 +2119,13 @@ def convert_topk(g, op, block):
     else:
         k = op.attr("k")
 
-    largest = op.attr("largest")
+    largest = True
+    axis = -1
+    if op.has_attr("axis"):
+        axis = op.attr("axis")
+    if op.has_attr("largest"):
+        largest = op.attr("largest")
     is_ascend = not largest
-    axis = op.attr("axis")
 
     value_names = op.output("Out")
     indice_names = op.output("Indices")
@@ -2317,8 +2329,10 @@ _convert_map = {
     "square": convert_square,
     "squeeze2": convert_squeeze,
     "swish": convert_swish,
+    "take_along_axis": convert_take_along_axis,
     "tan": convert_unary_op,
     "tanh": convert_unary_op,
+    "top_k": convert_topk,
     "tile": convert_tile,
     "top_k_v2": convert_topk,
     "transpose2": convert_transpose,
diff --git a/tests/python/frontend/paddlepaddle/test_forward.py 
b/tests/python/frontend/paddlepaddle/test_forward.py
index d21323d7ba..392db76942 100755
--- a/tests/python/frontend/paddlepaddle/test_forward.py
+++ b/tests/python/frontend/paddlepaddle/test_forward.py
@@ -1738,13 +1738,25 @@ def test_forward_topk():
     def topk6(inputs):
         return paddle.topk(inputs, k=1, axis=0)
 
+    # paddle.fluid.layers.topk
+    @paddle.jit.to_static
+    def topk7(inputs):
+        return paddle.fluid.layers.topk(inputs, k=1)
+
+    @paddle.jit.to_static
+    def topk8(inputs):
+        return paddle.fluid.layers.topk(inputs, k=2)
+
     input_data = paddle.to_tensor([[1, 4, 5, 7], [3, 6, 2, 5]], 
dtype=paddle.int32)
+    input_data_fp32 = paddle.to_tensor([[1, 4, 5, 7], [3, 6, 2, 5]], 
dtype=paddle.float32)
     verify_model(topk1, input_data=input_data)
     # verify_model(topk2, input_data=input_data)
     verify_model(topk3, input_data=input_data)
     verify_model(topk4, input_data=input_data)
     verify_model(topk5, input_data=input_data)
     verify_model(topk6, input_data=input_data)
+    verify_model(topk7, input_data=input_data_fp32)
+    verify_model(topk8, input_data=input_data_fp32)
 
 
 @tvm.testing.uses_gpu
@@ -1783,6 +1795,17 @@ def test_forward_where_index():
     verify_model(where_index_1, input_data=input_data, use_vm=True)
 
 
[email protected]_gpu
+def test_forward_take_along_axis():
+    @paddle.jit.to_static
+    def take_along_axis_1(inputs, index):
+        return paddle.take_along_axis(inputs, index, 0)
+
+    input_data = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
+    index = paddle.to_tensor([[0]])
+    verify_model(take_along_axis_1, input_data=[input_data, index])
+
+
 @tvm.testing.uses_gpu
 def test_forward_stack():
     class Stack1(nn.Layer):

Reply via email to