yzhliu commented on a change in pull request #10660: [MXNET-357] New Scala API
Design (Symbol)
URL: https://github.com/apache/incubator-mxnet/pull/10660#discussion_r187469159
##########
File path:
scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
##########
@@ -21,109 +21,155 @@ import scala.annotation.StaticAnnotation
import scala.collection.mutable.ListBuffer
import scala.language.experimental.macros
import scala.reflect.macros.blackbox
-
import org.apache.mxnet.init.Base._
import org.apache.mxnet.utils.OperatorBuildUtils
private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends
StaticAnnotation {
private[mxnet] def macroTransform(annottees: Any*) = macro
SymbolImplMacros.addDefs
}
+private[mxnet] class AddSymbolAPIs(isContrib: Boolean) extends
StaticAnnotation {
+ private[mxnet] def macroTransform(annottees: Any*) = macro
SymbolImplMacros.addNewDefs
+}
+
private[mxnet] object SymbolImplMacros {
- case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String)
+ case class SymbolArg(argName: String, argType: String, isOptional : Boolean)
+ case class SymbolFunction(name: String, listOfArgs: List[SymbolArg])
// scalastyle:off havetype
def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
- impl(c)(false, annottees: _*)
+ impl(c)(annottees: _*)
}
- // scalastyle:off havetype
+ def addNewDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
+ newAPIImpl(c)(annottees: _*)
+ }
+ // scalastyle:on havetype
- private val symbolFunctions: Map[String, SymbolFunction] = initSymbolModule()
+ private val symbolFunctions: List[SymbolFunction] = initSymbolModule()
- private def impl(c: blackbox.Context)(addSuper: Boolean, annottees:
c.Expr[Any]*): c.Expr[Any] = {
+ /**
+ * Implementation for fixed input API structure
+ */
+ private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any]
= {
import c.universe._
val isContrib: Boolean = c.prefix.tree match {
case q"new AddSymbolFunctions($b)" => c.eval[Boolean](c.Expr(b))
}
val newSymbolFunctions = {
- if (isContrib) symbolFunctions.filter(_._1.startsWith("_contrib_"))
- else symbolFunctions.filter(!_._1.startsWith("_contrib_"))
+ if (isContrib) symbolFunctions.filter(
+ func => func.name.startsWith("_contrib_") ||
!func.name.startsWith("_"))
+ else symbolFunctions.filter(!_.name.startsWith("_"))
}
- val AST_TYPE_MAP_STRING_ANY = AppliedTypeTree(Ident(TypeName("Map")),
- List(Ident(TypeName("String")), Ident(TypeName("Any"))))
- val AST_TYPE_MAP_STRING_STRING = AppliedTypeTree(Ident(TypeName("Map")),
- List(Ident(TypeName("String")), Ident(TypeName("String"))))
- val AST_TYPE_SYMBOL_VARARG = AppliedTypeTree(
- Select(
- Select(Ident(termNames.ROOTPKG), TermName("scala")),
- TypeName("<repeated>")
- ),
- List(Select(Select(Select(
- Ident(TermName("org")), TermName("apache")), TermName("mxnet")),
TypeName("Symbol")))
- )
-
- val functionDefs = newSymbolFunctions map { case (funcName, funcProp) =>
- val functionScope = {
- if (isContrib) Modifiers()
- else {
- if (funcName.startsWith("_")) Modifiers(Flag.PRIVATE) else
Modifiers()
- }
- }
- val newName = {
- if (isContrib) funcName.substring(funcName.indexOf("_contrib_") +
"_contrib_".length())
- else funcName
+
+ val functionDefs = newSymbolFunctions map { symbolfunction =>
+ val funcName = symbolfunction.name
+ val tName = TermName(funcName)
+ q"""
+ def $tName(name : String = null, attr : Map[String, String] = null)
+ (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null)
+ : org.apache.mxnet.Symbol = {
+ createSymbolGeneral($funcName,name,attr,args,kwargs)
+ }
+ """.asInstanceOf[DefDef]
}
- // It will generate definition something like,
- // def Concat(name: String = null, attr: Map[String, String] = null)
- // (args: Symbol*)(kwargs: Map[String, Any] = null)
- DefDef(functionScope, TermName(newName), List(),
- List(
- List(
- ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), TermName("name"),
- Ident(TypeName("String")), Literal(Constant(null))),
- ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), TermName("attr"),
- AST_TYPE_MAP_STRING_STRING, Literal(Constant(null)))
- ),
- List(
- ValDef(Modifiers(), TermName("args"), AST_TYPE_SYMBOL_VARARG,
EmptyTree)
- ),
- List(
- ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM),
TermName("kwargs"),
- AST_TYPE_MAP_STRING_ANY, Literal(Constant(null)))
- )
- ), TypeTree(),
- Apply(
- Ident(TermName("createSymbolGeneral")),
- List(
- Literal(Constant(funcName)),
- Ident(TermName("name")),
- Ident(TermName("attr")),
- Ident(TermName("args")),
- Ident(TermName("kwargs"))
- )
- )
- )
+ structGeneration(c)(functionDefs, annottees : _*)
+ }
+
+ /**
+ * Implementation for Dynamic typed API Symbol.api.<functioname>
+ */
+ private def newAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) :
c.Expr[Any] = {
+ import c.universe._
+
+ val isContrib: Boolean = c.prefix.tree match {
+ case q"new AddSymbolAPIs($b)" => c.eval[Boolean](c.Expr(b))
+ }
+
+ val newSymbolFunctions = {
+ if (isContrib) symbolFunctions.filter(
+ func => func.name.startsWith("_contrib_") ||
!func.name.startsWith("_"))
+ else symbolFunctions.filter(!_.name.startsWith("_"))
}
+ val functionDefs = newSymbolFunctions map { symbolfunction =>
+
+ // Construct argument field
+ var argDef = ListBuffer[String]()
+ symbolfunction.listOfArgs.foreach(symbolarg => {
Review comment:
can we combine this with the next `foreach`, i.e., line 120?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services