This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 53c5a72  [MXNET-918] Introduce Random module / Refact code generation 
(#13038)
53c5a72 is described below

commit 53c5a72c1f28dad284b7f6d7699cca6f0eec776a
Author: mathieu <[email protected]>
AuthorDate: Mon Nov 5 18:55:45 2018 +0100

    [MXNET-918] Introduce Random module / Refact code generation (#13038)
    
    * refactor code gen
    
    * remove xxxAPIMacroBase (overkill)
    
    * CI errors / scala-style
    
    * PR review comments
---
 .../scala/org/apache/mxnet/APIDocGenerator.scala   | 234 ++++++++----------
 .../scala/org/apache/mxnet/GeneratorBase.scala     | 157 ++++++++++++
 .../main/scala/org/apache/mxnet/NDArrayMacro.scala | 263 +++++++--------------
 .../main/scala/org/apache/mxnet/SymbolMacro.scala  | 250 ++++++--------------
 4 files changed, 411 insertions(+), 493 deletions(-)

diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
index b4efa65..bfa378e 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
@@ -17,178 +17,154 @@
 
 package org.apache.mxnet
 
-import org.apache.mxnet.init.Base._
-import org.apache.mxnet.utils.CToScalaUtils
 import java.io._
 import java.security.MessageDigest
 
-import scala.collection.mutable.{ArrayBuffer, ListBuffer}
+import scala.collection.mutable.ListBuffer
 
 /**
   * This object will generate the Scala documentation of the new Scala API
   * Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
   * The code will be executed during Macros stage and file live in Core stage
   */
-private[mxnet] object APIDocGenerator{
-  case class absClassArg(argName : String, argType : String, argDesc : String, 
isOptional : Boolean)
-  case class absClassFunction(name : String, desc : String,
-                           listOfArgs: List[absClassArg], returnType : String)
+private[mxnet] object APIDocGenerator extends GeneratorBase {
 
-
-  def main(args: Array[String]) : Unit = {
+  def main(args: Array[String]): Unit = {
     val FILE_PATH = args(0)
     val hashCollector = ListBuffer[String]()
-    hashCollector += absClassGen(FILE_PATH, true)
-    hashCollector += absClassGen(FILE_PATH, false)
+    hashCollector += typeSafeClassGen(FILE_PATH, true)
+    hashCollector += typeSafeClassGen(FILE_PATH, false)
     hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
     hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
     val finalHash = hashCollector.mkString("\n")
   }
 
-  def MD5Generator(input : String) : String = {
+  def MD5Generator(input: String): String = {
     val md = MessageDigest.getInstance("MD5")
     md.update(input.getBytes("UTF-8"))
     val digest = md.digest()
     org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
   }
 
-  def absClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
-    // scalastyle:off
-    val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
-    // Defines Operators that should not generated
-    val notGenerated = Set("Custom")
-    // TODO: Add Filter to the same location in case of refactor
-    val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
-      .filterNot(ele => notGenerated.contains(ele.name))
-      .map(absClassFunction => {
-      val scalaDoc = generateAPIDocFromBackend(absClassFunction)
-      val defBody = generateAPISignature(absClassFunction, isSymbol)
-      s"$scalaDoc\n$defBody"
-    })
-    val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase"
-    val apacheLicence = "/*\n* Licensed to the Apache Software Foundation 
(ASF) under one or more\n* contributor license agreements.  See the NOTICE file 
distributed with\n* this work for additional information regarding copyright 
ownership.\n* The ASF licenses this file to You under the Apache License, 
Version 2.0\n* (the \"License\"); you may not use this file except in 
compliance with\n* the License.  You may obtain a copy of the License at\n*\n*  
  http://www.apache.org/licenses/LICE [...]
-    val scalaStyle = "// scalastyle:off"
-    val packageDef = "package org.apache.mxnet"
-    val imports = "import org.apache.mxnet.annotation.Experimental"
-    val absClassDef = s"abstract class $packageName"
-    val finalStr = 
s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef 
{\n${absFuncs.mkString("\n")}\n}"
-    val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
-    pw.write(finalStr)
-    pw.close()
-    MD5Generator(finalStr)
+  def typeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
+    val generated = typeSafeFunctionsToGenerate(isSymbol, isContrib = false)
+      .map { func =>
+        val scalaDoc = generateAPIDocFromBackend(func)
+        val decl = generateAPISignature(func, isSymbol)
+        s"$scalaDoc\n$decl"
+      }
+
+    writeFile(
+      FILE_PATH,
+      if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
+      "package org.apache.mxnet",
+      generated)
   }
 
-  def nonTypeSafeClassGen(FILE_PATH : String, isSymbol : Boolean) : String = {
-    // scalastyle:off
-    val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
-    val absFuncs = absClassFunctions.map(absClassFunction => {
-      val scalaDoc = generateAPIDocFromBackend(absClassFunction, false)
-      if (isSymbol) {
-        val defBody = s"def ${absClassFunction.name}(name : String = null, 
attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs : 
Map[String, Any] = null): org.apache.mxnet.Symbol"
-        s"$scalaDoc\n$defBody"
-      } else {
-        val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: 
Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
-        val defBody = s"def ${absClassFunction.name}(args: Any*) : 
org.apache.mxnet.NDArrayFuncReturn"
-        s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody"
+  def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
+    val absFuncs = functionsToGenerate(isSymbol, isContrib = false)
+      .map { func =>
+        val scalaDoc = generateAPIDocFromBackend(func, false)
+        if (isSymbol) {
+          s"""$scalaDoc
+             |def ${func.name}(name : String = null, attr : Map[String, 
String] = null)
+             |    (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] 
= null):
+             |    org.apache.mxnet.Symbol
+           """.stripMargin
+        } else {
+          s"""$scalaDoc
+             |def ${func.name}(kwargs: Map[String, Any] = null)
+             |    (args: Any*): org.apache.mxnet.NDArrayFuncReturn
+             |
+             |$scalaDoc
+             |def ${func.name}(args: Any*): org.apache.mxnet.NDArrayFuncReturn
+           """.stripMargin
+        }
       }
-    })
-    val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
-    val apacheLicence = "/*\n* Licensed to the Apache Software Foundation 
(ASF) under one or more\n* contributor license agreements.  See the NOTICE file 
distributed with\n* this work for additional information regarding copyright 
ownership.\n* The ASF licenses this file to You under the Apache License, 
Version 2.0\n* (the \"License\"); you may not use this file except in 
compliance with\n* the License.  You may obtain a copy of the License at\n*\n*  
  http://www.apache.org/licenses/LICE [...]
-    val scalaStyle = "// scalastyle:off"
-    val packageDef = "package org.apache.mxnet"
-    val imports = "import org.apache.mxnet.annotation.Experimental"
-    val absClassDef = s"abstract class $packageName"
-    val finalStr = 
s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef 
{\n${absFuncs.mkString("\n")}\n}"
-    import java.io._
-    val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
-    pw.write(finalStr)
-    pw.close()
-    MD5Generator(finalStr)
+
+    writeFile(
+      FILE_PATH,
+      if (isSymbol) "SymbolBase" else "NDArrayBase",
+      "package org.apache.mxnet",
+      absFuncs)
   }
 
-  // Generate ScalaDoc type
-  def generateAPIDocFromBackend(func : absClassFunction, withParam : Boolean = 
true) : String = {
-    val desc = ArrayBuffer[String]()
-    desc += "  * <pre>"
-      func.desc.split("\n").foreach({ currStr =>
-      desc += s"  * $currStr"
-    })
-    desc += "  * </pre>"
-    val params = func.listOfArgs.map({ absClassArg =>
-      val currArgName = absClassArg.argName match {
-                case "var" => "vari"
-                case "type" => "typeOf"
-                case _ => absClassArg.argName
-      }
-      s"  * @param $currArgName\t\t${absClassArg.argDesc}"
-    })
+  def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String 
= {
+    val desc = func.desc.split("\n")
+      .mkString("  * <pre>\n", "\n  * ", "  * </pre>\n")
+
+    val params = func.listOfArgs.map { absClassArg =>
+      s"  * @param ${absClassArg.safeArgName}\t\t${absClassArg.argDesc}"
+    }
+
     val returnType = s"  * @return ${func.returnType}"
+
     if (withParam) {
-      s"  /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n 
 */"
+      s"""  /**
+         |$desc
+         |${params.mkString("\n")}
+         |$returnType
+         |  */""".stripMargin
     } else {
-      s"  /**\n${desc.mkString("\n")}\n$returnType\n  */"
+      s"""  /**
+         |$desc
+         |$returnType
+         |  */""".stripMargin
     }
   }
 
-  def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : 
String = {
-    var argDef = ListBuffer[String]()
-    func.listOfArgs.foreach(absClassArg => {
-      val currArgName = absClassArg.argName match {
-        case "var" => "vari"
-        case "type" => "typeOf"
-        case _ => absClassArg.argName
-      }
-      if (absClassArg.isOptional) {
-        argDef += s"$currArgName : Option[${absClassArg.argType}] = None"
-      }
-      else {
-        argDef += s"$currArgName : ${absClassArg.argType}"
-      }
-    })
-    var returnType = func.returnType
+  def generateAPISignature(func: Func, isSymbol: Boolean): String = {
+    val argDef = ListBuffer[String]()
+
+    argDef ++= typedFunctionCommonArgDef(func)
+
     if (isSymbol) {
       argDef += "name : String = null"
       argDef += "attr : Map[String, String] = null"
     } else {
       argDef += "out : Option[NDArray] = None"
-      returnType = "org.apache.mxnet.NDArrayFuncReturn"
     }
-    val experimentalTag = "@Experimental"
-    s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : 
$returnType"
-  }
 
+    val returnType = func.returnType
 
-  // List and add all the atomic symbol functions to current module.
-  private def getSymbolNDArrayMethods(isSymbol : Boolean): 
List[absClassFunction] = {
-    val opNames = ListBuffer.empty[String]
-    val returnType = if (isSymbol) "Symbol" else "NDArray"
-    _LIB.mxListAllOpNames(opNames)
-    // TODO: Add '_linalg_', '_sparse_', '_image_' support
-    // TODO: Add Filter to the same location in case of refactor
-    opNames.map(opName => {
-      val opHandle = new RefLong
-      _LIB.nnGetOpHandle(opName, opHandle)
-      makeAtomicSymbolFunction(opHandle.value, opName, "org.apache.mxnet." + 
returnType)
-    }).toList.filterNot(_.name.startsWith("_"))
+    s"""@Experimental
+       |def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin
   }
 
-  // Create an atomic symbol function by handle and function name.
-  private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: 
String, returnType : String)
-  : absClassFunction = {
-    val name = new RefString
-    val desc = new RefString
-    val keyVarNumArgs = new RefString
-    val numArgs = new RefInt
-    val argNames = ListBuffer.empty[String]
-    val argTypes = ListBuffer.empty[String]
-    val argDescs = ListBuffer.empty[String]
-
-    _LIB.mxSymbolGetAtomicSymbolInfo(
-      handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
-    val argList = argNames zip argTypes zip argDescs map { case ((argName, 
argType), argDesc) =>
-      val typeAndOption = CToScalaUtils.argumentCleaner(argName, argType, 
returnType)
-      new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2)
-    }
-    new absClassFunction(aliasName, desc.value, argList.toList, returnType)
+  def writeFile(FILE_PATH: String, className: String, packageDef: String,
+                absFuncs: Seq[String]): String = {
+
+    val finalStr =
+      s"""/*
+         |* Licensed to the Apache Software Foundation (ASF) under one or more
+         |* contributor license agreements.  See the NOTICE file distributed 
with
+         |* this work for additional information regarding copyright ownership.
+         |* The ASF licenses this file to You under the Apache License, 
Version 2.0
+         |* (the "License"); you may not use this file except in compliance 
with
+         |* the License.  You may obtain a copy of the License at
+         |*
+         |*    http://www.apache.org/licenses/LICENSE-2.0
+         |*
+         |* Unless required by applicable law or agreed to in writing, software
+         |* distributed under the License is distributed on an "AS IS" BASIS,
+         |* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 
implied.
+         |* See the License for the specific language governing permissions and
+         |* limitations under the License.
+         |*/
+         |
+         |$packageDef
+         |
+         |import org.apache.mxnet.annotation.Experimental
+         |
+         |// scalastyle:off
+         |abstract class $className {
+         |${absFuncs.mkString("\n")}
+         |}""".stripMargin
+
+    val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala"))
+    pw.write(finalStr)
+    pw.close()
+    MD5Generator(finalStr)
   }
+
 }
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
new file mode 100644
index 0000000..f4c4a91
--- /dev/null
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.apache.mxnet.init.Base.{RefInt, RefLong, RefString, _LIB}
+import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}
+
+import scala.collection.mutable.ListBuffer
+import scala.reflect.macros.blackbox
+
+abstract class GeneratorBase {
+  type Handle = Long
+
+  case class Arg(argName: String, argType: String, argDesc: String, 
isOptional: Boolean) {
+    def safeArgName: String = argName match {
+      case "var" => "vari"
+      case "type" => "typeOf"
+      case _ => argName
+    }
+  }
+
+  case class Func(name: String, desc: String, listOfArgs: List[Arg], 
returnType: String)
+
+  def functionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] = 
{
+    val l = getBackEndFunctions(isSymbol)
+    if (isContrib) {
+      l.filter(func => func.name.startsWith("_contrib_") || 
!func.name.startsWith("_"))
+    } else {
+      l.filterNot(_.name.startsWith("_"))
+    }
+  }
+
+  def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib: Boolean): 
List[Func] = {
+    // Operators that should not be generated
+    val notGenerated = Set("Custom")
+
+    val l = getBackEndFunctions(isSymbol)
+    val res = if (isContrib) {
+      l.filter(func => func.name.startsWith("_contrib_") || 
!func.name.startsWith("_"))
+    } else {
+      l.filterNot(_.name.startsWith("_"))
+    }
+    res.filterNot(ele => notGenerated.contains(ele.name))
+  }
+
+  protected def getBackEndFunctions(isSymbol: Boolean): List[Func] = {
+    val opNames = ListBuffer.empty[String]
+    _LIB.mxListAllOpNames(opNames)
+    opNames.map(opName => {
+      val opHandle = new RefLong
+      _LIB.nnGetOpHandle(opName, opHandle)
+      makeAtomicFunction(opHandle.value, opName, isSymbol)
+    }).toList
+  }
+
+  private def makeAtomicFunction(handle: Handle, aliasName: String, isSymbol: 
Boolean): Func = {
+    val name = new RefString
+    val desc = new RefString
+    val keyVarNumArgs = new RefString
+    val numArgs = new RefInt
+    val argNames = ListBuffer.empty[String]
+    val argTypes = ListBuffer.empty[String]
+    val argDescs = ListBuffer.empty[String]
+
+    _LIB.mxSymbolGetAtomicSymbolInfo(
+      handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
+    val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, 
argDescs)
+    val extraDoc: String = if (keyVarNumArgs.value != null && 
keyVarNumArgs.value.length > 0) {
+      s"This function support variable length of positional input 
(${keyVarNumArgs.value})."
+    } else {
+      ""
+    }
+    val realName = if (aliasName == name.value) "" else s"(a.k.a., 
${name.value})"
+    val docStr = s"$aliasName 
$realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
+
+    val argList = argNames zip argTypes zip argDescs map { case ((argName, 
argType), argDesc) =>
+      val family = if (isSymbol) "org.apache.mxnet.Symbol" else 
"org.apache.mxnet.NDArray"
+      val typeAndOption =
+        CToScalaUtils.argumentCleaner(argName, argType, family)
+      Arg(argName, typeAndOption._1, argDesc, typeAndOption._2)
+    }
+    val returnType =
+      if (isSymbol) "org.apache.mxnet.Symbol" else 
"org.apache.mxnet.NDArrayFuncReturn"
+    Func(aliasName, desc.value, argList.toList, returnType)
+  }
+
+  /**
+    * Generate class structure for all function APIs
+    *
+    * @param c
+    * @param funcDef DefDef type of function definitions
+    * @param annottees
+    * @return
+    */
+  protected def structGeneration(c: blackbox.Context)
+                                (funcDef: List[c.universe.DefDef], annottees: 
c.Expr[Any]*)
+  : c.Expr[Any] = {
+    import c.universe._
+    val inputs = annottees.map(_.tree).toList
+    // pattern match on the inputs
+    val modDefs = inputs map {
+      case ClassDef(mods, name, something, template) =>
+        val q = template match {
+          case Template(superMaybe, emptyValDef, defs) =>
+            Template(superMaybe, emptyValDef, defs ++ funcDef)
+          case ex =>
+            throw new IllegalArgumentException(s"Invalid template: $ex")
+        }
+        ClassDef(mods, name, something, q)
+      case ModuleDef(mods, name, template) =>
+        val q = template match {
+          case Template(superMaybe, emptyValDef, defs) =>
+            Template(superMaybe, emptyValDef, defs ++ funcDef)
+          case ex =>
+            throw new IllegalArgumentException(s"Invalid template: $ex")
+        }
+        ModuleDef(mods, name, q)
+      case ex =>
+        throw new IllegalArgumentException(s"Invalid macro input: $ex")
+    }
+    // wrap the result up in an Expr, and return it
+    val result = c.Expr(Block(modDefs, Literal(Constant())))
+    result
+  }
+
+  protected def typedFunctionCommonArgDef(func: Func): List[String] = {
+    // build function argument definition, with optionality, and safe names
+    func.listOfArgs.map(arg =>
+      if (arg.isOptional) {
+        // let's avoid a stupid Option[Array[...]]
+        if (arg.argType.startsWith("Array[")) {
+          s"${arg.safeArgName} : ${arg.argType} = Array.empty"
+        } else {
+          s"${arg.safeArgName} : Option[${arg.argType}] = None"
+        }
+      }
+      else {
+        s"${arg.safeArgName} : ${arg.argType}"
+      }
+    )
+  }
+}
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
index 2d3a1c7..d85abe1 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
@@ -17,11 +17,8 @@
 
 package org.apache.mxnet
 
-import org.apache.mxnet.init.Base._
-import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}
-
 import scala.annotation.StaticAnnotation
-import scala.collection.mutable.{ArrayBuffer, ListBuffer}
+import scala.collection.mutable.ListBuffer
 import scala.language.experimental.macros
 import scala.reflect.macros.blackbox
 
@@ -30,207 +27,111 @@ private[mxnet] class AddNDArrayFunctions(isContrib: 
Boolean) extends StaticAnnot
 }
 
 private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends 
StaticAnnotation {
-  private[mxnet] def macroTransform(annottees: Any*) = macro 
NDArrayMacro.typeSafeAPIDefs
+  private[mxnet] def macroTransform(annottees: Any*) = macro 
TypedNDArrayAPIMacro.typeSafeAPIDefs
 }
 
-private[mxnet] object NDArrayMacro {
-  case class NDArrayArg(argName: String, argType: String, isOptional : Boolean)
-  case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg])
-
-  // scalastyle:off havetype
-  def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
-    impl(c)(annottees: _*)
-  }
-  def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
-    typeSafeAPIImpl(c)(annottees: _*)
-  }
-  // scalastyle:off havetype
-
-  private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule()
+private[mxnet] object NDArrayMacro extends GeneratorBase {
 
-  private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] 
= {
+  def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
     import c.universe._
-
     val isContrib: Boolean = c.prefix.tree match {
       case q"new AddNDArrayFunctions($b)" => c.eval[Boolean](c.Expr(b))
     }
 
-    val newNDArrayFunctions = {
-      if (isContrib) ndarrayFunctions.filter(_.name.startsWith("_contrib_"))
-      else ndarrayFunctions.filterNot(_.name.startsWith("_"))
-    }
-
-     val functionDefs = newNDArrayFunctions flatMap { NDArrayfunction =>
-        val funcName = NDArrayfunction.name
-        val termName = TermName(funcName)
-       Seq(
-            // scalastyle:off
-            // (yizhi) We are investigating a way to make these functions 
type-safe
-            // and waiting to see the new approach is stable enough.
-            // Thus these functions may be deprecated in the future.
-            // e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*)
-            q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = 
{genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef],
-            // e.g def transpose(args: Any*)
-            q"def $termName(args: Any*) = 
{genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
-            // scalastyle:on
-          )
-        }
-
-    structGeneration(c)(functionDefs, annottees : _*)
+    impl(c)(isContrib, annottees: _*)
   }
 
-  private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : 
c.Expr[Any] = {
+  private def impl(c: blackbox.Context)
+                  (isContrib: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = 
{
     import c.universe._
 
-    val isContrib: Boolean = c.prefix.tree match {
-      case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
-    }
-    // Defines Operators that should not generated
-    val notGenerated = Set("Custom")
-
-    val newNDArrayFunctions = {
-      if (isContrib) ndarrayFunctions.filter(
-        func => func.name.startsWith("_contrib_") || 
!func.name.startsWith("_"))
-      else ndarrayFunctions.filterNot(_.name.startsWith("_"))
-    }.filterNot(ele => notGenerated.contains(ele.name))
-
-    val functionDefs = newNDArrayFunctions.map { ndarrayfunction =>
-
-      // Construct argument field
-      var argDef = ListBuffer[String]()
-      // Construct Implementation field
-      var impl = ListBuffer[String]()
-      impl += "val map = scala.collection.mutable.Map[String, Any]()"
-      impl += "val args = scala.collection.mutable.ArrayBuffer.empty[NDArray]"
-      ndarrayfunction.listOfArgs.foreach({ ndarrayarg =>
-        // var is a special word used to define variable in Scala,
-        // need to changed to something else in order to make it work
-        val currArgName = ndarrayarg.argName match {
-          case "var" => "vari"
-          case "type" => "typeOf"
-          case default => ndarrayarg.argName
-        }
-        if (ndarrayarg.isOptional) {
-          argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None"
-        }
-        else {
-          argDef += s"${currArgName} : ${ndarrayarg.argType}"
-        }
-        // NDArray arg implementation
-        val returnType = "org.apache.mxnet.NDArray"
-
-        // TODO: Currently we do not add place holder for NDArray
-        // Example: an NDArray operator like the following format
-        // nd.foo(arg1: NDArray(required), arg2: NDArray(Optional), arg3: 
NDArray(Optional)
-        // If we place nd.foo(arg1, arg3 = arg3), do we need to add place 
holder for arg2?
-        // What it should be?
-        val base =
-          if (ndarrayarg.argType.equals(returnType)) {
-            s"args += $currArgName"
-          } else if (ndarrayarg.argType.equals(s"Array[$returnType]")){
-            s"args ++= $currArgName"
-          } else {
-            "map(\"" + ndarrayarg.argName + "\") = " + currArgName
-          }
-        impl.append(
-          if (ndarrayarg.isOptional) s"if (!$currArgName.isEmpty) $base.get"
-          else base
-        )
-      })
-      // add default out parameter
-      argDef += "out : Option[NDArray] = None"
-      impl += "if (!out.isEmpty) map(\"out\") = out.get"
-      // scalastyle:off
-      impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + 
ndarrayfunction.name + "\", args.toSeq, map.toMap)"
-      // scalastyle:on
-      // Combine and build the function string
-      val returnType = "org.apache.mxnet.NDArrayFuncReturn"
-      var finalStr = s"def ${ndarrayfunction.name}"
-      finalStr += s" (${argDef.mkString(",")}) : $returnType"
-      finalStr += s" = {${impl.mkString("\n")}}"
-      c.parse(finalStr).asInstanceOf[DefDef]
+    val functions = functionsToGenerate(isSymbol = false, isContrib)
+
+    val functionDefs = functions.flatMap { NDArrayfunction =>
+      val funcName = NDArrayfunction.name
+      val termName = TermName(funcName)
+      Seq(
+        // e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*)
+        q"""
+             def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {
+               genericNDArrayFunctionInvoke($funcName, args, kwargs)
+             }
+          """.asInstanceOf[DefDef],
+        // e.g def transpose(args: Any*)
+        q"""
+             def $termName(args: Any*) = {
+               genericNDArrayFunctionInvoke($funcName, args, null)
+             }
+          """.asInstanceOf[DefDef]
+      )
     }
 
-    structGeneration(c)(functionDefs, annottees : _*)
+    structGeneration(c)(functionDefs, annottees: _*)
   }
+}
 
-  private def structGeneration(c: blackbox.Context)
-                              (funcDef : List[c.universe.DefDef], annottees: 
c.Expr[Any]*)
-  : c.Expr[Any] = {
+private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase {
+
+  def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): 
c.Expr[Any] = {
     import c.universe._
-    val inputs = annottees.map(_.tree).toList
-    // pattern match on the inputs
-    val modDefs = inputs map {
-      case ClassDef(mods, name, something, template) =>
-        val q = template match {
-          case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef)
-          case ex =>
-            throw new IllegalArgumentException(s"Invalid template: $ex")
-        }
-        ClassDef(mods, name, something, q)
-      case ModuleDef(mods, name, template) =>
-        val q = template match {
-          case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef)
-          case ex =>
-            throw new IllegalArgumentException(s"Invalid template: $ex")
-        }
-        ModuleDef(mods, name, q)
-      case ex =>
-        throw new IllegalArgumentException(s"Invalid macro input: $ex")
+    val isContrib: Boolean = c.prefix.tree match {
+      case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
     }
-    // wrap the result up in an Expr, and return it
-    val result = c.Expr(Block(modDefs, Literal(Constant())))
-    result
+
+    val functions = typeSafeFunctionsToGenerate(isSymbol = false, isContrib)
+
+    val functionDefs = functions.map(f => buildTypedFunction(c)(f))
+    structGeneration(c)(functionDefs, annottees: _*)
   }
 
+  protected def buildTypedFunction(c: blackbox.Context)
+                                  (function: Func): c.universe.DefDef = {
+    import c.universe._
 
+    val returnType = "org.apache.mxnet.NDArrayFuncReturn"
+    val ndarrayType = "org.apache.mxnet.NDArray"
 
+    // Construct argument field
+    val argDef = ListBuffer[String]()
+    argDef ++= typedFunctionCommonArgDef(function)
+    argDef += "out : Option[NDArray] = None"
 
-  // List and add all the atomic symbol functions to current module.
-  private def initNDArrayModule(): List[NDArrayFunction] = {
-    val opNames = ListBuffer.empty[String]
-    _LIB.mxListAllOpNames(opNames)
-    opNames.map(opName => {
-      val opHandle = new RefLong
-      _LIB.nnGetOpHandle(opName, opHandle)
-      makeNDArrayFunction(opHandle.value, opName)
-    }).toList
-  }
+    // Construct Implementation field
+    var impl = ListBuffer[String]()
+    impl += "val map = scala.collection.mutable.Map[String, Any]()"
+    impl += s"val args = 
scala.collection.mutable.ArrayBuffer.empty[$ndarrayType]"
 
-  // Create an atomic symbol function by handle and function name.
-  private def makeNDArrayFunction(handle: NDArrayHandle, aliasName: String)
-  : NDArrayFunction = {
-    val name = new RefString
-    val desc = new RefString
-    val keyVarNumArgs = new RefString
-    val numArgs = new RefInt
-    val argNames = ListBuffer.empty[String]
-    val argTypes = ListBuffer.empty[String]
-    val argDescs = ListBuffer.empty[String]
-
-    _LIB.mxSymbolGetAtomicSymbolInfo(
-      handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
-    val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, 
argDescs)
-    val extraDoc: String = if (keyVarNumArgs.value != null && 
keyVarNumArgs.value.length > 0) {
-      s"This function support variable length of positional input 
(${keyVarNumArgs.value})."
-    } else {
-      ""
-    }
-    val realName = if (aliasName == name.value) "" else s"(a.k.a., 
${name.value})"
-    val docStr = s"$aliasName 
$realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
-    // scalastyle:off println
-    if (System.getenv("MXNET4J_PRINT_OP_DEF") != null
-      && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") {
-      println("NDArray function definition:\n" + docStr)
-    }
-    // scalastyle:on println
-    val argList = argNames zip argTypes map { case (argName, argType) =>
-      val typeAndOption =
-        CToScalaUtils.argumentCleaner(argName, argType, 
"org.apache.mxnet.NDArray")
-      new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
+    // NDArray arg implementation
+    impl ++= function.listOfArgs.map { arg =>
+      if (arg.argType.equals(s"Array[$ndarrayType]")) {
+        s"args ++= ${arg.safeArgName}"
+      } else {
+        val base =
+          if (arg.argType.equals(ndarrayType)) {
+            // ndarrays go to args
+            s"args += ${arg.safeArgName}"
+          } else {
+            // other types go to kwargs
+            s"""map("${arg.argName}") = ${arg.safeArgName}"""
+          }
+        if (arg.isOptional) s"if (!${arg.safeArgName}.isEmpty) $base.get"
+        else base
+      }
     }
-    new NDArrayFunction(aliasName, argList.toList)
+
+    impl +=
+      s"""if (!out.isEmpty) map("out") = out.get
+         |org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(
+         |  "${function.name}", args.toSeq, map.toMap)
+       """.stripMargin
+
+    // Combine and build the function string
+    val finalStr =
+      s"""def ${function.name}
+         |   (${argDef.mkString(",")}) : $returnType
+         | = {${impl.mkString("\n")}}
+       """.stripMargin
+
+    c.parse(finalStr).asInstanceOf[DefDef]
   }
 }
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
index 42aa117..ab864e1 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
@@ -21,222 +21,106 @@ import scala.annotation.StaticAnnotation
 import scala.collection.mutable.ListBuffer
 import scala.language.experimental.macros
 import scala.reflect.macros.blackbox
-import org.apache.mxnet.init.Base._
-import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}
 
 private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends 
StaticAnnotation {
-  private[mxnet] def macroTransform(annottees: Any*) = macro 
SymbolImplMacros.addDefs
+  private[mxnet] def macroTransform(annottees: Any*) = macro 
SymbolMacro.addDefs
 }
 
 private[mxnet] class AddSymbolAPIs(isContrib: Boolean) extends 
StaticAnnotation {
-  private[mxnet] def macroTransform(annottees: Any*) = macro 
SymbolImplMacros.typeSafeAPIDefs
+  private[mxnet] def macroTransform(annottees: Any*) = macro 
TypedSymbolAPIMacro.typeSafeAPIDefs
 }
 
-private[mxnet] object SymbolImplMacros {
-  case class SymbolArg(argName: String, argType: String, isOptional : Boolean)
-  case class SymbolFunction(name: String, listOfArgs: List[SymbolArg])
+private[mxnet] object SymbolMacro extends GeneratorBase {
 
-  // scalastyle:off havetype
-  def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
-    impl(c)(annottees: _*)
-  }
-  def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
-    typedAPIImpl(c)(annottees: _*)
-  }
-  // scalastyle:on havetype
-
-  private val symbolFunctions: List[SymbolFunction] = initSymbolModule()
-
-  /**
-    * Implementation for fixed input API structure
-    */
-  private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] 
= {
+  def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
     import c.universe._
-
     val isContrib: Boolean = c.prefix.tree match {
       case q"new AddSymbolFunctions($b)" => c.eval[Boolean](c.Expr(b))
     }
 
-    val newSymbolFunctions = {
-      if (isContrib) symbolFunctions.filter(
-        func => func.name.startsWith("_contrib_") || 
!func.name.startsWith("_"))
-      else symbolFunctions.filter(!_.name.startsWith("_"))
-    }
+    impl(c)(isContrib, annottees: _*)
+  }
+
+  private def impl(c: blackbox.Context)
+                  (isContrib: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = 
{
+    import c.universe._
 
+    val functions = functionsToGenerate(isSymbol = false, isContrib)
 
-    val functionDefs = newSymbolFunctions map { symbolfunction =>
-        val funcName = symbolfunction.name
-        val tName = TermName(funcName)
-        q"""
+    val functionDefs = functions.map { symbolfunction =>
+      val funcName = symbolfunction.name
+      val tName = TermName(funcName)
+      q"""
             def $tName(name : String = null, attr : Map[String, String] = null)
-            (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null)
-             : org.apache.mxnet.Symbol = {
-              createSymbolGeneral($funcName,name,attr,args,kwargs)
-              }
+              (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = 
null)
+              : org.apache.mxnet.Symbol = {
+                createSymbolGeneral($funcName,name,attr,args,kwargs)
+            }
          """.asInstanceOf[DefDef]
-      }
+    }
 
-    structGeneration(c)(functionDefs, annottees : _*)
+    structGeneration(c)(functionDefs, annottees: _*)
   }
+}
 
-  /**
-    * Implementation for Dynamic typed API Symbol.api.<functioname>
-    */
-  private def typedAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : 
c.Expr[Any] = {
-    import c.universe._
+private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase {
 
+  def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): 
c.Expr[Any] = {
+    import c.universe._
     val isContrib: Boolean = c.prefix.tree match {
       case q"new AddSymbolAPIs($b)" => c.eval[Boolean](c.Expr(b))
     }
 
-    // Defines Operators that should not generated
-    val notGenerated = Set("Custom")
-
-    // TODO: Put Symbol.api.foo --> Stable APIs
-    // Symbol.contrib.bar--> Contrib APIs
-    val newSymbolFunctions = {
-      if (isContrib) symbolFunctions.filter(
-        func => func.name.startsWith("_contrib_") || 
!func.name.startsWith("_"))
-      else symbolFunctions.filter(!_.name.startsWith("_"))
-    }.filterNot(ele => notGenerated.contains(ele.name))
-
-    val functionDefs = newSymbolFunctions map { symbolfunction =>
-
-      // Construct argument field
-      var argDef = ListBuffer[String]()
-      // Construct Implementation field
-      var impl = ListBuffer[String]()
-      impl += "val map = scala.collection.mutable.Map[String, Any]()"
-      impl += "var args = Seq[org.apache.mxnet.Symbol]()"
-      symbolfunction.listOfArgs.foreach({ symbolarg =>
-        // var is a special word used to define variable in Scala,
-        // need to changed to something else in order to make it work
-        val currArgName = symbolarg.argName match {
-          case "var" => "vari"
-          case "type" => "typeOf"
-          case default => symbolarg.argName
-        }
-        if (symbolarg.isOptional) {
-          argDef += s"${currArgName} : Option[${symbolarg.argType}] = None"
-        }
-        else {
-          argDef += s"${currArgName} : ${symbolarg.argType}"
-        }
-        // Symbol arg implementation
-        val returnType = "org.apache.mxnet.Symbol"
-        val base =
-        if (symbolarg.argType.equals(s"Array[$returnType]")) {
-          if (symbolarg.isOptional) s"if (!$currArgName.isEmpty) args = 
$currArgName.get.toSeq"
-          else s"args = $currArgName.toSeq"
-        } else {
-          if (symbolarg.isOptional) {
-            // scalastyle:off
-            s"if (!$currArgName.isEmpty) map(" + "\"" + symbolarg.argName + 
"\"" + s") = $currArgName.get"
-            // scalastyle:on
-          }
-          else "map(\"" + symbolarg.argName + "\"" + s") = $currArgName"
-        }
+    val functions = typeSafeFunctionsToGenerate(isSymbol = true, isContrib)
 
-        impl += base
-      })
-      argDef += "name : String = null"
-      argDef += "attr : Map[String, String] = null"
-      // scalastyle:off
-      // TODO: Seq() here allows user to place Symbols rather than normal 
arguments to run, need to fix if old API deprecated
-      impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + 
symbolfunction.name + "\", name, attr, args, map.toMap)"
-      // scalastyle:on
-      // Combine and build the function string
-      val returnType = "org.apache.mxnet.Symbol"
-      var finalStr = s"def ${symbolfunction.name}"
-      finalStr += s" (${argDef.mkString(",")}) : $returnType"
-      finalStr += s" = {${impl.mkString("\n")}}"
-      c.parse(finalStr).asInstanceOf[DefDef]
-    }
-    structGeneration(c)(functionDefs, annottees : _*)
+    val functionDefs = functions.map(f => buildTypedFunction(c)(f))
+    structGeneration(c)(functionDefs, annottees: _*)
   }
 
-  /**
-    * Generate class structure for all function APIs
-    * @param c
-    * @param funcDef DefDef type of function definitions
-    * @param annottees
-    * @return
-    */
-  private def structGeneration(c: blackbox.Context)
-                              (funcDef : List[c.universe.DefDef], annottees: 
c.Expr[Any]*)
-  : c.Expr[Any] = {
+  protected def buildTypedFunction(c: blackbox.Context)
+                                  (function: Func): c.universe.DefDef = {
     import c.universe._
-    val inputs = annottees.map(_.tree).toList
-    // pattern match on the inputs
-    val modDefs = inputs map {
-      case ClassDef(mods, name, something, template) =>
-        val q = template match {
-          case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef)
-          case ex =>
-            throw new IllegalArgumentException(s"Invalid template: $ex")
-        }
-        ClassDef(mods, name, something, q)
-      case ModuleDef(mods, name, template) =>
-        val q = template match {
-          case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef)
-          case ex =>
-            throw new IllegalArgumentException(s"Invalid template: $ex")
-        }
-        ModuleDef(mods, name, q)
-      case ex =>
-        throw new IllegalArgumentException(s"Invalid macro input: $ex")
-    }
-    // wrap the result up in an Expr, and return it
-    val result = c.Expr(Block(modDefs, Literal(Constant())))
-    result
-  }
 
-  // List and add all the atomic symbol functions to current module.
-  private def initSymbolModule(): List[SymbolFunction] = {
-    val opNames = ListBuffer.empty[String]
-    _LIB.mxListAllOpNames(opNames)
-    // TODO: Add '_linalg_', '_sparse_', '_image_' support
-    opNames.map(opName => {
-      val opHandle = new RefLong
-      _LIB.nnGetOpHandle(opName, opHandle)
-      makeAtomicSymbolFunction(opHandle.value, opName)
-    }).toList
-  }
+    val returnType = "org.apache.mxnet.Symbol"
+    val symbolType = "org.apache.mxnet.Symbol"
 
-  // Create an atomic symbol function by handle and function name.
-  private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String)
-      : SymbolFunction = {
-    val name = new RefString
-    val desc = new RefString
-    val keyVarNumArgs = new RefString
-    val numArgs = new RefInt
-    val argNames = ListBuffer.empty[String]
-    val argTypes = ListBuffer.empty[String]
-    val argDescs = ListBuffer.empty[String]
-
-    _LIB.mxSymbolGetAtomicSymbolInfo(
-      handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
-    val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, 
argDescs)
-    val extraDoc: String = if (keyVarNumArgs.value != null && 
keyVarNumArgs.value.length > 0) {
-        s"This function support variable length of positional input 
(${keyVarNumArgs.value})."
+    // Construct argument field
+    val argDef = ListBuffer[String]()
+    argDef ++= typedFunctionCommonArgDef(function)
+    argDef += "name : String = null"
+    argDef += "attr : Map[String, String] = null"
+
+    // Construct Implementation field
+    val impl = ListBuffer[String]()
+    impl += "val map = scala.collection.mutable.Map[String, Any]()"
+    impl += s"var args = scala.collection.Seq[$symbolType]()"
+
+    // Symbol arg implementation
+    impl ++= function.listOfArgs.map { arg =>
+      if (arg.argType.equals(s"Array[$symbolType]")) {
+        s"if (!${arg.safeArgName}.isEmpty) args = ${arg.safeArgName}.toSeq"
       } else {
-        ""
+        // all go in kwargs
+        if (arg.isOptional) {
+          s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") = 
${arg.safeArgName}.get"""
+        } else {
+          s"""map("${arg.argName}") = ${arg.safeArgName}"""
+        }
       }
-    val realName = if (aliasName == name.value) "" else s"(a.k.a., 
${name.value})"
-    val docStr = s"$aliasName 
$realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
-    // scalastyle:off println
-    if (System.getenv("MXNET4J_PRINT_OP_DEF") != null
-          && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") {
-      println("Symbol function definition:\n" + docStr)
     }
-    // scalastyle:on println
-    val argList = argNames zip argTypes map { case (argName, argType) =>
-        val typeAndOption =
-          CToScalaUtils.argumentCleaner(argName, argType, 
"org.apache.mxnet.Symbol")
-        new SymbolArg(argName, typeAndOption._1, typeAndOption._2)
-    }
-    new SymbolFunction(aliasName, argList.toList)
+
+    impl +=
+      s"""org.apache.mxnet.Symbol.createSymbolGeneral(
+         |  "${function.name}", name, attr, args, map.toMap)
+       """.stripMargin
+
+    // Combine and build the function string
+    val finalStr =
+      s"""def ${function.name}
+         |   (${argDef.mkString(",")}) : $returnType
+         | = {${impl.mkString("\n")}}
+       """.stripMargin
+
+    c.parse(finalStr).asInstanceOf[DefDef]
   }
 }

Reply via email to