nswamy closed pull request #13164: [MXNET-1207][Cherry-pick] use ResourceScope
in Model/Trainer/FeedForward.scala (#12882)
URL: https://github.com/apache/incubator-mxnet/pull/13164
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
index 00a1450089f..2ed9d8cfbb8 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
@@ -17,9 +17,10 @@
package org.apache.mxnet
+import org.apache.mxnet.Base.CPtrAddress
import org.apache.mxnet.io.NDArrayIter
import org.apache.mxnet.optimizer.SGD
-import org.slf4j.{LoggerFactory, Logger}
+import org.slf4j.{Logger, LoggerFactory}
import scala.collection.mutable.ListBuffer
@@ -55,7 +56,7 @@ class FeedForward private(
argParams: Map[String, NDArray],
auxParams: Map[String, NDArray],
private val allowExtraParams: Boolean,
- val beginEpoch: Int) {
+ val beginEpoch: Int) extends NativeResource {
val logger: Logger = LoggerFactory.getLogger(classOf[FeedForward])
private var argumentChecked = false
@@ -126,6 +127,8 @@ class FeedForward private(
}
// Initialize weight parameters and auxiliary states
+ // The NDArrays associated with the _argParms and _auxParams are not
disposed instead
+ // they are passed a outer scope if available.
private def initParams(inputShapes: Map[String, Shape], overwrite: Boolean =
false)
: (IndexedSeq[String], IndexedSeq[String], IndexedSeq[String]) = {
val (argShapes, _, auxShapes) = symbol.inferShape(inputShapes)
@@ -137,16 +140,26 @@ class FeedForward private(
val paramNameShapes = (argNames zip argShapes).filter { case (name, _) =>
paramNames.contains(name)
}
- val argParams = paramNameShapes.map { case (name, shape) =>
- (name, NDArray.zeros(shape))
+ val argParams = paramNameShapes.map { case (name, shape) => {
+ val param = NDArray.zeros(shape)
+ val curScope = ResourceScope.getCurrentScope()
+ if (curScope.isDefined) curScope.get.moveToOuterScope(param)
+ (name, param)
+ }
}.toMap
- val auxParams = (auxNames zip auxShapes).map { case (name, shape) =>
- (name, NDArray.zeros(shape))
+
+ val auxParams = (auxNames zip auxShapes).map { case (name, shape) => {
+ val param = NDArray.zeros(shape)
+ val curScope = ResourceScope.getCurrentScope()
+ if (curScope.isDefined) curScope.get.moveToOuterScope(param)
+ (name, param)
+ }
}.toMap
for ((k, v) <- argParams) {
if (_argParams != null && _argParams.contains(k) && (!overwrite)) {
argParams(k).set(_argParams(k))
+
} else {
initializer(k, v)
}
@@ -277,13 +290,15 @@ class FeedForward private(
def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric,
kvStoreType: String,
epochEndCallback: EpochEndCallback, batchEndCallback:
BatchEndCallback,
logger: Logger, workLoadList: Seq[Float]): Unit = {
- // init params first to allow kv store use _argParams to decide its type
- initSymbolParams(trainData)
- // create kvstore
- val (kvStore, updateOnKVStore) = Model.createKVStore(kvStoreType,
ctx.length, _argParams)
- fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
- epochEndCallback, batchEndCallback, logger, workLoadList)
- kvStore.foreach(_.dispose())
+ ResourceScope.using() {
+ // init params first to allow kv store use _argParams to decide its type
+ initSymbolParams(trainData)
+ // create kvstore
+ val (kvStore, updateOnKVStore) = Model.createKVStore(kvStoreType,
ctx.length, _argParams)
+ fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
+ epochEndCallback, batchEndCallback, logger, workLoadList)
+// kvStore.foreach(_.dispose())
+ }
}
def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric,
@@ -313,11 +328,13 @@ class FeedForward private(
batchEndCallback: BatchEndCallback, logger: Logger,
workLoadList: Seq[Float]): Unit = {
// init params first to allow kv store use _argParams to decide its type
- initSymbolParams(trainData)
- // create kvstore
- val (kvStore, updateOnKVStore) = Model.createKVStore(kv)
- fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
- epochEndCallback, batchEndCallback, logger, workLoadList)
+ ResourceScope.using() {
+ initSymbolParams(trainData)
+ // create kvstore
+ val (kvStore, updateOnKVStore) = Model.createKVStore(kv)
+ fit(trainData, evalData, evalMetric, kvStore, updateOnKVStore,
+ epochEndCallback, batchEndCallback, logger, workLoadList)
+ }
}
def fit(trainData: DataIter, evalData: DataIter, evalMetric: EvalMetric,
@@ -352,44 +369,49 @@ class FeedForward private(
batchEndCallback: BatchEndCallback = null, logger: Logger =
FeedForward.logger,
workLoadList: Seq[Float] = null): Unit = {
require(evalMetric != null, "evalMetric cannot be null")
- val (argNames, paramNames, auxNames) = initSymbolParams(trainData)
-
- // init optimizer
- val batchSizeMultiplier = kvStore.map { kv =>
- if (kv.`type` == "dist_sync") {
- kv.numWorkers
- } else {
- 1
- }
- }
- val batchSize = trainData.batchSize * batchSizeMultiplier.getOrElse(1)
- this.optimizer.setArgNames(argNames)
- this.optimizer.setRescaleGrad(1f / batchSize)
- this.optimizer.setSymbol(this.symbol)
- val paramIdx2Name =
- if (updateOnKVStore) {
- paramNames.zipWithIndex.map { case (name, idx) => idx -> name }.toMap
- } else {
- paramNames.zipWithIndex.flatMap { case (name, idx) =>
- (0 until ctx.length).map(k => (idx * ctx.length + k) -> name).toMap
- }.toMap
+ // TODO: https://issues.apache.org/jira/browse/MXNET-1171
+ // this leaks memory, initSymbolParams->initParams is already called which
allocates
+ // NDArray in argParams, auxParams and here we are overwriting it by
calling again.
+ // PhantomRef should take care of releasing this when GC is called,
however we have to
+ // wait for the GC call to happen.
+ val (argNames, paramNames, auxNames) = initSymbolParams(trainData)
+
+ // init optimizer
+ val batchSizeMultiplier = kvStore.map { kv =>
+ if (kv.`type` == "dist_sync") {
+ kv.numWorkers
+ } else {
+ 1
+ }
}
- this.optimizer.setIdx2Name(paramIdx2Name)
-
- logger.debug("Start training on multi-device")
- Model.trainMultiDevice(
- symbol, ctx, argNames, paramNames, auxNames,
- _argParams, _auxParams,
- this.beginEpoch, this.numEpoch,
- this.epochSize, this.optimizer,
- kvStore, updateOnKVStore,
- trainData = trainData, evalData = Option(evalData),
- evalMetric = evalMetric,
- epochEndCallback = Option(epochEndCallback),
- batchEndCallback = Option(batchEndCallback),
- workLoadList = workLoadList,
- monitor = monitor,
- symGen = symGen)
+ val batchSize = trainData.batchSize * batchSizeMultiplier.getOrElse(1)
+ this.optimizer.setArgNames(argNames)
+ this.optimizer.setRescaleGrad(1f / batchSize)
+ this.optimizer.setSymbol(this.symbol)
+ val paramIdx2Name =
+ if (updateOnKVStore) {
+ paramNames.zipWithIndex.map { case (name, idx) => idx -> name }.toMap
+ } else {
+ paramNames.zipWithIndex.flatMap { case (name, idx) =>
+ (0 until ctx.length).map(k => (idx * ctx.length + k) -> name).toMap
+ }.toMap
+ }
+ this.optimizer.setIdx2Name(paramIdx2Name)
+
+ logger.debug("Start training on multi-device")
+ Model.trainMultiDevice(
+ symbol, ctx, argNames, paramNames, auxNames,
+ _argParams, _auxParams,
+ this.beginEpoch, this.numEpoch,
+ this.epochSize, this.optimizer,
+ kvStore, updateOnKVStore,
+ trainData = trainData, evalData = Option(evalData),
+ evalMetric = evalMetric,
+ epochEndCallback = Option(epochEndCallback),
+ batchEndCallback = Option(batchEndCallback),
+ workLoadList = workLoadList,
+ monitor = monitor,
+ symGen = symGen)
}
/**
@@ -416,9 +438,29 @@ class FeedForward private(
def serialize(): Array[Byte] = {
Model.serialize(this.symbol, getArgParams, getAuxParams)
}
+
+ // hack to make the FeedForward.scala work with ResourceScope and
+ // automatically release _argParms and _auxParms
+ override def nativeAddress: CPtrAddress = hashCode()
+
+ override def nativeDeAllocator: CPtrAddress => Int =
FeedForward.doNothingDeAllocator
+
+ override val ref: NativeResourceRef = super.register()
+
+ override val bytesAllocated: Long = 0L
+
+ override def dispose(): Unit = {
+ if (!super.isDisposed) {
+ _argParams.foreach { case (_, param) => param.dispose() }
+ _auxParams.foreach { case (_, param) => param.dispose() }
+ }
+ }
}
object FeedForward {
+
+ private def doNothingDeAllocator(dummy: CPtrAddress): Int = 0
+
private val logger: Logger = LoggerFactory.getLogger(classOf[FeedForward])
// Check if name is a data argument.
private def isDataArg(name: String): Boolean = {
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala
index 48d4b0c193b..1806b865337 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala
@@ -46,7 +46,8 @@ private[mxnet] trait NativeResource
*/
def nativeDeAllocator: (CPtrAddress => Int)
- /** Call NativeResource.register to get the reference
+ /**
+ * Call NativeResource.register to get the reference
*/
val ref: NativeResourceRef
@@ -56,6 +57,7 @@ private[mxnet] trait NativeResource
// intentionally making it a val, so it gets evaluated when defined
val bytesAllocated: Long
+ // this is set and unset by [[ResourceScope.add]] and
[[ResourceScope.remove]]
private[mxnet] var scope: Option[ResourceScope] = None
@volatile private var disposed = false
@@ -69,11 +71,11 @@ private[mxnet] trait NativeResource
* using PhantomReference
*/
def register(): NativeResourceRef = {
- scope = ResourceScope.getCurrentScope()
+ val scope = ResourceScope.getCurrentScope()
if (scope.isDefined) scope.get.add(this)
NativeResource.totalBytesAllocated.getAndAdd(bytesAllocated)
- // register with PhantomRef tracking to release incase the objects go
+ // register with PhantomRef tracking to release in case the objects go
// out of reference within scope but are held for long time
NativeResourceRef.register(this, nativeDeAllocator)
}
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
index 1c5782d873a..30fe1473a2c 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
@@ -58,6 +58,7 @@ class ResourceScope extends AutoCloseable {
*/
def add(resource: NativeResource): Unit = {
resourceQ.+=(resource)
+ resource.scope = Some(this)
}
/**
@@ -67,7 +68,21 @@ class ResourceScope extends AutoCloseable {
*/
def remove(resource: NativeResource): Unit = {
resourceQ.-=(resource)
+ resource.scope = None
}
+
+ /**
+ * Removes from current Scope and moves to outer scope if it exists
+ * @param resource Resource to be moved to an outer scope
+ */
+ def moveToOuterScope(resource: NativeResource): Unit = {
+ val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope()
+ if (prevScope.isDefined) {
+ this.remove(resource)
+ prevScope.get.add(resource)
+ } else this.remove(resource)
+ }
+
}
object ResourceScope {
@@ -92,32 +107,22 @@ object ResourceScope {
val curScope = if (scope != null) scope else new ResourceScope()
- val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope()
-
@inline def resourceInGeneric(g: scala.collection.Iterable[_]) = {
g.foreach( n =>
n match {
case nRes: NativeResource => {
- removeAndAddToPrevScope(nRes)
+ curScope.moveToOuterScope(nRes)
}
case kv: scala.Tuple2[_, _] => {
- if (kv._1.isInstanceOf[NativeResource]) removeAndAddToPrevScope(
+ if (kv._1.isInstanceOf[NativeResource]) curScope.moveToOuterScope(
kv._1.asInstanceOf[NativeResource])
- if (kv._2.isInstanceOf[NativeResource]) removeAndAddToPrevScope(
+ if (kv._2.isInstanceOf[NativeResource]) curScope.moveToOuterScope(
kv._2.asInstanceOf[NativeResource])
}
}
)
}
- @inline def removeAndAddToPrevScope(r: NativeResource) = {
- curScope.remove(r)
- if (prevScope.isDefined) {
- prevScope.get.add(r)
- r.scope = prevScope
- }
- }
-
@inline def safeAddSuppressed(t: Throwable, suppressed: Throwable): Unit =
{
if (!t.isInstanceOf[ControlThrowable]) t.addSuppressed(suppressed)
}
@@ -129,8 +134,8 @@ object ResourceScope {
ret match {
// don't de-allocate if returning any collection that contains
NativeResource.
case resInGeneric: scala.collection.Iterable[_] =>
resourceInGeneric(resInGeneric)
- case nRes: NativeResource => removeAndAddToPrevScope(nRes)
- case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd =>
removeAndAddToPrevScope(nd) )
+ case nRes: NativeResource => curScope.moveToOuterScope(nRes)
+ case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd =>
curScope.moveToOuterScope(nd) )
case _ => // do nothing
}
ret
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
index 608e191e019..f6c283c3dfb 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
@@ -43,7 +43,7 @@ object TrainModel {
*/
def test(model: String, dataPath: String, numExamples: Int = 60000,
numEpochs: Int = 10, benchmark: Boolean = false): Float = {
- NDArrayCollector.auto().withScope {
+ ResourceScope.using() {
val devs = Array(Context.cpu(0))
val envs: mutable.Map[String, String] = mutable.HashMap.empty[String,
String]
val (dataLoader, net) = dataLoaderAndModel("mnist", model, dataPath,
@@ -110,44 +110,46 @@ object TrainModel {
val inst = new TrainModel
val parser: CmdLineParser = new CmdLineParser(inst)
try {
- parser.parseArgument(args.toList.asJava)
-
- val dataPath = if (inst.dataDir == null) System.getenv("MXNET_HOME")
- else inst.dataDir
-
- val (dataLoader, net) = dataLoaderAndModel(inst.dataset, inst.network,
dataPath,
- inst.numLayers, inst.numExamples, inst.benchmark)
-
- val devs =
- if (inst.gpus != null) inst.gpus.split(',').map(id =>
Context.gpu(id.trim.toInt))
- else if (inst.cpus != null) inst.cpus.split(',').map(id =>
Context.cpu(id.trim.toInt))
- else Array(Context.cpu(0))
-
- val envs: mutable.Map[String, String] = mutable.HashMap.empty[String,
String]
- envs.put("DMLC_ROLE", inst.role)
- if (inst.schedulerHost != null) {
- require(inst.schedulerPort > 0, "scheduler port not specified")
- envs.put("DMLC_PS_ROOT_URI", inst.schedulerHost)
- envs.put("DMLC_PS_ROOT_PORT", inst.schedulerPort.toString)
- require(inst.numWorker > 0, "Num of workers must > 0")
- envs.put("DMLC_NUM_WORKER", inst.numWorker.toString)
- require(inst.numServer > 0, "Num of servers must > 0")
- envs.put("DMLC_NUM_SERVER", inst.numServer.toString)
- logger.info("Init PS environments")
- KVStoreServer.init(envs.toMap)
- }
-
- if (inst.role != "worker") {
- logger.info("Start KVStoreServer for scheduler & servers")
- KVStoreServer.start()
- } else {
- Trainer.fit(batchSize = inst.batchSize, numExamples =
inst.numExamples, devs = devs,
- network = net, dataLoader = dataLoader,
- kvStore = inst.kvStore, numEpochs = inst.numEpochs,
- modelPrefix = inst.modelPrefix, loadEpoch = inst.loadEpoch,
- lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch =
inst.lrFactorEpoch,
- monitorSize = inst.monitor)
- logger.info("Finish fit ...")
+ ResourceScope.using() {
+ parser.parseArgument(args.toList.asJava)
+
+ val dataPath = if (inst.dataDir == null) System.getenv("MXNET_HOME")
+ else inst.dataDir
+
+ val (dataLoader, net) = dataLoaderAndModel(inst.dataset, inst.network,
dataPath,
+ inst.numLayers, inst.numExamples, inst.benchmark)
+
+ val devs =
+ if (inst.gpus != null) inst.gpus.split(',').map(id =>
Context.gpu(id.trim.toInt))
+ else if (inst.cpus != null) inst.cpus.split(',').map(id =>
Context.cpu(id.trim.toInt))
+ else Array(Context.cpu(0))
+
+ val envs: mutable.Map[String, String] = mutable.HashMap.empty[String,
String]
+ envs.put("DMLC_ROLE", inst.role)
+ if (inst.schedulerHost != null) {
+ require(inst.schedulerPort > 0, "scheduler port not specified")
+ envs.put("DMLC_PS_ROOT_URI", inst.schedulerHost)
+ envs.put("DMLC_PS_ROOT_PORT", inst.schedulerPort.toString)
+ require(inst.numWorker > 0, "Num of workers must > 0")
+ envs.put("DMLC_NUM_WORKER", inst.numWorker.toString)
+ require(inst.numServer > 0, "Num of servers must > 0")
+ envs.put("DMLC_NUM_SERVER", inst.numServer.toString)
+ logger.info("Init PS environments")
+ KVStoreServer.init(envs.toMap)
+ }
+
+ if (inst.role != "worker") {
+ logger.info("Start KVStoreServer for scheduler & servers")
+ KVStoreServer.start()
+ } else {
+ Trainer.fit(batchSize = inst.batchSize, numExamples =
inst.numExamples, devs = devs,
+ network = net, dataLoader = dataLoader,
+ kvStore = inst.kvStore, numEpochs = inst.numEpochs,
+ modelPrefix = inst.modelPrefix, loadEpoch = inst.loadEpoch,
+ lr = inst.lr, lrFactor = inst.lrFactor, lrFactorEpoch =
inst.lrFactorEpoch,
+ monitorSize = inst.monitor)
+ logger.info("Finish fit ...")
+ }
}
} catch {
case ex: Exception => {
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
index 9a54e58b653..276816cf8c8 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/util/Trainer.scala
@@ -50,83 +50,84 @@ object Trainer {
lr: Float = 0.1f, lrFactor: Float = 1f, lrFactorEpoch: Float = 1f,
clipGradient: Float = 0f, monitorSize: Int = -1): Accuracy = {
// kvstore
- var kv = KVStore.create(kvStore)
+ ResourceScope.using() {
+ var kv = KVStore.create(kvStore)
- // load model
- val modelPrefixWithRank =
- if (modelPrefix == null) null
- else modelPrefix + s"-${kv.rank}"
+ // load model
+ val modelPrefixWithRank =
+ if (modelPrefix == null) null
+ else modelPrefix + s"-${kv.rank}"
- val (argParams, auxParams, beginEpoch) =
- if (loadEpoch >= 0) {
- require(modelPrefixWithRank != null)
- val tmp = FeedForward.load(modelPrefix, loadEpoch)
- (tmp.getArgParams, tmp.getAuxParams, loadEpoch)
- } else {
- (null, null, 0)
- }
+ val (argParams, auxParams, beginEpoch) =
+ if (loadEpoch >= 0) {
+ require(modelPrefixWithRank != null)
+ val tmp = FeedForward.load(modelPrefix, loadEpoch)
+ (tmp.getArgParams, tmp.getAuxParams, loadEpoch)
+ } else {
+ (null, null, 0)
+ }
- // save model
- val checkpoint: EpochEndCallback =
- if (modelPrefix == null) null
- else new EpochEndCallback {
- override def invoke(epoch: Int, symbol: Symbol,
- argParams: Map[String, NDArray],
- auxStates: Map[String, NDArray]): Unit = {
- Model.saveCheckpoint(modelPrefix, epoch + 1, symbol, argParams,
auxParams)
+ // save model
+ val checkpoint: EpochEndCallback =
+ if (modelPrefix == null) null
+ else new EpochEndCallback {
+ override def invoke(epoch: Int, symbol: Symbol,
+ argParams: Map[String, NDArray],
+ auxStates: Map[String, NDArray]): Unit = {
+ Model.saveCheckpoint(modelPrefix, epoch + 1, symbol, argParams,
auxParams)
+ }
}
- }
- // data
- val (train, validation) = dataLoader(batchSize, kv)
+ // data
+ val (train, validation) = dataLoader(batchSize, kv)
- // train
- val epochSize =
- if (kvStore == "dist_sync") numExamples / batchSize / kv.numWorkers
- else numExamples / batchSize
+ // train
+ val epochSize =
+ if (kvStore == "dist_sync") numExamples / batchSize / kv.numWorkers
+ else numExamples / batchSize
- val lrScheduler =
- if (lrFactor < 1f) {
- new FactorScheduler(step = Math.max((epochSize * lrFactorEpoch).toInt,
1),
- factor = lrFactor)
- } else {
- null
- }
- val optimizer: Optimizer = new SGD(learningRate = lr,
- lrScheduler = lrScheduler, clipGradient = clipGradient,
- momentum = 0.9f, wd = 0.00001f)
+ val lrScheduler =
+ if (lrFactor < 1f) {
+ new FactorScheduler(step = Math.max((epochSize *
lrFactorEpoch).toInt, 1),
+ factor = lrFactor)
+ } else {
+ null
+ }
+ val optimizer: Optimizer = new SGD(learningRate = lr,
+ lrScheduler = lrScheduler, clipGradient = clipGradient,
+ momentum = 0.9f, wd = 0.00001f)
- // disable kvstore for single device
- if (kv.`type`.contains("local") && (devs.length == 1 || devs(0).deviceType
!= "gpu")) {
- kv.dispose()
- kv = null
- }
+ // disable kvstore for single device
+ if (kv.`type`.contains("local") && (devs.length == 1 ||
devs(0).deviceType != "gpu")) {
+ kv.dispose()
+ kv = null
+ }
- val model = new FeedForward(ctx = devs,
- symbol = network,
- numEpoch = numEpochs,
- optimizer = optimizer,
- initializer = new Xavier(factorType = "in",
magnitude = 2.34f),
- argParams = argParams,
- auxParams = auxParams,
- beginEpoch = beginEpoch,
- epochSize = epochSize)
- if (monitorSize > 0) {
- model.setMonitor(new Monitor(monitorSize))
- }
- val acc = new Accuracy()
- model.fit(trainData = train,
- evalData = validation,
- evalMetric = acc,
- kvStore = kv,
- batchEndCallback = new Speedometer(batchSize, 50),
- epochEndCallback = checkpoint)
- if (kv != null) {
- kv.dispose()
+ val model = new FeedForward(ctx = devs,
+ symbol = network,
+ numEpoch = numEpochs,
+ optimizer = optimizer,
+ initializer = new Xavier(factorType = "in",
magnitude = 2.34f),
+ argParams = argParams,
+ auxParams = auxParams,
+ beginEpoch = beginEpoch,
+ epochSize = epochSize)
+ if (monitorSize > 0) {
+ model.setMonitor(new Monitor(monitorSize))
+ }
+ val acc = new Accuracy()
+ model.fit(trainData = train,
+ evalData = validation,
+ evalMetric = acc,
+ kvStore = kv,
+ batchEndCallback = new Speedometer(batchSize, 50),
+ epochEndCallback = checkpoint)
+ if (kv != null) {
+ kv.dispose()
+ }
+ acc
}
- acc
}
-
// scalastyle:on parameterNum
}
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services