WeichenXu123 commented on code in PR #40724: URL: https://github.com/apache/spark/pull/40724#discussion_r1161659107
########## python/pyspark/ml/torch/data.py: ########## @@ -0,0 +1,78 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import numpy as np + + +class SparkPartitionTorchDataset(torch.utils.data.IterableDataset): + + def __init__(self, arrow_file_path, schema, num_samples): + self.arrow_file_path = arrow_file_path + self.num_samples = num_samples + self.field_types = [field.dataType.simpleString() for field in schema] + + @staticmethod + def _extract_field_value(value, field_type): + # TODO: avoid checking field type for every row. + if field_type == "vector": + if value['type'] == 1: + # dense vector + return value['values'] + if value['type'] == 0: + # sparse vector + size = int(value['size']) + np_array = np.zeros(size, dtype=np.float64) + for index, elem_value in zip(value['indices'], value['values']): + np_array[index] = elem_value + return np_array + if field_type in ["float", "double", "int", "bigint", "smallint"]: + return value + + raise ValueError( + "SparkPartitionTorchDataset does not support loading data from field of " + f"type {field_type}." + ) + + def __iter__(self): + from pyspark.sql.pandas.serializers import ArrowStreamSerializer + serializer = ArrowStreamSerializer() + + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None and worker_info.num_workers > 1: + raise RuntimeError( + "`SparkPartitionTorchDataset` does not support multiple worker processes." + ) + + count = 0 + + while count < self.num_samples: + with open(self.arrow_file_path, "rb") as f: + batch_iter = serializer.load_stream(f) + for batch in batch_iter: + # TODO: we can optimize this further by directly extracting Review Comment: We shouldn't use it, we can read how it handle array type: ``` elif isinstance(dataType, ArrayType): element_conv = ArrowTableToRowsConversion._create_converter(dataType.elementType) def convert_array(value: Any) -> Any: if value is None: return None else: assert isinstance(value, list) return [element_conv(v) for v in value] return convert_array ``` it returns python list for array value, which is not efficient. We have "vector" type. But pandas dataframe handle it well by creating numpy array in this case. ########## python/pyspark/ml/torch/distributor.py: ########## @@ -643,13 +664,48 @@ def _setup_files(train_fn: Callable, *args: Any) -> Generator[Tuple[str, str], N finally: TorchDistributor._cleanup_files(save_dir) + @staticmethod + @contextmanager + def _setup_spark_partition_data(partition_data_iterator, input_schema_json): + from pyspark.sql.pandas.serializers import ArrowStreamSerializer + import json + + if input_schema_json is None: + yield + return + + save_dir = TorchDistributor._create_save_dir() + + try: + serializer = ArrowStreamSerializer() Review Comment: https://github.com/apache/spark/pull/40724#discussion_r1161659107 -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
