This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new fe46cd9  [MXNET-1231] Allow not using Some in the Scala operators 
(#13619)
fe46cd9 is described below

commit fe46cd96016a69e6d0a67af86926a32edbe77159
Author: Lanking <[email protected]>
AuthorDate: Thu Jan 3 12:39:33 2019 -0500

    [MXNET-1231] Allow not using Some in the Scala operators (#13619)
    
    * add initial commit
    
    * update image classifier as well
    
    * create Util class make Some conversion
    
    * add test changes
    
    * adress Comments
    
    * fix the spacing problem
    
    * fix generator base
    
    * change name to Option
---
 .../org/apache/mxnet/util/OptionConversion.scala   | 22 +++++++++++++++
 .../test/scala/org/apache/mxnet/NDArraySuite.scala | 13 +++++++++
 .../scala/org/apache/mxnet/infer/Classifier.scala  |  1 -
 .../scala/org/apache/mxnet/GeneratorBase.scala     |  4 +--
 .../org/apache/mxnet/utils/CToScalaUtils.scala     | 32 +++++++++++++++-------
 .../test/scala/org/apache/mxnet/MacrosSuite.scala  |  5 ++--
 6 files changed, 62 insertions(+), 15 deletions(-)

diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala
 
b/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala
new file mode 100644
index 0000000..2cf453a
--- /dev/null
+++ 
b/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.util
+
+object OptionConversion {
+  implicit def someWrapper[A](noSome : A) : Option[A] = Option(noSome)
+}
diff --git 
a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index 7992a0e..2db9ff1 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -593,4 +593,17 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll 
with Matchers {
     assert(rnd.shape === Shape(1, 2, 3, 4))
     assert(rnd2.shape === Shape(3, 4))
   }
+
+  test("Generated api") {
+    // Without SomeConversion
+    val arr3 = NDArray.ones(Shape(1, 2), dtype = DType.Float64)
+    val arr4 = NDArray.ones(Shape(1), dtype = DType.Float64)
+    val arr5 = NDArray.api.norm(arr3, ord = Some(1), out = Some(arr4))
+    // With SomeConversion
+    import org.apache.mxnet.util.OptionConversion._
+    val arr = NDArray.ones(Shape(1, 2), dtype = DType.Float64)
+    val arr2 = NDArray.ones(Shape(1), dtype = DType.Float64)
+    NDArray.api.norm(arr, ord = 1, out = arr2)
+    val result = NDArray.api.dot(arr2, arr2)
+  }
 }
diff --git 
a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala 
b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
index cf55bc1..5208923 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
@@ -126,7 +126,6 @@ class Classifier(modelPathPrefix: String,
     })
 
     val predictResult = predictResultPar.toArray
-
     var result: ListBuffer[IndexedSeq[(String, Float)]] =
       ListBuffer.empty[IndexedSeq[(String, Float)]]
 
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
index 1c2c4fd..498c4e9 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
@@ -96,7 +96,7 @@ private[mxnet] abstract class GeneratorBase {
       else if (isSymbol) "org.apache.mxnet.Symbol"
       else "org.apache.mxnet.NDArray"
       val typeAndOption =
-        CToScalaUtils.argumentCleaner(argName, argType, family)
+        CToScalaUtils.argumentCleaner(argName, argType, family, isJava)
       Arg(argName, typeAndOption._1, argDesc, typeAndOption._2)
     }
     val returnType =
@@ -191,7 +191,7 @@ private[mxnet] trait RandomHelpers {
   // unify call targets (random_xyz and sample_xyz) and unify their argument 
types
   private def unifyRandom(func: Func, isSymbol: Boolean): Func = {
     var typeConv = Set("org.apache.mxnet.NDArray", "org.apache.mxnet.Symbol",
-      "java.lang.Float", "java.lang.Integer")
+      "Float", "Int")
 
     func.copy(
       name = func.name.replaceAll("(random|sample)_", ""),
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
 
b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
index 2fd8b2e..57c4cfb 100644
--- 
a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
+++ 
b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
@@ -18,23 +18,35 @@ package org.apache.mxnet.utils
 
 private[mxnet] object CToScalaUtils {
 
-
+  private val javaType = Map(
+    "float" -> "java.lang.Float",
+    "int" -> "java.lang.Integer",
+    "long" -> "java.lang.Long",
+    "double" -> "java.lang.Double",
+    "bool" -> "java.lang.Boolean")
+  private val scalaType = Map(
+    "float" -> "Float",
+    "int" -> "Int",
+    "long" -> "Long",
+    "double" -> "Double",
+    "bool" -> "Boolean")
 
   // Convert C++ Types to Scala Types
   def typeConversion(in : String, argType : String = "", argName : String,
-                     returnType : String) : String = {
+                     returnType : String, isJava : Boolean) : String = {
     val header = returnType.split("\\.").dropRight(1)
+    val types = if (isJava) javaType else scalaType
     in match {
       case "Shape(tuple)" | "ShapeorNone" => s"${header.mkString(".")}.Shape"
       case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType
       case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | 
"SymbolorSymbol[]"
       => s"Array[$returnType]"
-      case "float" | "real_t" | "floatorNone" => "java.lang.Float"
-      case "int" | "intorNone" | "int(non-negative)" => "java.lang.Integer"
-      case "long" | "long(non-negative)" => "java.lang.Long"
-      case "double" | "doubleorNone" => "java.lang.Double"
+      case "float" | "real_t" | "floatorNone" => types("float")
+      case "int" | "intorNone" | "int(non-negative)" => types("int")
+      case "long" | "long(non-negative)" => types("long")
+      case "double" | "doubleorNone" => types("double")
       case "string" => "String"
-      case "boolean" | "booleanorNone" => "java.lang.Boolean"
+      case "boolean" | "booleanorNone" => types("bool")
       case "tupleof<float>" | "tupleof<double>" | "tupleof<>" | "ptr" | "" => 
"Any"
       case default => throw new IllegalArgumentException(
         s"Invalid type for args: $default\nString argType: $argType\nargName: 
$argName")
@@ -54,7 +66,7 @@ private[mxnet] object CToScalaUtils {
     * @return (Scala_Type, isOptional)
     */
   def argumentCleaner(argName: String, argType : String,
-                      returnType : String) : (String, Boolean) = {
+                      returnType : String, isJava : Boolean) : (String, 
Boolean) = {
     val spaceRemoved = argType.replaceAll("\\s+", "")
     var commaRemoved : Array[String] = new Array[String](0)
     // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
@@ -72,9 +84,9 @@ private[mxnet] object CToScalaUtils {
         s"""expected "optional" got ${commaRemoved(1)}""")
       require(commaRemoved(2).startsWith("default="),
         s"""expected "default=..." got ${commaRemoved(2)}""")
-      (typeConversion(commaRemoved(0), argType, argName, returnType), true)
+      (typeConversion(commaRemoved(0), argType, argName, returnType, isJava), 
true)
     } else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
-      val tempType = typeConversion(commaRemoved(0), argType, argName, 
returnType)
+      val tempType = typeConversion(commaRemoved(0), argType, argName, 
returnType, isJava)
       val tempOptional = tempType.equals("org.apache.mxnet.Symbol")
       (tempType, tempOptional)
     } else {
diff --git 
a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala 
b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
index 4404b08..4069bba 100644
--- a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
+++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
@@ -36,14 +36,15 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll {
     )
     val output = List(
       ("org.apache.mxnet.Symbol", true),
-      ("java.lang.Integer", false),
+      ("Int", false),
       ("org.apache.mxnet.Shape", true),
       ("String", true),
       ("Any", false)
     )
 
     for (idx <- input.indices) {
-      val result = CToScalaUtils.argumentCleaner("Sample", input(idx), 
"org.apache.mxnet.Symbol")
+      val result = CToScalaUtils.argumentCleaner("Sample", input(idx),
+        "org.apache.mxnet.Symbol", false)
       assert(result._1 === output(idx)._1 && result._2 === output(idx)._2)
     }
   }

Reply via email to