This is an automated email from the ASF dual-hosted git repository.
JingsongLi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/paimon.git
The following commit(s) were added to refs/heads/master by this push:
new 94f5497780 [python] Fix Daft Paimon write column alignment (#7947)
94f5497780 is described below
commit 94f5497780f6be434ed305a8881cc520d93ebdcf
Author: QuakeWang <[email protected]>
AuthorDate: Mon May 25 19:01:51 2026 +0800
[python] Fix Daft Paimon write column alignment (#7947)
- Align Daft write batches to the Paimon target schema by column name
before casting.
- Reject missing, extra, or duplicate fields instead of silently writing
wrong data.
- Add regression coverage with native pypaimon reader verification.
---
paimon-python/pypaimon/daft/daft_datasink.py | 51 ++++++++----
.../pypaimon/tests/daft/daft_sink_test.py | 91 ++++++++++++++++++++++
2 files changed, 129 insertions(+), 13 deletions(-)
diff --git a/paimon-python/pypaimon/daft/daft_datasink.py
b/paimon-python/pypaimon/daft/daft_datasink.py
index 1f37cc4165..c019b16a31 100644
--- a/paimon-python/pypaimon/daft/daft_datasink.py
+++ b/paimon-python/pypaimon/daft/daft_datasink.py
@@ -68,28 +68,53 @@ class PaimonDataSink(DataSink[list[Any]]):
]
)
+ def _validate_input_schema(self, input_schema: pa.Schema) -> None:
+ target_names = self._target_schema.names
+ input_names = input_schema.names
+
+ if len(set(input_names)) != len(input_names):
+ raise ValueError(
+ f"Cannot write to Paimon with duplicate input field names: "
+ f"{input_names}"
+ )
+ if len(set(target_names)) != len(target_names):
+ raise ValueError(
+ f"Cannot write to Paimon with duplicate target field names: "
+ f"{target_names}"
+ )
+
+ missing = [name for name in target_names if name not in input_names]
+ extra = [name for name in input_names if name not in target_names]
+ if missing or extra:
+ details = []
+ if missing:
+ details.append(f"missing fields: {missing}")
+ if extra:
+ details.append(f"extra fields: {extra}")
+ detail = "; ".join(details)
+ raise ValueError(f"Paimon write schema mismatch: {detail}")
+
+ def _align_batch_to_target_schema(self, batch: pa.RecordBatch) ->
pa.RecordBatch:
+ if batch.schema.names != self._target_schema.names:
+ batch = batch.select(self._target_schema.names)
+ if batch.schema != self._target_schema:
+ batch = batch.cast(self._target_schema)
+ return batch
+
def write(self, micropartitions: Iterator[MicroPartition]) ->
Iterator[WriteResult[list[Any]]]:
table_write = self._write_builder.new_write()
- cast_fields: list[tuple[int, pa.DataType]] | None = None
-
total_rows = 0
total_bytes = 0
+ last_input_schema: pa.Schema | None = None
try:
for mp in micropartitions:
for rb in mp.get_record_batches():
batch = rb.to_arrow_record_batch()
- if cast_fields is None:
- cast_fields = [
- (i, field.type)
- for i, field in enumerate(self._target_schema)
- if batch.column(i).type != field.type
- ]
- if cast_fields:
- arrays = list(batch.columns)
- for i, target_type in cast_fields:
- arrays[i] = arrays[i].cast(target_type)
- batch = pa.RecordBatch.from_arrays(arrays,
schema=self._target_schema)
+ if batch.schema != last_input_schema:
+ self._validate_input_schema(batch.schema)
+ last_input_schema = batch.schema
+ batch = self._align_batch_to_target_schema(batch)
table_write.write_arrow_batch(batch)
total_rows += batch.num_rows
total_bytes += batch.nbytes
diff --git a/paimon-python/pypaimon/tests/daft/daft_sink_test.py
b/paimon-python/pypaimon/tests/daft/daft_sink_test.py
index b5b9a171d3..3cca5487d9 100644
--- a/paimon-python/pypaimon/tests/daft/daft_sink_test.py
+++ b/paimon-python/pypaimon/tests/daft/daft_sink_test.py
@@ -33,6 +33,7 @@ daft = pytest.importorskip("daft")
from pypaimon.daft.daft_compat import has_file_range_reads
from pypaimon.daft.daft_catalog import PaimonTable
+from pypaimon.daft.daft_datasink import PaimonDataSink
from pypaimon.daft.daft_paimon import _read_table, _write_table
requires_blob = pytest.mark.skipif(not has_file_range_reads(), reason="BLOB
support requires daft >= 0.7.11")
@@ -58,6 +59,18 @@ def _write_to_paimon(table, arrow_table, mode="append",
overwrite_partition=None
table_commit.close()
+def _create_id_dt_table(catalog, table_name: str):
+ schema = pypaimon.Schema.from_pyarrow_schema(
+ pa.schema([
+ pa.field("id", pa.int64()),
+ pa.field("dt", pa.string()),
+ ]),
+ options={"bucket": "1", "file.format": "parquet", "bucket-key": "id"},
+ )
+ catalog.create_table(table_name, schema, ignore_if_exists=False)
+ return catalog.get_table(table_name)
+
+
@pytest.fixture(scope="function")
def local_paimon_catalog(tmp_path):
catalog = pypaimon.CatalogFactory.create({"warehouse": str(tmp_path)})
@@ -200,6 +213,39 @@ def
test_write_paimon_roundtrip_native_verify(append_only_table):
assert ids == [7, 8, 9]
+def test_write_paimon_aligns_columns_by_name(local_paimon_catalog):
+ """Input column order should not affect the values written to Paimon."""
+ catalog, _ = local_paimon_catalog
+ table = _create_id_dt_table(catalog, "test_db.column_order")
+
+ df = daft.from_pydict(
+ {
+ "dt": ["101", "202"],
+ "id": [1, 2],
+ }
+ )
+ _write_table(df, table)
+
+ result = _read_table(table).sort("id").to_pydict()
+ assert result == {
+ "id": [1, 2],
+ "dt": ["101", "202"],
+ }
+
+ read_builder = table.new_read_builder()
+ table_scan = read_builder.new_scan()
+ table_read = read_builder.new_read()
+ splits = table_scan.plan().splits()
+ arrow_table = table_read.to_arrow(splits)
+ native_rows = sorted(
+ zip(
+ arrow_table.column("id").to_pylist(),
+ arrow_table.column("dt").to_pylist(),
+ )
+ )
+ assert native_rows == [(1, "101"), (2, "202")]
+
+
# ---------------------------------------------------------------------------
# Overwrite
# ---------------------------------------------------------------------------
@@ -257,6 +303,18 @@ def test_write_paimon_invalid_mode(append_only_table):
_write_table(df, table, mode="upsert")
+def test_write_paimon_rejects_extra_columns(local_paimon_catalog):
+ """Extra input columns should fail instead of being silently dropped."""
+ catalog, _ = local_paimon_catalog
+ table = _create_id_dt_table(catalog, "test_db.extra_columns")
+ df = daft.from_pydict(
+ {"id": [1], "dt": ["2024-01-01"], "extra": ["unused"]}
+ )
+
+ with pytest.raises(RuntimeError, match="Paimon write schema mismatch"):
+ _write_table(df, table)
+
+
def test_write_paimon_pk_table(pk_table):
"""Writing to a PK table should work and be readable back."""
table, _ = pk_table
@@ -282,6 +340,39 @@ def test_write_paimon_pk_table(pk_table):
class TestSchemaConversion:
"""Tests for schema conversion utilities."""
+ def test_align_batch_to_target_schema_by_name(self):
+ """Record batches should be reordered by field name before casting."""
+ sink = PaimonDataSink.__new__(PaimonDataSink)
+ sink._target_schema = pa.schema([("id", pa.int64()), ("dt",
pa.string())])
+ batch = pa.record_batch(
+ [
+ pa.array(["101", "202"], type=pa.large_string()),
+ pa.array([1, 2], type=pa.int64()),
+ ],
+ names=["dt", "id"],
+ )
+
+ sink._validate_input_schema(batch.schema)
+ aligned = sink._align_batch_to_target_schema(batch)
+
+ assert aligned.schema == sink._target_schema
+ assert aligned.to_pydict() == {
+ "id": [1, 2],
+ "dt": ["101", "202"],
+ }
+
+ def test_validate_input_schema_rejects_mismatch(self):
+ """Schema validation should fail fast on missing or extra fields."""
+ sink = PaimonDataSink.__new__(PaimonDataSink)
+ sink._target_schema = pa.schema([("id", pa.int64()), ("dt",
pa.string())])
+ input_schema = pa.schema([("id", pa.int64()), ("extra", pa.string())])
+
+ with pytest.raises(
+ ValueError,
+ match="missing fields: \\['dt'\\]; extra fields: \\['extra'\\]",
+ ):
+ sink._validate_input_schema(input_schema)
+
def test_write_large_string_conversion(self, local_paimon_catalog):
"""Test that large_string columns are converted to string for
pypaimon."""
catalog, tmp_path = local_paimon_catalog