This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 61d3a570b1 [python] fix oss path handling for pyarrow 6.x (py36
compatibility) (#7180)
61d3a570b1 is described below
commit 61d3a570b164edd5678e645e600e424db11103f5
Author: XiaoHongbo <[email protected]>
AuthorDate: Mon Feb 2 20:29:27 2026 +0800
[python] fix oss path handling for pyarrow 6.x (py36 compatibility) (#7180)
PyArrow 6.x with OSS uses endpoint_override that already includes the
bucket, so the path passed to the filesystem must be the object key only
(e.g.path/to/file.txt), not bucket/path/to/file.txt.
- to_filesystem_path: for OSS + PyArrow < 7.0.0, return path part only
(no netloc prefix); PyArrow >= 7 uses netloc/path like S3.
- new_output_stream: for OSS + PyArrow < 7.0.0, compute parent dir via
string split instead of Path(path_str).parent so mkdirs gets the correct
key prefix.
- S3: set force_virtual_addressing only when PyArrow >= 7.0.0.
**Add `file_io_test` to the py36 CI so py36 compatibility issues are
caught in CI instead of only discovered by users.**
---
paimon-python/dev/lint-python.sh | 2 +-
.../pypaimon/filesystem/pyarrow_file_io.py | 47 +++++++++---
paimon-python/pypaimon/tests/file_io_test.py | 34 +++++++++
.../pypaimon/tests/rest/rest_base_test.py | 1 +
.../pypaimon/tests/rest/rest_simple_test.py | 85 ++++++++++++++--------
5 files changed, 125 insertions(+), 44 deletions(-)
diff --git a/paimon-python/dev/lint-python.sh b/paimon-python/dev/lint-python.sh
index 44be287149..46b25a2a9b 100755
--- a/paimon-python/dev/lint-python.sh
+++ b/paimon-python/dev/lint-python.sh
@@ -176,7 +176,7 @@ function pytest_check() {
# Determine test directory based on Python version
if [ "$PYTHON_VERSION" = "3.6" ]; then
- TEST_DIR="pypaimon/tests/py36"
+ TEST_DIR="pypaimon/tests/py36 pypaimon/tests/file_io_test.py"
echo "Running tests for Python 3.6: $TEST_DIR"
else
TEST_DIR="pypaimon/tests --ignore=pypaimon/tests/py36
--ignore=pypaimon/tests/e2e --ignore=pypaimon/tests/torch_read_test.py"
diff --git a/paimon-python/pypaimon/filesystem/pyarrow_file_io.py
b/paimon-python/pypaimon/filesystem/pyarrow_file_io.py
index a2b83ed2cc..74cc26a208 100644
--- a/paimon-python/pypaimon/filesystem/pyarrow_file_io.py
+++ b/paimon-python/pypaimon/filesystem/pyarrow_file_io.py
@@ -42,9 +42,14 @@ class PyArrowFileIO(FileIO):
def __init__(self, path: str, catalog_options: Options):
self.properties = catalog_options
self.logger = logging.getLogger(__name__)
+ self._pyarrow_gte_7 = parse(pyarrow.__version__) >= parse("7.0.0")
+ self._pyarrow_gte_8 = parse(pyarrow.__version__) >= parse("8.0.0")
scheme, netloc, _ = self.parse_location(path)
self.uri_reader_factory = UriReaderFactory(catalog_options)
- if scheme in {"oss"}:
+ self._is_oss = scheme in {"oss"}
+ self._oss_bucket = None
+ if self._is_oss:
+ self._oss_bucket = self._extract_oss_bucket(path)
self.filesystem = self._initialize_oss_fs(path)
elif scheme in {"s3", "s3a", "s3n"}:
self.filesystem = self._initialize_s3_fs()
@@ -63,13 +68,13 @@ class PyArrowFileIO(FileIO):
else:
return uri.scheme, uri.netloc, f"{uri.netloc}{uri.path}"
- @staticmethod
def _create_s3_retry_config(
+ self,
max_attempts: int = 10,
request_timeout: int = 60,
connect_timeout: int = 60
) -> Dict[str, Any]:
- if parse(pyarrow.__version__) >= parse("8.0.0"):
+ if self._pyarrow_gte_8:
config = {
'request_timeout': request_timeout,
'connect_timeout': connect_timeout
@@ -114,12 +119,11 @@ class PyArrowFileIO(FileIO):
"region": self.properties.get(OssOptions.OSS_REGION),
}
- if parse(pyarrow.__version__) >= parse("7.0.0"):
+ if self._pyarrow_gte_7:
client_kwargs['force_virtual_addressing'] = True
client_kwargs['endpoint_override'] =
self.properties.get(OssOptions.OSS_ENDPOINT)
else:
- oss_bucket = self._extract_oss_bucket(path)
- client_kwargs['endpoint_override'] = (oss_bucket + "." +
+ client_kwargs['endpoint_override'] = (self._oss_bucket + "." +
self.properties.get(OssOptions.OSS_ENDPOINT))
retry_config = self._create_s3_retry_config()
@@ -136,8 +140,9 @@ class PyArrowFileIO(FileIO):
"secret_key": self.properties.get(S3Options.S3_ACCESS_KEY_SECRET),
"session_token": self.properties.get(S3Options.S3_SECURITY_TOKEN),
"region": self.properties.get(S3Options.S3_REGION),
- "force_virtual_addressing": True,
}
+ if self._pyarrow_gte_7:
+ client_kwargs["force_virtual_addressing"] = True
retry_config = self._create_s3_retry_config()
client_kwargs.update(retry_config)
@@ -177,9 +182,20 @@ class PyArrowFileIO(FileIO):
def new_output_stream(self, path: str):
path_str = self.to_filesystem_path(path)
- parent_dir = Path(path_str).parent
- if str(parent_dir) and not self.exists(str(parent_dir)):
- self.mkdirs(str(parent_dir))
+
+ if self._is_oss and not self._pyarrow_gte_7:
+ # For PyArrow 6.x + OSS, path_str is already just the key part
+ if '/' in path_str:
+ parent_dir = '/'.join(path_str.split('/')[:-1])
+ else:
+ parent_dir = ''
+
+ if parent_dir and not self.exists(parent_dir):
+ self.mkdirs(parent_dir)
+ else:
+ parent_dir = Path(path_str).parent
+ if str(parent_dir) and not self.exists(str(parent_dir)):
+ self.mkdirs(str(parent_dir))
return self.filesystem.open_output_stream(path_str)
@@ -491,11 +507,18 @@ class PyArrowFileIO(FileIO):
if parsed.scheme:
if parsed.netloc:
path_part = normalized_path.lstrip('/')
- return f"{parsed.netloc}/{path_part}" if path_part else
parsed.netloc
+ if self._is_oss and not self._pyarrow_gte_7:
+ # For PyArrow 6.x + OSS, endpoint_override already
contains bucket,
+ result = path_part if path_part else '.'
+ return result
+ else:
+ result = f"{parsed.netloc}/{path_part}" if path_part
else parsed.netloc
+ return result
else:
result = normalized_path.lstrip('/')
return result if result else '.'
- return str(path)
+ else:
+ return str(path)
if parsed.scheme:
if not normalized_path:
diff --git a/paimon-python/pypaimon/tests/file_io_test.py
b/paimon-python/pypaimon/tests/file_io_test.py
index 6cdf713b0c..ca834040d1 100644
--- a/paimon-python/pypaimon/tests/file_io_test.py
+++ b/paimon-python/pypaimon/tests/file_io_test.py
@@ -26,6 +26,7 @@ import pyarrow
from pyarrow.fs import S3FileSystem
from pypaimon.common.options import Options
+from pypaimon.common.options.config import OssOptions
from pypaimon.filesystem.local_file_io import LocalFileIO
from pypaimon.filesystem.pyarrow_file_io import PyArrowFileIO
@@ -65,6 +66,30 @@ class FileIOTest(unittest.TestCase):
parent_str = str(Path(converted_path).parent)
self.assertEqual(file_io.to_filesystem_path(parent_str), parent_str)
+ from packaging.version import parse as parse_version
+ oss_io = PyArrowFileIO("oss://test-bucket/warehouse", Options({
+ OssOptions.OSS_ENDPOINT.key(): 'oss-cn-hangzhou.aliyuncs.com'
+ }))
+ lt7 = parse_version(pyarrow.__version__) < parse_version("7.0.0")
+ got = oss_io.to_filesystem_path("oss://test-bucket/path/to/file.txt")
+ expected_path = (
+ "path/to/file.txt" if lt7 else "test-bucket/path/to/file.txt")
+ self.assertEqual(got, expected_path)
+ nf = MagicMock(type=pyarrow.fs.FileType.NotFound)
+ mock_fs = MagicMock()
+ mock_fs.get_file_info.side_effect = [[nf], [nf]]
+ mock_fs.create_dir = MagicMock()
+ mock_fs.open_output_stream.return_value = MagicMock()
+ oss_io.filesystem = mock_fs
+ oss_io.new_output_stream("oss://test-bucket/path/to/file.txt")
+ mock_fs.create_dir.assert_called_once()
+ path_str =
oss_io.to_filesystem_path("oss://test-bucket/path/to/file.txt")
+ if lt7:
+ expected_parent = '/'.join(path_str.split('/')[:-1]) if '/' in
path_str else ''
+ else:
+ expected_parent = str(Path(path_str).parent)
+ self.assertEqual(mock_fs.create_dir.call_args[0][0], expected_parent)
+
def test_local_filesystem_path_conversion(self):
file_io = LocalFileIO("file:///tmp/warehouse", Options({}))
self.assertIsInstance(file_io, LocalFileIO)
@@ -213,6 +238,15 @@ class FileIOTest(unittest.TestCase):
file_io.delete_quietly("file:///some/path")
file_io.delete_directory_quietly("file:///some/path")
+
+ oss_io = PyArrowFileIO("oss://test-bucket/warehouse", Options({
+ OssOptions.OSS_ENDPOINT.key(): 'oss-cn-hangzhou.aliyuncs.com'
+ }))
+ mock_fs = MagicMock()
+ mock_fs.get_file_info.return_value = [
+ MagicMock(type=pyarrow.fs.FileType.NotFound)]
+ oss_io.filesystem = mock_fs
+
self.assertFalse(oss_io.exists("oss://test-bucket/path/to/file.txt"))
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
diff --git a/paimon-python/pypaimon/tests/rest/rest_base_test.py
b/paimon-python/pypaimon/tests/rest/rest_base_test.py
index 2fc14eb111..19419428a5 100644
--- a/paimon-python/pypaimon/tests/rest/rest_base_test.py
+++ b/paimon-python/pypaimon/tests/rest/rest_base_test.py
@@ -161,6 +161,7 @@ class RESTBaseTest(unittest.TestCase):
logging.basicConfig(level=logging.INFO)
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'])
+ self.rest_catalog.drop_table("default.test_table", True)
self.rest_catalog.create_table("default.test_table", schema, False)
table = self.rest_catalog.get_table("default.test_table")
diff --git a/paimon-python/pypaimon/tests/rest/rest_simple_test.py
b/paimon-python/pypaimon/tests/rest/rest_simple_test.py
index 2adce40496..3e29f950df 100644
--- a/paimon-python/pypaimon/tests/rest/rest_simple_test.py
+++ b/paimon-python/pypaimon/tests/rest/rest_simple_test.py
@@ -16,6 +16,8 @@ See the License for the specific language governing
permissions and
limitations under the License.
"""
+import sys
+
import pyarrow as pa
from pypaimon import Schema
@@ -27,6 +29,12 @@ from pypaimon.tests.rest.rest_base_test import RESTBaseTest
from pypaimon.write.row_key_extractor import FixedBucketRowKeyExtractor,
DynamicBucketRowKeyExtractor, \
UnawareBucketRowKeyExtractor
+if sys.version_info[:2] == (3, 6):
+ from pypaimon.tests.py36.pyarrow_compat import table_sort_by
+else:
+ def table_sort_by(table: pa.Table, column_name: str, order: str =
'ascending') -> pa.Table:
+ return table.sort_by([(column_name, order)])
+
class RESTSimpleTest(RESTBaseTest):
def setUp(self):
@@ -47,6 +55,7 @@ class RESTSimpleTest(RESTBaseTest):
def test_with_shard_ao_unaware_bucket(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'])
+
self.rest_catalog.drop_table('default.test_with_shard_ao_unaware_bucket', True)
self.rest_catalog.create_table('default.test_with_shard_ao_unaware_bucket',
schema, False)
table =
self.rest_catalog.get_table('default.test_with_shard_ao_unaware_bucket')
write_builder = table.new_batch_write_builder()
@@ -82,7 +91,7 @@ class RESTSimpleTest(RESTBaseTest):
read_builder = table.new_read_builder()
table_read = read_builder.new_read()
splits = read_builder.new_scan().with_shard(2, 3).plan().splits()
- actual = table_read.to_arrow(splits).sort_by('user_id')
+ actual = table_sort_by(table_read.to_arrow(splits), 'user_id')
expected = pa.Table.from_pydict({
'user_id': [5, 7, 8, 9, 11, 13],
'item_id': [1005, 1007, 1008, 1009, 1011, 1013],
@@ -93,21 +102,22 @@ class RESTSimpleTest(RESTBaseTest):
# Get the three actual tables
splits1 = read_builder.new_scan().with_shard(0, 3).plan().splits()
- actual1 = table_read.to_arrow(splits1).sort_by('user_id')
+ actual1 = table_sort_by(table_read.to_arrow(splits1), 'user_id')
splits2 = read_builder.new_scan().with_shard(1, 3).plan().splits()
- actual2 = table_read.to_arrow(splits2).sort_by('user_id')
+ actual2 = table_sort_by(table_read.to_arrow(splits2), 'user_id')
splits3 = read_builder.new_scan().with_shard(2, 3).plan().splits()
- actual3 = table_read.to_arrow(splits3).sort_by('user_id')
+ actual3 = table_sort_by(table_read.to_arrow(splits3), 'user_id')
# Concatenate the three tables
- actual = pa.concat_tables([actual1, actual2,
actual3]).sort_by('user_id')
- expected = self._read_test_table(read_builder).sort_by('user_id')
+ actual = table_sort_by(pa.concat_tables([actual1, actual2, actual3]),
'user_id')
+ expected = table_sort_by(self._read_test_table(read_builder),
'user_id')
self.assertEqual(actual, expected)
def test_with_shard_ao_unaware_bucket_manual(self):
"""Test shard_ao_unaware_bucket with setting bucket -1 manually"""
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'],
options={'bucket': '-1'})
+
self.rest_catalog.drop_table("default.test_with_shard_ao_unaware_bucket_manual",
True)
self.rest_catalog.create_table('default.test_with_shard_ao_unaware_bucket_manual',
schema, False)
table =
self.rest_catalog.get_table('default.test_with_shard_ao_unaware_bucket_manual')
write_builder = table.new_batch_write_builder()
@@ -134,7 +144,7 @@ class RESTSimpleTest(RESTBaseTest):
# Test first shard (0, 2) - should get first 3 rows
plan = read_builder.new_scan().with_shard(0, 2).plan()
- actual = table_read.to_arrow(plan.splits()).sort_by('user_id')
+ actual = table_sort_by(table_read.to_arrow(plan.splits()), 'user_id')
expected = pa.Table.from_pydict({
'user_id': [1, 2, 3],
'item_id': [1001, 1002, 1003],
@@ -145,7 +155,7 @@ class RESTSimpleTest(RESTBaseTest):
# Test second shard (1, 2) - should get last 3 rows
plan = read_builder.new_scan().with_shard(1, 2).plan()
- actual = table_read.to_arrow(plan.splits()).sort_by('user_id')
+ actual = table_sort_by(table_read.to_arrow(plan.splits()), 'user_id')
expected = pa.Table.from_pydict({
'user_id': [4, 5, 6],
'item_id': [1004, 1005, 1006],
@@ -157,6 +167,7 @@ class RESTSimpleTest(RESTBaseTest):
def test_with_shard_ao_fixed_bucket(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'],
options={'bucket': '5',
'bucket-key': 'item_id'})
+
self.rest_catalog.drop_table('default.test_with_slice_ao_fixed_bucket', True)
self.rest_catalog.create_table('default.test_with_slice_ao_fixed_bucket',
schema, False)
table =
self.rest_catalog.get_table('default.test_with_slice_ao_fixed_bucket')
write_builder = table.new_batch_write_builder()
@@ -192,7 +203,7 @@ class RESTSimpleTest(RESTBaseTest):
read_builder = table.new_read_builder()
table_read = read_builder.new_read()
splits = read_builder.new_scan().with_shard(0, 3).plan().splits()
- actual = table_read.to_arrow(splits).sort_by('user_id')
+ actual = table_sort_by(table_read.to_arrow(splits), 'user_id')
expected = pa.Table.from_pydict({
'user_id': [1, 2, 3, 5, 8, 12],
'item_id': [1001, 1002, 1003, 1005, 1008, 1012],
@@ -203,20 +214,21 @@ class RESTSimpleTest(RESTBaseTest):
# Get the three actual tables
splits1 = read_builder.new_scan().with_shard(0, 3).plan().splits()
- actual1 = table_read.to_arrow(splits1).sort_by('user_id')
+ actual1 = table_sort_by(table_read.to_arrow(splits1), 'user_id')
splits2 = read_builder.new_scan().with_shard(1, 3).plan().splits()
- actual2 = table_read.to_arrow(splits2).sort_by('user_id')
+ actual2 = table_sort_by(table_read.to_arrow(splits2), 'user_id')
splits3 = read_builder.new_scan().with_shard(2, 3).plan().splits()
- actual3 = table_read.to_arrow(splits3).sort_by('user_id')
+ actual3 = table_sort_by(table_read.to_arrow(splits3), 'user_id')
# Concatenate the three tables
- actual = pa.concat_tables([actual1, actual2,
actual3]).sort_by('user_id')
- expected = self._read_test_table(read_builder).sort_by('user_id')
+ actual = table_sort_by(pa.concat_tables([actual1, actual2, actual3]),
'user_id')
+ expected = table_sort_by(self._read_test_table(read_builder),
'user_id')
self.assertEqual(actual, expected)
def test_with_shard_single_partition(self):
"""Test sharding with single partition - tests _filter_by_shard with
simple data"""
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'])
+ self.rest_catalog.drop_table('default.test_shard_single_partition',
True)
self.rest_catalog.create_table('default.test_shard_single_partition',
schema, False)
table =
self.rest_catalog.get_table('default.test_shard_single_partition')
write_builder = table.new_batch_write_builder()
@@ -241,7 +253,7 @@ class RESTSimpleTest(RESTBaseTest):
# Test first shard (0, 2) - should get first 3 rows
plan = read_builder.new_scan().with_shard(0, 2).plan()
- actual = table_read.to_arrow(plan.splits()).sort_by('user_id')
+ actual = table_sort_by(table_read.to_arrow(plan.splits()), 'user_id')
expected = pa.Table.from_pydict({
'user_id': [1, 2, 3],
'item_id': [1001, 1002, 1003],
@@ -252,7 +264,7 @@ class RESTSimpleTest(RESTBaseTest):
# Test second shard (1, 2) - should get last 3 rows
plan = read_builder.new_scan().with_shard(1, 2).plan()
- actual = table_read.to_arrow(plan.splits()).sort_by('user_id')
+ actual = table_sort_by(table_read.to_arrow(plan.splits()), 'user_id')
expected = pa.Table.from_pydict({
'user_id': [4, 5, 6],
'item_id': [1004, 1005, 1006],
@@ -264,6 +276,7 @@ class RESTSimpleTest(RESTBaseTest):
def test_with_shard_uneven_distribution(self):
"""Test sharding with uneven row distribution across shards"""
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'])
+ self.rest_catalog.drop_table('default.test_shard_uneven', True)
self.rest_catalog.create_table('default.test_shard_uneven', schema,
False)
table = self.rest_catalog.get_table('default.test_shard_uneven')
write_builder = table.new_batch_write_builder()
@@ -288,7 +301,7 @@ class RESTSimpleTest(RESTBaseTest):
# Test sharding into 3 parts: 3, 2, 2 rows
plan1 = read_builder.new_scan().with_shard(0, 3).plan()
- actual1 = table_read.to_arrow(plan1.splits()).sort_by('user_id')
+ actual1 = table_sort_by(table_read.to_arrow(plan1.splits()), 'user_id')
expected1 = pa.Table.from_pydict({
'user_id': [1, 2, 3],
'item_id': [1001, 1002, 1003],
@@ -298,7 +311,7 @@ class RESTSimpleTest(RESTBaseTest):
self.assertEqual(actual1, expected1)
plan2 = read_builder.new_scan().with_shard(1, 3).plan()
- actual2 = table_read.to_arrow(plan2.splits()).sort_by('user_id')
+ actual2 = table_sort_by(table_read.to_arrow(plan2.splits()), 'user_id')
expected2 = pa.Table.from_pydict({
'user_id': [4, 5],
'item_id': [1004, 1005],
@@ -308,7 +321,7 @@ class RESTSimpleTest(RESTBaseTest):
self.assertEqual(actual2, expected2)
plan3 = read_builder.new_scan().with_shard(2, 3).plan()
- actual3 = table_read.to_arrow(plan3.splits()).sort_by('user_id')
+ actual3 = table_sort_by(table_read.to_arrow(plan3.splits()), 'user_id')
expected3 = pa.Table.from_pydict({
'user_id': [6, 7],
'item_id': [1006, 1007],
@@ -320,6 +333,7 @@ class RESTSimpleTest(RESTBaseTest):
def test_with_shard_single_shard(self):
"""Test sharding with only one shard - should return all data"""
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'])
+ self.rest_catalog.drop_table('default.test_shard_single', True)
self.rest_catalog.create_table('default.test_shard_single', schema,
False)
table = self.rest_catalog.get_table('default.test_shard_single')
write_builder = table.new_batch_write_builder()
@@ -343,13 +357,14 @@ class RESTSimpleTest(RESTBaseTest):
# Test single shard (0, 1) - should get all data
plan = read_builder.new_scan().with_shard(0, 1).plan()
- actual = table_read.to_arrow(plan.splits()).sort_by('user_id')
+ actual = table_sort_by(table_read.to_arrow(plan.splits()), 'user_id')
expected = pa.Table.from_pydict(data, schema=self.pa_schema)
self.assertEqual(actual, expected)
def test_with_shard_many_small_shards(self):
"""Test sharding with many small shards"""
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'])
+ self.rest_catalog.drop_table('default.test_shard_many_small', True)
self.rest_catalog.create_table('default.test_shard_many_small',
schema, False)
table = self.rest_catalog.get_table('default.test_shard_many_small')
write_builder = table.new_batch_write_builder()
@@ -381,6 +396,7 @@ class RESTSimpleTest(RESTBaseTest):
def test_with_shard_boundary_conditions(self):
"""Test sharding boundary conditions with edge cases"""
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'])
+ self.rest_catalog.drop_table('default.test_shard_boundary', True)
self.rest_catalog.create_table('default.test_shard_boundary', schema,
False)
table = self.rest_catalog.get_table('default.test_shard_boundary')
write_builder = table.new_batch_write_builder()
@@ -421,6 +437,7 @@ class RESTSimpleTest(RESTBaseTest):
"""Test with_shard method using 50000 rows of data to verify
performance and correctness"""
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'],
options={'bucket': '5',
'bucket-key': 'item_id'})
+ self.rest_catalog.drop_table('default.test_with_shard_large_dataset',
True)
self.rest_catalog.create_table('default.test_with_shard_large_dataset', schema,
False)
table =
self.rest_catalog.get_table('default.test_with_shard_large_dataset')
write_builder = table.new_batch_write_builder()
@@ -463,11 +480,11 @@ class RESTSimpleTest(RESTBaseTest):
print(f"Shard {shard_idx}/{num_shards}: {shard_rows} rows")
# Verify that all shards together contain all the data
- concatenated_result =
pa.concat_tables(shard_results).sort_by('user_id')
+ concatenated_result = table_sort_by(pa.concat_tables(shard_results),
'user_id')
# Read all data without sharding for comparison
all_splits = read_builder.new_scan().plan().splits()
- all_data = table_read.to_arrow(all_splits).sort_by('user_id')
+ all_data = table_sort_by(table_read.to_arrow(all_splits), 'user_id')
# Verify total row count
self.assertEqual(len(concatenated_result), len(all_data))
@@ -492,13 +509,13 @@ class RESTSimpleTest(RESTBaseTest):
shard_10_results.append(shard_result)
if shard_10_results:
- concatenated_10_shards =
pa.concat_tables(shard_10_results).sort_by('user_id')
+ concatenated_10_shards =
table_sort_by(pa.concat_tables(shard_10_results), 'user_id')
self.assertEqual(len(concatenated_10_shards), num_rows)
self.assertEqual(concatenated_10_shards, all_data)
# Test with single shard (should return all data)
single_shard_splits = read_builder.new_scan().with_shard(0,
1).plan().splits()
- single_shard_result =
table_read.to_arrow(single_shard_splits).sort_by('user_id')
+ single_shard_result =
table_sort_by(table_read.to_arrow(single_shard_splits), 'user_id')
self.assertEqual(len(single_shard_result), num_rows)
self.assertEqual(single_shard_result, all_data)
@@ -507,6 +524,7 @@ class RESTSimpleTest(RESTBaseTest):
def test_with_shard_large_dataset_one_commit(self):
"""Test with_shard method using 50000 rows of data to verify
performance and correctness"""
schema = Schema.from_pyarrow_schema(self.pa_schema)
+ self.rest_catalog.drop_table('default.test_with_shard_large_dataset',
True)
self.rest_catalog.create_table('default.test_with_shard_large_dataset', schema,
False)
table =
self.rest_catalog.get_table('default.test_with_shard_large_dataset')
write_builder = table.new_batch_write_builder()
@@ -541,11 +559,11 @@ class RESTSimpleTest(RESTBaseTest):
print(f"Shard {shard_idx}/{num_shards}: {shard_rows} rows")
# Verify that all shards together contain all the data
- concatenated_result =
pa.concat_tables(shard_results).sort_by('user_id')
+ concatenated_result = table_sort_by(pa.concat_tables(shard_results),
'user_id')
# Read all data without sharding for comparison
all_splits = read_builder.new_scan().plan().splits()
- all_data = table_read.to_arrow(all_splits).sort_by('user_id')
+ all_data = table_sort_by(table_read.to_arrow(all_splits), 'user_id')
# Verify total row count
self.assertEqual(len(concatenated_result), len(all_data))
@@ -564,6 +582,7 @@ class RESTSimpleTest(RESTBaseTest):
def test_with_shard_parameter_validation(self):
"""Test edge cases for parameter validation"""
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'])
+ self.rest_catalog.drop_table('default.test_shard_validation_edge',
True)
self.rest_catalog.create_table('default.test_shard_validation_edge',
schema, False)
table =
self.rest_catalog.get_table('default.test_shard_validation_edge')
@@ -575,6 +594,7 @@ class RESTSimpleTest(RESTBaseTest):
def test_with_shard_pk_dynamic_bucket(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['user_id'], primary_keys=['user_id', 'dt'])
+ self.rest_catalog.drop_table('default.test_with_shard', True)
self.rest_catalog.create_table('default.test_with_shard', schema,
False)
table = self.rest_catalog.get_table('default.test_with_shard')
@@ -592,6 +612,7 @@ class RESTSimpleTest(RESTBaseTest):
def test_with_shard_pk_fixed_bucket(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['user_id'], primary_keys=['user_id', 'dt'],
options={'bucket': '5'})
+ self.rest_catalog.drop_table('default.test_with_shard', True)
self.rest_catalog.create_table('default.test_with_shard', schema,
False)
table = self.rest_catalog.get_table('default.test_with_shard')
@@ -625,6 +646,7 @@ class RESTSimpleTest(RESTBaseTest):
def test_with_shard_uniform_division(self):
schema = Schema.from_pyarrow_schema(self.pa_schema,
partition_keys=['dt'])
+ self.rest_catalog.drop_table('default.with_shard_uniform_division',
True)
self.rest_catalog.create_table('default.with_shard_uniform_division',
schema, False)
table =
self.rest_catalog.get_table('default.with_shard_uniform_division')
write_builder = table.new_batch_write_builder()
@@ -647,17 +669,17 @@ class RESTSimpleTest(RESTBaseTest):
# Get the three actual tables
splits1 = read_builder.new_scan().with_shard(0, 3).plan().splits()
- actual1 = table_read.to_arrow(splits1).sort_by('user_id')
+ actual1 = table_sort_by(table_read.to_arrow(splits1), 'user_id')
splits2 = read_builder.new_scan().with_shard(1, 3).plan().splits()
- actual2 = table_read.to_arrow(splits2).sort_by('user_id')
+ actual2 = table_sort_by(table_read.to_arrow(splits2), 'user_id')
splits3 = read_builder.new_scan().with_shard(2, 3).plan().splits()
- actual3 = table_read.to_arrow(splits3).sort_by('user_id')
+ actual3 = table_sort_by(table_read.to_arrow(splits3), 'user_id')
self.assertEqual(5, len(actual1))
self.assertEqual(5, len(actual2))
self.assertEqual(4, len(actual3))
# Concatenate the three tables
- actual = pa.concat_tables([actual1, actual2,
actual3]).sort_by('user_id')
- expected = self._read_test_table(read_builder).sort_by('user_id')
+ actual = table_sort_by(pa.concat_tables([actual1, actual2, actual3]),
'user_id')
+ expected = table_sort_by(self._read_test_table(read_builder),
'user_id')
self.assertEqual(expected, actual)
def test_create_drop_database_table(self):
@@ -728,6 +750,7 @@ class RESTSimpleTest(RESTBaseTest):
options={},
comment="comment"
)
+ catalog.drop_table(identifier, True)
catalog.create_table(identifier, schema, False)
catalog.alter_table(