kszucs commented on a change in pull request #439:
URL: https://github.com/apache/arrow-rs/pull/439#discussion_r655484390



##########
File path: arrow-pyarrow-integration-testing/tests/test_sql.py
##########
@@ -16,84 +16,195 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import unittest
-
-import pyarrow
-import arrow_pyarrow_integration_testing
-
-
-class TestCase(unittest.TestCase):
-    def test_primitive_python(self):
-        """
-        Python -> Rust -> Python
-        """
-        old_allocated = pyarrow.total_allocated_bytes()
-        a = pyarrow.array([1, 2, 3])
-        b = arrow_pyarrow_integration_testing.double(a)
-        self.assertEqual(b, pyarrow.array([2, 4, 6]))
-        del a
-        del b
-        # No leak of C++ memory
-        self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())
-
-    def test_primitive_rust(self):
-        """
-        Rust -> Python -> Rust
-        """
-        old_allocated = pyarrow.total_allocated_bytes()
-
-        def double(array):
-            array = array.to_pylist()
-            return pyarrow.array([x * 2 if x is not None else None for x in 
array])
-
-        is_correct = arrow_pyarrow_integration_testing.double_py(double)
-        self.assertTrue(is_correct)
-        # No leak of C++ memory
-        self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())
-
-    def test_string_python(self):
-        """
-        Python -> Rust -> Python
-        """
-        old_allocated = pyarrow.total_allocated_bytes()
-        a = pyarrow.array(["a", None, "ccc"])
-        b = arrow_pyarrow_integration_testing.substring(a, 1)
-        self.assertEqual(b, pyarrow.array(["", None, "cc"]))
-        del a
-        del b
-        # No leak of C++ memory
-        self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())
-
-    def test_time32_python(self):
-        """
-        Python -> Rust -> Python
-        """
-        old_allocated = pyarrow.total_allocated_bytes()
-        a = pyarrow.array([None, 1, 2], pyarrow.time32('s'))
-        b = arrow_pyarrow_integration_testing.concatenate(a)
-        expected = pyarrow.array([None, 1, 2] + [None, 1, 2], 
pyarrow.time32('s'))
-        self.assertEqual(b, expected)
-        del a
-        del b
-        del expected
-        # No leak of C++ memory
-        self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())
-
-    def test_list_array(self):
-        """
-        Python -> Rust -> Python
-        """
-        old_allocated = pyarrow.total_allocated_bytes()
-        a = pyarrow.array([[], None, [1, 2], [4, 5, 6]], 
pyarrow.list_(pyarrow.int64()))
-        b = arrow_pyarrow_integration_testing.round_trip(a)
-
-        b.validate(full=True)
-        assert a.to_pylist() == b.to_pylist()
-        assert a.type == b.type
-        del a
-        del b
-        # No leak of C++ memory
-        self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())
+import contextlib
+import string
 
+import pytest
+import pyarrow as pa
 
+from arrow_pyarrow_integration_testing import PyDataType, PyField, PySchema
+import arrow_pyarrow_integration_testing as rust
 
+
+@contextlib.contextmanager
+def no_pyarrow_leak():
+    # No leak of C++ memory
+    old_allocation = pa.total_allocated_bytes()
+    try:
+        yield
+    finally:
+        assert pa.total_allocated_bytes() == old_allocation
+
+
+@pytest.fixture(autouse=True)
+def assert_pyarrow_leak():
+    # automatically applied to all test cases
+    with no_pyarrow_leak():
+        yield
+
+
+_supported_pyarrow_types = [
+    pa.null(),
+    pa.bool_(),
+    pa.int32(),
+    pa.time32("s"),
+    pa.time64("us"),
+    pa.date32(),
+    pa.float16(),
+    pa.float32(),
+    pa.float64(),
+    pa.string(),
+    pa.binary(),
+    pa.large_string(),
+    pa.large_binary(),
+    pa.list_(pa.int32()),
+    pa.large_list(pa.uint16()),
+    pa.struct(
+        [
+            pa.field("a", pa.int32()),
+            pa.field("b", pa.int8()),
+            pa.field("c", pa.string()),
+        ]
+    ),
+    pa.struct(
+        [
+            pa.field("a", pa.int32(), nullable=False),
+            pa.field("b", pa.int8(), nullable=False),
+            pa.field("c", pa.string()),
+        ]
+    ),
+]
+
+_unsupported_pyarrow_types = [

Review comment:
       https://github.com/apache/arrow-rs/issues/477




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to