Github user HyukjinKwon commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19459#discussion_r150228512
  
    --- Diff: python/pyspark/sql/session.py ---
    @@ -454,13 +454,60 @@ def _convert_from_pandas(self, pdf, schema):
     
             # Check if any columns need to be fixed for Spark to infer properly
             if len(np_records) > 0:
    -            record_type_list = self._get_numpy_record_dtypes(np_records[0])
    -            if record_type_list is not None:
    -                return [r.astype(record_type_list).tolist() for r in 
np_records], schema
    +            record_dtype = self._get_numpy_record_dtype(np_records[0])
    +            if record_dtype is not None:
    +                return [r.astype(record_dtype).tolist() for r in 
np_records], schema
     
             # Convert list of numpy records to python lists
             return [r.tolist() for r in np_records], schema
     
    +    def _create_from_pandas_with_arrow(self, pdf, schema):
    +        """
    +        Create a DataFrame from a given pandas.DataFrame by slicing it 
into partitions, converting
    +        to Arrow data, then sending to the JVM to parallelize. If a schema 
is passed in, the
    +        data types will be used to coerce the data in Pandas to Arrow 
conversion.
    +        """
    +        from pyspark.serializers import ArrowSerializer, _create_batch
    +        from pyspark.sql.types import from_arrow_schema, to_arrow_type, 
TimestampType
    +        from pandas.api.types import is_datetime64_dtype, 
is_datetime64tz_dtype
    +
    +        # Determine arrow types to coerce data when creating batches
    +        if isinstance(schema, StructType):
    +            arrow_types = [to_arrow_type(f.dataType) for f in 
schema.fields]
    +        elif isinstance(schema, DataType):
    +            raise ValueError("Single data type %s is not supported with 
Arrow" % str(schema))
    +        else:
    +            # Any timestamps must be coerced to be compatible with Spark
    +            arrow_types = [to_arrow_type(TimestampType())
    +                           if is_datetime64_dtype(t) or 
is_datetime64tz_dtype(t) else None
    +                           for t in pdf.dtypes]
    +
    +        # Slice the DataFrame to be batched
    +        step = -(-len(pdf) // self.sparkContext.defaultParallelism)  # 
round int up
    +        pdf_slices = (pdf[start:start + step] for start in xrange(0, 
len(pdf), step))
    +
    +        # Create Arrow record batches
    +        batches = [_create_batch([(c, t) for (_, c), t in 
zip(pdf_slice.iteritems(), arrow_types)])
    +                   for pdf_slice in pdf_slices]
    +
    +        # Create the Spark schema from the first Arrow batch (always at 
least 1 batch after slicing)
    +        if schema is None or isinstance(schema, list):
    +            schema_from_arrow = from_arrow_schema(batches[0].schema)
    +            names = pdf.columns if schema is None else schema
    --- End diff --
    
    Could we maybe just resemble
    
    
https://github.com/apache/spark/blob/1d341042d6948e636643183da9bf532268592c6a/python/pyspark/sql/session.py#L403-L411
    
    just to be more readable in a way?
    
    ```python
    if schema is None or isinstance(schema, (list, tuple)):
        struct = from_arrow_schema(batches[0].schema)
        if isinstance(schema, (list, tuple)):
            for i, name in enumerate(schema):
                struct.fields[i].name = name
                struct.names[i] = name
        schema = struct
    ```


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to