This is an automated email from the ASF dual-hosted git repository.
lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new fc5fce7 [MXNET-918] Random module (#13039)
fc5fce7 is described below
commit fc5fce769e8a6beea807e2a30b4930d055371971
Author: mathieu <[email protected]>
AuthorDate: Fri Dec 14 23:22:06 2018 +0100
[MXNET-918] Random module (#13039)
* introduce random API
* revert useless changes
* shorter types in APIDoc gen code
* fix after merge from master
* Trigger CI
* temp code / diag on CI
* cleanup type-class code
* cleanup type-class code
* fix scalastyle
---
.../src/main/scala/org/apache/mxnet/Base.scala | 18 +++
.../src/main/scala/org/apache/mxnet/NDArray.scala | 1 +
.../main/scala/org/apache/mxnet/NDArrayAPI.scala | 13 +-
.../src/main/scala/org/apache/mxnet/Symbol.scala | 1 +
.../main/scala/org/apache/mxnet/SymbolAPI.scala | 12 +-
.../test/scala/org/apache/mxnet/NDArraySuite.scala | 17 ++
.../test/scala/org/apache/mxnet/SymbolSuite.scala | 22 +++
.../scala/org/apache/mxnet/APIDocGenerator.scala | 43 +++++-
.../scala/org/apache/mxnet/GeneratorBase.scala | 75 ++++++++-
.../main/scala/org/apache/mxnet/NDArrayMacro.scala | 171 ++++++++++++++++-----
.../main/scala/org/apache/mxnet/SymbolMacro.scala | 147 +++++++++++++-----
11 files changed, 435 insertions(+), 85 deletions(-)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
index b2a53fd..bb9518d 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
@@ -153,3 +153,21 @@ private[mxnet] object Base {
}
class MXNetError(val err: String) extends Exception(err)
+
+// Some type-classes to ease the work in Symbol.random and NDArray.random
modules
+
+class SymbolOrScalar[T](val isScalar: Boolean)
+object SymbolOrScalar {
+ def apply[T](implicit ev: SymbolOrScalar[T]): SymbolOrScalar[T] = ev
+ implicit object FloatWitness extends SymbolOrScalar[Float](true)
+ implicit object IntWitness extends SymbolOrScalar[Int](true)
+ implicit object SymbolWitness extends SymbolOrScalar[Symbol](false)
+}
+
+class NDArrayOrScalar[T](val isScalar: Boolean)
+object NDArrayOrScalar {
+ def apply[T](implicit ev: NDArrayOrScalar[T]): NDArrayOrScalar[T] = ev
+ implicit object FloatWitness extends NDArrayOrScalar[Float](true)
+ implicit object IntWitness extends NDArrayOrScalar[Int](true)
+ implicit object NDArrayWitness extends NDArrayOrScalar[NDArray](false)
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 3a0c3c1..1259581 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -40,6 +40,7 @@ object NDArray extends NDArrayBase {
private val functions: Map[String, NDArrayFunction] = initNDArrayModule()
val api = NDArrayAPI
+ val random = NDArrayRandomAPI
private def addDependency(froms: Array[NDArray], tos: Array[NDArray]): Unit
= {
froms.foreach { from =>
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala
index 1d8551c..024fed1 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala
@@ -15,11 +15,22 @@
* limitations under the License.
*/
package org.apache.mxnet
-@AddNDArrayAPIs(false)
+
/**
* typesafe NDArray API: NDArray.api._
* Main code will be generated during compile time through Macros
*/
+@AddNDArrayAPIs(false)
object NDArrayAPI extends NDArrayAPIBase {
// TODO: Implement CustomOp for NDArray
}
+
+/**
+ * typesafe NDArray random module: NDArray.random._
+ * Main code will be generated during compile time through Macros
+ */
+@AddNDArrayRandomAPIs(false)
+object NDArrayRandomAPI extends NDArrayRandomAPIBase {
+
+}
+
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index 01349a6..29885fc 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -842,6 +842,7 @@ object Symbol extends SymbolBase {
private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3)
val api = SymbolAPI
+ val random = SymbolRandomAPI
def pow(sym1: Symbol, sym2: Symbol): Symbol = {
Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2))
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
index 1bfb055..f166de1 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
@@ -19,11 +19,11 @@ package org.apache.mxnet
import scala.collection.mutable
-@AddSymbolAPIs(false)
/**
* typesafe Symbol API: Symbol.api._
* Main code will be generated during compile time through Macros
*/
+@AddSymbolAPIs(false)
object SymbolAPI extends SymbolAPIBase {
def Custom (op_type : String, kwargs : mutable.Map[String, Any],
name : String = null, attr : Map[String, String] = null) : Symbol
= {
@@ -32,3 +32,13 @@ object SymbolAPI extends SymbolAPIBase {
Symbol.createSymbolGeneral("Custom", name, attr, Seq(), map.toMap)
}
}
+
+/**
+ * typesafe Symbol random module: Symbol.random._
+ * Main code will be generated during compile time through Macros
+ */
+@AddSymbolRandomAPIs(false)
+object SymbolRandomAPI extends SymbolRandomAPIBase {
+
+}
+
diff --git
a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index 5d88bb3..7992a0e 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -576,4 +576,21 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll
with Matchers {
assert(arr.internal.toDoubleArray === Array(2d, 2d))
assert(arr.internal.toByteArray === Array(2.toByte, 2.toByte))
}
+
+ test("NDArray random module is generated properly") {
+ val lam = NDArray.ones(1, 2)
+ val rnd = NDArray.random.poisson(lam = Some(lam), shape = Some(Shape(3,
4)))
+ val rnd2 = NDArray.random.poisson(lam = Some(1f), shape = Some(Shape(3,
4)))
+ assert(rnd.shape === Shape(1, 2, 3, 4))
+ assert(rnd2.shape === Shape(3, 4))
+ }
+
+ test("NDArray random module is generated properly - special case of
'normal'") {
+ val mu = NDArray.ones(1, 2)
+ val sigma = NDArray.ones(1, 2) * 2
+ val rnd = NDArray.random.normal(mu = Some(mu), sigma = Some(sigma), shape
= Some(Shape(3, 4)))
+ val rnd2 = NDArray.random.normal(mu = Some(1f), sigma = Some(2f), shape =
Some(Shape(3, 4)))
+ assert(rnd.shape === Shape(1, 2, 3, 4))
+ assert(rnd2.shape === Shape(3, 4))
+ }
}
diff --git
a/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala
b/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala
index ebb61d7..d134c83 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala
@@ -20,6 +20,7 @@ package org.apache.mxnet
import org.scalatest.{BeforeAndAfterAll, FunSuite}
class SymbolSuite extends FunSuite with BeforeAndAfterAll {
+
test("symbol compose") {
val data = Symbol.Variable("data")
@@ -71,4 +72,25 @@ class SymbolSuite extends FunSuite with BeforeAndAfterAll {
val data2 = data.clone()
assert(data.toJson === data2.toJson)
}
+
+ test("Symbol random module is generated properly") {
+ val lam = Symbol.Variable("lam")
+ val rnd = Symbol.random.poisson(lam = Some(lam), shape = Some(Shape(2, 2)))
+ val rnd2 = Symbol.random.poisson(lam = Some(1f), shape = Some(Shape(2, 2)))
+ // scalastyle:off println
+ println(s"Symbol.random.poisson debug info: ${rnd.debugStr}")
+ println(s"Symbol.random.poisson debug info: ${rnd2.debugStr}")
+ // scalastyle:on println
+ }
+
+ test("Symbol random module is generated properly - special case of
'normal'") {
+ val loc = Symbol.Variable("loc")
+ val scale = Symbol.Variable("scale")
+ val rnd = Symbol.random.normal(mu = Some(loc), sigma = Some(scale), shape
= Some(Shape(2, 2)))
+ val rnd2 = Symbol.random.normal(mu = Some(1f), sigma = Some(2f), shape =
Some(Shape(2, 2)))
+ // scalastyle:off println
+ println(s"Symbol.random.sample_normal debug info: ${rnd.debugStr}")
+ println(s"Symbol.random.random_normal debug info: ${rnd2.debugStr}")
+ // scalastyle:on println
+ }
}
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
index ce12dc7..97cd18a 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
@@ -27,13 +27,15 @@ import scala.collection.mutable.ListBuffer
* Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
* The code will be executed during Macros stage and file live in Core stage
*/
-private[mxnet] object APIDocGenerator extends GeneratorBase {
+private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers
{
def main(args: Array[String]): Unit = {
val FILE_PATH = args(0)
val hashCollector = ListBuffer[String]()
hashCollector += typeSafeClassGen(FILE_PATH, true)
hashCollector += typeSafeClassGen(FILE_PATH, false)
+ hashCollector += typeSafeRandomClassGen(FILE_PATH, true)
+ hashCollector += typeSafeRandomClassGen(FILE_PATH, false)
hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
hashCollector += javaClassGen(FILE_PATH)
@@ -57,8 +59,27 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
writeFile(
FILE_PATH,
+ "package org.apache.mxnet",
if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
+ "import org.apache.mxnet.annotation.Experimental",
+ generated)
+ }
+
+ def typeSafeRandomClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
+ val generated = typeSafeRandomFunctionsToGenerate(isSymbol)
+ .map { func =>
+ val scalaDoc = generateAPIDocFromBackend(func)
+ val typeParameter = randomGenericTypeSpec(isSymbol, false)
+ val decl = generateAPISignature(func, isSymbol, typeParameter)
+ s"$scalaDoc\n$decl"
+ }
+
+ writeFile(
+ FILE_PATH,
"package org.apache.mxnet",
+ if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase",
+ """import org.apache.mxnet.annotation.Experimental
+ |import scala.reflect.ClassTag""".stripMargin,
generated)
}
@@ -85,8 +106,9 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
writeFile(
FILE_PATH,
- if (isSymbol) "SymbolBase" else "NDArrayBase",
"package org.apache.mxnet",
+ if (isSymbol) "SymbolBase" else "NDArrayBase",
+ "import org.apache.mxnet.annotation.Experimental",
absFuncs)
}
@@ -110,7 +132,12 @@ private[mxnet] object APIDocGenerator extends
GeneratorBase {
}).toSeq
val packageName = "NDArrayBase"
val packageDef = "package org.apache.mxnet.javaapi"
- writeFile(filePath + "javaapi/", packageName, packageDef, absFuncs)
+ writeFile(
+ filePath + "javaapi/",
+ packageDef,
+ packageName,
+ "import org.apache.mxnet.annotation.Experimental",
+ absFuncs)
}
def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String
= {
@@ -146,7 +173,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase
{
}
}
- def generateAPISignature(func: Func, isSymbol: Boolean): String = {
+ def generateAPISignature(func: Func, isSymbol: Boolean, typeParameter:
String = ""): String = {
val argDef = ListBuffer[String]()
argDef ++= typedFunctionCommonArgDef(func)
@@ -162,7 +189,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase
{
val returnType = func.returnType
s"""@Experimental
- |def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin
+ |def ${func.name}$typeParameter (${argDef.mkString(", ")}):
$returnType""".stripMargin
}
def generateJavaAPISignature(func : Func) : String = {
@@ -223,8 +250,8 @@ private[mxnet] object APIDocGenerator extends GeneratorBase
{
}
}
- def writeFile(FILE_PATH: String, className: String, packageDef: String,
- absFuncs: Seq[String]): String = {
+ def writeFile(FILE_PATH: String, packageDef: String, className: String,
+ imports: String, absFuncs: Seq[String]): String = {
val finalStr =
s"""/*
@@ -246,7 +273,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase
{
|
|$packageDef
|
- |import org.apache.mxnet.annotation.Experimental
+ |$imports
|
|// scalastyle:off
|abstract class $className {
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
index 9245ef1..1c2c4fd 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
@@ -23,7 +23,7 @@ import org.apache.mxnet.utils.{CToScalaUtils,
OperatorBuildUtils}
import scala.collection.mutable.ListBuffer
import scala.reflect.macros.blackbox
-abstract class GeneratorBase {
+private[mxnet] abstract class GeneratorBase {
type Handle = Long
case class Arg(argName: String, argType: String, argDesc: String,
isOptional: Boolean) {
@@ -46,7 +46,8 @@ abstract class GeneratorBase {
}
}
- def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib: Boolean):
List[Func] = {
+ // filter the operators to generate in the type-safe Symbol.api and
NDArray.api
+ protected def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib:
Boolean): List[Func] = {
// Operators that should not be generated
val notGenerated = Set("Custom")
@@ -144,8 +145,8 @@ abstract class GeneratorBase {
result
}
+ // build function argument definition, with optionality, and safe names
protected def typedFunctionCommonArgDef(func: Func): List[String] = {
- // build function argument definition, with optionality, and safe names
func.listOfArgs.map(arg =>
if (arg.isOptional) {
// let's avoid a stupid Option[Array[...]]
@@ -161,3 +162,71 @@ abstract class GeneratorBase {
)
}
}
+
+// a mixin to ease generating the Random module
+private[mxnet] trait RandomHelpers {
+ self: GeneratorBase =>
+
+ // a generic type spec used in Symbol.random and NDArray.random modules
+ protected def randomGenericTypeSpec(isSymbol: Boolean, fullPackageSpec:
Boolean): String = {
+ val classTag = if (fullPackageSpec) "scala.reflect.ClassTag" else
"ClassTag"
+ if (isSymbol) s"[T: SymbolOrScalar : $classTag]"
+ else s"[T: NDArrayOrScalar : $classTag]"
+ }
+
+ // filter the operators to generate in the type-safe Symbol.random and
NDArray.random
+ protected def typeSafeRandomFunctionsToGenerate(isSymbol: Boolean):
List[Func] = {
+ getBackEndFunctions(isSymbol)
+ .filter(f => f.name.startsWith("_sample_") ||
f.name.startsWith("_random_"))
+ .map(f => f.copy(name = f.name.stripPrefix("_")))
+ // unify _random and _sample
+ .map(f => unifyRandom(f, isSymbol))
+ // deduplicate
+ .groupBy(_.name)
+ .mapValues(_.head)
+ .values
+ .toList
+ }
+
+ // unify call targets (random_xyz and sample_xyz) and unify their argument
types
+ private def unifyRandom(func: Func, isSymbol: Boolean): Func = {
+ var typeConv = Set("org.apache.mxnet.NDArray", "org.apache.mxnet.Symbol",
+ "java.lang.Float", "java.lang.Integer")
+
+ func.copy(
+ name = func.name.replaceAll("(random|sample)_", ""),
+ listOfArgs = func.listOfArgs
+ .map(hackNormalFunc)
+ .map(arg =>
+ if (typeConv(arg.argType)) arg.copy(argType = "T")
+ else arg
+ )
+ // TODO: some functions are non consistent in random_ vs sample_
regarding optionality
+ // we may try to unify that as well here.
+ )
+ }
+
+ // hacks to manage the fact that random_normal and sample_normal have
+ // non-consistent parameter naming in the back-end
+ // this first one, merge loc/scale and mu/sigma
+ protected def hackNormalFunc(arg: Arg): Arg = {
+ if (arg.argName == "loc") arg.copy(argName = "mu")
+ else if (arg.argName == "scale") arg.copy(argName = "sigma")
+ else arg
+ }
+
+ // this second one reverts this merge prior to back-end call
+ protected def unhackNormalFunc(func: Func): String = {
+ if (func.name.equals("normal")) {
+ s"""if(target.equals("random_normal")) {
+ | if(map.contains("mu")) { map("loc") = map("mu"); map.remove("mu")
}
+ | if(map.contains("sigma")) { map("scale") = map("sigma");
map.remove("sigma") }
+ |}
+ """.stripMargin
+ } else {
+ ""
+ }
+
+ }
+
+}
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
index d85abe1..c18694b 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
@@ -18,7 +18,6 @@
package org.apache.mxnet
import scala.annotation.StaticAnnotation
-import scala.collection.mutable.ListBuffer
import scala.language.experimental.macros
import scala.reflect.macros.blackbox
@@ -30,6 +29,14 @@ private[mxnet] class AddNDArrayAPIs(isContrib: Boolean)
extends StaticAnnotation
private[mxnet] def macroTransform(annottees: Any*) = macro
TypedNDArrayAPIMacro.typeSafeAPIDefs
}
+private[mxnet] class AddNDArrayRandomAPIs(isContrib: Boolean) extends
StaticAnnotation {
+ private[mxnet] def macroTransform(annottees: Any*) =
+ macro TypedNDArrayRandomAPIMacro.typeSafeAPIDefs
+}
+
+/**
+ * For non-typed NDArray API
+ */
private[mxnet] object NDArrayMacro extends GeneratorBase {
def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
@@ -70,6 +77,9 @@ private[mxnet] object NDArrayMacro extends GeneratorBase {
}
}
+/**
+ * NDArray.api code generation
+ */
private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase {
def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*):
c.Expr[Any] = {
@@ -78,9 +88,9 @@ private[mxnet] object TypedNDArrayAPIMacro extends
GeneratorBase {
case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
}
- val functions = typeSafeFunctionsToGenerate(isSymbol = false, isContrib)
+ val functionDefs = typeSafeFunctionsToGenerate(isSymbol = false, isContrib)
+ .map(f => buildTypedFunction(c)(f))
- val functionDefs = functions.map(f => buildTypedFunction(c)(f))
structGeneration(c)(functionDefs, annottees: _*)
}
@@ -89,49 +99,136 @@ private[mxnet] object TypedNDArrayAPIMacro extends
GeneratorBase {
import c.universe._
val returnType = "org.apache.mxnet.NDArrayFuncReturn"
- val ndarrayType = "org.apache.mxnet.NDArray"
-
- // Construct argument field
- val argDef = ListBuffer[String]()
- argDef ++= typedFunctionCommonArgDef(function)
- argDef += "out : Option[NDArray] = None"
-
- // Construct Implementation field
- var impl = ListBuffer[String]()
- impl += "val map = scala.collection.mutable.Map[String, Any]()"
- impl += s"val args =
scala.collection.mutable.ArrayBuffer.empty[$ndarrayType]"
-
- // NDArray arg implementation
- impl ++= function.listOfArgs.map { arg =>
- if (arg.argType.equals(s"Array[$ndarrayType]")) {
- s"args ++= ${arg.safeArgName}"
- } else {
- val base =
- if (arg.argType.equals(ndarrayType)) {
- // ndarrays go to args
+
+ // Construct API arguments declaration
+ val argDecl = super.typedFunctionCommonArgDef(function) :+ "out :
Option[NDArray] = None"
+
+ // Map API input args to backend args
+ val backendArgsMapping =
+ function.listOfArgs.map { arg =>
+ // ndarrays go to args, other types go to kwargs
+ if (arg.argType.equals(s"Array[org.apache.mxnet.NDArray]")) {
+ s"args ++= ${arg.safeArgName}.toSeq"
+ } else {
+ val base = if (arg.argType.equals("org.apache.mxnet.NDArray")) {
s"args += ${arg.safeArgName}"
} else {
- // other types go to kwargs
s"""map("${arg.argName}") = ${arg.safeArgName}"""
}
- if (arg.isOptional) s"if (!${arg.safeArgName}.isEmpty) $base.get"
- else base
+ if (arg.isOptional) s"if (!${arg.safeArgName}.isEmpty) $base.get"
+ else base
+ }
}
- }
- impl +=
- s"""if (!out.isEmpty) map("out") = out.get
- |org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(
- | "${function.name}", args.toSeq, map.toMap)
+ val impl =
+ s"""
+ |def ${function.name}
+ | (${argDecl.mkString(",")}): $returnType = {
+ |
+ | val map = scala.collection.mutable.Map[String, Any]()
+ | val args =
scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]
+ |
+ | if (!out.isEmpty) map("out") = out.get
+ |
+ | ${backendArgsMapping.mkString("\n")}
+ |
+ | org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(
+ | "${function.name}", args.toSeq, map.toMap)
+ |}
""".stripMargin
- // Combine and build the function string
- val finalStr =
- s"""def ${function.name}
- | (${argDef.mkString(",")}) : $returnType
- | = {${impl.mkString("\n")}}
+ c.parse(impl).asInstanceOf[DefDef]
+ }
+}
+
+
+/**
+ * NDArray.random code generation
+ */
+private[mxnet] object TypedNDArrayRandomAPIMacro extends GeneratorBase
+ with RandomHelpers {
+
+ def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*):
c.Expr[Any] = {
+ // Note: no contrib managed in this module
+
+ val functionDefs = typeSafeRandomFunctionsToGenerate(isSymbol = false)
+ .map(f => buildTypedFunction(c)(f))
+
+ structGeneration(c)(functionDefs, annottees: _*)
+ }
+
+ protected def buildTypedFunction(c: blackbox.Context)
+ (function: Func): c.universe.DefDef = {
+ import c.universe._
+
+ val returnType = "org.apache.mxnet.NDArrayFuncReturn"
+
+ // Construct API arguments declaration
+ val argDecl = super.typedFunctionCommonArgDef(function) :+ "out :
Option[NDArray] = None"
+
+ // Map API input args to backend args
+ val backendArgsMapping =
+ function.listOfArgs.map { arg =>
+ // ndarrays go to args, other types go to kwargs
+ if (arg.argType.equals("Array[org.apache.mxnet.NDArray]")) {
+ s"args ++= ${arg.safeArgName}.toSeq"
+ } else {
+ if (arg.argType.equals("T")) {
+ if (arg.isOptional) {
+ s"""if(${arg.safeArgName}.isDefined) {
+ | if(isScalar) {
+ | map("${arg.argName}") = ${arg.safeArgName}.get
+ | } else {
+ | args +=
${arg.safeArgName}.get.asInstanceOf[org.apache.mxnet.NDArray]
+ | }
+ |}
+ """.stripMargin
+ } else {
+ s"""if(isScalar) {
+ | map("${arg.argName}") = ${arg.safeArgName}
+ |} else {
+ | args +=
${arg.safeArgName}.asInstanceOf[org.apache.mxnet.NDArray]
+ |}
+ """.stripMargin
+ }
+ } else {
+ if (arg.isOptional) {
+ s"""if (${arg.safeArgName}.isDefined)
map("${arg.argName}")=${arg.safeArgName}.get"""
+ } else {
+ s"""map("${arg.argName}") = ${arg.safeArgName}"""
+ }
+ }
+ }
+ }
+
+ val impl =
+ s"""
+ |def ${function.name}${randomGenericTypeSpec(false, true)}
+ | (${argDecl.mkString(",")}): $returnType = {
+ |
+ | val map = scala.collection.mutable.Map[String, Any]()
+ | val args =
scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]
+ | val isScalar = NDArrayOrScalar[T].isScalar
+ |
+ | if(out.isDefined) map("out") = out.get
+ |
+ | ${backendArgsMapping.mkString("\n")}
+ |
+ | val target = if(isScalar) {
+ | "random_${function.name}"
+ | } else {
+ | "sample_${function.name}"
+ | }
+ |
+ | ${unhackNormalFunc(function)}
+ |
+ | org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(
+ | target, args.toSeq, map.toMap)
+ |}
""".stripMargin
- c.parse(finalStr).asInstanceOf[DefDef]
+ c.parse(impl).asInstanceOf[DefDef]
}
+
+
}
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
index ab864e1..7ec80b9 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
@@ -17,8 +17,8 @@
package org.apache.mxnet
+
import scala.annotation.StaticAnnotation
-import scala.collection.mutable.ListBuffer
import scala.language.experimental.macros
import scala.reflect.macros.blackbox
@@ -30,6 +30,14 @@ private[mxnet] class AddSymbolAPIs(isContrib: Boolean)
extends StaticAnnotation
private[mxnet] def macroTransform(annottees: Any*) = macro
TypedSymbolAPIMacro.typeSafeAPIDefs
}
+private[mxnet] class AddSymbolRandomAPIs(isContrib: Boolean) extends
StaticAnnotation {
+ private[mxnet] def macroTransform(annottees: Any*) =
+ macro TypedSymbolRandomAPIMacro.typeSafeAPIDefs
+}
+
+/**
+ * For non-typed Symbol API
+ */
private[mxnet] object SymbolMacro extends GeneratorBase {
def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
@@ -63,6 +71,9 @@ private[mxnet] object SymbolMacro extends GeneratorBase {
}
}
+/**
+ * Symbol.api code generation
+ */
private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase {
def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*):
c.Expr[Any] = {
@@ -71,9 +82,9 @@ private[mxnet] object TypedSymbolAPIMacro extends
GeneratorBase {
case q"new AddSymbolAPIs($b)" => c.eval[Boolean](c.Expr(b))
}
- val functions = typeSafeFunctionsToGenerate(isSymbol = true, isContrib)
+ val functionDefs = typeSafeFunctionsToGenerate(isSymbol = true, isContrib)
+ .map(f => buildTypedFunction(c)(f))
- val functionDefs = functions.map(f => buildTypedFunction(c)(f))
structGeneration(c)(functionDefs, annottees: _*)
}
@@ -82,45 +93,111 @@ private[mxnet] object TypedSymbolAPIMacro extends
GeneratorBase {
import c.universe._
val returnType = "org.apache.mxnet.Symbol"
- val symbolType = "org.apache.mxnet.Symbol"
-
- // Construct argument field
- val argDef = ListBuffer[String]()
- argDef ++= typedFunctionCommonArgDef(function)
- argDef += "name : String = null"
- argDef += "attr : Map[String, String] = null"
-
- // Construct Implementation field
- val impl = ListBuffer[String]()
- impl += "val map = scala.collection.mutable.Map[String, Any]()"
- impl += s"var args = scala.collection.Seq[$symbolType]()"
-
- // Symbol arg implementation
- impl ++= function.listOfArgs.map { arg =>
- if (arg.argType.equals(s"Array[$symbolType]")) {
- s"if (!${arg.safeArgName}.isEmpty) args = ${arg.safeArgName}.toSeq"
- } else {
- // all go in kwargs
- if (arg.isOptional) {
- s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") =
${arg.safeArgName}.get"""
+
+ // Construct API arguments declaration
+ val argDecl = super.typedFunctionCommonArgDef(function) :+
+ "name : String = null" :+
+ "attr : Map[String, String] = null"
+
+ // Map API input args to backend args
+ val backendArgsMapping =
+ function.listOfArgs.map { arg =>
+ if (arg.argType.equals(s"Array[org.apache.mxnet.Symbol]")) {
+ s"args = ${arg.safeArgName}.toSeq"
} else {
- s"""map("${arg.argName}") = ${arg.safeArgName}"""
+ // all go in kwargs
+ if (arg.isOptional) {
+ s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") =
${arg.safeArgName}.get"""
+ } else {
+ s"""map("${arg.argName}") = ${arg.safeArgName}"""
+ }
}
}
- }
- impl +=
- s"""org.apache.mxnet.Symbol.createSymbolGeneral(
- | "${function.name}", name, attr, args, map.toMap)
+ val impl =
+ s"""
+ |def ${function.name}
+ | (${argDecl.mkString(",")}): $returnType = {
+ |
+ | val map = scala.collection.mutable.Map[String, Any]()
+ | var args = scala.collection.Seq[org.apache.mxnet.Symbol]()
+ |
+ | ${backendArgsMapping.mkString("\n")}
+ |
+ | org.apache.mxnet.Symbol.createSymbolGeneral(
+ | "${function.name}", name, attr, args, map.toMap)
+ |}
""".stripMargin
- // Combine and build the function string
- val finalStr =
- s"""def ${function.name}
- | (${argDef.mkString(",")}) : $returnType
- | = {${impl.mkString("\n")}}
+ c.parse(impl).asInstanceOf[DefDef]
+ }
+}
+
+
+/**
+ * Symbol.random code generation
+ */
+private[mxnet] object TypedSymbolRandomAPIMacro extends GeneratorBase
+ with RandomHelpers {
+
+ def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*):
c.Expr[Any] = {
+ val functionDefs = typeSafeRandomFunctionsToGenerate(isSymbol = true)
+ .map(f => buildTypedFunction(c)(f))
+
+ structGeneration(c)(functionDefs, annottees: _*)
+ }
+
+ protected def buildTypedFunction(c: blackbox.Context)
+ (function: Func): c.universe.DefDef = {
+ import c.universe._
+
+ val returnType = "org.apache.mxnet.Symbol"
+
+ // Construct API arguments declaration
+ val argDecl = super.typedFunctionCommonArgDef(function) :+
+ "name : String = null" :+
+ "attr : Map[String, String] = null"
+
+ // Map API input args to backend args
+ val backendArgsMapping =
+ function.listOfArgs.map { arg =>
+ if (arg.argType.equals(s"Array[org.apache.mxnet.Symbol]")) {
+ s"args = ${arg.safeArgName}.toSeq"
+ } else {
+ // all go in kwargs
+ if (arg.isOptional) {
+ s"""if (${arg.safeArgName}.isDefined) map("${arg.argName}") =
${arg.safeArgName}.get"""
+ } else {
+ s"""map("${arg.argName}") = ${arg.safeArgName}"""
+ }
+ }
+ }
+
+ val impl =
+ s"""
+ |def ${function.name}${randomGenericTypeSpec(true, true)}
+ | (${argDecl.mkString(",")}): $returnType = {
+ |
+ | val map = scala.collection.mutable.Map[String, Any]()
+ | var args = scala.collection.Seq[org.apache.mxnet.Symbol]()
+ | val isScalar = SymbolOrScalar[T].isScalar
+ |
+ | ${backendArgsMapping.mkString("\n")}
+ |
+ | val target = if(isScalar) {
+ | "random_${function.name}"
+ | } else {
+ | "sample_${function.name}"
+ | }
+ |
+ | ${unhackNormalFunc(function)}
+ |
+ | org.apache.mxnet.Symbol.createSymbolGeneral(
+ | target, name, attr, args, map.toMap)
+ |}
""".stripMargin
- c.parse(finalStr).asInstanceOf[DefDef]
+ c.parse(impl).asInstanceOf[DefDef]
}
}
+