This is an automated email from the ASF dual-hosted git repository.
lanking pushed a commit to branch v1.4.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.4.x by this push:
new 69515c2 [v1.4.1] Java bug-fix cherry pick (#14834)
69515c2 is described below
commit 69515c2f9b1ac6fd4b661d5411a97de968cf4e2e
Author: Lanking <[email protected]>
AuthorDate: Mon Apr 29 16:03:36 2019 -0700
[v1.4.1] Java bug-fix cherry pick (#14834)
* clean up submodule (#14645)
* Scala/Java Predict API fix #14756 (#14804)
* add fix in the code
* add unit test
* update comments
* add fixes to code gen
---
.../scala/org/apache/mxnet/module/BaseModule.scala | 17 +-
.../java/org/apache/mxnet/javaapi/NDArrayTest.java | 4 +-
.../test/scala/org/apache/mxnet/ModuleSuite.scala | 28 ++++
.../scala/org/apache/mxnet/APIDocGenerator.scala | 184 ++++++++++++++-------
4 files changed, 173 insertions(+), 60 deletions(-)
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
index b73f4ad..73ccef2 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
@@ -247,11 +247,23 @@ abstract class BaseModule {
/**
* Run prediction and collect the outputs.
- * @param evalData
+ * @param evalData dataIter to do the Inference
* @param numBatch Default is -1, indicating running all the batches in the
data iterator.
* @param reset Default is `True`, indicating whether we should reset the
data iter before start
* doing prediction.
* @return The return value will be a list `[out1, out2, out3]`.
+ * The concatenation process will be like
+ * {{{
+ * outputBatches = [
+ * [a1, a2, a3], // batch a
+ * [b1, b2, b3] // batch b
+ * ]
+ * result = [
+ * NDArray, // [a1, b1]
+ * NDArray, // [a2, b2]
+ * NDArray, // [a3, b3]
+ * ]
+ * }}}
* Where each element is concatenation of the outputs for all the
mini-batches.
*/
def predict(evalData: DataIter, numBatch: Int = -1, reset: Boolean = true)
@@ -264,7 +276,8 @@ abstract class BaseModule {
s"in mini-batches (${out.size})." +
"Maybe bucketing is used?")
)
- val concatenatedOutput = outputBatches.map(out => NDArray.concatenate(out))
+ val oBT = outputBatches.transpose
+ val concatenatedOutput = oBT.map(out => NDArray.concatenate(out))
outputBatches.foreach(_.foreach(_.dispose()))
concatenatedOutput
}
diff --git
a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
index 2659b78..5bbe8bb 100644
--- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
@@ -71,7 +71,7 @@ public class NDArrayTest {
NDArray$ NDArray = NDArray$.MODULE$;
float[] arr = new float[]{1.0f, 2.0f, 3.0f};
NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new
Context("cpu", 0));
- float result = NDArray.norm(NDArray.new normParam(nd))[0].toArray()[0];
+ float result = NDArray.norm(new normParam(nd))[0].toArray()[0];
float cal = 0.0f;
for (float ele : arr) {
cal += ele * ele;
@@ -79,7 +79,7 @@ public class NDArrayTest {
cal = (float) Math.sqrt(cal);
assertTrue(Math.abs(result - cal) < 1e-5);
NDArray dotResult = new NDArray(new float[]{0}, new Shape(new
int[]{1}), new Context("cpu", 0));
- NDArray.dot(NDArray.new dotParam(nd, nd).setOut(dotResult));
+ NDArray.dot(new dotParam(nd, nd).setOut(dotResult));
assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f}));
}
}
diff --git
a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
index 88e314e..e6ebfd3 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
@@ -23,6 +23,34 @@ import org.apache.mxnet.optimizer._
import org.apache.mxnet.io._
class ModuleSuite extends FunSuite with BeforeAndAfterAll {
+
+ class myModule(symbol : Symbol) extends Module (symbol) {
+ override def predictEveryBatch(evalData: DataIter,
+ numBatch: Int = 1, reset: Boolean = true):
+ IndexedSeq[IndexedSeq[NDArray]] = {
+ val data = IndexedSeq(
+ NDArray.ones(Shape(1, 10, 1)),
+ NDArray.ones(Shape(1, 10, 1)),
+ NDArray.ones(Shape(1, 10, 4))
+ )
+ List.fill(numBatch)(data).toIndexedSeq
+ }
+ }
+
+ test("predict") {
+ val sym = Symbol.Variable("data")
+ val mod = new myModule(sym)
+ val dummyIter = new NDArrayIter(IndexedSeq(NDArray.ones(1)))
+ var output = mod.predict(dummyIter, 1)
+ require(output(0).shape == Shape(1, 10, 1))
+ require(output(1).shape == Shape(1, 10, 1))
+ require(output(2).shape == Shape(1, 10, 4))
+ output = mod.predict(dummyIter, 2)
+ require(output(0).shape == Shape(2, 10, 1))
+ require(output(1).shape == Shape(2, 10, 1))
+ require(output(2).shape == Shape(2, 10, 4))
+ }
+
test ("model dtype") {
val dType = DType.Float32
val dShape = Shape(3, 8, 7)
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
index ce12dc7..77a2704 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
@@ -23,12 +23,16 @@ import java.security.MessageDigest
import scala.collection.mutable.ListBuffer
/**
- * This object will generate the Scala documentation of the new Scala API
- * Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
+ * This object will generate the Scala documentation of the Scala/Java APIs
* The code will be executed during Macros stage and file live in Core stage
*/
private[mxnet] object APIDocGenerator extends GeneratorBase {
+ /**
+ * Main method used to generate code and write to files
+ * A hash check placed at the end to verify changes
+ * @param args Input args
+ */
def main(args: Array[String]): Unit = {
val FILE_PATH = args(0)
val hashCollector = ListBuffer[String]()
@@ -40,6 +44,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
val finalHash = hashCollector.mkString("\n")
}
+ /**
+ * Generate MD5 result from an input string
+ * Encoded in UTF-8
+ * @param input The input string
+ * @return A MD5 value from the string
+ */
def MD5Generator(input: String): String = {
val md = MessageDigest.getInstance("MD5")
md.update(input.getBytes("UTF-8"))
@@ -47,6 +57,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
}
+ /**
+ * Type-safe class body generation for NDArray/Symbol
+ * @param FILE_PATH File path write the file to
+ * @param isSymbol Check if write the Symbol API, NDArray otherwise
+ * @return MD5 String
+ */
def typeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val generated = typeSafeFunctionsToGenerate(isSymbol, isContrib = false)
.map { func =>
@@ -57,11 +73,22 @@ private[mxnet] object APIDocGenerator extends GeneratorBase
{
writeFile(
FILE_PATH,
- if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
"package org.apache.mxnet",
+ if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
+ "import org.apache.mxnet.annotation.Experimental",
generated)
}
+ /**
+ * Non Type-safe interface of Scala Symbol/NDArray
+ * It includes class definition : e.g class SymbolBase
+ * and function definitions : e.g def softmax(...)(...)(...) : NDArray
+ * Users can directly use the api by calling NDArray.<function_name>
+ * It support both positional input or Map input
+ * @param FILE_PATH File path write the file to
+ * @param isSymbol Check if write the Symbol API, NDArray otherwise
+ * @return MD5 String
+ */
def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val absFuncs = functionsToGenerate(isSymbol, isContrib = false)
.map { func =>
@@ -85,34 +112,53 @@ private[mxnet] object APIDocGenerator extends
GeneratorBase {
writeFile(
FILE_PATH,
- if (isSymbol) "SymbolBase" else "NDArrayBase",
"package org.apache.mxnet",
+ if (isSymbol) "SymbolBase" else "NDArrayBase",
+ "import org.apache.mxnet.annotation.Experimental",
absFuncs)
}
- def javaClassGen(filePath : String) : String = {
+ /**
+ * Type-safe interface of Java NDArray
+ * @param FILE_PATH File path write the file to
+ * @return MD5 String
+ */
+ def javaClassGen(FILE_PATH : String) : String = {
val notGenerated = Set("Custom")
val absClassFunctions = functionsToGenerate(false, false, true)
- val absFuncs = absClassFunctions.filterNot(ele =>
notGenerated.contains(ele.name))
- .groupBy(_.name.toLowerCase).map(ele => {
- /* Pattern matching for not generating deprecated method
- * Group all method name in lowercase
- * Kill the capital lettered method such as Cast vs cast
- * As it defined by default it deprecated
- */
- if (ele._2.length == 1) ele._2.head
- else {
- if (ele._2.head.name.head.isLower) ele._2.head
- else ele._2.last
- }
- }).map(absClassFunction => {
+ val (absFuncs, paramClassUncleaned) =
+ absClassFunctions.filterNot(ele => notGenerated.contains(ele.name))
+ .groupBy(_.name.toLowerCase).map(ele => {
+ /* Pattern matching for not generating deprecated method
+ * Group all method name in lowercase
+ * Kill the capital lettered method such as Cast vs cast
+ * As it defined by default it deprecated
+ */
+ if (ele._2.length == 1) ele._2.head
+ else {
+ if (ele._2.head.name.head.isLower) ele._2.head
+ else ele._2.last
+ }
+ }).map(absClassFunction => {
generateJavaAPISignature(absClassFunction)
- }).toSeq
+ }).toSeq.unzip
+ val paramClass = paramClassUncleaned.filterNot(_.isEmpty)
val packageName = "NDArrayBase"
val packageDef = "package org.apache.mxnet.javaapi"
- writeFile(filePath + "javaapi/", packageName, packageDef, absFuncs)
+ writeFile(
+ FILE_PATH + "javaapi/",
+ packageDef,
+ packageName,
+ "import org.apache.mxnet.annotation.Experimental",
+ absFuncs, Some(paramClass))
}
+ /**
+ * Generate Scala docs from the function description
+ * @param func The function case class
+ * @param withParam Whether to generate param field
+ * @return A formatted string for the function description
+ */
def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String
= {
def fixDesc(desc: String): String = {
var curDesc = desc
@@ -146,7 +192,15 @@ private[mxnet] object APIDocGenerator extends
GeneratorBase {
}
}
- def generateAPISignature(func: Func, isSymbol: Boolean): String = {
+ /**
+ * Generate the function interface
+ * e.g: def softmax(data: NDArray, name ...): NDArrayFunctionReturn
+ * @param func The function case class
+ * @param isSymbol Check if generate Symbol function, NDArray otherwise
+ * @param typeParameter Type param specifically used in Random Module
+ * @return Formatted string for the function
+ */
+ def generateAPISignature(func: Func, isSymbol: Boolean, typeParameter:
String = ""): String = {
val argDef = ListBuffer[String]()
argDef ++= typedFunctionCommonArgDef(func)
@@ -162,10 +216,15 @@ private[mxnet] object APIDocGenerator extends
GeneratorBase {
val returnType = func.returnType
s"""@Experimental
- |def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin
+ |def ${func.name}$typeParameter (${argDef.mkString(", ")}):
$returnType""".stripMargin
}
- def generateJavaAPISignature(func : Func) : String = {
+ /**
+ * Generate Java function interface
+ * @param func The function case class
+ * @return A formatted string for the function
+ */
+ def generateJavaAPISignature(func : Func) : (String, String) = {
val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2
var argDef = ListBuffer[String]()
var classDef = ListBuffer[String]()
@@ -204,54 +263,67 @@ private[mxnet] object APIDocGenerator extends
GeneratorBase {
| }
| def getOut() = this.out
| """.stripMargin
- s"""$scalaDocNoParam
- | $experimentalTag
- | def ${func.name}(po: ${func.name}Param) : $returnType
- | /**
- | * This Param Object is specifically used for ${func.name}
- | ${requiredParam.mkString("\n")}
- | */
- | class ${func.name}Param(${argDef.mkString(",")}) {
- | ${classDef.mkString("\n ")}
- | }""".stripMargin
+ (s"""$scalaDocNoParam
+ | $experimentalTag
+ | def ${func.name}(po: ${func.name}Param) : $returnType
+ | """.stripMargin,
+ s"""/**
+ | * This Param Object is specifically used for ${func.name}
+ | ${requiredParam.mkString("\n")}
+ | */
+ | class ${func.name}Param(${argDef.mkString(",")}) {
+ | ${classDef.mkString("\n ")}
+ | }""".stripMargin)
} else {
argDef += "out : NDArray"
- s"""$scalaDoc
- |$experimentalTag
- | def ${func.name}(${argDef.mkString(", ")}) : $returnType
- | """.stripMargin
+ (s"""$scalaDoc
+ |$experimentalTag
+ | def ${func.name}(${argDef.mkString(", ")}) : $returnType
+ | """.stripMargin, "")
}
}
- def writeFile(FILE_PATH: String, className: String, packageDef: String,
- absFuncs: Seq[String]): String = {
+ /**
+ * Write the formatted string to file
+ * @param FILE_PATH Location of the file writes to
+ * @param packageDef Package definition
+ * @param className Class name
+ * @param imports Packages need to import
+ * @param absFuncs All formatted functions
+ * @return A MD5 string
+ */
+ def writeFile(FILE_PATH: String, packageDef: String, className: String,
+ imports: String, absFuncs: Seq[String],
+ paramClass: Option[Seq[String]] = None): String = {
val finalStr =
s"""/*
- |* Licensed to the Apache Software Foundation (ASF) under one or more
- |* contributor license agreements. See the NOTICE file distributed
with
- |* this work for additional information regarding copyright ownership.
- |* The ASF licenses this file to You under the Apache License,
Version 2.0
- |* (the "License"); you may not use this file except in compliance
with
- |* the License. You may obtain a copy of the License at
- |*
- |* http://www.apache.org/licenses/LICENSE-2.0
- |*
- |* Unless required by applicable law or agreed to in writing, software
- |* distributed under the License is distributed on an "AS IS" BASIS,
- |* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied.
- |* See the License for the specific language governing permissions and
- |* limitations under the License.
- |*/
+ | * Licensed to the Apache Software Foundation (ASF) under one or more
+ | * contributor license agreements. See the NOTICE file distributed
with
+ | * this work for additional information regarding copyright
ownership.
+ | * The ASF licenses this file to You under the Apache License,
Version 2.0
+ | * (the "License"); you may not use this file except in compliance
with
+ | * the License. You may obtain a copy of the License at
+ | *
+ | * http://www.apache.org/licenses/LICENSE-2.0
+ | *
+ | * Unless required by applicable law or agreed to in writing,
software
+ | * distributed under the License is distributed on an "AS IS" BASIS,
+ | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied.
+ | * See the License for the specific language governing permissions
and
+ | * limitations under the License.
+ | */
|
|$packageDef
|
- |import org.apache.mxnet.annotation.Experimental
+ |$imports
|
|// scalastyle:off
|abstract class $className {
|${absFuncs.mkString("\n")}
- |}""".stripMargin
+ |}
+ |${paramClass.getOrElse(Seq()).mkString("\n")}
+ |""".stripMargin
val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala"))