This is an automated email from the ASF dual-hosted git repository.
cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 6535c06 [MXNET-1180] Scala Image API (#12995)
6535c06 is described below
commit 6535c061e7674dd72e001d943e17b3804e60646d
Author: Lanking <[email protected]>
AuthorDate: Fri Nov 2 08:02:48 2018 -0700
[MXNET-1180] Scala Image API (#12995)
* add image and image suite
* apply toImage function and tests
* bug fix
* apply the commented change
* add test to apply border
* fix scalastyle
---
.../src/main/scala/org/apache/mxnet/Image.scala | 185 +++++++++++++++++++++
.../test/scala/org/apache/mxnet/ImageSuite.scala | 100 +++++++++++
2 files changed, 285 insertions(+)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
new file mode 100644
index 0000000..43f81a2
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Image.scala
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+// scalastyle:off
+import java.awt.image.BufferedImage
+// scalastyle:on
+import java.io.InputStream
+
+import scala.collection.mutable
+import scala.collection.mutable.{ArrayBuffer, ListBuffer}
+
+/**
+ * Image API of Scala package
+ * enable OpenCV feature
+ */
+object Image {
+
+ /**
+ * Decode image with OpenCV.
+ * Note: return image in RGB by default, instead of OpenCV's default BGR.
+ * @param buf Buffer containing binary encoded image
+ * @param flag Convert decoded image to grayscale (0) or color (1).
+ * @param to_rgb Whether to convert decoded image
+ * to mxnet's default RGB format (instead of opencv's default
BGR).
+ * @return NDArray in HWC format
+ */
+ def imDecode(buf: Array[Byte], flag: Int,
+ to_rgb: Boolean,
+ out: Option[NDArray]): NDArray = {
+ val nd = NDArray.array(buf.map(_.toFloat), Shape(buf.length))
+ val byteND = NDArray.api.cast(nd, "uint8")
+ val args : ListBuffer[Any] = ListBuffer()
+ val map : mutable.Map[String, Any] = mutable.Map()
+ args += byteND
+ map("flag") = flag
+ map("to_rgb") = to_rgb
+ if (out.isDefined) map("out") = out.get
+ NDArray.genericNDArrayFunctionInvoke("_cvimdecode", args, map.toMap)
+ }
+
+ /**
+ * Same imageDecode with InputStream
+ * @param inputStream the inputStream of the image
+ * @return NDArray in HWC format
+ */
+ def imDecode(inputStream: InputStream, flag: Int = 1,
+ to_rgb: Boolean = true,
+ out: Option[NDArray] = None): NDArray = {
+ val buffer = new Array[Byte](2048)
+ val arrBuffer = ArrayBuffer[Byte]()
+ var length = 0
+ while (length != -1) {
+ length = inputStream.read(buffer)
+ if (length != -1) arrBuffer ++= buffer.slice(0, length)
+ }
+ imDecode(arrBuffer.toArray, flag, to_rgb, out)
+ }
+
+ /**
+ * Read and decode image with OpenCV.
+ * Note: return image in RGB by default, instead of OpenCV's default BGR.
+ * @param filename Name of the image file to be loaded.
+ * @param flag Convert decoded image to grayscale (0) or color (1).
+ * @param to_rgb Whether to convert decoded image to mxnet's default RGB
format
+ * (instead of opencv's default BGR).
+ * @return org.apache.mxnet.NDArray in HWC format
+ */
+ def imRead(filename: String, flag: Option[Int] = None,
+ to_rgb: Option[Boolean] = None,
+ out: Option[NDArray] = None): NDArray = {
+ val args : ListBuffer[Any] = ListBuffer()
+ val map : mutable.Map[String, Any] = mutable.Map()
+ map("filename") = filename
+ if (flag.isDefined) map("flag") = flag.get
+ if (to_rgb.isDefined) map("to_rgb") = to_rgb.get
+ if (out.isDefined) map("out") = out.get
+ NDArray.genericNDArrayFunctionInvoke("_cvimread", args, map.toMap)
+ }
+
+ /**
+ * Resize image with OpenCV.
+ * @param src source image in NDArray
+ * @param w Width of resized image.
+ * @param h Height of resized image.
+ * @param interp Interpolation method (default=cv2.INTER_LINEAR).
+ * @return org.apache.mxnet.NDArray
+ */
+ def imResize(src: org.apache.mxnet.NDArray, w: Int, h: Int,
+ interp: Option[Int] = None,
+ out: Option[NDArray] = None): NDArray = {
+ val args : ListBuffer[Any] = ListBuffer()
+ val map : mutable.Map[String, Any] = mutable.Map()
+ args += src
+ map("w") = w
+ map("h") = h
+ if (interp.isDefined) map("interp") = interp.get
+ if (out.isDefined) map("out") = out.get
+ NDArray.genericNDArrayFunctionInvoke("_cvimresize", args, map.toMap)
+ }
+
+ /**
+ * Pad image border with OpenCV.
+ * @param src source image
+ * @param top Top margin.
+ * @param bot Bottom margin.
+ * @param left Left margin.
+ * @param right Right margin.
+ * @param typeOf Filling type (default=cv2.BORDER_CONSTANT).
+ * @param value (Deprecated! Use ``values`` instead.) Fill with single
value.
+ * @param values Fill with value(RGB[A] or gray), up to 4 channels.
+ * @return org.apache.mxnet.NDArray
+ */
+ def copyMakeBorder(src: org.apache.mxnet.NDArray, top: Int, bot: Int,
+ left: Int, right: Int, typeOf: Option[Int] = None,
+ value: Option[Double] = None, values: Option[Any] = None,
+ out: Option[NDArray] = None): NDArray = {
+ val args : ListBuffer[Any] = ListBuffer()
+ val map : mutable.Map[String, Any] = mutable.Map()
+ args += src
+ map("top") = top
+ map("bot") = bot
+ map("left") = left
+ map("right") = right
+ if (typeOf.isDefined) map("type") = typeOf.get
+ if (value.isDefined) map("value") = value.get
+ if (values.isDefined) map("values") = values.get
+ if (out.isDefined) map("out") = out.get
+ NDArray.genericNDArrayFunctionInvoke("_cvcopyMakeBorder", args, map.toMap)
+ }
+
+ /**
+ * Do a fixed crop on the image
+ * @param src Src image in NDArray
+ * @param x0 starting x point
+ * @param y0 starting y point
+ * @param w width of the image
+ * @param h height of the image
+ * @return cropped NDArray
+ */
+ def fixedCrop(src: NDArray, x0: Int, y0: Int, w: Int, h: Int): NDArray = {
+ NDArray.api.crop(src, Shape(y0, x0, 0), Shape(y0 + h, x0 + w,
src.shape.get(2)))
+ }
+
+ /**
+ * Convert a NDArray image to a real image
+ * The time cost will increase if the image resolution is big
+ * @param src Source image file in RGB
+ * @return Buffered Image
+ */
+ def toImage(src: NDArray): BufferedImage = {
+ require(src.dtype == DType.UInt8, "The input NDArray must be bytes")
+ require(src.shape.length == 3, "The input should contains height, width
and channel")
+ val height = src.shape.get(0)
+ val width = src.shape.get(1)
+ val img = new BufferedImage(width, height, BufferedImage.TYPE_INT_RGB)
+ (0 until height).par.foreach(r => {
+ (0 until width).par.foreach(c => {
+ val arr = src.at(r).at(c).toArray
+ // NDArray in RGB
+ val red = arr(0).toByte & 0xFF
+ val green = arr(1).toByte & 0xFF
+ val blue = arr(2).toByte & 0xFF
+ val rgb = (red << 16) | (green << 8) | blue
+ img.setRGB(c, r, rgb)
+ })
+ })
+ img
+ }
+
+}
diff --git
a/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala
b/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala
new file mode 100644
index 0000000..67815ad
--- /dev/null
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ImageSuite.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import java.io.File
+import java.net.URL
+
+import javax.imageio.ImageIO
+import org.apache.commons.io.FileUtils
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.slf4j.LoggerFactory
+
+class ImageSuite extends FunSuite with BeforeAndAfterAll {
+ private var imLocation = ""
+ private val logger = LoggerFactory.getLogger(classOf[ImageSuite])
+
+ private def downloadUrl(url: String, filePath: String, maxRetry: Option[Int]
= None) : Unit = {
+ val tmpFile = new File(filePath)
+ var retry = maxRetry.getOrElse(3)
+ var success = false
+ if (!tmpFile.exists()) {
+ while (retry > 0 && !success) {
+ try {
+ FileUtils.copyURLToFile(new URL(url), tmpFile)
+ success = true
+ } catch {
+ case e: Exception => retry -= 1
+ }
+ }
+ } else {
+ success = true
+ }
+ if (!success) throw new Exception(s"$url Download failed!")
+ }
+
+ override def beforeAll(): Unit = {
+ val tempDirPath = System.getProperty("java.io.tmpdir")
+ imLocation = tempDirPath + "/inputImages/Pug-Cookie.jpg"
+ downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg",
+ imLocation)
+ }
+
+ test("Test load image") {
+ val nd = Image.imRead(imLocation)
+ logger.info(s"OpenCV load image with shape: ${nd.shape}")
+ require(nd.shape == Shape(576, 1024, 3), "image shape not Match!")
+ }
+
+ test("Test load image from Socket") {
+ val url = new
URL("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg")
+ val inputStream = url.openStream
+ val nd = Image.imDecode(inputStream)
+ logger.info(s"OpenCV load image with shape: ${nd.shape}")
+ require(nd.shape == Shape(576, 1024, 3), "image shape not Match!")
+ }
+
+ test("Test resize image") {
+ val nd = Image.imRead(imLocation)
+ val resizeIm = Image.imResize(nd, 224, 224)
+ logger.info(s"OpenCV resize image with shape: ${resizeIm.shape}")
+ require(resizeIm.shape == Shape(224, 224, 3), "image shape not Match!")
+ }
+
+ test("Test crop image") {
+ val nd = Image.imRead(imLocation)
+ val nd2 = Image.fixedCrop(nd, 0, 0, 224, 224)
+ require(nd2.shape == Shape(224, 224, 3), "image shape not Match!")
+ }
+
+ test("Test apply border") {
+ val nd = Image.imRead(imLocation)
+ val nd2 = Image.copyMakeBorder(nd, 1, 1, 1, 1)
+ require(nd2.shape == Shape(578, 1026, 3), s"image shape not Match!")
+ }
+
+ test("Test convert to Image") {
+ val nd = Image.imRead(imLocation)
+ val resizeIm = Image.imResize(nd, 224, 224)
+ val tempDirPath = System.getProperty("java.io.tmpdir")
+ val img = Image.toImage(resizeIm)
+ ImageIO.write(img, "png", new File(tempDirPath + "/inputImages/out.png"))
+ logger.info(s"converted image stored in ${tempDirPath +
"/inputImages/out.png"}")
+ }
+
+}