lanking520 closed pull request #13619: [MXNET-1231] Allow not using Some in the Scala operators URL: https://github.com/apache/incubator-mxnet/pull/13619
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala b/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala new file mode 100644 index 00000000000..2cf453ac3d1 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet.util + +object OptionConversion { + implicit def someWrapper[A](noSome : A) : Option[A] = Option(noSome) +} diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala index 7992a0ed867..2db9ff11b37 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala @@ -593,4 +593,17 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers { assert(rnd.shape === Shape(1, 2, 3, 4)) assert(rnd2.shape === Shape(3, 4)) } + + test("Generated api") { + // Without SomeConversion + val arr3 = NDArray.ones(Shape(1, 2), dtype = DType.Float64) + val arr4 = NDArray.ones(Shape(1), dtype = DType.Float64) + val arr5 = NDArray.api.norm(arr3, ord = Some(1), out = Some(arr4)) + // With SomeConversion + import org.apache.mxnet.util.OptionConversion._ + val arr = NDArray.ones(Shape(1, 2), dtype = DType.Float64) + val arr2 = NDArray.ones(Shape(1), dtype = DType.Float64) + NDArray.api.norm(arr, ord = 1, out = arr2) + val result = NDArray.api.dot(arr2, arr2) + } } diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala index cf55bc10d97..5208923275f 100644 --- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala +++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala @@ -126,7 +126,6 @@ class Classifier(modelPathPrefix: String, }) val predictResult = predictResultPar.toArray - var result: ListBuffer[IndexedSeq[(String, Float)]] = ListBuffer.empty[IndexedSeq[(String, Float)]] diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala index 1c2c4fd704b..498c4e94366 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala @@ -96,7 +96,7 @@ private[mxnet] abstract class GeneratorBase { else if (isSymbol) "org.apache.mxnet.Symbol" else "org.apache.mxnet.NDArray" val typeAndOption = - CToScalaUtils.argumentCleaner(argName, argType, family) + CToScalaUtils.argumentCleaner(argName, argType, family, isJava) Arg(argName, typeAndOption._1, argDesc, typeAndOption._2) } val returnType = @@ -191,7 +191,7 @@ private[mxnet] trait RandomHelpers { // unify call targets (random_xyz and sample_xyz) and unify their argument types private def unifyRandom(func: Func, isSymbol: Boolean): Func = { var typeConv = Set("org.apache.mxnet.NDArray", "org.apache.mxnet.Symbol", - "java.lang.Float", "java.lang.Integer") + "Float", "Int") func.copy( name = func.name.replaceAll("(random|sample)_", ""), diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala index 2fd8b2e73c7..57c4cfba10b 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala @@ -18,23 +18,35 @@ package org.apache.mxnet.utils private[mxnet] object CToScalaUtils { - + private val javaType = Map( + "float" -> "java.lang.Float", + "int" -> "java.lang.Integer", + "long" -> "java.lang.Long", + "double" -> "java.lang.Double", + "bool" -> "java.lang.Boolean") + private val scalaType = Map( + "float" -> "Float", + "int" -> "Int", + "long" -> "Long", + "double" -> "Double", + "bool" -> "Boolean") // Convert C++ Types to Scala Types def typeConversion(in : String, argType : String = "", argName : String, - returnType : String) : String = { + returnType : String, isJava : Boolean) : String = { val header = returnType.split("\\.").dropRight(1) + val types = if (isJava) javaType else scalaType in match { case "Shape(tuple)" | "ShapeorNone" => s"${header.mkString(".")}.Shape" case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]" => s"Array[$returnType]" - case "float" | "real_t" | "floatorNone" => "java.lang.Float" - case "int" | "intorNone" | "int(non-negative)" => "java.lang.Integer" - case "long" | "long(non-negative)" => "java.lang.Long" - case "double" | "doubleorNone" => "java.lang.Double" + case "float" | "real_t" | "floatorNone" => types("float") + case "int" | "intorNone" | "int(non-negative)" => types("int") + case "long" | "long(non-negative)" => types("long") + case "double" | "doubleorNone" => types("double") case "string" => "String" - case "boolean" | "booleanorNone" => "java.lang.Boolean" + case "boolean" | "booleanorNone" => types("bool") case "tupleof<float>" | "tupleof<double>" | "tupleof<>" | "ptr" | "" => "Any" case default => throw new IllegalArgumentException( s"Invalid type for args: $default\nString argType: $argType\nargName: $argName") @@ -54,7 +66,7 @@ private[mxnet] object CToScalaUtils { * @return (Scala_Type, isOptional) */ def argumentCleaner(argName: String, argType : String, - returnType : String) : (String, Boolean) = { + returnType : String, isJava : Boolean) : (String, Boolean) = { val spaceRemoved = argType.replaceAll("\\s+", "") var commaRemoved : Array[String] = new Array[String](0) // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'} @@ -72,9 +84,9 @@ private[mxnet] object CToScalaUtils { s"""expected "optional" got ${commaRemoved(1)}""") require(commaRemoved(2).startsWith("default="), s"""expected "default=..." got ${commaRemoved(2)}""") - (typeConversion(commaRemoved(0), argType, argName, returnType), true) + (typeConversion(commaRemoved(0), argType, argName, returnType, isJava), true) } else if (commaRemoved.length == 2 || commaRemoved.length == 1) { - val tempType = typeConversion(commaRemoved(0), argType, argName, returnType) + val tempType = typeConversion(commaRemoved(0), argType, argName, returnType, isJava) val tempOptional = tempType.equals("org.apache.mxnet.Symbol") (tempType, tempOptional) } else { diff --git a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala index 4404b0885d5..4069bba2522 100644 --- a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala +++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala @@ -36,14 +36,15 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll { ) val output = List( ("org.apache.mxnet.Symbol", true), - ("java.lang.Integer", false), + ("Int", false), ("org.apache.mxnet.Shape", true), ("String", true), ("Any", false) ) for (idx <- input.indices) { - val result = CToScalaUtils.argumentCleaner("Sample", input(idx), "org.apache.mxnet.Symbol") + val result = CToScalaUtils.argumentCleaner("Sample", input(idx), + "org.apache.mxnet.Symbol", false) assert(result._1 === output(idx)._1 && result._2 === output(idx)._2) } } ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
