This is an automated email from the ASF dual-hosted git repository.
cmeier pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 52a3553 [docstring] improve docstring and indentation in `module.clj`
(#14705)
52a3553 is described below
commit 52a3553fe200214437c717e7b35e6ce39adb59d8
Author: Arthur Caillau <[email protected]>
AuthorDate: Tue Apr 16 15:31:01 2019 +0200
[docstring] improve docstring and indentation in `module.clj` (#14705)
---
.../src/org/apache/clojure_mxnet/module.clj | 544 +++++++++++++--------
.../src/org/apache/clojure_mxnet/util.clj | 2 +-
2 files changed, 345 insertions(+), 201 deletions(-)
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
index aa5ce39..09f17e5 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
@@ -16,6 +16,7 @@
;;
(ns org.apache.clojure-mxnet.module
+ "Module API for Clojure package."
(:refer-clojure :exclude [update symbol])
(:require [org.apache.clojure-mxnet.callback :as callback]
[org.apache.clojure-mxnet.context :as context]
@@ -31,18 +32,29 @@
(:import (org.apache.mxnet.module Module FitParams BaseModule)
(org.apache.mxnet.io MXDataIter NDArrayIter)
(org.apache.mxnet Initializer Optimizer NDArray DataBatch
- Context EvalMetric Monitor Callback$Speedometer
DataDesc)))
+ Context EvalMetric Monitor Callback$Speedometer
+ DataDesc)))
(defn module
- "Module is a basic module that wrap a symbol.
- sym : Symbol definition.
- map of options
- :data-names - Input data names.
- :label-names - Input label names
- :contexts - Default is cpu().
- :workload-list - Default nil, indicating uniform workload.
- :fixed-param-names Default nil, indicating no network parameters are
fixed."
- ([sym {:keys [data-names label-names contexts workload-list
fixed-param-names] :as opts
+ "Module is a basic module that wrap a `symbol`.
+ `sym`: Symbol definition.
+ `opts-map` {
+ `data-names`: vector of strings - Default is [\"data\"]
+ Input data names
+ `label-names`: vector of strings - Default is [\"softmax_label\"]
+ Input label names
+ `contexts`: Context - Default is `context/cpu`.
+ `workload-list`: Default nil
+ Indicating uniform workload.
+ `fixed-param-names`: Default nil
+ Indicating no network parameters are fixed.
+ }
+ Ex:
+ (module sym)
+ (module sym {:data-names [\"data\"]
+ :label-names [\"linear_regression_label\"]}"
+ ([sym {:keys [data-names label-names contexts
+ workload-list fixed-param-names] :as opts
:or {data-names ["data"]
label-names ["softmax_label"]
contexts [(context/default-context)]}}]
@@ -80,31 +92,41 @@
(s/def ::force-rebind boolean?)
(s/def ::shared-module #(instance? Module))
(s/def ::grad-req string?)
-(s/def ::bind-opts (s/keys :req-un [::data-shapes] :opt-un [::label-shapes
::for-training ::inputs-need-grad
- ::force-rebind
::shared-module ::grad-req]))
+(s/def ::bind-opts
+ (s/keys :req-un [::data-shapes]
+ :opt-un [::label-shapes ::for-training ::inputs-need-grad
+ ::force-rebind ::shared-module ::grad-req]))
(defn bind
"Bind the symbols to construct executors. This is necessary before one
can perform computation with the module.
- mod : module
- map of opts:
- :data-shapes Typically is (provide-data-desc data-iter). Data shape must
be in the form of io/data-desc with is a map of :name :shape :dtype and :layout
- :label-shapes Typically is (provide-label-desc data-iter). map of :name
:shape :dtype and :layout
- :for-training Default is `true`. Whether the executors should be bind for
training.
- :inputs-need-grad Default is `false`.
- Whether the gradients to the input data need to be
computed.
- Typically this is not needed.
- But this might be needed when implementing composition
of modules.
- :force-rebind Default is `false`.
- This function does nothing if the executors are already
binded.
- But with this `true`, the executors will be forced to
rebind.
- :shared-module Default is nil. This is used in bucketing.
- When not `None`, the shared module essentially corresponds
to
- a different bucket -- a module with different symbol
- but with the same sets of parameters
- (e.g. unrolled RNNs with different lengths). "
- [mod {:keys [data-shapes label-shapes for-training inputs-need-grad
force-rebind
- shared-module grad-req] :as opts
+ `mod`: module
+ `opts-map` {
+ `data-shapes`: map of `:name`, `:shape`, `:dtype`, and `:layout`
+ Typically is `(provide-data-desc data-iter)`.Data shape must be in
the
+ form of `io/data-desc`
+ `label-shapes`: map of `:name` `:shape` `:dtype` and `:layout`
+ Typically is `(provide-label-desc data-iter)`.
+ `for-training`: boolean - Default is `true`
+ Whether the executors should be bind for training.
+ `inputs-need-grad`: boolean - Default is `false`.
+ Whether the gradients to the input data need to be computed.
+ Typically this is not needed. But this might be needed when
+ implementing composition of modules.
+ `force-rebind`: boolean - Default is `false`.
+ This function does nothing if the executors are already binded. But
+ with this `true`, the executors will be forced to rebind.
+ `shared-module`: Default is nil.
+ This is used in bucketing. When not `nil`, the shared module
+ essentially corresponds to a different bucket -- a module with
+ different symbol but with the same sets of parameters (e.g. unrolled
+ RNNs with different lengths).
+ }
+ Ex:
+ (bind {:data-shapes (mx-io/provide-data train-iter)
+ :label-shapes (mx-io/provide-label test-iter)})) "
+ [mod {:keys [data-shapes label-shapes for-training inputs-need-grad
+ force-rebind shared-module grad-req] :as opts
:or {for-training true
inputs-need-grad false
force-rebind false
@@ -129,24 +151,36 @@
(s/def ::aux-params map?)
(s/def ::force-init boolean?)
(s/def ::allow-extra boolean?)
-(s/def ::init-params-opts (s/keys :opt-un [::initializer ::arg-params
::aux-params
- ::force-init ::allow-extra]))
+(s/def ::init-params-opts
+ (s/keys :opt-un [::initializer ::arg-params ::aux-params
+ ::force-init ::allow-extra]))
(defn init-params
- " Initialize the parameters and auxiliary states.
- options map
- :initializer - Called to initialize parameters if needed.
- :arg-params - If not nil, should be a map of existing arg-params.
- Initialization will be copied from that.
- :auxParams - If not nil, should be a map of existing aux-params.
- Initialization will be copied from that.
- :allow-missing - If true, params could contain missing values,
- and the initializer will be called to fill those
missing params.
- :force-init - If true, will force re-initialize even if already
initialized.
- :allow-extra - Whether allow extra parameters that are not needed by
symbol.
- If this is True, no error will be thrown when argParams or
auxParams
- contain extra parameters that is not needed by the executor."
- ([mod {:keys [initializer arg-params aux-params allow-missing force-init
allow-extra] :as opts
+ "Initialize the parameters and auxiliary states.
+ `opts-map` {
+ `initializer`: Initializer - Default is `uniform`
+ Called to initialize parameters if needed.
+ `arg-params`: map
+ If not nil, should be a map of existing arg-params. Initialization
+ will be copied from that.
+ `aux-params`: map
+ If not nil, should be a map of existing aux-params. Initialization
+ will be copied from that.
+ `allow-missing`: boolean - Default is `false`
+ If true, params could contain missing values, and the initializer
will
+ be called to fill those missing params.
+ `force-init` boolean - Default is `false`
+ If true, will force re-initialize even if already initialized.
+ `allow-extra`: boolean - Default is `false`
+ Whether allow extra parameters that are not needed by symbol.
+ If this is `true`, no error will be thrown when `arg-params` or
+ `aux-params` contain extra parameters that is not needed by the
+ executor.
+ Ex:
+ (init-params {:initializer (initializer/xavier)})
+ (init-params {:force-init true :allow-extra true})"
+ ([mod {:keys [initializer arg-params aux-params allow-missing force-init
+ allow-extra] :as opts
:or {initializer (initializer/uniform 0.01)
allow-missing false
force-init false
@@ -167,17 +201,23 @@
(s/def ::kvstore string?)
(s/def ::reset-optimizer boolean?)
(s/def ::force-init boolean?)
-(s/def ::init-optimizer-opts (s/keys :opt-un [::optimizer ::kvstore
::reset-optimizer ::force-init]))
+(s/def ::init-optimizer-opts
+ (s/keys :opt-un [::optimizer ::kvstore ::reset-optimizer ::force-init]))
(defn init-optimizer
- " Install and initialize optimizers.
- - mod Module
- - options map of
- - kvstore
- - reset-optimizer Default `True`, indicating whether we should set
- `rescaleGrad` & `idx2name` for optimizer according to executorGroup
- - force-init Default `False`, indicating whether we should force
- re-initializing the optimizer in the case an optimizer is already
installed."
+ "Install and initialize optimizers.
+ `mod`: Module
+ `opts-map` {
+ `kvstore`: string - Default is \"local\"
+ `optimizer`: Optimizer - Default is `sgd`
+ `reset-optimizer`: boolean - Default is `true`
+ Indicating whether we should set `rescaleGrad` & `idx2name` for
+ optimizer according to executorGroup.
+ `force-init`: boolean - Default is `false`
+ Indicating whether we should force re-initializing the optimizer
+ in the case an optimizer is already installed.
+ Ex:
+ (init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1})})"
([mod {:keys [kvstore optimizer reset-optimizer force-init] :as opts
:or {kvstore "local"
optimizer (optimizer/sgd)
@@ -191,8 +231,10 @@
(defn forward
"Forward computation.
- data-batch - input data of form io/data-batch either map or DataBatch
- is-train - Default is nil, which means `is_train` takes the value of
`for_training`."
+ `data-batch`: Either map or DataBatch
+ Input data of form `io/data-batch`.
+ `is-train`: Default is nil
+ Which means `is_train` takes the value of `for_training`."
([mod data-batch is-train]
(util/validate! ::mx-io/data-batch data-batch "Invalid data batch")
(doto mod
@@ -209,9 +251,9 @@
(defn backward
"Backward computation.
- out-grads - Gradient on the outputs to be propagated back.
- This parameter is only needed when bind is called
- on outputs that are not a loss function."
+ `out-grads`: collection of NDArrays
+ Gradient on the outputs to be propagated back. This parameter is only
+ needed when bind is called on outputs that are not a loss function."
([mod out-grads]
(util/validate! ::out-grads out-grads "Invalid out-grads")
(doto mod
@@ -227,50 +269,48 @@
(.forwardBackward data-batch)))
(defn outputs
- " Get outputs of the previous forward computation.
- In the case when data-parallelism is used,
- the outputs will be collected from multiple devices.
- The results will look like `[[out1_dev1, out1_dev2], [out2_dev1,
out2_dev2]]`,
- those `NDArray` might live on different devices."
+ "Get outputs of the previous forward computation.
+ In the case when data-parallelism is used, the outputs will be collected
from
+ multiple devices. The results will look like
+ `[[out1_dev1, out1_dev2], [out2_dev1, out2_dev2]]`.
+ Those `NDArray`s might live on different devices."
[mod]
(->> (.getOutputs mod)
(util/scala-vector->vec)
(mapv util/scala-vector->vec)))
(defn update
- "Update parameters according to the installed optimizer and the gradients
computed
- in the previous forward-backward batch."
+ "Update parameters according to the installed optimizer and the gradients
+ computed in the previous forward-backward batch."
[mod]
(doto mod
(.update)))
(defn outputs-merged
- " Get outputs of the previous forward computation.
- return In the case when data-parallelism is used,
- the outputs will be merged from multiple devices,
- as they look like from a single executor.
- The results will look like `[out1, out2]`"
+ "Get outputs of the previous forward computation.
+ In the case when data-parallelism is used, the outputs will be merged from
+ multiple devices, as they look like from a single executor.
+ The results will look like `[out1, out2]`."
[mod]
(->> (.getOutputsMerged mod)
(util/scala-vector->vec)))
(defn input-grads
- " Get the gradients to the inputs, computed in the previous backward
computation.
- In the case when data-parallelism is used,
- the outputs will be collected from multiple devices.
- The results will look like `[[grad1_dev1, grad1_dev2],
[grad2_dev1, grad2_dev2]]`
- those `NDArray` might live on different devices."
+ "Get the gradients to the inputs, computed in the previous backward
computation.
+ In the case when data-parallelism is used, the outputs will be collected
from
+ multiple devices. The results will look like
+ `[[grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2]]`.
+ Those `NDArray`s might live on different devices."
[mod]
(->> (.getInputGrads mod)
(util/scala-vector->vec)
(mapv util/scala-vector->vec)))
(defn input-grads-merged
- " Get the gradients to the inputs, computed in the previous backward
computation.
- return In the case when data-parallelism is used,
- the outputs will be merged from multiple devices,
- as they look like from a single executor.
- The results will look like `[grad1, grad2]`"
+ "Get the gradients to the inputs, computed in the previous backward
computation.
+ In the case when data-parallelism is used, the outputs will be merged from
+ multiple devices, as they look like from a single executor.
+ The results will look like `[grad1, grad2]`."
[mod]
(->> (.getInputGradsMerged mod)
(util/scala-vector->vec)))
@@ -278,16 +318,25 @@
(s/def ::prefix string?)
(s/def ::epoch int?)
(s/def ::save-opt-states boolean?)
-(s/def ::save-checkpoint-opts (s/keys :req-un [::prefix ::epoch] :opt-un
[::save-opt-states ::save-checkpoint]))
+(s/def ::save-checkpoint-opts
+ (s/keys :req-un [::prefix ::epoch]
+ :opt-un [::save-opt-states ::save-checkpoint]))
(defn save-checkpoint
- " Save current progress to checkpoint.
- Use mx.callback.module_checkpoint as epoch_end_callback to save during
training.
- - mod Module
- - opt-map with
- :prefix The file prefix to checkpoint to
- :epoch The current epoch number
- :save-opt-states Whether to save optimizer states for continue training
"
+ "Save current progress to checkpoint.
+ Use mx.callback.module_checkpoint as epoch_end_callback to save during
+ training.
+ `mod`: Module
+ `opts-map` {
+ `prefix`: string
+ The file prefix to checkpoint to
+ `epoch`: int
+ The current epoch number
+ `save-opt-states`: boolean - Default is `false`
+ Whether to save optimizer states for continue training
+ }
+ Ex:
+ (save-checkpoint {:prefix \"saved_model\" :epoch 0 :save-opt-states
true})"
([mod {:keys [prefix epoch save-opt-states] :as opts
:or {save-opt-states false}}]
(util/validate! ::save-checkpoint-opts opts "Invalid save checkpoint opts")
@@ -303,24 +352,34 @@
(s/def ::contexts (s/coll-of ::context :kind vector?))
(s/def ::workload-list (s/coll-of number? :kind vector?))
(s/def ::fixed-params-names (s/coll-of string? :kind vector?))
-(s/def ::load-checkpoint-opts (s/keys :req-un [::prefix ::epoch]
- :opt-un [::load-optimizer-states
::data-names ::label-names
- ::contexts ::workload-list
::fixed-param-names]))
+(s/def ::load-checkpoint-opts
+ (s/keys :req-un [::prefix ::epoch]
+ :opt-un [::load-optimizer-states ::data-names ::label-names
+ ::contexts ::workload-list ::fixed-param-names]))
(defn load-checkpoint
"Create a model from previously saved checkpoint.
- - opts map of
- - prefix Path prefix of saved model files. You should have
prefix-symbol.json,
- prefix-xxxx.params, and optionally prefix-xxxx.states,
- where xxxx is the epoch number.
- - epoch Epoch to load.
- - load-optimizer-states Whether to load optimizer states.
- Checkpoint needs to have been made with
save-optimizer-states=True
- - dataNames Input data names.
- - labelNames Input label names
- - contexts Default is cpu().
- - workload-list Default nil, indicating uniform workload.
- - fixed-param-names Default nil, indicating no network parameters are
fixed."
+ `opts-map` {
+ `prefix`: string
+ Path prefix of saved model files. You should have prefix-symbol.json,
+ prefix-xxxx.params, and optionally prefix-xxxx.states, where xxxx is
+ the epoch number.
+ `epoch`: int
+ Epoch to load.
+ `load-optimizer-states`: boolean - Default is false
+ Whether to load optimizer states. Checkpoint needs to have been made
+ with `save-optimizer-states` = `true`.
+ `data-names`: vector of strings - Default is [\"data\"]
+ Input data names.
+ `label-names`: vector of strings - Default is [\"softmax_label\"]
+ Input label names.
+ `contexts`: Context - Default is `context/cpu`
+ `workload-list`: Default nil
+ Indicating uniform workload.
+ `fixed-param-names`: Default nil
+ Indicating no network parameters are fixed.
+ Ex:
+ (load-checkpoint {:prefix \"my-model\" :epoch 1 :load-optimizer-states
true}"
([{:keys [prefix epoch load-optimizer-states data-names label-names contexts
workload-list fixed-param-names] :as opts
:or {load-optimizer-states false
@@ -358,10 +417,10 @@
(util/scala-map->map (.auxParams mod)))
(defn reshape
- " Reshapes the module for new input shapes.
- - mod module
- - data-shapes Typically is `(provide-data data-iter)
- - param label-shapes Typically is `(provide-label data-tier)`. "
+ "Reshapes the module for new input shapes.
+ `mod`: Module
+ `data-shapes`: Typically is `(provide-data data-iter)`
+ `label-shapes`: Typically is `(provide-label data-tier)`"
([mod data-shapes label-shapes]
(util/validate! ::data-shapes data-shapes "Invalid data-shapes")
(util/validate! (s/nilable ::label-shapes) label-shapes "Invalid
label-shapes")
@@ -376,28 +435,35 @@
([mod data-shapes]
(reshape mod data-shapes nil)))
-(s/def ::set-param-opts (s/keys :opt-un [::arg-params ::aux-params
::allow-missing ::force-init ::allow-extra]))
+(s/def ::set-param-opts
+ (s/keys :opt-un [::arg-params ::aux-params ::allow-missing
+ ::force-init ::allow-extra]))
(defn get-params [mod]
(.getParams mod))
(defn set-params
- " Assign parameter and aux state values.
- - mod module
- - arg-params : map
- map of name to value (`NDArray`) mapping.
- - aux-params : map
- map of name to value (`NDArray`) mapping.
- - allow-missing : bool
- If true, params could contain missing values, and the initializer
will be
- called to fill those missing params.
- - force-init : bool
- If true, will force re-initialize even if already initialized.
- - allow-extra : bool
- Whether allow extra parameters that are not needed by symbol.
- If this is True, no error will be thrown when arg-params or
aux-params
- contain extra parameters that is not needed by the executor."
- [mod {:keys [arg-params aux-params allow-missing force-init allow-extra] :as
opts
+ "Assign parameters and aux state values.
+ `mod`: Module
+ `opts-map` {
+ `arg-params`: map - map of name to value (`NDArray`) mapping.
+ `aux-params`: map - map of name to value (`NDArray`) mapping.
+ `allow-missing`: boolean
+ If true, params could contain missing values, and the initializer
will
+ be called to fill those missing params.
+ `force-init`: boolean - Default is `false`
+ If true, will force re-initialize even if already initialized.
+ `allow-extra`: boolean - Default is `false`
+ Whether allow extra parameters that are not needed by symbol. If this
+ is `true`, no error will be thrown when arg-params or aux-params
+ contain extra parameters that is not needed by the executor.
+ }
+ Ex:
+ (set-params mod
+ {:arg-params {\"fc_0_weight\" (ndarray/array [0.15 0.2 0.25 0.3] [2 2])
+ :allow-missing true})"
+ [mod {:keys [arg-params aux-params allow-missing force-init
+ allow-extra] :as opts
:or {allow-missing false force-init true allow-extra false}}]
(util/validate! ::set-param-opts opts "Invalid set-params")
(doto mod
@@ -409,33 +475,32 @@
allow-extra)))
(defn install-monitor
- "Install monitor on all executors"
+ "Install monitor on all executors."
[mod monitor]
(doto mod
(.installMonitor monitor)))
(defn borrow-optimizer
- "Borrow optimizer from a shared module. Used in bucketing, where exactly the
same
- optimizer (esp. kvstore) is used.
- - mod module
- - shared-module"
+ "Borrow optimizer from a shared module. Used in bucketing, where exactly the
+ same optimizer (esp. kvstore) is used.
+ `mod`: Module
+ `shared-module`"
[mod shared-module]
(doto mod
(.borrowOptimizer shared-module)))
(defn save-optimizer-states
- "Save optimizer (updater) state to file
- - mod module
- - fname Path to output states file."
+ "Save optimizer (updater) state to file.
+ `mod`: Module
+ `fname`: string - Path to output states file."
[mod fname]
(doto mod
(.saveOptimizerStates mod fname)))
(defn load-optimizer-states
- "Load optimizer (updater) state from file
- - mod module
- - fname Path to input states file.
- "
+ "Load optimizer (updater) state from file.
+ `mod`: Module
+ `fname`: string - Path to input states file."
[mod fname]
(doto mod
(.loadOptimzerStates fname)))
@@ -444,10 +509,13 @@
(s/def ::labels (s/coll-of ::ndarray :kind vector?))
(defn update-metric
- "Evaluate and accumulate evaluation metric on outputs of the last forward
computation.
- - mod module
- - eval-metric
- - labels"
+ "Evaluate and accumulate evaluation metric on outputs of the last forward
+ computation.
+ `mod`: module
+ `eval-metric`: EvalMetric
+ `labels`: collection of NDArrays
+ Ex:
+ (update-metric mod (eval-metric/mse) labels)"
[mod eval-metric labels]
(util/validate! ::eval-metric eval-metric "Invalid eval metric")
(util/validate! ::labels labels "Invalid labels")
@@ -458,18 +526,48 @@
(s/def ::validation-metric ::eval-metric)
(s/def ::monitor #(instance? Monitor %))
(s/def ::batch-end-callback #(instance? Callback$Speedometer %))
-(s/def ::fit-params-opts (s/keys :opt-un [::eval-metric ::kvstore ::optimizer
::initializer
- ::arg-params ::aux-params
::allow-missing ::force-rebind
- ::force-init ::begin-epoch
::validation-metric ::monitor
- ::batch-end-callback]))
+(s/def ::fit-params-opts
+ (s/keys :opt-un [::eval-metric ::kvstore ::optimizer ::initializer
+ ::arg-params ::aux-params ::allow-missing ::force-rebind
+ ::force-init ::begin-epoch ::validation-metric ::monitor
+ ::batch-end-callback]))
;; callbacks are not supported for now
(defn fit-params
- "Fit Params"
+ "Initialize FitParams with provided parameters.
+ `eval-metric`: EvalMetric - Default is `accuracy`
+ `kvstore`: String - Default is \"local\"
+ `optimizer`: Optimizer - Default is `sgd`
+ `initializer`: Initializer - Default is `uniform`
+ Called to initialize parameters if needed.
+ `arg-params`: map
+ If not nil, should be a map of existing `arg-params`. Initialization
+ will be copied from that.
+ `aux-params`: map -
+ If not nil, should be a map of existing `aux-params`. Initialization
+ will be copied from that.
+ `allow-missing`: boolean - Default is `false`
+ If `true`, params could contain missing values, and the initializer
will
+ be called to fill those missing params.
+ `force-rebind`: boolean - Default is `false`
+ This function does nothing if the executors are already binded. But
with
+ this `true`, the executors will be forced to rebind.
+ `force-init`: boolean - Default is `false`
+ If `true`, will force re-initialize even if already initialized.
+ `begin-epoch`: int - Default is 0
+ `validation-metric`: EvalMetric
+ `monitor`: Monitor
+ Ex:
+ (fit-params {:force-init true :force-rebind true :allow-missing true})
+ (fit-params
+ {:batch-end-callback (callback/speedometer batch-size 100)
+ :initializer (initializer/xavier)
+ :optimizer (optimizer/sgd {:learning-rate 0.01})
+ :eval-metric (eval-metric/mse)})"
([{:keys [eval-metric kvstore optimizer
initializer arg-params aux-params
- allow-missing force-rebind force-init begin-epoch
validation-metric monitor
- batch-end-callback] :as opts
+ allow-missing force-rebind force-init begin-epoch
+ validation-metric monitor batch-end-callback] :as opts
:or {eval-metric (eval-metric/accuracy)
kvstore "local"
optimizer (optimizer/sgd)
@@ -500,25 +598,36 @@
(s/def ::ndarray-iter #(instance? NDArrayIter %))
(s/def ::train-data (s/or :mx-iter ::mx-data-iter :ndarry-iter ::ndarray-iter))
(s/def ::eval-data ::train-data)
-(s/def ::num-epoch int?)
+(s/def ::num-epoch (s/and int? pos?))
(s/def ::fit-params #(instance? FitParams %))
-(s/def ::fit-options (s/keys :req-un [::train-data] :opt-un [::eval-data
::num-epoch ::fit-params]))
+(s/def ::fit-options
+ (s/keys :req-un [::train-data]
+ :opt-un [::eval-data ::num-epoch ::fit-params]))
;;; High Level API
(defn score
- " Run prediction on `eval-data` and evaluate the performance according to
`eval-metric`.
- - mod module
- - option map with
- :eval-data : DataIter
- :eval-metric : EvalMetric
- :num-batch Number of batches to run. Default is `Integer.MAX_VALUE`,
- indicating run until the `DataIter` finishes.
- :batch-end-callback -not supported yet
- :reset Default `True`,
- indicating whether we should reset `eval-data` before
starting evaluating.
- :epoch Default 0. For compatibility, this will be passed to callbacks (if
any).
- During training, this will correspond to the training epoch
number."
+ "Run prediction on `eval-data` and evaluate the performance according to
+ `eval-metric`.
+ `mod`: module
+ `opts-map` {
+ `eval-data`: DataIter
+ `eval-metric`: EvalMetric
+ `num-batch`: int - Default is `Integer.MAX_VALUE`
+ Number of batches to run. Indicating run until the `DataIter`
+ finishes.
+ `batch-end-callback`: not supported yet.
+ `reset`: boolean - Default is `true`,
+ Indicating whether we should reset `eval-data` before starting
+ evaluating.
+ `epoch`: int - Default is 0
+ For compatibility, this will be passed to callbacks (if any). During
+ training, this will correspond to the training epoch number.
+ }
+ Ex:
+ (score mod {:eval-data data-iter :eval-metric (eval-metric/accuracy)})
+ (score mod {:eval-data data-iter
+ :eval-metric (eval-metric/mse) :num-batch 10})"
[mod {:keys [eval-data eval-metric num-batch reset epoch] :as opts
:or {num-batch Integer/MAX_VALUE
reset true
@@ -537,15 +646,30 @@
(defn fit
"Train the module parameters.
- - mod module
- - train-data (data-iterator)
- - eval-data (data-iterator)If not nil, will be used as validation set and
evaluate
- the performance after each epoch.
- - num-epoch Number of epochs to run training.
- - f-params Extra parameters for training (See fit-params)."
+ `mod`: Module
+ `opts-map` {
+ `train-data`: DataIter
+ `eval-data`: DataIter
+ If not nil, will be used as validation set and evaluate the
+ performance after each epoch.
+ `num-epoch`: int
+ Number of epochs to run training.
+ `fit-params`: FitParams
+ Extra parameters for training (see fit-params).
+ }
+ Ex:
+ (fit {:train-data train-iter :eval-data test-iter :num-epoch 100)
+ (fit {:train-data train-iter
+ :eval-data test-iter
+ :num-epoch 5
+ :fit-params
+ (fit-params {:batch-end-callback (callback/speedometer 128 100)
+ :initializer (initializer/xavier)
+ :optimizer (optimizer/sgd {:learning-rate 0.01})
+ :eval-metric (eval-metric/mse)}))"
[mod {:keys [train-data eval-data num-epoch fit-params] :as opts
- `:or {num-epoch 1
- fit-params (new FitParams)}}]
+ :or {num-epoch 1
+ fit-params (new FitParams)}}]
(util/validate! ::fit-options opts "Invalid options for fit")
(doto mod
(.fit
@@ -557,12 +681,13 @@
(s/def ::eval-data ::train-data)
(s/def ::num-batch integer?)
(s/def ::reset boolean?)
-(s/def ::predict-opts (s/keys :req-un [::eval-data] :opt-un [::num-batch
::reset]))
+(s/def ::predict-opts
+ (s/keys :req-un [::eval-data] :opt-un [::num-batch ::reset]))
(defn predict-batch
- "Run the predication on a data batch
- - mod module
- - data-batch data-batch"
+ "Run the predication on a data batch.
+ `mod`: Module
+ `data-batch`: data-batch"
[mod data-batch]
(util/validate! ::mx-io/data-batch data-batch "Invalid data batch")
(util/coerce-return (.predict mod (if (map? data-batch)
@@ -571,41 +696,60 @@
(defn predict
"Run prediction and collect the outputs.
- - mod module
- - option map with
- - :eval-data
- - :num-batch Default is -1, indicating running all the batches in the
data iterator.
- - :reset Default is `True`, indicating whether we should reset the data
iter before start
- doing prediction.
- The return value will be a vector of NDArrays `[out1, out2, out3]`.
- Where each element is concatenation of the outputs for all the
mini-batches."
+ `mod`: Module
+ `opts-map` {
+ `eval-data`: DataIter
+ `num-batch` int - Default is `-1`
+ Indicating running all the batches in the data iterator.
+ `reset`: boolean - Default is `true`
+ Indicating whether we should reset the data iter before start doing
+ prediction.
+ }
+ returns: vector of NDArrays `[out1, out2, out3]` where each element is the
+ concatenation of the outputs for all the mini-batches.
+ Ex:
+ (predict mod {:eval-data test-iter})
+ (predict mod {:eval-data test-iter :num-batch 10 :reset false})"
[mod {:keys [eval-data num-batch reset] :as opts
:or {num-batch -1
reset true}}]
(util/validate! ::predict-opts opts "Invalid opts for predict")
(util/scala-vector->vec (.predict mod eval-data (int num-batch) reset)))
-(s/def ::predict-every-batch-opts (s/keys :req-un [::eval-data] :opt-un
[::num-batch ::reset]))
+(s/def ::predict-every-batch-opts
+ (s/keys :req-un [::eval-data] :opt-un [::num-batch ::reset]))
(defn predict-every-batch
- " Run prediction and collect the outputs.
- - module
- - option map with
- :eval-data
- :num-batch Default is -1, indicating running all the batches in the data
iterator.
- :reset Default is `True`, indicating whether we should reset the data
iter before start
- doing prediction.
- The return value will be a nested list like
- [[out1_batch1, out2_batch1, ...], [out1_batch2, out2_batch2, ...]]`
- This mode is useful because in some cases (e.g. bucketing),
- the module does not necessarily produce the same number of outputs."
+ "Run prediction and collect the outputs.
+ `mod`: Module
+ `opts-map` {
+ `eval-data`: DataIter
+ `num-batch` int - Default is `-1`
+ Indicating running all the batches in the data iterator.
+ `reset` boolean - Default is `true`
+ Indicating whether we should reset the data iter before start doing
+ prediction.
+ }
+ returns: nested list like this
+ `[[out1_batch1, out2_batch1, ...], [out1_batch2, out2_batch2, ...]]`
+
+ Note: This mode is useful because in some cases (e.g. bucketing), the module
+ does not necessarily produce the same number of outputs.
+ Ex:
+ (predict-every-batch mod {:eval-data test-iter})"
[mod {:keys [eval-data num-batch reset] :as opts
:or {num-batch -1
reset true}}]
- (util/validate! ::predict-every-batch-opts opts "Invalid opts for
predict-every-batch")
- (mapv util/scala-vector->vec (util/scala-vector->vec (.predictEveryBatch mod
eval-data (int num-batch) reset))))
-
-(s/def ::score-opts (s/keys :req-un [::eval-data ::eval-metric] :opt-un
[::num-batch ::reset ::epoch]))
+ (util/validate! ::predict-every-batch-opts
+ opts
+ "Invalid opts for predict-every-batch")
+ (mapv util/scala-vector->vec
+ (util/scala-vector->vec
+ (.predictEveryBatch mod eval-data (int num-batch) reset))))
+
+(s/def ::score-opts
+ (s/keys :req-un [::eval-data ::eval-metric]
+ :opt-un [::num-batch ::reset ::epoch]))
(defn exec-group [mod]
(.execGroup mod))
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
index 7ee25d4..9dc6c8f 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
@@ -250,7 +250,7 @@
shape)))
(defn map->scala-tuple-seq
- "* Convert a map to a scala-Seq of scala-Tubple.
+ "* Convert a map to a scala-Seq of scala-Tuple.
* Should also work if a seq of seq of 2 things passed.
* Otherwise passed through unchanged."
[map-or-tuple-seq]