This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 021cd92b59 [python] Pre-repartition Ray writes by (partition, bucket)
for fixed-bucket tables (#7813)
021cd92b59 is described below
commit 021cd92b5927e9cc86783f8f68ff9b03ff586ff5
Author: chaoyang <[email protected]>
AuthorDate: Sat May 16 00:28:00 2026 +0800
[python] Pre-repartition Ray writes by (partition, bucket) for fixed-bucket
tables (#7813)
When `write_paimon` is given a Ray Dataset, Ray's default round-robin
block distribution scatters rows that share the same `(partition,
bucket)` across many Ray tasks. Each task opens its own writer and emits
its own data file, so the write produces
`partitions × buckets × ray_tasks` files instead of the
`partitions × buckets` the writer would naturally produce.
Spark and Flink already cluster rows by `(partition, bucket)` before
writing — see `PaimonSparkWriter.repartitionByPartitionsAndBucket` and
the `RowAssignerChannelComputer` / `RowWithBucketChannelComputer` chain.
This PR brings the same pre-clustering to the Ray path.
---
docs/content/pypaimon/ray-data.md | 14 ++
paimon-python/pypaimon/ray/ray_paimon.py | 9 +
paimon-python/pypaimon/ray/shuffle.py | 136 +++++++++++
.../pypaimon/tests/ray_repartition_test.py | 263 +++++++++++++++++++++
.../pypaimon/tests/test_ray_shuffle_helper.py | 210 ++++++++++++++++
5 files changed, 632 insertions(+)
diff --git a/docs/content/pypaimon/ray-data.md
b/docs/content/pypaimon/ray-data.md
index b0ea958849..209c7c67a9 100644
--- a/docs/content/pypaimon/ray-data.md
+++ b/docs/content/pypaimon/ray-data.md
@@ -207,6 +207,20 @@ write_paimon(
)
```
+**Automatic (partition, bucket) clustering for HASH_FIXED tables:**
+
+For HASH_FIXED tables, `write_paimon` automatically clusters rows by
+`(partition_keys..., bucket)` before writing so each (partition,
+bucket) lands in a single Ray task — one writer, one file group. This
+avoids the small-file storm that Ray's default round-robin
+distribution would otherwise produce (`partitions × buckets ×
+ray_tasks` files instead of `partitions × buckets`).
+
+Bucket assignment uses the same hash routine the writer uses, so the
+bucket seen by the groupby is byte-equivalent to the one the writer
+would compute. No user configuration is required. For non-HASH_FIXED
+tables the dataset is written as-is.
+
**Parameters:**
- `dataset`: the Ray Dataset to write.
- `table_identifier`: full table name, e.g. `"db_name.table_name"`.
diff --git a/paimon-python/pypaimon/ray/ray_paimon.py
b/paimon-python/pypaimon/ray/ray_paimon.py
index bd81949394..86505097d8 100644
--- a/paimon-python/pypaimon/ray/ray_paimon.py
+++ b/paimon-python/pypaimon/ray/ray_paimon.py
@@ -117,6 +117,12 @@ def write_paimon(
) -> None:
"""Write a Ray Dataset to a Paimon table.
+ For HASH_FIXED tables, rows are automatically clustered by
+ ``(partition_keys..., bucket)`` before writing so that each
+ (partition, bucket) lands in a single Ray task. This avoids the
+ small-file storm that Ray's default round-robin distribution would
+ otherwise produce. No user configuration is required.
+
Args:
dataset: The Ray Dataset to write.
table_identifier: Full table name, e.g. ``"db_name.table_name"``.
@@ -126,11 +132,14 @@ def write_paimon(
ray_remote_args: Optional kwargs passed to ``ray.remote`` in write
tasks.
"""
from pypaimon.catalog.catalog_factory import CatalogFactory
+ from pypaimon.ray.shuffle import maybe_apply_repartition
from pypaimon.write.ray_datasink import PaimonDatasink
catalog = CatalogFactory.create(catalog_options)
table = catalog.get_table(table_identifier)
+ dataset = maybe_apply_repartition(dataset, table)
+
datasink = PaimonDatasink(table, overwrite=overwrite)
write_kwargs = {}
diff --git a/paimon-python/pypaimon/ray/shuffle.py
b/paimon-python/pypaimon/ray/shuffle.py
new file mode 100644
index 0000000000..b17f7a7ab1
--- /dev/null
+++ b/paimon-python/pypaimon/ray/shuffle.py
@@ -0,0 +1,136 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+"""Pre-repartition a Ray Dataset by (partition, bucket) before writing
+to a Paimon table.
+
+Without this, Ray's default round-robin block distribution scatters rows
+that share the same (partition, bucket) across many Ray tasks. Each
+task then opens its own writer and emits its own data file, producing
+``partitions × buckets × ray_tasks`` files instead of the
+``partitions × buckets`` the writer would naturally produce.
+
+For HASH_FIXED tables we group rows by ``(partition_keys..., bucket)``
+so every distinct group lands in a single Ray task. ``bucket`` is
+computed using the same ``FixedBucketRowKeyExtractor`` the writer
+uses, so the bucket assignment seen by the groupby is byte-equivalent
+to the writer's. HASH_FIXED writes are always pre-clustered; no user
+opt-in is required.
+
+For any other bucket mode the dataset is returned unchanged.
+"""
+
+import uuid
+from typing import TYPE_CHECKING, List
+
+import pyarrow as pa
+
+from pypaimon.table.bucket_mode import BucketMode
+
+if TYPE_CHECKING:
+ import ray.data
+
+ from pypaimon.table.table import Table
+
+# Default transient column name. A collision-safe variant is picked at
+# runtime by ``_pick_bucket_col_name`` so user tables that happen to
+# contain a column with this name still work correctly.
+BUCKET_KEY_COL = "__paimon_bucket__"
+
+
+def _pick_bucket_col_name(existing_names) -> str:
+ """Return a bucket column name guaranteed not to collide with
+ ``existing_names``. Falls back to a UUID suffix on collision."""
+ if BUCKET_KEY_COL not in existing_names:
+ return BUCKET_KEY_COL
+ while True:
+ candidate = "__paimon_bucket_{}_".format(uuid.uuid4().hex[:8])
+ if candidate not in existing_names:
+ return candidate
+
+
+def maybe_apply_repartition(
+ dataset: "ray.data.Dataset",
+ table: "Table",
+) -> "ray.data.Dataset":
+ """Cluster rows by ``(partition_keys..., bucket)`` for HASH_FIXED tables.
+
+ For any other bucket mode the dataset is returned unchanged.
+ HASH_FIXED writes are always pre-clustered, with no user opt-in
+ required.
+ """
+ if table.bucket_mode() != BucketMode.HASH_FIXED:
+ return dataset
+
+ partition_keys = list(table.table_schema.partition_keys or [])
+ extractor = table.create_row_key_extractor()
+ col_names = set(f.name for f in table.table_schema.fields)
+ bucket_col = _pick_bucket_col_name(col_names)
+ bucket_udf = _make_bucket_udf(extractor, bucket_col)
+
+ ds_with_bucket = dataset.map_batches(
+ bucket_udf, batch_format="pyarrow", zero_copy_batch=True,
+ )
+ group_keys: List[str] = partition_keys + [bucket_col]
+ grouped = ds_with_bucket.groupby(group_keys)
+ regrouped = grouped.map_groups(_identity_batch, batch_format="pyarrow")
+ return regrouped.drop_columns([bucket_col])
+
+
+def _identity_batch(batch: pa.Table) -> pa.Table:
+ # Some Ray versions promote ``string`` to ``large_string`` (and
+ # ``binary`` to ``large_binary``) while materialising blocks for
+ # ``groupby().map_groups``. Paimon's writer compares schemas with a
+ # strict ``!=`` and rejects the large variants, so coerce them back
+ # to the regular types here. Other Arrow types pass through.
+ return _coerce_large_string_types(batch)
+
+
+def _coerce_large_string_types(batch: pa.Table) -> pa.Table:
+ needs_cast = False
+ fields = []
+ for field in batch.schema:
+ if pa.types.is_large_string(field.type):
+ fields.append(field.with_type(pa.string()))
+ needs_cast = True
+ elif pa.types.is_large_binary(field.type):
+ fields.append(field.with_type(pa.binary()))
+ needs_cast = True
+ else:
+ fields.append(field)
+ return batch.cast(pa.schema(fields)) if needs_cast else batch
+
+
+def _make_bucket_udf(extractor, bucket_col):
+ """Build a map_batches UDF that appends a transient bucket column.
+
+ The bucket value comes from ``extract_partition_bucket_batch`` so it
+ matches the writer's bucket assignment for the same row exactly.
+ """
+ def _udf(batch: pa.Table) -> pa.Table:
+ if batch.num_rows == 0:
+ return batch.append_column(
+ bucket_col, pa.array([], type=pa.int32())
+ )
+ record_batch = batch.combine_chunks().to_batches()[0]
+ _, buckets = extractor.extract_partition_bucket_batch(record_batch)
+ return batch.append_column(
+ bucket_col, pa.array(buckets, type=pa.int32())
+ )
+
+ return _udf
diff --git a/paimon-python/pypaimon/tests/ray_repartition_test.py
b/paimon-python/pypaimon/tests/ray_repartition_test.py
new file mode 100644
index 0000000000..b66b014b4f
--- /dev/null
+++ b/paimon-python/pypaimon/tests/ray_repartition_test.py
@@ -0,0 +1,263 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+"""End-to-end tests for HASH_FIXED auto-clustering on ``write_paimon``.
+
+For HASH_FIXED tables, ``write_paimon`` automatically pre-clusters rows
+by ``(partition_keys..., bucket)`` (matching Spark/Flink). These tests
+cover:
+
+ * roundtrip correctness on a HASH_FIXED PK table.
+ * roundtrip correctness on a partitioned HASH_FIXED PK table.
+ * the transient bucket column is stripped from the sink-visible
+ schema.
+ * the output is one file per (partition, bucket) — i.e. the
+ small-file storm is eliminated.
+ * regression: a table whose schema already contains a column named
+ ``__paimon_bucket__`` still works (collision-safe column name).
+ * non-HASH_FIXED tables (BUCKET_UNAWARE etc.) pass through unchanged.
+"""
+
+import glob
+import os
+import shutil
+import tempfile
+import unittest
+
+import pyarrow as pa
+import ray
+
+from pypaimon import CatalogFactory, Schema
+
+
+class RayShuffleTest(unittest.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tempdir = tempfile.mkdtemp()
+ cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+ cls.catalog_options = {'warehouse': cls.warehouse}
+
+ catalog = CatalogFactory.create(cls.catalog_options)
+ catalog.create_database('default', True)
+
+ if not ray.is_initialized():
+ # 4 CPUs gives us enough room to actually fan a multi-block
+ # write across multiple workers so the "small-file" claim
+ # is observable.
+ ray.init(ignore_reinit_error=True, num_cpus=4)
+
+ @classmethod
+ def tearDownClass(cls):
+ try:
+ if ray.is_initialized():
+ ray.shutdown()
+ except Exception:
+ pass
+ try:
+ shutil.rmtree(cls.tempdir)
+ except OSError:
+ pass
+
+ def _make_table(self, table_name, pa_schema, *, primary_keys=None,
+ partition_keys=None, options=None):
+ identifier = 'default.{}'.format(table_name)
+ schema = Schema.from_pyarrow_schema(
+ pa_schema,
+ primary_keys=primary_keys,
+ partition_keys=partition_keys,
+ options=options,
+ )
+ catalog = CatalogFactory.create(self.catalog_options)
+ catalog.create_table(identifier, schema, False)
+ return identifier
+
+ def _read_table(self, identifier):
+ """Read table data via the direct API (not ``read_paimon``).
+
+ This avoids going through ``RayDatasource._get_read_task`` which
+ has a pre-existing strict nullability check (``from_batches``
+ with Paimon schema) that rejects batches where the reader drops
+ ``not null`` (a raw-convertible PK split issue). Shuffle tests
+ care about *write* correctness, not the Ray read path.
+ """
+ catalog = CatalogFactory.create(self.catalog_options)
+ table = catalog.get_table(identifier)
+ rb = table.new_read_builder()
+ splits = rb.new_scan().plan().splits()
+ arrow = rb.new_read().to_arrow(splits)
+ return arrow.to_pandas() if arrow is not None else
pa.table({}).to_pandas()
+
+ def _count_data_files(self, table_name):
+ """All data files under the table directory, regardless of
partition."""
+ root = os.path.join(self.warehouse, 'default.db', table_name)
+ patterns = ['*.parquet', '*.orc', '*.avro']
+ files = []
+ for pattern in patterns:
+ files.extend(glob.glob(
+ os.path.join(root, '**', 'bucket-*', pattern), recursive=True,
+ ))
+ return files
+
+ # ----- HASH_FIXED auto-clustering -----
+
+ def test_fixed_bucket_roundtrip(self):
+ from pypaimon.ray import write_paimon
+
+ pa_schema = pa.schema([
+ pa.field('id', pa.int32(), nullable=False),
+ ('name', pa.string()),
+ ])
+ table_name = 'test_fixed_bucket_roundtrip'
+ identifier = self._make_table(
+ table_name, pa_schema,
+ primary_keys=['id'], options={'bucket': '4'},
+ )
+
+ rows = pa.Table.from_pydict(
+ {'id': list(range(40)), 'name': [f'v{i}' for i in range(40)]},
+ schema=pa_schema,
+ )
+ ds = ray.data.from_arrow(rows).repartition(4)
+ write_paimon(ds, identifier, self.catalog_options)
+
+ result = self._read_table(identifier)
+ self.assertEqual(len(result), 40)
+ self.assertEqual(set(result['id']), set(range(40)))
+ self.assertNotIn('__paimon_bucket__', result.columns)
+
+ def test_partitioned_fixed_bucket_roundtrip(self):
+ """Partitioned table — confirms the post-groupby schema does not
+ end up with duplicated partition-key or bucket columns."""
+ from pypaimon.ray import write_paimon
+
+ pa_schema = pa.schema([
+ pa.field('id', pa.int32(), nullable=False),
+ ('dt', pa.string()),
+ ('value', pa.int64()),
+ ])
+ table_name = 'test_partitioned_fixed_bucket_roundtrip'
+ identifier = self._make_table(
+ table_name, pa_schema,
+ primary_keys=['id', 'dt'], partition_keys=['dt'],
+ options={'bucket': '4'},
+ )
+
+ rows = pa.Table.from_pydict({
+ 'id': list(range(20)),
+ 'dt': ['2026-01-01'] * 10 + ['2026-01-02'] * 10,
+ 'value': list(range(20)),
+ }, schema=pa_schema)
+ ds = ray.data.from_arrow(rows).repartition(4)
+ write_paimon(ds, identifier, self.catalog_options)
+
+ result = self._read_table(identifier)
+ self.assertEqual(set(result.columns), {'id', 'dt', 'value'})
+ self.assertEqual(len(result), 20)
+ self.assertEqual(set(result['dt']), {'2026-01-01', '2026-01-02'})
+
+ def test_fixed_bucket_writes_one_file_per_bucket(self):
+ """With multiple input blocks, auto-clustering collapses per-task
+ files into per-bucket files."""
+ from pypaimon.ray import write_paimon
+
+ pa_schema = pa.schema([
+ pa.field('id', pa.int32(), nullable=False),
+ ('value', pa.int64()),
+ ])
+ rows = pa.Table.from_pydict(
+ {'id': list(range(200)), 'value': list(range(200))},
+ schema=pa_schema,
+ )
+
+ identifier = self._make_table(
+ 'test_one_file_per_bucket', pa_schema,
+ primary_keys=['id'], options={'bucket': '4'},
+ )
+
+ # Materialise 4 input blocks. Without auto-clustering each task
+ # would emit one file per bucket it touched (up to 16 files).
+ write_paimon(
+ ray.data.from_arrow(rows).repartition(4),
+ identifier, self.catalog_options,
+ )
+
+ files = self._count_data_files('test_one_file_per_bucket')
+ # 4 buckets × 1 file each.
+ self.assertEqual(len(files), 4)
+
+ def test_fixed_bucket_with_colliding_column_name(self):
+ """A table that has a column named ``__paimon_bucket__`` must
+ still work — the helper picks a collision-free transient
+ column name."""
+ from pypaimon.ray import write_paimon
+
+ pa_schema = pa.schema([
+ pa.field('id', pa.int32(), nullable=False),
+ ('__paimon_bucket__', pa.string()),
+ ])
+ table_name = 'test_fixed_bucket_collide_col'
+ identifier = self._make_table(
+ table_name, pa_schema,
+ primary_keys=['id'], options={'bucket': '2'},
+ )
+
+ rows = pa.Table.from_pydict(
+ {'id': list(range(10)),
+ '__paimon_bucket__': [f'v{i}' for i in range(10)]},
+ schema=pa_schema,
+ )
+ ds = ray.data.from_arrow(rows).repartition(2)
+ write_paimon(ds, identifier, self.catalog_options)
+
+ result = self._read_table(identifier)
+ self.assertEqual(len(result), 10)
+ self.assertEqual(set(result.columns), {'id', '__paimon_bucket__'})
+
+ # ----- non-HASH_FIXED passthrough -----
+
+ def test_non_fixed_bucket_roundtrip(self):
+ """BUCKET_UNAWARE tables are written without pre-clustering;
+ roundtrip data must still be correct."""
+ from pypaimon.ray import read_paimon, write_paimon
+
+ pa_schema = pa.schema([
+ ('id', pa.int32()),
+ ('value', pa.int64()),
+ ])
+ # bucket=-1 + no primary keys → BUCKET_UNAWARE
+ table_name = 'test_non_fixed_bucket_roundtrip'
+ identifier = self._make_table(
+ table_name, pa_schema, options={'bucket': '-1'},
+ )
+
+ rows = pa.Table.from_pydict(
+ {'id': list(range(10)), 'value': list(range(10))},
+ schema=pa_schema,
+ )
+ write_paimon(
+ ray.data.from_arrow(rows), identifier, self.catalog_options,
+ )
+
+ result = read_paimon(identifier, self.catalog_options).to_pandas()
+ self.assertEqual(len(result), 10)
+ self.assertEqual(set(result['id']), set(range(10)))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/paimon-python/pypaimon/tests/test_ray_shuffle_helper.py
b/paimon-python/pypaimon/tests/test_ray_shuffle_helper.py
new file mode 100644
index 0000000000..a849f2788e
--- /dev/null
+++ b/paimon-python/pypaimon/tests/test_ray_shuffle_helper.py
@@ -0,0 +1,210 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+"""Unit tests for the Ray pre-shuffle helper in pypaimon/ray/shuffle.py.
+
+These tests exercise the helper in isolation: the bucket-key UDF (with a
+stub extractor), the collision-safe column name picker, the
+large-type coercion, and the bucket-mode dispatch in
+``maybe_apply_repartition``. Ray-based end-to-end behaviour is covered
+in ``pypaimon/tests/ray_repartition_test.py``.
+"""
+
+import unittest
+from unittest.mock import MagicMock
+
+import pyarrow as pa
+
+from pypaimon.ray.shuffle import (BUCKET_KEY_COL, _coerce_large_string_types,
+ _make_bucket_udf, _pick_bucket_col_name,
+ maybe_apply_repartition)
+from pypaimon.table.bucket_mode import BucketMode
+
+
+class BucketUdfTest(unittest.TestCase):
+ """The bucket-key UDF appends a deterministic int32 column."""
+
+ def _make_extractor(self, buckets_per_row):
+ extractor = MagicMock()
+ extractor.extract_partition_bucket_batch.return_value = (
+ [() for _ in buckets_per_row],
+ list(buckets_per_row),
+ )
+ return extractor
+
+ def test_appends_int32_bucket_column(self):
+ extractor = self._make_extractor([0, 1, 0])
+ udf = _make_bucket_udf(extractor, BUCKET_KEY_COL)
+ batch = pa.table({"id": [10, 11, 12]})
+
+ out = udf(batch)
+
+ self.assertEqual(out.column_names, ["id", BUCKET_KEY_COL])
+ self.assertEqual(out.schema.field(BUCKET_KEY_COL).type, pa.int32())
+ self.assertEqual(out.column(BUCKET_KEY_COL).to_pylist(), [0, 1, 0])
+
+ def test_empty_batch_appends_empty_column(self):
+ extractor = self._make_extractor([])
+ udf = _make_bucket_udf(extractor, BUCKET_KEY_COL)
+ batch = pa.table({"id": pa.array([], type=pa.int32())})
+
+ out = udf(batch)
+
+ self.assertEqual(out.num_rows, 0)
+ self.assertEqual(out.column_names, ["id", BUCKET_KEY_COL])
+ # The extractor is short-circuited on empty input — we don't pay
+ # the cost of combining empty chunks just to call into it.
+ extractor.extract_partition_bucket_batch.assert_not_called()
+
+ def test_multichunk_batch_combines_before_extracting(self):
+ # Two record batches in the same table — the UDF must combine
+ # before calling the extractor, otherwise the extractor sees
+ # half the rows.
+ extractor = self._make_extractor([0, 1, 2, 3])
+ udf = _make_bucket_udf(extractor, BUCKET_KEY_COL)
+ rb1 = pa.record_batch({"id": [1, 2]})
+ rb2 = pa.record_batch({"id": [3, 4]})
+ batch = pa.Table.from_batches([rb1, rb2])
+
+ out = udf(batch)
+
+ self.assertEqual(out.num_rows, 4)
+ self.assertEqual(out.column(BUCKET_KEY_COL).to_pylist(), [0, 1, 2, 3])
+ # Extractor is called exactly once with all four rows.
+ call = extractor.extract_partition_bucket_batch.call_args
+ passed_batch = call.args[0]
+ self.assertEqual(passed_batch.num_rows, 4)
+
+
+class PickBucketColNameTest(unittest.TestCase):
+ """``_pick_bucket_col_name`` avoids collision with user columns."""
+
+ def test_default_name_when_no_collision(self):
+ self.assertEqual(
+ _pick_bucket_col_name({"id", "name"}), BUCKET_KEY_COL)
+
+ def test_fallback_when_default_collides(self):
+ name = _pick_bucket_col_name({"id", BUCKET_KEY_COL})
+ self.assertNotEqual(name, BUCKET_KEY_COL)
+ self.assertTrue(name.startswith("__paimon_bucket_"))
+ self.assertNotIn(name, {"id", BUCKET_KEY_COL})
+
+
+class CoerceLargeStringTypesTest(unittest.TestCase):
+ """``_identity_batch`` casts back the large_string / large_binary
+ types that some Ray versions introduce when materialising blocks
+ during ``groupby().map_groups``. The Paimon writer's strict schema
+ check would otherwise reject those rows."""
+
+ def test_pass_through_when_no_large_variants(self):
+ batch = pa.table({"id": pa.array([1, 2], type=pa.int32()),
+ "name": pa.array(["a", "b"], type=pa.string())})
+ out = _coerce_large_string_types(batch)
+ self.assertEqual(out.schema, batch.schema)
+
+ def test_casts_large_string_back_to_string(self):
+ batch = pa.table({
+ "id": pa.array([1, 2], type=pa.int32()),
+ "name": pa.array(["x", "y"], type=pa.large_string()),
+ })
+ out = _coerce_large_string_types(batch)
+ self.assertEqual(out.schema.field("name").type, pa.string())
+ self.assertEqual(out.column("name").to_pylist(), ["x", "y"])
+
+ def test_casts_large_binary_back_to_binary(self):
+ batch = pa.table({
+ "blob": pa.array([b"x", b"y"], type=pa.large_binary()),
+ })
+ out = _coerce_large_string_types(batch)
+ self.assertEqual(out.schema.field("blob").type, pa.binary())
+
+
+class BucketModeDispatchTest(unittest.TestCase):
+ """``maybe_apply_repartition`` clusters HASH_FIXED tables and
+ returns other bucket modes unchanged."""
+
+ def _make_table(self, bucket_mode):
+ table = MagicMock()
+ table.bucket_mode.return_value = bucket_mode
+ return table
+
+ def test_bucket_unaware_returns_dataset_unchanged(self):
+ dataset = object() # sentinel; must not be wrapped or mutated
+ table = self._make_table(BucketMode.BUCKET_UNAWARE)
+
+ self.assertIs(maybe_apply_repartition(dataset, table), dataset)
+
+ def test_hash_dynamic_returns_dataset_unchanged(self):
+ dataset = object()
+ table = self._make_table(BucketMode.HASH_DYNAMIC)
+
+ self.assertIs(maybe_apply_repartition(dataset, table), dataset)
+
+ def test_cross_partition_returns_dataset_unchanged(self):
+ dataset = object()
+ table = self._make_table(BucketMode.CROSS_PARTITION)
+
+ self.assertIs(maybe_apply_repartition(dataset, table), dataset)
+
+ def test_hash_fixed_runs_map_batches_groupby_chain(self):
+ dataset = MagicMock(name="dataset")
+ dataset.map_batches.return_value.groupby.return_value \
+ .map_groups.return_value.drop_columns.return_value = "clustered"
+ table = MagicMock()
+ table.bucket_mode.return_value = BucketMode.HASH_FIXED
+ table.table_schema.partition_keys = []
+ table.table_schema.fields = [
+ type("F", (), {"name": "id"})(),
+ type("F", (), {"name": "value"})(),
+ ]
+
+ out = maybe_apply_repartition(dataset, table)
+
+ self.assertEqual(out, "clustered")
+ # The helper appends a transient bucket column, groups by it,
+ # runs the identity batch over each group, then drops the
+ # transient column. We assert the call chain, not its kwargs,
+ # since defaults are an implementation detail.
+ dataset.map_batches.assert_called_once()
+ dataset.map_batches.return_value.groupby.assert_called_once()
+ dataset.map_batches.return_value.groupby.return_value \
+ .map_groups.assert_called_once()
+ dataset.map_batches.return_value.groupby.return_value \
+ .map_groups.return_value.drop_columns.assert_called_once_with(
+ [BUCKET_KEY_COL]
+ )
+
+ def test_hash_fixed_groups_include_partition_keys(self):
+ dataset = MagicMock(name="dataset")
+ table = MagicMock()
+ table.bucket_mode.return_value = BucketMode.HASH_FIXED
+ table.table_schema.partition_keys = ["dt"]
+ table.table_schema.fields = [
+ type("F", (), {"name": "id"})(),
+ type("F", (), {"name": "dt"})(),
+ ]
+
+ maybe_apply_repartition(dataset, table)
+
+ group_call = dataset.map_batches.return_value.groupby.call_args
+ passed_keys = group_call.args[0]
+ self.assertEqual(passed_keys, ["dt", BUCKET_KEY_COL])
+
+
+if __name__ == "__main__":
+ unittest.main()