This is an automated email from the ASF dual-hosted git repository.
liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 53c5a72 [MXNET-918] Introduce Random module / Refact code generation
(#13038)
53c5a72 is described below
commit 53c5a72c1f28dad284b7f6d7699cca6f0eec776a
Author: mathieu <[email protected]>
AuthorDate: Mon Nov 5 18:55:45 2018 +0100
[MXNET-918] Introduce Random module / Refact code generation (#13038)
* refactor code gen
* remove xxxAPIMacroBase (overkill)
* CI errors / scala-style
* PR review comments
---
.../scala/org/apache/mxnet/APIDocGenerator.scala | 234 ++++++++----------
.../scala/org/apache/mxnet/GeneratorBase.scala | 157 ++++++++++++
.../main/scala/org/apache/mxnet/NDArrayMacro.scala | 263 +++++++--------------
.../main/scala/org/apache/mxnet/SymbolMacro.scala | 250 ++++++--------------
4 files changed, 411 insertions(+), 493 deletions(-)
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
index b4efa65..bfa378e 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
@@ -17,178 +17,154 @@
package org.apache.mxnet
-import org.apache.mxnet.init.Base._
-import org.apache.mxnet.utils.CToScalaUtils
import java.io._
import java.security.MessageDigest
-import scala.collection.mutable.{ArrayBuffer, ListBuffer}
+import scala.collection.mutable.ListBuffer
/**
* This object will generate the Scala documentation of the new Scala API
* Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
* The code will be executed during Macros stage and file live in Core stage
*/
-private[mxnet] object APIDocGenerator{
- case class absClassArg(argName : String, argType : String, argDesc : String,
isOptional : Boolean)
- case class absClassFunction(name : String, desc : String,
- listOfArgs: List[absClassArg], returnType : String)
+private[mxnet] object APIDocGenerator extends GeneratorBase {
-
- def main(args: Array[String]) : Unit = {
+ def main(args: Array[String]): Unit = {
val FILE_PATH = args(0)
val hashCollector = ListBuffer[String]()
- hashCollector += absClassGen(FILE_PATH, true)
- hashCollector += absClassGen(FILE_PATH, false)
+ hashCollector += typeSafeClassGen(FILE_PATH, true)
+ hashCollector += typeSafeClassGen(FILE_PATH, false)
hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
val finalHash = hashCollector.mkString("\n")
}
- def MD5Generator(input : String) : String = {
+ def MD5Generator(input: String): String = {
val md = MessageDigest.getInstance("MD5")
md.update(input.getBytes("UTF-8"))
val digest = md.digest()
org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
}
- def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
- // scalastyle:off
- val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
- // Defines Operators that should not generated
- val notGenerated = Set("Custom")
- // TODO: Add Filter to the same location in case of refactor
- val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
- .filterNot(ele => notGenerated.contains(ele.name))
- .map(absClassFunction => {
- val scalaDoc = generateAPIDocFromBackend(absClassFunction)
- val defBody = generateAPISignature(absClassFunction, isSymbol)
- s"$scalaDoc\n$defBody"
- })
- val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase"
- val apacheLicence = "/*\n* Licensed to the Apache Software Foundation
(ASF) under one or more\n* contributor license agreements. See the NOTICE file
distributed with\n* this work for additional information regarding copyright
ownership.\n* The ASF licenses this file to You under the Apache License,
Version 2.0\n* (the \"License\"); you may not use this file except in
compliance with\n* the License. You may obtain a copy of the License at\n*\n*
http://www.apache.org/licenses/LICE [...]
- val scalaStyle = "// scalastyle:off"
- val packageDef = "package org.apache.mxnet"
- val imports = "import org.apache.mxnet.annotation.Experimental"
- val absClassDef = s"abstract class $packageName"
- val finalStr =
s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef
{\n${absFuncs.mkString("\n")}\n}"
- val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
- pw.write(finalStr)
- pw.close()
- MD5Generator(finalStr)
+ def typeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
+ val generated = typeSafeFunctionsToGenerate(isSymbol, isContrib = false)
+ .map { func =>
+ val scalaDoc = generateAPIDocFromBackend(func)
+ val decl = generateAPISignature(func, isSymbol)
+ s"$scalaDoc\n$decl"
+ }
+
+ writeFile(
+ FILE_PATH,
+ if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
+ "package org.apache.mxnet",
+ generated)
}
- def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
- // scalastyle:off
- val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
- val absFuncs = absClassFunctions.map(absClassFunction => {
- val scalaDoc = generateAPIDocFromBackend(absClassFunction, false)
- if (isSymbol) {
- val defBody = s"def ${absClassFunction.name}(name : String = null,
attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs :
Map[String, Any] = null): org.apache.mxnet.Symbol"
- s"$scalaDoc\n$defBody"
- } else {
- val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs:
Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
- val defBody = s"def ${absClassFunction.name}(args: Any*) :
org.apache.mxnet.NDArrayFuncReturn"
- s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody"
+ def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
+ val absFuncs = functionsToGenerate(isSymbol, isContrib = false)
+ .map { func =>
+ val scalaDoc = generateAPIDocFromBackend(func, false)
+ if (isSymbol) {
+ s"""$scalaDoc
+ |def ${func.name}(name : String = null, attr : Map[String,
String] = null)
+ | (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any]
= null):
+ | org.apache.mxnet.Symbol
+ """.stripMargin
+ } else {
+ s"""$scalaDoc
+ |def ${func.name}(kwargs: Map[String, Any] = null)
+ | (args: Any*): org.apache.mxnet.NDArrayFuncReturn
+ |
+ |$scalaDoc
+ |def ${func.name}(args: Any*): org.apache.mxnet.NDArrayFuncReturn
+ """.stripMargin
+ }
}
- })
- val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
- val apacheLicence = "/*\n* Licensed to the Apache Software Foundation
(ASF) under one or more\n* contributor license agreements. See the NOTICE file
distributed with\n* this work for additional information regarding copyright
ownership.\n* The ASF licenses this file to You under the Apache License,
Version 2.0\n* (the \"License\"); you may not use this file except in
compliance with\n* the License. You may obtain a copy of the License at\n*\n*
http://www.apache.org/licenses/LICE [...]
- val scalaStyle = "// scalastyle:off"
- val packageDef = "package org.apache.mxnet"
- val imports = "import org.apache.mxnet.annotation.Experimental"
- val absClassDef = s"abstract class $packageName"
- val finalStr =
s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef
{\n${absFuncs.mkString("\n")}\n}"
- import java.io._
- val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
- pw.write(finalStr)
- pw.close()
- MD5Generator(finalStr)
+
+ writeFile(
+ FILE_PATH,
+ if (isSymbol) "SymbolBase" else "NDArrayBase",
+ "package org.apache.mxnet",
+ absFuncs)
}
- // Generate ScalaDoc type
- def generateAPIDocFromBackend(func : absClassFunction, withParam : Boolean =
true) : String = {
- val desc = ArrayBuffer[String]()
- desc += " * <pre>"
- func.desc.split("\n").foreach({ currStr =>
- desc += s" * $currStr"
- })
- desc += " * </pre>"
- val params = func.listOfArgs.map({ absClassArg =>
- val currArgName = absClassArg.argName match {
- case "var" => "vari"
- case "type" => "typeOf"
- case _ => absClassArg.argName
- }
- s" * @param $currArgName\t\t${absClassArg.argDesc}"
- })
+ def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String
= {
+ val desc = func.desc.split("\n")
+ .mkString(" * <pre>\n", "\n * ", " * </pre>\n")
+
+ val params = func.listOfArgs.map { absClassArg =>
+ s" * @param ${absClassArg.safeArgName}\t\t${absClassArg.argDesc}"
+ }
+
val returnType = s" * @return ${func.returnType}"
+
if (withParam) {
- s" /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n
*/"
+ s""" /**
+ |$desc
+ |${params.mkString("\n")}
+ |$returnType
+ | */""".stripMargin
} else {
- s" /**\n${desc.mkString("\n")}\n$returnType\n */"
+ s""" /**
+ |$desc
+ |$returnType
+ | */""".stripMargin
}
}
- def generateAPISignature(func : absClassFunction, isSymbol : Boolean) :
String = {
- var argDef = ListBuffer[String]()
- func.listOfArgs.foreach(absClassArg => {
- val currArgName = absClassArg.argName match {
- case "var" => "vari"
- case "type" => "typeOf"
- case _ => absClassArg.argName
- }
- if (absClassArg.isOptional) {
- argDef += s"$currArgName : Option[${absClassArg.argType}] = None"
- }
- else {
- argDef += s"$currArgName : ${absClassArg.argType}"
- }
- })
- var returnType = func.returnType
+ def generateAPISignature(func: Func, isSymbol: Boolean): String = {
+ val argDef = ListBuffer[String]()
+
+ argDef ++= typedFunctionCommonArgDef(func)
+
if (isSymbol) {
argDef += "name : String = null"
argDef += "attr : Map[String, String] = null"
} else {
argDef += "out : Option[NDArray] = None"
- returnType = "org.apache.mxnet.NDArrayFuncReturn"
}
- val experimentalTag = "@Experimental"
- s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) :
$returnType"
- }
+ val returnType = func.returnType
- // List and add all the atomic symbol functions to current module.
- private def getSymbolNDArrayMethods(isSymbol : Boolean):
List[absClassFunction] = {
- val opNames = ListBuffer.empty[String]
- val returnType = if (isSymbol) "Symbol" else "NDArray"
- _LIB.mxListAllOpNames(opNames)
- // TODO: Add '_linalg_', '_sparse_', '_image_' support
- // TODO: Add Filter to the same location in case of refactor
- opNames.map(opName => {
- val opHandle = new RefLong
- _LIB.nnGetOpHandle(opName, opHandle)
- makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." +
returnType)
- }).toList.filterNot(_.name.startsWith("_"))
+ s"""@Experimental
+ |def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin
}
- // Create an atomic symbol function by handle and function name.
- private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName:
String, returnType : String)
- : absClassFunction = {
- val name = new RefString
- val desc = new RefString
- val keyVarNumArgs = new RefString
- val numArgs = new RefInt
- val argNames = ListBuffer.empty[String]
- val argTypes = ListBuffer.empty[String]
- val argDescs = ListBuffer.empty[String]
-
- _LIB.mxSymbolGetAtomicSymbolInfo(
- handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
- val argList = argNames zip argTypes zip argDescs map { case ((argName,
argType), argDesc) =>
- val typeAndOption = CToScalaUtils.argumentCleaner(argName, argType,
returnType)
- new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2)
- }
- new absClassFunction(aliasName, desc.value, argList.toList, returnType)
+ def writeFile(FILE_PATH: String, className: String, packageDef: String,
+ absFuncs: Seq[String]): String = {
+
+ val finalStr =
+ s"""/*
+ |* Licensed to the Apache Software Foundation (ASF) under one or more
+ |* contributor license agreements. See the NOTICE file distributed
with
+ |* this work for additional information regarding copyright ownership.
+ |* The ASF licenses this file to You under the Apache License,
Version 2.0
+ |* (the "License"); you may not use this file except in compliance
with
+ |* the License. You may obtain a copy of the License at
+ |*
+ |* http://www.apache.org/licenses/LICENSE-2.0
+ |*
+ |* Unless required by applicable law or agreed to in writing, software
+ |* distributed under the License is distributed on an "AS IS" BASIS,
+ |* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied.
+ |* See the License for the specific language governing permissions and
+ |* limitations under the License.
+ |*/
+ |
+ |$packageDef
+ |
+ |import org.apache.mxnet.annotation.Experimental
+ |
+ |// scalastyle:off
+ |abstract class $className {
+ |${absFuncs.mkString("\n")}
+ |}""".stripMargin
+
+ val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala"))
+ pw.write(finalStr)
+ pw.close()
+ MD5Generator(finalStr)
}
+
}
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
new file mode 100644
index 0000000..f4c4a91
--- /dev/null
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.apache.mxnet.init.Base.{RefInt, RefLong, RefString, _LIB}
+import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}
+
+import scala.collection.mutable.ListBuffer
+import scala.reflect.macros.blackbox
+
+abstract class GeneratorBase {
+ type Handle = Long
+
+ case class Arg(argName: String, argType: String, argDesc: String,
isOptional: Boolean) {
+ def safeArgName: String = argName match {
+ case "var" => "vari"
+ case "type" => "typeOf"
+ case _ => argName
+ }
+ }
+
+ case class Func(name: String, desc: String, listOfArgs: List[Arg],
returnType: String)
+
+ def functionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] =
{
+ val l = getBackEndFunctions(isSymbol)
+ if (isContrib) {
+ l.filter(func => func.name.startsWith("_contrib_") ||
!func.name.startsWith("_"))
+ } else {
+ l.filterNot(_.name.startsWith("_"))
+ }
+ }
+
+ def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib: Boolean):
List[Func] = {
+ // Operators that should not be generated
+ val notGenerated = Set("Custom")
+
+ val l = getBackEndFunctions(isSymbol)
+ val res = if (isContrib) {
+ l.filter(func => func.name.startsWith("_contrib_") ||
!func.name.startsWith("_"))
+ } else {
+ l.filterNot(_.name.startsWith("_"))
+ }
+ res.filterNot(ele => notGenerated.contains(ele.name))
+ }
+
+ protected def getBackEndFunctions(isSymbol: Boolean): List[Func] = {
+ val opNames = ListBuffer.empty[String]
+ _LIB.mxListAllOpNames(opNames)
+ opNames.map(opName => {
+ val opHandle = new RefLong
+ _LIB.nnGetOpHandle(opName, opHandle)
+ makeAtomicFunction(opHandle.value, opName, isSymbol)
+ }).toList
+ }
+
+ private def makeAtomicFunction(handle: Handle, aliasName: String, isSymbol:
Boolean): Func = {
+ val name = new RefString
+ val desc = new RefString
+ val keyVarNumArgs = new RefString
+ val numArgs = new RefInt
+ val argNames = ListBuffer.empty[String]
+ val argTypes = ListBuffer.empty[String]
+ val argDescs = ListBuffer.empty[String]
+
+ _LIB.mxSymbolGetAtomicSymbolInfo(
+ handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
+ val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes,
argDescs)
+ val extraDoc: String = if (keyVarNumArgs.value != null &&
keyVarNumArgs.value.length > 0) {
+ s"This function support variable length of positional input
(${keyVarNumArgs.value})."
+ } else {
+ ""
+ }
+ val realName = if (aliasName == name.value) "" else s"(a.k.a.,
${name.value})"
+ val docStr = s"$aliasName
$realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
+
+ val argList = argNames zip argTypes zip argDescs map { case ((argName,
argType), argDesc) =>
+ val family = if (isSymbol) "org.apache.mxnet.Symbol" else
"org.apache.mxnet.NDArray"
+ val typeAndOption =
+ CToScalaUtils.argumentCleaner(argName, argType, family)
+ Arg(argName, typeAndOption._1, argDesc, typeAndOption._2)
+ }
+ val returnType =
+ if (isSymbol) "org.apache.mxnet.Symbol" else
"org.apache.mxnet.NDArrayFuncReturn"
+ Func(aliasName, desc.value, argList.toList, returnType)
+ }
+
+ /**
+ * Generate class structure for all function APIs
+ *
+ * @param c
+ * @param funcDef DefDef type of function definitions
+ * @param annottees
+ * @return
+ */
+ protected def structGeneration(c: blackbox.Context)
+ (funcDef: List[c.universe.DefDef], annottees:
c.Expr[Any]*)
+ : c.Expr[Any] = {
+ import c.universe._
+ val inputs = annottees.map(_.tree).toList
+ // pattern match on the inputs
+ val modDefs = inputs map {
+ case ClassDef(mods, name, something, template) =>
+ val q = template match {
+ case Template(superMaybe, emptyValDef, defs) =>
+ Template(superMaybe, emptyValDef, defs ++ funcDef)
+ case ex =>
+ throw new IllegalArgumentException(s"Invalid template: $ex")
+ }
+ ClassDef(mods, name, something, q)
+ case ModuleDef(mods, name, template) =>
+ val q = template match {
+ case Template(superMaybe, emptyValDef, defs) =>
+ Template(superMaybe, emptyValDef, defs ++ funcDef)
+ case ex =>
+ throw new IllegalArgumentException(s"Invalid template: $ex")
+ }
+ ModuleDef(mods, name, q)
+ case ex =>
+ throw new IllegalArgumentException(s"Invalid macro input: $ex")
+ }
+ // wrap the result up in an Expr, and return it
+ val result = c.Expr(Block(modDefs, Literal(Constant())))
+ result
+ }
+
+ protected def typedFunctionCommonArgDef(func: Func): List[String] = {
+ // build function argument definition, with optionality, and safe names
+ func.listOfArgs.map(arg =>
+ if (arg.isOptional) {
+ // let's avoid a stupid Option[Array[...]]
+ if (arg.argType.startsWith("Array[")) {
+ s"${arg.safeArgName} : ${arg.argType} = Array.empty"
+ } else {
+ s"${arg.safeArgName} : Option[${arg.argType}] = None"
+ }
+ }
+ else {
+ s"${arg.safeArgName} : ${arg.argType}"
+ }
+ )
+ }
+}
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
index 2d3a1c7..d85abe1 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
@@ -17,11 +17,8 @@
package org.apache.mxnet
-import org.apache.mxnet.init.Base._
-import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}
-
import scala.annotation.StaticAnnotation
-import scala.collection.mutable.{ArrayBuffer, ListBuffer}
+import scala.collection.mutable.ListBuffer
import scala.language.experimental.macros
import scala.reflect.macros.blackbox
@@ -30,207 +27,111 @@ private[mxnet] class AddNDArrayFunctions(isContrib:
Boolean) extends StaticAnnot
}
private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends
StaticAnnotation {
- private[mxnet] def macroTransform(annottees: Any*) = macro
NDArrayMacro.typeSafeAPIDefs
+ private[mxnet] def macroTransform(annottees: Any*) = macro
TypedNDArrayAPIMacro.typeSafeAPIDefs
}
-private[mxnet] object NDArrayMacro {
- case class NDArrayArg(argName: String, argType: String, isOptional : Boolean)
- case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg])
-
- // scalastyle:off havetype
- def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
- impl(c)(annottees: _*)
- }
- def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
- typeSafeAPIImpl(c)(annottees: _*)
- }
- // scalastyle:off havetype
-
- private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule()
+private[mxnet] object NDArrayMacro extends GeneratorBase {
- private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any]
= {
+ def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._
-
val isContrib: Boolean = c.prefix.tree match {
case q"new AddNDArrayFunctions($b)" => c.eval[Boolean](c.Expr(b))
}
- val newNDArrayFunctions = {
- if (isContrib) ndarrayFunctions.filter(_.name.startsWith("_contrib_"))
- else ndarrayFunctions.filterNot(_.name.startsWith("_"))
- }
-
- val functionDefs = newNDArrayFunctions flatMap { NDArrayfunction =>
- val funcName = NDArrayfunction.name
- val termName = TermName(funcName)
- Seq(
- // scalastyle:off
- // (yizhi) We are investigating a way to make these functions
type-safe
- // and waiting to see the new approach is stable enough.
- // Thus these functions may be deprecated in the future.
- // e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*)
- q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) =
{genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef],
- // e.g def transpose(args: Any*)
- q"def $termName(args: Any*) =
{genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
- // scalastyle:on
- )
- }
-
- structGeneration(c)(functionDefs, annottees : _*)
+ impl(c)(isContrib, annottees: _*)
}
- private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) :
c.Expr[Any] = {
+ private def impl(c: blackbox.Context)
+ (isContrib: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] =
{
import c.universe._
- val isContrib: Boolean = c.prefix.tree match {
- case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
- }
- // Defines Operators that should not generated
- val notGenerated = Set("Custom")
-
- val newNDArrayFunctions = {
- if (isContrib) ndarrayFunctions.filter(
- func => func.name.startsWith("_contrib_") ||
!func.name.startsWith("_"))
- else ndarrayFunctions.filterNot(_.name.startsWith("_"))
- }.filterNot(ele => notGenerated.contains(ele.name))
-
- val functionDefs = newNDArrayFunctions.map { ndarrayfunction =>
-
- // Construct argument field
- var argDef = ListBuffer[String]()
- // Construct Implementation field
- var impl = ListBuffer[String]()
- impl += "val map = scala.collection.mutable.Map[String, Any]()"
- impl += "val args = scala.collection.mutable.ArrayBuffer.empty[NDArray]"
- ndarrayfunction.listOfArgs.foreach({ ndarrayarg =>
- // var is a special word used to define variable in Scala,
- // need to changed to something else in order to make it work
- val currArgName = ndarrayarg.argName match {
- case "var" => "vari"
- case "type" => "typeOf"
- case default => ndarrayarg.argName
- }
- if (ndarrayarg.isOptional) {
- argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None"
- }
- else {
- argDef += s"${currArgName} : ${ndarrayarg.argType}"
- }
- // NDArray arg implementation
- val returnType = "org.apache.mxnet.NDArray"
-
- // TODO: Currently we do not add place holder for NDArray
- // Example: an NDArray operator like the following format
- // nd.foo(arg1: NDArray(required), arg2: NDArray(Optional), arg3:
NDArray(Optional)
- // If we place nd.foo(arg1, arg3 = arg3), do we need to add place
holder for arg2?
- // What it should be?
- val base =
- if (ndarrayarg.argType.equals(returnType)) {
- s"args += $currArgName"
- } else if (ndarrayarg.argType.equals(s"Array[$returnType]")){
- s"args ++= $currArgName"
- } else {
- "map(\"" + ndarrayarg.argName + "\") = " + currArgName
- }
- impl.append(
- if (ndarrayarg.isOptional) s"if (!$currArgName.isEmpty) $base.get"
- else base
- )
- })
- // add default out parameter
- argDef += "out : Option[NDArray] = None"
- impl += "if (!out.isEmpty) map(\"out\") = out.get"
- // scalastyle:off
- impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" +
ndarrayfunction.name + "\", args.toSeq, map.toMap)"
- // scalastyle:on
- // Combine and build the function string
- val returnType = "org.apache.mxnet.NDArrayFuncReturn"
- var finalStr = s"def ${ndarrayfunction.name}"
- finalStr += s" (${argDef.mkString(",")}) : $returnType"
- finalStr += s" = {${impl.mkString("\n")}}"
- c.parse(finalStr).asInstanceOf[DefDef]
+ val functions = functionsToGenerate(isSymbol = false, isContrib)
+
+ val functionDefs = functions.flatMap { NDArrayfunction =>
+ val funcName = NDArrayfunction.name
+ val termName = TermName(funcName)
+ Seq(
+ // e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*)
+ q"""
+ def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {
+ genericNDArrayFunctionInvoke($funcName, args, kwargs)
+ }
+ """.asInstanceOf[DefDef],
+ // e.g def transpose(args: Any*)
+ q"""
+ def $termName(args: Any*) = {
+ genericNDArrayFunctionInvoke($funcName, args, null)
+ }
+ """.asInstanceOf[DefDef]
+ )
}
- structGeneration(c)(functionDefs, annottees : _*)
+ structGeneration(c)(functionDefs, annottees: _*)
}
+}
- private def structGeneration(c: blackbox.Context)
- (funcDef : List[c.universe.DefDef], annottees:
c.Expr[Any]*)
- : c.Expr[Any] = {
+private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase {
+
+ def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*):
c.Expr[Any] = {
import c.universe._
- val inputs = annottees.map(_.tree).toList
- // pattern match on the inputs
- val modDefs = inputs map {
- case ClassDef(mods, name, something, template) =>
- val q = template match {
- case Template(superMaybe, emptyValDef, defs) =>
- Template(superMaybe, emptyValDef, defs ++ funcDef)
- case ex =>
- throw new IllegalArgumentException(s"Invalid template: $ex")
- }
- ClassDef(mods, name, something, q)
- case ModuleDef(mods, name, template) =>
- val q = template match {
- case Template(superMaybe, emptyValDef, defs) =>
- Template(superMaybe, emptyValDef, defs ++ funcDef)
- case ex =>
- throw new IllegalArgumentException(s"Invalid template: $ex")
- }
- ModuleDef(mods, name, q)
- case ex =>
- throw new IllegalArgumentException(s"Invalid macro input: $ex")
+ val isContrib: Boolean = c.prefix.tree match {
+ case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
}
- // wrap the result up in an Expr, and return it
- val result = c.Expr(Block(modDefs, Literal(Constant())))
- result
+
+ val functions = typeSafeFunctionsToGenerate(isSymbol = false, isContrib)
+
+ val functionDefs = functions.map(f => buildTypedFunction(c)(f))
+ structGeneration(c)(functionDefs, annottees: _*)
}
+ protected def buildTypedFunction(c: blackbox.Context)
+ (function: Func): c.universe.DefDef = {
+ import c.universe._
+ val returnType = "org.apache.mxnet.NDArrayFuncReturn"
+ val ndarrayType = "org.apache.mxnet.NDArray"
+ // Construct argument field
+ val argDef = ListBuffer[String]()
+ argDef ++= typedFunctionCommonArgDef(function)
+ argDef += "out : Option[NDArray] = None"
- // List and add all the atomic symbol functions to current module.
- private def initNDArrayModule(): List[NDArrayFunction] = {
- val opNames = ListBuffer.empty[String]
- _LIB.mxListAllOpNames(opNames)
- opNames.map(opName => {
- val opHandle = new RefLong
- _LIB.nnGetOpHandle(opName, opHandle)
- makeNDArrayFunction(opHandle.value, opName)
- }).toList
- }
+ // Construct Implementation field
+ var impl = ListBuffer[String]()
+ impl += "val map = scala.collection.mutable.Map[String, Any]()"
+ impl += s"val args =
scala.collection.mutable.ArrayBuffer.empty[$ndarrayType]"
- // Create an atomic symbol function by handle and function name.
- private def makeNDArrayFunction(handle: NDArrayHandle, aliasName: String)
- : NDArrayFunction = {
- val name = new RefString
- val desc = new RefString
- val keyVarNumArgs = new RefString
- val numArgs = new RefInt
- val argNames = ListBuffer.empty[String]
- val argTypes = ListBuffer.empty[String]
- val argDescs = ListBuffer.empty[String]
-
- _LIB.mxSymbolGetAtomicSymbolInfo(
- handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
- val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes,
argDescs)
- val extraDoc: String = if (keyVarNumArgs.value != null &&
keyVarNumArgs.value.length > 0) {
- s"This function support variable length of positional input
(${keyVarNumArgs.value})."
- } else {
- ""
- }
- val realName = if (aliasName == name.value) "" else s"(a.k.a.,
${name.value})"
- val docStr = s"$aliasName
$realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
- // scalastyle:off println
- if (System.getenv("MXNET4J_PRINT_OP_DEF") != null
- && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") {
- println("NDArray function definition:\n" + docStr)
- }
- // scalastyle:on println
- val argList = argNames zip argTypes map { case (argName, argType) =>
- val typeAndOption =
- CToScalaUtils.argumentCleaner(argName, argType,
"org.apache.mxnet.NDArray")
- new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
+ // NDArray arg implementation
+ impl ++= function.listOfArgs.map { arg =>
+ if (arg.argType.equals(s"Array[$ndarrayType]")) {
+ s"args ++= ${arg.safeArgName}"
+ } else {
+ val base =
+ if (arg.argType.equals(ndarrayType)) {
+ // ndarrays go to args
+ s"args += ${arg.safeArgName}"
+ } else {
+ // other types go to kwargs
+ s"""map("${arg.argName}") = ${arg.safeArgName}"""
+ }
+ if (arg.isOptional) s"if (!${arg.safeArgName}.isEmpty) $base.get"
+ else base
+ }
}
- new NDArrayFunction(aliasName, argList.toList)
+
+ impl +=
+ s"""if (!out.isEmpty) map("out") = out.get
+ |org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(
+ | "${function.name}", args.toSeq, map.toMap)
+ """.stripMargin
+
+ // Combine and build the function string
+ val finalStr =
+ s"""def ${function.name}
+ | (${argDef.mkString(",")}) : $returnType
+ | = {${impl.mkString("\n")}}
+ """.stripMargin
+
+ c.parse(finalStr).asInstanceOf[DefDef]
}
}
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
index 42aa117..ab864e1 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
@@ -21,222 +21,106 @@ import scala.annotation.StaticAnnotation
import scala.collection.mutable.ListBuffer
import scala.language.experimental.macros
import scala.reflect.macros.blackbox
-import org.apache.mxnet.init.Base._
-import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}
private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends
StaticAnnotation {
- private[mxnet] def macroTransform(annottees: Any*) = macro
SymbolImplMacros.addDefs
+ private[mxnet] def macroTransform(annottees: Any*) = macro
SymbolMacro.addDefs
}
private[mxnet] class AddSymbolAPIs(isContrib: Boolean) extends
StaticAnnotation {
- private[mxnet] def macroTransform(annottees: Any*) = macro
SymbolImplMacros.typeSafeAPIDefs
+ private[mxnet] def macroTransform(annottees: Any*) = macro
TypedSymbolAPIMacro.typeSafeAPIDefs
}
-private[mxnet] object SymbolImplMacros {
- case class SymbolArg(argName: String, argType: String, isOptional : Boolean)
- case class SymbolFunction(name: String, listOfArgs: List[SymbolArg])
+private[mxnet] object SymbolMacro extends GeneratorBase {
- // scalastyle:off havetype
- def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
- impl(c)(annottees: _*)
- }
- def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
- typedAPIImpl(c)(annottees: _*)
- }
- // scalastyle:on havetype
-
- private val symbolFunctions: List[SymbolFunction] = initSymbolModule()
-
- /**
- * Implementation for fixed input API structure
- */
- private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any]
= {
+ def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
import c.universe._
-
val isContrib: Boolean = c.prefix.tree match {
case q"new AddSymbolFunctions($b)" => c.eval[Boolean](c.Expr(b))
}
- val newSymbolFunctions = {
- if (isContrib) symbolFunctions.filter(
- func => func.name.startsWith("_contrib_") ||
!func.name.startsWith("_"))
- else symbolFunctions.filter(!_.name.startsWith("_"))
- }
+ impl(c)(isContrib, annottees: _*)
+ }
+
+ private def impl(c: blackbox.Context)
+ (isContrib: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] =
{
+ import c.universe._
+ val functions = functionsToGenerate(isSymbol = false, isContrib)
- val functionDefs = newSymbolFunctions map { symbolfunction =>
- val funcName = symbolfunction.name
- val tName = TermName(funcName)
- q"""
+ val functionDefs = functions.map { symbolfunction =>
+ val funcName = symbolfunction.name
+ val tName = TermName(funcName)
+ q"""
def $tName(name : String = null, attr : Map[String, String] = null)
- (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null)
- : org.apache.mxnet.Symbol = {
- createSymbolGeneral($funcName,name,attr,args,kwargs)
- }
+ (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] =
null)
+ : org.apache.mxnet.Symbol = {
+ createSymbolGeneral($funcName,name,attr,args,kwargs)
+ }
""".asInstanceOf[DefDef]
- }
+ }
- structGeneration(c)(functionDefs, annottees : _*)
+ structGeneration(c)(functionDefs, annottees: _*)
}
+}
- /**
- * Implementation for Dynamic typed API Symbol.api.<functioname>
- */
- private def typedAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) :
c.Expr[Any] = {
- import c.universe._
+private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase {
+ def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*):
c.Expr[Any] = {
+ import c.universe._
val isContrib: Boolean = c.prefix.tree match {
case q"new AddSymbolAPIs($b)" => c.eval[Boolean](c.Expr(b))
}
- // Defines Operators that should not generated
- val notGenerated = Set("Custom")
-
- // TODO: Put Symbol.api.foo --> Stable APIs
- // Symbol.contrib.bar--> Contrib APIs
- val newSymbolFunctions = {
- if (isContrib) symbolFunctions.filter(
- func => func.name.startsWith("_contrib_") ||
!func.name.startsWith("_"))
- else symbolFunctions.filter(!_.name.startsWith("_"))
- }.filterNot(ele => notGenerated.contains(ele.name))
-
- val functionDefs = newSymbolFunctions map { symbolfunction =>
-
- // Construct argument field
- var argDef = ListBuffer[String]()
- // Construct Implementation field
- var impl = ListBuffer[String]()
- impl += "val map = scala.collection.mutable.Map[String, Any]()"
- impl += "var args = Seq[org.apache.mxnet.Symbol]()"
- symbolfunction.listOfArgs.foreach({ symbolarg =>
- // var is a special word used to define variable in Scala,
- // need to changed to something else in order to make it work
- val currArgName = symbolarg.argName match {
- case "var" => "vari"
- case "type" => "typeOf"
- case default => symbolarg.argName
- }
- if (symbolarg.isOptional) {
- argDef += s"${currArgName} : Option[${symbolarg.argType}] = None"
- }
- else {
- argDef += s"${currArgName} : ${symbolarg.argType}"
- }
- // Symbol arg implementation
- val returnType = "org.apache.mxnet.Symbol"
- val base =
- if (symbolarg.argType.equals(s"Array[$returnType]")) {
- if (symbolarg.isOptional) s"if (!$currArgName.isEmpty) args =
$currArgName.get.toSeq"
- else s"args = $currArgName.toSeq"
- } else {
- if (symbolarg.isOptional) {
- // scalastyle:off
- s"if (!$currArgName.isEmpty) map(" + "\"" + symbolarg.argName +
"\"" + s") = $currArgName.get"
- // scalastyle:on
- }
- else "map(\"" + symbolarg.argName + "\"" + s") = $currArgName"
- }
+ val functions = typeSafeFunctionsToGenerate(isSymbol = true, isContrib)
- impl += base
- })
- argDef += "name : String = null"
- argDef += "attr : Map[String, String] = null"
- // scalastyle:off
- // TODO: Seq() here allows user to place Symbols rather than normal
arguments to run, need to fix if old API deprecated
- impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" +
symbolfunction.name + "\", name, attr, args, map.toMap)"
- // scalastyle:on
- // Combine and build the function string
- val returnType = "org.apache.mxnet.Symbol"
- var finalStr = s"def ${symbolfunction.name}"
- finalStr += s" (${argDef.mkString(",")}) : $returnType"
- finalStr += s" = {${impl.mkString("\n")}}"
- c.parse(finalStr).asInstanceOf[DefDef]
- }
- structGeneration(c)(functionDefs, annottees : _*)
+ val functionDefs = functions.map(f => buildTypedFunction(c)(f))
+ structGeneration(c)(functionDefs, annottees: _*)
}
- /**
- * Generate class structure for all function APIs
- * @param c
- * @param funcDef DefDef type of function definitions
- * @param annottees
- * @return
- */
- private def structGeneration(c: blackbox.Context)
- (funcDef : List[c.universe.DefDef], annottees:
c.Expr[Any]*)
- : c.Expr[Any] = {
+ protected def buildTypedFunction(c: blackbox.Context)
+ (function: Func): c.universe.DefDef = {
import c.universe._
- val inputs = annottees.map(_.tree).toList
- // pattern match on the inputs
- val modDefs = inputs map {
- case ClassDef(mods, name, something, template) =>
- val q = template match {
- case Template(superMaybe, emptyValDef, defs) =>
- Template(superMaybe, emptyValDef, defs ++ funcDef)
- case ex =>
- throw new IllegalArgumentException(s"Invalid template: $ex")
- }
- ClassDef(mods, name, something, q)
- case ModuleDef(mods, name, template) =>
- val q = template match {
- case Template(superMaybe, emptyValDef, defs) =>
- Template(superMaybe, emptyValDef, defs ++ funcDef)
- case ex =>
- throw new IllegalArgumentException(s"Invalid template: $ex")
- }
- ModuleDef(mods, name, q)
- case ex =>
- throw new IllegalArgumentException(s"Invalid macro input: $ex")
- }
- // wrap the result up in an Expr, and return it
- val result = c.Expr(Block(modDefs, Literal(Constant())))
- result
- }
- // List and add all the atomic symbol functions to current module.
- private def initSymbolModule(): List[SymbolFunction] = {
- val opNames = ListBuffer.empty[String]
- _LIB.mxListAllOpNames(opNames)
- // TODO: Add '_linalg_', '_sparse_', '_image_' support
- opNames.map(opName => {
- val opHandle = new RefLong
- _LIB.nnGetOpHandle(opName, opHandle)
- makeAtomicSymbolFunction(opHandle.value, opName)
- }).toList
- }
+ val returnType = "org.apache.mxnet.Symbol"
+ val symbolType = "org.apache.mxnet.Symbol"
- // Create an atomic symbol function by handle and function name.
- private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String)
- : SymbolFunction = {
- val name = new RefString
- val desc = new RefString
- val keyVarNumArgs = new RefString
- val numArgs = new RefInt
- val argNames = ListBuffer.empty[String]
- val argTypes = ListBuffer.empty[String]
- val argDescs = ListBuffer.empty[String]
-
- _LIB.mxSymbolGetAtomicSymbolInfo(
- handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
- val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes,
argDescs)
- val extraDoc: String = if (keyVarNumArgs.value != null &&
keyVarNumArgs.value.length > 0) {
- s"This function support variable length of positional input
(${keyVarNumArgs.value})."
+ // Construct argument field
+ val argDef = ListBuffer[String]()
+ argDef ++= typedFunctionCommonArgDef(function)
+ argDef += "name : String = null"
+ argDef += "attr : Map[String, String] = null"
+
+ // Construct Implementation field
+ val impl = ListBuffer[String]()
+ impl += "val map = scala.collection.mutable.Map[String, Any]()"
+ impl += s"var args = scala.collection.Seq[$symbolType]()"
+
+ // Symbol arg implementation
+ impl ++= function.listOfArgs.map { arg =>
+ if (arg.argType.equals(s"Array[$symbolType]")) {
+ s"if (!${arg.safeArgName}.isEmpty) args = ${arg.safeArgName}.toSeq"
} else {
- ""
+ // all go in kwargs
+ if (arg.isOptional) {
+ s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") =
${arg.safeArgName}.get"""
+ } else {
+ s"""map("${arg.argName}") = ${arg.safeArgName}"""
+ }
}
- val realName = if (aliasName == name.value) "" else s"(a.k.a.,
${name.value})"
- val docStr = s"$aliasName
$realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
- // scalastyle:off println
- if (System.getenv("MXNET4J_PRINT_OP_DEF") != null
- && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") {
- println("Symbol function definition:\n" + docStr)
}
- // scalastyle:on println
- val argList = argNames zip argTypes map { case (argName, argType) =>
- val typeAndOption =
- CToScalaUtils.argumentCleaner(argName, argType,
"org.apache.mxnet.Symbol")
- new SymbolArg(argName, typeAndOption._1, typeAndOption._2)
- }
- new SymbolFunction(aliasName, argList.toList)
+
+ impl +=
+ s"""org.apache.mxnet.Symbol.createSymbolGeneral(
+ | "${function.name}", name, attr, args, map.toMap)
+ """.stripMargin
+
+ // Combine and build the function string
+ val finalStr =
+ s"""def ${function.name}
+ | (${argDef.mkString(",")}) : $returnType
+ | = {${impl.mkString("\n")}}
+ """.stripMargin
+
+ c.parse(finalStr).asInstanceOf[DefDef]
}
}