yzhliu commented on a change in pull request #13162: [MXNET-1198] MXNet Java API
URL: https://github.com/apache/incubator-mxnet/pull/13162#discussion_r233611420
 
 

 ##########
 File path: 
scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
 ##########
 @@ -0,0 +1,387 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.javaapi
+
+import org.apache.mxnet.javaapi.DType.DType
+
+import collection.JavaConverters._
+
+@AddJNDArrayAPIs(false)
+object NDArray extends NDArrayBase {
+  implicit def fromNDArray(nd: org.apache.mxnet.NDArray): NDArray = new 
NDArray(nd)
+
+  implicit def toNDArray(jnd: NDArray): org.apache.mxnet.NDArray = jnd.nd
+
+  def waitall(): Unit = org.apache.mxnet.NDArray.waitall()
+
+  /**
+    * One hot encoding indices into matrix out.
+    * @param indices An NDArray containing indices of the categorical features.
+    * @param out The result holder of the encoding.
+    * @return Same as out.
+    */
+  def onehotEncode(indices: NDArray, out: NDArray): NDArray
+  = org.apache.mxnet.NDArray.onehotEncode(indices, out)
+
+  /**
+    * Create an empty uninitialized new NDArray, with specified shape.
+    *
+    * @param shape shape of the NDArray.
+    * @param ctx The context of the NDArray.
+    *
+    * @return The created NDArray.
+    */
+  def empty(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
+  = org.apache.mxnet.NDArray.empty(shape, ctx, dtype)
+  def empty(ctx: Context, shape: Array[Int]): NDArray
+  = org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
+  def empty(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
+  = org.apache.mxnet.NDArray.empty(new Shape(shape), ctx)
+
+  /**
+    * Create a new NDArray filled with 0, with specified shape.
+    *
+    * @param shape shape of the NDArray.
+    * @param ctx The context of the NDArray.
+    *
+    * @return The created NDArray.
+    */
+  def zeros(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
+  = org.apache.mxnet.NDArray.zeros(shape, ctx, dtype)
+  def zeros(ctx: Context, shape: Array[Int]): NDArray
+  = org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
+  def zeros(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
+  = org.apache.mxnet.NDArray.zeros(new Shape(shape), ctx)
+
+  /**
+    * Create a new NDArray filled with 1, with specified shape.
+    * @param shape shape of the NDArray.
+    * @param ctx The context of the NDArray.
+    * @return The created NDArray.
+    */
+  def ones(shape: Shape, ctx: Context, dtype: DType.DType): NDArray
+  = org.apache.mxnet.NDArray.ones(shape, ctx, dtype)
+  def ones(ctx: Context, shape: Array[Int]): NDArray
+  = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
+  def ones(ctx : Context, shape : java.util.List[java.lang.Integer]) : NDArray
+  = org.apache.mxnet.NDArray.ones(new Shape(shape), ctx)
+
+  /**
+    * Create a new NDArray filled with given value, with specified shape.
+    * @param shape shape of the NDArray.
+    * @param value value to be filled with
+    * @param ctx The context of the NDArray
+    */
+  def full(shape: Shape, value: Float, ctx: Context): NDArray
+  = org.apache.mxnet.NDArray.full(shape, value, ctx)
+
+  def power(lhs: NDArray, rhs: NDArray): NDArray = 
org.apache.mxnet.NDArray.power(lhs, rhs)
+  def power(lhs: NDArray, rhs: Float): NDArray = 
org.apache.mxnet.NDArray.power(lhs, rhs)
+  def power(lhs: Float, rhs: NDArray): NDArray = 
org.apache.mxnet.NDArray.power(lhs, rhs)
+
+  def maximum(lhs: NDArray, rhs: NDArray): NDArray = 
org.apache.mxnet.NDArray.maximum(lhs, rhs)
+  def maximum(lhs: NDArray, rhs: Float): NDArray = 
org.apache.mxnet.NDArray.maximum(lhs, rhs)
+  def maximum(lhs: Float, rhs: NDArray): NDArray = 
org.apache.mxnet.NDArray.maximum(lhs, rhs)
+
+  def minimum(lhs: NDArray, rhs: NDArray): NDArray = 
org.apache.mxnet.NDArray.minimum(lhs, rhs)
+  def minimum(lhs: NDArray, rhs: Float): NDArray = 
org.apache.mxnet.NDArray.minimum(lhs, rhs)
+  def minimum(lhs: Float, rhs: NDArray): NDArray = 
org.apache.mxnet.NDArray.minimum(lhs, rhs)
+
+
+  /**
+    * Returns the result of element-wise **equal to** (==) comparison 
operation with broadcasting.
+    * For each element in input arrays, return 1(true) if corresponding 
elements are same,
+    * otherwise return 0(false).
+    */
+  def equal(lhs: NDArray, rhs: NDArray): NDArray = 
org.apache.mxnet.NDArray.equal(lhs, rhs)
+  def equal(lhs: NDArray, rhs: Float): NDArray = 
org.apache.mxnet.NDArray.equal(lhs, rhs)
+
+  /**
+    * Returns the result of element-wise **not equal to** (!=) comparison 
operation
+    * with broadcasting.
+    * For each element in input arrays, return 1(true) if corresponding 
elements are different,
+    * otherwise return 0(false).
+    */
+  def notEqual(lhs: NDArray, rhs: NDArray): NDArray = 
org.apache.mxnet.NDArray.notEqual(lhs, rhs)
+  def notEqual(lhs: NDArray, rhs: Float): NDArray = 
org.apache.mxnet.NDArray.notEqual(lhs, rhs)
+
+  /**
+    * Returns the result of element-wise **greater than** (>) comparison 
operation
+    * with broadcasting.
+    * For each element in input arrays, return 1(true) if lhs elements are 
greater than rhs,
+    * otherwise return 0(false).
+    */
+  def greater(lhs: NDArray, rhs: NDArray): NDArray = 
org.apache.mxnet.NDArray.greater(lhs, rhs)
+  def greater(lhs: NDArray, rhs: Float): NDArray = 
org.apache.mxnet.NDArray.greater(lhs, rhs)
+
+  /**
+    * Returns the result of element-wise **greater than or equal to** (>=) 
comparison
+    * operation with broadcasting.
+    * For each element in input arrays, return 1(true) if lhs elements are 
greater than equal to rhs
+    * otherwise return 0(false).
+    */
+  def greaterEqual(lhs: NDArray, rhs: NDArray): NDArray
+  = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
+  def greaterEqual(lhs: NDArray, rhs: Float): NDArray
+  = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
+
+  /**
+    * Returns the result of element-wise **lesser than** (<) comparison 
operation
+    * with broadcasting.
+    * For each element in input arrays, return 1(true) if lhs elements are 
less than rhs,
+    * otherwise return 0(false).
+    */
+  def lesser(lhs: NDArray, rhs: NDArray): NDArray = 
org.apache.mxnet.NDArray.lesser(lhs, rhs)
+  def lesser(lhs: NDArray, rhs: Float): NDArray = 
org.apache.mxnet.NDArray.lesser(lhs, rhs)
+
+  /**
+    * Returns the result of element-wise **lesser than or equal to** (<=) 
comparison
+    * operation with broadcasting.
+    * For each element in input arrays, return 1(true) if lhs elements are
+    * lesser than equal to rhs, otherwise return 0(false).
+    */
+  def lesserEqual(lhs: NDArray, rhs: NDArray): NDArray
+  = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
+  def lesserEqual(lhs: NDArray, rhs: Float): NDArray
+  = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
+
+  /**
+    * Create a new NDArray that copies content from source_array.
+    * @param sourceArr Source data to create NDArray from.
+    * @param shape shape of the NDArray
+    * @param ctx The context of the NDArray, default to current default 
context.
+    * @return The created NDArray.
+    */
+  def array(sourceArr: java.util.List[java.lang.Float], shape: Shape, ctx: 
Context = null): NDArray
+  = org.apache.mxnet.NDArray.array(
+    sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx)
+
+  /**
+    * Returns evenly spaced values within a given interval.
+    * Values are generated within the half-open interval [`start`, `stop`). In 
other
+    * words, the interval includes `start` but excludes `stop`.
+    * @param start Start of interval.
+    * @param stop End of interval.
+    * @param step Spacing between values.
+    * @param repeat Number of times to repeat each element.
+    * @param ctx Device context.
+    * @param dType The data type of the `NDArray`.
+    * @return NDArray of evenly spaced values in the specified range.
+    */
+  def arange(start: Float, stop: Float, step: Float, repeat: Int,
+             ctx: Context, dType: DType.DType): NDArray =
+    org.apache.mxnet.NDArray.arange(start, Some(stop), step, repeat, ctx, 
dType)
+}
+
+/**
+  * NDArray object in mxnet.
+  * NDArray is basic ndarray/Tensor like data structure in mxnet. <br />
+  * <b>
+  * NOTE: NDArray is stored in native memory. Use NDArray in a 
try-with-resources() construct
+  * or a [[ResourceScope]] in a try-with-resource to have them automatically 
disposed. You can
+  * explicitly control the lifetime of NDArray by calling dispose manually. 
Failure to do this
+  * will result in leaking native memory.
+  * </b>
+  */
+class NDArray(val nd : org.apache.mxnet.NDArray ) {
+
+  def this(arr : Array[Float], shape : Shape, ctx : Context) = {
+    this(org.apache.mxnet.NDArray.array(arr, shape, ctx))
+  }
+
+  def this(arr : java.util.List[java.lang.Float], shape : Shape, ctx : 
Context) = {
+    this(NDArray.array(arr, shape, ctx))
+  }
+
+  def serialize() : Array[Byte] = nd.serialize()
+
+  /**
+    * Release the native memory. <br />
+    * The NDArrays it depends on will NOT be disposed. <br />
+    * The object shall never be used after it is disposed.
+    */
+  def dispose() : Unit = nd.dispose()
+
+  /**
+    * Dispose all NDArrays who help to construct this array. <br />
+    * e.g. (a * b + c).disposeDeps() will dispose a, b, c (including their 
deps) and a * b
+    * @return this array
+    */
+  def disposeDeps() : NDArray = nd.disposeDepsExcept()
+  // def disposeDepsExcept(arr : Array[NDArray]) : NDArray = 
nd.disposeDepsExcept()
 
 Review comment:
   remove or uncomment

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to