Github user mengxr commented on a diff in the pull request:
https://github.com/apache/spark/pull/8648#discussion_r40146842
--- Diff: mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala ---
@@ -260,127 +263,126 @@ private[ann] trait ActivationFunction extends
Serializable {
}
/**
- * Implements in-place application of functions
+ * Implements in-place application of functions.
*/
private[ann] object ActivationFunction {
- def apply(x: BDM[Double], y: BDM[Double], func: Double => Double): Unit
= {
- var i = 0
- while (i < x.rows) {
- var j = 0
- while (j < x.cols) {
- y(i, j) = func(x(i, j))
- j += 1
- }
- i += 1
- }
+ def apply(x: BDM[Double], y: BDM[Double], func: UFunc with MappingUFunc)(
+ implicit impl: func.Impl[BDM[Double], BDM[Double]]): Unit = {
+ y := func(x)
}
def apply(
- x1: BDM[Double],
- x2: BDM[Double],
- y: BDM[Double],
- func: (Double, Double) => Double): Unit = {
- var i = 0
- while (i < x1.rows) {
- var j = 0
- while (j < x1.cols) {
- y(i, j) = func(x1(i, j), x2(i, j))
- j += 1
- }
- i += 1
- }
+ x1: BDM[Double],
+ x2: BDM[Double],
+ y: BDM[Double],
+ func: UFunc with MappingUFunc)(
+ implicit impl: func.Impl2[BDM[Double], BDM[Double], BDM[Double]]):
Unit = {
+ y := func(x1, x2)
}
}
/**
- * Implements SoftMax activation function
+ * Implements Softmax activation function.
*/
private[ann] class SoftmaxFunction extends ActivationFunction {
override def eval(x: BDM[Double], y: BDM[Double]): Unit = {
- var j = 0
- // find max value to make sure later that exponent is computable
- while (j < x.cols) {
- var i = 0
- var max = Double.MinValue
- while (i < x.rows) {
- if (x(i, j) > max) {
- max = x(i, j)
- }
- i += 1
- }
- var sum = 0.0
- i = 0
- while (i < x.rows) {
- val res = Math.exp(x(i, j) - max)
- y(i, j) = res
- sum += res
- i += 1
- }
- i = 0
- while (i < x.rows) {
- y(i, j) /= sum
- i += 1
- }
- j += 1
+ (0 until x.cols).foreach { j =>
+ // subtract max value to prevent overflow during exp
+ // does not affect correctness since we normalize right after
+ val maxVal = Bmax(x(::, j))
+ y(::, j) := breeze.numerics.exp(x(::, j) - maxVal)
+ y(::, j) :/= Bsum(y(::, j))
--- End diff --
@feynmanliang Could you run some micro-benchmark on this function? I think
this is the only place that might cause performance issues.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]