Repository: spark
Updated Branches:
  refs/heads/master 0605ad761 -> 1edb3175d


[SPARK-21866][ML][PYSPARK] Adding spark image reader

## What changes were proposed in this pull request?
Adding spark image reader, an implementation of schema for representing images 
in spark DataFrames

The code is taken from the spark package located here:
(https://github.com/Microsoft/spark-images)

Please see the JIRA for more information 
(https://issues.apache.org/jira/browse/SPARK-21866)

Please see mailing list for SPIP vote and approval information:
(http://apache-spark-developers-list.1001551.n3.nabble.com/VOTE-SPIP-SPARK-21866-Image-support-in-Apache-Spark-td22510.html)

# Background and motivation
As Apache Spark is being used more and more in the industry, some new use cases 
are emerging for different data formats beyond the traditional SQL types or the 
numerical types (vectors and matrices). Deep Learning applications commonly 
deal with image processing. A number of projects add some Deep Learning 
capabilities to Spark (see list below), but they struggle to communicate with 
each other or with MLlib pipelines because there is no standard way to 
represent an image in Spark DataFrames. We propose to federate efforts for 
representing images in Spark by defining a representation that caters to the 
most common needs of users and library developers.
This SPIP proposes a specification to represent images in Spark DataFrames and 
Datasets (based on existing industrial standards), and an interface for loading 
sources of images. It is not meant to be a full-fledged image processing 
library, but rather the core description that other libraries and users can 
rely on. Several packages already offer various processing facilities for 
transforming images or doing more complex operations, and each has various 
design tradeoffs that make them better as standalone solutions.
This project is a joint collaboration between Microsoft and Databricks, which 
have been testing this design in two open source packages: MMLSpark and Deep 
Learning Pipelines.
The proposed image format is an in-memory, decompressed representation that 
targets low-level applications. It is significantly more liberal in memory 
usage than compressed image representations such as JPEG, PNG, etc., but it 
allows easy communication with popular image processing libraries and has no 
decoding overhead.

## How was this patch tested?

Unit tests in scala ImageSchemaSuite, unit tests in python

Author: Ilya Matiach <il...@microsoft.com>
Author: hyukjinkwon <gurwls...@gmail.com>

Closes #19439 from imatiach-msft/ilmat/spark-images.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1edb3175
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1edb3175
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1edb3175

Branch: refs/heads/master
Commit: 1edb3175d8358c2f6bfc84a0d958342bd5337a62
Parents: 0605ad7
Author: Ilya Matiach <il...@microsoft.com>
Authored: Wed Nov 22 15:45:45 2017 -0800
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Wed Nov 22 15:45:45 2017 -0800

----------------------------------------------------------------------
 .../images/kittens/29.5.a_b_EGDP022204.jpg      | Bin 0 -> 27295 bytes
 data/mllib/images/kittens/54893.jpg             | Bin 0 -> 35914 bytes
 data/mllib/images/kittens/DP153539.jpg          | Bin 0 -> 26354 bytes
 data/mllib/images/kittens/DP802813.jpg          | Bin 0 -> 30432 bytes
 data/mllib/images/kittens/not-image.txt         |   1 +
 data/mllib/images/license.txt                   |  13 +
 data/mllib/images/multi-channel/BGRA.png        | Bin 0 -> 683 bytes
 data/mllib/images/multi-channel/chr30.4.184.jpg | Bin 0 -> 59472 bytes
 data/mllib/images/multi-channel/grayscale.jpg   | Bin 0 -> 36728 bytes
 dev/sparktestsupport/modules.py                 |   1 +
 .../org/apache/spark/ml/image/HadoopUtils.scala | 116 +++++++++
 .../org/apache/spark/ml/image/ImageSchema.scala | 257 +++++++++++++++++++
 .../spark/ml/image/ImageSchemaSuite.scala       | 108 ++++++++
 python/docs/pyspark.ml.rst                      |   8 +
 python/pyspark/ml/image.py                      | 198 ++++++++++++++
 python/pyspark/ml/tests.py                      |  19 ++
 16 files changed, 721 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1edb3175/data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg
----------------------------------------------------------------------
diff --git a/data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg 
b/data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg
new file mode 100644
index 0000000..435e7df
Binary files /dev/null and b/data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg 
differ

http://git-wip-us.apache.org/repos/asf/spark/blob/1edb3175/data/mllib/images/kittens/54893.jpg
----------------------------------------------------------------------
diff --git a/data/mllib/images/kittens/54893.jpg 
b/data/mllib/images/kittens/54893.jpg
new file mode 100644
index 0000000..825630c
Binary files /dev/null and b/data/mllib/images/kittens/54893.jpg differ

http://git-wip-us.apache.org/repos/asf/spark/blob/1edb3175/data/mllib/images/kittens/DP153539.jpg
----------------------------------------------------------------------
diff --git a/data/mllib/images/kittens/DP153539.jpg 
b/data/mllib/images/kittens/DP153539.jpg
new file mode 100644
index 0000000..571efe9
Binary files /dev/null and b/data/mllib/images/kittens/DP153539.jpg differ

http://git-wip-us.apache.org/repos/asf/spark/blob/1edb3175/data/mllib/images/kittens/DP802813.jpg
----------------------------------------------------------------------
diff --git a/data/mllib/images/kittens/DP802813.jpg 
b/data/mllib/images/kittens/DP802813.jpg
new file mode 100644
index 0000000..2d12359
Binary files /dev/null and b/data/mllib/images/kittens/DP802813.jpg differ

http://git-wip-us.apache.org/repos/asf/spark/blob/1edb3175/data/mllib/images/kittens/not-image.txt
----------------------------------------------------------------------
diff --git a/data/mllib/images/kittens/not-image.txt 
b/data/mllib/images/kittens/not-image.txt
new file mode 100644
index 0000000..283e5e9
--- /dev/null
+++ b/data/mllib/images/kittens/not-image.txt
@@ -0,0 +1 @@
+not an image

http://git-wip-us.apache.org/repos/asf/spark/blob/1edb3175/data/mllib/images/license.txt
----------------------------------------------------------------------
diff --git a/data/mllib/images/license.txt b/data/mllib/images/license.txt
new file mode 100644
index 0000000..052f302
--- /dev/null
+++ b/data/mllib/images/license.txt
@@ -0,0 +1,13 @@
+The images in the folder "kittens" are under the creative commons CC0 license, 
or no rights reserved:
+https://creativecommons.org/share-your-work/public-domain/cc0/
+The images are taken from:
+https://ccsearch.creativecommons.org/image/detail/WZnbJSJ2-dzIDiuUUdto3Q==
+https://ccsearch.creativecommons.org/image/detail/_TlKu_rm_QrWlR0zthQTXA==
+https://ccsearch.creativecommons.org/image/detail/OPNnHJb6q37rSZ5o_L5JHQ==
+https://ccsearch.creativecommons.org/image/detail/B2CVP_j5KjwZm7UAVJ3Hvw==
+
+The chr30.4.184.jpg and grayscale.jpg images are also under the CC0 license, 
taken from:
+https://ccsearch.creativecommons.org/image/detail/8eO_qqotBfEm2UYxirLntw==
+
+The image under "multi-channel" directory is under the CC BY-SA 4.0 license 
cropped from:
+https://en.wikipedia.org/wiki/Alpha_compositing#/media/File:Hue_alpha_falloff.png

http://git-wip-us.apache.org/repos/asf/spark/blob/1edb3175/data/mllib/images/multi-channel/BGRA.png
----------------------------------------------------------------------
diff --git a/data/mllib/images/multi-channel/BGRA.png 
b/data/mllib/images/multi-channel/BGRA.png
new file mode 100644
index 0000000..a944c6c
Binary files /dev/null and b/data/mllib/images/multi-channel/BGRA.png differ

http://git-wip-us.apache.org/repos/asf/spark/blob/1edb3175/data/mllib/images/multi-channel/chr30.4.184.jpg
----------------------------------------------------------------------
diff --git a/data/mllib/images/multi-channel/chr30.4.184.jpg 
b/data/mllib/images/multi-channel/chr30.4.184.jpg
new file mode 100644
index 0000000..7068b97
Binary files /dev/null and b/data/mllib/images/multi-channel/chr30.4.184.jpg 
differ

http://git-wip-us.apache.org/repos/asf/spark/blob/1edb3175/data/mllib/images/multi-channel/grayscale.jpg
----------------------------------------------------------------------
diff --git a/data/mllib/images/multi-channel/grayscale.jpg 
b/data/mllib/images/multi-channel/grayscale.jpg
new file mode 100644
index 0000000..621cdd1
Binary files /dev/null and b/data/mllib/images/multi-channel/grayscale.jpg 
differ

http://git-wip-us.apache.org/repos/asf/spark/blob/1edb3175/dev/sparktestsupport/modules.py
----------------------------------------------------------------------
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 91d5667..dacc89f 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -466,6 +466,7 @@ pyspark_ml = Module(
         "pyspark.ml.evaluation",
         "pyspark.ml.feature",
         "pyspark.ml.fpm",
+        "pyspark.ml.image",
         "pyspark.ml.linalg.__init__",
         "pyspark.ml.recommendation",
         "pyspark.ml.regression",

http://git-wip-us.apache.org/repos/asf/spark/blob/1edb3175/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala 
b/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala
new file mode 100644
index 0000000..8c975a2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/image/HadoopUtils.scala
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.image
+
+import scala.language.existentials
+import scala.util.Random
+
+import org.apache.commons.io.FilenameUtils
+import org.apache.hadoop.conf.{Configuration, Configured}
+import org.apache.hadoop.fs.{Path, PathFilter}
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
+
+import org.apache.spark.sql.SparkSession
+
+private object RecursiveFlag {
+  /**
+   * Sets the spark recursive flag and then restores it.
+   *
+   * @param value Value to set
+   * @param spark Existing spark session
+   * @param f The function to evaluate after setting the flag
+   * @return Returns the evaluation result T of the function
+   */
+  def withRecursiveFlag[T](value: Boolean, spark: SparkSession)(f: => T): T = {
+    val flagName = FileInputFormat.INPUT_DIR_RECURSIVE
+    val hadoopConf = spark.sparkContext.hadoopConfiguration
+    val old = Option(hadoopConf.get(flagName))
+    hadoopConf.set(flagName, value.toString)
+    try f finally {
+      old match {
+        case Some(v) => hadoopConf.set(flagName, v)
+        case None => hadoopConf.unset(flagName)
+      }
+    }
+  }
+}
+
+/**
+ * Filter that allows loading a fraction of HDFS files.
+ */
+private class SamplePathFilter extends Configured with PathFilter {
+  val random = new Random()
+
+  // Ratio of files to be read from disk
+  var sampleRatio: Double = 1
+
+  override def setConf(conf: Configuration): Unit = {
+    if (conf != null) {
+      sampleRatio = conf.getDouble(SamplePathFilter.ratioParam, 1)
+      val seed = conf.getLong(SamplePathFilter.seedParam, 0)
+      random.setSeed(seed)
+    }
+  }
+
+  override def accept(path: Path): Boolean = {
+    // Note: checking fileSystem.isDirectory is very slow here, so we use 
basic rules instead
+    !SamplePathFilter.isFile(path) || random.nextDouble() < sampleRatio
+  }
+}
+
+private object SamplePathFilter {
+  val ratioParam = "sampleRatio"
+  val seedParam = "seed"
+
+  def isFile(path: Path): Boolean = FilenameUtils.getExtension(path.toString) 
!= ""
+
+  /**
+   * Sets the HDFS PathFilter flag and then restores it.
+   * Only applies the filter if sampleRatio is less than 1.
+   *
+   * @param sampleRatio Fraction of the files that the filter picks
+   * @param spark Existing Spark session
+   * @param seed Random number seed
+   * @param f The function to evaluate after setting the flag
+   * @return Returns the evaluation result T of the function
+   */
+  def withPathFilter[T](
+      sampleRatio: Double,
+      spark: SparkSession,
+      seed: Long)(f: => T): T = {
+    val sampleImages = sampleRatio < 1
+    if (sampleImages) {
+      val flagName = FileInputFormat.PATHFILTER_CLASS
+      val hadoopConf = spark.sparkContext.hadoopConfiguration
+      val old = Option(hadoopConf.getClass(flagName, null))
+      hadoopConf.setDouble(SamplePathFilter.ratioParam, sampleRatio)
+      hadoopConf.setLong(SamplePathFilter.seedParam, seed)
+      hadoopConf.setClass(flagName, classOf[SamplePathFilter], 
classOf[PathFilter])
+      try f finally {
+        hadoopConf.unset(SamplePathFilter.ratioParam)
+        hadoopConf.unset(SamplePathFilter.seedParam)
+        old match {
+          case Some(v) => hadoopConf.setClass(flagName, v, classOf[PathFilter])
+          case None => hadoopConf.unset(flagName)
+        }
+      }
+    } else {
+      f
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1edb3175/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala 
b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala
new file mode 100644
index 0000000..f7850b2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/image/ImageSchema.scala
@@ -0,0 +1,257 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.image
+
+import java.awt.Color
+import java.awt.color.ColorSpace
+import java.io.ByteArrayInputStream
+import javax.imageio.ImageIO
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.input.PortableDataStream
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
+import org.apache.spark.sql.types._
+
+/**
+ * :: Experimental ::
+ * Defines the image schema and methods to read and manipulate images.
+ */
+@Experimental
+@Since("2.3.0")
+object ImageSchema {
+
+  val undefinedImageType = "Undefined"
+
+  /**
+   * (Scala-specific) OpenCV type mapping supported
+   */
+  val ocvTypes: Map[String, Int] = Map(
+    undefinedImageType -> -1,
+    "CV_8U" -> 0, "CV_8UC1" -> 0, "CV_8UC3" -> 16, "CV_8UC4" -> 24
+  )
+
+  /**
+   * (Java-specific) OpenCV type mapping supported
+   */
+  val javaOcvTypes: java.util.Map[String, Int] = ocvTypes.asJava
+
+  /**
+   * Schema for the image column: Row(String, Int, Int, Int, Int, Array[Byte])
+   */
+  val columnSchema = StructType(
+    StructField("origin", StringType, true) ::
+    StructField("height", IntegerType, false) ::
+    StructField("width", IntegerType, false) ::
+    StructField("nChannels", IntegerType, false) ::
+    // OpenCV-compatible type: CV_8UC3 in most cases
+    StructField("mode", IntegerType, false) ::
+    // Bytes in OpenCV-compatible order: row-wise BGR in most cases
+    StructField("data", BinaryType, false) :: Nil)
+
+  val imageFields: Array[String] = columnSchema.fieldNames
+
+  /**
+   * DataFrame with a single column of images named "image" (nullable)
+   */
+  val imageSchema = StructType(StructField("image", columnSchema, true) :: Nil)
+
+  /**
+   * Gets the origin of the image
+   *
+   * @return The origin of the image
+   */
+  def getOrigin(row: Row): String = row.getString(0)
+
+  /**
+   * Gets the height of the image
+   *
+   * @return The height of the image
+   */
+  def getHeight(row: Row): Int = row.getInt(1)
+
+  /**
+   * Gets the width of the image
+   *
+   * @return The width of the image
+   */
+  def getWidth(row: Row): Int = row.getInt(2)
+
+  /**
+   * Gets the number of channels in the image
+   *
+   * @return The number of channels in the image
+   */
+  def getNChannels(row: Row): Int = row.getInt(3)
+
+  /**
+   * Gets the OpenCV representation as an int
+   *
+   * @return The OpenCV representation as an int
+   */
+  def getMode(row: Row): Int = row.getInt(4)
+
+  /**
+   * Gets the image data
+   *
+   * @return The image data
+   */
+  def getData(row: Row): Array[Byte] = row.getAs[Array[Byte]](5)
+
+  /**
+   * Default values for the invalid image
+   *
+   * @param origin Origin of the invalid image
+   * @return Row with the default values
+   */
+  private[spark] def invalidImageRow(origin: String): Row =
+    Row(Row(origin, -1, -1, -1, ocvTypes(undefinedImageType), 
Array.ofDim[Byte](0)))
+
+  /**
+   * Convert the compressed image (jpeg, png, etc.) into OpenCV
+   * representation and store it in DataFrame Row
+   *
+   * @param origin Arbitrary string that identifies the image
+   * @param bytes Image bytes (for example, jpeg)
+   * @return DataFrame Row or None (if the decompression fails)
+   */
+  private[spark] def decode(origin: String, bytes: Array[Byte]): Option[Row] = 
{
+
+    val img = ImageIO.read(new ByteArrayInputStream(bytes))
+
+    if (img == null) {
+      None
+    } else {
+      val isGray = img.getColorModel.getColorSpace.getType == 
ColorSpace.TYPE_GRAY
+      val hasAlpha = img.getColorModel.hasAlpha
+
+      val height = img.getHeight
+      val width = img.getWidth
+      val (nChannels, mode) = if (isGray) {
+        (1, ocvTypes("CV_8UC1"))
+      } else if (hasAlpha) {
+        (4, ocvTypes("CV_8UC4"))
+      } else {
+        (3, ocvTypes("CV_8UC3"))
+      }
+
+      val imageSize = height * width * nChannels
+      assert(imageSize < 1e9, "image is too large")
+      val decoded = Array.ofDim[Byte](imageSize)
+
+      // Grayscale images in Java require special handling to get the correct 
intensity
+      if (isGray) {
+        var offset = 0
+        val raster = img.getRaster
+        for (h <- 0 until height) {
+          for (w <- 0 until width) {
+            decoded(offset) = raster.getSample(w, h, 0).toByte
+            offset += 1
+          }
+        }
+      } else {
+        var offset = 0
+        for (h <- 0 until height) {
+          for (w <- 0 until width) {
+            val color = new Color(img.getRGB(w, h))
+
+            decoded(offset) = color.getBlue.toByte
+            decoded(offset + 1) = color.getGreen.toByte
+            decoded(offset + 2) = color.getRed.toByte
+            if (nChannels == 4) {
+              decoded(offset + 3) = color.getAlpha.toByte
+            }
+            offset += nChannels
+          }
+        }
+      }
+
+      // the internal "Row" is needed, because the image is a single DataFrame 
column
+      Some(Row(Row(origin, height, width, nChannels, mode, decoded)))
+    }
+  }
+
+  /**
+   * Read the directory of images from the local or remote source
+   *
+   * @note If multiple jobs are run in parallel with different sampleRatio or 
recursive flag,
+   * there may be a race condition where one job overwrites the hadoop configs 
of another.
+   * @note If sample ratio is less than 1, sampling uses a PathFilter that is 
efficient but
+   * potentially non-deterministic.
+   *
+   * @param path Path to the image directory
+   * @return DataFrame with a single column "image" of images;
+   *         see ImageSchema for the details
+   */
+  def readImages(path: String): DataFrame = readImages(path, null, false, -1, 
false, 1.0, 0)
+
+  /**
+   * Read the directory of images from the local or remote source
+   *
+   * @note If multiple jobs are run in parallel with different sampleRatio or 
recursive flag,
+   * there may be a race condition where one job overwrites the hadoop configs 
of another.
+   * @note If sample ratio is less than 1, sampling uses a PathFilter that is 
efficient but
+   * potentially non-deterministic.
+   *
+   * @param path Path to the image directory
+   * @param sparkSession Spark Session, if omitted gets or creates the session
+   * @param recursive Recursive path search flag
+   * @param numPartitions Number of the DataFrame partitions,
+   *                      if omitted uses defaultParallelism instead
+   * @param dropImageFailures Drop the files that are not valid images from 
the result
+   * @param sampleRatio Fraction of the files loaded
+   * @return DataFrame with a single column "image" of images;
+   *         see ImageSchema for the details
+   */
+  def readImages(
+      path: String,
+      sparkSession: SparkSession,
+      recursive: Boolean,
+      numPartitions: Int,
+      dropImageFailures: Boolean,
+      sampleRatio: Double,
+      seed: Long): DataFrame = {
+    require(sampleRatio <= 1.0 && sampleRatio >= 0, "sampleRatio should be 
between 0 and 1")
+
+    val session = if (sparkSession != null) sparkSession else 
SparkSession.builder().getOrCreate
+    val partitions =
+      if (numPartitions > 0) {
+        numPartitions
+      } else {
+        session.sparkContext.defaultParallelism
+      }
+
+    RecursiveFlag.withRecursiveFlag(recursive, session) {
+      SamplePathFilter.withPathFilter(sampleRatio, session, seed) {
+        val binResult = session.sparkContext.binaryFiles(path, partitions)
+        val streams = if (numPartitions == -1) binResult else 
binResult.repartition(partitions)
+        val convert = (origin: String, bytes: PortableDataStream) =>
+          decode(origin, bytes.toArray())
+        val images = if (dropImageFailures) {
+          streams.flatMap { case (origin, bytes) => convert(origin, bytes) }
+        } else {
+          streams.map { case (origin, bytes) =>
+            convert(origin, bytes).getOrElse(invalidImageRow(origin))
+          }
+        }
+        session.createDataFrame(images, imageSchema)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1edb3175/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala
new file mode 100644
index 0000000..dba61cd
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/image/ImageSchemaSuite.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.image
+
+import java.nio.file.Paths
+import java.util.Arrays
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.image.ImageSchema._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+
+class ImageSchemaSuite extends SparkFunSuite with MLlibTestSparkContext {
+  // Single column of images named "image"
+  private lazy val imagePath = "../data/mllib/images"
+
+  test("Smoke test: create basic ImageSchema dataframe") {
+    val origin = "path"
+    val width = 1
+    val height = 1
+    val nChannels = 3
+    val data = Array[Byte](0, 0, 0)
+    val mode = ocvTypes("CV_8UC3")
+
+    // Internal Row corresponds to image StructType
+    val rows = Seq(Row(Row(origin, height, width, nChannels, mode, data)),
+      Row(Row(null, height, width, nChannels, mode, data)))
+    val rdd = sc.makeRDD(rows)
+    val df = spark.createDataFrame(rdd, ImageSchema.imageSchema)
+
+    assert(df.count === 2, "incorrect image count")
+    assert(df.schema("image").dataType == columnSchema, "data do not fit 
ImageSchema")
+  }
+
+  test("readImages count test") {
+    var df = readImages(imagePath)
+    assert(df.count === 1)
+
+    df = readImages(imagePath, null, true, -1, false, 1.0, 0)
+    assert(df.count === 9)
+
+    df = readImages(imagePath, null, true, -1, true, 1.0, 0)
+    val countTotal = df.count
+    assert(countTotal === 7)
+
+    df = readImages(imagePath, null, true, -1, true, 0.5, 0)
+    // Random number about half of the size of the original dataset
+    val count50 = df.count
+    assert(count50 > 0 && count50 < countTotal)
+  }
+
+  test("readImages partition test") {
+    val df = readImages(imagePath, null, true, 3, true, 1.0, 0)
+    assert(df.rdd.getNumPartitions === 3)
+  }
+
+  // Images with the different number of channels
+  test("readImages pixel values test") {
+
+    val images = readImages(imagePath + "/multi-channel/").collect
+
+    images.foreach { rrow =>
+      val row = rrow.getAs[Row](0)
+      val filename = Paths.get(getOrigin(row)).getFileName().toString()
+      if (firstBytes20.contains(filename)) {
+        val mode = getMode(row)
+        val bytes20 = getData(row).slice(0, 20)
+
+        val (expectedMode, expectedBytes) = firstBytes20(filename)
+        assert(ocvTypes(expectedMode) === mode, "mode of the image is not read 
correctly")
+        assert(Arrays.equals(expectedBytes, bytes20), "incorrect numeric value 
for flattened image")
+      }
+    }
+  }
+
+  // number of channels and first 20 bytes of OpenCV representation
+  // - default representation for 3-channel RGB images is BGR row-wise:
+  //   (B00, G00, R00,      B10, G10, R10,      ...)
+  // - default representation for 4-channel RGB images is BGRA row-wise:
+  //   (B00, G00, R00, A00, B10, G10, R10, A00, ...)
+  private val firstBytes20 = Map(
+    "grayscale.jpg" ->
+      (("CV_8UC1", Array[Byte](-2, -33, -61, -60, -59, -59, -64, -59, -66, 
-67, -73, -73, -62,
+        -57, -60, -63, -53, -49, -55, -69))),
+    "chr30.4.184.jpg" -> (("CV_8UC3",
+      Array[Byte](-9, -3, -1, -43, -32, -28, -75, -60, -57, -78, -59, -56, 
-74, -59, -57,
+        -71, -58, -56, -73, -64))),
+    "BGRA.png" -> (("CV_8UC4",
+      Array[Byte](-128, -128, -8, -1, -128, -128, -8, -1, -128,
+        -128, -8, -1, 127, 127, -9, -1, 127, 127, -9, -1)))
+  )
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1edb3175/python/docs/pyspark.ml.rst
----------------------------------------------------------------------
diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst
index 01627ba..6a5d817 100644
--- a/python/docs/pyspark.ml.rst
+++ b/python/docs/pyspark.ml.rst
@@ -97,6 +97,14 @@ pyspark.ml.fpm module
     :undoc-members:
     :inherited-members:
 
+pyspark.ml.image module
+----------------------------
+
+.. automodule:: pyspark.ml.image
+    :members:
+    :undoc-members:
+    :inherited-members:
+
 pyspark.ml.util module
 ----------------------------
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1edb3175/python/pyspark/ml/image.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/image.py b/python/pyspark/ml/image.py
new file mode 100644
index 0000000..7d14f05
--- /dev/null
+++ b/python/pyspark/ml/image.py
@@ -0,0 +1,198 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+.. attribute:: ImageSchema
+
+    An attribute of this module that contains the instance of 
:class:`_ImageSchema`.
+
+.. autoclass:: _ImageSchema
+   :members:
+"""
+
+import numpy as np
+from pyspark import SparkContext
+from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string
+from pyspark.sql import DataFrame, SparkSession
+
+
+class _ImageSchema(object):
+    """
+    Internal class for `pyspark.ml.image.ImageSchema` attribute. Meant to be 
private and
+    not to be instantized. Use `pyspark.ml.image.ImageSchema` attribute to 
access the
+    APIs of this class.
+    """
+
+    def __init__(self):
+        self._imageSchema = None
+        self._ocvTypes = None
+        self._imageFields = None
+        self._undefinedImageType = None
+
+    @property
+    def imageSchema(self):
+        """
+        Returns the image schema.
+
+        :return: a :class:`StructType` with a single column of images
+               named "image" (nullable).
+
+        .. versionadded:: 2.3.0
+        """
+
+        if self._imageSchema is None:
+            ctx = SparkContext._active_spark_context
+            jschema = 
ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageSchema()
+            self._imageSchema = _parse_datatype_json_string(jschema.json())
+        return self._imageSchema
+
+    @property
+    def ocvTypes(self):
+        """
+        Returns the OpenCV type mapping supported.
+
+        :return: a dictionary containing the OpenCV type mapping supported.
+
+        .. versionadded:: 2.3.0
+        """
+
+        if self._ocvTypes is None:
+            ctx = SparkContext._active_spark_context
+            self._ocvTypes = 
dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes())
+        return self._ocvTypes
+
+    @property
+    def imageFields(self):
+        """
+        Returns field names of image columns.
+
+        :return: a list of field names.
+
+        .. versionadded:: 2.3.0
+        """
+
+        if self._imageFields is None:
+            ctx = SparkContext._active_spark_context
+            self._imageFields = 
list(ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageFields())
+        return self._imageFields
+
+    @property
+    def undefinedImageType(self):
+        """
+        Returns the name of undefined image type for the invalid image.
+
+        .. versionadded:: 2.3.0
+        """
+
+        if self._undefinedImageType is None:
+            ctx = SparkContext._active_spark_context
+            self._undefinedImageType = \
+                
ctx._jvm.org.apache.spark.ml.image.ImageSchema.undefinedImageType()
+        return self._undefinedImageType
+
+    def toNDArray(self, image):
+        """
+        Converts an image to an array with metadata.
+
+        :param image: The image to be converted.
+        :return: a `numpy.ndarray` that is an image.
+
+        .. versionadded:: 2.3.0
+        """
+
+        height = image.height
+        width = image.width
+        nChannels = image.nChannels
+        return np.ndarray(
+            shape=(height, width, nChannels),
+            dtype=np.uint8,
+            buffer=image.data,
+            strides=(width * nChannels, nChannels, 1))
+
+    def toImage(self, array, origin=""):
+        """
+        Converts an array with metadata to a two-dimensional image.
+
+        :param array array: The array to convert to image.
+        :param str origin: Path to the image, optional.
+        :return: a :class:`Row` that is a two dimensional image.
+
+        .. versionadded:: 2.3.0
+        """
+
+        if array.ndim != 3:
+            raise ValueError("Invalid array shape")
+        height, width, nChannels = array.shape
+        ocvTypes = ImageSchema.ocvTypes
+        if nChannels == 1:
+            mode = ocvTypes["CV_8UC1"]
+        elif nChannels == 3:
+            mode = ocvTypes["CV_8UC3"]
+        elif nChannels == 4:
+            mode = ocvTypes["CV_8UC4"]
+        else:
+            raise ValueError("Invalid number of channels")
+        data = bytearray(array.astype(dtype=np.uint8).ravel())
+        # Creating new Row with _create_row(), because Row(name = value, ... )
+        # orders fields by name, which conflicts with expected schema order
+        # when the new DataFrame is created by UDF
+        return _create_row(self.imageFields,
+                           [origin, height, width, nChannels, mode, data])
+
+    def readImages(self, path, recursive=False, numPartitions=-1,
+                   dropImageFailures=False, sampleRatio=1.0, seed=0):
+        """
+        Reads the directory of images from the local or remote source.
+
+        .. note:: If multiple jobs are run in parallel with different 
sampleRatio or recursive flag,
+            there may be a race condition where one job overwrites the hadoop 
configs of another.
+
+        .. note:: If sample ratio is less than 1, sampling uses a PathFilter 
that is efficient but
+            potentially non-deterministic.
+
+        :param str path: Path to the image directory.
+        :param bool recursive: Recursive search flag.
+        :param int numPartitions: Number of DataFrame partitions.
+        :param bool dropImageFailures: Drop the files that are not valid 
images.
+        :param float sampleRatio: Fraction of the images loaded.
+        :param int seed: Random number seed.
+        :return: a :class:`DataFrame` with a single column of "images",
+               see ImageSchema for details.
+
+        >>> df = ImageSchema.readImages('python/test_support/image/kittens', 
recursive=True)
+        >>> df.count()
+        4
+
+        .. versionadded:: 2.3.0
+        """
+
+        ctx = SparkContext._active_spark_context
+        spark = SparkSession(ctx)
+        image_schema = ctx._jvm.org.apache.spark.ml.image.ImageSchema
+        jsession = spark._jsparkSession
+        jresult = image_schema.readImages(path, jsession, recursive, 
numPartitions,
+                                          dropImageFailures, 
float(sampleRatio), seed)
+        return DataFrame(jresult, spark._wrapped)
+
+
+ImageSchema = _ImageSchema()
+
+
+# Monkey patch to disallow instantization of this class.
+def _disallow_instance(_):
+    raise RuntimeError("Creating instance of _ImageSchema class is 
disallowed.")
+_ImageSchema.__init__ = _disallow_instance

http://git-wip-us.apache.org/repos/asf/spark/blob/1edb3175/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 2f1f3af..2258d61 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -54,6 +54,7 @@ from pyspark.ml.evaluation import 
BinaryClassificationEvaluator, \
     MulticlassClassificationEvaluator, RegressionEvaluator
 from pyspark.ml.feature import *
 from pyspark.ml.fpm import FPGrowth, FPGrowthModel
+from pyspark.ml.image import ImageSchema
 from pyspark.ml.linalg import DenseMatrix, DenseMatrix, DenseVector, Matrices, 
MatrixUDT, \
     SparseMatrix, SparseVector, Vector, VectorUDT, Vectors
 from pyspark.ml.param import Param, Params, TypeConverters
@@ -1818,6 +1819,24 @@ class FPGrowthTests(SparkSessionTestCase):
         del self.data
 
 
+class ImageReaderTest(SparkSessionTestCase):
+
+    def test_read_images(self):
+        data_path = 'data/mllib/images/kittens'
+        df = ImageSchema.readImages(data_path, recursive=True, 
dropImageFailures=True)
+        self.assertEqual(df.count(), 4)
+        first_row = df.take(1)[0][0]
+        array = ImageSchema.toNDArray(first_row)
+        self.assertEqual(len(array), first_row[1])
+        self.assertEqual(ImageSchema.toImage(array, origin=first_row[0]), 
first_row)
+        self.assertEqual(df.schema, ImageSchema.imageSchema)
+        expected = {'CV_8UC3': 16, 'Undefined': -1, 'CV_8U': 0, 'CV_8UC1': 0, 
'CV_8UC4': 24}
+        self.assertEqual(ImageSchema.ocvTypes, expected)
+        expected = ['origin', 'height', 'width', 'nChannels', 'mode', 'data']
+        self.assertEqual(ImageSchema.imageFields, expected)
+        self.assertEqual(ImageSchema.undefinedImageType, "Undefined")
+
+
 class ALSTest(SparkSessionTestCase):
 
     def test_storage_levels(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to