zhengruifeng commented on code in PR #52637:
URL: https://github.com/apache/spark/pull/52637#discussion_r2438429853


##########
python/pyspark/sql/tests/test_python_datasource.py:
##########
@@ -755,6 +755,28 @@ def write(self, iterator):
         ):
             df.write.format("test").mode("append").saveAsTable("test_table")
 
+    def test_decimal_round(self):
+        class SimpleDataSource(DataSource):
+            @classmethod
+            def name(cls) -> str:
+                return "simple_decimal"
+
+            def schema(self) -> StructType:
+                return StructType(
+                    [StructField("i", IntegerType()), StructField("d", 
DecimalType(38, 18))]
+                )
+
+            def reader(self, schema: StructType) -> DataSourceReader:
+                return SimpleDataSourceReader()
+
+        class SimpleDataSourceReader(DataSourceReader):
+            def read(self, partition: InputPartition) -> Iterator[Tuple]:
+                yield (1, Decimal(1.234))
+
+        self.spark.dataSource.register(SimpleDataSource)
+        df = self.spark.read.format("simple_decimal").load()
+        self.assertEqual(df.select("d").first().d, 
Decimal("1.233999999999999986"))

Review Comment:
   this test fails with
   ```
   ======================================================================
   ERROR [3.713s]: test_decimal_round 
(pyspark.sql.tests.test_python_datasource.PythonDataSourceTests.test_decimal_round)
   ----------------------------------------------------------------------
   Traceback (most recent call last):
     File 
"/Users/ruifeng.zheng/spark/python/pyspark/sql/tests/test_python_datasource.py",
 line 778, in test_decimal_round
       self.assertEqual(df.select("d").first().d, 
Decimal("1.233999999999999986"))
                        ~~~~~~~~~~~~~~~~~~~~^^
     File "/Users/ruifeng.zheng/spark/python/pyspark/sql/classic/dataframe.py", 
line 
   
   ...
   
     File "/Users/ruifeng.zheng/spark/python/pyspark/worker.py", line 3258, in 
main
       process()
       ~~~~~~~^^
     File "/Users/ruifeng.zheng/spark/python/pyspark/worker.py", line 3250, in 
process
       serializer.dump_stream(out_iter, outfile)
       ^^^^^^^
     File 
"/Users/ruifeng.zheng/spark/python/pyspark/sql/pandas/serializers.py", line 
187, in dump_stream
       return super(ArrowStreamUDFSerializer, 
self).dump_stream(wrap_and_init_stream(), stream)
       ^^^^^^^^^^^^^^^
     File 
"/Users/ruifeng.zheng/spark/python/pyspark/sql/pandas/serializers.py", line 
121, in dump_stream
       for batch in iterator:
       ^^^^^^^^^^^
     File 
"/Users/ruifeng.zheng/spark/python/pyspark/sql/pandas/serializers.py", line 
167, in wrap_and_init_stream
       for batch, _ in iterator:
       ^^^^^^^^^^^
     File "/Users/ruifeng.zheng/spark/python/pyspark/worker.py", line 2803, in 
func
       for result_batch, result_type in result_iter:
       ^^^^^^^^^^^
     File 
"/Users/ruifeng.zheng/spark/python/pyspark/sql/worker/plan_data_source_read.py",
 line 167, in records_to_arrow_batches
       batch = pa.RecordBatch.from_arrays(pylist, schema=pa_schema)
       ^^^^^^^^^^^^^^^
   ...
   
   pyarrow.lib.ArrowInvalid: Rescaling Decimal value would cause data loss
   ```
   before this change



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to