This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bf4b8f5c76 split test_forward_math_api function (#11537)
bf4b8f5c76 is described below
commit bf4b8f5c766be8320df8d792a8c063b7b42c69f5
Author: heliqi <[email protected]>
AuthorDate: Sun Jun 5 22:11:45 2022 -0500
split test_forward_math_api function (#11537)
---
tests/python/frontend/paddlepaddle/test_forward.py | 237 +++++++++++++++++----
1 file changed, 193 insertions(+), 44 deletions(-)
diff --git a/tests/python/frontend/paddlepaddle/test_forward.py
b/tests/python/frontend/paddlepaddle/test_forward.py
index 56ec3a4e54..8b696404e2 100644
--- a/tests/python/frontend/paddlepaddle/test_forward.py
+++ b/tests/python/frontend/paddlepaddle/test_forward.py
@@ -1358,7 +1358,10 @@ def test_forward_slice():
@tvm.testing.uses_gpu
-def test_forward_math_api():
+def run_math_api(func):
+ api_name = func.__name__.split("_")[-1]
+ print("func_name:", api_name)
+
class MathAPI(nn.Layer):
def __init__(self, api_name):
super(MathAPI, self).__init__()
@@ -1371,52 +1374,198 @@ def test_forward_math_api():
def forward(self, inputs):
return self.func(inputs)
- api_list = [
- "abs",
- "acos",
- "asin",
- "atan",
- "ceil",
- "cos",
- "cosh",
- "elu",
- "erf",
- "exp",
- "floor",
- "hardshrink",
- "hardtanh",
- "log_sigmoid",
- "log_softmax",
- "log",
- "log2",
- "log10",
- "log1p",
- "reciprocal",
- "relu",
- "relu6",
- "round",
- "rsqrt",
- "selu",
- "sigmoid",
- "sign",
- "sin",
- "sinh",
- "softplus",
- "softsign",
- "sqrt",
- "square",
- "swish",
- "tan",
- "tanh",
- ]
input_shapes = [[128], [2, 100], [10, 2, 5], [7, 3, 4, 1]]
for input_shape in input_shapes:
input_data = paddle.rand(input_shape, dtype="float32")
- for api_name in api_list:
- if api_name in ["log", "log2", "log10", "reciprocal", "sqrt",
"rsqrt"]:
- # avoid illegal input, all elements should be positive
- input_data = paddle.uniform(input_shape, min=0.01, max=0.99)
- verify_model(MathAPI(api_name), input_data=input_data)
+ if api_name in ["log", "log2", "log10", "reciprocal", "sqrt", "rsqrt"]:
+ # avoid illegal input, all elements should be positive
+ input_data = paddle.uniform(input_shape, min=0.01, max=0.99)
+ verify_model(MathAPI(api_name), input_data=input_data)
+
+
+@run_math_api
+def test_forward_abs():
+ pass
+
+
+@run_math_api
+def test_forward_acos():
+ pass
+
+
+@run_math_api
+def test_forward_abs():
+ pass
+
+
+@run_math_api
+def test_forward_atan():
+ pass
+
+
+@run_math_api
+def test_forward_ceil():
+ pass
+
+
+@run_math_api
+def test_forward_cos():
+ pass
+
+
+@run_math_api
+def test_forward_cosh():
+ pass
+
+
+@run_math_api
+def test_forward_elu():
+ pass
+
+
+@run_math_api
+def test_forward_erf():
+ pass
+
+
+@run_math_api
+def test_forward_exp():
+ pass
+
+
+@run_math_api
+def test_forward_floor():
+ pass
+
+
+@run_math_api
+def test_forward_hardshrink():
+ pass
+
+
+@run_math_api
+def test_forward_hardtanh():
+ pass
+
+
+@run_math_api
+def test_forward_log_sigmoid():
+ pass
+
+
+@run_math_api
+def test_forward_log_softmax():
+ pass
+
+
+@run_math_api
+def test_forward_log():
+ pass
+
+
+@run_math_api
+def test_forward_log2():
+ pass
+
+
+@run_math_api
+def test_forward_log10():
+ pass
+
+
+@run_math_api
+def test_forward_log1p():
+ pass
+
+
+@run_math_api
+def test_forward_reciprocal():
+ pass
+
+
+@run_math_api
+def test_forward_relu():
+ pass
+
+
+@run_math_api
+def test_forward_round():
+ pass
+
+
+@run_math_api
+def test_forward_rsqrt():
+ pass
+
+
+@run_math_api
+def test_forward_selu():
+ pass
+
+
+@run_math_api
+def test_forward_sigmoid():
+ pass
+
+
+@run_math_api
+def test_forward_sign():
+ pass
+
+
+@run_math_api
+def test_forward_sin():
+ pass
+
+
+@run_math_api
+def test_forward_softplus():
+ pass
+
+
+@run_math_api
+def test_forward_sqrt():
+ pass
+
+
+@run_math_api
+def test_forward_square():
+ pass
+
+
+@run_math_api
+def test_forward_sin():
+ pass
+
+
+@run_math_api
+def test_forward_softsign():
+ pass
+
+
+@run_math_api
+def test_forward_sqrt():
+ pass
+
+
+@run_math_api
+def test_forward_square():
+ pass
+
+
+@run_math_api
+def test_forward_swish():
+ pass
+
+
+@run_math_api
+def test_forward_tan():
+ pass
+
+
+@run_math_api
+def test_forward_tanh():
+ pass
@tvm.testing.uses_gpu