lanking520 closed pull request #11621: [MXNET-531] NeuralStyle Example for Scala
URL: https://github.com/apache/incubator-mxnet/pull/11621
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/ModelVgg19.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/ModelVgg19.scala
index 4d9aa35d21f..ca4c242ab1c 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/ModelVgg19.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/ModelVgg19.scala
@@ -17,92 +17,73 @@
package org.apache.mxnetexamples.neuralstyle
-import org.apache.mxnet.Context
-import org.apache.mxnet.Executor
-import org.apache.mxnet.NDArray
-import org.apache.mxnet.Symbol
-import org.apache.mxnet.Shape
+import org.apache.mxnet.{Context, Executor, NDArray, Shape, Symbol}
/**
- * Definition for the neuralstyle network and initialize it with pretrained
weight
- * @author Depeng Liang
- */
+ * Definition for the neuralstyle network and initialize it with pretrained
weight
+ */
object ModelVgg19 {
case class ConvExecutor(executor: Executor, data: NDArray, dataGrad: NDArray,
- style: Array[NDArray], content: NDArray, argDict:
Map[String, NDArray])
+ style: Array[NDArray], content: NDArray, argDict:
Map[String, NDArray])
+
+ def ConvRelu(data : Symbol, convName : String, reluName : String,
+ numFilter : Int, kernel : (Int, Int) = (3, 3),
+ stride : (Int, Int) = (1, 1)) : Symbol = {
+ val conv = Symbol.api.Convolution(data = Some(data), num_filter =
numFilter,
+ pad = Some(Shape(1, 1)), kernel = Shape(kernel._1, kernel._2),
+ stride = Some(Shape(stride._1, stride._2)), no_bias = Some(false),
+ workspace = Some(1024), name = convName)
+ val relu = Symbol.api.relu(data = Some(conv), name = reluName)
+ conv.dispose()
+ relu
+ }
def getSymbol: (Symbol, Symbol) = {
+ getVggSymbol()
+ }
+
+ def getVggSymbol(prefix: String = "", contentOnly: Boolean = false):
(Symbol, Symbol) = {
// declare symbol
- val data = Symbol.Variable("data")
- val conv1_1 = Symbol.Convolution("conv1_1")()(Map("data" -> data ,
"num_filter" -> 64,
- "pad" -> "(1,1)", "kernel" -> "(3,3)",
"stride" -> "(1,1)",
- "no_bias" -> false, "workspace" ->
1024))
- val relu1_1 = Symbol.Activation("relu1_1")()(Map("data" -> conv1_1 ,
"act_type" -> "relu"))
- val conv1_2 = Symbol.Convolution("conv1_2")()(Map("data" -> relu1_1 ,
"num_filter" -> 64,
- "pad" -> "(1,1)", "kernel" -> "(3,3)",
"stride" -> "(1,1)",
- "no_bias" -> false, "workspace" ->
1024))
- val relu1_2 = Symbol.Activation("relu1_2")()(Map("data" -> conv1_2 ,
"act_type" -> "relu"))
- val pool1 = Symbol.Pooling("pool1")()(Map("data" -> relu1_2 , "pad" ->
"(0,0)",
- "kernel" -> "(2,2)", "stride" -> "(2,2)",
"pool_type" -> "avg"))
- val conv2_1 = Symbol.Convolution("conv2_1")()(Map("data" -> pool1 ,
"num_filter" -> 128,
- "pad" -> "(1,1)", "kernel" -> "(3,3)",
"stride" -> "(1,1)",
- "no_bias" -> false, "workspace" ->
1024))
- val relu2_1 = Symbol.Activation("relu2_1")()(Map("data" -> conv2_1 ,
"act_type" -> "relu"))
- val conv2_2 = Symbol.Convolution("conv2_2")()(Map("data" -> relu2_1 ,
"num_filter" -> 128,
- "pad" -> "(1,1)", "kernel" -> "(3,3)",
"stride" -> "(1,1)",
- "no_bias" -> false, "workspace" ->
1024))
- val relu2_2 = Symbol.Activation("relu2_2")()(Map("data" -> conv2_2 ,
"act_type" -> "relu"))
- val pool2 = Symbol.Pooling("pool2")()(Map("data" -> relu2_2 , "pad" ->
"(0,0)",
- "kernel" -> "(2,2)", "stride" -> "(2,2)",
"pool_type" -> "avg"))
- val conv3_1 = Symbol.Convolution("conv3_1")()(Map("data" -> pool2 ,
"num_filter" -> 256,
- "pad" -> "(1,1)", "kernel" -> "(3,3)",
"stride" -> "(1,1)",
- "no_bias" -> false, "workspace" ->
1024))
- val relu3_1 = Symbol.Activation("relu3_1")()(Map("data" -> conv3_1 ,
"act_type" -> "relu"))
- val conv3_2 = Symbol.Convolution("conv3_2")()(Map("data" -> relu3_1 ,
"num_filter" -> 256,
- "pad" -> "(1,1)", "kernel" -> "(3,3)",
"stride" -> "(1,1)",
- "no_bias" -> false, "workspace" ->
1024))
- val relu3_2 = Symbol.Activation("'relu3_2")()(Map("data" -> conv3_2 ,
"act_type" -> "relu"))
- val conv3_3 = Symbol.Convolution("conv3_3")()(Map("data" -> relu3_2 ,
"num_filter" -> 256,
- "pad" -> "(1,1)", "kernel" -> "(3,3)",
"stride" -> "(1,1)",
- "no_bias" -> false, "workspace" ->
1024))
- val relu3_3 = Symbol.Activation("relu3_3")()(Map("data" -> conv3_3 ,
"act_type" -> "relu"))
- val conv3_4 = Symbol.Convolution("conv3_4")()(Map("data" -> relu3_3 ,
"num_filter" -> 256,
- "pad" -> "(1,1)", "kernel" -> "(3,3)",
"stride" -> "(1,1)",
- "no_bias" -> false, "workspace" ->
1024))
- val relu3_4 = Symbol.Activation("relu3_4")()(Map("data" -> conv3_4 ,
"act_type" -> "relu"))
- val pool3 = Symbol.Pooling("pool3")()(Map("data" -> relu3_4 , "pad" ->
"(0,0)",
- "kernel" -> "(2,2)", "stride" -> "(2,2)",
"pool_type" -> "avg"))
- val conv4_1 = Symbol.Convolution("conv4_1")()(Map("data" -> pool3 ,
"num_filter" -> 512,
- "pad" -> "(1,1)", "kernel" -> "(3,3)",
"stride" -> "(1,1)",
- "no_bias" -> false, "workspace" ->
1024))
- val relu4_1 = Symbol.Activation("relu4_1")()(Map("data" -> conv4_1 ,
"act_type" -> "relu"))
- val conv4_2 = Symbol.Convolution("conv4_2")()(Map("data" -> relu4_1 ,
"num_filter" -> 512,
- "pad" -> "(1,1)", "kernel" -> "(3,3)",
"stride" -> "(1,1)",
- "no_bias" -> false, "workspace" ->
1024))
- val relu4_2 = Symbol.Activation("relu4_2")()(Map("data" -> conv4_2 ,
"act_type" -> "relu"))
- val conv4_3 = Symbol.Convolution("conv4_3")()(Map("data" -> relu4_2 ,
"num_filter" -> 512,
- "pad" -> "(1,1)", "kernel" -> "(3,3)",
"stride" -> "(1,1)",
- "no_bias" -> false, "workspace" ->
1024))
- val relu4_3 = Symbol.Activation("relu4_3")()(Map("data" -> conv4_3 ,
"act_type" -> "relu"))
- val conv4_4 = Symbol.Convolution("conv4_4")()(Map("data" -> relu4_3 ,
"num_filter" -> 512,
- "pad" -> "(1,1)", "kernel" -> "(3,3)",
"stride" -> "(1,1)",
- "no_bias" -> false, "workspace" ->
1024))
- val relu4_4 = Symbol.Activation("relu4_4")()(Map("data" -> conv4_4 ,
"act_type" -> "relu"))
- val pool4 = Symbol.Pooling("pool4")()(Map("data" -> relu4_4 , "pad" ->
"(0,0)",
- "kernel" -> "(2,2)", "stride" -> "(2,2)",
"pool_type" -> "avg"))
- val conv5_1 = Symbol.Convolution("conv5_1")()(Map("data" -> pool4 ,
"num_filter" -> 512,
- "pad" -> "(1,1)", "kernel" -> "(3,3)",
"stride" -> "(1,1)",
- "no_bias" -> false, "workspace" ->
1024))
- val relu5_1 = Symbol.Activation("relu5_1")()(Map("data" -> conv5_1 ,
"act_type" -> "relu"))
+ val data = Symbol.Variable(s"${prefix}data")
+
+ val relu1_1 = ConvRelu(data, s"${prefix}conv1_1", s"${prefix}relu1_1", 64)
+ val relu1_2 = ConvRelu(relu1_1, s"${prefix}conv1_2", s"${prefix}relu1_2",
64)
+ val pool1 = Symbol.api.Pooling(data = Some(relu1_2), pad = Some(Shape(0,
0)),
+ kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)), pool_type =
Some("avg"),
+ name = s"${prefix}pool1")
+
+ val relu2_1 = ConvRelu(pool1, s"${prefix}conv2_1", s"${prefix}relu2_1",
128)
+ val relu2_2 = ConvRelu(relu2_1, s"${prefix}conv2_2", s"${prefix}relu2_2",
128)
+ val pool2 = Symbol.api.Pooling(data = Some(relu2_2), pad = Some(Shape(0,
0)),
+ kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)), pool_type =
Some("avg"),
+ name = s"${prefix}pool2")
+
+ val relu3_1 = ConvRelu(pool2, s"${prefix}conv3_1", s"${prefix}relu3_1",
256)
+ val relu3_2 = ConvRelu(relu3_1, s"${prefix}conv3_2", s"${prefix}relu3_2",
256)
+ val relu3_3 = ConvRelu(relu3_2, s"${prefix}conv3_3", s"${prefix}relu3_3",
256)
+ val relu3_4 = ConvRelu(relu3_3, s"${prefix}conv3_4", s"${prefix}relu3_4",
256)
+ val pool3 = Symbol.api.Pooling(data = Some(relu3_4), pad = Some(Shape(0,
0)),
+ kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)), pool_type =
Some("avg"),
+ name = s"${prefix}pool3")
+
+ val relu4_1 = ConvRelu(pool3, s"${prefix}conv4_1", s"${prefix}relu4_1",
512)
+ val relu4_2 = ConvRelu(relu4_1, s"${prefix}conv4_2", s"${prefix}relu4_2",
512)
+ val relu4_3 = ConvRelu(relu4_2, s"${prefix}conv4_3", s"${prefix}relu4_3",
512)
+ val relu4_4 = ConvRelu(relu4_3, s"${prefix}conv4_4", s"${prefix}relu4_4",
512)
+ val pool4 = Symbol.api.Pooling(data = Some(relu4_4), pad = Some(Shape(0,
0)),
+ kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)), pool_type =
Some("avg"),
+ name = s"${prefix}pool4")
+
+ val relu5_1 = ConvRelu(pool4, s"${prefix}conv5_1", s"${prefix}relu5_1",
512)
// style and content layers
- val style = Symbol.Group(relu1_1, relu2_1, relu3_1, relu4_1, relu5_1)
+ val style = if (contentOnly) null else Symbol.Group(relu1_1, relu2_1,
relu3_1, relu4_1, relu5_1)
val content = Symbol.Group(relu4_2)
(style, content)
}
def getExecutor(style: Symbol, content: Symbol, modelPath: String,
- inputSize: (Int, Int), ctx: Context): ConvExecutor = {
+ inputSize: (Int, Int), ctx: Context): ConvExecutor = {
val out = Symbol.Group(style, content)
// make executor
val (argShapes, outputShapes, auxShapes) = out.inferShape(
@@ -116,15 +97,17 @@ object ModelVgg19 {
val key = s"arg:$name"
if (pretrained.contains(key)) argDict(name).set(pretrained(key))
}
+ pretrained.foreach(ele => ele._2.dispose())
val executor = out.bind(ctx, argDict, gradDict)
+ out.dispose()
val outArray = executor.outputs
ConvExecutor(executor = executor,
- data = argDict("data"),
- dataGrad = gradDict("data"),
- style = outArray.take(outArray.length - 1),
- content = outArray(outArray.length - 1),
- argDict = argDict)
- }
+ data = argDict("data"),
+ dataGrad = gradDict("data"),
+ style = outArray.take(outArray.length - 1),
+ content = outArray(outArray.length - 1),
+ argDict = argDict)
+ }
def getModel(modelPath: String, inputSize: (Int, Int), ctx: Context):
ConvExecutor = {
val (style, content) = getSymbol
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala
index d99ea641b5d..f98d725c230 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyle.scala
@@ -17,22 +17,22 @@
package org.apache.mxnetexamples.neuralstyle
-import org.apache.mxnet._
-import org.kohsuke.args4j.{CmdLineParser, Option}
-import org.slf4j.LoggerFactory
-import scala.collection.JavaConverters._
-import com.sksamuel.scrimage.Image
import java.io.File
-import com.sksamuel.scrimage.Pixel
+
+import com.sksamuel.scrimage.{Image, Pixel}
import com.sksamuel.scrimage.filter.GaussianBlurFilter
import com.sksamuel.scrimage.nio.JpegWriter
+import org.apache.mxnet._
import org.apache.mxnet.optimizer.Adam
+import org.kohsuke.args4j.{CmdLineParser, Option}
+import org.slf4j.LoggerFactory
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ListBuffer
/**
- * An Implementation of the paper A Neural Algorithm of Artistic Style
- * by Leon A. Gatys, Alexander S. Ecker, and Matthias Bethge
- * @author Depeng Liang
- */
+ * An Implementation of the paper A Neural Algorithm of Artistic Style
+ */
object NeuralStyle {
case class NSExecutor(executor: Executor, data: NDArray, dataGrad: NDArray)
@@ -109,11 +109,11 @@ object NeuralStyle {
var gradScale = List[Int]()
for (i <- 0 until style.listOutputs().length) {
val shape = outputShape(i)
- val x = Symbol.Reshape()()(Map("data" -> style.get(i),
- "target_shape" -> Shape(shape(1), shape(2) * shape(3))))
- // use fully connected to quickly do dot(x, x^T)
- val gram = Symbol.FullyConnected()()(Map("data" -> x, "weight" -> x,
- "no_bias" -> true, "num_hidden" -> shape(1)))
+ val x = Symbol.api.Reshape(data = Some(style.get(i)),
+ target_shape = Some(Shape(shape(1), shape(2) * shape(3))))
+ val gram = Symbol.api.FullyConnected(data = Some(x), weight = Some(x),
+ no_bias = Some(true), num_hidden = shape(1))
+ x.dispose()
gramList = gramList :+ gram
gradScale = gradScale :+ (shape(1) * shape(2) * shape(3) * shape(1))
}
@@ -121,13 +121,20 @@ object NeuralStyle {
}
def getLoss(gram: Symbol, content: Symbol): (Symbol, Symbol) = {
- var gramLoss = List[Symbol]()
+ var gramLoss = ListBuffer[Symbol]()
for (i <- 0 until gram.listOutputs().length) {
val gvar = Symbol.Variable(s"target_gram_$i")
- gramLoss = gramLoss :+ Symbol.sum()(Symbol.square()(gvar -
gram.get(i))())()
+ Symbol.api.square(data = Some(gvar - gram.get(i)))
+ gramLoss += Symbol.api.sum(
+ Some(Symbol.api.square(data = Some(gvar - gram.get(i))))
+ )
+ gvar.dispose()
}
+ gram.dispose()
val cvar = Symbol.Variable("target_content")
- val contentLoss = Symbol.sum()(Symbol.square()(cvar - content)())()
+ val contentLoss = Symbol.api.sum(
+ Some(Symbol.api.square(Some(cvar - content)))
+ )
(Symbol.Group(gramLoss: _*), contentLoss)
}
@@ -138,12 +145,13 @@ object NeuralStyle {
val nChannel = img.shape(1)
val sImg = Symbol.Variable("img")
val sKernel = Symbol.Variable("kernel")
- val channels = Symbol.SliceChannel()(sImg)(Map("num_outputs" -> nChannel))
- val out = Symbol.Concat()((0 until nChannel).map { i =>
- Symbol.Convolution()()(Map("data" -> channels.get(i), "weight" ->
sKernel,
- "num_filter" -> 1, "kernel" -> "(3,3)", "pad" -> "(1,1)",
- "no_bias" -> true, "stride" -> "(1,1)"))
- }: _*)() * tvWeight
+ val channels = Symbol.api.SliceChannel(data = Some(sImg), num_outputs =
nChannel)
+ val result = (0 until nChannel).map { i =>
+ Symbol.api.Convolution(data = Some(channels.get(i)), weight =
Some(sKernel),
+ num_filter = 1, kernel = Shape(3, 3), pad = Some(Shape(1, 1)), no_bias
= Some(true),
+ stride = Some(Shape(1, 1)))
+ }.toArray
+ val out = Symbol.api.Concat(result, result.length) * tvWeight
val kernel = {
val tmp = NDArray.empty(Shape(1, 1, 3, 3), ctx)
tmp.set(Array[Float](0, -1, 0, -1, 4, -1, 0, -1, 0))
@@ -156,104 +164,123 @@ object NeuralStyle {
Math.sqrt(array.map(x => x * x).sum.toDouble).toFloat
}
- def main(args: Array[String]): Unit = {
- val alle = new NeuralStyle
- val parser: CmdLineParser = new CmdLineParser(alle)
- try {
- parser.parseArgument(args.toList.asJava)
- assert(alle.contentImage != null && alle.styleImage != null
- && alle.modelPath != null && alle.outputDir != null)
+ //scalastyle:off
+ def runTraining(model : String, contentImage : String, styleImage: String,
dev : Context,
+ modelPath : String, outputDir : String, styleWeight : Float,
+ contentWeight : Float, tvWeight : Float, gaussianRadius :
Int,
+ lr: Float, maxNumEpochs: Int, maxLongEdge: Int,
+ saveEpochs : Int, stopEps: Float) : Unit = {
- val dev = if (alle.gpu >= 0) Context.gpu(alle.gpu) else Context.cpu(0)
- val contentNp = preprocessContentImage(alle.contentImage,
alle.maxLongEdge, dev)
- val styleNp = preprocessStyleImage(alle.styleImage, contentNp.shape, dev)
- val size = (contentNp.shape(2), contentNp.shape(3))
+ val contentNp = preprocessContentImage(contentImage, maxLongEdge, dev)
+ val styleNp = preprocessStyleImage(styleImage, contentNp.shape, dev)
+ val size = (contentNp.shape(2), contentNp.shape(3))
- val (style, content) = ModelVgg19.getSymbol
- val (gram, gScale) = styleGramSymbol(size, style)
- var modelExecutor = ModelVgg19.getExecutor(gram, content,
alle.modelPath, size, dev)
+ val (style, content) = ModelVgg19.getSymbol
+ val (gram, gScale) = styleGramSymbol(size, style)
+ var modelExecutor = ModelVgg19.getExecutor(gram, content, modelPath, size,
dev)
- modelExecutor.data.set(styleNp)
- modelExecutor.executor.forward()
+ modelExecutor.data.set(styleNp)
+ modelExecutor.executor.forward()
- val styleArray = modelExecutor.style.map(_.copyTo(Context.cpu()))
- modelExecutor.data.set(contentNp)
- modelExecutor.executor.forward()
- val contentArray = modelExecutor.content.copyTo(Context.cpu())
+ val styleArray = modelExecutor.style.map(_.copyTo(Context.cpu()))
+ modelExecutor.data.set(contentNp)
+ modelExecutor.executor.forward()
+ val contentArray = modelExecutor.content.copyTo(Context.cpu())
- // delete the executor
- modelExecutor = null
+ // delete the executor
+ modelExecutor.argDict.foreach(ele => ele._2.dispose())
+ modelExecutor.content.dispose()
+ modelExecutor.data.dispose()
+ modelExecutor.dataGrad.dispose()
+ modelExecutor.style.foreach(_.dispose())
+ modelExecutor.executor.dispose()
+ modelExecutor = null
- val (styleLoss, contentLoss) = getLoss(gram, content)
- modelExecutor = ModelVgg19.getExecutor(
- styleLoss, contentLoss, alle.modelPath, size, dev)
+ val (styleLoss, contentLoss) = getLoss(gram, content)
+ modelExecutor = ModelVgg19.getExecutor(
+ styleLoss, contentLoss, modelPath, size, dev)
- val gradArray = {
- var tmpGA = Array[NDArray]()
- for (i <- 0 until styleArray.length) {
- modelExecutor.argDict(s"target_gram_$i").set(styleArray(i))
- tmpGA = tmpGA :+ NDArray.ones(Shape(1), dev) * (alle.styleWeight /
gScale(i))
- }
- tmpGA :+ NDArray.ones(Shape(1), dev) * alle.contentWeight
+ val gradArray = {
+ var tmpGA = Array[NDArray]()
+ for (i <- 0 until styleArray.length) {
+ modelExecutor.argDict(s"target_gram_$i").set(styleArray(i))
+ tmpGA = tmpGA :+ NDArray.ones(Shape(1), dev) * (styleWeight /
gScale(i))
}
+ tmpGA :+ NDArray.ones(Shape(1), dev) * contentWeight
+ }
- modelExecutor.argDict("target_content").set(contentArray)
-
- // train
- val img = Random.uniform(-0.1f, 0.1f, contentNp.shape, dev)
- val lr = new FactorScheduler(step = 10, factor = 0.9f)
-
- saveImage(contentNp, s"${alle.outputDir}/input.jpg", alle.guassianRadius)
- saveImage(styleNp, s"${alle.outputDir}/style.jpg", alle.guassianRadius)
-
- val optimizer = new Adam(
- learningRate = alle.lr,
- wd = 0.005f,
- lrScheduler = lr)
- val optimState = optimizer.createState(0, img)
-
- logger.info(s"start training arguments $alle")
-
- var oldImg = img.copyTo(dev)
- val clipNorm = img.shape.toVector.reduce(_ * _)
- val tvGradExecutor = getTvGradExecutor(img, dev, alle.tvWeight)
- var eps = 0f
- var trainingDone = false
- var e = 0
- while (e < alle.maxNumEpochs && !trainingDone) {
- modelExecutor.data.set(img)
- modelExecutor.executor.forward()
- modelExecutor.executor.backward(gradArray)
-
- val gNorm = NDArray.norm(modelExecutor.dataGrad).toScalar
- if (gNorm > clipNorm) {
- modelExecutor.dataGrad.set(modelExecutor.dataGrad * (clipNorm /
gNorm))
- }
- tvGradExecutor match {
- case Some(executor) => {
- executor.forward()
- optimizer.update(0, img,
- modelExecutor.dataGrad + executor.outputs(0),
- optimState)
- }
- case None =>
- optimizer.update(0, img, modelExecutor.dataGrad, optimState)
- }
- eps = (NDArray.norm(oldImg - img) / NDArray.norm(img)).toScalar
- oldImg.set(img)
- logger.info(s"epoch $e, relative change $eps")
+ modelExecutor.argDict("target_content").set(contentArray)
- if (eps < alle.stopEps) {
- logger.info("eps < args.stop_eps, training finished")
- trainingDone = true
- }
- if ((e + 1) % alle.saveEpochs == 0) {
- saveImage(img, s"${alle.outputDir}/tmp_${e + 1}.jpg",
alle.guassianRadius)
+ // train
+ val img = Random.uniform(-0.1f, 0.1f, contentNp.shape, dev)
+ val lrFS = new FactorScheduler(step = 10, factor = 0.9f)
+
+ saveImage(contentNp, s"${outputDir}/input.jpg", gaussianRadius)
+ saveImage(styleNp, s"${outputDir}/style.jpg", gaussianRadius)
+
+ val optimizer = new Adam(
+ learningRate = lr,
+ wd = 0.005f,
+ lrScheduler = lrFS)
+ val optimState = optimizer.createState(0, img)
+
+ logger.info(s"start training arguments")
+
+ var oldImg = img.copyTo(dev)
+ val clipNorm = img.shape.toVector.reduce(_ * _)
+ val tvGradExecutor = getTvGradExecutor(img, dev, tvWeight)
+ var eps = 0f
+ var trainingDone = false
+ var e = 0
+ while (e < maxNumEpochs && !trainingDone) {
+ modelExecutor.data.set(img)
+ modelExecutor.executor.forward()
+ modelExecutor.executor.backward(gradArray)
+
+ val gNorm = NDArray.norm(modelExecutor.dataGrad).toScalar
+ if (gNorm > clipNorm) {
+ modelExecutor.dataGrad.set(modelExecutor.dataGrad * (clipNorm / gNorm))
+ }
+ tvGradExecutor match {
+ case Some(executor) => {
+ executor.forward()
+ optimizer.update(0, img,
+ modelExecutor.dataGrad + executor.outputs(0),
+ optimState)
}
- e = e + 1
+ case None =>
+ optimizer.update(0, img, modelExecutor.dataGrad, optimState)
+ }
+ eps = (NDArray.norm(oldImg - img) / NDArray.norm(img)).toScalar
+ oldImg.set(img)
+ logger.info(s"epoch $e, relative change $eps")
+
+ if (eps < stopEps) {
+ logger.info("eps < args.stop_eps, training finished")
+ trainingDone = true
+ }
+ if ((e + 1) % saveEpochs == 0) {
+ saveImage(img, s"${outputDir}/tmp_${e + 1}.jpg", gaussianRadius)
}
- saveImage(img, s"${alle.outputDir}/out.jpg", alle.guassianRadius)
- logger.info("Finish fit ...")
+ e = e + 1
+ }
+ saveImage(img, s"${outputDir}/out.jpg", gaussianRadius)
+ logger.info("Finish fit ...")
+ }
+
+ def main(args: Array[String]): Unit = {
+ val alle = new NeuralStyle
+ val parser: CmdLineParser = new CmdLineParser(alle)
+ try {
+ parser.parseArgument(args.toList.asJava)
+ assert(alle.contentImage != null && alle.styleImage != null
+ && alle.modelPath != null && alle.outputDir != null)
+
+ val dev = if (alle.gpu >= 0) Context.gpu(alle.gpu) else Context.cpu(0)
+ runTraining(alle.model, alle.contentImage, alle.styleImage, dev,
alle.modelPath,
+ alle.outputDir, alle.styleWeight, alle.contentWeight, alle.tvWeight,
+ alle.gaussianRadius, alle.lr, alle.maxNumEpochs, alle.maxLongEdge,
+ alle.saveEpochs, alle.stopEps)
} catch {
case ex: Exception => {
logger.error(ex.getMessage, ex)
@@ -293,6 +320,6 @@ class NeuralStyle {
private val outputDir: String = null
@Option(name = "--save-epochs", usage = "save the output every n epochs")
private val saveEpochs: Int = 50
- @Option(name = "--guassian-radius", usage = "the gaussian blur filter
radius")
- private val guassianRadius: Int = 1
+ @Option(name = "--gaussian-radius", usage = "the gaussian blur filter
radius")
+ private val gaussianRadius: Int = 1
}
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/README.md
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/README.md
new file mode 100644
index 00000000000..fe849343c9d
--- /dev/null
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/README.md
@@ -0,0 +1,83 @@
+# Neural Style Example for Scala
+
+## Introduction
+This model contains three important components:
+- Boost Inference
+- Boost Training
+- Neural Style conversion
+
+You can use the prebuilt VGG model to do the conversion.
+By adding a style image, you can create several interesting images.
+
+Original Image | Style Image
+:-------------------------:|:-------------------------:
+
|

+
+Boost Inference Image (pretrained) | Epoch 150 Image
+:-------------------------:|:-------------------------:
+
|

+
+## Setup
+Please download the input image and style image following the links below:
+
+Input image
+```bash
+https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/NeuralStyle/IMG_4343.jpg
+```
+Style image
+```bash
+https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/NeuralStyle/starry_night.jpg
+```
+
+VGG model --Boost inference
+```bash
+https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/NeuralStyle/model.zip
+```
+
+VGG model --Boost Training
+```bash
+https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/NeuralStyle/vgg19.params
+```
+
+Please unzip the model before you use it.
+
+## Boost Inference Example
+
+Please provide the corresponding arguments before you execute the program
+```bash
+--input-image
+<path>/IMG_4343.jpg
+--model-path
+<path>/model
+--output-path
+<outputPath>
+```
+
+## Boost Training Example
+Please download your own training data for boost training.
+You can use 26k images sampled from [MIT Place
dataset](http://places.csail.mit.edu/).
+```bash
+--style-image
+<path>/starry_night.jpg
+--data-path
+<path>/images
+--vgg-model-path
+<path>/vgg19.params
+--save-model-path
+<path>
+```
+
+## NeuralStyle Example
+Please provide the corresponding arguments before you execute the program
+```bash
+--model-path
+<path>/vgg19.params
+--content-image
+<path>/IMG_4343.jpg
+--style-image
+<path>/starry_night.jpg
+--gpu
+<num_of_gpus>
+--output-dir
+<path>
+```
\ No newline at end of file
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/Basic.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/Basic.scala
index c604f842c4c..56303253f33 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/Basic.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/Basic.scala
@@ -17,16 +17,11 @@
package org.apache.mxnetexamples.neuralstyle.end2end
-import org.apache.mxnet.Shape
-import org.apache.mxnet.Context
-import org.apache.mxnet.NDArray
-import org.apache.mxnet.Symbol
-import org.apache.mxnet.Initializer
+import org.apache.mxnet.{Context, Initializer, NDArray, Shape, Symbol}
+import org.apache.mxnetexamples.neuralstyle.ModelVgg19
import org.slf4j.LoggerFactory
-/**
- * @author Depeng Liang
- */
+
object Basic {
class PretrainedInit(prefix: String, params: Map[String, NDArray],
@@ -61,7 +56,7 @@ object Basic {
def getStyleModule(prefix: String, dShape: Shape,
ctx: Context, params: Map[String, NDArray]): Module = {
val inputShape = Map(s"${prefix}_data" -> dShape)
- val (style, content) = ModelVgg19.getVggSymbol(prefix)
+ val (style, content) = ModelVgg19.getVggSymbol(prefix + "_")
val (gram, gScale) = styleGramSymbol(inputShape, style)
val init = new PretrainedInit(prefix, params, true)
new Module(symbol = gram, context = ctx,
@@ -75,11 +70,10 @@ object Basic {
var gradScale = List[Int]()
for (i <- 0 until style.listOutputs().length) {
val shape = outputShape(i)
- val x = Symbol.Reshape()()(Map("data" -> style.get(i),
- "shape" -> Shape(shape(1), shape(2) * shape(3))))
- // use fully connected to quickly do dot(x, x^T)
- val gram = Symbol.FullyConnected()()(Map("data" -> x, "weight" -> x,
- "no_bias" -> true, "num_hidden" -> shape(1)))
+ val x = Symbol.api.Reshape(data = Some(style.get(i)),
+ shape = Some(Shape(shape(1), shape(2) * shape(3))))
+ val gram = Symbol.api.FullyConnected(data = Some(x), weight = Some(x),
+ no_bias = Some(true), num_hidden = shape(1))
gramList = gramList :+ gram
gradScale = gradScale :+ (shape(1) * shape(2) * shape(3) * shape(1))
}
@@ -90,16 +84,18 @@ object Basic {
var gramLoss = List[Symbol]()
for (i <- 0 until gram.listOutputs().length) {
val gvar = Symbol.Variable(s"target_gram_$i")
- gramLoss = gramLoss :+ Symbol.sum()(Symbol.square()(gvar -
gram.get(i))())()
+ gramLoss = gramLoss :+ Symbol.api.sum(Some(
+ Symbol.api.square(Some(gvar - gram.get(i)))
+ ))
}
val cvar = Symbol.Variable("target_content")
- val contentLoss = Symbol.sum()(Symbol.square()(cvar - content)())()
+ val contentLoss = Symbol.api.sum(Some(Symbol.api.square(Some(cvar -
content))))
(Symbol.Group(gramLoss: _*), contentLoss)
}
def getContentModule(prefix: String, dShape: Shape,
ctx: Context, params: Map[String, NDArray]): Module = {
- val (_, sym) = ModelVgg19.getVggSymbol(prefix, true)
+ val (_, sym) = ModelVgg19.getVggSymbol(prefix + "_", true)
val init = new PretrainedInit(prefix, params)
new Module(symbol = sym, context = ctx,
dataShapes = Map(s"${prefix}_data" -> dShape),
@@ -109,7 +105,7 @@ object Basic {
def getLossModule(prefix: String, dShape: Shape,
ctx: Context, params: Map[String, NDArray]): (Module, List[Int]) = {
val inputShape = Map(s"${prefix}_data" -> dShape)
- val (style, content) = ModelVgg19.getVggSymbol(prefix)
+ val (style, content) = ModelVgg19.getVggSymbol(prefix + "_")
val (gram, gScale) = styleGramSymbol(inputShape, style)
val (styleLoss, contentLoss) = getLoss(gram, content)
val sym = Symbol.Group(styleLoss, contentLoss)
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostInference.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostInference.scala
index 0feb73d3036..5410fb9edc7 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostInference.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostInference.scala
@@ -17,19 +17,43 @@
package org.apache.mxnetexamples.neuralstyle.end2end
-import org.slf4j.LoggerFactory
+import org.apache.mxnet.{Context, Shape}
import org.kohsuke.args4j.{CmdLineParser, Option}
+import org.slf4j.LoggerFactory
+
import scala.collection.JavaConverters._
-import org.apache.mxnet.Shape
-import org.apache.mxnet.Context
-/**
- * @author Depeng Liang
- */
object BoostInference {
private val logger = LoggerFactory.getLogger(classOf[BoostInference])
+ def runInference(modelPath: String, outputPath: String, guassianRadius : Int,
+ inputImage : String, ctx : Context): Unit = {
+ val dShape = Shape(1, 3, 480, 640)
+ val clipNorm = 1.0f * dShape.product
+ // generator
+ val gens = Array(
+ GenV4.getModule("g0", dShape, ctx, isTrain = false),
+ GenV3.getModule("g1", dShape, ctx, isTrain = false),
+ GenV3.getModule("g2", dShape, ctx, isTrain = false),
+ GenV4.getModule("g3", dShape, ctx, isTrain = false)
+ )
+ gens.zipWithIndex.foreach { case (gen, i) =>
+ gen.loadParams(s"$modelPath/$i/v3_0002-0026000.params")
+ }
+
+ val contentNp =
+ DataProcessing.preprocessContentImage(s"$inputImage", dShape, ctx)
+ var data = Array(contentNp)
+ for (i <- 0 until gens.length) {
+ gens(i).forward(data.takeRight(1))
+ val newImg = gens(i).getOutputs()(0)
+ data :+= newImg
+ DataProcessing.saveImage(newImg, s"$outputPath/out_$i.jpg",
guassianRadius)
+ logger.info(s"Converted image: $outputPath/out_$i.jpg")
+ }
+ }
+
def main(args: Array[String]): Unit = {
val stce = new BoostInference
val parser: CmdLineParser = new CmdLineParser(stce)
@@ -39,30 +63,10 @@ object BoostInference {
&& stce.inputImage != null
&& stce.outputPath != null)
- val dShape = Shape(1, 3, 480, 640)
- val clipNorm = 1.0f * dShape.product
val ctx = if (stce.gpu == -1) Context.cpu() else Context.gpu(stce.gpu)
- // generator
- val gens = Array(
- GenV4.getModule("g0", dShape, ctx, isTrain = false),
- GenV3.getModule("g1", dShape, ctx, isTrain = false),
- GenV3.getModule("g2", dShape, ctx, isTrain = false),
- GenV4.getModule("g3", dShape, ctx, isTrain = false)
- )
- gens.zipWithIndex.foreach { case (gen, i) =>
- gen.loadParams(s"${stce.modelPath}/$i/v3_0002-0026000.params")
- }
+ runInference(stce.modelPath, stce.outputPath, stce.guassianRadius,
stce.inputImage, ctx)
- val contentNp =
- DataProcessing.preprocessContentImage(s"${stce.inputImage}", dShape,
ctx)
- var data = Array(contentNp)
- for (i <- 0 until gens.length) {
- gens(i).forward(data.takeRight(1))
- val newImg = gens(i).getOutputs()(0)
- data :+= newImg
- DataProcessing.saveImage(newImg, s"${stce.outputPath}/out_${i}.jpg",
stce.guassianRadius)
- }
} catch {
case ex: Exception => {
logger.error(ex.getMessage, ex)
@@ -74,7 +78,7 @@ object BoostInference {
}
class BoostInference {
- @Option(name = "--model-path", usage = "the save model path")
+ @Option(name = "--model-path", usage = "the saved model path")
private val modelPath: String = null
@Option(name = "--input-image", usage = "the style image")
private val inputImage: String = null
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostTrain.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostTrain.scala
index 8b5549de4af..08b4c85d2c5 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostTrain.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostTrain.scala
@@ -17,24 +17,17 @@
package org.apache.mxnetexamples.neuralstyle.end2end
-import org.slf4j.LoggerFactory
+import java.io.File
+
+import org.apache.mxnet.{Context, Executor, NDArray, Shape, Symbol}
+import org.apache.mxnet.optimizer.SGD
import org.kohsuke.args4j.{CmdLineParser, Option}
+import org.slf4j.LoggerFactory
+
import scala.collection.JavaConverters._
-import org.apache.mxnet.NDArray
-import org.apache.mxnet.Shape
-import org.apache.mxnet.Context
-import org.apache.mxnet.DataBatch
-import org.apache.mxnet.Symbol
-import org.apache.mxnet.Executor
-import org.apache.mxnet.optimizer.SGD
-import java.io.File
-import javax.imageio.ImageIO
import scala.util.Random
-import org.apache.mxnet.optimizer.Adam
-/**
- * @author Depeng Liang
- */
+
object BoostTrain {
private val logger = LoggerFactory.getLogger(classOf[BoostTrain])
@@ -46,12 +39,13 @@ object BoostTrain {
val nChannel = img.shape(1)
val sImg = Symbol.Variable("img")
val sKernel = Symbol.Variable("kernel")
- val channels = Symbol.SliceChannel()(sImg)(Map("num_outputs" -> nChannel))
- val out = Symbol.Concat()((0 until nChannel).map { i =>
- Symbol.Convolution()()(Map("data" -> channels.get(i), "weight" ->
sKernel,
- "num_filter" -> 1, "kernel" -> "(3,3)", "pad" -> "(1,1)",
- "no_bias" -> true, "stride" -> "(1,1)"))
- }.toArray: _*)() * tvWeight
+ val channels = Symbol.api.SliceChannel(data = Some(sImg), num_outputs =
nChannel)
+ val toConcat = (0 until nChannel).map( i =>
+ Symbol.api.Convolution(data = Some(channels.get(i)), weight =
Some(sKernel),
+ num_filter = 1, kernel = Shape(3, 3), pad = Some(Shape(1, 1)),
+ no_bias = Some(true), stride = Some(Shape(1, 1)))
+ ).toArray
+ val out = Symbol.api.Concat(data = toConcat, num_args = toConcat.length) *
tvWeight
val kernel = {
val tmp = NDArray.empty(Shape(1, 1, 3, 3), ctx)
tmp.set(Array[Float](0, -1, 0, -1, 4, -1, 0, -1, 0))
@@ -60,130 +54,135 @@ object BoostTrain {
out.bind(ctx, Map("img" -> img, "kernel" -> kernel))
}
- def main(args: Array[String]): Unit = {
- val stin = new BoostTrain
- val parser: CmdLineParser = new CmdLineParser(stin)
- try {
- parser.parseArgument(args.toList.asJava)
- assert(stin.dataPath != null
- && stin.vggModelPath != null
- && stin.saveModelPath != null
- && stin.styleImage != null)
- // params
- val vggParams = NDArray.load2Map(stin.vggModelPath)
- val styleWeight = 1.2f
- val contentWeight = 10f
- val dShape = Shape(1, 3, 384, 384)
- val clipNorm = 0.05f * dShape.product
- val modelPrefix = "v3"
- val ctx = if (stin.gpu == -1) Context.cpu() else Context.gpu(stin.gpu)
-
- // init style
- val styleNp = DataProcessing.preprocessStyleImage(stin.styleImage,
dShape, ctx)
- var styleMod = Basic.getStyleModule("style", dShape, ctx, vggParams)
- styleMod.forward(Array(styleNp))
- val styleArray = styleMod.getOutputs().map(_.copyTo(Context.cpu()))
- styleMod.dispose()
- styleMod = null
-
- // content
- val contentMod = Basic.getContentModule("content", dShape, ctx,
vggParams)
-
- // loss
- val (loss, gScale) = Basic.getLossModule("loss", dShape, ctx, vggParams)
- val extraArgs = (0 until styleArray.length)
- .map( i => s"target_gram_$i" ->
styleArray(i)).toMap
- loss.setParams(extraArgs)
- var gradArray = Array[NDArray]()
- for (i <- 0 until styleArray.length) {
- gradArray = gradArray :+ (NDArray.ones(Shape(1), ctx) * (styleWeight /
gScale(i)))
- }
- gradArray = gradArray :+ (NDArray.ones(Shape(1), ctx) * contentWeight)
-
- // generator
- val gens = Array(
- GenV4.getModule("g0", dShape, ctx),
- GenV3.getModule("g1", dShape, ctx),
- GenV3.getModule("g2", dShape, ctx),
- GenV4.getModule("g3", dShape, ctx)
- )
- gens.foreach { gen =>
- val opt = new SGD(learningRate = 1e-4f,
- momentum = 0.9f,
- wd = 5e-3f,
- clipGradient = 5f)
- gen.initOptimizer(opt)
- }
+ def runTraining(dataPath : String, vggModelPath: String, ctx : Context,
+ styleImage : String, saveModelPath : String) : Unit = {
+ // params
+ val vggParams = NDArray.load2Map(vggModelPath)
+ val styleWeight = 1.2f
+ val contentWeight = 10f
+ val dShape = Shape(1, 3, 384, 384)
+ val clipNorm = 0.05f * dShape.product
+ val modelPrefix = "v3"
+ // init style
+ val styleNp = DataProcessing.preprocessStyleImage(styleImage, dShape, ctx)
+ var styleMod = Basic.getStyleModule("style", dShape, ctx, vggParams)
+ styleMod.forward(Array(styleNp))
+ val styleArray = styleMod.getOutputs().map(_.copyTo(Context.cpu()))
+ styleMod.dispose()
+ styleMod = null
+
+ // content
+ val contentMod = Basic.getContentModule("content", dShape, ctx, vggParams)
+
+ // loss
+ val (loss, gScale) = Basic.getLossModule("loss", dShape, ctx, vggParams)
+ val extraArgs = (0 until styleArray.length)
+ .map( i => s"target_gram_$i" -> styleArray(i)).toMap
+ loss.setParams(extraArgs)
+ var gradArray = Array[NDArray]()
+ for (i <- 0 until styleArray.length) {
+ gradArray = gradArray :+ (NDArray.ones(Shape(1), ctx) * (styleWeight /
gScale(i)))
+ }
+ gradArray = gradArray :+ (NDArray.ones(Shape(1), ctx) * contentWeight)
+
+ // generator
+ val gens = Array(
+ GenV4.getModule("g0", dShape, ctx),
+ GenV3.getModule("g1", dShape, ctx),
+ GenV3.getModule("g2", dShape, ctx),
+ GenV4.getModule("g3", dShape, ctx)
+ )
+ gens.foreach { gen =>
+ val opt = new SGD(learningRate = 1e-4f,
+ momentum = 0.9f,
+ wd = 5e-3f,
+ clipGradient = 5f)
+ gen.initOptimizer(opt)
+ }
- var filelist = new File(stin.dataPath).list().toList
- val numImage = filelist.length
- logger.info(s"Dataset size: $numImage")
+ var filelist = new File(dataPath).list().toList
+ val numImage = filelist.length
+ logger.info(s"Dataset size: $numImage")
- val tvWeight = 1e-2f
+ val tvWeight = 1e-2f
- val startEpoch = 0
- val endEpoch = 3
+ val startEpoch = 0
+ val endEpoch = 3
- for (k <- 0 until gens.length) {
- val path = new File(s"${stin.saveModelPath}/$k")
- if (!path.exists()) path.mkdir()
- }
+ for (k <- 0 until gens.length) {
+ val path = new File(s"${saveModelPath}/$k")
+ if (!path.exists()) path.mkdir()
+ }
- // train
- for (i <- startEpoch until endEpoch) {
- filelist = Random.shuffle(filelist)
- for (idx <- filelist.indices) {
- var dataArray = Array[NDArray]()
- var lossGradArray = Array[NDArray]()
- val data =
-
DataProcessing.preprocessContentImage(s"${stin.dataPath}/${filelist(idx)}",
dShape, ctx)
- dataArray = dataArray :+ data
- // get content
- contentMod.forward(Array(data))
- // set target content
- loss.setParams(Map("target_content" -> contentMod.getOutputs()(0)))
- // gen_forward
- for (k <- 0 until gens.length) {
- gens(k).forward(dataArray.takeRight(1))
- dataArray = dataArray :+ gens(k).getOutputs()(0)
- // loss forward
- loss.forward(dataArray.takeRight(1))
- loss.backward(gradArray)
- lossGradArray = lossGradArray :+ loss.getInputGrads()(0)
- }
- val grad = NDArray.zeros(data.shape, ctx)
- for (k <- gens.length - 1 to 0 by -1) {
- val tvGradExecutor = getTvGradExecutor(gens(k).getOutputs()(0),
ctx, tvWeight)
- tvGradExecutor.forward()
- grad += lossGradArray(k) + tvGradExecutor.outputs(0)
- val gNorm = NDArray.norm(grad)
- if (gNorm.toScalar > clipNorm) {
- grad *= clipNorm / gNorm.toScalar
- }
- gens(k).backward(Array(grad))
- gens(k).update()
- gNorm.dispose()
- tvGradExecutor.dispose()
+ // train
+ for (i <- startEpoch until endEpoch) {
+ filelist = Random.shuffle(filelist)
+ for (idx <- filelist.indices) {
+ var dataArray = Array[NDArray]()
+ var lossGradArray = Array[NDArray]()
+ val data =
+
DataProcessing.preprocessContentImage(s"${dataPath}/${filelist(idx)}", dShape,
ctx)
+ dataArray = dataArray :+ data
+ // get content
+ contentMod.forward(Array(data))
+ // set target content
+ loss.setParams(Map("target_content" -> contentMod.getOutputs()(0)))
+ // gen_forward
+ for (k <- 0 until gens.length) {
+ gens(k).forward(dataArray.takeRight(1))
+ dataArray = dataArray :+ gens(k).getOutputs()(0)
+ // loss forward
+ loss.forward(dataArray.takeRight(1))
+ loss.backward(gradArray)
+ lossGradArray = lossGradArray :+ loss.getInputGrads()(0)
+ }
+ val grad = NDArray.zeros(data.shape, ctx)
+ for (k <- gens.length - 1 to 0 by -1) {
+ val tvGradExecutor = getTvGradExecutor(gens(k).getOutputs()(0), ctx,
tvWeight)
+ tvGradExecutor.forward()
+ grad += lossGradArray(k) + tvGradExecutor.outputs(0)
+ val gNorm = NDArray.norm(grad)
+ if (gNorm.toScalar > clipNorm) {
+ grad *= clipNorm / gNorm.toScalar
}
- grad.dispose()
- if (idx % 20 == 0) {
- logger.info(s"Epoch $i: Image $idx")
- for (k <- 0 until gens.length) {
- val n = NDArray.norm(gens(k).getInputGrads()(0))
- logger.info(s"Data Norm : ${n.toScalar / dShape.product}")
- n.dispose()
- }
+ gens(k).backward(Array(grad))
+ gens(k).update()
+ gNorm.dispose()
+ tvGradExecutor.dispose()
+ }
+ grad.dispose()
+ if (idx % 20 == 0) {
+ logger.info(s"Epoch $i: Image $idx")
+ for (k <- 0 until gens.length) {
+ val n = NDArray.norm(gens(k).getInputGrads()(0))
+ logger.info(s"Data Norm : ${n.toScalar / dShape.product}")
+ n.dispose()
}
- if (idx % 1000 == 0) {
- for (k <- 0 until gens.length) {
- gens(k).saveParams(
- s"${stin.saveModelPath}/$k/${modelPrefix}_" +
- s"${"%04d".format(i)}-${"%07d".format(idx)}.params")
- }
+ }
+ if (idx % 1000 == 0) {
+ for (k <- 0 until gens.length) {
+ gens(k).saveParams(
+ s"${saveModelPath}/$k/${modelPrefix}_" +
+ s"${"%04d".format(i)}-${"%07d".format(idx)}.params")
}
- data.dispose()
}
+ data.dispose()
}
+ }
+ }
+
+ def main(args: Array[String]): Unit = {
+ val stin = new BoostTrain
+ val parser: CmdLineParser = new CmdLineParser(stin)
+ try {
+ parser.parseArgument(args.toList.asJava)
+ assert(stin.dataPath != null
+ && stin.vggModelPath != null
+ && stin.saveModelPath != null
+ && stin.styleImage != null)
+
+ val ctx = if (stin.gpu == -1) Context.cpu() else Context.gpu(stin.gpu)
+ runTraining(stin.dataPath, stin.vggModelPath, ctx, stin.styleImage,
stin.saveModelPath)
} catch {
case ex: Exception => {
logger.error(ex.getMessage, ex)
@@ -197,9 +196,9 @@ object BoostTrain {
class BoostTrain {
@Option(name = "--data-path", usage = "the input train data path")
private val dataPath: String = null
- @Option(name = "--vgg--model-path", usage = "the pretrained model to use:
['vgg']")
+ @Option(name = "--vgg-model-path", usage = "the pretrained model to use:
['vgg']")
private val vggModelPath: String = null
- @Option(name = "--save--model-path", usage = "the save model path")
+ @Option(name = "--save-model-path", usage = "the save model path")
private val saveModelPath: String = null
@Option(name = "--style-image", usage = "the style image")
private val styleImage: String = null
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/DataProcessing.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/DataProcessing.scala
index 94d05bb7d57..80a009ea40c 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/DataProcessing.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/DataProcessing.scala
@@ -17,19 +17,14 @@
package org.apache.mxnetexamples.neuralstyle.end2end
-import com.sksamuel.scrimage.Image
-import com.sksamuel.scrimage.Pixel
+import java.io.File
+
+import com.sksamuel.scrimage.{Image, Pixel}
import com.sksamuel.scrimage.filter.GaussianBlurFilter
import com.sksamuel.scrimage.nio.JpegWriter
-import org.apache.mxnet.Context
-import org.apache.mxnet.NDArray
-import java.io.File
-import org.apache.mxnet.Shape
-import scala.util.Random
+import org.apache.mxnet.{Context, NDArray, Shape}
+
-/**
- * @author Depeng Liang
- */
object DataProcessing {
def preprocessContentImage(path: String,
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/GenV3.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/GenV3.scala
index b90e9f0e317..d7ab59e2840 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/GenV3.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/GenV3.scala
@@ -17,34 +17,33 @@
package org.apache.mxnetexamples.neuralstyle.end2end
-import org.apache.mxnet.Symbol
-import org.apache.mxnet.Shape
-import org.apache.mxnet.Context
-import org.apache.mxnet.Xavier
+import org.apache.mxnet.{Context, Shape, Symbol, Xavier}
+
-/**
- * @author Depeng Liang
- */
object GenV3 {
def Conv(data: Symbol, numFilter: Int, kernel: (Int, Int) = (5, 5),
- pad: (Int, Int) = (2, 2), stride: (Int, Int) = (2, 2)): Symbol = {
- var sym = Symbol.Convolution()()(Map("data" -> data, "num_filter" ->
numFilter,
- "kernel" -> s"$kernel", "stride" -> s"$stride", "pad" -> s"$pad",
"no_bias" -> false))
- sym = Symbol.BatchNorm()()(Map("data" -> sym, "fix_gamma" -> false))
- sym = Symbol.LeakyReLU()()(Map("data" -> sym, "act_type" -> "leaky"))
- sym
+ pad: (Int, Int) = (2, 2), stride: (Int, Int) = (2, 2)): Symbol = {
+ val sym1 = Symbol.api.Convolution(data = Some(data), num_filter =
numFilter,
+ kernel = Shape(kernel._1, kernel._2), stride = Some(Shape(stride._1,
stride._2)),
+ pad = Some(Shape(pad._1, pad._2)), no_bias = Some(false))
+ val sym2 = Symbol.api.BatchNorm(data = Some(sym1), fix_gamma = Some(false))
+ val sym3 = Symbol.api.LeakyReLU(data = Some(sym2), act_type =
Some("leaky"))
+ sym2.dispose()
+ sym1.dispose()
+ sym3
}
def Deconv(data: Symbol, numFilter: Int, imHw: (Int, Int),
- kernel: (Int, Int) = (7, 7), pad: (Int, Int) = (2, 2), stride: (Int,
Int) = (2, 2),
- crop: Boolean = true, out: Boolean = false): Symbol = {
- var sym = Symbol.Deconvolution()()(Map("data" -> data, "num_filter" ->
numFilter,
- "kernel" -> s"$kernel", "stride" -> s"$stride", "pad" -> s"$pad",
"no_bias" -> true))
- if (crop) sym = Symbol.Crop()(sym)(
- Map("offset" -> "(1, 1)", "h_w" -> s"$imHw", "num_args" -> 1))
- sym = Symbol.BatchNorm()()(Map("data" -> sym, "fix_gamma" -> false))
- if (out == false) Symbol.LeakyReLU()()(Map("data" -> sym, "act_type" ->
"leaky"))
- else Symbol.Activation()()(Map("data" -> sym, "act_type" -> "tanh"))
+ kernel: (Int, Int) = (7, 7), pad: (Int, Int) = (2, 2), stride:
(Int, Int) = (2, 2),
+ crop: Boolean = true, out: Boolean = false): Symbol = {
+ var sym = Symbol.api.Deconvolution(data = Some(data), num_filter =
numFilter,
+ kernel = Shape(kernel._1, kernel._2), stride = Some(Shape(stride._1,
stride._2)),
+ pad = Some(Shape(pad._1, pad._2)), no_bias = Some(true))
+ if (crop) sym = Symbol.api.Crop(data = Array(sym), offset = Some(Shape(1,
1)),
+ h_w = Some(Shape(imHw._1, imHw._2)), num_args = 1)
+ sym = Symbol.api.BatchNorm(data = Some(sym), fix_gamma = Some(false))
+ if (out == false) Symbol.api.LeakyReLU(data = Some(sym), act_type =
Some("leaky"))
+ else Symbol.api.Activation(data = Some(sym), act_type = "tanh")
}
def getGenerator(prefix: String, imHw: (Int, Int)): Symbol = {
@@ -61,12 +60,12 @@ object GenV3 {
val conv5_1 = Conv(deconv2, 96, kernel = (3, 3), pad = (1, 1), stride =
(1, 1))
val deconv3 = Deconv(conv5_1, 3, imHw, kernel = (8, 8), pad = (3, 3), out
= true, crop = false)
val rawOut = (deconv3 * 128) + 128
- val norm = Symbol.SliceChannel()(rawOut)(Map("num_outputs" -> 3))
+ val norm = Symbol.api.SliceChannel(data = Some(rawOut), num_outputs = 3)
val rCh = norm.get(0) - 123.68f
val gCh = norm.get(1) - 116.779f
val bCh = norm.get(2) - 103.939f
- val normOut = Symbol.Concat()(rCh, gCh, bCh)() * 0.4f + data * 0.6f
- normOut
+ val normOut = Symbol.api.Concat(data = Array(rCh, gCh, bCh), num_args = 3)
+ normOut * 0.4f + data * 0.6f
}
def getModule(prefix: String, dShape: Shape, ctx: Context, isTrain: Boolean
= true): Module = {
@@ -77,9 +76,9 @@ object GenV3 {
else (dataShape, false, false)
}
val mod = new Module(symbol = sym, context = ctx,
- dataShapes = dataShapes,
- initializer = new Xavier(magnitude = 2f),
- forTraining = forTraining, inputsNeedGrad =
inputsNeedGrad)
+ dataShapes = dataShapes,
+ initializer = new Xavier(magnitude = 2f),
+ forTraining = forTraining, inputsNeedGrad = inputsNeedGrad)
mod
}
}
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/GenV4.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/GenV4.scala
index 876a0529b69..82fc9b6ce10 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/GenV4.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/GenV4.scala
@@ -17,78 +17,43 @@
package org.apache.mxnetexamples.neuralstyle.end2end
-import org.apache.mxnet.Symbol
-import org.apache.mxnet.Shape
-import org.apache.mxnet.Context
-import org.apache.mxnet.Xavier
+import org.apache.mxnet.{Context, Shape, Symbol, Xavier}
-/**
- * @author Depeng Liang
- */
-object GenV4 {
- def Conv(data: Symbol, numFilter: Int, kernel: (Int, Int) = (5, 5),
- pad: (Int, Int) = (2, 2), stride: (Int, Int) = (2, 2)): Symbol = {
- var sym = Symbol.Convolution()()(Map("data" -> data, "num_filter" ->
numFilter,
- "kernel" -> s"$kernel", "stride" -> s"$stride", "pad" -> s"$pad",
"no_bias" -> false))
- sym = Symbol.BatchNorm()()(Map("data" -> sym, "fix_gamma" -> false))
- sym = Symbol.LeakyReLU()()(Map("data" -> sym, "act_type" -> "leaky"))
- sym
- }
+object GenV4 {
- def Deconv(data: Symbol, numFilter: Int, imHw: (Int, Int), kernel: (Int,
Int) = (6, 6),
- pad: (Int, Int) = (2, 2), stride: (Int, Int) = (2, 2), out: Boolean =
false): Symbol = {
- var sym = Symbol.Deconvolution()()(Map("data" -> data, "num_filter" ->
numFilter,
- "kernel" -> s"$kernel", "stride" -> s"$stride", "pad" -> s"$pad",
"no_bias" -> true))
- sym = Symbol.BatchNorm()()(Map("data" -> sym, "fix_gamma" -> false))
- if (out == false) Symbol.LeakyReLU()()(Map("data" -> sym, "act_type" ->
"leaky"))
- else Symbol.Activation()()(Map("data" -> sym, "act_type" -> "tanh"))
+ def Conv(data: Symbol, numFilter: Int, workspace : Long, kernel: (Int, Int)
= (5, 5),
+ pad: (Int, Int) = (2, 2)): Symbol = {
+ val sym1 = Symbol.api.Convolution(data = Some(data), num_filter =
numFilter,
+ kernel = Shape(kernel._1, kernel._2), workspace = Some(workspace),
+ pad = Some(Shape(pad._1, pad._2)), no_bias = Some(false))
+ val sym2 = Symbol.api.BatchNorm(data = Some(sym1), fix_gamma = Some(false))
+ val sym3 = Symbol.api.LeakyReLU(data = Some(sym2), act_type =
Some("leaky"))
+ sym2.dispose()
+ sym1.dispose()
+ sym3
}
def getGenerator(prefix: String, imHw: (Int, Int)): Symbol = {
val data = Symbol.Variable(s"${prefix}_data")
- var conv1_1 = Symbol.Convolution()()(Map("data" -> data, "num_filter" ->
48,
- "kernel" -> "(5, 5)", "pad" -> "(2, 2)", "no_bias" -> false,
"workspace" -> 4096))
- conv1_1 = Symbol.BatchNorm()()(Map("data" -> conv1_1, "fix_gamma" ->
false))
- conv1_1 = Symbol.LeakyReLU()()(Map("data" -> conv1_1, "act_type" ->
"leaky"))
-
- var conv2_1 = Symbol.Convolution()()(Map("data" -> conv1_1, "num_filter"
-> 32,
- "kernel" -> "(5, 5)", "pad" -> "(2, 2)", "no_bias" -> false,
"workspace" -> 4096))
- conv2_1 = Symbol.BatchNorm()()(Map("data" -> conv2_1, "fix_gamma" ->
false))
- conv2_1 = Symbol.LeakyReLU()()(Map("data" -> conv2_1, "act_type" ->
"leaky"))
-
- var conv3_1 = Symbol.Convolution()()(Map("data" -> conv2_1, "num_filter"
-> 64,
- "kernel" -> "(3, 3)", "pad" -> "(1, 1)", "no_bias" -> false,
"workspace" -> 4096))
- conv3_1 = Symbol.BatchNorm()()(Map("data" -> conv3_1, "fix_gamma" ->
false))
- conv3_1 = Symbol.LeakyReLU()()(Map("data" -> conv3_1, "act_type" ->
"leaky"))
-
- var conv4_1 = Symbol.Convolution()()(Map("data" -> conv3_1, "num_filter"
-> 32,
- "kernel" -> "(5, 5)", "pad" -> "(2, 2)", "no_bias" -> false,
"workspace" -> 4096))
- conv4_1 = Symbol.BatchNorm()()(Map("data" -> conv4_1, "fix_gamma" ->
false))
- conv4_1 = Symbol.LeakyReLU()()(Map("data" -> conv4_1, "act_type" ->
"leaky"))
-
- var conv5_1 = Symbol.Convolution()()(Map("data" -> conv4_1, "num_filter"
-> 48,
- "kernel" -> "(5, 5)", "pad" -> "(2, 2)", "no_bias" -> false,
"workspace" -> 4096))
- conv5_1 = Symbol.BatchNorm()()(Map("data" -> conv5_1, "fix_gamma" ->
false))
- conv5_1 = Symbol.LeakyReLU()()(Map("data" -> conv5_1, "act_type" ->
"leaky"))
-
- var conv6_1 = Symbol.Convolution()()(Map("data" -> conv5_1, "num_filter"
-> 32,
- "kernel" -> "(5, 5)", "pad" -> "(2, 2)", "no_bias" -> true,
"workspace" -> 4096))
- conv6_1 = Symbol.BatchNorm()()(Map("data" -> conv6_1, "fix_gamma" ->
false))
- conv6_1 = Symbol.LeakyReLU()()(Map("data" -> conv6_1, "act_type" ->
"leaky"))
-
- var out = Symbol.Convolution()()(Map("data" -> conv6_1, "num_filter" -> 3,
"kernel" -> "(3, 3)",
- "pad" -> "(1, 1)", "no_bias" -> true, "workspace" -> 4096))
- out = Symbol.BatchNorm()()(Map("data" -> out, "fix_gamma" -> false))
- out = Symbol.Activation()()(Map("data" -> out, "act_type" -> "tanh"))
+ var conv1_1 = Conv(data, 48, 4096)
+ val conv2_1 = Conv(conv1_1, 32, 4096)
+ var conv3_1 = Conv(conv2_1, 64, 4096, (3, 3), (1, 1))
+ var conv4_1 = Conv(conv3_1, 32, 4096)
+ var conv5_1 = Conv(conv4_1, 48, 4096)
+ var conv6_1 = Conv(conv5_1, 32, 4096)
+ var out = Symbol.api.Convolution(data = Some(conv6_1), num_filter = 3,
kernel = Shape(3, 3),
+ pad = Some(Shape(1, 1)), no_bias = Some(true), workspace = Some(4096))
+ out = Symbol.api.BatchNorm(data = Some(out), fix_gamma = Some(false))
+ out = Symbol.api.Activation(data = Some(out), act_type = "tanh")
val rawOut = (out * 128) + 128
- val norm = Symbol.SliceChannel()(rawOut)(Map("num_outputs" -> 3))
+ val norm = Symbol.api.SliceChannel(data = Some(rawOut), num_outputs = 3)
val rCh = norm.get(0) - 123.68f
val gCh = norm.get(1) - 116.779f
val bCh = norm.get(2) - 103.939f
- val normOut = Symbol.Concat()(rCh, gCh, bCh)() * 0.4f + data * 0.6f
- normOut
+ val normOut = Symbol.api.Concat(data = Array(rCh, gCh, bCh), num_args = 3)
+ normOut * 0.4f + data * 0.6f
}
def getModule(prefix: String, dShape: Shape, ctx: Context, isTrain: Boolean
= true): Module = {
@@ -99,9 +64,9 @@ object GenV4 {
else (dataShape, false, false)
}
val mod = new Module(symbol = sym, context = ctx,
- dataShapes = dataShapes,
- initializer = new Xavier(magnitude = 2f),
- forTraining = forTraining, inputsNeedGrad =
inputsNeedGrad)
+ dataShapes = dataShapes,
+ initializer = new Xavier(magnitude = 2f),
+ forTraining = forTraining, inputsNeedGrad = inputsNeedGrad)
mod
}
}
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/ModelVgg19.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/ModelVgg19.scala
deleted file mode 100644
index 6044847be4a..00000000000
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/ModelVgg19.scala
+++ /dev/null
@@ -1,111 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mxnetexamples.neuralstyle.end2end
-
-import org.apache.mxnet.Executor
-import org.apache.mxnet.NDArray
-import org.apache.mxnet.Symbol
-
-
-object ModelVgg19 {
-
- def getVggSymbol(prefix: String, contentOnly: Boolean = false): (Symbol,
Symbol) = {
- // declare symbol
- val data = Symbol.Variable(s"${prefix}_data")
- val conv1_1 = Symbol.Convolution(s"${prefix}_conv1_1")()(Map("data" ->
data,
- "num_filter" -> 64, "pad" -> "(1,1)", "kernel" ->
"(3,3)",
- "stride" -> "(1,1)", "no_bias" -> false,
"workspace" -> 1024))
- val relu1_1 = Symbol.Activation(s"${prefix}_relu1_1")()(Map("data" ->
conv1_1,
- "act_type" -> "relu"))
- val conv1_2 = Symbol.Convolution(s"${prefix}_conv1_2")()(Map("data" ->
relu1_1,
- "num_filter" -> 64, "pad" -> "(1,1)", "kernel" ->
"(3,3)",
- "stride" -> "(1,1)", "no_bias" -> false,
"workspace" -> 1024))
- val relu1_2 = Symbol.Activation(s"${prefix}_relu1_2")()(Map("data" ->
conv1_2,
- "act_type" -> "relu"))
- val pool1 = Symbol.Pooling(s"${prefix}_pool1")()(Map("data" -> relu1_2 ,
"pad" -> "(0,0)",
- "kernel" -> "(2,2)", "stride" -> "(2,2)",
"pool_type" -> "avg"))
- val conv2_1 = Symbol.Convolution(s"${prefix}_conv2_1")()(Map("data" ->
pool1,
- "num_filter" -> 128, "pad" -> "(1,1)", "kernel" ->
"(3,3)",
- "stride" -> "(1,1)", "no_bias" -> false,
"workspace" -> 1024))
- val relu2_1 = Symbol.Activation(s"${prefix}_relu2_1")()(Map("data" ->
conv2_1,
- "act_type" -> "relu"))
- val conv2_2 = Symbol.Convolution(s"${prefix}_conv2_2")()(Map("data" ->
relu2_1,
- "num_filter" -> 128, "pad" -> "(1,1)", "kernel" ->
"(3,3)",
- "stride" -> "(1,1)", "no_bias" -> false,
"workspace" -> 1024))
- val relu2_2 = Symbol.Activation(s"${prefix}_relu2_2")()(Map("data" ->
conv2_2,
- "act_type" -> "relu"))
- val pool2 = Symbol.Pooling("pool2")()(Map("data" -> relu2_2 , "pad" ->
"(0,0)",
- "kernel" -> "(2,2)", "stride" -> "(2,2)",
"pool_type" -> "avg"))
- val conv3_1 = Symbol.Convolution(s"${prefix}_conv3_1")()(Map("data" ->
pool2,
- "num_filter" -> 256, "pad" -> "(1,1)", "kernel" ->
"(3,3)",
- "stride" -> "(1,1)", "no_bias" -> false,
"workspace" -> 1024))
- val relu3_1 = Symbol.Activation(s"${prefix}_relu3_1")()(Map("data" ->
conv3_1,
- "act_type" -> "relu"))
- val conv3_2 = Symbol.Convolution(s"${prefix}_conv3_2")()(Map("data" ->
relu3_1,
- "num_filter" -> 256, "pad" -> "(1,1)", "kernel" ->
"(3,3)",
- "stride" -> "(1,1)", "no_bias" -> false,
"workspace" -> 1024))
- val relu3_2 = Symbol.Activation(s"${prefix}_relu3_2")()(Map("data" ->
conv3_2,
- "act_type" -> "relu"))
- val conv3_3 = Symbol.Convolution(s"${prefix}_conv3_3")()(Map("data" ->
relu3_2,
- "num_filter" -> 256, "pad" -> "(1,1)", "kernel" ->
"(3,3)",
- "stride" -> "(1,1)", "no_bias" -> false,
"workspace" -> 1024))
- val relu3_3 = Symbol.Activation(s"${prefix}_relu3_3")()(Map("data" ->
conv3_3,
- "act_type" -> "relu"))
- val conv3_4 = Symbol.Convolution(s"${prefix}_conv3_4")()(Map("data" ->
relu3_3,
- "num_filter" -> 256, "pad" -> "(1,1)", "kernel" ->
"(3,3)",
- "stride" -> "(1,1)", "no_bias" -> false,
"workspace" -> 1024))
- val relu3_4 = Symbol.Activation(s"${prefix}_relu3_4")()(Map("data" ->
conv3_4 ,
- "act_type" -> "relu"))
- val pool3 = Symbol.Pooling(s"${prefix}_pool3")()(Map("data" -> relu3_4,
- "pad" -> "(0,0)", "kernel" -> "(2,2)", "stride" ->
"(2,2)",
- "pool_type" -> "avg"))
- val conv4_1 = Symbol.Convolution(s"${prefix}_conv4_1")()(Map("data" ->
pool3,
- "num_filter" -> 512, "pad" -> "(1,1)", "kernel" ->
"(3,3)",
- "stride" -> "(1,1)", "no_bias" -> false,
"workspace" -> 1024))
- val relu4_1 = Symbol.Activation(s"${prefix}_relu4_1")()(Map("data" ->
conv4_1,
- "act_type" -> "relu"))
- val conv4_2 = Symbol.Convolution(s"${prefix}_conv4_2")()(Map("data" ->
relu4_1,
- "num_filter" -> 512, "pad" -> "(1,1)", "kernel" ->
"(3,3)",
- "stride" -> "(1,1)", "no_bias" -> false,
"workspace" -> 1024))
- val relu4_2 = Symbol.Activation(s"${prefix}_relu4_2")()(Map("data" ->
conv4_2,
- "act_type" -> "relu"))
- val conv4_3 = Symbol.Convolution(s"${prefix}_conv4_3")()(Map("data" ->
relu4_2,
- "num_filter" -> 512, "pad" -> "(1,1)", "kernel" ->
"(3,3)",
- "stride" -> "(1,1)", "no_bias" -> false,
"workspace" -> 1024))
- val relu4_3 = Symbol.Activation(s"${prefix}_relu4_3")()(Map("data" ->
conv4_3,
- "act_type" -> "relu"))
- val conv4_4 = Symbol.Convolution(s"${prefix}_conv4_4")()(Map("data" ->
relu4_3,
- "num_filter" -> 512, "pad" -> "(1,1)", "kernel" ->
"(3,3)",
- "stride" -> "(1,1)", "no_bias" -> false,
"workspace" -> 1024))
- val relu4_4 = Symbol.Activation(s"${prefix}_relu4_4")()(Map("data" ->
conv4_4,
- "act_type" -> "relu"))
- val pool4 = Symbol.Pooling(s"${prefix}_pool4")()(Map("data" -> relu4_4,
- "pad" -> "(0,0)", "kernel" -> "(2,2)", "stride" ->
"(2,2)",
- "pool_type" -> "avg"))
- val conv5_1 = Symbol.Convolution(s"${prefix}_conv5_1")()(Map("data" ->
pool4,
- "num_filter" -> 512, "pad" -> "(1,1)", "kernel" ->
"(3,3)",
- "stride" -> "(1,1)", "no_bias" -> false,
"workspace" -> 1024))
- val relu5_1 = Symbol.Activation(s"${prefix}_relu5_1")()(Map("data" ->
conv5_1,
- "act_type" -> "relu"))
-
- // style and content layers
- val style = if (contentOnly) null else Symbol.Group(relu1_1, relu2_1,
relu3_1, relu4_1, relu5_1)
- val content = Symbol.Group(relu4_2)
- (style, content)
- }
-}
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/Module.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/Module.scala
index d681b16c5af..1d11f886406 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/Module.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/Module.scala
@@ -17,20 +17,9 @@
package org.apache.mxnetexamples.neuralstyle.end2end
-import org.apache.mxnet.Context
+import org.apache.mxnet.{Context, Initializer, NDArray, Optimizer, Shape,
Symbol, Uniform}
import org.slf4j.LoggerFactory
-import org.apache.mxnet.Symbol
-import org.apache.mxnet.NDArray
-import org.apache.mxnet.Optimizer
-import org.apache.mxnet.Executor
-import org.apache.mxnet.Shape
-import org.apache.mxnet.Uniform
-import org.apache.mxnet.Initializer
-import org.apache.mxnet.DataBatch
-
-/**
- * @author Depeng Liang
- */
+
class Module(symbol: Symbol,
context: Context,
dataShapes: Map[String, Shape],
diff --git
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
index 8ab3a4b364a..96820ce4e98 100644
---
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
+++
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala
@@ -18,41 +18,38 @@
package org.apache.mxnetexamples.gan
import java.io.File
-import java.net.URL
-
-import org.apache.commons.io.FileUtils
import org.apache.mxnet.Context
import org.apache.mxnetexamples.Util
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
import org.slf4j.LoggerFactory
import scala.sys.process.Process
+@Ignore
class GanExampleSuite extends FunSuite with BeforeAndAfterAll{
private val logger = LoggerFactory.getLogger(classOf[GanExampleSuite])
test("Example CI: Test GAN MNIST") {
- if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
- System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
- logger.info("Downloading mnist model")
- val baseUrl =
"https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci"
- val tempDirPath = System.getProperty("java.io.tmpdir")
- val modelDirPath = tempDirPath + File.separator + "mnist/"
- logger.info("tempDirPath: %s".format(tempDirPath))
- Util.downloadUrl(baseUrl + "/mnist/mnist.zip",
- tempDirPath + "/mnist/mnist.zip")
- // TODO: Need to confirm with Windows
- Process("unzip " + tempDirPath + "/mnist/mnist.zip -d "
- + tempDirPath + "/mnist/") !
-
- val context = Context.gpu()
-
- val output = GanMnist.runTraining(modelDirPath, context, modelDirPath, 5)
- Process("rm -rf " + modelDirPath) !
-
- assert(output >= 0.0f)
- } else {
- logger.info("GPU test only, skipped...")
- }
+ if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+ System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+ logger.info("Downloading mnist model")
+ val baseUrl =
"https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci"
+ val tempDirPath = System.getProperty("java.io.tmpdir")
+ val modelDirPath = tempDirPath + File.separator + "mnist/"
+ logger.info("tempDirPath: %s".format(tempDirPath))
+ Util.downloadUrl(baseUrl + "/mnist/mnist.zip", tempDirPath +
"/mnist/mnist.zip")
+ // TODO: Need to confirm with Windows
+ Process("unzip " + tempDirPath + "/mnist/mnist.zip -d "
+ + tempDirPath + "/mnist/") !
+
+ val context = Context.gpu()
+
+ val output = GanMnist.runTraining(modelDirPath, context, modelDirPath,
5)
+ Process("rm -rf " + modelDirPath) !
+
+ assert(output >= 0.0f)
+ } else {
+ logger.info("GPU test only, skipped...")
+ }
}
}
diff --git
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/MNISTExampleSuite.scala
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/MNISTExampleSuite.scala
index 7b1d6ddc38b..0fd3af02d9c 100644
---
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/MNISTExampleSuite.scala
+++
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/MNISTExampleSuite.scala
@@ -29,8 +29,7 @@ import org.slf4j.LoggerFactory
import scala.sys.process.Process
/**
- * Integration test for imageClassifier example.
- * This will run as a part of "make scalatest"
+ * Integration test for MNIST example.
*/
class MNISTExampleSuite extends FunSuite with BeforeAndAfterAll {
private val logger = LoggerFactory.getLogger(classOf[MNISTExampleSuite])
diff --git
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala
new file mode 100644
index 00000000000..dc8fc5b8c14
--- /dev/null
+++
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.neuralstyle
+
+import org.apache.mxnet.Context
+import org.apache.mxnetexamples.Util
+import org.apache.mxnetexamples.neuralstyle.end2end.{BoostInference,
BoostTrain}
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.slf4j.LoggerFactory
+
+import scala.sys.process.Process
+
+/**
+ * Neural Suite Test package
+ * Currently there is no plan to run to test accuracy
+ * This test is just to verify the model is runnable
+ */
+class NeuralStyleSuite extends FunSuite with BeforeAndAfterAll {
+ private val logger = LoggerFactory.getLogger(classOf[NeuralStyleSuite])
+
+
+ override def beforeAll(): Unit = {
+ logger.info("Downloading vgg model")
+ val tempDirPath = System.getProperty("java.io.tmpdir")
+ logger.info("tempDirPath: %s".format(tempDirPath))
+ val baseUrl =
"https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci/NeuralStyle/"
+ Util.downloadUrl(baseUrl + "IMG_4343.jpg", tempDirPath +
"/NS/IMG_4343.jpg")
+ Util.downloadUrl(baseUrl + "starry_night.jpg", tempDirPath +
"/NS/starry_night.jpg")
+ Util.downloadUrl(baseUrl + "model.zip", tempDirPath + "/NS/model.zip")
+ Util.downloadUrl(baseUrl + "vgg19.params", tempDirPath +
"/NS/vgg19.params")
+ // TODO: Need to confirm with Windows
+ Process(s"unzip $tempDirPath/NS/model.zip -d $tempDirPath/NS/") !
+
+ Process(s"mkdir $tempDirPath/NS/images") !
+
+ for (i <- 0 until 20) {
+ Process(s"cp $tempDirPath/NS/IMG_4343.jpg
$tempDirPath/NS/images/img$i.jpg") !
+ }
+ }
+
+ test("Example CI: Test Boost Inference") {
+ val tempDirPath = System.getProperty("java.io.tmpdir")
+ var ctx = Context.cpu()
+ if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+ System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+ ctx = Context.gpu()
+ }
+ BoostInference.runInference(tempDirPath + "/NS/model", tempDirPath +
"/NS", 2,
+ tempDirPath + "/NS/IMG_4343.jpg", ctx)
+ }
+
+ test("Example CI: Test Boost Training") {
+ val tempDirPath = System.getProperty("java.io.tmpdir")
+ if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+ System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+ val ctx = Context.gpu()
+ BoostTrain.runTraining(tempDirPath + "/NS/images", tempDirPath +
"/NS/vgg19.params", ctx,
+ tempDirPath + "/NS/starry_night.jpg", tempDirPath + "/NS")
+ } else {
+ logger.info("GPU test only, skip CPU...")
+ }
+ }
+
+ test("Example CI: Test Neural Style") {
+ val tempDirPath = System.getProperty("java.io.tmpdir")
+ if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+ System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+ val ctx = Context.gpu()
+ NeuralStyle.runTraining("vgg19", tempDirPath + "/NS/IMG_4343.jpg",
+ tempDirPath + "/NS/starry_night.jpg",
+ ctx, tempDirPath + "/NS/vgg19.params", tempDirPath + "/NS",
+ 1f, 20f, 0.01f, 1, 10f, 60, 600, 50, 0.0005f)
+ } else {
+ logger.info("GPU test only, skip CPU")
+ }
+ }
+}
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services