This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new fc5fce7  [MXNET-918] Random module (#13039)
fc5fce7 is described below

commit fc5fce769e8a6beea807e2a30b4930d055371971
Author: mathieu <[email protected]>
AuthorDate: Fri Dec 14 23:22:06 2018 +0100

    [MXNET-918] Random module (#13039)
    
    * introduce random API
    
    * revert useless changes
    
    * shorter types in APIDoc gen code
    
    * fix after merge from master
    
    * Trigger CI
    
    * temp code / diag on CI
    
    * cleanup type-class code
    
    * cleanup type-class code
    
    * fix scalastyle
---
 .../src/main/scala/org/apache/mxnet/Base.scala     |  18 +++
 .../src/main/scala/org/apache/mxnet/NDArray.scala  |   1 +
 .../main/scala/org/apache/mxnet/NDArrayAPI.scala   |  13 +-
 .../src/main/scala/org/apache/mxnet/Symbol.scala   |   1 +
 .../main/scala/org/apache/mxnet/SymbolAPI.scala    |  12 +-
 .../test/scala/org/apache/mxnet/NDArraySuite.scala |  17 ++
 .../test/scala/org/apache/mxnet/SymbolSuite.scala  |  22 +++
 .../scala/org/apache/mxnet/APIDocGenerator.scala   |  43 +++++-
 .../scala/org/apache/mxnet/GeneratorBase.scala     |  75 ++++++++-
 .../main/scala/org/apache/mxnet/NDArrayMacro.scala | 171 ++++++++++++++++-----
 .../main/scala/org/apache/mxnet/SymbolMacro.scala  | 147 +++++++++++++-----
 11 files changed, 435 insertions(+), 85 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
index b2a53fd..bb9518d 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
@@ -153,3 +153,21 @@ private[mxnet] object Base {
 }
 
 class MXNetError(val err: String) extends Exception(err)
+
+// Some type-classes to ease the work in Symbol.random and NDArray.random 
modules
+
+class SymbolOrScalar[T](val isScalar: Boolean)
+object SymbolOrScalar {
+  def apply[T](implicit ev: SymbolOrScalar[T]): SymbolOrScalar[T] = ev
+  implicit object FloatWitness extends SymbolOrScalar[Float](true)
+  implicit object IntWitness extends SymbolOrScalar[Int](true)
+  implicit object SymbolWitness extends SymbolOrScalar[Symbol](false)
+}
+
+class NDArrayOrScalar[T](val isScalar: Boolean)
+object NDArrayOrScalar {
+  def apply[T](implicit ev: NDArrayOrScalar[T]): NDArrayOrScalar[T] = ev
+  implicit object FloatWitness extends NDArrayOrScalar[Float](true)
+  implicit object IntWitness extends NDArrayOrScalar[Int](true)
+  implicit object NDArrayWitness extends NDArrayOrScalar[NDArray](false)
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 3a0c3c1..1259581 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -40,6 +40,7 @@ object NDArray extends NDArrayBase {
   private val functions: Map[String, NDArrayFunction] = initNDArrayModule()
 
   val api = NDArrayAPI
+  val random = NDArrayRandomAPI
 
   private def addDependency(froms: Array[NDArray], tos: Array[NDArray]): Unit 
= {
     froms.foreach { from =>
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala
index 1d8551c..024fed1 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayAPI.scala
@@ -15,11 +15,22 @@
  * limitations under the License.
  */
 package org.apache.mxnet
-@AddNDArrayAPIs(false)
+
 /**
   * typesafe NDArray API: NDArray.api._
   * Main code will be generated during compile time through Macros
   */
+@AddNDArrayAPIs(false)
 object NDArrayAPI extends NDArrayAPIBase {
   // TODO: Implement CustomOp for NDArray
 }
+
+/**
+  * typesafe NDArray random module: NDArray.random._
+  * Main code will be generated during compile time through Macros
+  */
+@AddNDArrayRandomAPIs(false)
+object NDArrayRandomAPI extends NDArrayRandomAPIBase {
+
+}
+
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index 01349a6..29885fc 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -842,6 +842,7 @@ object Symbol extends SymbolBase {
   private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3)
 
   val api = SymbolAPI
+  val random = SymbolRandomAPI
 
   def pow(sym1: Symbol, sym2: Symbol): Symbol = {
     Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2))
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
index 1bfb055..f166de1 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/SymbolAPI.scala
@@ -19,11 +19,11 @@ package org.apache.mxnet
 import scala.collection.mutable
 
 
-@AddSymbolAPIs(false)
 /**
   * typesafe Symbol API: Symbol.api._
   * Main code will be generated during compile time through Macros
   */
+@AddSymbolAPIs(false)
 object SymbolAPI extends SymbolAPIBase {
   def Custom (op_type : String, kwargs : mutable.Map[String, Any],
              name : String = null, attr : Map[String, String] = null) : Symbol 
= {
@@ -32,3 +32,13 @@ object SymbolAPI extends SymbolAPIBase {
     Symbol.createSymbolGeneral("Custom", name, attr, Seq(), map.toMap)
   }
 }
+
+/**
+  * typesafe Symbol random module: Symbol.random._
+  * Main code will be generated during compile time through Macros
+  */
+@AddSymbolRandomAPIs(false)
+object SymbolRandomAPI extends SymbolRandomAPIBase {
+
+}
+
diff --git 
a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index 5d88bb3..7992a0e 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -576,4 +576,21 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll 
with Matchers {
     assert(arr.internal.toDoubleArray === Array(2d, 2d))
     assert(arr.internal.toByteArray === Array(2.toByte, 2.toByte))
   }
+
+  test("NDArray random module is generated properly") {
+    val lam = NDArray.ones(1, 2)
+    val rnd = NDArray.random.poisson(lam = Some(lam), shape = Some(Shape(3, 
4)))
+    val rnd2 = NDArray.random.poisson(lam = Some(1f), shape = Some(Shape(3, 
4)))
+    assert(rnd.shape === Shape(1, 2, 3, 4))
+    assert(rnd2.shape === Shape(3, 4))
+  }
+
+  test("NDArray random module is generated properly - special case of 
'normal'") {
+    val mu = NDArray.ones(1, 2)
+    val sigma = NDArray.ones(1, 2) * 2
+    val rnd = NDArray.random.normal(mu = Some(mu), sigma = Some(sigma), shape 
= Some(Shape(3, 4)))
+    val rnd2 = NDArray.random.normal(mu = Some(1f), sigma = Some(2f), shape = 
Some(Shape(3, 4)))
+    assert(rnd.shape === Shape(1, 2, 3, 4))
+    assert(rnd2.shape === Shape(3, 4))
+  }
 }
diff --git 
a/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala
index ebb61d7..d134c83 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/SymbolSuite.scala
@@ -20,6 +20,7 @@ package org.apache.mxnet
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
 
 class SymbolSuite extends FunSuite with BeforeAndAfterAll {
+
   test("symbol compose") {
     val data = Symbol.Variable("data")
 
@@ -71,4 +72,25 @@ class SymbolSuite extends FunSuite with BeforeAndAfterAll {
     val data2 = data.clone()
     assert(data.toJson === data2.toJson)
   }
+
+  test("Symbol random module is generated properly") {
+    val lam = Symbol.Variable("lam")
+    val rnd = Symbol.random.poisson(lam = Some(lam), shape = Some(Shape(2, 2)))
+    val rnd2 = Symbol.random.poisson(lam = Some(1f), shape = Some(Shape(2, 2)))
+    // scalastyle:off println
+    println(s"Symbol.random.poisson debug info: ${rnd.debugStr}")
+    println(s"Symbol.random.poisson debug info: ${rnd2.debugStr}")
+    // scalastyle:on println
+  }
+
+  test("Symbol random module is generated properly - special case of 
'normal'") {
+    val loc = Symbol.Variable("loc")
+    val scale = Symbol.Variable("scale")
+    val rnd = Symbol.random.normal(mu = Some(loc), sigma = Some(scale), shape 
= Some(Shape(2, 2)))
+    val rnd2 = Symbol.random.normal(mu = Some(1f), sigma = Some(2f), shape = 
Some(Shape(2, 2)))
+    // scalastyle:off println
+    println(s"Symbol.random.sample_normal debug info: ${rnd.debugStr}")
+    println(s"Symbol.random.random_normal debug info: ${rnd2.debugStr}")
+    // scalastyle:on println
+  }
 }
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
index ce12dc7..97cd18a 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
@@ -27,13 +27,15 @@ import scala.collection.mutable.ListBuffer
   * Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
   * The code will be executed during Macros stage and file live in Core stage
   */
-private[mxnet] object APIDocGenerator extends GeneratorBase {
+private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers 
{
 
   def main(args: Array[String]): Unit = {
     val FILE_PATH = args(0)
     val hashCollector = ListBuffer[String]()
     hashCollector += typeSafeClassGen(FILE_PATH, true)
     hashCollector += typeSafeClassGen(FILE_PATH, false)
+    hashCollector += typeSafeRandomClassGen(FILE_PATH, true)
+    hashCollector += typeSafeRandomClassGen(FILE_PATH, false)
     hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
     hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
     hashCollector += javaClassGen(FILE_PATH)
@@ -57,8 +59,27 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
 
     writeFile(
       FILE_PATH,
+      "package org.apache.mxnet",
       if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
+      "import org.apache.mxnet.annotation.Experimental",
+      generated)
+  }
+
+  def typeSafeRandomClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
+    val generated = typeSafeRandomFunctionsToGenerate(isSymbol)
+      .map { func =>
+        val scalaDoc = generateAPIDocFromBackend(func)
+        val typeParameter = randomGenericTypeSpec(isSymbol, false)
+        val decl = generateAPISignature(func, isSymbol, typeParameter)
+        s"$scalaDoc\n$decl"
+      }
+
+    writeFile(
+      FILE_PATH,
       "package org.apache.mxnet",
+      if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase",
+      """import org.apache.mxnet.annotation.Experimental
+        |import scala.reflect.ClassTag""".stripMargin,
       generated)
   }
 
@@ -85,8 +106,9 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
 
     writeFile(
       FILE_PATH,
-      if (isSymbol) "SymbolBase" else "NDArrayBase",
       "package org.apache.mxnet",
+      if (isSymbol) "SymbolBase" else "NDArrayBase",
+      "import org.apache.mxnet.annotation.Experimental",
       absFuncs)
   }
 
@@ -110,7 +132,12 @@ private[mxnet] object APIDocGenerator extends 
GeneratorBase {
       }).toSeq
     val packageName = "NDArrayBase"
     val packageDef = "package org.apache.mxnet.javaapi"
-    writeFile(filePath + "javaapi/", packageName, packageDef, absFuncs)
+    writeFile(
+      filePath + "javaapi/",
+      packageDef,
+      packageName,
+      "import org.apache.mxnet.annotation.Experimental",
+      absFuncs)
   }
 
   def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String 
= {
@@ -146,7 +173,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase 
{
     }
   }
 
-  def generateAPISignature(func: Func, isSymbol: Boolean): String = {
+  def generateAPISignature(func: Func, isSymbol: Boolean, typeParameter: 
String = ""): String = {
     val argDef = ListBuffer[String]()
 
     argDef ++= typedFunctionCommonArgDef(func)
@@ -162,7 +189,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase 
{
     val returnType = func.returnType
 
     s"""@Experimental
-       |def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin
+       |def ${func.name}$typeParameter (${argDef.mkString(", ")}): 
$returnType""".stripMargin
   }
 
   def generateJavaAPISignature(func : Func) : String = {
@@ -223,8 +250,8 @@ private[mxnet] object APIDocGenerator extends GeneratorBase 
{
     }
   }
 
-  def writeFile(FILE_PATH: String, className: String, packageDef: String,
-                absFuncs: Seq[String]): String = {
+  def writeFile(FILE_PATH: String, packageDef: String, className: String,
+                imports: String, absFuncs: Seq[String]): String = {
 
     val finalStr =
       s"""/*
@@ -246,7 +273,7 @@ private[mxnet] object APIDocGenerator extends GeneratorBase 
{
          |
          |$packageDef
          |
-         |import org.apache.mxnet.annotation.Experimental
+         |$imports
          |
          |// scalastyle:off
          |abstract class $className {
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
index 9245ef1..1c2c4fd 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
@@ -23,7 +23,7 @@ import org.apache.mxnet.utils.{CToScalaUtils, 
OperatorBuildUtils}
 import scala.collection.mutable.ListBuffer
 import scala.reflect.macros.blackbox
 
-abstract class GeneratorBase {
+private[mxnet] abstract class GeneratorBase {
   type Handle = Long
 
   case class Arg(argName: String, argType: String, argDesc: String, 
isOptional: Boolean) {
@@ -46,7 +46,8 @@ abstract class GeneratorBase {
     }
   }
 
-  def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib: Boolean): 
List[Func] = {
+  // filter the operators to generate in the type-safe Symbol.api and 
NDArray.api
+  protected def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib: 
Boolean): List[Func] = {
     // Operators that should not be generated
     val notGenerated = Set("Custom")
 
@@ -144,8 +145,8 @@ abstract class GeneratorBase {
     result
   }
 
+  // build function argument definition, with optionality, and safe names
   protected def typedFunctionCommonArgDef(func: Func): List[String] = {
-    // build function argument definition, with optionality, and safe names
     func.listOfArgs.map(arg =>
       if (arg.isOptional) {
         // let's avoid a stupid Option[Array[...]]
@@ -161,3 +162,71 @@ abstract class GeneratorBase {
     )
   }
 }
+
+// a mixin to ease generating the Random module
+private[mxnet] trait RandomHelpers {
+  self: GeneratorBase =>
+
+  // a generic type spec used in Symbol.random and NDArray.random modules
+  protected def randomGenericTypeSpec(isSymbol: Boolean, fullPackageSpec: 
Boolean): String = {
+    val classTag = if (fullPackageSpec) "scala.reflect.ClassTag" else 
"ClassTag"
+    if (isSymbol) s"[T: SymbolOrScalar : $classTag]"
+    else s"[T: NDArrayOrScalar : $classTag]"
+  }
+
+  // filter the operators to generate in the type-safe Symbol.random and 
NDArray.random
+  protected def typeSafeRandomFunctionsToGenerate(isSymbol: Boolean): 
List[Func] = {
+    getBackEndFunctions(isSymbol)
+      .filter(f => f.name.startsWith("_sample_") || 
f.name.startsWith("_random_"))
+      .map(f => f.copy(name = f.name.stripPrefix("_")))
+      // unify _random and _sample
+      .map(f => unifyRandom(f, isSymbol))
+      // deduplicate
+      .groupBy(_.name)
+      .mapValues(_.head)
+      .values
+      .toList
+  }
+
+  // unify call targets (random_xyz and sample_xyz) and unify their argument 
types
+  private def unifyRandom(func: Func, isSymbol: Boolean): Func = {
+    var typeConv = Set("org.apache.mxnet.NDArray", "org.apache.mxnet.Symbol",
+      "java.lang.Float", "java.lang.Integer")
+
+    func.copy(
+      name = func.name.replaceAll("(random|sample)_", ""),
+      listOfArgs = func.listOfArgs
+        .map(hackNormalFunc)
+        .map(arg =>
+          if (typeConv(arg.argType)) arg.copy(argType = "T")
+          else arg
+        )
+      // TODO: some functions are non consistent in random_ vs sample_ 
regarding optionality
+      // we may try to unify that as well here.
+    )
+  }
+
+  // hacks to manage the fact that random_normal and sample_normal have
+  // non-consistent parameter naming in the back-end
+  // this first one, merge loc/scale and mu/sigma
+  protected def hackNormalFunc(arg: Arg): Arg = {
+    if (arg.argName == "loc") arg.copy(argName = "mu")
+    else if (arg.argName == "scale") arg.copy(argName = "sigma")
+    else arg
+  }
+
+  // this second one reverts this merge prior to back-end call
+  protected def unhackNormalFunc(func: Func): String = {
+    if (func.name.equals("normal")) {
+      s"""if(target.equals("random_normal")) {
+         |  if(map.contains("mu")) { map("loc") = map("mu"); map.remove("mu")  
}
+         |  if(map.contains("sigma")) { map("scale") = map("sigma"); 
map.remove("sigma") }
+         |}
+       """.stripMargin
+    } else {
+      ""
+    }
+
+  }
+
+}
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
index d85abe1..c18694b 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
@@ -18,7 +18,6 @@
 package org.apache.mxnet
 
 import scala.annotation.StaticAnnotation
-import scala.collection.mutable.ListBuffer
 import scala.language.experimental.macros
 import scala.reflect.macros.blackbox
 
@@ -30,6 +29,14 @@ private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) 
extends StaticAnnotation
   private[mxnet] def macroTransform(annottees: Any*) = macro 
TypedNDArrayAPIMacro.typeSafeAPIDefs
 }
 
+private[mxnet] class AddNDArrayRandomAPIs(isContrib: Boolean) extends 
StaticAnnotation {
+  private[mxnet] def macroTransform(annottees: Any*) =
+  macro TypedNDArrayRandomAPIMacro.typeSafeAPIDefs
+}
+
+/**
+  * For non-typed NDArray API
+  */
 private[mxnet] object NDArrayMacro extends GeneratorBase {
 
   def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
@@ -70,6 +77,9 @@ private[mxnet] object NDArrayMacro extends GeneratorBase {
   }
 }
 
+/**
+  * NDArray.api code generation
+  */
 private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase {
 
   def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): 
c.Expr[Any] = {
@@ -78,9 +88,9 @@ private[mxnet] object TypedNDArrayAPIMacro extends 
GeneratorBase {
       case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
     }
 
-    val functions = typeSafeFunctionsToGenerate(isSymbol = false, isContrib)
+    val functionDefs = typeSafeFunctionsToGenerate(isSymbol = false, isContrib)
+      .map(f => buildTypedFunction(c)(f))
 
-    val functionDefs = functions.map(f => buildTypedFunction(c)(f))
     structGeneration(c)(functionDefs, annottees: _*)
   }
 
@@ -89,49 +99,136 @@ private[mxnet] object TypedNDArrayAPIMacro extends 
GeneratorBase {
     import c.universe._
 
     val returnType = "org.apache.mxnet.NDArrayFuncReturn"
-    val ndarrayType = "org.apache.mxnet.NDArray"
-
-    // Construct argument field
-    val argDef = ListBuffer[String]()
-    argDef ++= typedFunctionCommonArgDef(function)
-    argDef += "out : Option[NDArray] = None"
-
-    // Construct Implementation field
-    var impl = ListBuffer[String]()
-    impl += "val map = scala.collection.mutable.Map[String, Any]()"
-    impl += s"val args = 
scala.collection.mutable.ArrayBuffer.empty[$ndarrayType]"
-
-    // NDArray arg implementation
-    impl ++= function.listOfArgs.map { arg =>
-      if (arg.argType.equals(s"Array[$ndarrayType]")) {
-        s"args ++= ${arg.safeArgName}"
-      } else {
-        val base =
-          if (arg.argType.equals(ndarrayType)) {
-            // ndarrays go to args
+
+    // Construct API arguments declaration
+    val argDecl = super.typedFunctionCommonArgDef(function) :+ "out : 
Option[NDArray] = None"
+
+    // Map API input args to backend args
+    val backendArgsMapping =
+      function.listOfArgs.map { arg =>
+        // ndarrays go to args, other types go to kwargs
+        if (arg.argType.equals(s"Array[org.apache.mxnet.NDArray]")) {
+          s"args ++= ${arg.safeArgName}.toSeq"
+        } else {
+          val base = if (arg.argType.equals("org.apache.mxnet.NDArray")) {
             s"args += ${arg.safeArgName}"
           } else {
-            // other types go to kwargs
             s"""map("${arg.argName}") = ${arg.safeArgName}"""
           }
-        if (arg.isOptional) s"if (!${arg.safeArgName}.isEmpty) $base.get"
-        else base
+          if (arg.isOptional) s"if (!${arg.safeArgName}.isEmpty) $base.get"
+          else base
+        }
       }
-    }
 
-    impl +=
-      s"""if (!out.isEmpty) map("out") = out.get
-         |org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(
-         |  "${function.name}", args.toSeq, map.toMap)
+    val impl =
+      s"""
+         |def ${function.name}
+         |  (${argDecl.mkString(",")}): $returnType = {
+         |
+         |  val map = scala.collection.mutable.Map[String, Any]()
+         |  val args = 
scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]
+         |
+         |  if (!out.isEmpty) map("out") = out.get
+         |
+         |  ${backendArgsMapping.mkString("\n")}
+         |
+         |  org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(
+         |    "${function.name}", args.toSeq, map.toMap)
+         |}
        """.stripMargin
 
-    // Combine and build the function string
-    val finalStr =
-      s"""def ${function.name}
-         |   (${argDef.mkString(",")}) : $returnType
-         | = {${impl.mkString("\n")}}
+    c.parse(impl).asInstanceOf[DefDef]
+  }
+}
+
+
+/**
+  * NDArray.random code generation
+  */
+private[mxnet] object TypedNDArrayRandomAPIMacro extends GeneratorBase
+  with RandomHelpers {
+
+  def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): 
c.Expr[Any] = {
+    // Note: no contrib managed in this module
+
+    val functionDefs = typeSafeRandomFunctionsToGenerate(isSymbol = false)
+      .map(f => buildTypedFunction(c)(f))
+
+    structGeneration(c)(functionDefs, annottees: _*)
+  }
+
+  protected def buildTypedFunction(c: blackbox.Context)
+                                  (function: Func): c.universe.DefDef = {
+    import c.universe._
+
+    val returnType = "org.apache.mxnet.NDArrayFuncReturn"
+
+    // Construct API arguments declaration
+    val argDecl = super.typedFunctionCommonArgDef(function) :+ "out : 
Option[NDArray] = None"
+
+    // Map API input args to backend args
+    val backendArgsMapping =
+      function.listOfArgs.map { arg =>
+        // ndarrays go to args, other types go to kwargs
+        if (arg.argType.equals("Array[org.apache.mxnet.NDArray]")) {
+          s"args ++= ${arg.safeArgName}.toSeq"
+        } else {
+          if (arg.argType.equals("T")) {
+            if (arg.isOptional) {
+              s"""if(${arg.safeArgName}.isDefined) {
+                 |  if(isScalar) {
+                 |    map("${arg.argName}") = ${arg.safeArgName}.get
+                 |  } else {
+                 |    args += 
${arg.safeArgName}.get.asInstanceOf[org.apache.mxnet.NDArray]
+                 |  }
+                 |}
+             """.stripMargin
+            } else {
+              s"""if(isScalar) {
+                 |  map("${arg.argName}") = ${arg.safeArgName}
+                 |} else {
+                 |  args += 
${arg.safeArgName}.asInstanceOf[org.apache.mxnet.NDArray]
+                 |}
+             """.stripMargin
+            }
+          } else {
+            if (arg.isOptional) {
+              s"""if (${arg.safeArgName}.isDefined) 
map("${arg.argName}")=${arg.safeArgName}.get"""
+            } else {
+              s"""map("${arg.argName}") = ${arg.safeArgName}"""
+            }
+          }
+        }
+      }
+
+    val impl =
+      s"""
+         |def ${function.name}${randomGenericTypeSpec(false, true)}
+         |  (${argDecl.mkString(",")}): $returnType = {
+         |
+         |  val map = scala.collection.mutable.Map[String, Any]()
+         |  val args = 
scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]
+         |  val isScalar = NDArrayOrScalar[T].isScalar
+         |
+         |  if(out.isDefined) map("out") = out.get
+         |
+         |  ${backendArgsMapping.mkString("\n")}
+         |
+         |  val target = if(isScalar) {
+         |    "random_${function.name}"
+         |  } else {
+         |    "sample_${function.name}"
+         |  }
+         |
+         |  ${unhackNormalFunc(function)}
+         |
+         |  org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(
+         |    target, args.toSeq, map.toMap)
+         |}
        """.stripMargin
 
-    c.parse(finalStr).asInstanceOf[DefDef]
+    c.parse(impl).asInstanceOf[DefDef]
   }
+
+
 }
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
index ab864e1..7ec80b9 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
@@ -17,8 +17,8 @@
 
 package org.apache.mxnet
 
+
 import scala.annotation.StaticAnnotation
-import scala.collection.mutable.ListBuffer
 import scala.language.experimental.macros
 import scala.reflect.macros.blackbox
 
@@ -30,6 +30,14 @@ private[mxnet] class AddSymbolAPIs(isContrib: Boolean) 
extends StaticAnnotation
   private[mxnet] def macroTransform(annottees: Any*) = macro 
TypedSymbolAPIMacro.typeSafeAPIDefs
 }
 
+private[mxnet] class AddSymbolRandomAPIs(isContrib: Boolean) extends 
StaticAnnotation {
+  private[mxnet] def macroTransform(annottees: Any*) =
+  macro TypedSymbolRandomAPIMacro.typeSafeAPIDefs
+}
+
+/**
+  * For non-typed Symbol API
+  */
 private[mxnet] object SymbolMacro extends GeneratorBase {
 
   def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
@@ -63,6 +71,9 @@ private[mxnet] object SymbolMacro extends GeneratorBase {
   }
 }
 
+/**
+  * Symbol.api code generation
+  */
 private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase {
 
   def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): 
c.Expr[Any] = {
@@ -71,9 +82,9 @@ private[mxnet] object TypedSymbolAPIMacro extends 
GeneratorBase {
       case q"new AddSymbolAPIs($b)" => c.eval[Boolean](c.Expr(b))
     }
 
-    val functions = typeSafeFunctionsToGenerate(isSymbol = true, isContrib)
+    val functionDefs = typeSafeFunctionsToGenerate(isSymbol = true, isContrib)
+      .map(f => buildTypedFunction(c)(f))
 
-    val functionDefs = functions.map(f => buildTypedFunction(c)(f))
     structGeneration(c)(functionDefs, annottees: _*)
   }
 
@@ -82,45 +93,111 @@ private[mxnet] object TypedSymbolAPIMacro extends 
GeneratorBase {
     import c.universe._
 
     val returnType = "org.apache.mxnet.Symbol"
-    val symbolType = "org.apache.mxnet.Symbol"
-
-    // Construct argument field
-    val argDef = ListBuffer[String]()
-    argDef ++= typedFunctionCommonArgDef(function)
-    argDef += "name : String = null"
-    argDef += "attr : Map[String, String] = null"
-
-    // Construct Implementation field
-    val impl = ListBuffer[String]()
-    impl += "val map = scala.collection.mutable.Map[String, Any]()"
-    impl += s"var args = scala.collection.Seq[$symbolType]()"
-
-    // Symbol arg implementation
-    impl ++= function.listOfArgs.map { arg =>
-      if (arg.argType.equals(s"Array[$symbolType]")) {
-        s"if (!${arg.safeArgName}.isEmpty) args = ${arg.safeArgName}.toSeq"
-      } else {
-        // all go in kwargs
-        if (arg.isOptional) {
-          s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") = 
${arg.safeArgName}.get"""
+
+    // Construct API arguments declaration
+    val argDecl = super.typedFunctionCommonArgDef(function) :+
+      "name : String = null" :+
+      "attr : Map[String, String] = null"
+
+    // Map API input args to backend args
+    val backendArgsMapping =
+      function.listOfArgs.map { arg =>
+        if (arg.argType.equals(s"Array[org.apache.mxnet.Symbol]")) {
+          s"args = ${arg.safeArgName}.toSeq"
         } else {
-          s"""map("${arg.argName}") = ${arg.safeArgName}"""
+          // all go in kwargs
+          if (arg.isOptional) {
+            s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") = 
${arg.safeArgName}.get"""
+          } else {
+            s"""map("${arg.argName}") = ${arg.safeArgName}"""
+          }
         }
       }
-    }
 
-    impl +=
-      s"""org.apache.mxnet.Symbol.createSymbolGeneral(
-         |  "${function.name}", name, attr, args, map.toMap)
+    val impl =
+      s"""
+         |def ${function.name}
+         |  (${argDecl.mkString(",")}): $returnType = {
+         |
+         |  val map = scala.collection.mutable.Map[String, Any]()
+         |  var args = scala.collection.Seq[org.apache.mxnet.Symbol]()
+         |
+         |  ${backendArgsMapping.mkString("\n")}
+         |
+         |  org.apache.mxnet.Symbol.createSymbolGeneral(
+         |    "${function.name}", name, attr, args, map.toMap)
+         |}
        """.stripMargin
 
-    // Combine and build the function string
-    val finalStr =
-      s"""def ${function.name}
-         |   (${argDef.mkString(",")}) : $returnType
-         | = {${impl.mkString("\n")}}
+    c.parse(impl).asInstanceOf[DefDef]
+  }
+}
+
+
+/**
+  * Symbol.random code generation
+  */
+private[mxnet] object TypedSymbolRandomAPIMacro extends GeneratorBase
+  with RandomHelpers {
+
+  def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): 
c.Expr[Any] = {
+    val functionDefs = typeSafeRandomFunctionsToGenerate(isSymbol = true)
+      .map(f => buildTypedFunction(c)(f))
+
+    structGeneration(c)(functionDefs, annottees: _*)
+  }
+
+  protected def buildTypedFunction(c: blackbox.Context)
+                                  (function: Func): c.universe.DefDef = {
+    import c.universe._
+
+    val returnType = "org.apache.mxnet.Symbol"
+
+    // Construct API arguments declaration
+    val argDecl = super.typedFunctionCommonArgDef(function) :+
+      "name : String = null" :+
+      "attr : Map[String, String] = null"
+
+    // Map API input args to backend args
+    val backendArgsMapping =
+      function.listOfArgs.map { arg =>
+        if (arg.argType.equals(s"Array[org.apache.mxnet.Symbol]")) {
+          s"args = ${arg.safeArgName}.toSeq"
+        } else {
+          // all go in kwargs
+          if (arg.isOptional) {
+            s"""if (${arg.safeArgName}.isDefined) map("${arg.argName}") = 
${arg.safeArgName}.get"""
+          } else {
+            s"""map("${arg.argName}") = ${arg.safeArgName}"""
+          }
+        }
+      }
+
+    val impl =
+      s"""
+         |def ${function.name}${randomGenericTypeSpec(true, true)}
+         |  (${argDecl.mkString(",")}): $returnType = {
+         |
+         |  val map = scala.collection.mutable.Map[String, Any]()
+         |  var args = scala.collection.Seq[org.apache.mxnet.Symbol]()
+         |  val isScalar = SymbolOrScalar[T].isScalar
+         |
+         |  ${backendArgsMapping.mkString("\n")}
+         |
+         |  val target = if(isScalar) {
+         |    "random_${function.name}"
+         |  } else {
+         |    "sample_${function.name}"
+         |  }
+         |
+         |  ${unhackNormalFunc(function)}
+         |
+         |  org.apache.mxnet.Symbol.createSymbolGeneral(
+         |    target, name, attr, args, map.toMap)
+         |}
        """.stripMargin
 
-    c.parse(finalStr).asInstanceOf[DefDef]
+    c.parse(impl).asInstanceOf[DefDef]
   }
 }
+

Reply via email to