zhengruifeng commented on code in PR #52637:
URL: https://github.com/apache/spark/pull/52637#discussion_r2438429853
##########
python/pyspark/sql/tests/test_python_datasource.py:
##########
@@ -755,6 +755,28 @@ def write(self, iterator):
):
df.write.format("test").mode("append").saveAsTable("test_table")
+ def test_decimal_round(self):
+ class SimpleDataSource(DataSource):
+ @classmethod
+ def name(cls) -> str:
+ return "simple_decimal"
+
+ def schema(self) -> StructType:
+ return StructType(
+ [StructField("i", IntegerType()), StructField("d",
DecimalType(38, 18))]
+ )
+
+ def reader(self, schema: StructType) -> DataSourceReader:
+ return SimpleDataSourceReader()
+
+ class SimpleDataSourceReader(DataSourceReader):
+ def read(self, partition: InputPartition) -> Iterator[Tuple]:
+ yield (1, Decimal(1.234))
+
+ self.spark.dataSource.register(SimpleDataSource)
+ df = self.spark.read.format("simple_decimal").load()
+ self.assertEqual(df.select("d").first().d,
Decimal("1.233999999999999986"))
Review Comment:
this test fails with
```
======================================================================
ERROR [3.713s]: test_decimal_round
(pyspark.sql.tests.test_python_datasource.PythonDataSourceTests.test_decimal_round)
----------------------------------------------------------------------
Traceback (most recent call last):
File
"/Users/ruifeng.zheng/spark/python/pyspark/sql/tests/test_python_datasource.py",
line 778, in test_decimal_round
self.assertEqual(df.select("d").first().d,
Decimal("1.233999999999999986"))
~~~~~~~~~~~~~~~~~~~~^^
File "/Users/ruifeng.zheng/spark/python/pyspark/sql/classic/dataframe.py",
line
...
File "/Users/ruifeng.zheng/spark/python/pyspark/worker.py", line 3258, in
main
process()
~~~~~~~^^
File "/Users/ruifeng.zheng/spark/python/pyspark/worker.py", line 3250, in
process
serializer.dump_stream(out_iter, outfile)
^^^^^^^
File
"/Users/ruifeng.zheng/spark/python/pyspark/sql/pandas/serializers.py", line
187, in dump_stream
return super(ArrowStreamUDFSerializer,
self).dump_stream(wrap_and_init_stream(), stream)
^^^^^^^^^^^^^^^
File
"/Users/ruifeng.zheng/spark/python/pyspark/sql/pandas/serializers.py", line
121, in dump_stream
for batch in iterator:
^^^^^^^^^^^
File
"/Users/ruifeng.zheng/spark/python/pyspark/sql/pandas/serializers.py", line
167, in wrap_and_init_stream
for batch, _ in iterator:
^^^^^^^^^^^
File "/Users/ruifeng.zheng/spark/python/pyspark/worker.py", line 2803, in
func
for result_batch, result_type in result_iter:
^^^^^^^^^^^
File
"/Users/ruifeng.zheng/spark/python/pyspark/sql/worker/plan_data_source_read.py",
line 167, in records_to_arrow_batches
batch = pa.RecordBatch.from_arrays(pylist, schema=pa_schema)
^^^^^^^^^^^^^^^
...
pyarrow.lib.ArrowInvalid: Rescaling Decimal value would cause data loss
```
before this change
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]