lanking520 closed pull request #13039: [MXNET-918] Random module
URL: https://github.com/apache/incubator-mxnet/pull/13039
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
index b2a53fd9f2d..bb9518d51f1 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
@@ -153,3 +153,21 @@ private[mxnet] object Base {
}
class MXNetError(val err: String) extends Exception(err)
+
+// Some type-classes to ease the work in Symbol.random and NDArray.random
modules
+
+class SymbolOrScalar[T](val isScalar: Boolean)
+object SymbolOrScalar {
+ def apply[T](implicit ev: SymbolOrScalar[T]): SymbolOrScalar[T] = ev
+ implicit object FloatWitness extends SymbolOrScalar[Float](true)
+ implicit object IntWitness extends SymbolOrScalar[Int](true)
+ implicit object SymbolWitness extends SymbolOrScalar[Symbol](false)
+}
+
+class NDArrayOrScalar[T](val isScalar: Boolean)
+object NDArrayOrScalar {
+ def apply[T](implicit ev: NDArrayOrScalar[T]): NDArrayOrScalar[T] = ev
+ implicit object FloatWitness extends NDArrayOrScalar[Float](true)
+ implicit object IntWitness extends NDArrayOrScalar[Int](true)
+ implicit object NDArrayWitness extends NDArrayOrScalar[NDArray](false)
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 3a0c3c11f16..125958150b7 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -40,6 +40,7 @@ object NDArray extends NDArrayBase {
private val functions: Map[String, NDArrayFunction] = initNDArrayModule()
val api = NDArrayAPI
+ val random = NDArrayRandomAPI
private def addDependency(froms: Array[NDArray], tos: Array[NDArray]): Unit
= {
froms.foreach { from =>
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala
index 1d8551c1b1e..024fed1c4ba 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala
@@ -15,11 +15,22 @@
* limitations under the License.
*/
package org.apache.mxnet
-@AddNDArrayAPIs(false)
+
/**
* typesafe NDArray API: NDArray.api._
* Main code will be generated during compile time through Macros
*/
+@AddNDArrayAPIs(false)
object NDArrayAPI extends NDArrayAPIBase {
// TODO: Implement CustomOp for NDArray
}
+
+/**
+ * typesafe NDArray random module: NDArray.random._
+ * Main code will be generated during compile time through Macros
+ */
+@AddNDArrayRandomAPIs(false)
+object NDArrayRandomAPI extends NDArrayRandomAPIBase {
+
+}
+
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index 01349a689b6..29885fc723c 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -842,6 +842,7 @@ object Symbol extends SymbolBase {
private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3)
val api = SymbolAPI
+ val random = SymbolRandomAPI
def pow(sym1: Symbol, sym2: Symbol): Symbol = {
Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2))
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
index 1bfb0559cf9..f166de11ea5 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
@@ -19,11 +19,11 @@ package org.apache.mxnet
import scala.collection.mutable
-@AddSymbolAPIs(false)
/**
* typesafe Symbol API: Symbol.api._
* Main code will be generated during compile time through Macros
*/
+@AddSymbolAPIs(false)
object SymbolAPI extends SymbolAPIBase {
def Custom (op_type : String, kwargs : mutable.Map[String, Any],
name : String = null, attr : Map[String, String] = null) : Symbol
= {
@@ -32,3 +32,13 @@ object SymbolAPI extends SymbolAPIBase {
Symbol.createSymbolGeneral("Custom", name, attr, Seq(), map.toMap)
}
}
+
+/**
+ * typesafe Symbol random module: Symbol.random._
+ * Main code will be generated during compile time through Macros
+ */
+@AddSymbolRandomAPIs(false)
+object SymbolRandomAPI extends SymbolRandomAPIBase {
+
+}
+
diff --git
a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index 5d88bb39e50..7992a0ed867 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -576,4 +576,21 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll
with Matchers {
assert(arr.internal.toDoubleArray === Array(2d, 2d))
assert(arr.internal.toByteArray === Array(2.toByte, 2.toByte))
}
+
+ test("NDArray random module is generated properly") {
+ val lam = NDArray.ones(1, 2)
+ val rnd = NDArray.random.poisson(lam = Some(lam), shape = Some(Shape(3,
4)))
+ val rnd2 = NDArray.random.poisson(lam = Some(1f), shape = Some(Shape(3,
4)))
+ assert(rnd.shape === Shape(1, 2, 3, 4))
+ assert(rnd2.shape === Shape(3, 4))
+ }
+
+ test("NDArray random module is generated properly - special case of
'normal'") {
+ val mu = NDArray.ones(1, 2)
+ val sigma = NDArray.ones(1, 2) * 2
+ val rnd = NDArray.random.normal(mu = Some(mu), sigma = Some(sigma), shape
= Some(Shape(3, 4)))
+ val rnd2 = NDArray.random.normal(mu = Some(1f), sigma = Some(2f), shape =
Some(Shape(3, 4)))
+ assert(rnd.shape === Shape(1, 2, 3, 4))
+ assert(rnd2.shape === Shape(3, 4))
+ }
}
diff --git
a/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala
b/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala
index ebb61d7d4bf..d134c83ff7e 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala
@@ -20,6 +20,7 @@ package org.apache.mxnet
import org.scalatest.{BeforeAndAfterAll, FunSuite}
class SymbolSuite extends FunSuite with BeforeAndAfterAll {
+
test("symbol compose") {
val data = Symbol.Variable("data")
@@ -71,4 +72,25 @@ class SymbolSuite extends FunSuite with BeforeAndAfterAll {
val data2 = data.clone()
assert(data.toJson === data2.toJson)
}
+
+ test("Symbol random module is generated properly") {
+ val lam = Symbol.Variable("lam")
+ val rnd = Symbol.random.poisson(lam = Some(lam), shape = Some(Shape(2, 2)))
+ val rnd2 = Symbol.random.poisson(lam = Some(1f), shape = Some(Shape(2, 2)))
+ // scalastyle:off println
+ println(s"Symbol.random.poisson debug info: ${rnd.debugStr}")
+ println(s"Symbol.random.poisson debug info: ${rnd2.debugStr}")
+ // scalastyle:on println
+ }
+
+ test("Symbol random module is generated properly - special case of
'normal'") {
+ val loc = Symbol.Variable("loc")
+ val scale = Symbol.Variable("scale")
+ val rnd = Symbol.random.normal(mu = Some(loc), sigma = Some(scale), shape
= Some(Shape(2, 2)))
+ val rnd2 = Symbol.random.normal(mu = Some(1f), sigma = Some(2f), shape =
Some(Shape(2, 2)))
+ // scalastyle:off println
+ println(s"Symbol.random.sample_normal debug info: ${rnd.debugStr}")
+ println(s"Symbol.random.random_normal debug info: ${rnd2.debugStr}")
+ // scalastyle:on println
+ }
}
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
index ce12dc7cd5a..97cd18a5b33 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
@@ -27,13 +27,15 @@ import scala.collection.mutable.ListBuffer
* Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
* The code will be executed during Macros stage and file live in Core stage
*/
-private[mxnet] object APIDocGenerator extends GeneratorBase {
+private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers
{
def main(args: Array[String]): Unit = {
val FILE_PATH = args(0)
val hashCollector = ListBuffer[String]()
hashCollector += typeSafeClassGen(FILE_PATH, true)
hashCollector += typeSafeClassGen(FILE_PATH, false)
+ hashCollector += typeSafeRandomClassGen(FILE_PATH, true)
+ hashCollector += typeSafeRandomClassGen(FILE_PATH, false)
hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
hashCollector += javaClassGen(FILE_PATH)
@@ -57,8 +59,27 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
writeFile(
FILE_PATH,
+ "package org.apache.mxnet",
if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
+ "import org.apache.mxnet.annotation.Experimental",
+ generated)
+ }
+
+ def typeSafeRandomClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
+ val generated = typeSafeRandomFunctionsToGenerate(isSymbol)
+ .map { func =>
+ val scalaDoc = generateAPIDocFromBackend(func)
+ val typeParameter = randomGenericTypeSpec(isSymbol, false)
+ val decl = generateAPISignature(func, isSymbol, typeParameter)
+ s"$scalaDoc\n$decl"
+ }
+
+ writeFile(
+ FILE_PATH,
"package org.apache.mxnet",
+ if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase",
+ """import org.apache.mxnet.annotation.Experimental
+ |import scala.reflect.ClassTag""".stripMargin,
generated)
}
@@ -85,8 +106,9 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
writeFile(
FILE_PATH,
- if (isSymbol) "SymbolBase" else "NDArrayBase",
"package org.apache.mxnet",
+ if (isSymbol) "SymbolBase" else "NDArrayBase",
+ "import org.apache.mxnet.annotation.Experimental",
absFuncs)
}
@@ -110,7 +132,12 @@ private[mxnet] object APIDocGenerator extends
GeneratorBase {
}).toSeq
val packageName = "NDArrayBase"
val packageDef = "package org.apache.mxnet.javaapi"
- writeFile(filePath + "javaapi/", packageName, packageDef, absFuncs)
+ writeFile(
+ filePath + "javaapi/",
+ packageDef,
+ packageName,
+ "import org.apache.mxnet.annotation.Experimental",
+ absFuncs)
}
def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String
= {
@@ -146,7 +173,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase
{
}
}
- def generateAPISignature(func: Func, isSymbol: Boolean): String = {
+ def generateAPISignature(func: Func, isSymbol: Boolean, typeParameter:
String = ""): String = {
val argDef = ListBuffer[String]()
argDef ++= typedFunctionCommonArgDef(func)
@@ -162,7 +189,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase
{
val returnType = func.returnType
s"""@Experimental
- |def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin
+ |def ${func.name}$typeParameter (${argDef.mkString(", ")}):
$returnType""".stripMargin
}
def generateJavaAPISignature(func : Func) : String = {
@@ -223,8 +250,8 @@ private[mxnet] object APIDocGenerator extends GeneratorBase
{
}
}
- def writeFile(FILE_PATH: String, className: String, packageDef: String,
- absFuncs: Seq[String]): String = {
+ def writeFile(FILE_PATH: String, packageDef: String, className: String,
+ imports: String, absFuncs: Seq[String]): String = {
val finalStr =
s"""/*
@@ -246,7 +273,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase
{
|
|$packageDef
|
- |import org.apache.mxnet.annotation.Experimental
+ |$imports
|
|// scalastyle:off
|abstract class $className {
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
index 9245ef1b437..1c2c4fd704b 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
@@ -23,7 +23,7 @@ import org.apache.mxnet.utils.{CToScalaUtils,
OperatorBuildUtils}
import scala.collection.mutable.ListBuffer
import scala.reflect.macros.blackbox
-abstract class GeneratorBase {
+private[mxnet] abstract class GeneratorBase {
type Handle = Long
case class Arg(argName: String, argType: String, argDesc: String,
isOptional: Boolean) {
@@ -46,7 +46,8 @@ abstract class GeneratorBase {
}
}
- def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib: Boolean):
List[Func] = {
+ // filter the operators to generate in the type-safe Symbol.api and
NDArray.api
+ protected def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib:
Boolean): List[Func] = {
// Operators that should not be generated
val notGenerated = Set("Custom")
@@ -144,8 +145,8 @@ abstract class GeneratorBase {
result
}
+ // build function argument definition, with optionality, and safe names
protected def typedFunctionCommonArgDef(func: Func): List[String] = {
- // build function argument definition, with optionality, and safe names
func.listOfArgs.map(arg =>
if (arg.isOptional) {
// let's avoid a stupid Option[Array[...]]
@@ -161,3 +162,71 @@ abstract class GeneratorBase {
)
}
}
+
+// a mixin to ease generating the Random module
+private[mxnet] trait RandomHelpers {
+ self: GeneratorBase =>
+
+ // a generic type spec used in Symbol.random and NDArray.random modules
+ protected def randomGenericTypeSpec(isSymbol: Boolean, fullPackageSpec:
Boolean): String = {
+ val classTag = if (fullPackageSpec) "scala.reflect.ClassTag" else
"ClassTag"
+ if (isSymbol) s"[T: SymbolOrScalar : $classTag]"
+ else s"[T: NDArrayOrScalar : $classTag]"
+ }
+
+ // filter the operators to generate in the type-safe Symbol.random and
NDArray.random
+ protected def typeSafeRandomFunctionsToGenerate(isSymbol: Boolean):
List[Func] = {
+ getBackEndFunctions(isSymbol)
+ .filter(f => f.name.startsWith("_sample_") ||
f.name.startsWith("_random_"))
+ .map(f => f.copy(name = f.name.stripPrefix("_")))
+ // unify _random and _sample
+ .map(f => unifyRandom(f, isSymbol))
+ // deduplicate
+ .groupBy(_.name)
+ .mapValues(_.head)
+ .values
+ .toList
+ }
+
+ // unify call targets (random_xyz and sample_xyz) and unify their argument
types
+ private def unifyRandom(func: Func, isSymbol: Boolean): Func = {
+ var typeConv = Set("org.apache.mxnet.NDArray", "org.apache.mxnet.Symbol",
+ "java.lang.Float", "java.lang.Integer")
+
+ func.copy(
+ name = func.name.replaceAll("(random|sample)_", ""),
+ listOfArgs = func.listOfArgs
+ .map(hackNormalFunc)
+ .map(arg =>
+ if (typeConv(arg.argType)) arg.copy(argType = "T")
+ else arg
+ )
+ // TODO: some functions are non consistent in random_ vs sample_
regarding optionality
+ // we may try to unify that as well here.
+ )
+ }
+
+ // hacks to manage the fact that random_normal and sample_normal have
+ // non-consistent parameter naming in the back-end
+ // this first one, merge loc/scale and mu/sigma
+ protected def hackNormalFunc(arg: Arg): Arg = {
+ if (arg.argName == "loc") arg.copy(argName = "mu")
+ else if (arg.argName == "scale") arg.copy(argName = "sigma")
+ else arg
+ }
+
+ // this second one reverts this merge prior to back-end call
+ protected def unhackNormalFunc(func: Func): String = {
+ if (func.name.equals("normal")) {
+ s"""if(target.equals("random_normal")) {
+ | if(map.contains("mu")) { map("loc") = map("mu"); map.remove("mu")
}
+ | if(map.contains("sigma")) { map("scale") = map("sigma");
map.remove("sigma") }
+ |}
+ """.stripMargin
+ } else {
+ ""
+ }
+
+ }
+
+}
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
index d85abe1ecc4..c18694b59bf 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
@@ -18,7 +18,6 @@
package org.apache.mxnet
import scala.annotation.StaticAnnotation
-import scala.collection.mutable.ListBuffer
import scala.language.experimental.macros
import scala.reflect.macros.blackbox
@@ -30,6 +29,14 @@ private[mxnet] class AddNDArrayAPIs(isContrib: Boolean)
extends StaticAnnotation
private[mxnet] def macroTransform(annottees: Any*) = macro
TypedNDArrayAPIMacro.typeSafeAPIDefs
}
+private[mxnet] class AddNDArrayRandomAPIs(isContrib: Boolean) extends
StaticAnnotation {
+ private[mxnet] def macroTransform(annottees: Any*) =
+ macro TypedNDArrayRandomAPIMacro.typeSafeAPIDefs
+}
+
+/**
+ * For non-typed NDArray API
+ */
private[mxnet] object NDArrayMacro extends GeneratorBase {
def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
@@ -70,6 +77,9 @@ private[mxnet] object NDArrayMacro extends GeneratorBase {
}
}
+/**
+ * NDArray.api code generation
+ */
private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase {
def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*):
c.Expr[Any] = {
@@ -78,9 +88,9 @@ private[mxnet] object TypedNDArrayAPIMacro extends
GeneratorBase {
case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
}
- val functions = typeSafeFunctionsToGenerate(isSymbol = false, isContrib)
+ val functionDefs = typeSafeFunctionsToGenerate(isSymbol = false, isContrib)
+ .map(f => buildTypedFunction(c)(f))
- val functionDefs = functions.map(f => buildTypedFunction(c)(f))
structGeneration(c)(functionDefs, annottees: _*)
}
@@ -89,49 +99,136 @@ private[mxnet] object TypedNDArrayAPIMacro extends
GeneratorBase {
import c.universe._
val returnType = "org.apache.mxnet.NDArrayFuncReturn"
- val ndarrayType = "org.apache.mxnet.NDArray"
-
- // Construct argument field
- val argDef = ListBuffer[String]()
- argDef ++= typedFunctionCommonArgDef(function)
- argDef += "out : Option[NDArray] = None"
-
- // Construct Implementation field
- var impl = ListBuffer[String]()
- impl += "val map = scala.collection.mutable.Map[String, Any]()"
- impl += s"val args =
scala.collection.mutable.ArrayBuffer.empty[$ndarrayType]"
-
- // NDArray arg implementation
- impl ++= function.listOfArgs.map { arg =>
- if (arg.argType.equals(s"Array[$ndarrayType]")) {
- s"args ++= ${arg.safeArgName}"
- } else {
- val base =
- if (arg.argType.equals(ndarrayType)) {
- // ndarrays go to args
+
+ // Construct API arguments declaration
+ val argDecl = super.typedFunctionCommonArgDef(function) :+ "out :
Option[NDArray] = None"
+
+ // Map API input args to backend args
+ val backendArgsMapping =
+ function.listOfArgs.map { arg =>
+ // ndarrays go to args, other types go to kwargs
+ if (arg.argType.equals(s"Array[org.apache.mxnet.NDArray]")) {
+ s"args ++= ${arg.safeArgName}.toSeq"
+ } else {
+ val base = if (arg.argType.equals("org.apache.mxnet.NDArray")) {
s"args += ${arg.safeArgName}"
} else {
- // other types go to kwargs
s"""map("${arg.argName}") = ${arg.safeArgName}"""
}
- if (arg.isOptional) s"if (!${arg.safeArgName}.isEmpty) $base.get"
- else base
+ if (arg.isOptional) s"if (!${arg.safeArgName}.isEmpty) $base.get"
+ else base
+ }
}
- }
- impl +=
- s"""if (!out.isEmpty) map("out") = out.get
- |org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(
- | "${function.name}", args.toSeq, map.toMap)
+ val impl =
+ s"""
+ |def ${function.name}
+ | (${argDecl.mkString(",")}): $returnType = {
+ |
+ | val map = scala.collection.mutable.Map[String, Any]()
+ | val args =
scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]
+ |
+ | if (!out.isEmpty) map("out") = out.get
+ |
+ | ${backendArgsMapping.mkString("\n")}
+ |
+ | org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(
+ | "${function.name}", args.toSeq, map.toMap)
+ |}
""".stripMargin
- // Combine and build the function string
- val finalStr =
- s"""def ${function.name}
- | (${argDef.mkString(",")}) : $returnType
- | = {${impl.mkString("\n")}}
+ c.parse(impl).asInstanceOf[DefDef]
+ }
+}
+
+
+/**
+ * NDArray.random code generation
+ */
+private[mxnet] object TypedNDArrayRandomAPIMacro extends GeneratorBase
+ with RandomHelpers {
+
+ def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*):
c.Expr[Any] = {
+ // Note: no contrib managed in this module
+
+ val functionDefs = typeSafeRandomFunctionsToGenerate(isSymbol = false)
+ .map(f => buildTypedFunction(c)(f))
+
+ structGeneration(c)(functionDefs, annottees: _*)
+ }
+
+ protected def buildTypedFunction(c: blackbox.Context)
+ (function: Func): c.universe.DefDef = {
+ import c.universe._
+
+ val returnType = "org.apache.mxnet.NDArrayFuncReturn"
+
+ // Construct API arguments declaration
+ val argDecl = super.typedFunctionCommonArgDef(function) :+ "out :
Option[NDArray] = None"
+
+ // Map API input args to backend args
+ val backendArgsMapping =
+ function.listOfArgs.map { arg =>
+ // ndarrays go to args, other types go to kwargs
+ if (arg.argType.equals("Array[org.apache.mxnet.NDArray]")) {
+ s"args ++= ${arg.safeArgName}.toSeq"
+ } else {
+ if (arg.argType.equals("T")) {
+ if (arg.isOptional) {
+ s"""if(${arg.safeArgName}.isDefined) {
+ | if(isScalar) {
+ | map("${arg.argName}") = ${arg.safeArgName}.get
+ | } else {
+ | args +=
${arg.safeArgName}.get.asInstanceOf[org.apache.mxnet.NDArray]
+ | }
+ |}
+ """.stripMargin
+ } else {
+ s"""if(isScalar) {
+ | map("${arg.argName}") = ${arg.safeArgName}
+ |} else {
+ | args +=
${arg.safeArgName}.asInstanceOf[org.apache.mxnet.NDArray]
+ |}
+ """.stripMargin
+ }
+ } else {
+ if (arg.isOptional) {
+ s"""if (${arg.safeArgName}.isDefined)
map("${arg.argName}")=${arg.safeArgName}.get"""
+ } else {
+ s"""map("${arg.argName}") = ${arg.safeArgName}"""
+ }
+ }
+ }
+ }
+
+ val impl =
+ s"""
+ |def ${function.name}${randomGenericTypeSpec(false, true)}
+ | (${argDecl.mkString(",")}): $returnType = {
+ |
+ | val map = scala.collection.mutable.Map[String, Any]()
+ | val args =
scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]
+ | val isScalar = NDArrayOrScalar[T].isScalar
+ |
+ | if(out.isDefined) map("out") = out.get
+ |
+ | ${backendArgsMapping.mkString("\n")}
+ |
+ | val target = if(isScalar) {
+ | "random_${function.name}"
+ | } else {
+ | "sample_${function.name}"
+ | }
+ |
+ | ${unhackNormalFunc(function)}
+ |
+ | org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(
+ | target, args.toSeq, map.toMap)
+ |}
""".stripMargin
- c.parse(finalStr).asInstanceOf[DefDef]
+ c.parse(impl).asInstanceOf[DefDef]
}
+
+
}
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
index ab864e1ef19..7ec80b9c066 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
@@ -17,8 +17,8 @@
package org.apache.mxnet
+
import scala.annotation.StaticAnnotation
-import scala.collection.mutable.ListBuffer
import scala.language.experimental.macros
import scala.reflect.macros.blackbox
@@ -30,6 +30,14 @@ private[mxnet] class AddSymbolAPIs(isContrib: Boolean)
extends StaticAnnotation
private[mxnet] def macroTransform(annottees: Any*) = macro
TypedSymbolAPIMacro.typeSafeAPIDefs
}
+private[mxnet] class AddSymbolRandomAPIs(isContrib: Boolean) extends
StaticAnnotation {
+ private[mxnet] def macroTransform(annottees: Any*) =
+ macro TypedSymbolRandomAPIMacro.typeSafeAPIDefs
+}
+
+/**
+ * For non-typed Symbol API
+ */
private[mxnet] object SymbolMacro extends GeneratorBase {
def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
@@ -63,6 +71,9 @@ private[mxnet] object SymbolMacro extends GeneratorBase {
}
}
+/**
+ * Symbol.api code generation
+ */
private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase {
def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*):
c.Expr[Any] = {
@@ -71,9 +82,9 @@ private[mxnet] object TypedSymbolAPIMacro extends
GeneratorBase {
case q"new AddSymbolAPIs($b)" => c.eval[Boolean](c.Expr(b))
}
- val functions = typeSafeFunctionsToGenerate(isSymbol = true, isContrib)
+ val functionDefs = typeSafeFunctionsToGenerate(isSymbol = true, isContrib)
+ .map(f => buildTypedFunction(c)(f))
- val functionDefs = functions.map(f => buildTypedFunction(c)(f))
structGeneration(c)(functionDefs, annottees: _*)
}
@@ -82,45 +93,111 @@ private[mxnet] object TypedSymbolAPIMacro extends
GeneratorBase {
import c.universe._
val returnType = "org.apache.mxnet.Symbol"
- val symbolType = "org.apache.mxnet.Symbol"
-
- // Construct argument field
- val argDef = ListBuffer[String]()
- argDef ++= typedFunctionCommonArgDef(function)
- argDef += "name : String = null"
- argDef += "attr : Map[String, String] = null"
-
- // Construct Implementation field
- val impl = ListBuffer[String]()
- impl += "val map = scala.collection.mutable.Map[String, Any]()"
- impl += s"var args = scala.collection.Seq[$symbolType]()"
-
- // Symbol arg implementation
- impl ++= function.listOfArgs.map { arg =>
- if (arg.argType.equals(s"Array[$symbolType]")) {
- s"if (!${arg.safeArgName}.isEmpty) args = ${arg.safeArgName}.toSeq"
- } else {
- // all go in kwargs
- if (arg.isOptional) {
- s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") =
${arg.safeArgName}.get"""
+
+ // Construct API arguments declaration
+ val argDecl = super.typedFunctionCommonArgDef(function) :+
+ "name : String = null" :+
+ "attr : Map[String, String] = null"
+
+ // Map API input args to backend args
+ val backendArgsMapping =
+ function.listOfArgs.map { arg =>
+ if (arg.argType.equals(s"Array[org.apache.mxnet.Symbol]")) {
+ s"args = ${arg.safeArgName}.toSeq"
} else {
- s"""map("${arg.argName}") = ${arg.safeArgName}"""
+ // all go in kwargs
+ if (arg.isOptional) {
+ s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") =
${arg.safeArgName}.get"""
+ } else {
+ s"""map("${arg.argName}") = ${arg.safeArgName}"""
+ }
}
}
- }
- impl +=
- s"""org.apache.mxnet.Symbol.createSymbolGeneral(
- | "${function.name}", name, attr, args, map.toMap)
+ val impl =
+ s"""
+ |def ${function.name}
+ | (${argDecl.mkString(",")}): $returnType = {
+ |
+ | val map = scala.collection.mutable.Map[String, Any]()
+ | var args = scala.collection.Seq[org.apache.mxnet.Symbol]()
+ |
+ | ${backendArgsMapping.mkString("\n")}
+ |
+ | org.apache.mxnet.Symbol.createSymbolGeneral(
+ | "${function.name}", name, attr, args, map.toMap)
+ |}
""".stripMargin
- // Combine and build the function string
- val finalStr =
- s"""def ${function.name}
- | (${argDef.mkString(",")}) : $returnType
- | = {${impl.mkString("\n")}}
+ c.parse(impl).asInstanceOf[DefDef]
+ }
+}
+
+
+/**
+ * Symbol.random code generation
+ */
+private[mxnet] object TypedSymbolRandomAPIMacro extends GeneratorBase
+ with RandomHelpers {
+
+ def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*):
c.Expr[Any] = {
+ val functionDefs = typeSafeRandomFunctionsToGenerate(isSymbol = true)
+ .map(f => buildTypedFunction(c)(f))
+
+ structGeneration(c)(functionDefs, annottees: _*)
+ }
+
+ protected def buildTypedFunction(c: blackbox.Context)
+ (function: Func): c.universe.DefDef = {
+ import c.universe._
+
+ val returnType = "org.apache.mxnet.Symbol"
+
+ // Construct API arguments declaration
+ val argDecl = super.typedFunctionCommonArgDef(function) :+
+ "name : String = null" :+
+ "attr : Map[String, String] = null"
+
+ // Map API input args to backend args
+ val backendArgsMapping =
+ function.listOfArgs.map { arg =>
+ if (arg.argType.equals(s"Array[org.apache.mxnet.Symbol]")) {
+ s"args = ${arg.safeArgName}.toSeq"
+ } else {
+ // all go in kwargs
+ if (arg.isOptional) {
+ s"""if (${arg.safeArgName}.isDefined) map("${arg.argName}") =
${arg.safeArgName}.get"""
+ } else {
+ s"""map("${arg.argName}") = ${arg.safeArgName}"""
+ }
+ }
+ }
+
+ val impl =
+ s"""
+ |def ${function.name}${randomGenericTypeSpec(true, true)}
+ | (${argDecl.mkString(",")}): $returnType = {
+ |
+ | val map = scala.collection.mutable.Map[String, Any]()
+ | var args = scala.collection.Seq[org.apache.mxnet.Symbol]()
+ | val isScalar = SymbolOrScalar[T].isScalar
+ |
+ | ${backendArgsMapping.mkString("\n")}
+ |
+ | val target = if(isScalar) {
+ | "random_${function.name}"
+ | } else {
+ | "sample_${function.name}"
+ | }
+ |
+ | ${unhackNormalFunc(function)}
+ |
+ | org.apache.mxnet.Symbol.createSymbolGeneral(
+ | target, name, attr, args, map.toMap)
+ |}
""".stripMargin
- c.parse(finalStr).asInstanceOf[DefDef]
+ c.parse(impl).asInstanceOf[DefDef]
}
}
+
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services