This is an automated email from the ASF dual-hosted git repository.
lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new fe46cd9 [MXNET-1231] Allow not using Some in the Scala operators
(#13619)
fe46cd9 is described below
commit fe46cd96016a69e6d0a67af86926a32edbe77159
Author: Lanking <[email protected]>
AuthorDate: Thu Jan 3 12:39:33 2019 -0500
[MXNET-1231] Allow not using Some in the Scala operators (#13619)
* add initial commit
* update image classifier as well
* create Util class make Some conversion
* add test changes
* adress Comments
* fix the spacing problem
* fix generator base
* change name to Option
---
.../org/apache/mxnet/util/OptionConversion.scala | 22 +++++++++++++++
.../test/scala/org/apache/mxnet/NDArraySuite.scala | 13 +++++++++
.../scala/org/apache/mxnet/infer/Classifier.scala | 1 -
.../scala/org/apache/mxnet/GeneratorBase.scala | 4 +--
.../org/apache/mxnet/utils/CToScalaUtils.scala | 32 +++++++++++++++-------
.../test/scala/org/apache/mxnet/MacrosSuite.scala | 5 ++--
6 files changed, 62 insertions(+), 15 deletions(-)
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala
new file mode 100644
index 0000000..2cf453a
--- /dev/null
+++
b/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.util
+
+object OptionConversion {
+ implicit def someWrapper[A](noSome : A) : Option[A] = Option(noSome)
+}
diff --git
a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index 7992a0e..2db9ff1 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -593,4 +593,17 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll
with Matchers {
assert(rnd.shape === Shape(1, 2, 3, 4))
assert(rnd2.shape === Shape(3, 4))
}
+
+ test("Generated api") {
+ // Without SomeConversion
+ val arr3 = NDArray.ones(Shape(1, 2), dtype = DType.Float64)
+ val arr4 = NDArray.ones(Shape(1), dtype = DType.Float64)
+ val arr5 = NDArray.api.norm(arr3, ord = Some(1), out = Some(arr4))
+ // With SomeConversion
+ import org.apache.mxnet.util.OptionConversion._
+ val arr = NDArray.ones(Shape(1, 2), dtype = DType.Float64)
+ val arr2 = NDArray.ones(Shape(1), dtype = DType.Float64)
+ NDArray.api.norm(arr, ord = 1, out = arr2)
+ val result = NDArray.api.dot(arr2, arr2)
+ }
}
diff --git
a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
index cf55bc1..5208923 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
@@ -126,7 +126,6 @@ class Classifier(modelPathPrefix: String,
})
val predictResult = predictResultPar.toArray
-
var result: ListBuffer[IndexedSeq[(String, Float)]] =
ListBuffer.empty[IndexedSeq[(String, Float)]]
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
index 1c2c4fd..498c4e9 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
@@ -96,7 +96,7 @@ private[mxnet] abstract class GeneratorBase {
else if (isSymbol) "org.apache.mxnet.Symbol"
else "org.apache.mxnet.NDArray"
val typeAndOption =
- CToScalaUtils.argumentCleaner(argName, argType, family)
+ CToScalaUtils.argumentCleaner(argName, argType, family, isJava)
Arg(argName, typeAndOption._1, argDesc, typeAndOption._2)
}
val returnType =
@@ -191,7 +191,7 @@ private[mxnet] trait RandomHelpers {
// unify call targets (random_xyz and sample_xyz) and unify their argument
types
private def unifyRandom(func: Func, isSymbol: Boolean): Func = {
var typeConv = Set("org.apache.mxnet.NDArray", "org.apache.mxnet.Symbol",
- "java.lang.Float", "java.lang.Integer")
+ "Float", "Int")
func.copy(
name = func.name.replaceAll("(random|sample)_", ""),
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
index 2fd8b2e..57c4cfb 100644
---
a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
+++
b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
@@ -18,23 +18,35 @@ package org.apache.mxnet.utils
private[mxnet] object CToScalaUtils {
-
+ private val javaType = Map(
+ "float" -> "java.lang.Float",
+ "int" -> "java.lang.Integer",
+ "long" -> "java.lang.Long",
+ "double" -> "java.lang.Double",
+ "bool" -> "java.lang.Boolean")
+ private val scalaType = Map(
+ "float" -> "Float",
+ "int" -> "Int",
+ "long" -> "Long",
+ "double" -> "Double",
+ "bool" -> "Boolean")
// Convert C++ Types to Scala Types
def typeConversion(in : String, argType : String = "", argName : String,
- returnType : String) : String = {
+ returnType : String, isJava : Boolean) : String = {
val header = returnType.split("\\.").dropRight(1)
+ val types = if (isJava) javaType else scalaType
in match {
case "Shape(tuple)" | "ShapeorNone" => s"${header.mkString(".")}.Shape"
case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType
case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" |
"SymbolorSymbol[]"
=> s"Array[$returnType]"
- case "float" | "real_t" | "floatorNone" => "java.lang.Float"
- case "int" | "intorNone" | "int(non-negative)" => "java.lang.Integer"
- case "long" | "long(non-negative)" => "java.lang.Long"
- case "double" | "doubleorNone" => "java.lang.Double"
+ case "float" | "real_t" | "floatorNone" => types("float")
+ case "int" | "intorNone" | "int(non-negative)" => types("int")
+ case "long" | "long(non-negative)" => types("long")
+ case "double" | "doubleorNone" => types("double")
case "string" => "String"
- case "boolean" | "booleanorNone" => "java.lang.Boolean"
+ case "boolean" | "booleanorNone" => types("bool")
case "tupleof<float>" | "tupleof<double>" | "tupleof<>" | "ptr" | "" =>
"Any"
case default => throw new IllegalArgumentException(
s"Invalid type for args: $default\nString argType: $argType\nargName:
$argName")
@@ -54,7 +66,7 @@ private[mxnet] object CToScalaUtils {
* @return (Scala_Type, isOptional)
*/
def argumentCleaner(argName: String, argType : String,
- returnType : String) : (String, Boolean) = {
+ returnType : String, isJava : Boolean) : (String,
Boolean) = {
val spaceRemoved = argType.replaceAll("\\s+", "")
var commaRemoved : Array[String] = new Array[String](0)
// Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
@@ -72,9 +84,9 @@ private[mxnet] object CToScalaUtils {
s"""expected "optional" got ${commaRemoved(1)}""")
require(commaRemoved(2).startsWith("default="),
s"""expected "default=..." got ${commaRemoved(2)}""")
- (typeConversion(commaRemoved(0), argType, argName, returnType), true)
+ (typeConversion(commaRemoved(0), argType, argName, returnType, isJava),
true)
} else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
- val tempType = typeConversion(commaRemoved(0), argType, argName,
returnType)
+ val tempType = typeConversion(commaRemoved(0), argType, argName,
returnType, isJava)
val tempOptional = tempType.equals("org.apache.mxnet.Symbol")
(tempType, tempOptional)
} else {
diff --git
a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
index 4404b08..4069bba 100644
--- a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
+++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
@@ -36,14 +36,15 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll {
)
val output = List(
("org.apache.mxnet.Symbol", true),
- ("java.lang.Integer", false),
+ ("Int", false),
("org.apache.mxnet.Shape", true),
("String", true),
("Any", false)
)
for (idx <- input.indices) {
- val result = CToScalaUtils.argumentCleaner("Sample", input(idx),
"org.apache.mxnet.Symbol")
+ val result = CToScalaUtils.argumentCleaner("Sample", input(idx),
+ "org.apache.mxnet.Symbol", false)
assert(result._1 === output(idx)._1 && result._2 === output(idx)._2)
}
}