This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 74479b8  [MXNET-386] ongoing maintenance on NDArray (#11126)
74479b8 is described below

commit 74479b89eaba8241573079aa5e32f0ba0f8dd00e
Author: Lanking <lanking...@live.com>
AuthorDate: Tue Jun 12 17:06:13 2018 -0700

    [MXNET-386] ongoing maintenance on NDArray (#11126)
    
    * Important ndarray feature
    
    * merge generic function Invoke
    
    * Pass the Scala Style test
    
    * add Experimental tags
    
    * Change with NDArgs addition
    
    * change dir for Experimental tag
    
    * reTrigger CI
    
    * add Symbol Macros change
    
    * Add some workaround on NDArray
    
    * Simplify the base part
    
    * add changes on ND and Symbols...
    
    * avoid vars
    
    * add Symbol Macros
    
    * Trigger the CI
    
    * Trigger CI
---
 .../src/main/scala/org/apache/mxnet/NDArray.scala  | 17 +++++---
 .../org/apache/mxnet/annotation/Experimental.scala | 25 ++++++++++++
 .../scala/org/apache/mxnet/APIDocGenerator.scala   |  7 +++-
 .../main/scala/org/apache/mxnet/NDArrayMacro.scala | 47 +++++++++++++---------
 .../main/scala/org/apache/mxnet/SymbolMacro.scala  | 24 ++++++++---
 5 files changed, 87 insertions(+), 33 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 469107a..49f4d35 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -65,12 +65,12 @@ object NDArray {
     val ndArgs = ArrayBuffer.empty[NDArray]
     val posArgs = ArrayBuffer.empty[String]
     args.foreach {
-      case arr: NDArray =>
-        ndArgs.append(arr)
-      case arrFunRet: NDArrayFuncReturn =>
-        arrFunRet.arr.foreach(ndArgs.append(_))
-      case arg =>
-        posArgs.append(arg.toString)
+        case arr: NDArray =>
+          ndArgs.append(arr)
+        case arrFunRet: NDArrayFuncReturn =>
+          arrFunRet.arr.foreach(ndArgs.append(_))
+        case arg =>
+          posArgs.append(arg.toString)
     }
 
     require(posArgs.length <= function.arguments.length,
@@ -81,6 +81,7 @@ object NDArray {
         ++ function.arguments.slice(0, posArgs.length).zip(posArgs) - "out"
       ).map { case (k, v) => k -> v.toString }
 
+
     val (oriOutputs, outputVars) =
       if (kwargs != null && kwargs.contains("out")) {
         val output = kwargs("out")
@@ -537,6 +538,10 @@ object NDArray {
     new NDArray(handleRef.value)
   }
 
+  private def _crop_assign(kwargs: Map[String, Any] = null)(args: Any*) : 
NDArrayFuncReturn = {
+    genericNDArrayFunctionInvoke("_crop_assign", args, kwargs)
+  }
+
   // TODO: imdecode
 }
 
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/annotation/Experimental.scala
 
b/scala-package/core/src/main/scala/org/apache/mxnet/annotation/Experimental.scala
new file mode 100644
index 0000000..33d1d33
--- /dev/null
+++ 
b/scala-package/core/src/main/scala/org/apache/mxnet/annotation/Experimental.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.annotation
+
+import java.lang.annotation.{ElementType, Retention, Target, _}
+
+@Retention(RetentionPolicy.RUNTIME)
+@Target(Array(ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, 
ElementType.PARAMETER,
+  ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE))
+class Experimental {}
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
index 90fe260..3bbc7fd 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
@@ -52,8 +52,9 @@ private[mxnet] object APIDocGenerator{
     val apacheLicence = "/*\n* Licensed to the Apache Software Foundation 
(ASF) under one or more\n* contributor license agreements.  See the NOTICE file 
distributed with\n* this work for additional information regarding copyright 
ownership.\n* The ASF licenses this file to You under the Apache License, 
Version 2.0\n* (the \"License\"); you may not use this file except in 
compliance with\n* the License.  You may obtain a copy of the License at\n*\n*  
  http://www.apache.org/licenses/LICE [...]
     val scalaStyle = "// scalastyle:off"
     val packageDef = "package org.apache.mxnet"
+    val imports = "import org.apache.mxnet.annotation.Experimental"
     val absClassDef = s"abstract class $packageName"
-    val finalStr = s"$apacheLicence\n$scalaStyle\n$packageDef\n$absClassDef 
{\n${absFuncs.mkString("\n")}\n}"
+    val finalStr = 
s"$apacheLicence\n$scalaStyle\n$packageDef\n$imports\n$absClassDef 
{\n${absFuncs.mkString("\n")}\n}"
     import java.io._
     val pw = new PrintWriter(new File(FILE_PATH + s"$packageName.scala"))
     pw.write(finalStr)
@@ -97,9 +98,11 @@ private[mxnet] object APIDocGenerator{
       argDef += "name : String = null"
       argDef += "attr : Map[String, String] = null"
     } else {
+      argDef += "out : Option[NDArray] = None"
       returnType = "org.apache.mxnet.NDArrayFuncReturn"
     }
-    s"def ${func.name} (${argDef.mkString(", ")}) : ${returnType}"
+    val experimentalTag = "@Experimental"
+    s"$experimentalTag\ndef ${func.name} (${argDef.mkString(", ")}) : 
$returnType"
   }
 
 
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
index ce5b532..082c64a 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
@@ -21,7 +21,7 @@ import org.apache.mxnet.init.Base._
 import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}
 
 import scala.annotation.StaticAnnotation
-import scala.collection.mutable.ListBuffer
+import scala.collection.mutable.{ArrayBuffer, ListBuffer}
 import scala.language.experimental.macros
 import scala.reflect.macros.blackbox
 
@@ -57,14 +57,13 @@ private[mxnet] object NDArrayMacro {
 
     val newNDArrayFunctions = {
       if (isContrib) ndarrayFunctions.filter(_.name.startsWith("_contrib_"))
-      else ndarrayFunctions.filter(!_.name.startsWith("_contrib_"))
+      else ndarrayFunctions.filterNot(_.name.startsWith("_"))
     }
 
      val functionDefs = newNDArrayFunctions flatMap { NDArrayfunction =>
         val funcName = NDArrayfunction.name
         val termName = TermName(funcName)
-        if (!NDArrayfunction.name.startsWith("_") || 
NDArrayfunction.name.startsWith("_contrib_")) {
-          Seq(
+       Seq(
             // scalastyle:off
             // (yizhi) We are investigating a way to make these functions 
type-safe
             // and waiting to see the new approach is stable enough.
@@ -75,16 +74,7 @@ private[mxnet] object NDArrayMacro {
             q"def $termName(args: Any*) = 
{genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
             // scalastyle:on
           )
-        } else {
-          // Default private
-          Seq(
-            // scalastyle:off
-            q"private def $termName(kwargs: Map[String, Any] = null)(args: 
Any*) = {genericNDArrayFunctionInvoke($funcName, args, 
kwargs)}".asInstanceOf[DefDef],
-            q"private def $termName(args: Any*) = 
{genericNDArrayFunctionInvoke($funcName, args, null)}".asInstanceOf[DefDef]
-            // scalastyle:on
-          )
         }
-      }
 
     structGeneration(c)(functionDefs, annottees : _*)
   }
@@ -109,6 +99,7 @@ private[mxnet] object NDArrayMacro {
       // Construct Implementation field
       var impl = ListBuffer[String]()
       impl += "val map = scala.collection.mutable.Map[String, Any]()"
+      impl += "val args = scala.collection.mutable.ArrayBuffer.empty[NDArray]"
       ndarrayfunction.listOfArgs.foreach({ ndarrayarg =>
         // var is a special word used to define variable in Scala,
         // need to changed to something else in order to make it work
@@ -123,14 +114,32 @@ private[mxnet] object NDArrayMacro {
         else {
           argDef += s"${currArgName} : ${ndarrayarg.argType}"
         }
-        var base = "map(\"" + ndarrayarg.argName + "\") = " + currArgName
-        if (ndarrayarg.isOptional) {
-          base = "if (!" + currArgName + ".isEmpty)" + base + ".get"
-        }
-        impl += base
+        // NDArray arg implementation
+        val returnType = "org.apache.mxnet.NDArray"
+
+        // TODO: Currently we do not add place holder for NDArray
+        // Example: an NDArray operator like the following format
+        // nd.foo(arg1: NDArray(required), arg2: NDArray(Optional), arg3: 
NDArray(Optional)
+        // If we place nd.foo(arg1, arg3 = arg3), do we need to add place 
holder for arg2?
+        // What it should be?
+        val base =
+          if (ndarrayarg.argType.equals(returnType)) {
+            s"args += $currArgName"
+          } else if (ndarrayarg.argType.equals(s"Array[$returnType]")){
+            s"args ++= $currArgName"
+          } else {
+            "map(\"" + ndarrayarg.argName + "\") = " + currArgName
+          }
+        impl.append(
+          if (ndarrayarg.isOptional) s"if (!$currArgName.isEmpty) $base.get"
+          else base
+        )
       })
+      // add default out parameter
+      argDef += "out : Option[NDArray] = None"
+      impl += "if (!out.isEmpty) map(\"out\") = out.get"
       // scalastyle:off
-      impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + 
ndarrayfunction.name + "\", null, map.toMap)"
+      impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" + 
ndarrayfunction.name + "\", args.toSeq, map.toMap)"
       // scalastyle:on
       // Combine and build the function string
       val returnType = "org.apache.mxnet.NDArrayFuncReturn"
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
index bacbdb2..81430c2 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
@@ -41,7 +41,7 @@ private[mxnet] object SymbolImplMacros {
     impl(c)(annottees: _*)
   }
   def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
-    newAPIImpl(c)(annottees: _*)
+    typedAPIImpl(c)(annottees: _*)
   }
   // scalastyle:on havetype
 
@@ -82,7 +82,7 @@ private[mxnet] object SymbolImplMacros {
   /**
     * Implementation for Dynamic typed API Symbol.api.<functioname>
     */
-  private def newAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : 
c.Expr[Any] = {
+  private def typedAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) : 
c.Expr[Any] = {
     import c.universe._
 
     val isContrib: Boolean = c.prefix.tree match {
@@ -104,6 +104,7 @@ private[mxnet] object SymbolImplMacros {
       // Construct Implementation field
       var impl = ListBuffer[String]()
       impl += "val map = scala.collection.mutable.Map[String, Any]()"
+      impl += "var args = Seq[org.apache.mxnet.Symbol]()"
       symbolfunction.listOfArgs.foreach({ symbolarg =>
         // var is a special word used to define variable in Scala,
         // need to changed to something else in order to make it work
@@ -118,17 +119,28 @@ private[mxnet] object SymbolImplMacros {
         else {
           argDef += s"${currArgName} : ${symbolarg.argType}"
         }
-        var base = "map(\"" + symbolarg.argName + "\") = " + currArgName
-        if (symbolarg.isOptional) {
-          base = "if (!" + currArgName + ".isEmpty)" + base + ".get"
+        // Symbol arg implementation
+        val returnType = "org.apache.mxnet.Symbol"
+        val base =
+        if (symbolarg.argType.equals(s"Array[$returnType]")) {
+          if (symbolarg.isOptional) s"if (!$currArgName.isEmpty) args = 
$currArgName.get.toSeq"
+          else s"args = $currArgName.toSeq"
+        } else {
+          if (symbolarg.isOptional) {
+            // scalastyle:off
+            s"if (!$currArgName.isEmpty) map(" + "\"" + symbolarg.argName + 
"\"" + s") = $currArgName.get"
+            // scalastyle:on
+          }
+          else "map(\"" + symbolarg.argName + "\"" + s") = $currArgName"
         }
+
         impl += base
       })
       argDef += "name : String = null"
       argDef += "attr : Map[String, String] = null"
       // scalastyle:off
       // TODO: Seq() here allows user to place Symbols rather than normal 
arguments to run, need to fix if old API deprecated
-      impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + 
symbolfunction.name + "\", name, attr, Seq(), map.toMap)"
+      impl += "org.apache.mxnet.Symbol.createSymbolGeneral(\"" + 
symbolfunction.name + "\", name, attr, args, map.toMap)"
       // scalastyle:on
       // Combine and build the function string
       val returnType = "org.apache.mxnet.Symbol"

-- 
To stop receiving notification emails like this one, please contact
liuyi...@apache.org.

Reply via email to