This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1642e928478c [SPARK-46823][CONNECT][PYTHON]
`LocalDataToArrowConversion` should check the nullability
1642e928478c is described below
commit 1642e928478c8c20bae5203ecf2e4d659aca7692
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Jan 24 00:43:41 2024 -0800
[SPARK-46823][CONNECT][PYTHON] `LocalDataToArrowConversion` should check
the nullability
### What changes were proposed in this pull request?
`LocalDataToArrowConversion` should check the nullability
### Why are the changes needed?
this check was missing
### Does this PR introduce _any_ user-facing change?
yes
```
data = [("asd", None)]
schema = StructType(
[
StructField("name", StringType(), nullable=True),
StructField("age", IntegerType(), nullable=False),
]
)
```
before:
```
In [3]: df = spark.createDataFrame([("asd", None)], schema)
In [4]: df
Out[4]: 24/01/24 12:08:28 ERROR ErrorUtils: Spark Connect RPC error during:
analyze. UserId: ruifeng.zheng. SessionId: cd692bb1-d503-4043-a9db-d29cb5c16517.
java.lang.IllegalStateException: Value at index is null
at org.apache.arrow.vector.IntVector.get(IntVector.java:107)
at
org.apache.spark.sql.vectorized.ArrowColumnVector$IntAccessor.getInt(ArrowColumnVector.java:338)
at
org.apache.spark.sql.vectorized.ArrowColumnVector.getInt(ArrowColumnVector.java:88)
at
org.apache.spark.sql.vectorized.ColumnarBatchRow.getInt(ColumnarBatchRow.java:109)
at
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
Source)
at
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
Source)
at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
at scala.collection.immutable.List.prependedAll(List.scala:153)
at scala.collection.immutable.List$.from(List.scala:684)
at scala.collection.immutable.List$.from(List.scala:681)
at scala.collection.SeqFactory$Delegate.from(Factory.scala:306)
at scala.collection.immutable.Seq$.from(Seq.scala:42)
at scala.collection.IterableOnceOps.toSeq(IterableOnce.scala:1326)
at scala.collection.IterableOnceOps.toSeq$(IterableOnce.scala:1326)
at scala.collection.AbstractIterator.toSeq(Iterator.scala:1300)
at
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformLocalRelation(SparkConnectPlanner.scala:1239)
at
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:139)
at
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.process(SparkConnectAnalyzeHandler.scala:59)
at
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1(SparkConnectAnalyzeHandler.scala:43)
at
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1$adapted(SparkConnectAnalyzeHandler.scala:42)
at
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:289)
at
org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:918)
at
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:289)
at
org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
at
org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:80)
at
org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:182)
at
org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:79)
at
org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:288)
at
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.handle(SparkConnectAnalyzeHandler.scala:42)
at
org.apache.spark.sql.connect.service.SparkConnectService.analyzePlan(SparkConnectService.scala:95)
at
org.apache.spark.connect.proto.SparkConnectServiceGrpc$MethodHandlers.invoke(SparkConnectServiceGrpc.java:907)
at
org.sparkproject.connect.grpc.io.grpc.stub.ServerCalls$UnaryServerCallHandler$UnaryServerCallListener.onHalfClose(ServerCalls.java:182)
at
org.sparkproject.connect.grpc.io.grpc.internal.ServerCallImpl$ServerStreamListenerImpl.halfClosed(ServerCallImpl.java:351)
at
org.sparkproject.connect.grpc.io.grpc.internal.ServerImpl$JumpToApplicationThreadServerStreamListener$1HalfClosed.runInContext(ServerImpl.java:860)
at
org.sparkproject.connect.grpc.io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
at
org.sparkproject.connect.grpc.io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
at
java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
at
java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
at java.base/java.lang.Thread.run(Thread.java:833)
24/01/24 12:08:28 ERROR ErrorUtils: Spark Connect RPC error during:
analyze. UserId: ruifeng.zheng. SessionId: cd692bb1-d503-4043-a9db-d29cb5c16517.
java.lang.IllegalStateException: Value at index is null
at org.apache.arrow.vector.IntVector.get(IntVector.java:107)
at
org.apache.spark.sql.vectorized.ArrowColumnVector$IntAccessor.getInt(ArrowColumnVector.java:338)
at
org.apache.spark.sql.vectorized.ArrowColumnVector.getInt(ArrowColumnVector.java:88)
at
org.apache.spark.sql.vectorized.ColumnarBatchRow.getInt(ColumnarBatchRow.java:109)
at
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
Source)
at
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
Source)
at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
at scala.collection.immutable.List.prependedAll(List.scala:153)
at scala.collection.immutable.List$.from(List.scala:684)
at scala.collection.immutable.List$.from(List.scala:681)
at scala.collection.SeqFactory$Delegate.from(Factory.scala:306)
at scala.collection.immutable.Seq$.from(Seq.scala:42)
at scala.collection.IterableOnceOps.toSeq(IterableOnce.scala:1326)
at scala.collection.IterableOnceOps.toSeq$(IterableOnce.scala:1326)
at scala.collection.AbstractIterator.toSeq(Iterator.scala:1300)
at
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformLocalRelation(SparkConnectPlanner.scala:1239)
at
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:139)
at
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.process(SparkConnectAnalyzeHandler.scala:59)
at
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1(SparkConnectAnalyzeHandler.scala:43)
at
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1$adapted(SparkConnectAnalyzeHandler.scala:42)
at
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:289)
at
org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:918)
at
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:289)
at
org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
at
org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:80)
at
org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:182)
at
org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:79)
at
org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:288)
at
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.handle(SparkConnectAnalyzeHandler.scala:42)
at
org.apache.spark.sql.connect.service.SparkConnectService.analyzePlan(SparkConnectService.scala:95)
at
org.apache.spark.connect.proto.SparkConnectServiceGrpc$MethodHandlers.invoke(SparkConnectServiceGrpc.java:907)
at
org.sparkproject.connect.grpc.io.grpc.stub.ServerCalls$UnaryServerCallHandler$UnaryServerCallListener.onHalfClose(ServerCalls.java:182)
at
org.sparkproject.connect.grpc.io.grpc.internal.ServerCallImpl$ServerStreamListenerImpl.halfClosed(ServerCallImpl.java:351)
at
org.sparkproject.connect.grpc.io.grpc.internal.ServerImpl$JumpToApplicationThreadServerStreamListener$1HalfClosed.runInContext(ServerImpl.java:860)
at
org.sparkproject.connect.grpc.io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
at
org.sparkproject.connect.grpc.io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
at
java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
at
java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
at java.base/java.lang.Thread.run(Thread.java:833)
24/01/24 12:08:28 ERROR ErrorUtils: Spark Connect RPC error during:
analyze. UserId: ruifeng.zheng. SessionId: cd692bb1-d503-4043-a9db-d29cb5c16517.
java.lang.IllegalStateException: Value at index is null
at org.apache.arrow.vector.IntVector.get(IntVector.java:107)
at
org.apache.spark.sql.vectorized.ArrowColumnVector$IntAccessor.getInt(ArrowColumnVector.java:338)
at
org.apache.spark.sql.vectorized.ArrowColumnVector.getInt(ArrowColumnVector.java:88)
at
org.apache.spark.sql.vectorized.ColumnarBatchRow.getInt(ColumnarBatchRow.java:109)
at
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
Source)
at
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
Source)
at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
at scala.collection.immutable.List.prependedAll(List.scala:153)
at scala.collection.immutable.List$.from(List.scala:684)
at scala.collection.immutable.List$.from(List.scala:681)
at scala.collection.SeqFactory$Delegate.from(Factory.scala:306)
at scala.collection.immutable.Seq$.from(Seq.scala:42)
at scala.collection.IterableOnceOps.toSeq(IterableOnce.scala:1326)
at scala.collection.IterableOnceOps.toSeq$(IterableOnce.scala:1326)
at scala.collection.AbstractIterator.toSeq(Iterator.scala:1300)
at
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformLocalRelation(SparkConnectPlanner.scala:1239)
at
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:139)
at
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.process(SparkConnectAnalyzeHandler.scala:59)
at
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1(SparkConnectAnalyzeHandler.scala:43)
at
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1$adapted(SparkConnectAnalyzeHandler.scala:42)
at
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:289)
at
org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:918)
at
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:289)
at
org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
at
org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:80)
at
org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:182)
at
org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:79)
at
org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:288)
at
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.handle(SparkConnectAnalyzeHandler.scala:42)
at
org.apache.spark.sql.connect.service.SparkConnectService.analyzePlan(SparkConnectService.scala:95)
at
org.apache.spark.connect.proto.SparkConnectServiceGrpc$MethodHandlers.invoke(SparkConnectServiceGrpc.java:907)
at
org.sparkproject.connect.grpc.io.grpc.stub.ServerCalls$UnaryServerCallHandler$UnaryServerCallListener.onHalfClose(ServerCalls.java:182)
at
org.sparkproject.connect.grpc.io.grpc.internal.ServerCallImpl$ServerStreamListenerImpl.halfClosed(ServerCallImpl.java:351)
at
org.sparkproject.connect.grpc.io.grpc.internal.ServerImpl$JumpToApplicationThreadServerStreamListener$1HalfClosed.runInContext(ServerImpl.java:860)
at
org.sparkproject.connect.grpc.io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
at
org.sparkproject.connect.grpc.io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
at
java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
at
java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
at java.base/java.lang.Thread.run(Thread.java:833)
---------------------------------------------------------------------------
SparkConnectGrpcException Traceback (most recent call last)
File
~/.dev/miniconda3/envs/spark_dev_311/lib/python3.11/site-packages/IPython/core/formatters.py:708,
in PlainTextFormatter.__call__(self, obj)
701 stream = StringIO()
702 printer = pretty.RepresentationPrinter(stream, self.verbose,
703 self.max_width, self.newline,
704 max_seq_length=self.max_seq_length,
705 singleton_pprinters=self.singleton_printers,
706 type_pprinters=self.type_printers,
707 deferred_pprinters=self.deferred_printers)
--> 708 printer.pretty(obj)
709 printer.flush()
710 return stream.getvalue()
File
~/.dev/miniconda3/envs/spark_dev_311/lib/python3.11/site-packages/IPython/lib/pretty.py:410,
in RepresentationPrinter.pretty(self, obj)
407 return meth(obj, self, cycle)
408 if cls is not object \
409 and callable(cls.__dict__.get('__repr__')):
--> 410 return _repr_pprint(obj, self, cycle)
412 return _default_pprint(obj, self, cycle)
413 finally:
File
~/.dev/miniconda3/envs/spark_dev_311/lib/python3.11/site-packages/IPython/lib/pretty.py:778,
in _repr_pprint(obj, p, cycle)
776 """A pprint that just redirects to the normal repr function."""
777 # Find newlines and replace them with p.break_()
--> 778 output = repr(obj)
779 lines = output.splitlines()
780 with p.group():
File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:141, in
DataFrame.__repr__(self)
135 if repl_eager_eval_enabled == "true":
136 return self._show_string(
137 n=int(cast(str, repl_eager_eval_max_num_rows)),
138 truncate=int(cast(str, repl_eager_eval_truncate)),
139 vertical=False,
140 )
--> 141 return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in
self.dtypes))
File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:238, in
DataFrame.dtypes(self)
236 property
237 def dtypes(self) -> List[Tuple[str, str]]:
--> 238 return [(str(f.name), f.dataType.simpleString()) for f in
self.schema.fields]
File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:1786, in
DataFrame.schema(self)
1783 property
1784 def schema(self) -> StructType:
1785 query = self._plan.to_proto(self._session.client)
-> 1786 return self._session.client.schema(query)
File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:921, in
SparkConnectClient.schema(self, plan)
917 """
918 Return schema for given plan.
919 """
920 logger.info(f"Schema for plan: {self._proto_to_string(plan)}")
--> 921 schema = self._analyze(method="schema", plan=plan).schema
922 assert schema is not None
923 # Server side should populate the struct field which is the schema.
File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1107, in
SparkConnectClient._analyze(self, method, **kwargs)
1105 raise SparkConnectException("Invalid state during retry
exception handling.")
1106 except Exception as error:
-> 1107 self._handle_error(error)
File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1525, in
SparkConnectClient._handle_error(self, error)
1523 self.thread_local.inside_error_handling = True
1524 if isinstance(error, grpc.RpcError):
-> 1525 self._handle_rpc_error(error)
1526 elif isinstance(error, ValueError):
1527 if "Cannot invoke RPC" in str(error) and "closed" in str(error):
File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1595, in
SparkConnectClient._handle_rpc_error(self, rpc_error)
1592 info = error_details_pb2.ErrorInfo()
1593 d.Unpack(info)
-> 1595 raise convert_exception(
1596 info,
1597 status.message,
1598 self._fetch_enriched_error(info),
1599 self._display_server_stack_trace(),
1600 ) from None
1602 raise SparkConnectGrpcException(status.message) from None
1603 else:
SparkConnectGrpcException: (java.lang.IllegalStateException) Value at index
is null
JVM stacktrace:
java.lang.IllegalStateException
at org.apache.arrow.vector.IntVector.get(IntVector.java:107)
at
org.apache.spark.sql.vectorized.ArrowColumnVector$IntAccessor.getInt(ArrowColumnVector.java:338)
at
org.apache.spark.sql.vectorized.ArrowColumnVector.getInt(ArrowColumnVector.java:88)
at
org.apache.spark.sql.vectorized.ColumnarBatchRow.getInt(ColumnarBatchRow.java:109)
at
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(:-1)
at
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(:-1)
at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
at scala.collection.immutable.List.prependedAll(List.scala:153)
at scala.collection.immutable.List$.from(List.scala:684)
at scala.collection.immutable.List$.from(List.scala:681)
at scala.collection.SeqFactory$Delegate.from(Factory.scala:306)
at scala.collection.immutable.Seq$.from(Seq.scala:42)
at scala.collection.IterableOnceOps.toSeq(IterableOnce.scala:1326)
at scala.collection.IterableOnceOps.toSeq$(IterableOnce.scala:1326)
at scala.collection.AbstractIterator.toSeq(Iterator.scala:1300)
at
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformLocalRelation(SparkConnectPlanner.scala:1239)
at
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:139)
at
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.process(SparkConnectAnalyzeHandler.scala:59)
at
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1(SparkConnectAnalyzeHandler.scala:43)
at
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1$adapted(SparkConnectAnalyzeHandler.scala:42)
at
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:289)
at
org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:918)
at
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:289)
at
org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
at
org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:80)
at
org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:182)
at
org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:79)
at
org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:288)
at
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.handle(SparkConnectAnalyzeHandler.scala:42)
at
org.apache.spark.sql.connect.service.SparkConnectService.analyzePlan(SparkConnectService.scala:95)
at
org.apache.spark.connect.proto.SparkConnectServiceGrpc$MethodHandlers.invoke(SparkConnectServiceGrpc.java:907)
at
org.sparkproject.connect.grpc.io.grpc.stub.ServerCalls$UnaryServerCallHandler$UnaryServerCallListener.onHalfClose(ServerCalls.java:182)
at
org.sparkproject.connect.grpc.io.grpc.internal.ServerCallImpl$ServerStreamListenerImpl.halfClosed(ServerCallImpl.java:351)
at
org.sparkproject.connect.grpc.io.grpc.internal.ServerImpl$JumpToApplicationThreadServerStreamListener$1HalfClosed.runInContext(ServerImpl.java:860)
at
org.sparkproject.connect.grpc.io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
at
org.sparkproject.connect.grpc.io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
at
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
at
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
at java.lang.Thread.run(Thread.java:833)
```
after:
```
---------------------------------------------------------------------------
PySparkValueError Traceback (most recent call last)
Cell In[3], line 1
----> 1 df = spark.createDataFrame([("asd", None)], schema)
File ~/Dev/spark/python/pyspark/sql/connect/session.py:538, in
SparkSession.createDataFrame(self, data, schema)
533 from pyspark.sql.connect.conversion import
LocalDataToArrowConversion
535 # Spark Connect will try its best to build the Arrow table with
the
536 # inferred schema in the client side, and then rename the
columns and
537 # cast the datatypes in the server side.
--> 538 _table = LocalDataToArrowConversion.convert(_data, _schema)
540 # TODO: Beside the validation on number of columns, we should also
check
541 # whether the Arrow Schema is compatible with the user provided
Schema.
542 if _num_cols is not None and _num_cols != _table.shape[1]:
File ~/Dev/spark/python/pyspark/sql/connect/conversion.py:351, in
LocalDataToArrowConversion.convert(data, schema)
342 raise PySparkValueError(
343 error_class="AXIS_LENGTH_MISMATCH",
344 message_parameters={
(...)
347 },
348 )
350 for i in range(len(column_names)):
--> 351 pylist[i].append(column_convs[i](item[i]))
353 pa_schema = to_arrow_schema(
354 StructType(
355 [
(...)
361 )
362 )
364 return pa.Table.from_arrays(pylist, schema=pa_schema)
File ~/Dev/spark/python/pyspark/sql/connect/conversion.py:297, in
LocalDataToArrowConversion._create_converter.<locals>.convert_other(value)
295 def convert_other(value: Any) -> Any:
296 if value is None:
--> 297 raise PySparkValueError(f"input for {dataType} must not be
None")
298 return value
PySparkValueError: input for IntegerType() must not be None
```
### How was this patch tested?
added ut
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44861 from zhengruifeng/connect_check_nullable.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
python/pyspark/sql/connect/conversion.py | 78 +++++++++++++++++++---
.../sql/tests/connect/test_connect_basic.py | 12 ++++
2 files changed, 80 insertions(+), 10 deletions(-)
diff --git a/python/pyspark/sql/connect/conversion.py
b/python/pyspark/sql/connect/conversion.py
index fb5a2d4b17b1..c86ee9c75fec 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -61,14 +61,23 @@ class LocalDataToArrowConversion:
"""
@staticmethod
- def _need_converter(dataType: DataType) -> bool:
- if isinstance(dataType, NullType):
+ def _need_converter(
+ dataType: DataType,
+ nullable: bool = True,
+ ) -> bool:
+ if not nullable:
+ # always check the nullability
+ return True
+ elif isinstance(dataType, NullType):
+ # always check the nullability
return True
elif isinstance(dataType, StructType):
# Struct maybe rows, should convert to dict.
return True
elif isinstance(dataType, ArrayType):
- return
LocalDataToArrowConversion._need_converter(dataType.elementType)
+ return LocalDataToArrowConversion._need_converter(
+ dataType.elementType, dataType.containsNull
+ )
elif isinstance(dataType, MapType):
# Different from PySpark, here always needs conversion,
# since an Arrow Map requires a list of tuples.
@@ -90,26 +99,41 @@ class LocalDataToArrowConversion:
return False
@staticmethod
- def _create_converter(dataType: DataType) -> Callable:
+ def _create_converter(
+ dataType: DataType,
+ nullable: bool = True,
+ ) -> Callable:
assert dataType is not None and isinstance(dataType, DataType)
+ assert isinstance(nullable, bool)
- if not LocalDataToArrowConversion._need_converter(dataType):
+ if not LocalDataToArrowConversion._need_converter(dataType, nullable):
return lambda value: value
if isinstance(dataType, NullType):
- return lambda value: None
+
+ def convert_null(value: Any) -> Any:
+ if value is not None:
+ raise PySparkValueError(f"input for {dataType} must be
None, but got {value}")
+ return None
+
+ return convert_null
elif isinstance(dataType, StructType):
field_names = dataType.fieldNames()
dedup_field_names = _dedup_names(dataType.names)
field_convs = [
- LocalDataToArrowConversion._create_converter(field.dataType)
+ LocalDataToArrowConversion._create_converter(
+ field.dataType,
+ field.nullable,
+ )
for field in dataType.fields
]
def convert_struct(value: Any) -> Any:
if value is None:
+ if not nullable:
+ raise PySparkValueError(f"input for {dataType} must
not be None")
return None
else:
assert isinstance(value, (tuple, dict)) or hasattr(
@@ -143,10 +167,15 @@ class LocalDataToArrowConversion:
return convert_struct
elif isinstance(dataType, ArrayType):
- element_conv =
LocalDataToArrowConversion._create_converter(dataType.elementType)
+ element_conv = LocalDataToArrowConversion._create_converter(
+ dataType.elementType,
+ dataType.containsNull,
+ )
def convert_array(value: Any) -> Any:
if value is None:
+ if not nullable:
+ raise PySparkValueError(f"input for {dataType} must
not be None")
return None
else:
assert isinstance(value, (list, array.array))
@@ -156,10 +185,15 @@ class LocalDataToArrowConversion:
elif isinstance(dataType, MapType):
key_conv =
LocalDataToArrowConversion._create_converter(dataType.keyType)
- value_conv =
LocalDataToArrowConversion._create_converter(dataType.valueType)
+ value_conv = LocalDataToArrowConversion._create_converter(
+ dataType.valueType,
+ dataType.valueContainsNull,
+ )
def convert_map(value: Any) -> Any:
if value is None:
+ if not nullable:
+ raise PySparkValueError(f"input for {dataType} must
not be None")
return None
else:
assert isinstance(value, dict)
@@ -176,6 +210,8 @@ class LocalDataToArrowConversion:
def convert_binary(value: Any) -> Any:
if value is None:
+ if not nullable:
+ raise PySparkValueError(f"input for {dataType} must
not be None")
return None
else:
assert isinstance(value, (bytes, bytearray))
@@ -187,6 +223,8 @@ class LocalDataToArrowConversion:
def convert_timestamp(value: Any) -> Any:
if value is None:
+ if not nullable:
+ raise PySparkValueError(f"input for {dataType} must
not be None")
return None
else:
assert isinstance(value, datetime.datetime)
@@ -198,6 +236,8 @@ class LocalDataToArrowConversion:
def convert_timestamp_ntz(value: Any) -> Any:
if value is None:
+ if not nullable:
+ raise PySparkValueError(f"input for {dataType} must
not be None")
return None
else:
assert isinstance(value, datetime.datetime) and
value.tzinfo is None
@@ -209,6 +249,8 @@ class LocalDataToArrowConversion:
def convert_decimal(value: Any) -> Any:
if value is None:
+ if not nullable:
+ raise PySparkValueError(f"input for {dataType} must
not be None")
return None
else:
assert isinstance(value, decimal.Decimal)
@@ -220,6 +262,8 @@ class LocalDataToArrowConversion:
def convert_string(value: Any) -> Any:
if value is None:
+ if not nullable:
+ raise PySparkValueError(f"input for {dataType} must
not be None")
return None
else:
if isinstance(value, bool):
@@ -238,12 +282,22 @@ class LocalDataToArrowConversion:
def convert_udt(value: Any) -> Any:
if value is None:
+ if not nullable:
+ raise PySparkValueError(f"input for {dataType} must
not be None")
return None
else:
return conv(udt.serialize(value))
return convert_udt
+ elif not nullable:
+
+ def convert_other(value: Any) -> Any:
+ if value is None:
+ raise PySparkValueError(f"input for {dataType} must not be
None")
+ return value
+
+ return convert_other
else:
return lambda value: value
@@ -256,7 +310,11 @@ class LocalDataToArrowConversion:
column_names = schema.fieldNames()
column_convs = [
- LocalDataToArrowConversion._create_converter(field.dataType) for
field in schema.fields
+ LocalDataToArrowConversion._create_converter(
+ field.dataType,
+ field.nullable,
+ )
+ for field in schema.fields
]
pylist: List[List] = [[] for _ in range(len(column_names))]
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index fbc1debe7511..08b0a0be2dcf 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1125,6 +1125,18 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.assertEqual(cdf.schema, sdf.schema)
self.assertEqual(cdf.collect(), sdf.collect())
+ def test_create_df_nullability(self):
+ data = [("asd", None)]
+ schema = StructType(
+ [
+ StructField("name", StringType(), nullable=True),
+ StructField("age", IntegerType(), nullable=False),
+ ]
+ )
+
+ with self.assertRaises(PySparkValueError):
+ self.spark.createDataFrame(data, schema)
+
def test_simple_explain_string(self):
df = self.connect.read.table(self.tbl_name).limit(10)
result = df._explain_string()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]