This is an automated email from the ASF dual-hosted git repository.

JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 021cd92b59 [python] Pre-repartition Ray writes by (partition, bucket) 
for fixed-bucket tables (#7813)
021cd92b59 is described below

commit 021cd92b5927e9cc86783f8f68ff9b03ff586ff5
Author: chaoyang <[email protected]>
AuthorDate: Sat May 16 00:28:00 2026 +0800

    [python] Pre-repartition Ray writes by (partition, bucket) for fixed-bucket 
tables (#7813)
    
    When `write_paimon` is given a Ray Dataset, Ray's default round-robin
    block distribution scatters rows that share the same `(partition,
    bucket)` across many Ray tasks. Each task opens its own writer and emits
    its own data file, so the write produces
    `partitions × buckets × ray_tasks` files instead of the
    `partitions × buckets` the writer would naturally produce.
    
    Spark and Flink already cluster rows by `(partition, bucket)` before
    writing — see `PaimonSparkWriter.repartitionByPartitionsAndBucket` and
    the `RowAssignerChannelComputer` / `RowWithBucketChannelComputer` chain.
    This PR brings the same pre-clustering to the Ray path.
---
 docs/content/pypaimon/ray-data.md                  |  14 ++
 paimon-python/pypaimon/ray/ray_paimon.py           |   9 +
 paimon-python/pypaimon/ray/shuffle.py              | 136 +++++++++++
 .../pypaimon/tests/ray_repartition_test.py         | 263 +++++++++++++++++++++
 .../pypaimon/tests/test_ray_shuffle_helper.py      | 210 ++++++++++++++++
 5 files changed, 632 insertions(+)

diff --git a/docs/content/pypaimon/ray-data.md 
b/docs/content/pypaimon/ray-data.md
index b0ea958849..209c7c67a9 100644
--- a/docs/content/pypaimon/ray-data.md
+++ b/docs/content/pypaimon/ray-data.md
@@ -207,6 +207,20 @@ write_paimon(
 )
 ```
 
+**Automatic (partition, bucket) clustering for HASH_FIXED tables:**
+
+For HASH_FIXED tables, `write_paimon` automatically clusters rows by
+`(partition_keys..., bucket)` before writing so each (partition,
+bucket) lands in a single Ray task — one writer, one file group. This
+avoids the small-file storm that Ray's default round-robin
+distribution would otherwise produce (`partitions × buckets ×
+ray_tasks` files instead of `partitions × buckets`).
+
+Bucket assignment uses the same hash routine the writer uses, so the
+bucket seen by the groupby is byte-equivalent to the one the writer
+would compute. No user configuration is required. For non-HASH_FIXED
+tables the dataset is written as-is.
+
 **Parameters:**
 - `dataset`: the Ray Dataset to write.
 - `table_identifier`: full table name, e.g. `"db_name.table_name"`.
diff --git a/paimon-python/pypaimon/ray/ray_paimon.py 
b/paimon-python/pypaimon/ray/ray_paimon.py
index bd81949394..86505097d8 100644
--- a/paimon-python/pypaimon/ray/ray_paimon.py
+++ b/paimon-python/pypaimon/ray/ray_paimon.py
@@ -117,6 +117,12 @@ def write_paimon(
 ) -> None:
     """Write a Ray Dataset to a Paimon table.
 
+    For HASH_FIXED tables, rows are automatically clustered by
+    ``(partition_keys..., bucket)`` before writing so that each
+    (partition, bucket) lands in a single Ray task. This avoids the
+    small-file storm that Ray's default round-robin distribution would
+    otherwise produce. No user configuration is required.
+
     Args:
         dataset: The Ray Dataset to write.
         table_identifier: Full table name, e.g. ``"db_name.table_name"``.
@@ -126,11 +132,14 @@ def write_paimon(
         ray_remote_args: Optional kwargs passed to ``ray.remote`` in write 
tasks.
     """
     from pypaimon.catalog.catalog_factory import CatalogFactory
+    from pypaimon.ray.shuffle import maybe_apply_repartition
     from pypaimon.write.ray_datasink import PaimonDatasink
 
     catalog = CatalogFactory.create(catalog_options)
     table = catalog.get_table(table_identifier)
 
+    dataset = maybe_apply_repartition(dataset, table)
+
     datasink = PaimonDatasink(table, overwrite=overwrite)
 
     write_kwargs = {}
diff --git a/paimon-python/pypaimon/ray/shuffle.py 
b/paimon-python/pypaimon/ray/shuffle.py
new file mode 100644
index 0000000000..b17f7a7ab1
--- /dev/null
+++ b/paimon-python/pypaimon/ray/shuffle.py
@@ -0,0 +1,136 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+"""Pre-repartition a Ray Dataset by (partition, bucket) before writing
+to a Paimon table.
+
+Without this, Ray's default round-robin block distribution scatters rows
+that share the same (partition, bucket) across many Ray tasks. Each
+task then opens its own writer and emits its own data file, producing
+``partitions × buckets × ray_tasks`` files instead of the
+``partitions × buckets`` the writer would naturally produce.
+
+For HASH_FIXED tables we group rows by ``(partition_keys..., bucket)``
+so every distinct group lands in a single Ray task. ``bucket`` is
+computed using the same ``FixedBucketRowKeyExtractor`` the writer
+uses, so the bucket assignment seen by the groupby is byte-equivalent
+to the writer's. HASH_FIXED writes are always pre-clustered; no user
+opt-in is required.
+
+For any other bucket mode the dataset is returned unchanged.
+"""
+
+import uuid
+from typing import TYPE_CHECKING, List
+
+import pyarrow as pa
+
+from pypaimon.table.bucket_mode import BucketMode
+
+if TYPE_CHECKING:
+    import ray.data
+
+    from pypaimon.table.table import Table
+
+# Default transient column name. A collision-safe variant is picked at
+# runtime by ``_pick_bucket_col_name`` so user tables that happen to
+# contain a column with this name still work correctly.
+BUCKET_KEY_COL = "__paimon_bucket__"
+
+
+def _pick_bucket_col_name(existing_names) -> str:
+    """Return a bucket column name guaranteed not to collide with
+    ``existing_names``. Falls back to a UUID suffix on collision."""
+    if BUCKET_KEY_COL not in existing_names:
+        return BUCKET_KEY_COL
+    while True:
+        candidate = "__paimon_bucket_{}_".format(uuid.uuid4().hex[:8])
+        if candidate not in existing_names:
+            return candidate
+
+
+def maybe_apply_repartition(
+        dataset: "ray.data.Dataset",
+        table: "Table",
+) -> "ray.data.Dataset":
+    """Cluster rows by ``(partition_keys..., bucket)`` for HASH_FIXED tables.
+
+    For any other bucket mode the dataset is returned unchanged.
+    HASH_FIXED writes are always pre-clustered, with no user opt-in
+    required.
+    """
+    if table.bucket_mode() != BucketMode.HASH_FIXED:
+        return dataset
+
+    partition_keys = list(table.table_schema.partition_keys or [])
+    extractor = table.create_row_key_extractor()
+    col_names = set(f.name for f in table.table_schema.fields)
+    bucket_col = _pick_bucket_col_name(col_names)
+    bucket_udf = _make_bucket_udf(extractor, bucket_col)
+
+    ds_with_bucket = dataset.map_batches(
+        bucket_udf, batch_format="pyarrow", zero_copy_batch=True,
+    )
+    group_keys: List[str] = partition_keys + [bucket_col]
+    grouped = ds_with_bucket.groupby(group_keys)
+    regrouped = grouped.map_groups(_identity_batch, batch_format="pyarrow")
+    return regrouped.drop_columns([bucket_col])
+
+
+def _identity_batch(batch: pa.Table) -> pa.Table:
+    # Some Ray versions promote ``string`` to ``large_string`` (and
+    # ``binary`` to ``large_binary``) while materialising blocks for
+    # ``groupby().map_groups``. Paimon's writer compares schemas with a
+    # strict ``!=`` and rejects the large variants, so coerce them back
+    # to the regular types here. Other Arrow types pass through.
+    return _coerce_large_string_types(batch)
+
+
+def _coerce_large_string_types(batch: pa.Table) -> pa.Table:
+    needs_cast = False
+    fields = []
+    for field in batch.schema:
+        if pa.types.is_large_string(field.type):
+            fields.append(field.with_type(pa.string()))
+            needs_cast = True
+        elif pa.types.is_large_binary(field.type):
+            fields.append(field.with_type(pa.binary()))
+            needs_cast = True
+        else:
+            fields.append(field)
+    return batch.cast(pa.schema(fields)) if needs_cast else batch
+
+
+def _make_bucket_udf(extractor, bucket_col):
+    """Build a map_batches UDF that appends a transient bucket column.
+
+    The bucket value comes from ``extract_partition_bucket_batch`` so it
+    matches the writer's bucket assignment for the same row exactly.
+    """
+    def _udf(batch: pa.Table) -> pa.Table:
+        if batch.num_rows == 0:
+            return batch.append_column(
+                bucket_col, pa.array([], type=pa.int32())
+            )
+        record_batch = batch.combine_chunks().to_batches()[0]
+        _, buckets = extractor.extract_partition_bucket_batch(record_batch)
+        return batch.append_column(
+            bucket_col, pa.array(buckets, type=pa.int32())
+        )
+
+    return _udf
diff --git a/paimon-python/pypaimon/tests/ray_repartition_test.py 
b/paimon-python/pypaimon/tests/ray_repartition_test.py
new file mode 100644
index 0000000000..b66b014b4f
--- /dev/null
+++ b/paimon-python/pypaimon/tests/ray_repartition_test.py
@@ -0,0 +1,263 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+"""End-to-end tests for HASH_FIXED auto-clustering on ``write_paimon``.
+
+For HASH_FIXED tables, ``write_paimon`` automatically pre-clusters rows
+by ``(partition_keys..., bucket)`` (matching Spark/Flink). These tests
+cover:
+
+  * roundtrip correctness on a HASH_FIXED PK table.
+  * roundtrip correctness on a partitioned HASH_FIXED PK table.
+  * the transient bucket column is stripped from the sink-visible
+    schema.
+  * the output is one file per (partition, bucket) — i.e. the
+    small-file storm is eliminated.
+  * regression: a table whose schema already contains a column named
+    ``__paimon_bucket__`` still works (collision-safe column name).
+  * non-HASH_FIXED tables (BUCKET_UNAWARE etc.) pass through unchanged.
+"""
+
+import glob
+import os
+import shutil
+import tempfile
+import unittest
+
+import pyarrow as pa
+import ray
+
+from pypaimon import CatalogFactory, Schema
+
+
+class RayShuffleTest(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        cls.tempdir = tempfile.mkdtemp()
+        cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
+        cls.catalog_options = {'warehouse': cls.warehouse}
+
+        catalog = CatalogFactory.create(cls.catalog_options)
+        catalog.create_database('default', True)
+
+        if not ray.is_initialized():
+            # 4 CPUs gives us enough room to actually fan a multi-block
+            # write across multiple workers so the "small-file" claim
+            # is observable.
+            ray.init(ignore_reinit_error=True, num_cpus=4)
+
+    @classmethod
+    def tearDownClass(cls):
+        try:
+            if ray.is_initialized():
+                ray.shutdown()
+        except Exception:
+            pass
+        try:
+            shutil.rmtree(cls.tempdir)
+        except OSError:
+            pass
+
+    def _make_table(self, table_name, pa_schema, *, primary_keys=None,
+                    partition_keys=None, options=None):
+        identifier = 'default.{}'.format(table_name)
+        schema = Schema.from_pyarrow_schema(
+            pa_schema,
+            primary_keys=primary_keys,
+            partition_keys=partition_keys,
+            options=options,
+        )
+        catalog = CatalogFactory.create(self.catalog_options)
+        catalog.create_table(identifier, schema, False)
+        return identifier
+
+    def _read_table(self, identifier):
+        """Read table data via the direct API (not ``read_paimon``).
+
+        This avoids going through ``RayDatasource._get_read_task`` which
+        has a pre-existing strict nullability check (``from_batches``
+        with Paimon schema) that rejects batches where the reader drops
+        ``not null`` (a raw-convertible PK split issue). Shuffle tests
+        care about *write* correctness, not the Ray read path.
+        """
+        catalog = CatalogFactory.create(self.catalog_options)
+        table = catalog.get_table(identifier)
+        rb = table.new_read_builder()
+        splits = rb.new_scan().plan().splits()
+        arrow = rb.new_read().to_arrow(splits)
+        return arrow.to_pandas() if arrow is not None else 
pa.table({}).to_pandas()
+
+    def _count_data_files(self, table_name):
+        """All data files under the table directory, regardless of 
partition."""
+        root = os.path.join(self.warehouse, 'default.db', table_name)
+        patterns = ['*.parquet', '*.orc', '*.avro']
+        files = []
+        for pattern in patterns:
+            files.extend(glob.glob(
+                os.path.join(root, '**', 'bucket-*', pattern), recursive=True,
+            ))
+        return files
+
+    # ----- HASH_FIXED auto-clustering -----
+
+    def test_fixed_bucket_roundtrip(self):
+        from pypaimon.ray import write_paimon
+
+        pa_schema = pa.schema([
+            pa.field('id', pa.int32(), nullable=False),
+            ('name', pa.string()),
+        ])
+        table_name = 'test_fixed_bucket_roundtrip'
+        identifier = self._make_table(
+            table_name, pa_schema,
+            primary_keys=['id'], options={'bucket': '4'},
+        )
+
+        rows = pa.Table.from_pydict(
+            {'id': list(range(40)), 'name': [f'v{i}' for i in range(40)]},
+            schema=pa_schema,
+        )
+        ds = ray.data.from_arrow(rows).repartition(4)
+        write_paimon(ds, identifier, self.catalog_options)
+
+        result = self._read_table(identifier)
+        self.assertEqual(len(result), 40)
+        self.assertEqual(set(result['id']), set(range(40)))
+        self.assertNotIn('__paimon_bucket__', result.columns)
+
+    def test_partitioned_fixed_bucket_roundtrip(self):
+        """Partitioned table — confirms the post-groupby schema does not
+        end up with duplicated partition-key or bucket columns."""
+        from pypaimon.ray import write_paimon
+
+        pa_schema = pa.schema([
+            pa.field('id', pa.int32(), nullable=False),
+            ('dt', pa.string()),
+            ('value', pa.int64()),
+        ])
+        table_name = 'test_partitioned_fixed_bucket_roundtrip'
+        identifier = self._make_table(
+            table_name, pa_schema,
+            primary_keys=['id', 'dt'], partition_keys=['dt'],
+            options={'bucket': '4'},
+        )
+
+        rows = pa.Table.from_pydict({
+            'id': list(range(20)),
+            'dt': ['2026-01-01'] * 10 + ['2026-01-02'] * 10,
+            'value': list(range(20)),
+        }, schema=pa_schema)
+        ds = ray.data.from_arrow(rows).repartition(4)
+        write_paimon(ds, identifier, self.catalog_options)
+
+        result = self._read_table(identifier)
+        self.assertEqual(set(result.columns), {'id', 'dt', 'value'})
+        self.assertEqual(len(result), 20)
+        self.assertEqual(set(result['dt']), {'2026-01-01', '2026-01-02'})
+
+    def test_fixed_bucket_writes_one_file_per_bucket(self):
+        """With multiple input blocks, auto-clustering collapses per-task
+        files into per-bucket files."""
+        from pypaimon.ray import write_paimon
+
+        pa_schema = pa.schema([
+            pa.field('id', pa.int32(), nullable=False),
+            ('value', pa.int64()),
+        ])
+        rows = pa.Table.from_pydict(
+            {'id': list(range(200)), 'value': list(range(200))},
+            schema=pa_schema,
+        )
+
+        identifier = self._make_table(
+            'test_one_file_per_bucket', pa_schema,
+            primary_keys=['id'], options={'bucket': '4'},
+        )
+
+        # Materialise 4 input blocks. Without auto-clustering each task
+        # would emit one file per bucket it touched (up to 16 files).
+        write_paimon(
+            ray.data.from_arrow(rows).repartition(4),
+            identifier, self.catalog_options,
+        )
+
+        files = self._count_data_files('test_one_file_per_bucket')
+        # 4 buckets × 1 file each.
+        self.assertEqual(len(files), 4)
+
+    def test_fixed_bucket_with_colliding_column_name(self):
+        """A table that has a column named ``__paimon_bucket__`` must
+        still work — the helper picks a collision-free transient
+        column name."""
+        from pypaimon.ray import write_paimon
+
+        pa_schema = pa.schema([
+            pa.field('id', pa.int32(), nullable=False),
+            ('__paimon_bucket__', pa.string()),
+        ])
+        table_name = 'test_fixed_bucket_collide_col'
+        identifier = self._make_table(
+            table_name, pa_schema,
+            primary_keys=['id'], options={'bucket': '2'},
+        )
+
+        rows = pa.Table.from_pydict(
+            {'id': list(range(10)),
+             '__paimon_bucket__': [f'v{i}' for i in range(10)]},
+            schema=pa_schema,
+        )
+        ds = ray.data.from_arrow(rows).repartition(2)
+        write_paimon(ds, identifier, self.catalog_options)
+
+        result = self._read_table(identifier)
+        self.assertEqual(len(result), 10)
+        self.assertEqual(set(result.columns), {'id', '__paimon_bucket__'})
+
+    # ----- non-HASH_FIXED passthrough -----
+
+    def test_non_fixed_bucket_roundtrip(self):
+        """BUCKET_UNAWARE tables are written without pre-clustering;
+        roundtrip data must still be correct."""
+        from pypaimon.ray import read_paimon, write_paimon
+
+        pa_schema = pa.schema([
+            ('id', pa.int32()),
+            ('value', pa.int64()),
+        ])
+        # bucket=-1 + no primary keys → BUCKET_UNAWARE
+        table_name = 'test_non_fixed_bucket_roundtrip'
+        identifier = self._make_table(
+            table_name, pa_schema, options={'bucket': '-1'},
+        )
+
+        rows = pa.Table.from_pydict(
+            {'id': list(range(10)), 'value': list(range(10))},
+            schema=pa_schema,
+        )
+        write_paimon(
+            ray.data.from_arrow(rows), identifier, self.catalog_options,
+        )
+
+        result = read_paimon(identifier, self.catalog_options).to_pandas()
+        self.assertEqual(len(result), 10)
+        self.assertEqual(set(result['id']), set(range(10)))
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/paimon-python/pypaimon/tests/test_ray_shuffle_helper.py 
b/paimon-python/pypaimon/tests/test_ray_shuffle_helper.py
new file mode 100644
index 0000000000..a849f2788e
--- /dev/null
+++ b/paimon-python/pypaimon/tests/test_ray_shuffle_helper.py
@@ -0,0 +1,210 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+"""Unit tests for the Ray pre-shuffle helper in pypaimon/ray/shuffle.py.
+
+These tests exercise the helper in isolation: the bucket-key UDF (with a
+stub extractor), the collision-safe column name picker, the
+large-type coercion, and the bucket-mode dispatch in
+``maybe_apply_repartition``. Ray-based end-to-end behaviour is covered
+in ``pypaimon/tests/ray_repartition_test.py``.
+"""
+
+import unittest
+from unittest.mock import MagicMock
+
+import pyarrow as pa
+
+from pypaimon.ray.shuffle import (BUCKET_KEY_COL, _coerce_large_string_types,
+                                  _make_bucket_udf, _pick_bucket_col_name,
+                                  maybe_apply_repartition)
+from pypaimon.table.bucket_mode import BucketMode
+
+
+class BucketUdfTest(unittest.TestCase):
+    """The bucket-key UDF appends a deterministic int32 column."""
+
+    def _make_extractor(self, buckets_per_row):
+        extractor = MagicMock()
+        extractor.extract_partition_bucket_batch.return_value = (
+            [() for _ in buckets_per_row],
+            list(buckets_per_row),
+        )
+        return extractor
+
+    def test_appends_int32_bucket_column(self):
+        extractor = self._make_extractor([0, 1, 0])
+        udf = _make_bucket_udf(extractor, BUCKET_KEY_COL)
+        batch = pa.table({"id": [10, 11, 12]})
+
+        out = udf(batch)
+
+        self.assertEqual(out.column_names, ["id", BUCKET_KEY_COL])
+        self.assertEqual(out.schema.field(BUCKET_KEY_COL).type, pa.int32())
+        self.assertEqual(out.column(BUCKET_KEY_COL).to_pylist(), [0, 1, 0])
+
+    def test_empty_batch_appends_empty_column(self):
+        extractor = self._make_extractor([])
+        udf = _make_bucket_udf(extractor, BUCKET_KEY_COL)
+        batch = pa.table({"id": pa.array([], type=pa.int32())})
+
+        out = udf(batch)
+
+        self.assertEqual(out.num_rows, 0)
+        self.assertEqual(out.column_names, ["id", BUCKET_KEY_COL])
+        # The extractor is short-circuited on empty input — we don't pay
+        # the cost of combining empty chunks just to call into it.
+        extractor.extract_partition_bucket_batch.assert_not_called()
+
+    def test_multichunk_batch_combines_before_extracting(self):
+        # Two record batches in the same table — the UDF must combine
+        # before calling the extractor, otherwise the extractor sees
+        # half the rows.
+        extractor = self._make_extractor([0, 1, 2, 3])
+        udf = _make_bucket_udf(extractor, BUCKET_KEY_COL)
+        rb1 = pa.record_batch({"id": [1, 2]})
+        rb2 = pa.record_batch({"id": [3, 4]})
+        batch = pa.Table.from_batches([rb1, rb2])
+
+        out = udf(batch)
+
+        self.assertEqual(out.num_rows, 4)
+        self.assertEqual(out.column(BUCKET_KEY_COL).to_pylist(), [0, 1, 2, 3])
+        # Extractor is called exactly once with all four rows.
+        call = extractor.extract_partition_bucket_batch.call_args
+        passed_batch = call.args[0]
+        self.assertEqual(passed_batch.num_rows, 4)
+
+
+class PickBucketColNameTest(unittest.TestCase):
+    """``_pick_bucket_col_name`` avoids collision with user columns."""
+
+    def test_default_name_when_no_collision(self):
+        self.assertEqual(
+            _pick_bucket_col_name({"id", "name"}), BUCKET_KEY_COL)
+
+    def test_fallback_when_default_collides(self):
+        name = _pick_bucket_col_name({"id", BUCKET_KEY_COL})
+        self.assertNotEqual(name, BUCKET_KEY_COL)
+        self.assertTrue(name.startswith("__paimon_bucket_"))
+        self.assertNotIn(name, {"id", BUCKET_KEY_COL})
+
+
+class CoerceLargeStringTypesTest(unittest.TestCase):
+    """``_identity_batch`` casts back the large_string / large_binary
+    types that some Ray versions introduce when materialising blocks
+    during ``groupby().map_groups``. The Paimon writer's strict schema
+    check would otherwise reject those rows."""
+
+    def test_pass_through_when_no_large_variants(self):
+        batch = pa.table({"id": pa.array([1, 2], type=pa.int32()),
+                          "name": pa.array(["a", "b"], type=pa.string())})
+        out = _coerce_large_string_types(batch)
+        self.assertEqual(out.schema, batch.schema)
+
+    def test_casts_large_string_back_to_string(self):
+        batch = pa.table({
+            "id": pa.array([1, 2], type=pa.int32()),
+            "name": pa.array(["x", "y"], type=pa.large_string()),
+        })
+        out = _coerce_large_string_types(batch)
+        self.assertEqual(out.schema.field("name").type, pa.string())
+        self.assertEqual(out.column("name").to_pylist(), ["x", "y"])
+
+    def test_casts_large_binary_back_to_binary(self):
+        batch = pa.table({
+            "blob": pa.array([b"x", b"y"], type=pa.large_binary()),
+        })
+        out = _coerce_large_string_types(batch)
+        self.assertEqual(out.schema.field("blob").type, pa.binary())
+
+
+class BucketModeDispatchTest(unittest.TestCase):
+    """``maybe_apply_repartition`` clusters HASH_FIXED tables and
+    returns other bucket modes unchanged."""
+
+    def _make_table(self, bucket_mode):
+        table = MagicMock()
+        table.bucket_mode.return_value = bucket_mode
+        return table
+
+    def test_bucket_unaware_returns_dataset_unchanged(self):
+        dataset = object()  # sentinel; must not be wrapped or mutated
+        table = self._make_table(BucketMode.BUCKET_UNAWARE)
+
+        self.assertIs(maybe_apply_repartition(dataset, table), dataset)
+
+    def test_hash_dynamic_returns_dataset_unchanged(self):
+        dataset = object()
+        table = self._make_table(BucketMode.HASH_DYNAMIC)
+
+        self.assertIs(maybe_apply_repartition(dataset, table), dataset)
+
+    def test_cross_partition_returns_dataset_unchanged(self):
+        dataset = object()
+        table = self._make_table(BucketMode.CROSS_PARTITION)
+
+        self.assertIs(maybe_apply_repartition(dataset, table), dataset)
+
+    def test_hash_fixed_runs_map_batches_groupby_chain(self):
+        dataset = MagicMock(name="dataset")
+        dataset.map_batches.return_value.groupby.return_value \
+            .map_groups.return_value.drop_columns.return_value = "clustered"
+        table = MagicMock()
+        table.bucket_mode.return_value = BucketMode.HASH_FIXED
+        table.table_schema.partition_keys = []
+        table.table_schema.fields = [
+            type("F", (), {"name": "id"})(),
+            type("F", (), {"name": "value"})(),
+        ]
+
+        out = maybe_apply_repartition(dataset, table)
+
+        self.assertEqual(out, "clustered")
+        # The helper appends a transient bucket column, groups by it,
+        # runs the identity batch over each group, then drops the
+        # transient column. We assert the call chain, not its kwargs,
+        # since defaults are an implementation detail.
+        dataset.map_batches.assert_called_once()
+        dataset.map_batches.return_value.groupby.assert_called_once()
+        dataset.map_batches.return_value.groupby.return_value \
+            .map_groups.assert_called_once()
+        dataset.map_batches.return_value.groupby.return_value \
+            .map_groups.return_value.drop_columns.assert_called_once_with(
+                [BUCKET_KEY_COL]
+            )
+
+    def test_hash_fixed_groups_include_partition_keys(self):
+        dataset = MagicMock(name="dataset")
+        table = MagicMock()
+        table.bucket_mode.return_value = BucketMode.HASH_FIXED
+        table.table_schema.partition_keys = ["dt"]
+        table.table_schema.fields = [
+            type("F", (), {"name": "id"})(),
+            type("F", (), {"name": "dt"})(),
+        ]
+
+        maybe_apply_repartition(dataset, table)
+
+        group_call = dataset.map_batches.return_value.groupby.call_args
+        passed_keys = group_call.args[0]
+        self.assertEqual(passed_keys, ["dt", BUCKET_KEY_COL])
+
+
+if __name__ == "__main__":
+    unittest.main()

Reply via email to