This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1642e928478c [SPARK-46823][CONNECT][PYTHON] 
`LocalDataToArrowConversion` should check the nullability
1642e928478c is described below

commit 1642e928478c8c20bae5203ecf2e4d659aca7692
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Jan 24 00:43:41 2024 -0800

    [SPARK-46823][CONNECT][PYTHON] `LocalDataToArrowConversion` should check 
the nullability
    
    ### What changes were proposed in this pull request?
    `LocalDataToArrowConversion` should check the nullability
    
    ### Why are the changes needed?
    this check was missing
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ```
            data = [("asd", None)]
            schema = StructType(
                [
                    StructField("name", StringType(), nullable=True),
                    StructField("age", IntegerType(), nullable=False),
                ]
            )
    ```
    
    before:
    ```
    In [3]: df = spark.createDataFrame([("asd", None)], schema)
    
    In [4]: df
    Out[4]: 24/01/24 12:08:28 ERROR ErrorUtils: Spark Connect RPC error during: 
analyze. UserId: ruifeng.zheng. SessionId: cd692bb1-d503-4043-a9db-d29cb5c16517.
    java.lang.IllegalStateException: Value at index is null
            at org.apache.arrow.vector.IntVector.get(IntVector.java:107)
            at 
org.apache.spark.sql.vectorized.ArrowColumnVector$IntAccessor.getInt(ArrowColumnVector.java:338)
            at 
org.apache.spark.sql.vectorized.ArrowColumnVector.getInt(ArrowColumnVector.java:88)
            at 
org.apache.spark.sql.vectorized.ColumnarBatchRow.getInt(ColumnarBatchRow.java:109)
            at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
 Source)
            at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
 Source)
            at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
            at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
            at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
            at scala.collection.immutable.List.prependedAll(List.scala:153)
            at scala.collection.immutable.List$.from(List.scala:684)
            at scala.collection.immutable.List$.from(List.scala:681)
            at scala.collection.SeqFactory$Delegate.from(Factory.scala:306)
            at scala.collection.immutable.Seq$.from(Seq.scala:42)
            at scala.collection.IterableOnceOps.toSeq(IterableOnce.scala:1326)
            at scala.collection.IterableOnceOps.toSeq$(IterableOnce.scala:1326)
            at scala.collection.AbstractIterator.toSeq(Iterator.scala:1300)
            at 
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformLocalRelation(SparkConnectPlanner.scala:1239)
            at 
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:139)
            at 
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.process(SparkConnectAnalyzeHandler.scala:59)
            at 
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1(SparkConnectAnalyzeHandler.scala:43)
            at 
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1$adapted(SparkConnectAnalyzeHandler.scala:42)
            at 
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:289)
            at 
org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:918)
            at 
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:289)
            at 
org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
            at 
org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:80)
            at 
org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:182)
            at 
org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:79)
            at 
org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:288)
            at 
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.handle(SparkConnectAnalyzeHandler.scala:42)
            at 
org.apache.spark.sql.connect.service.SparkConnectService.analyzePlan(SparkConnectService.scala:95)
            at 
org.apache.spark.connect.proto.SparkConnectServiceGrpc$MethodHandlers.invoke(SparkConnectServiceGrpc.java:907)
            at 
org.sparkproject.connect.grpc.io.grpc.stub.ServerCalls$UnaryServerCallHandler$UnaryServerCallListener.onHalfClose(ServerCalls.java:182)
            at 
org.sparkproject.connect.grpc.io.grpc.internal.ServerCallImpl$ServerStreamListenerImpl.halfClosed(ServerCallImpl.java:351)
            at 
org.sparkproject.connect.grpc.io.grpc.internal.ServerImpl$JumpToApplicationThreadServerStreamListener$1HalfClosed.runInContext(ServerImpl.java:860)
            at 
org.sparkproject.connect.grpc.io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
            at 
org.sparkproject.connect.grpc.io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
            at 
java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
            at 
java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
            at java.base/java.lang.Thread.run(Thread.java:833)
    24/01/24 12:08:28 ERROR ErrorUtils: Spark Connect RPC error during: 
analyze. UserId: ruifeng.zheng. SessionId: cd692bb1-d503-4043-a9db-d29cb5c16517.
    java.lang.IllegalStateException: Value at index is null
            at org.apache.arrow.vector.IntVector.get(IntVector.java:107)
            at 
org.apache.spark.sql.vectorized.ArrowColumnVector$IntAccessor.getInt(ArrowColumnVector.java:338)
            at 
org.apache.spark.sql.vectorized.ArrowColumnVector.getInt(ArrowColumnVector.java:88)
            at 
org.apache.spark.sql.vectorized.ColumnarBatchRow.getInt(ColumnarBatchRow.java:109)
            at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
 Source)
            at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
 Source)
            at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
            at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
            at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
            at scala.collection.immutable.List.prependedAll(List.scala:153)
            at scala.collection.immutable.List$.from(List.scala:684)
            at scala.collection.immutable.List$.from(List.scala:681)
            at scala.collection.SeqFactory$Delegate.from(Factory.scala:306)
            at scala.collection.immutable.Seq$.from(Seq.scala:42)
            at scala.collection.IterableOnceOps.toSeq(IterableOnce.scala:1326)
            at scala.collection.IterableOnceOps.toSeq$(IterableOnce.scala:1326)
            at scala.collection.AbstractIterator.toSeq(Iterator.scala:1300)
            at 
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformLocalRelation(SparkConnectPlanner.scala:1239)
            at 
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:139)
            at 
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.process(SparkConnectAnalyzeHandler.scala:59)
            at 
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1(SparkConnectAnalyzeHandler.scala:43)
            at 
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1$adapted(SparkConnectAnalyzeHandler.scala:42)
            at 
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:289)
            at 
org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:918)
            at 
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:289)
            at 
org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
            at 
org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:80)
            at 
org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:182)
            at 
org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:79)
            at 
org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:288)
            at 
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.handle(SparkConnectAnalyzeHandler.scala:42)
            at 
org.apache.spark.sql.connect.service.SparkConnectService.analyzePlan(SparkConnectService.scala:95)
            at 
org.apache.spark.connect.proto.SparkConnectServiceGrpc$MethodHandlers.invoke(SparkConnectServiceGrpc.java:907)
            at 
org.sparkproject.connect.grpc.io.grpc.stub.ServerCalls$UnaryServerCallHandler$UnaryServerCallListener.onHalfClose(ServerCalls.java:182)
            at 
org.sparkproject.connect.grpc.io.grpc.internal.ServerCallImpl$ServerStreamListenerImpl.halfClosed(ServerCallImpl.java:351)
            at 
org.sparkproject.connect.grpc.io.grpc.internal.ServerImpl$JumpToApplicationThreadServerStreamListener$1HalfClosed.runInContext(ServerImpl.java:860)
            at 
org.sparkproject.connect.grpc.io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
            at 
org.sparkproject.connect.grpc.io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
            at 
java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
            at 
java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
            at java.base/java.lang.Thread.run(Thread.java:833)
    24/01/24 12:08:28 ERROR ErrorUtils: Spark Connect RPC error during: 
analyze. UserId: ruifeng.zheng. SessionId: cd692bb1-d503-4043-a9db-d29cb5c16517.
    java.lang.IllegalStateException: Value at index is null
            at org.apache.arrow.vector.IntVector.get(IntVector.java:107)
            at 
org.apache.spark.sql.vectorized.ArrowColumnVector$IntAccessor.getInt(ArrowColumnVector.java:338)
            at 
org.apache.spark.sql.vectorized.ArrowColumnVector.getInt(ArrowColumnVector.java:88)
            at 
org.apache.spark.sql.vectorized.ColumnarBatchRow.getInt(ColumnarBatchRow.java:109)
            at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
 Source)
            at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown
 Source)
            at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
            at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
            at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
            at scala.collection.immutable.List.prependedAll(List.scala:153)
            at scala.collection.immutable.List$.from(List.scala:684)
            at scala.collection.immutable.List$.from(List.scala:681)
            at scala.collection.SeqFactory$Delegate.from(Factory.scala:306)
            at scala.collection.immutable.Seq$.from(Seq.scala:42)
            at scala.collection.IterableOnceOps.toSeq(IterableOnce.scala:1326)
            at scala.collection.IterableOnceOps.toSeq$(IterableOnce.scala:1326)
            at scala.collection.AbstractIterator.toSeq(Iterator.scala:1300)
            at 
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformLocalRelation(SparkConnectPlanner.scala:1239)
            at 
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:139)
            at 
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.process(SparkConnectAnalyzeHandler.scala:59)
            at 
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1(SparkConnectAnalyzeHandler.scala:43)
            at 
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1$adapted(SparkConnectAnalyzeHandler.scala:42)
            at 
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:289)
            at 
org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:918)
            at 
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:289)
            at 
org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
            at 
org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:80)
            at 
org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:182)
            at 
org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:79)
            at 
org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:288)
            at 
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.handle(SparkConnectAnalyzeHandler.scala:42)
            at 
org.apache.spark.sql.connect.service.SparkConnectService.analyzePlan(SparkConnectService.scala:95)
            at 
org.apache.spark.connect.proto.SparkConnectServiceGrpc$MethodHandlers.invoke(SparkConnectServiceGrpc.java:907)
            at 
org.sparkproject.connect.grpc.io.grpc.stub.ServerCalls$UnaryServerCallHandler$UnaryServerCallListener.onHalfClose(ServerCalls.java:182)
            at 
org.sparkproject.connect.grpc.io.grpc.internal.ServerCallImpl$ServerStreamListenerImpl.halfClosed(ServerCallImpl.java:351)
            at 
org.sparkproject.connect.grpc.io.grpc.internal.ServerImpl$JumpToApplicationThreadServerStreamListener$1HalfClosed.runInContext(ServerImpl.java:860)
            at 
org.sparkproject.connect.grpc.io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
            at 
org.sparkproject.connect.grpc.io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
            at 
java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
            at 
java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
            at java.base/java.lang.Thread.run(Thread.java:833)
    ---------------------------------------------------------------------------
    SparkConnectGrpcException                 Traceback (most recent call last)
    File 
~/.dev/miniconda3/envs/spark_dev_311/lib/python3.11/site-packages/IPython/core/formatters.py:708,
 in PlainTextFormatter.__call__(self, obj)
        701 stream = StringIO()
        702 printer = pretty.RepresentationPrinter(stream, self.verbose,
        703     self.max_width, self.newline,
        704     max_seq_length=self.max_seq_length,
        705     singleton_pprinters=self.singleton_printers,
        706     type_pprinters=self.type_printers,
        707     deferred_pprinters=self.deferred_printers)
    --> 708 printer.pretty(obj)
        709 printer.flush()
        710 return stream.getvalue()
    
    File 
~/.dev/miniconda3/envs/spark_dev_311/lib/python3.11/site-packages/IPython/lib/pretty.py:410,
 in RepresentationPrinter.pretty(self, obj)
        407                         return meth(obj, self, cycle)
        408                 if cls is not object \
        409                         and callable(cls.__dict__.get('__repr__')):
    --> 410                     return _repr_pprint(obj, self, cycle)
        412     return _default_pprint(obj, self, cycle)
        413 finally:
    
    File 
~/.dev/miniconda3/envs/spark_dev_311/lib/python3.11/site-packages/IPython/lib/pretty.py:778,
 in _repr_pprint(obj, p, cycle)
        776 """A pprint that just redirects to the normal repr function."""
        777 # Find newlines and replace them with p.break_()
    --> 778 output = repr(obj)
        779 lines = output.splitlines()
        780 with p.group():
    
    File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:141, in 
DataFrame.__repr__(self)
        135     if repl_eager_eval_enabled == "true":
        136         return self._show_string(
        137             n=int(cast(str, repl_eager_eval_max_num_rows)),
        138             truncate=int(cast(str, repl_eager_eval_truncate)),
        139             vertical=False,
        140         )
    --> 141 return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in 
self.dtypes))
    
    File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:238, in 
DataFrame.dtypes(self)
        236 property
        237 def dtypes(self) -> List[Tuple[str, str]]:
    --> 238     return [(str(f.name), f.dataType.simpleString()) for f in 
self.schema.fields]
    
    File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:1786, in 
DataFrame.schema(self)
       1783 property
       1784 def schema(self) -> StructType:
       1785     query = self._plan.to_proto(self._session.client)
    -> 1786     return self._session.client.schema(query)
    
    File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:921, in 
SparkConnectClient.schema(self, plan)
        917 """
        918 Return schema for given plan.
        919 """
        920 logger.info(f"Schema for plan: {self._proto_to_string(plan)}")
    --> 921 schema = self._analyze(method="schema", plan=plan).schema
        922 assert schema is not None
        923 # Server side should populate the struct field which is the schema.
    
    File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1107, in 
SparkConnectClient._analyze(self, method, **kwargs)
       1105     raise SparkConnectException("Invalid state during retry 
exception handling.")
       1106 except Exception as error:
    -> 1107     self._handle_error(error)
    
    File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1525, in 
SparkConnectClient._handle_error(self, error)
       1523 self.thread_local.inside_error_handling = True
       1524 if isinstance(error, grpc.RpcError):
    -> 1525     self._handle_rpc_error(error)
       1526 elif isinstance(error, ValueError):
       1527     if "Cannot invoke RPC" in str(error) and "closed" in str(error):
    
    File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1595, in 
SparkConnectClient._handle_rpc_error(self, rpc_error)
       1592             info = error_details_pb2.ErrorInfo()
       1593             d.Unpack(info)
    -> 1595             raise convert_exception(
       1596                 info,
       1597                 status.message,
       1598                 self._fetch_enriched_error(info),
       1599                 self._display_server_stack_trace(),
       1600             ) from None
       1602     raise SparkConnectGrpcException(status.message) from None
       1603 else:
    
    SparkConnectGrpcException: (java.lang.IllegalStateException) Value at index 
is null
    
    JVM stacktrace:
    java.lang.IllegalStateException
            at org.apache.arrow.vector.IntVector.get(IntVector.java:107)
            at 
org.apache.spark.sql.vectorized.ArrowColumnVector$IntAccessor.getInt(ArrowColumnVector.java:338)
            at 
org.apache.spark.sql.vectorized.ArrowColumnVector.getInt(ArrowColumnVector.java:88)
            at 
org.apache.spark.sql.vectorized.ColumnarBatchRow.getInt(ColumnarBatchRow.java:109)
            at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(:-1)
            at 
org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(:-1)
            at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
            at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
            at scala.collection.Iterator$$anon$9.next(Iterator.scala:584)
            at scala.collection.immutable.List.prependedAll(List.scala:153)
            at scala.collection.immutable.List$.from(List.scala:684)
            at scala.collection.immutable.List$.from(List.scala:681)
            at scala.collection.SeqFactory$Delegate.from(Factory.scala:306)
            at scala.collection.immutable.Seq$.from(Seq.scala:42)
            at scala.collection.IterableOnceOps.toSeq(IterableOnce.scala:1326)
            at scala.collection.IterableOnceOps.toSeq$(IterableOnce.scala:1326)
            at scala.collection.AbstractIterator.toSeq(Iterator.scala:1300)
            at 
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformLocalRelation(SparkConnectPlanner.scala:1239)
            at 
org.apache.spark.sql.connect.planner.SparkConnectPlanner.transformRelation(SparkConnectPlanner.scala:139)
            at 
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.process(SparkConnectAnalyzeHandler.scala:59)
            at 
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1(SparkConnectAnalyzeHandler.scala:43)
            at 
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.$anonfun$handle$1$adapted(SparkConnectAnalyzeHandler.scala:42)
            at 
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$2(SessionHolder.scala:289)
            at 
org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:918)
            at 
org.apache.spark.sql.connect.service.SessionHolder.$anonfun$withSession$1(SessionHolder.scala:289)
            at 
org.apache.spark.JobArtifactSet$.withActiveJobArtifactState(JobArtifactSet.scala:94)
            at 
org.apache.spark.sql.artifact.ArtifactManager.$anonfun$withResources$1(ArtifactManager.scala:80)
            at 
org.apache.spark.util.Utils$.withContextClassLoader(Utils.scala:182)
            at 
org.apache.spark.sql.artifact.ArtifactManager.withResources(ArtifactManager.scala:79)
            at 
org.apache.spark.sql.connect.service.SessionHolder.withSession(SessionHolder.scala:288)
            at 
org.apache.spark.sql.connect.service.SparkConnectAnalyzeHandler.handle(SparkConnectAnalyzeHandler.scala:42)
            at 
org.apache.spark.sql.connect.service.SparkConnectService.analyzePlan(SparkConnectService.scala:95)
            at 
org.apache.spark.connect.proto.SparkConnectServiceGrpc$MethodHandlers.invoke(SparkConnectServiceGrpc.java:907)
            at 
org.sparkproject.connect.grpc.io.grpc.stub.ServerCalls$UnaryServerCallHandler$UnaryServerCallListener.onHalfClose(ServerCalls.java:182)
            at 
org.sparkproject.connect.grpc.io.grpc.internal.ServerCallImpl$ServerStreamListenerImpl.halfClosed(ServerCallImpl.java:351)
            at 
org.sparkproject.connect.grpc.io.grpc.internal.ServerImpl$JumpToApplicationThreadServerStreamListener$1HalfClosed.runInContext(ServerImpl.java:860)
            at 
org.sparkproject.connect.grpc.io.grpc.internal.ContextRunnable.run(ContextRunnable.java:37)
            at 
org.sparkproject.connect.grpc.io.grpc.internal.SerializingExecutor.run(SerializingExecutor.java:133)
            at 
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
            at 
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
            at java.lang.Thread.run(Thread.java:833)
    
    ```
    
    after:
    ```
    ---------------------------------------------------------------------------
    PySparkValueError                         Traceback (most recent call last)
    Cell In[3], line 1
    ----> 1 df = spark.createDataFrame([("asd", None)], schema)
    
    File ~/Dev/spark/python/pyspark/sql/connect/session.py:538, in 
SparkSession.createDataFrame(self, data, schema)
        533     from pyspark.sql.connect.conversion import 
LocalDataToArrowConversion
        535     # Spark Connect will try its best to build the Arrow table with 
the
        536     # inferred schema in the client side, and then rename the 
columns and
        537     # cast the datatypes in the server side.
    --> 538     _table = LocalDataToArrowConversion.convert(_data, _schema)
        540 # TODO: Beside the validation on number of columns, we should also 
check
        541 # whether the Arrow Schema is compatible with the user provided 
Schema.
        542 if _num_cols is not None and _num_cols != _table.shape[1]:
    
    File ~/Dev/spark/python/pyspark/sql/connect/conversion.py:351, in 
LocalDataToArrowConversion.convert(data, schema)
        342             raise PySparkValueError(
        343                 error_class="AXIS_LENGTH_MISMATCH",
        344                 message_parameters={
       (...)
        347                 },
        348             )
        350         for i in range(len(column_names)):
    --> 351             pylist[i].append(column_convs[i](item[i]))
        353 pa_schema = to_arrow_schema(
        354     StructType(
        355         [
       (...)
        361     )
        362 )
        364 return pa.Table.from_arrays(pylist, schema=pa_schema)
    
    File ~/Dev/spark/python/pyspark/sql/connect/conversion.py:297, in 
LocalDataToArrowConversion._create_converter.<locals>.convert_other(value)
        295 def convert_other(value: Any) -> Any:
        296     if value is None:
    --> 297         raise PySparkValueError(f"input for {dataType} must not be 
None")
        298     return value
    
    PySparkValueError: input for IntegerType() must not be None
    ```
    
    ### How was this patch tested?
    added ut
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #44861 from zhengruifeng/connect_check_nullable.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 python/pyspark/sql/connect/conversion.py           | 78 +++++++++++++++++++---
 .../sql/tests/connect/test_connect_basic.py        | 12 ++++
 2 files changed, 80 insertions(+), 10 deletions(-)

diff --git a/python/pyspark/sql/connect/conversion.py 
b/python/pyspark/sql/connect/conversion.py
index fb5a2d4b17b1..c86ee9c75fec 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -61,14 +61,23 @@ class LocalDataToArrowConversion:
     """
 
     @staticmethod
-    def _need_converter(dataType: DataType) -> bool:
-        if isinstance(dataType, NullType):
+    def _need_converter(
+        dataType: DataType,
+        nullable: bool = True,
+    ) -> bool:
+        if not nullable:
+            # always check the nullability
+            return True
+        elif isinstance(dataType, NullType):
+            # always check the nullability
             return True
         elif isinstance(dataType, StructType):
             # Struct maybe rows, should convert to dict.
             return True
         elif isinstance(dataType, ArrayType):
-            return 
LocalDataToArrowConversion._need_converter(dataType.elementType)
+            return LocalDataToArrowConversion._need_converter(
+                dataType.elementType, dataType.containsNull
+            )
         elif isinstance(dataType, MapType):
             # Different from PySpark, here always needs conversion,
             # since an Arrow Map requires a list of tuples.
@@ -90,26 +99,41 @@ class LocalDataToArrowConversion:
             return False
 
     @staticmethod
-    def _create_converter(dataType: DataType) -> Callable:
+    def _create_converter(
+        dataType: DataType,
+        nullable: bool = True,
+    ) -> Callable:
         assert dataType is not None and isinstance(dataType, DataType)
+        assert isinstance(nullable, bool)
 
-        if not LocalDataToArrowConversion._need_converter(dataType):
+        if not LocalDataToArrowConversion._need_converter(dataType, nullable):
             return lambda value: value
 
         if isinstance(dataType, NullType):
-            return lambda value: None
+
+            def convert_null(value: Any) -> Any:
+                if value is not None:
+                    raise PySparkValueError(f"input for {dataType} must be 
None, but got {value}")
+                return None
+
+            return convert_null
 
         elif isinstance(dataType, StructType):
             field_names = dataType.fieldNames()
             dedup_field_names = _dedup_names(dataType.names)
 
             field_convs = [
-                LocalDataToArrowConversion._create_converter(field.dataType)
+                LocalDataToArrowConversion._create_converter(
+                    field.dataType,
+                    field.nullable,
+                )
                 for field in dataType.fields
             ]
 
             def convert_struct(value: Any) -> Any:
                 if value is None:
+                    if not nullable:
+                        raise PySparkValueError(f"input for {dataType} must 
not be None")
                     return None
                 else:
                     assert isinstance(value, (tuple, dict)) or hasattr(
@@ -143,10 +167,15 @@ class LocalDataToArrowConversion:
             return convert_struct
 
         elif isinstance(dataType, ArrayType):
-            element_conv = 
LocalDataToArrowConversion._create_converter(dataType.elementType)
+            element_conv = LocalDataToArrowConversion._create_converter(
+                dataType.elementType,
+                dataType.containsNull,
+            )
 
             def convert_array(value: Any) -> Any:
                 if value is None:
+                    if not nullable:
+                        raise PySparkValueError(f"input for {dataType} must 
not be None")
                     return None
                 else:
                     assert isinstance(value, (list, array.array))
@@ -156,10 +185,15 @@ class LocalDataToArrowConversion:
 
         elif isinstance(dataType, MapType):
             key_conv = 
LocalDataToArrowConversion._create_converter(dataType.keyType)
-            value_conv = 
LocalDataToArrowConversion._create_converter(dataType.valueType)
+            value_conv = LocalDataToArrowConversion._create_converter(
+                dataType.valueType,
+                dataType.valueContainsNull,
+            )
 
             def convert_map(value: Any) -> Any:
                 if value is None:
+                    if not nullable:
+                        raise PySparkValueError(f"input for {dataType} must 
not be None")
                     return None
                 else:
                     assert isinstance(value, dict)
@@ -176,6 +210,8 @@ class LocalDataToArrowConversion:
 
             def convert_binary(value: Any) -> Any:
                 if value is None:
+                    if not nullable:
+                        raise PySparkValueError(f"input for {dataType} must 
not be None")
                     return None
                 else:
                     assert isinstance(value, (bytes, bytearray))
@@ -187,6 +223,8 @@ class LocalDataToArrowConversion:
 
             def convert_timestamp(value: Any) -> Any:
                 if value is None:
+                    if not nullable:
+                        raise PySparkValueError(f"input for {dataType} must 
not be None")
                     return None
                 else:
                     assert isinstance(value, datetime.datetime)
@@ -198,6 +236,8 @@ class LocalDataToArrowConversion:
 
             def convert_timestamp_ntz(value: Any) -> Any:
                 if value is None:
+                    if not nullable:
+                        raise PySparkValueError(f"input for {dataType} must 
not be None")
                     return None
                 else:
                     assert isinstance(value, datetime.datetime) and 
value.tzinfo is None
@@ -209,6 +249,8 @@ class LocalDataToArrowConversion:
 
             def convert_decimal(value: Any) -> Any:
                 if value is None:
+                    if not nullable:
+                        raise PySparkValueError(f"input for {dataType} must 
not be None")
                     return None
                 else:
                     assert isinstance(value, decimal.Decimal)
@@ -220,6 +262,8 @@ class LocalDataToArrowConversion:
 
             def convert_string(value: Any) -> Any:
                 if value is None:
+                    if not nullable:
+                        raise PySparkValueError(f"input for {dataType} must 
not be None")
                     return None
                 else:
                     if isinstance(value, bool):
@@ -238,12 +282,22 @@ class LocalDataToArrowConversion:
 
             def convert_udt(value: Any) -> Any:
                 if value is None:
+                    if not nullable:
+                        raise PySparkValueError(f"input for {dataType} must 
not be None")
                     return None
                 else:
                     return conv(udt.serialize(value))
 
             return convert_udt
 
+        elif not nullable:
+
+            def convert_other(value: Any) -> Any:
+                if value is None:
+                    raise PySparkValueError(f"input for {dataType} must not be 
None")
+                return value
+
+            return convert_other
         else:
             return lambda value: value
 
@@ -256,7 +310,11 @@ class LocalDataToArrowConversion:
         column_names = schema.fieldNames()
 
         column_convs = [
-            LocalDataToArrowConversion._create_converter(field.dataType) for 
field in schema.fields
+            LocalDataToArrowConversion._create_converter(
+                field.dataType,
+                field.nullable,
+            )
+            for field in schema.fields
         ]
 
         pylist: List[List] = [[] for _ in range(len(column_names))]
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index fbc1debe7511..08b0a0be2dcf 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1125,6 +1125,18 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         self.assertEqual(cdf.schema, sdf.schema)
         self.assertEqual(cdf.collect(), sdf.collect())
 
+    def test_create_df_nullability(self):
+        data = [("asd", None)]
+        schema = StructType(
+            [
+                StructField("name", StringType(), nullable=True),
+                StructField("age", IntegerType(), nullable=False),
+            ]
+        )
+
+        with self.assertRaises(PySparkValueError):
+            self.spark.createDataFrame(data, schema)
+
     def test_simple_explain_string(self):
         df = self.connect.read.table(self.tbl_name).limit(10)
         result = df._explain_string()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to