lanking520 closed pull request #12772: [MXNET-984] Add Java NDArray and
introduce Java Operator Builder class
URL: https://github.com/apache/incubator-mxnet/pull/12772
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/Makefile b/Makefile
index a4b41b8d837..fe2df2c20af 100644
--- a/Makefile
+++ b/Makefile
@@ -606,7 +606,7 @@ scalaclean:
scalapkg:
(cd $(ROOTDIR)/scala-package; \
- mvn package -P$(SCALA_PKG_PROFILE),$(SCALA_VERSION_PROFILE)
-Dcxx="$(CXX)" \
+ mvn package
-P$(SCALA_PKG_PROFILE),$(SCALA_VERSION_PROFILE),integrationtest -Dcxx="$(CXX)" \
-Dbuild.platform="$(SCALA_PKG_PROFILE)" \
-Dcflags="$(CFLAGS)" -Dldflags="$(LDFLAGS)" \
-Dcurrent_libdir="$(ROOTDIR)/lib" \
diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml
index ea3a2d68c9f..6e2d8d6e9cc 100644
--- a/scala-package/core/pom.xml
+++ b/scala-package/core/pom.xml
@@ -86,7 +86,10 @@
<artifactId>maven-surefire-plugin</artifactId>
<version>2.22.0</version>
<configuration>
- <skipTests>false</skipTests>
+ <argLine>
+
-Djava.library.path=${project.parent.basedir}/native/${platform}/target
+ </argLine>
+ <skipTests>${skipTests}</skipTests>
</configuration>
</plugin>
<plugin>
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
index 5f0caedcc40..2f4f3e6409e 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/Context.scala
@@ -42,6 +42,5 @@ object Context {
val gpu: Context = org.apache.mxnet.Context.gpu()
val devtype2str = org.apache.mxnet.Context.devstr2type.asJava
val devstr2type = org.apache.mxnet.Context.devstr2type.asJava
-
def defaultCtx: Context = org.apache.mxnet.Context.defaultCtx
}
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
new file mode 100644
index 00000000000..c77b440d880
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi
+
+import org.apache.mxnet.javaapi.DType.DType
+
+import collection.JavaConverters._
+
+@AddJNDArrayAPIs(false)
+object NDArray {
+ implicit def fromNDArray(nd: org.apache.mxnet.NDArray): NDArray = new
NDArray(nd)
+
+ implicit def toNDArray(jnd: NDArray): org.apache.mxnet.NDArray = jnd.nd
+
+ def waitall(): Unit = org.apache.mxnet.NDArray.waitall()
+
+ def onehotEncode(indices: NDArray, out: NDArray): NDArray
+ = org.apache.mxnet.NDArray.onehotEncode(indices, out)
+
+ def empty(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
+ = org.apache.mxnet.NDArray.empty(shape, ctx, dtype)
+ def empty(ctx: Context, shape: Array[Int]): NDArray
+ = org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
+ def empty(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
+ = org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
+ def zeros(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
+ = org.apache.mxnet.NDArray.zeros(shape, ctx, dtype)
+ def zeros(ctx: Context, shape: Array[Int]): NDArray
+ = org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
+ def zeros(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
+ = org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
+ def ones(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
+ = org.apache.mxnet.NDArray.ones(shape, ctx, dtype)
+ def ones(ctx: Context, shape: Array[Int]): NDArray
+ = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
+ def ones(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
+ = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
+ def full(shape: Shape, value: Float, ctx: Context): NDArray
+ = org.apache.mxnet.NDArray.full(shape, value, ctx)
+
+ def power(lhs: NDArray, rhs: NDArray): NDArray =
org.apache.mxnet.NDArray.power(lhs, rhs)
+ def power(lhs: NDArray, rhs: Float): NDArray =
org.apache.mxnet.NDArray.power(lhs, rhs)
+ def power(lhs: Float, rhs: NDArray): NDArray =
org.apache.mxnet.NDArray.power(lhs, rhs)
+
+ def maximum(lhs: NDArray, rhs: NDArray): NDArray =
org.apache.mxnet.NDArray.maximum(lhs, rhs)
+ def maximum(lhs: NDArray, rhs: Float): NDArray =
org.apache.mxnet.NDArray.maximum(lhs, rhs)
+ def maximum(lhs: Float, rhs: NDArray): NDArray =
org.apache.mxnet.NDArray.maximum(lhs, rhs)
+
+ def minimum(lhs: NDArray, rhs: NDArray): NDArray =
org.apache.mxnet.NDArray.minimum(lhs, rhs)
+ def minimum(lhs: NDArray, rhs: Float): NDArray =
org.apache.mxnet.NDArray.minimum(lhs, rhs)
+ def minimum(lhs: Float, rhs: NDArray): NDArray =
org.apache.mxnet.NDArray.minimum(lhs, rhs)
+
+ def equal(lhs: NDArray, rhs: NDArray): NDArray =
org.apache.mxnet.NDArray.equal(lhs, rhs)
+ def equal(lhs: NDArray, rhs: Float): NDArray =
org.apache.mxnet.NDArray.equal(lhs, rhs)
+
+ def notEqual(lhs: NDArray, rhs: NDArray): NDArray =
org.apache.mxnet.NDArray.notEqual(lhs, rhs)
+ def notEqual(lhs: NDArray, rhs: Float): NDArray =
org.apache.mxnet.NDArray.notEqual(lhs, rhs)
+
+ def greater(lhs: NDArray, rhs: NDArray): NDArray =
org.apache.mxnet.NDArray.greater(lhs, rhs)
+ def greater(lhs: NDArray, rhs: Float): NDArray =
org.apache.mxnet.NDArray.greater(lhs, rhs)
+
+ def greaterEqual(lhs: NDArray, rhs: NDArray): NDArray
+ = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
+ def greaterEqual(lhs: NDArray, rhs: Float): NDArray
+ = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
+
+ def lesser(lhs: NDArray, rhs: NDArray): NDArray =
org.apache.mxnet.NDArray.lesser(lhs, rhs)
+ def lesser(lhs: NDArray, rhs: Float): NDArray =
org.apache.mxnet.NDArray.lesser(lhs, rhs)
+
+ def lesserEqual(lhs: NDArray, rhs: NDArray): NDArray
+ = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
+ def lesserEqual(lhs: NDArray, rhs: Float): NDArray
+ = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
+
+ def array(sourceArr: java.util.List[java.lang.Float], shape: Shape, ctx:
Context = null): NDArray
+ = org.apache.mxnet.NDArray.array(
+ sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx)
+
+ def arange(start: Float, stop: Float, step: Float, repeat: Int,
+ ctx: Context, dType: DType.DType): NDArray =
+ org.apache.mxnet.NDArray.arange(start, Some(stop), step, repeat, ctx,
dType)
+}
+
+class NDArray(val nd : org.apache.mxnet.NDArray ) {
+
+ def this(arr : Array[Float], shape : Shape, ctx : Context) = {
+ this(org.apache.mxnet.NDArray.array(arr, shape, ctx))
+ }
+
+ def this(arr : java.util.List[java.lang.Float], shape : Shape, ctx :
Context) = {
+ this(NDArray.array(arr, shape, ctx))
+ }
+
+ def serialize() : Array[Byte] = nd.serialize()
+
+ def dispose() : Unit = nd.dispose()
+ def disposeDeps() : NDArray = nd.disposeDepsExcept()
+ // def disposeDepsExcept(arr : Array[NDArray]) : NDArray =
nd.disposeDepsExcept()
+
+ def slice(start : Int, stop : Int) : NDArray = nd.slice(start, stop)
+
+ def slice (i : Int) : NDArray = nd.slice(i)
+
+ def at(idx : Int) : NDArray = nd.at(idx)
+
+ def T : NDArray = nd.T
+
+ def dtype : DType = nd.dtype
+
+ def asType(dtype : DType) : NDArray = nd.asType(dtype)
+
+ def reshape(dims : Array[Int]) : NDArray = nd.reshape(dims)
+
+ def waitToRead(): Unit = nd.waitToRead()
+
+ def context : Context = nd.context
+
+ def set(value : Float) : NDArray = nd.set(value)
+ def set(other : NDArray) : NDArray = nd.set(other)
+ def set(other : Array[Float]) : NDArray = nd.set(other)
+
+ def add(other : NDArray) : NDArray = this.nd + other.nd
+ def add(other : Float) : NDArray = this.nd + other
+ def _add(other : NDArray) : NDArray = this.nd += other
+ def _add(other : Float) : NDArray = this.nd += other
+ def subtract(other : NDArray) : NDArray = this.nd - other
+ def subtract(other : Float) : NDArray = this.nd - other
+ def _subtract(other : NDArray) : NDArray = this.nd -= other
+ def _subtract(other : Float) : NDArray = this.nd -= other
+ def multiply(other : NDArray) : NDArray = this.nd * other
+ def multiply(other : Float) : NDArray = this.nd * other
+ def _multiply(other : NDArray) : NDArray = this.nd *= other
+ def _multiply(other : Float) : NDArray = this.nd *= other
+ def div(other : NDArray) : NDArray = this.nd / other
+ def div(other : Float) : NDArray = this.nd / other
+ def _div(other : NDArray) : NDArray = this.nd /= other
+ def _div(other : Float) : NDArray = this.nd /= other
+ def pow(other : NDArray) : NDArray = this.nd ** other
+ def pow(other : Float) : NDArray = this.nd ** other
+ def _pow(other : NDArray) : NDArray = this.nd **= other
+ def _pow(other : Float) : NDArray = this.nd **= other
+ def mod(other : NDArray) : NDArray = this.nd % other
+ def mod(other : Float) : NDArray = this.nd % other
+ def _mod(other : NDArray) : NDArray = this.nd %= other
+ def _mod(other : Float) : NDArray = this.nd %= other
+ def greater(other : NDArray) : NDArray = this.nd > other
+ def greater(other : Float) : NDArray = this.nd > other
+ def greaterEqual(other : NDArray) : NDArray = this.nd >= other
+ def greaterEqual(other : Float) : NDArray = this.nd >= other
+ def lesser(other : NDArray) : NDArray = this.nd < other
+ def lesser(other : Float) : NDArray = this.nd < other
+ def lesserEqual(other : NDArray) : NDArray = this.nd <= other
+ def lesserEqual(other : Float) : NDArray = this.nd <= other
+
+ def toArray : Array[Float] = nd.toArray
+
+ def toScalar : Float = nd.toScalar
+
+ def copyTo(other : NDArray) : NDArray = nd.copyTo(other)
+
+ def copyTo(ctx : Context) : NDArray = nd.copyTo(ctx)
+
+ def copy() : NDArray = copyTo(this.context)
+
+ def shape : Shape = nd.shape
+
+ def size : Int = shape.product
+
+ def asInContext(context: Context): NDArray = nd.asInContext(context)
+
+ override def equals(obj: Any): Boolean = nd.equals(obj)
+ override def hashCode(): Int = nd.hashCode
+}
+
+object NDArrayFuncReturn {
+ implicit def toNDFuncReturn(javaFunReturn : NDArrayFuncReturn)
+ : org.apache.mxnet.NDArrayFuncReturn = javaFunReturn.ndFuncReturn
+ implicit def toJavaNDFuncReturn(ndFuncReturn :
org.apache.mxnet.NDArrayFuncReturn)
+ : NDArrayFuncReturn = new NDArrayFuncReturn(ndFuncReturn)
+}
+
+private[mxnet] class NDArrayFuncReturn(val ndFuncReturn :
org.apache.mxnet.NDArrayFuncReturn) {
+ def head : NDArray = ndFuncReturn.head
+ def get : NDArray = ndFuncReturn.get
+ def apply(i : Int) : NDArray = ndFuncReturn.apply(i)
+ // TODO: Add JavaNDArray operational stuff
+}
diff --git
a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
new file mode 100644
index 00000000000..e0b8179a236
--- /dev/null
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi;
+
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertTrue;
+
+public class NDArrayTest {
+ @Test
+ public void testCreateNDArray() {
+ NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f},
+ new Shape(new int[]{1, 3}),
+ new Context("cpu", 0));
+ int[] arr = new int[]{1, 3};
+ assertTrue(Arrays.equals(nd.shape().toArray(), arr));
+ assertTrue(nd.at(0).at(0).toArray()[0] == 1.0f);
+ List<Float> list = new ArrayList<Float>();
+ list.add(1.0f);
+ list.add(2.0f);
+ list.add(3.0f);
+ nd.dispose();
+ // Second way creating NDArray
+ nd = NDArray.array(list,
+ new Shape(new int[]{1, 3}),
+ new Context("cpu", 0));
+ assertTrue(Arrays.equals(nd.shape().toArray(), arr));
+ }
+
+ @Test
+ public void testZeroOneEmpty(){
+ NDArray ones = NDArray.ones(new Context("cpu", 0), new int[]{100,
100});
+ NDArray zeros = NDArray.zeros(new Context("cpu", 0), new int[]{100,
100});
+ NDArray empty = NDArray.zeros(new Context("cpu", 0), new int[]{100,
100});
+ int[] arr = new int[]{100, 100};
+ assertTrue(Arrays.equals(ones.shape().toArray(), arr));
+ assertTrue(Arrays.equals(zeros.shape().toArray(), arr));
+ assertTrue(Arrays.equals(empty.shape().toArray(), arr));
+ }
+
+ @Test
+ public void testComparison(){
+ NDArray nd = new NDArray(new float[]{1.0f, 2.0f, 3.0f}, new Shape(new
int[]{3}), new Context("cpu", 0));
+ NDArray nd2 = new NDArray(new float[]{3.0f, 4.0f, 5.0f}, new Shape(new
int[]{3}), new Context("cpu", 0));
+ nd = nd.add(nd2);
+ float[] greater = new float[]{1, 1, 1};
+ assertTrue(Arrays.equals(nd.greater(nd2).toArray(), greater));
+ nd = nd.subtract(nd2);
+ nd = nd.subtract(nd2);
+ float[] lesser = new float[]{0, 0, 0};
+ assertTrue(Arrays.equals(nd.greater(nd2).toArray(), lesser));
+ }
+
+ @Test
+ public void testGenerated(){
+ NDArray$ NDArray = NDArray$.MODULE$;
+ float[] arr = new float[]{1.0f, 2.0f, 3.0f};
+ NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new
Context("cpu", 0));
+ float result = NDArray.norm(nd).invoke().get().toArray()[0];
+ float cal = 0.0f;
+ for (float ele : arr) {
+ cal += ele * ele;
+ }
+ cal = (float) Math.sqrt(cal);
+ assertTrue(Math.abs(result - cal) < 1e-5);
+ NDArray dotResult = new NDArray(new float[]{0}, new Shape(new
int[]{1}), new Context("cpu", 0));
+ NDArray.dot(nd, nd).setout(dotResult).invoke().get();
+ assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f}));
+ }
+}
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
new file mode 100644
index 00000000000..de1c8058ab9
--- /dev/null
+++
b/scala-package/macros/src/main/scala/org/apache/mxnet/javaapi/JavaNDArrayMacro.scala
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi
+
+import org.apache.mxnet.init.Base._
+import org.apache.mxnet.utils.CToScalaUtils
+
+import scala.annotation.StaticAnnotation
+import scala.collection.mutable.ListBuffer
+import scala.language.experimental.macros
+import scala.reflect.macros.blackbox
+
+private[mxnet] class AddJNDArrayAPIs(isContrib: Boolean) extends
StaticAnnotation {
+ private[mxnet] def macroTransform(annottees: Any*) = macro
JavaNDArrayMacro.typeSafeAPIDefs
+}
+
+private[mxnet] object JavaNDArrayMacro {
+ case class NDArrayArg(argName: String, argType: String, isOptional : Boolean)
+ case class NDArrayFunction(name: String, listOfArgs: List[NDArrayArg])
+
+ // scalastyle:off havetype
+ def typeSafeAPIDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
+ typeSafeAPIImpl(c)(annottees: _*)
+ }
+ // scalastyle:off havetype
+
+ private val ndarrayFunctions: List[NDArrayFunction] = initNDArrayModule()
+
+ private def typeSafeAPIImpl(c: blackbox.Context)(annottees: c.Expr[Any]*) :
c.Expr[Any] = {
+ import c.universe._
+
+ val isContrib: Boolean = c.prefix.tree match {
+ case q"new AddJNDArrayAPIs($b)" => c.eval[Boolean](c.Expr(b))
+ }
+ // Defines Operators that should not generated
+ val notGenerated = Set("Custom")
+
+ val newNDArrayFunctions = {
+ if (isContrib) ndarrayFunctions.filter(
+ func => func.name.startsWith("_contrib_") ||
!func.name.startsWith("_"))
+ else ndarrayFunctions.filterNot(_.name.startsWith("_"))
+ }.filterNot(ele =>
notGenerated.contains(ele.name)).groupBy(_.name.toLowerCase).map(ele => {
+ // Pattern matching for not generating depreciated method
+ if (ele._2.length == 1) ele._2.head
+ else {
+ if (ele._2.head.name.head.isLower) ele._2.head
+ else ele._2.last
+ }
+ })
+
+ val functionDefs = ListBuffer[DefDef]()
+ val classDefs = ListBuffer[ClassDef]()
+
+ newNDArrayFunctions.foreach { ndarrayfunction =>
+
+ // Construct argument field with all required args
+ var argDef = ListBuffer[String]()
+ // Construct Optional Arg
+ var OptionArgDef = ListBuffer[String]()
+ // Construct function Implementation field (e.g norm)
+ var impl = ListBuffer[String]()
+ impl += "val map = scala.collection.mutable.Map[String, Any]()"
+ // scalastyle:off
+ impl += "val args=
scala.collection.mutable.ArrayBuffer.empty[org.apache.mxnet.NDArray]"
+ // scalastyle:on
+ // Construct Class Implementation (e.g normBuilder)
+ var classImpl = ListBuffer[String]()
+ ndarrayfunction.listOfArgs.foreach({ ndarrayarg =>
+ // var is a special word used to define variable in Scala,
+ // need to changed to something else in order to make it work
+ var currArgName = ndarrayarg.argName match {
+ case "var" => "vari"
+ case "type" => "typeOf"
+ case _ => ndarrayarg.argName
+ }
+ if (ndarrayarg.isOptional) {
+ OptionArgDef += s"private var $currArgName : ${ndarrayarg.argType} =
null"
+ val tempDef = s"def set$currArgName($currArgName :
${ndarrayarg.argType})"
+ val tempImpl = s"this.$currArgName = $currArgName\nthis"
+ classImpl += s"$tempDef = {$tempImpl}"
+ } else {
+ argDef += s"$currArgName : ${ndarrayarg.argType}"
+ }
+ // NDArray arg implementation
+ val returnType = "org.apache.mxnet.javaapi.NDArray"
+ val base =
+ if (ndarrayarg.argType.equals(returnType)) {
+ s"args += this.$currArgName"
+ } else if (ndarrayarg.argType.equals(s"Array[$returnType]")){
+ s"this.$currArgName.foreach(args+=_)"
+ } else {
+ "map(\"" + ndarrayarg.argName + "\") = this." + currArgName
+ }
+ impl.append(
+ if (ndarrayarg.isOptional) s"if (this.$currArgName != null) $base"
+ else base
+ )
+ })
+ // add default out parameter
+ classImpl +=
+ "def setout(out : org.apache.mxnet.javaapi.NDArray) = {this.out =
out\nthis}"
+ impl += "if (this.out != null) map(\"out\") = this.out"
+ OptionArgDef += "private var out : org.apache.mxnet.NDArray = null"
+ val returnType = "org.apache.mxnet.javaapi.NDArrayFuncReturn"
+ // scalastyle:off
+ // Combine and build the function string
+ impl += "org.apache.mxnet.NDArray.genericNDArrayFunctionInvoke(\"" +
ndarrayfunction.name + "\", args.toSeq, map.toMap)"
+ val classDef = s"class
${ndarrayfunction.name}Builder(${argDef.mkString(",")})"
+ val classBody =
s"${OptionArgDef.mkString("\n")}\n${classImpl.mkString("\n")}\ndef invoke() :
$returnType = {${impl.mkString("\n")}}"
+ val classFinal = s"$classDef {$classBody}"
+ val functionDef = s"def ${ndarrayfunction.name}
(${argDef.mkString(",")})"
+ val functionBody = s"new
${ndarrayfunction.name}Builder(${argDef.map(_.split(":")(0)).mkString(",")})"
+ val functionFinal = s"$functionDef = $functionBody"
+ // scalastyle:on
+ functionDefs += c.parse(functionFinal).asInstanceOf[DefDef]
+ classDefs += c.parse(classFinal).asInstanceOf[ClassDef]
+ }
+
+ structGeneration(c)(functionDefs.toList, classDefs.toList, annottees : _*)
+ }
+
+ private def structGeneration(c: blackbox.Context)
+ (funcDef : List[c.universe.DefDef],
+ classDef : List[c.universe.ClassDef],
+ annottees: c.Expr[Any]*)
+ : c.Expr[Any] = {
+ import c.universe._
+ val inputs = annottees.map(_.tree).toList
+ // pattern match on the inputs
+ var modDefs = inputs map {
+ case ClassDef(mods, name, something, template) =>
+ val q = template match {
+ case Template(superMaybe, emptyValDef, defs) =>
+ Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef)
+ case ex =>
+ throw new IllegalArgumentException(s"Invalid template: $ex")
+ }
+ ClassDef(mods, name, something, q)
+ case ModuleDef(mods, name, template) =>
+ val q = template match {
+ case Template(superMaybe, emptyValDef, defs) =>
+ Template(superMaybe, emptyValDef, defs ++ funcDef ++ classDef)
+ case ex =>
+ throw new IllegalArgumentException(s"Invalid template: $ex")
+ }
+ ModuleDef(mods, name, q)
+ case ex =>
+ throw new IllegalArgumentException(s"Invalid macro input: $ex")
+ }
+ // modDefs ++= classDef
+ // wrap the result up in an Expr, and return it
+ val result = c.Expr(Block(modDefs, Literal(Constant())))
+ result
+ }
+
+ // List and add all the atomic symbol functions to current module.
+ private def initNDArrayModule(): List[NDArrayFunction] = {
+ val opNames = ListBuffer.empty[String]
+ _LIB.mxListAllOpNames(opNames)
+ opNames.map(opName => {
+ val opHandle = new RefLong
+ _LIB.nnGetOpHandle(opName, opHandle)
+ makeNDArrayFunction(opHandle.value, opName)
+ }).toList
+ }
+
+ // Create an atomic symbol function by handle and function name.
+ private def makeNDArrayFunction(handle: NDArrayHandle, aliasName: String)
+ : NDArrayFunction = {
+ val name = new RefString
+ val desc = new RefString
+ val keyVarNumArgs = new RefString
+ val numArgs = new RefInt
+ val argNames = ListBuffer.empty[String]
+ val argTypes = ListBuffer.empty[String]
+ val argDescs = ListBuffer.empty[String]
+
+ _LIB.mxSymbolGetAtomicSymbolInfo(
+ handle, name, desc, numArgs, argNames, argTypes, argDescs, keyVarNumArgs)
+ val argList = argNames zip argTypes map { case (argName, argType) =>
+ val typeAndOption =
+ CToScalaUtils.argumentCleaner(argName, argType,
+ "org.apache.mxnet.javaapi.NDArray", "javaapi.Shape")
+ new NDArrayArg(argName, typeAndOption._1, typeAndOption._2)
+ }
+ new NDArrayFunction(aliasName, argList.toList)
+ }
+}
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
index d0ebe5b1d2c..48d8fdf38bc 100644
---
a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
+++
b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
@@ -21,19 +21,19 @@ private[mxnet] object CToScalaUtils {
// Convert C++ Types to Scala Types
- def typeConversion(in : String, argType : String = "",
- argName : String, returnType : String) : String = {
+ def typeConversion(in : String, argType : String = "", argName : String,
+ returnType : String, shapeType : String = "Shape") :
String = {
in match {
- case "Shape(tuple)" | "ShapeorNone" => "org.apache.mxnet.Shape"
+ case "Shape(tuple)" | "ShapeorNone" => s"org.apache.mxnet.$shapeType"
case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType
case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" |
"SymbolorSymbol[]"
=> s"Array[$returnType]"
- case "float" | "real_t" | "floatorNone" =>
"org.apache.mxnet.Base.MXFloat"
- case "int" | "intorNone" | "int(non-negative)" => "Int"
- case "long" | "long(non-negative)" => "Long"
- case "double" | "doubleorNone" => "Double"
+ case "float" | "real_t" | "floatorNone" => "java.lang.Float"
+ case "int" | "intorNone" | "int(non-negative)" => "java.lang.Integer"
+ case "long" | "long(non-negative)" => "java.lang.Long"
+ case "double" | "doubleorNone" => "java.lang.Double"
case "string" => "String"
- case "boolean" | "booleanorNone" => "Boolean"
+ case "boolean" | "booleanorNone" => "java.lang.Boolean"
case "tupleof<float>" | "tupleof<double>" | "tupleof<>" | "ptr" | "" =>
"Any"
case default => throw new IllegalArgumentException(
s"Invalid type for args: $default\nString argType: $argType\nargName:
$argName")
@@ -52,8 +52,8 @@ private[mxnet] object CToScalaUtils {
* @param argType Raw arguement Type description
* @return (Scala_Type, isOptional)
*/
- def argumentCleaner(argName: String,
- argType : String, returnType : String) : (String,
Boolean) = {
+ def argumentCleaner(argName: String, argType : String,
+ returnType : String, shapeType : String = "Shape") :
(String, Boolean) = {
val spaceRemoved = argType.replaceAll("\\s+", "")
var commaRemoved : Array[String] = new Array[String](0)
// Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
@@ -73,7 +73,7 @@ private[mxnet] object CToScalaUtils {
s"""expected "default=..." got ${commaRemoved(2)}""")
(typeConversion(commaRemoved(0), argType, argName, returnType), true)
} else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
- val tempType = typeConversion(commaRemoved(0), argType, argName,
returnType)
+ val tempType = typeConversion(commaRemoved(0), argType, argName,
returnType, shapeType)
val tempOptional = tempType.equals("org.apache.mxnet.Symbol")
(tempType, tempOptional)
} else {
diff --git
a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
index c3a7c58c1af..4404b0885d5 100644
--- a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
+++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
@@ -36,7 +36,7 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll {
)
val output = List(
("org.apache.mxnet.Symbol", true),
- ("Int", false),
+ ("java.lang.Integer", false),
("org.apache.mxnet.Shape", true),
("String", true),
("Any", false)
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services