quic-sanirudh commented on a change in pull request #9017:
URL: https://github.com/apache/tvm/pull/9017#discussion_r710050671



##########
File path: tests/python/frontend/onnx/test_forward.py
##########
@@ -3056,6 +3056,152 @@ def verify_global_pooling(x_shape, mode):
         verify_global_pooling([4, 1, 2, 6, 4], mode)
 
 
[email protected]_targets
+def test_qlinear_average_pool(target, dev):
+    def verify_qlinear_average_pool(
+        x_shape, kernel_shape, strides, pads, out_shape, auto_pad="NOTSET"
+    ):
+        input_nodes = [
+            helper.make_tensor_value_info("X", TensorProto.FLOAT, 
list(x_shape)),
+        ]
+
+        output_nodes = [
+            helper.make_tensor_value_info("Y", TensorProto.FLOAT, 
list(out_shape)),
+        ]
+
+        input_names = ["X"]
+
+        node = helper.make_node(
+            "AveragePool",
+            inputs=input_names,
+            outputs=["Y"],
+            kernel_shape=kernel_shape,
+            strides=strides,
+        )
+
+        if pads is None:
+            pad_attr = helper.make_attribute("auto_pad", auto_pad)
+        else:
+            pad_attr = helper.make_attribute("pads", pads)
+        node.attribute.append(pad_attr)
+
+        graph = helper.make_graph(
+            [node],
+            "qlinear_average_pool_test",
+            inputs=input_nodes,
+            outputs=output_nodes,
+        )
+
+        model = helper.make_model(graph, 
producer_name="qlinear_average_pool_Test")
+        quantize_and_verify_with_ort(model, input_names, [x_shape], target, 
dev)
+
+    # Pool1D
+    verify_qlinear_average_pool(
+        x_shape=[1, 1, 32],
+        kernel_shape=[3],
+        strides=[1],
+        pads=[1, 1],
+        out_shape=[1, 1, 32],
+    )
+    # Pool2D
+    verify_qlinear_average_pool(
+        x_shape=[1, 1, 32, 32],
+        kernel_shape=[3, 3],
+        strides=[1, 1],
+        pads=[1, 1, 1, 1],
+        out_shape=[1, 1, 32, 32],
+    )
+
+    # Pool1D with stride
+    verify_qlinear_average_pool(
+        x_shape=[1, 1, 32],
+        kernel_shape=[3],
+        strides=[2],
+        pads=[1, 1],
+        out_shape=[1, 1, 16],
+    )
+    # Pool2D with stride
+    verify_qlinear_average_pool(
+        x_shape=[1, 1, 32, 32],
+        kernel_shape=[3, 3],
+        strides=[2, 2],
+        pads=[1, 1, 1, 1],
+        out_shape=[1, 1, 16, 16],
+    )
+
+    # Pool1D with stride and autopadding
+    verify_qlinear_average_pool(
+        x_shape=[1, 1, 32],
+        kernel_shape=[3],
+        strides=[2],
+        pads=None,
+        out_shape=[1, 1, 16],
+        auto_pad="SAME_UPPER",
+    )
+    # Pool2D with stride and autopadding
+    verify_qlinear_average_pool(
+        x_shape=[1, 1, 32, 32],
+        kernel_shape=[3, 3],
+        strides=[2, 2],
+        pads=None,
+        out_shape=[1, 1, 16, 16],
+        auto_pad="SAME_UPPER",
+    )
+
+    # Pool3D with stride
+    verify_qlinear_average_pool(
+        x_shape=[1, 1, 32, 32, 32],
+        kernel_shape=[3, 3, 3],
+        strides=[2, 2, 2],
+        pads=[1, 1, 1, 1, 1, 1],
+        out_shape=[1, 1, 16, 16, 16],
+    )
+
+    # Pool3D with stride and autopadding
+    verify_qlinear_average_pool(
+        x_shape=[1, 1, 32, 32, 32],
+        kernel_shape=[3, 3, 3],
+        strides=[2, 2, 2],
+        pads=None,
+        out_shape=[1, 1, 16, 16, 16],
+        auto_pad="SAME_UPPER",
+    )
+
+
[email protected]_targets
+def test_qlinear_global_average_pool(target, dev):
+    def verify_qlinear_global_average_pool(x_shape):
+        out_shape = x_shape[:2] + [1] * (len(x_shape) - 2)
+
+        node_type = "GlobalAveragePool"
+
+        input_names = ["X"]
+
+        pool_node = helper.make_node(node_type, inputs=input_names, 
outputs=["Y"])

Review comment:
       It should be the floating point op name as mentioned in the comment 
above.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to