Github user HyukjinKwon commented on a diff in the pull request:
https://github.com/apache/spark/pull/19439#discussion_r150376997
--- Diff: python/pyspark/ml/image.py ---
@@ -0,0 +1,196 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+.. attribute:: ImageSchema
+
+ An attribute of :class:`_ImageSchema` in this module.
+
+.. autoclass:: _ImageSchema
+ :members:
+"""
+
+import numpy as np
+from pyspark import SparkContext
+from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string
+from pyspark.sql import DataFrame, SparkSession
+
+
+class _ImageSchema(object):
+ """
+ Internal class for `pyspark.ml.image.ImageSchema` attribute. Meant to
be private and
+ not to be instantized. Use `pyspark.ml.image.ImageSchema` attribute to
access the
+ APIs of this class.
+ """
+
+ def __init__(self):
+ self._imageSchema = None
+ self._ocvTypes = None
+ self._imageFields = None
+ self._undefinedImageType = None
+
+ @property
+ def imageSchema(self):
+ """
+ Returns the image schema.
+
+ :rtype StructType: a DataFrame with a single column of images
+ named "image" (nullable)
+
+ .. versionadded:: 2.3.0
+ """
+
+ if self._imageSchema is None:
+ ctx = SparkContext._active_spark_context
+ jschema =
ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageSchema()
+ self._imageSchema = _parse_datatype_json_string(jschema.json())
+ return self._imageSchema
+
+ @property
+ def ocvTypes(self):
+ """
+ Returns the OpenCV type mapping supported
+
+ :rtype dict: The OpenCV type mapping supported
+
+ .. versionadded:: 2.3.0
+ """
+
+ if self._ocvTypes is None:
+ ctx = SparkContext._active_spark_context
+ self._ocvTypes =
dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes())
+ return self._ocvTypes
+
+ @property
+ def imageFields(self):
+ """
+ Returns field names of image columns.
+
+ :rtype list: a list of field names.
+
+ .. versionadded:: 2.3.0
+ """
+
+ if self._imageFields is None:
+ ctx = SparkContext._active_spark_context
+ self._imageFields =
list(ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageFields())
+ return self._imageFields
+
+ @property
+ def undefinedImageType(self):
+ """
+ Returns the name of undefined image type for the invalid image.
+
+ .. versionadded:: 2.3.0
+ """
+
+ if self._undefinedImageType is None:
+ ctx = SparkContext._active_spark_context
+ self._undefinedImageType = \
+
ctx._jvm.org.apache.spark.ml.image.ImageSchema.undefinedImageType()
+ return self._undefinedImageType
+
+ def toNDArray(self, image):
+ """
+ Converts an image to a one-dimensional array.
+
+ :param image: The image to be converted
+ :rtype array: The image as a one-dimensional array
+
+ .. versionadded:: 2.3.0
+ """
+
+ height = image.height
+ width = image.width
+ nChannels = image.nChannels
+ return np.ndarray(
+ shape=(height, width, nChannels),
+ dtype=np.uint8,
+ buffer=image.data,
+ strides=(width * nChannels, nChannels, 1))
+
+ def toImage(self, array, origin=""):
+ """
+ Converts an array with metadata to a two-dimensional image.
+
+ :param array array: The array to convert to image
+ :param str origin: Path to the image, optional
+ :rtype object: Two dimensional image
+
+ .. versionadded:: 2.3.0
+ """
+
+ if array.ndim != 3:
+ raise ValueError("Invalid array shape")
+ height, width, nChannels = array.shape
+ ocvTypes = ImageSchema.ocvTypes
+ if nChannels == 1:
+ mode = ocvTypes["CV_8UC1"]
+ elif nChannels == 3:
+ mode = ocvTypes["CV_8UC3"]
+ elif nChannels == 4:
+ mode = ocvTypes["CV_8UC4"]
+ else:
+ raise ValueError("Invalid number of channels")
+ data = bytearray(array.astype(dtype=np.uint8).ravel())
+ # Creating new Row with _create_row(), because Row(name = value,
... )
+ # orders fields by name, which conflicts with expected schema order
+ # when the new DataFrame is created by UDF
+ return _create_row(self.imageFields,
+ [origin, height, width, nChannels, mode, data])
+
+ def readImages(self, path, recursive=False, numPartitions=-1,
+ dropImageFailures=False, sampleRatio=1.0, seed=0):
+ """
+ Reads the directory of images from the local or remote source.
+
+ WARNINGS:
+ - If multiple jobs are run in parallel with different
sampleRatio or recursive flag,
+ there may be a race condition where one job overwrites the
hadoop configs of another.
+
+ :param str path: Path to the image directory
+ :param bool recursive: Recursive search flag
+ :param int numPartitions: Number of DataFrame partitions
+ :param bool dropImageFailures: Drop the files that are not valid
images
+ :param float sampleRatio: Fraction of the images loaded
+ :param int seed: Random number seed
+ :rtype DataFrame: DataFrame with a single column of "images",
--- End diff --
I am sorry @imatiach-msft for a forth-and-back. I was confused between my
production code and spark codes. Looks we should better use `:return:` just for
consistency with other docs.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]