yzhliu closed pull request #13242: [MXNET-918] [Introduce Random module / 
Refact code generation (#13038)][Cherry pick] 
URL: https://github.com/apache/incubator-mxnet/pull/13242
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/ObjectDetectionBenchmark.java
 
b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/ObjectDetectionBenchmark.java
index 485e0afa3e4..257ea324162 100644
--- 
a/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/ObjectDetectionBenchmark.java
+++ 
b/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/benchmark/ObjectDetectionBenchmark.java
@@ -57,7 +57,7 @@ public void runBatchInference() {
         List<NDArray> nd = new ArrayList<>();
         NDArray[] temp = new NDArray[batchSize];
         for (int i = 0; i < batchSize; i++) temp[i] = img.copy();
-        NDArray batched = NDArray.concat(temp, 
batchSize).setdim(0).invoke().get();
+        NDArray batched = NDArray.concat(temp, batchSize, 0, null)[0];
         nd.add(batched);
         objDet.objectDetectWithNDArray(nd, 3);
     }
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
index f2326868e8e..0c12e1f1c67 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
@@ -17,196 +17,151 @@
 
 package org.apache.mxnet
 
-import org.apache.mxnet.init.Base._
-import org.apache.mxnet.utils.CToScalaUtils
 import java.io._
 import java.security.MessageDigest
 
-import scala.collection.mutable.{ArrayBuffer, ListBuffer}
+import scala.collection.mutable.ListBuffer
 
 /**
   * This object will generate the Scala documentation of the new Scala API
   * Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
   * The code will be executed during Macros stage and file live in Core stage
   */
-private[mxnet] object APIDocGenerator{
-  case class absClassArg(argName : String, argType : String, argDesc : String, 
isOptional : Boolean)
-  case class absClassFunction(name : String, desc : String,
-                           listOfArgs: List[absClassArg], returnType : String)
+private[mxnet] object APIDocGenerator extends GeneratorBase {
 
-
-  def main(args: Array[String]) : Unit = {
+  def main(args: Array[String]): Unit = {
     val FILE_PATH = args(0)
     val hashCollector = ListBuffer[String]()
-    hashCollector += absClassGen(FILE_PATH, true)
-    hashCollector += absClassGen(FILE_PATH, false)
+    hashCollector += typeSafeClassGen(FILE_PATH, true)
+    hashCollector += typeSafeClassGen(FILE_PATH, false)
     hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
     hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
-    // Generate Java API documentation
-    hashCollector += javaClassGen(FILE_PATH + "javaapi/")
+    hashCollector += javaClassGen(FILE_PATH)
     val finalHash = hashCollector.mkString("\n")
   }
 
-  def MD5Generator(input : String) : String = {
+  def MD5Generator(input: String): String = {
     val md = MessageDigest.getInstance("MD5")
     md.update(input.getBytes("UTF-8"))
     val digest = md.digest()
     org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
   }
 
-  def fileGen(filePath : String, packageName : String, packageDef : String,
-              absFuncs : List[String]) : String = {
-    val apacheLicense =
-      """/*
-        |* Licensed to the Apache Software Foundation (ASF) under one or more
-        |* contributor license agreements.  See the NOTICE file distributed 
with
-        |* this work for additional information regarding copyright ownership.
-        |* The ASF licenses this file to You under the Apache License, Version 
2.0
-        |* (the "License"); you may not use this file except in compliance with
-        |* the License.  You may obtain a copy of the License at
-        |*
-        |*    http://www.apache.org/licenses/LICENSE-2.0
-        |*
-        |* Unless required by applicable law or agreed to in writing, software
-        |* distributed under the License is distributed on an "AS IS" BASIS,
-        |* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 
implied.
-        |* See the License for the specific language governing permissions and
-        |* limitations under the License.
-        |*/
-        |""".stripMargin
-    val scalaStyle = "// scalastyle:off"
-    val imports = "import org.apache.mxnet.annotation.Experimental"
-    val absClassDef = s"abstract class $packageName"
+  def typeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
+    val generated = typeSafeFunctionsToGenerate(isSymbol, isContrib = false)
+      .map { func =>
+        val scalaDoc = generateAPIDocFromBackend(func)
+        val decl = generateAPISignature(func, isSymbol)
+        s"$scalaDoc\n$decl"
+      }
 
-    val finalStr =
-      s"""$apacheLicense
-         |$scalaStyle
-         |$packageDef
-         |$imports
-         |$absClassDef {
-         |${absFuncs.mkString("\n")}
-         |}""".stripMargin
-    val pw = new PrintWriter(new File(filePath + s"$packageName.scala"))
-    pw.write(finalStr)
-    pw.close()
-    MD5Generator(finalStr)
+    writeFile(
+      FILE_PATH,
+      if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
+      "package org.apache.mxnet",
+      generated)
   }
 
-  def absClassGen(filePath : String, isSymbol : Boolean) : String = {
-    val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
-    // Defines Operators that should not generated
-    val notGenerated = Set("Custom")
-    // TODO: Add Filter to the same location in case of refactor
-    val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
-      .filterNot(ele => notGenerated.contains(ele.name))
-      .map(absClassFunction => {
-      val scalaDoc = generateAPIDocFromBackend(absClassFunction)
-      val defBody = generateAPISignature(absClassFunction, isSymbol)
-      s"$scalaDoc\n$defBody"
-    })
-    val packageName = if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase"
-    val packageDef = "package org.apache.mxnet"
-    fileGen(filePath, packageName, packageDef, absFuncs)
+  def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
+    val absFuncs = functionsToGenerate(isSymbol, isContrib = false)
+      .map { func =>
+        val scalaDoc = generateAPIDocFromBackend(func, false)
+        if (isSymbol) {
+          s"""$scalaDoc
+             |def ${func.name}(name : String = null, attr : Map[String, 
String] = null)
+             |    (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] 
= null):
+             |    org.apache.mxnet.Symbol
+           """.stripMargin
+        } else {
+          s"""$scalaDoc
+             |def ${func.name}(kwargs: Map[String, Any] = null)
+             |    (args: Any*): org.apache.mxnet.NDArrayFuncReturn
+             |
+             |$scalaDoc
+             |def ${func.name}(args: Any*): org.apache.mxnet.NDArrayFuncReturn
+           """.stripMargin
+        }
+      }
+
+    writeFile(
+      FILE_PATH,
+      if (isSymbol) "SymbolBase" else "NDArrayBase",
+      "package org.apache.mxnet",
+      absFuncs)
   }
 
   def javaClassGen(filePath : String) : String = {
     val notGenerated = Set("Custom")
-    val absClassFunctions = getSymbolNDArrayMethods(false, true)
-    // TODO: Add Filter to the same location in case of refactor
-    val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
-      .filterNot(ele => notGenerated.contains(ele.name))
-      .map(absClassFunction => {
+    val absClassFunctions = functionsToGenerate(false, false, true)
+    val absFuncs = absClassFunctions.filterNot(ele => 
notGenerated.contains(ele.name))
+      .groupBy(_.name.toLowerCase).map(ele => {
+      /* Pattern matching for not generating deprecated method
+       * Group all method name in lowercase
+       * Kill the capital lettered method such as Cast vs cast
+       * As it defined by default it deprecated
+       */
+      if (ele._2.length == 1) ele._2.head
+      else {
+        if (ele._2.head.name.head.isLower) ele._2.head
+        else ele._2.last
+      }
+    }).map(absClassFunction => {
         generateJavaAPISignature(absClassFunction)
-      })
+      }).toSeq
     val packageName = "NDArrayBase"
     val packageDef = "package org.apache.mxnet.javaapi"
-    fileGen(filePath, packageName, packageDef, absFuncs)
+    writeFile(filePath + "javaapi/", packageName, packageDef, absFuncs)
   }
 
-  def nonTypeSafeClassGen(filePath : String, isSymbol : Boolean) : String = {
-    // scalastyle:off
-    val absClassFunctions = getSymbolNDArrayMethods(isSymbol)
-    val absFuncs = absClassFunctions.map(absClassFunction => {
-      val scalaDoc = generateAPIDocFromBackend(absClassFunction, false)
-      if (isSymbol) {
-        val defBody = s"def ${absClassFunction.name}(name : String = null, 
attr : Map[String, String] = null)(args : org.apache.mxnet.Symbol*)(kwargs : 
Map[String, Any] = null): org.apache.mxnet.Symbol"
-        s"$scalaDoc\n$defBody"
-      } else {
-        val defBodyWithKwargs = s"def ${absClassFunction.name}(kwargs: 
Map[String, Any] = null)(args: Any*) : org.apache.mxnet.NDArrayFuncReturn"
-        val defBody = s"def ${absClassFunction.name}(args: Any*) : 
org.apache.mxnet.NDArrayFuncReturn"
-        s"$scalaDoc\n$defBodyWithKwargs\n$scalaDoc\n$defBody"
-      }
-    })
-    val packageName = if (isSymbol) "SymbolBase" else "NDArrayBase"
-    val packageDef = "package org.apache.mxnet"
-    fileGen(filePath, packageName, packageDef, absFuncs)
-  }
+  def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String 
= {
+    val desc = func.desc.split("\n")
+      .mkString("  * <pre>\n", "\n  * ", "  * </pre>\n")
 
-  /**
-    * Some of the C++ type name is not valid in Scala
-    * such as var and type. This method is to convert
-    * them into other names to get it passed
-    * @param in the input String
-    * @return converted name string
-    */
-  def safetyNameCheck(in : String) : String = {
-    in match {
-      case "var" => "vari"
-      case "type" => "typeOf"
-      case _ => in
+    val params = func.listOfArgs.map { absClassArg =>
+      s"  * @param ${absClassArg.safeArgName}\t\t${absClassArg.argDesc}"
     }
-  }
 
-  // Generate ScalaDoc type
-  def generateAPIDocFromBackend(func : absClassFunction, withParam : Boolean = 
true) : String = {
-    val desc = ArrayBuffer[String]()
-    desc += "  * <pre>"
-      func.desc.split("\n").foreach({ currStr =>
-      desc += s"  * $currStr"
-    })
-    desc += "  * </pre>"
-    val params = func.listOfArgs.map({ absClassArg =>
-      val currArgName = safetyNameCheck(absClassArg.argName)
-      s"  * @param $currArgName\t\t${absClassArg.argDesc}"
-    })
     val returnType = s"  * @return ${func.returnType}"
+
     if (withParam) {
-      s"  /**\n${desc.mkString("\n")}\n${params.mkString("\n")}\n$returnType\n 
 */"
+      s"""  /**
+         |$desc
+         |${params.mkString("\n")}
+         |$returnType
+         |  */""".stripMargin
     } else {
-      s"  /**\n${desc.mkString("\n")}\n$returnType\n  */"
+      s"""  /**
+         |$desc
+         |$returnType
+         |  */""".stripMargin
     }
   }
 
-  def generateAPISignature(func : absClassFunction, isSymbol : Boolean) : 
String = {
-    var argDef = ListBuffer[String]()
-    func.listOfArgs.foreach(absClassArg => {
-      val currArgName = safetyNameCheck(absClassArg.argName)
-      if (absClassArg.isOptional) {
-        argDef += s"$currArgName : Option[${absClassArg.argType}] = None"
-      }
-      else {
-        argDef += s"$currArgName : ${absClassArg.argType}"
-      }
-    })
-    var returnType = func.returnType
+  def generateAPISignature(func: Func, isSymbol: Boolean): String = {
+    val argDef = ListBuffer[String]()
+
+    argDef ++= typedFunctionCommonArgDef(func)
+
     if (isSymbol) {
       argDef += "name : String = null"
       argDef += "attr : Map[String, String] = null"
     } else {
       argDef += "out : Option[NDArray] = None"
-      returnType = "org.apache.mxnet.NDArrayFuncReturn"
     }
-    val experimentalTag = "@Experimental"
-    s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : 
$returnType"
+
+    val returnType = func.returnType
+
+    s"""@Experimental
+       |def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin
   }
 
-  def generateJavaAPISignature(func : absClassFunction) : String = {
+  def generateJavaAPISignature(func : Func) : String = {
     val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2
     var argDef = ListBuffer[String]()
     var classDef = ListBuffer[String]()
     var requiredParam = ListBuffer[String]()
     func.listOfArgs.foreach(absClassArg => {
-      val currArgName = safetyNameCheck(absClassArg.argName)
+      val currArgName = absClassArg.safeArgName
       // scalastyle:off
       if (absClassArg.isOptional && useParamObject) {
         classDef +=
@@ -240,15 +195,15 @@ private[mxnet] object APIDocGenerator{
            | def getOut() = this.out
            | """.stripMargin
       s"""$scalaDocNoParam
-          | $experimentalTag
-          | def ${func.name}(po: ${func.name}Param) : $returnType
-          | /**
-          | * This Param Object is specifically used for ${func.name}
-          | ${requiredParam.mkString("\n")}
-          | */
-          | class ${func.name}Param(${argDef.mkString(",")}) {
-          |  ${classDef.mkString("\n  ")}
-          | }""".stripMargin
+         | $experimentalTag
+         | def ${func.name}(po: ${func.name}Param) : $returnType
+         | /**
+         | * This Param Object is specifically used for ${func.name}
+         | ${requiredParam.mkString("\n")}
+         | */
+         | class ${func.name}Param(${argDef.mkString(",")}) {
+         |  ${classDef.mkString("\n  ")}
+         | }""".stripMargin
     } else {
       argDef += "out : NDArray"
       s"""$scalaDoc
@@ -258,48 +213,40 @@ private[mxnet] object APIDocGenerator{
     }
   }
 
+  def writeFile(FILE_PATH: String, className: String, packageDef: String,
+                absFuncs: Seq[String]): String = {
 
-  // List and add all the atomic symbol functions to current module.
-  private def getSymbolNDArrayMethods(isSymbol : Boolean,
-                                      isJava : Boolean = false): 
List[absClassFunction] = {
-    val opNames = ListBuffer.empty[String]
-    val returnType = if (isSymbol) "Symbol" else "NDArray"
-    val returnHeader = if (isJava) "org.apache.mxnet.javaapi." else 
"org.apache.mxnet."
-    _LIB.mxListAllOpNames(opNames)
-    // TODO: Add '_linalg_', '_sparse_', '_image_' support
-    // TODO: Add Filter to the same location in case of refactor
-    opNames.map(opName => {
-      val opHandle = new RefLong
-      _LIB.nnGetOpHandle(opName, opHandle)
-      makeAtomicSymbolFunction(opHandle.value, opName, returnHeader + 
returnType)
-    }).filterNot(_.name.startsWith("_")).groupBy(_.name.toLowerCase).map(ele 
=> {
-      // Pattern matching for not generating depreciated method
-      if (ele._2.length == 1) ele._2.head
-      else {
-        if (ele._2.head.name.head.isLower) ele._2.head
-        else ele._2.last
-      }
-    }).toList
-  }
-
-  // Create an atomic symbol function by handle and function name.
-  private def makeAtomicSymbolFunction(handle: SymbolHandle,
-                                       aliasName: String, returnType : String)
-  : absClassFunction = {
-    val name = new RefString
-    val desc = new RefString
-    val keyVarNumArgs = new RefString
-    val numArgs = new RefInt
-    val argNames = ListBuffer.empty[String]
-    val argTypes = ListBuffer.empty[String]
-    val argDescs = ListBuffer.empty[String]
+    val finalStr =
+      s"""/*
+         |* Licensed to the Apache Software Foundation (ASF) under one or more
+         |* contributor license agreements.  See the NOTICE file distributed 
with
+         |* this work for additional information regarding copyright ownership.
+         |* The ASF licenses this file to You under the Apache License, 
Version 2.0
+         |* (the "License"); you may not use this file except in compliance 
with
+         |* the License.  You may obtain a copy of the License at
+         |*
+         |*    http://www.apache.org/licenses/LICENSE-2.0
+         |*
+         |* Unless required by applicable law or agreed to in writing, software
+         |* distributed under the License is distributed on an "AS IS" BASIS,
+         |* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 
implied.
+         |* See the License for the specific language governing permissions and
+         |* limitations under the License.
+         |*/
+         |
+         |$packageDef
+         |
+         |import org.apache.mxnet.annotation.Experimental
+         |
+         |// scalastyle:off
+         |abstract class $className {
+         |${absFuncs.mkString("\n")}
+         |}""".stripMargin
 
-    _LIB.mxSymbolGetAtomicSymbolInfo(
-      handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
-    val argList = argNames zip argTypes zip argDescs map { case ((argName, 
argType), argDesc) =>
-      val typeAndOption = CToScalaUtils.argumentCleaner(argName, argType, 
returnType)
-      new absClassArg(argName, typeAndOption._1, argDesc, typeAndOption._2)
-    }
-    new absClassFunction(aliasName, desc.value, argList.toList, returnType)
+    val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala"))
+    pw.write(finalStr)
+    pw.close()
+    MD5Generator(finalStr)
   }
-}
+
+}
\ No newline at end of file
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
new file mode 100644
index 00000000000..9245ef1b437
--- /dev/null
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.apache.mxnet.init.Base.{RefInt, RefLong, RefString, _LIB}
+import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}
+
+import scala.collection.mutable.ListBuffer
+import scala.reflect.macros.blackbox
+
+abstract class GeneratorBase {
+  type Handle = Long
+
+  case class Arg(argName: String, argType: String, argDesc: String, 
isOptional: Boolean) {
+    def safeArgName: String = argName match {
+      case "var" => "vari"
+      case "type" => "typeOf"
+      case _ => argName
+    }
+  }
+
+  case class Func(name: String, desc: String, listOfArgs: List[Arg], 
returnType: String)
+
+  def functionsToGenerate(isSymbol: Boolean, isContrib: Boolean,
+                          isJava: Boolean = false): List[Func] = {
+    val l = getBackEndFunctions(isSymbol, isJava)
+    if (isContrib) {
+      l.filter(func => func.name.startsWith("_contrib_") || 
!func.name.startsWith("_"))
+    } else {
+      l.filterNot(_.name.startsWith("_"))
+    }
+  }
+
+  def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib: Boolean): 
List[Func] = {
+    // Operators that should not be generated
+    val notGenerated = Set("Custom")
+
+    val l = getBackEndFunctions(isSymbol)
+    val res = if (isContrib) {
+      l.filter(func => func.name.startsWith("_contrib_") || 
!func.name.startsWith("_"))
+    } else {
+      l.filterNot(_.name.startsWith("_"))
+    }
+    res.filterNot(ele => notGenerated.contains(ele.name))
+  }
+
+  protected def getBackEndFunctions(isSymbol: Boolean, isJava: Boolean = 
false): List[Func] = {
+    val opNames = ListBuffer.empty[String]
+    _LIB.mxListAllOpNames(opNames)
+    opNames.map(opName => {
+      val opHandle = new RefLong
+      _LIB.nnGetOpHandle(opName, opHandle)
+      makeAtomicFunction(opHandle.value, opName, isSymbol, isJava)
+    }).toList
+  }
+
+  private def makeAtomicFunction(handle: Handle, aliasName: String,
+                                 isSymbol: Boolean, isJava: Boolean): Func = {
+    val name = new RefString
+    val desc = new RefString
+    val keyVarNumArgs = new RefString
+    val numArgs = new RefInt
+    val argNames = ListBuffer.empty[String]
+    val argTypes = ListBuffer.empty[String]
+    val argDescs = ListBuffer.empty[String]
+
+    _LIB.mxSymbolGetAtomicSymbolInfo(
+      handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
+    val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, 
argDescs)
+    val extraDoc: String = if (keyVarNumArgs.value != null && 
keyVarNumArgs.value.length > 0) {
+      s"This function support variable length of positional input 
(${keyVarNumArgs.value})."
+    } else {
+      ""
+    }
+    val realName = if (aliasName == name.value) "" else s"(a.k.a., 
${name.value})"
+    val docStr = s"$aliasName 
$realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
+
+    val argList = argNames zip argTypes zip argDescs map { case ((argName, 
argType), argDesc) =>
+      val family = if (isJava) "org.apache.mxnet.javaapi.NDArray"
+      else if (isSymbol) "org.apache.mxnet.Symbol"
+      else "org.apache.mxnet.NDArray"
+      val typeAndOption =
+        CToScalaUtils.argumentCleaner(argName, argType, family)
+      Arg(argName, typeAndOption._1, argDesc, typeAndOption._2)
+    }
+    val returnType =
+      if (isJava) "Array[org.apache.mxnet.javaapi.NDArray]"
+      else if (isSymbol) "org.apache.mxnet.Symbol"
+      else "org.apache.mxnet.NDArrayFuncReturn"
+    Func(aliasName, desc.value, argList.toList, returnType)
+  }
+
+  /**
+    * Generate class structure for all function APIs
+    *
+    * @param c
+    * @param funcDef DefDef type of function definitions
+    * @param annottees
+    * @return
+    */
+  protected def structGeneration(c: blackbox.Context)
+                                (funcDef: List[c.universe.DefDef], annottees: 
c.Expr[Any]*)
+  : c.Expr[Any] = {
+    import c.universe._
+    val inputs = annottees.map(_.tree).toList
+    // pattern match on the inputs
+    val modDefs = inputs map {
+      case ClassDef(mods, name, something, template) =>
+        val q = template match {
+          case Template(superMaybe, emptyValDef, defs) =>
+            Template(superMaybe, emptyValDef, defs ++ funcDef)
+          case ex =>
+            throw new IllegalArgumentException(s"Invalid template: $ex")
+        }
+        ClassDef(mods, name, something, q)
+      case ModuleDef(mods, name, template) =>
+        val q = template match {
+          case Template(superMaybe, emptyValDef, defs) =>
+            Template(superMaybe, emptyValDef, defs ++ funcDef)
+          case ex =>
+            throw new IllegalArgumentException(s"Invalid template: $ex")
+        }
+        ModuleDef(mods, name, q)
+      case ex =>
+        throw new IllegalArgumentException(s"Invalid macro input: $ex")
+    }
+    // wrap the result up in an Expr, and return it
+    val result = c.Expr(Block(modDefs, Literal(Constant())))
+    result
+  }
+
+  protected def typedFunctionCommonArgDef(func: Func): List[String] = {
+    // build function argument definition, with optionality, and safe names
+    func.listOfArgs.map(arg =>
+      if (arg.isOptional) {
+        // let's avoid a stupid Option[Array[...]]
+        if (arg.argType.startsWith("Array[")) {
+          s"${arg.safeArgName} : ${arg.argType} = Array.empty"
+        } else {
+          s"${arg.safeArgName} : Option[${arg.argType}] = None"
+        }
+      }
+      else {
+        s"${arg.safeArgName} : ${arg.argType}"
+      }
+    )
+  }
+}
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
index 2d3a1c7ec5a..d85abe1ecc4 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
@@ -17,11 +17,8 @@
 
 package org.apache.mxnet
 
-import org.apache.mxnet.init.Base._
-import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}
-
 import scala.annotation.StaticAnnotation
-import scala.collection.mutable.{ArrayBuffer, ListBuffer}
+import scala.collection.mutable.ListBuffer
 import scala.language.experimental.macros
 import scala.reflect.macros.blackbox
 
@@ -30,207 +27,111 @@ private[mxnet] class AddNDArrayFunctions(isContrib: 
Boolean) extends StaticAnnot
 }
 
 private[mxnet] class AddNDArrayAPIs(isContrib: Boolean) extends 
StaticAnnotation {
-  private[mxnet] def macroTransform(annottees: Any*) = macro 
NDArrayMacro.typeSafeAPIDefs
+  private[mxnet] def macroTransform(annottees: Any*) = macro 
TypedNDArrayAPIMacro.typeSafeAPIDefs
 }
 
-private[mxnet] object NDArrayMacro {
-  case class NDArrayArg(argName: String, argType: String, isOptional : Boolean)
-  case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg])
-
-  // scalastyle:off havetype
-  def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
-    impl(c)(annottees: _*)
-  }
-  def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
-    typeSafeAPIImpl(c)(annottees: _*)
-  }
-  // scalastyle:off havetype
-
-  private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule()
+private[mxnet] object NDArrayMacro extends GeneratorBase {
 
-  private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] 
= {
+  def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
     import c.universe._
-
     val isContrib: Boolean = c.prefix.tree match {
       case q"new AddNDArrayFunctions($b)" => c.eval[Boolean](c.Expr(b))
     }
 
-    val newNDArrayFunctions = {
-      if (isContrib) ndarrayFunctions.filter(_.name.startsWith("_contrib_"))
-      else ndarrayFunctions.filterNot(_.name.startsWith("_"))
-    }
-
-     val functionDefs = newNDArrayFunctions flatMap { NDArrayfunction =>
-        val funcName = NDArrayfunction.name
-        val termName = TermName(funcName)
-       Seq(
-            // scalastyle:off
-            // (yizhi) We are investigating a way to make these functions 
type-safe
-            // and waiting to see the new approach is stable enough.
-            // Thus these functions may be deprecated in the future.
-            // e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*)
-            q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = 
{genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef],
-            // e.g def transpose(args: Any*)
-            q"def $termName(args: Any*) = 
{genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
-            // scalastyle:on
-          )
-        }
-
-    structGeneration(c)(functionDefs, annottees : _*)
+    impl(c)(isContrib, annottees: _*)
   }
 
-  private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : 
c.Expr[Any] = {
+  private def impl(c: blackbox.Context)
+                  (isContrib: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = 
{
     import c.universe._
 
-    val isContrib: Boolean = c.prefix.tree match {
-      case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
-    }
-    // Defines Operators that should not generated
-    val notGenerated = Set("Custom")
-
-    val newNDArrayFunctions = {
-      if (isContrib) ndarrayFunctions.filter(
-        func => func.name.startsWith("_contrib_") || 
!func.name.startsWith("_"))
-      else ndarrayFunctions.filterNot(_.name.startsWith("_"))
-    }.filterNot(ele => notGenerated.contains(ele.name))
-
-    val functionDefs = newNDArrayFunctions.map { ndarrayfunction =>
-
-      // Construct argument field
-      var argDef = ListBuffer[String]()
-      // Construct Implementation field
-      var impl = ListBuffer[String]()
-      impl += "val map = scala.collection.mutable.Map[String, Any]()"
-      impl += "val args = scala.collection.mutable.ArrayBuffer.empty[NDArray]"
-      ndarrayfunction.listOfArgs.foreach({ ndarrayarg =>
-        // var is a special word used to define variable in Scala,
-        // need to changed to something else in order to make it work
-        val currArgName = ndarrayarg.argName match {
-          case "var" => "vari"
-          case "type" => "typeOf"
-          case default => ndarrayarg.argName
-        }
-        if (ndarrayarg.isOptional) {
-          argDef += s"${currArgName} : Option[${ndarrayarg.argType}] = None"
-        }
-        else {
-          argDef += s"${currArgName} : ${ndarrayarg.argType}"
-        }
-        // NDArray arg implementation
-        val returnType = "org.apache.mxnet.NDArray"
-
-        // TODO: Currently we do not add place holder for NDArray
-        // Example: an NDArray operator like the following format
-        // nd.foo(arg1: NDArray(required), arg2: NDArray(Optional), arg3: 
NDArray(Optional)
-        // If we place nd.foo(arg1, arg3 = arg3), do we need to add place 
holder for arg2?
-        // What it should be?
-        val base =
-          if (ndarrayarg.argType.equals(returnType)) {
-            s"args += $currArgName"
-          } else if (ndarrayarg.argType.equals(s"Array[$returnType]")){
-            s"args ++= $currArgName"
-          } else {
-            "map(\"" + ndarrayarg.argName + "\") = " + currArgName
-          }
-        impl.append(
-          if (ndarrayarg.isOptional) s"if (!$currArgName.isEmpty) $base.get"
-          else base
-        )
-      })
-      // add default out parameter
-      argDef += "out : Option[NDArray] = None"
-      impl += "if (!out.isEmpty) map(\"out\") = out.get"
-      // scalastyle:off
-      impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + 
ndarrayfunction.name + "\", args.toSeq, map.toMap)"
-      // scalastyle:on
-      // Combine and build the function string
-      val returnType = "org.apache.mxnet.NDArrayFuncReturn"
-      var finalStr = s"def ${ndarrayfunction.name}"
-      finalStr += s" (${argDef.mkString(",")}) : $returnType"
-      finalStr += s" = {${impl.mkString("\n")}}"
-      c.parse(finalStr).asInstanceOf[DefDef]
+    val functions = functionsToGenerate(isSymbol = false, isContrib)
+
+    val functionDefs = functions.flatMap { NDArrayfunction =>
+      val funcName = NDArrayfunction.name
+      val termName = TermName(funcName)
+      Seq(
+        // e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*)
+        q"""
+             def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {
+               genericNDArrayFunctionInvoke($funcName, args, kwargs)
+             }
+          """.asInstanceOf[DefDef],
+        // e.g def transpose(args: Any*)
+        q"""
+             def $termName(args: Any*) = {
+               genericNDArrayFunctionInvoke($funcName, args, null)
+             }
+          """.asInstanceOf[DefDef]
+      )
     }
 
-    structGeneration(c)(functionDefs, annottees : _*)
+    structGeneration(c)(functionDefs, annottees: _*)
   }
+}
 
-  private def structGeneration(c: blackbox.Context)
-                              (funcDef : List[c.universe.DefDef], annottees: 
c.Expr[Any]*)
-  : c.Expr[Any] = {
+private[mxnet] object TypedNDArrayAPIMacro extends GeneratorBase {
+
+  def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): 
c.Expr[Any] = {
     import c.universe._
-    val inputs = annottees.map(_.tree).toList
-    // pattern match on the inputs
-    val modDefs = inputs map {
-      case ClassDef(mods, name, something, template) =>
-        val q = template match {
-          case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef)
-          case ex =>
-            throw new IllegalArgumentException(s"Invalid template: $ex")
-        }
-        ClassDef(mods, name, something, q)
-      case ModuleDef(mods, name, template) =>
-        val q = template match {
-          case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef)
-          case ex =>
-            throw new IllegalArgumentException(s"Invalid template: $ex")
-        }
-        ModuleDef(mods, name, q)
-      case ex =>
-        throw new IllegalArgumentException(s"Invalid macro input: $ex")
+    val isContrib: Boolean = c.prefix.tree match {
+      case q"new AddNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
     }
-    // wrap the result up in an Expr, and return it
-    val result = c.Expr(Block(modDefs, Literal(Constant())))
-    result
+
+    val functions = typeSafeFunctionsToGenerate(isSymbol = false, isContrib)
+
+    val functionDefs = functions.map(f => buildTypedFunction(c)(f))
+    structGeneration(c)(functionDefs, annottees: _*)
   }
 
+  protected def buildTypedFunction(c: blackbox.Context)
+                                  (function: Func): c.universe.DefDef = {
+    import c.universe._
 
+    val returnType = "org.apache.mxnet.NDArrayFuncReturn"
+    val ndarrayType = "org.apache.mxnet.NDArray"
 
+    // Construct argument field
+    val argDef = ListBuffer[String]()
+    argDef ++= typedFunctionCommonArgDef(function)
+    argDef += "out : Option[NDArray] = None"
 
-  // List and add all the atomic symbol functions to current module.
-  private def initNDArrayModule(): List[NDArrayFunction] = {
-    val opNames = ListBuffer.empty[String]
-    _LIB.mxListAllOpNames(opNames)
-    opNames.map(opName => {
-      val opHandle = new RefLong
-      _LIB.nnGetOpHandle(opName, opHandle)
-      makeNDArrayFunction(opHandle.value, opName)
-    }).toList
-  }
+    // Construct Implementation field
+    var impl = ListBuffer[String]()
+    impl += "val map = scala.collection.mutable.Map[String, Any]()"
+    impl += s"val args = 
scala.collection.mutable.ArrayBuffer.empty[$ndarrayType]"
 
-  // Create an atomic symbol function by handle and function name.
-  private def makeNDArrayFunction(handle: NDArrayHandle, aliasName: String)
-  : NDArrayFunction = {
-    val name = new RefString
-    val desc = new RefString
-    val keyVarNumArgs = new RefString
-    val numArgs = new RefInt
-    val argNames = ListBuffer.empty[String]
-    val argTypes = ListBuffer.empty[String]
-    val argDescs = ListBuffer.empty[String]
-
-    _LIB.mxSymbolGetAtomicSymbolInfo(
-      handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
-    val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, 
argDescs)
-    val extraDoc: String = if (keyVarNumArgs.value != null && 
keyVarNumArgs.value.length > 0) {
-      s"This function support variable length of positional input 
(${keyVarNumArgs.value})."
-    } else {
-      ""
-    }
-    val realName = if (aliasName == name.value) "" else s"(a.k.a., 
${name.value})"
-    val docStr = s"$aliasName 
$realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
-    // scalastyle:off println
-    if (System.getenv("MXNET4J_PRINT_OP_DEF") != null
-      && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") {
-      println("NDArray function definition:\n" + docStr)
-    }
-    // scalastyle:on println
-    val argList = argNames zip argTypes map { case (argName, argType) =>
-      val typeAndOption =
-        CToScalaUtils.argumentCleaner(argName, argType, 
"org.apache.mxnet.NDArray")
-      new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
+    // NDArray arg implementation
+    impl ++= function.listOfArgs.map { arg =>
+      if (arg.argType.equals(s"Array[$ndarrayType]")) {
+        s"args ++= ${arg.safeArgName}"
+      } else {
+        val base =
+          if (arg.argType.equals(ndarrayType)) {
+            // ndarrays go to args
+            s"args += ${arg.safeArgName}"
+          } else {
+            // other types go to kwargs
+            s"""map("${arg.argName}") = ${arg.safeArgName}"""
+          }
+        if (arg.isOptional) s"if (!${arg.safeArgName}.isEmpty) $base.get"
+        else base
+      }
     }
-    new NDArrayFunction(aliasName, argList.toList)
+
+    impl +=
+      s"""if (!out.isEmpty) map("out") = out.get
+         |org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(
+         |  "${function.name}", args.toSeq, map.toMap)
+       """.stripMargin
+
+    // Combine and build the function string
+    val finalStr =
+      s"""def ${function.name}
+         |   (${argDef.mkString(",")}) : $returnType
+         | = {${impl.mkString("\n")}}
+       """.stripMargin
+
+    c.parse(finalStr).asInstanceOf[DefDef]
   }
 }
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
index 42aa11781d8..ab864e1ef19 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
@@ -21,222 +21,106 @@ import scala.annotation.StaticAnnotation
 import scala.collection.mutable.ListBuffer
 import scala.language.experimental.macros
 import scala.reflect.macros.blackbox
-import org.apache.mxnet.init.Base._
-import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}
 
 private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends 
StaticAnnotation {
-  private[mxnet] def macroTransform(annottees: Any*) = macro 
SymbolImplMacros.addDefs
+  private[mxnet] def macroTransform(annottees: Any*) = macro 
SymbolMacro.addDefs
 }
 
 private[mxnet] class AddSymbolAPIs(isContrib: Boolean) extends 
StaticAnnotation {
-  private[mxnet] def macroTransform(annottees: Any*) = macro 
SymbolImplMacros.typeSafeAPIDefs
+  private[mxnet] def macroTransform(annottees: Any*) = macro 
TypedSymbolAPIMacro.typeSafeAPIDefs
 }
 
-private[mxnet] object SymbolImplMacros {
-  case class SymbolArg(argName: String, argType: String, isOptional : Boolean)
-  case class SymbolFunction(name: String, listOfArgs: List[SymbolArg])
+private[mxnet] object SymbolMacro extends GeneratorBase {
 
-  // scalastyle:off havetype
-  def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
-    impl(c)(annottees: _*)
-  }
-  def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
-    typedAPIImpl(c)(annottees: _*)
-  }
-  // scalastyle:on havetype
-
-  private val symbolFunctions: List[SymbolFunction] = initSymbolModule()
-
-  /**
-    * Implementation for fixed input API structure
-    */
-  private def impl(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] 
= {
+  def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): c.Expr[Any] = {
     import c.universe._
-
     val isContrib: Boolean = c.prefix.tree match {
       case q"new AddSymbolFunctions($b)" => c.eval[Boolean](c.Expr(b))
     }
 
-    val newSymbolFunctions = {
-      if (isContrib) symbolFunctions.filter(
-        func => func.name.startsWith("_contrib_") || 
!func.name.startsWith("_"))
-      else symbolFunctions.filter(!_.name.startsWith("_"))
-    }
+    impl(c)(isContrib, annottees: _*)
+  }
+
+  private def impl(c: blackbox.Context)
+                  (isContrib: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = 
{
+    import c.universe._
 
+    val functions = functionsToGenerate(isSymbol = false, isContrib)
 
-    val functionDefs = newSymbolFunctions map { symbolfunction =>
-        val funcName = symbolfunction.name
-        val tName = TermName(funcName)
-        q"""
+    val functionDefs = functions.map { symbolfunction =>
+      val funcName = symbolfunction.name
+      val tName = TermName(funcName)
+      q"""
             def $tName(name : String = null, attr : Map[String, String] = null)
-            (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null)
-             : org.apache.mxnet.Symbol = {
-              createSymbolGeneral($funcName,name,attr,args,kwargs)
-              }
+              (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = 
null)
+              : org.apache.mxnet.Symbol = {
+                createSymbolGeneral($funcName,name,attr,args,kwargs)
+            }
          """.asInstanceOf[DefDef]
-      }
+    }
 
-    structGeneration(c)(functionDefs, annottees : _*)
+    structGeneration(c)(functionDefs, annottees: _*)
   }
+}
 
-  /**
-    * Implementation for Dynamic typed API Symbol.api.<functioname>
-    */
-  private def typedAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : 
c.Expr[Any] = {
-    import c.universe._
+private[mxnet] object TypedSymbolAPIMacro extends GeneratorBase {
 
+  def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*): 
c.Expr[Any] = {
+    import c.universe._
     val isContrib: Boolean = c.prefix.tree match {
       case q"new AddSymbolAPIs($b)" => c.eval[Boolean](c.Expr(b))
     }
 
-    // Defines Operators that should not generated
-    val notGenerated = Set("Custom")
-
-    // TODO: Put Symbol.api.foo --> Stable APIs
-    // Symbol.contrib.bar--> Contrib APIs
-    val newSymbolFunctions = {
-      if (isContrib) symbolFunctions.filter(
-        func => func.name.startsWith("_contrib_") || 
!func.name.startsWith("_"))
-      else symbolFunctions.filter(!_.name.startsWith("_"))
-    }.filterNot(ele => notGenerated.contains(ele.name))
-
-    val functionDefs = newSymbolFunctions map { symbolfunction =>
-
-      // Construct argument field
-      var argDef = ListBuffer[String]()
-      // Construct Implementation field
-      var impl = ListBuffer[String]()
-      impl += "val map = scala.collection.mutable.Map[String, Any]()"
-      impl += "var args = Seq[org.apache.mxnet.Symbol]()"
-      symbolfunction.listOfArgs.foreach({ symbolarg =>
-        // var is a special word used to define variable in Scala,
-        // need to changed to something else in order to make it work
-        val currArgName = symbolarg.argName match {
-          case "var" => "vari"
-          case "type" => "typeOf"
-          case default => symbolarg.argName
-        }
-        if (symbolarg.isOptional) {
-          argDef += s"${currArgName} : Option[${symbolarg.argType}] = None"
-        }
-        else {
-          argDef += s"${currArgName} : ${symbolarg.argType}"
-        }
-        // Symbol arg implementation
-        val returnType = "org.apache.mxnet.Symbol"
-        val base =
-        if (symbolarg.argType.equals(s"Array[$returnType]")) {
-          if (symbolarg.isOptional) s"if (!$currArgName.isEmpty) args = 
$currArgName.get.toSeq"
-          else s"args = $currArgName.toSeq"
-        } else {
-          if (symbolarg.isOptional) {
-            // scalastyle:off
-            s"if (!$currArgName.isEmpty) map(" + "\"" + symbolarg.argName + 
"\"" + s") = $currArgName.get"
-            // scalastyle:on
-          }
-          else "map(\"" + symbolarg.argName + "\"" + s") = $currArgName"
-        }
+    val functions = typeSafeFunctionsToGenerate(isSymbol = true, isContrib)
 
-        impl += base
-      })
-      argDef += "name : String = null"
-      argDef += "attr : Map[String, String] = null"
-      // scalastyle:off
-      // TODO: Seq() here allows user to place Symbols rather than normal 
arguments to run, need to fix if old API deprecated
-      impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + 
symbolfunction.name + "\", name, attr, args, map.toMap)"
-      // scalastyle:on
-      // Combine and build the function string
-      val returnType = "org.apache.mxnet.Symbol"
-      var finalStr = s"def ${symbolfunction.name}"
-      finalStr += s" (${argDef.mkString(",")}) : $returnType"
-      finalStr += s" = {${impl.mkString("\n")}}"
-      c.parse(finalStr).asInstanceOf[DefDef]
-    }
-    structGeneration(c)(functionDefs, annottees : _*)
+    val functionDefs = functions.map(f => buildTypedFunction(c)(f))
+    structGeneration(c)(functionDefs, annottees: _*)
   }
 
-  /**
-    * Generate class structure for all function APIs
-    * @param c
-    * @param funcDef DefDef type of function definitions
-    * @param annottees
-    * @return
-    */
-  private def structGeneration(c: blackbox.Context)
-                              (funcDef : List[c.universe.DefDef], annottees: 
c.Expr[Any]*)
-  : c.Expr[Any] = {
+  protected def buildTypedFunction(c: blackbox.Context)
+                                  (function: Func): c.universe.DefDef = {
     import c.universe._
-    val inputs = annottees.map(_.tree).toList
-    // pattern match on the inputs
-    val modDefs = inputs map {
-      case ClassDef(mods, name, something, template) =>
-        val q = template match {
-          case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef)
-          case ex =>
-            throw new IllegalArgumentException(s"Invalid template: $ex")
-        }
-        ClassDef(mods, name, something, q)
-      case ModuleDef(mods, name, template) =>
-        val q = template match {
-          case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef)
-          case ex =>
-            throw new IllegalArgumentException(s"Invalid template: $ex")
-        }
-        ModuleDef(mods, name, q)
-      case ex =>
-        throw new IllegalArgumentException(s"Invalid macro input: $ex")
-    }
-    // wrap the result up in an Expr, and return it
-    val result = c.Expr(Block(modDefs, Literal(Constant())))
-    result
-  }
 
-  // List and add all the atomic symbol functions to current module.
-  private def initSymbolModule(): List[SymbolFunction] = {
-    val opNames = ListBuffer.empty[String]
-    _LIB.mxListAllOpNames(opNames)
-    // TODO: Add '_linalg_', '_sparse_', '_image_' support
-    opNames.map(opName => {
-      val opHandle = new RefLong
-      _LIB.nnGetOpHandle(opName, opHandle)
-      makeAtomicSymbolFunction(opHandle.value, opName)
-    }).toList
-  }
+    val returnType = "org.apache.mxnet.Symbol"
+    val symbolType = "org.apache.mxnet.Symbol"
 
-  // Create an atomic symbol function by handle and function name.
-  private def makeAtomicSymbolFunction(handle: SymbolHandle, aliasName: String)
-      : SymbolFunction = {
-    val name = new RefString
-    val desc = new RefString
-    val keyVarNumArgs = new RefString
-    val numArgs = new RefInt
-    val argNames = ListBuffer.empty[String]
-    val argTypes = ListBuffer.empty[String]
-    val argDescs = ListBuffer.empty[String]
-
-    _LIB.mxSymbolGetAtomicSymbolInfo(
-      handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
-    val paramStr = OperatorBuildUtils.ctypes2docstring(argNames, argTypes, 
argDescs)
-    val extraDoc: String = if (keyVarNumArgs.value != null && 
keyVarNumArgs.value.length > 0) {
-        s"This function support variable length of positional input 
(${keyVarNumArgs.value})."
+    // Construct argument field
+    val argDef = ListBuffer[String]()
+    argDef ++= typedFunctionCommonArgDef(function)
+    argDef += "name : String = null"
+    argDef += "attr : Map[String, String] = null"
+
+    // Construct Implementation field
+    val impl = ListBuffer[String]()
+    impl += "val map = scala.collection.mutable.Map[String, Any]()"
+    impl += s"var args = scala.collection.Seq[$symbolType]()"
+
+    // Symbol arg implementation
+    impl ++= function.listOfArgs.map { arg =>
+      if (arg.argType.equals(s"Array[$symbolType]")) {
+        s"if (!${arg.safeArgName}.isEmpty) args = ${arg.safeArgName}.toSeq"
       } else {
-        ""
+        // all go in kwargs
+        if (arg.isOptional) {
+          s"""if (!${arg.safeArgName}.isEmpty) map("${arg.argName}") = 
${arg.safeArgName}.get"""
+        } else {
+          s"""map("${arg.argName}") = ${arg.safeArgName}"""
+        }
       }
-    val realName = if (aliasName == name.value) "" else s"(a.k.a., 
${name.value})"
-    val docStr = s"$aliasName 
$realName\n${desc.value}\n\n$paramStr\n$extraDoc\n"
-    // scalastyle:off println
-    if (System.getenv("MXNET4J_PRINT_OP_DEF") != null
-          && System.getenv("MXNET4J_PRINT_OP_DEF").toLowerCase == "true") {
-      println("Symbol function definition:\n" + docStr)
     }
-    // scalastyle:on println
-    val argList = argNames zip argTypes map { case (argName, argType) =>
-        val typeAndOption =
-          CToScalaUtils.argumentCleaner(argName, argType, 
"org.apache.mxnet.Symbol")
-        new SymbolArg(argName, typeAndOption._1, typeAndOption._2)
-    }
-    new SymbolFunction(aliasName, argList.toList)
+
+    impl +=
+      s"""org.apache.mxnet.Symbol.createSymbolGeneral(
+         |  "${function.name}", name, attr, args, map.toMap)
+       """.stripMargin
+
+    // Combine and build the function string
+    val finalStr =
+      s"""def ${function.name}
+         |   (${argDef.mkString(",")}) : $returnType
+         | = {${impl.mkString("\n")}}
+       """.stripMargin
+
+    c.parse(finalStr).asInstanceOf[DefDef]
   }
 }
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
 
b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
index 2d1827038af..4dfd6eb044a 100644
--- 
a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
+++ 
b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
@@ -17,8 +17,7 @@
 
 package org.apache.mxnet.javaapi
 
-import org.apache.mxnet.init.Base._
-import org.apache.mxnet.utils.CToScalaUtils
+import org.apache.mxnet.GeneratorBase
 
 import scala.annotation.StaticAnnotation
 import scala.collection.mutable.ListBuffer
@@ -29,9 +28,7 @@ private[mxnet] class AddJNDArrayAPIs(isContrib: Boolean) 
extends StaticAnnotatio
   private[mxnet] def macroTransform(annottees: Any*) = macro 
JavaNDArrayMacro.typeSafeAPIDefs
 }
 
-private[mxnet] object JavaNDArrayMacro {
-  case class NDArrayArg(argName: String, argType: String, isOptional : Boolean)
-  case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg])
+private[mxnet] object JavaNDArrayMacro extends GeneratorBase {
 
   // scalastyle:off havetype
   def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
@@ -39,8 +36,6 @@ private[mxnet] object JavaNDArrayMacro {
   }
   // scalastyle:off havetype
 
-  private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule()
-
   private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : 
c.Expr[Any] = {
     import c.universe._
 
@@ -50,12 +45,13 @@ private[mxnet] object JavaNDArrayMacro {
     // Defines Operators that should not generated
     val notGenerated = Set("Custom")
 
-    val newNDArrayFunctions = {
-      if (isContrib) ndarrayFunctions.filter(
-        func => func.name.startsWith("_contrib_") || 
!func.name.startsWith("_"))
-      else ndarrayFunctions.filterNot(_.name.startsWith("_"))
-    }.filterNot(ele => 
notGenerated.contains(ele.name)).groupBy(_.name.toLowerCase).map(ele => {
-      // Pattern matching for not generating depreciated method
+    val newNDArrayFunctions = functionsToGenerate(false, false, true)
+      .filterNot(ele => 
notGenerated.contains(ele.name)).groupBy(_.name.toLowerCase).map(ele => {
+      /* Pattern matching for not generating deprecated method
+       * Group all method name in lowercase
+       * Kill the capital lettered method such as Cast vs cast
+       * As it defined by default it deprecated
+       */
       if (ele._2.length == 1) ele._2.head
       else {
         if (ele._2.head.name.head.isLower) ele._2.head
@@ -79,11 +75,7 @@ private[mxnet] object JavaNDArrayMacro {
       ndarrayfunction.listOfArgs.foreach({ ndarrayArg =>
         // var is a special word used to define variable in Scala,
         // need to changed to something else in order to make it work
-        var currArgName = ndarrayArg.argName match {
-          case "var" => "vari"
-          case "type" => "typeOf"
-          case _ => ndarrayArg.argName
-        }
+        var currArgName = ndarrayArg.safeArgName
         if (useParamObject) currArgName = s"po.get${currArgName.capitalize}()"
         argDef += s"$currArgName : ${ndarrayArg.argType}"
         // NDArray arg implementation
@@ -128,73 +120,6 @@ private[mxnet] object JavaNDArrayMacro {
         functionDefs += c.parse(funcDef).asInstanceOf[DefDef]
       }
     }
-
     structGeneration(c)(functionDefs.toList, annottees : _*)
   }
-
-  private def structGeneration(c: blackbox.Context)
-                              (funcDef : List[c.universe.DefDef],
-                               annottees: c.Expr[Any]*)
-  : c.Expr[Any] = {
-    import c.universe._
-    val inputs = annottees.map(_.tree).toList
-    // pattern match on the inputs
-    var modDefs = inputs map {
-      case ClassDef(mods, name, something, template) =>
-        val q = template match {
-          case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef)
-          case ex =>
-            throw new IllegalArgumentException(s"Invalid template: $ex")
-        }
-        ClassDef(mods, name, something, q)
-      case ModuleDef(mods, name, template) =>
-        val q = template match {
-          case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef)
-          case ex =>
-            throw new IllegalArgumentException(s"Invalid template: $ex")
-        }
-        ModuleDef(mods, name, q)
-      case ex =>
-        throw new IllegalArgumentException(s"Invalid macro input: $ex")
-    }
-    //    modDefs ++= classDef
-    // wrap the result up in an Expr, and return it
-    val result = c.Expr(Block(modDefs, Literal(Constant())))
-    result
-  }
-
-  // List and add all the atomic symbol functions to current module.
-  private def initNDArrayModule(): List[NDArrayFunction] = {
-    val opNames = ListBuffer.empty[String]
-    _LIB.mxListAllOpNames(opNames)
-    opNames.map(opName => {
-      val opHandle = new RefLong
-      _LIB.nnGetOpHandle(opName, opHandle)
-      makeNDArrayFunction(opHandle.value, opName)
-    }).toList
-  }
-
-  // Create an atomic symbol function by handle and function name.
-  private def makeNDArrayFunction(handle: NDArrayHandle, aliasName: String)
-  : NDArrayFunction = {
-    val name = new RefString
-    val desc = new RefString
-    val keyVarNumArgs = new RefString
-    val numArgs = new RefInt
-    val argNames = ListBuffer.empty[String]
-    val argTypes = ListBuffer.empty[String]
-    val argDescs = ListBuffer.empty[String]
-
-    _LIB.mxSymbolGetAtomicSymbolInfo(
-      handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
-    val argList = argNames zip argTypes map { case (argName, argType) =>
-      val typeAndOption =
-        CToScalaUtils.argumentCleaner(argName, argType,
-          "org.apache.mxnet.javaapi.NDArray")
-      new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
-    }
-    new NDArrayFunction(aliasName, argList.toList)
-  }
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to