nswamy closed pull request #13159: [MXNET-1202] Change Builder class into a 
better way
URL: https://github.com/apache/incubator-mxnet/pull/13159
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
index d4e67f73408..cdcc292ada6 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
@@ -385,17 +385,3 @@ class NDArray(val nd : org.apache.mxnet.NDArray ) {
   override def equals(obj: Any): Boolean = nd.equals(obj)
   override def hashCode(): Int = nd.hashCode
 }
-
-object NDArrayFuncReturn {
-  implicit def toNDFuncReturn(javaFunReturn : NDArrayFuncReturn)
-  : org.apache.mxnet.NDArrayFuncReturn = javaFunReturn.ndFuncReturn
-  implicit def toJavaNDFuncReturn(ndFuncReturn : 
org.apache.mxnet.NDArrayFuncReturn)
-  : NDArrayFuncReturn = new NDArrayFuncReturn(ndFuncReturn)
-}
-
-private[mxnet] class NDArrayFuncReturn(val ndFuncReturn : 
org.apache.mxnet.NDArrayFuncReturn) {
-  def head : NDArray = ndFuncReturn.head
-  def get : NDArray = ndFuncReturn.get
-  def apply(i : Int) : NDArray = ndFuncReturn.apply(i)
-  // TODO: Add JavaNDArray operational stuff
-}
diff --git 
a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java 
b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
index a9bad83f62d..2659b7848bc 100644
--- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
@@ -19,9 +19,9 @@
 
 import org.junit.Test;
 
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
+import org.apache.mxnet.javaapi.NDArrayBase.*;
 
 import static org.junit.Assert.assertTrue;
 
@@ -71,7 +71,7 @@ public void testGenerated(){
         NDArray$ NDArray = NDArray$.MODULE$;
         float[] arr = new float[]{1.0f, 2.0f, 3.0f};
         NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new 
Context("cpu", 0));
-        float result = NDArray.norm(nd).invoke().get().toArray()[0];
+        float result = NDArray.norm(NDArray.new normParam(nd))[0].toArray()[0];
         float cal = 0.0f;
         for (float ele : arr) {
             cal += ele * ele;
@@ -79,7 +79,7 @@ public void testGenerated(){
         cal = (float) Math.sqrt(cal);
         assertTrue(Math.abs(result - cal) < 1e-5);
         NDArray dotResult = new NDArray(new float[]{0}, new Shape(new 
int[]{1}), new Context("cpu", 0));
-        NDArray.dot(nd, nd).setout(dotResult).invoke().get();
+        NDArray.dot(NDArray.new dotParam(nd, nd).setOut(dotResult));
         assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f}));
     }
 }
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
index 44d47a2099d..f2326868e8e 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
@@ -116,9 +116,7 @@ private[mxnet] object APIDocGenerator{
     val absFuncs = absClassFunctions.filterNot(_.name.startsWith("_"))
       .filterNot(ele => notGenerated.contains(ele.name))
       .map(absClassFunction => {
-        val scalaDoc = generateAPIDocFromBackend(absClassFunction)
-        val defBody = generateJavaAPISignature(absClassFunction)
-        s"$scalaDoc\n$defBody"
+        generateJavaAPISignature(absClassFunction)
       })
     val packageName = "NDArrayBase"
     val packageDef = "package org.apache.mxnet.javaapi"
@@ -203,27 +201,61 @@ private[mxnet] object APIDocGenerator{
   }
 
   def generateJavaAPISignature(func : absClassFunction) : String = {
+    val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2
     var argDef = ListBuffer[String]()
     var classDef = ListBuffer[String]()
+    var requiredParam = ListBuffer[String]()
     func.listOfArgs.foreach(absClassArg => {
       val currArgName = safetyNameCheck(absClassArg.argName)
       // scalastyle:off
-      if (absClassArg.isOptional) {
-        classDef += s"def set${absClassArg.argName}(${absClassArg.argName} : 
${absClassArg.argType}) : ${func.name}BuilderBase"
+      if (absClassArg.isOptional && useParamObject) {
+        classDef +=
+          s"""private var $currArgName: ${absClassArg.argType} = null
+             |/**
+             | * @param $currArgName\t\t${absClassArg.argDesc}
+             | */
+             |def set${currArgName.capitalize}($currArgName : 
${absClassArg.argType}): ${func.name}Param = {
+             |  this.$currArgName = $currArgName
+             |  this
+             | }""".stripMargin
       }
       else {
+        requiredParam += s"  * @param $currArgName\t\t${absClassArg.argDesc}"
         argDef += s"$currArgName : ${absClassArg.argType}"
       }
+      classDef += s"def get${currArgName.capitalize}() = this.$currArgName"
       // scalastyle:on
     })
-    classDef += s"def setout(out : NDArray) : ${func.name}BuilderBase"
-    classDef += s"def invoke() : org.apache.mxnet.javaapi.NDArrayFuncReturn"
     val experimentalTag = "@Experimental"
-    // scalastyle:off
-    var finalStr = s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", 
")}) : ${func.name}BuilderBase\n"
-    // scalastyle:on
-    finalStr += s"abstract class ${func.name}BuilderBase {\n  
${classDef.mkString("\n  ")}\n}"
-    finalStr
+    val returnType = "Array[NDArray]"
+    val scalaDoc = generateAPIDocFromBackend(func)
+    val scalaDocNoParam = generateAPIDocFromBackend(func, false)
+    if(useParamObject) {
+      classDef +=
+        s"""private var out : org.apache.mxnet.NDArray = null
+           |def setOut(out : NDArray) : ${func.name}Param = {
+           |  this.out = out
+           |  this
+           | }
+           | def getOut() = this.out
+           | """.stripMargin
+      s"""$scalaDocNoParam
+          | $experimentalTag
+          | def ${func.name}(po: ${func.name}Param) : $returnType
+          | /**
+          | * This Param Object is specifically used for ${func.name}
+          | ${requiredParam.mkString("\n")}
+          | */
+          | class ${func.name}Param(${argDef.mkString(",")}) {
+          |  ${classDef.mkString("\n  ")}
+          | }""".stripMargin
+    } else {
+      argDef += "out : NDArray"
+      s"""$scalaDoc
+         |$experimentalTag
+         | def ${func.name}(${argDef.mkString(", ")}) : $returnType
+         | """.stripMargin
+    }
   }
 
 
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
 
b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
index d5be97b501c..2d1827038af 100644
--- 
a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
+++ 
b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
@@ -68,18 +68,14 @@ private[mxnet] object JavaNDArrayMacro {
 
     newNDArrayFunctions.foreach { ndarrayfunction =>
 
+      val useParamObject = ndarrayfunction.listOfArgs.count(arg => 
arg.isOptional) >= 2
       // Construct argument field with all required args
       var argDef = ListBuffer[String]()
-      // Construct Optional Arg
-      var OptionArgDef = ListBuffer[String]()
       // Construct function Implementation field (e.g norm)
       var impl = ListBuffer[String]()
       impl += "val map = scala.collection.mutable.Map[String, Any]()"
-      // scalastyle:off
-      impl += "val args= 
scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]"
-      // scalastyle:on
-      // Construct Class Implementation (e.g normBuilder)
-      var classImpl = ListBuffer[String]()
+      impl +=
+        "val args= 
scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]"
       ndarrayfunction.listOfArgs.foreach({ ndarrayArg =>
         // var is a special word used to define variable in Scala,
         // need to changed to something else in order to make it work
@@ -88,55 +84,56 @@ private[mxnet] object JavaNDArrayMacro {
           case "type" => "typeOf"
           case _ => ndarrayArg.argName
         }
-        if (ndarrayArg.isOptional) {
-          OptionArgDef += s"private var $currArgName : ${ndarrayArg.argType} = 
null"
-          val tempDef = s"def set$currArgName($currArgName : 
${ndarrayArg.argType})"
-          val tempImpl = s"this.$currArgName = $currArgName\nthis"
-          classImpl += s"$tempDef = {$tempImpl}"
-        } else {
-          argDef += s"$currArgName : ${ndarrayArg.argType}"
-        }
+        if (useParamObject) currArgName = s"po.get${currArgName.capitalize}()"
+        argDef += s"$currArgName : ${ndarrayArg.argType}"
         // NDArray arg implementation
         val returnType = "org.apache.mxnet.javaapi.NDArray"
         val base =
           if (ndarrayArg.argType.equals(returnType)) {
-            s"args += this.$currArgName"
+            s"args += $currArgName"
           } else if (ndarrayArg.argType.equals(s"Array[$returnType]")){
-            s"this.$currArgName.foreach(args+=_)"
+            s"$currArgName.foreach(args+=_)"
           } else {
-            "map(\"" + ndarrayArg.argName + "\") = this." + currArgName
+            "map(\"" + ndarrayArg.argName + "\") = " + currArgName
           }
         impl.append(
-          if (ndarrayArg.isOptional) s"if (this.$currArgName != null) $base"
+          if (ndarrayArg.isOptional) s"if ($currArgName != null) $base"
           else base
         )
       })
       // add default out parameter
-      classImpl +=
-        "def setout(out : org.apache.mxnet.javaapi.NDArray) = {this.out = 
out\nthis}"
-      impl += "if (this.out != null) map(\"out\") = this.out"
-      OptionArgDef += "private var out : org.apache.mxnet.NDArray = null"
-      val returnType = "org.apache.mxnet.javaapi.NDArrayFuncReturn"
+      argDef += s"out: org.apache.mxnet.javaapi.NDArray"
+      if (useParamObject) {
+        impl += "if (po.getOut() != null) map(\"out\") = po.getOut()"
+      } else {
+        impl += "if (out != null) map(\"out\") = out"
+      }
+      val returnType = "Array[org.apache.mxnet.javaapi.NDArray]"
       // scalastyle:off
       // Combine and build the function string
-      impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + 
ndarrayfunction.name + "\", args.toSeq, map.toMap)"
-      val classDef = s"class 
${ndarrayfunction.name}Builder(${argDef.mkString(",")}) extends 
${ndarrayfunction.name}BuilderBase"
-      val classBody = 
s"${OptionArgDef.mkString("\n")}\n${classImpl.mkString("\n")}\ndef invoke() : 
$returnType = {${impl.mkString("\n")}}"
-      val classFinal = s"$classDef {$classBody}"
-      val functionDef = s"def ${ndarrayfunction.name} 
(${argDef.mkString(",")})"
-      val functionBody = s"new 
${ndarrayfunction.name}Builder(${argDef.map(_.split(":")(0)).mkString(",")})"
-      val functionFinal = s"$functionDef : ${ndarrayfunction.name}BuilderBase 
= $functionBody"
-      // scalastyle:on
-      functionDefs += c.parse(functionFinal).asInstanceOf[DefDef]
-      classDefs += c.parse(classFinal).asInstanceOf[ClassDef]
+      impl += "val finalArr = 
org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" +
+        ndarrayfunction.name + "\", args.toSeq, map.toMap).arr"
+      impl += "finalArr.map(ele => new NDArray(ele))"
+      if (useParamObject) {
+        val funcDef =
+          s"""def ${ndarrayfunction.name}(po: ${ndarrayfunction.name}Param): 
$returnType = {
+             |  ${impl.mkString("\n")}
+             | }""".stripMargin
+        functionDefs += c.parse(funcDef).asInstanceOf[DefDef]
+      } else {
+        val funcDef =
+          s"""def ${ndarrayfunction.name}(${argDef.mkString(",")}): 
$returnType = {
+             |  ${impl.mkString("\n")}
+             | }""".stripMargin
+        functionDefs += c.parse(funcDef).asInstanceOf[DefDef]
+      }
     }
 
-    structGeneration(c)(functionDefs.toList, classDefs.toList, annottees : _*)
+    structGeneration(c)(functionDefs.toList, annottees : _*)
   }
 
   private def structGeneration(c: blackbox.Context)
                               (funcDef : List[c.universe.DefDef],
-                               classDef : List[c.universe.ClassDef],
                                annottees: c.Expr[Any]*)
   : c.Expr[Any] = {
     import c.universe._
@@ -146,7 +143,7 @@ private[mxnet] object JavaNDArrayMacro {
       case ClassDef(mods, name, something, template) =>
         val q = template match {
           case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef)
+            Template(superMaybe, emptyValDef, defs ++ funcDef)
           case ex =>
             throw new IllegalArgumentException(s"Invalid template: $ex")
         }
@@ -154,7 +151,7 @@ private[mxnet] object JavaNDArrayMacro {
       case ModuleDef(mods, name, template) =>
         val q = template match {
           case Template(superMaybe, emptyValDef, defs) =>
-            Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef)
+            Template(superMaybe, emptyValDef, defs ++ funcDef)
           case ex =>
             throw new IllegalArgumentException(s"Invalid template: $ex")
         }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to