nswamy closed pull request #12064: Allow stop of arange to be inferred from dims. URL: https://github.com/apache/incubator-mxnet/pull/12064
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/contrib/clojure-package/README.md b/contrib/clojure-package/README.md index 5e7356caf64..ea678ccf2db 100644 --- a/contrib/clojure-package/README.md +++ b/contrib/clojure-package/README.md @@ -107,7 +107,9 @@ The jars from maven with the needed MXNet native binaries in it. On startup, the ### Build from MXNET Source -Checkout the latest sha from the main package +First, ensure you have JDK 8 on your system. Later versions may produce cryptic build errors mentioning `scala.reflect.internal.MissingRequirementError`. + +Checkout the latest SHA from the main package: `git clone --recursive https://github.com/apache/incubator-mxnet.git ~/mxnet` `cd ~/mxnet` diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj index e37a8bc8c98..7ca4ede9733 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj @@ -89,7 +89,7 @@ (NDArray/arange (float start) ($/option (float stop)) step repeat ctx dtype)) ([start stop] (arange start stop {}))) - + (defn slice "Return a sliced NDArray that shares memory with current one." ([ndarray i] diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj index 42ae034eb6d..12135fb75ca 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj @@ -135,10 +135,20 @@ ([start stop {:keys [step repeat dtype] :or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE} :as opts}] - (Symbol/arange (float start) ($/option (float stop)) step repeat nil dtype)) + (Symbol/arange (float start) ($/option (float stop)) step repeat false nil dtype)) ([start stop] (arange start stop {}))) +(defn arange-with-inference + "Behaves like arange operator, but infers the stop value from the output shape, + which must be known from the rest of the net." + ([start {:keys [step repeat dtype] + :or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE} + :as opts}] + (Symbol/arange (float start) ($/option nil) step repeat true nil dtype)) + ([start] + (arange-with-inference start {}))) + ;;; manually defined because of a conflicting arity of 2 with the auto-gen (defn min ([sym-name kwargs-map symbol-list kwargs-map-1] diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj index a71a312e1ae..1b4b2ea2fbe 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj @@ -222,6 +222,17 @@ (is (= 0 (count (executor/grad-arrays exec)))) (is (approx= 1e-4 result (-> (executor/outputs exec) (first)))))) +(deftest test-arange-with-inference + (let [arange (sym/arange-with-inference 0) + data (sym/variable "data") + added (sym/+ arange data) + result (range 0 4) + data-tmp (ndarray/zeros [4]) + exec (sym/bind added (context/default-context) {"data" data-tmp})] + (executor/forward exec) + (is (= 0 (count (executor/grad-arrays exec)))) + (is (approx= 1e-4 result (-> (executor/outputs exec) (first)))))) + (deftest test-scalar-pow (let [data (sym/variable "data") shape-vec [1 1] diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj index dcdbea64579..ecd54ca7277 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj @@ -22,6 +22,8 @@ (if (and (number? x) (number? y)) (let [diff (Math/abs (- x y))] (< diff tolerance)) - (reduce (fn [x y] (and x y)) - (map #(approx= tolerance %1 %2) x y)))) + (and + (= (count x) (count y)) + (reduce (fn [x y] (and x y)) + (map #(approx= tolerance %1 %2) x y))))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj index 5551fab435f..de3480827ba 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj @@ -21,6 +21,7 @@ [org.apache.clojure-mxnet.util :as util] [org.apache.clojure-mxnet.ndarray :as ndarray] [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.test-util :as test-util] [clojure.spec.alpha :as s]) (:import (org.apache.mxnet Shape NDArrayFuncReturn NDArray) (scala.collection Map Set) @@ -183,3 +184,10 @@ (deftest test-validate (is (nil? (util/validate! string? "foo" "Not a string!"))) (is (thrown-with-msg? Exception #"Not a string!" (util/validate! ::x 1 "Not a string!")))) + +(deftest test-approx= + (let [data1 [1 1 1 1] + data2 [1 1 1 1 9 9 9 9] + data3 [1 1 1 2]] + (is (not (test-util/approx= 1e-9 data1 data2))) + (is (test-util/approx= 2 data1 data3)))) \ No newline at end of file diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 46b21a90d4c..d6d619f30ca 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -2475,7 +2475,7 @@ def moveaxis(tensor, source, destination): # pylint: disable= no-member, protected-access, too-many-arguments, redefined-outer-name -def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t): +def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, ctx=None, dtype=mx_real_t): """Returns evenly spaced values within a given interval. Values are generated within the half-open interval [`start`, `stop`). In other @@ -2519,7 +2519,7 @@ def arange(start, stop=None, step=1.0, repeat=1, ctx=None, dtype=mx_real_t): if ctx is None: ctx = current_context() return _internal._arange(start=start, stop=stop, step=step, repeat=repeat, - dtype=dtype, ctx=str(ctx)) + infer_range=infer_range, dtype=dtype, ctx=str(ctx)) # pylint: enable= no-member, protected-access, too-many-arguments diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 5f6cbd6b6e1..da5533f3666 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -2886,7 +2886,7 @@ def full(shape, val, dtype=None, **kwargs): return _internal._full(shape=shape, dtype=dtype, value=float(val), **kwargs) # pylint: disable=redefined-outer-name -def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None): +def arange(start, stop=None, step=1.0, repeat=1, infer_range=False, name=None, dtype=None): """Returns evenly spaced values within a given interval. Parameters @@ -2911,7 +2911,7 @@ def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None): if dtype is None: dtype = _numpy.float32 return _internal._arange(start=start, stop=stop, step=step, repeat=repeat, - name=name, dtype=dtype) + infer_range=infer_range, name=name, dtype=dtype) def histogram(a, bins=10, range=None, **kwargs): """Compute the histogram of the input data. diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 548c30b73a1..8b5e1e01095 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -407,11 +407,10 @@ object NDArray extends NDArrayBase { * @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`. * @return NDArray of evenly spaced values in the specified range. */ - def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, - repeat: Int = 1, ctx: Context = Context.defaultCtx, - dType: DType = Base.MX_REAL_TYPE): NDArray = { - val params = Map("start" -> start, "step" -> step, - "repeat" -> repeat, "ctx" -> ctx.toString, "dtype" -> dType.toString()) + def arange(start: Float, stop: Option[Float], step: Float, + repeat: Int, ctx: Context, dType: DType): NDArray = { + val params = Map("start" -> start, "step" -> step, "repeat" -> repeat, + "infer_range" -> false, "ctx" -> ctx.toString, "dtype" -> dType.toString()) val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get) NDArray.genericNDArrayFunctionInvoke("_arange", Seq(), fParams)(0) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 194d3681523..e3e1a320358 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -955,9 +955,28 @@ object Symbol extends SymbolBase { * @return Symbol The created Symbol. */ def arange(start: Float, stop: Option[Float] = None, step: Float = 1.0f, - repeat: Int = 1, name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = { - val params = Map("start" -> start, "step" -> step, - "repeat" -> repeat, "dtype" -> dType.toString()) + repeat: Int = 1, name: String = null, dType: DType = Base.MX_REAL_TYPE): Symbol = { + arange(start, stop, step, repeat, infer_range = false, name, dType) + } + + /** + * Returns evenly spaced values within a given interval. + * stop value can be infered from the output shape, + * which must be known from the rest of the net. + * @param start Start of interval. The default start value is 0. + * @param stop End of interval. + * @param step Spacing between values. The default step size is 1. + * @param repeat Number of times to repeat each element. The default repeat count is 1. + * @param infer_range Infer the stop value from output shape + * @param ctx Device context. Default context is the current default context. + * @param dType The data type of the `NDArray`. The default datatype is `DType.Float32`. + * @return NDArray of evenly spaced values in the specified range. + */ + def arange(start: Float, stop: Option[Float], step: Float, + repeat: Int, infer_range: Boolean, name: String, + dType: DType): Symbol = { + val params = Map("start" -> start, "step" -> step, "repeat" -> repeat, + "infer_range" -> infer_range, "dtype" -> dType.toString()) val fParams = if (stop == None) params else params ++ Map("stop" -> stop.get) createSymbolGeneral("_arange", name, null, Array.empty[Symbol], fParams) } diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 4af3a40f42a..304911a02a7 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -123,6 +123,7 @@ struct RangeParam : public dmlc::Parameter<RangeParam> { dmlc::optional<double> stop; double step; int repeat; + bool infer_range; std::string ctx; int dtype; DMLC_DECLARE_PARAMETER(RangeParam) { @@ -140,6 +141,10 @@ struct RangeParam : public dmlc::Parameter<RangeParam> { .set_default(1) .describe("The repeating time of all elements." " E.g repeat=3, the element a will be repeated three times --> a, a, a."); + DMLC_DECLARE_FIELD(infer_range) + .set_default(false) + .describe("Whether to infer the stop position from the start, step, repeat, and output tensor" + "size."); DMLC_DECLARE_FIELD(ctx) .set_default("") .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." @@ -176,7 +181,7 @@ struct InitOpWithScalarParam : dmlc::Parameter<InitOpWithScalarParam> { inline void RangeParamParser(nnvm::NodeAttrs* attrs) { RangeParam param; param.Init(attrs->dict); - if (!static_cast<bool>(param.stop)) { + if (!static_cast<bool>(param.infer_range) && !static_cast<bool>(param.stop)) { param.stop = param.start; param.start = 0; } @@ -471,6 +476,9 @@ inline bool RangeShape(const nnvm::NodeAttrs& attrs, << "Range does not support step=0, received " << param.step; CHECK(param.repeat > 0) << "Range only supports repeat > 0, received " << param.repeat; + if (param.infer_range && !param.stop.has_value()) { + return false; + } if (param.step > 0) { CHECK(param.start < param.stop.value()) << "Invalid range (start, stop, step) = " diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 90e85d123d5..e0f7219ea76 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3555,10 +3555,18 @@ def test_arange(): nd_out = mx.nd.arange(*config, repeat=repeats, dtype=dtype) assert_almost_equal(np_out, nd_out.asnumpy()) + def test_arange_inferstop(): + s = mx.sym.arange(start=0, stop=None, infer_range=True) + s = mx.sym.elemwise_add(s, mx.sym.zeros(shape=[5])) + exe = s.bind(ctx=mx.cpu(), args={}) + exe.forward() + assert_almost_equal(exe.outputs[0].asnumpy(), np.array([0,1,2,3,4])) + test_basic_val_init(mx.sym.zeros, np.zeros, (3, 4), np.float32) test_basic_val_init(mx.sym.ones, np.ones, 3, np.int32) test_basic_val_init(mx.sym.ones, np.ones, (2, 2, 3), np.float16) test_arange() + test_arange_inferstop() @with_seed() ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
