Github user feynmanliang commented on a diff in the pull request:
https://github.com/apache/spark/pull/8648#discussion_r41700924
--- Diff: mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala ---
@@ -260,127 +263,126 @@ private[ann] trait ActivationFunction extends
Serializable {
}
/**
- * Implements in-place application of functions
+ * Implements in-place application of functions.
*/
private[ann] object ActivationFunction {
- def apply(x: BDM[Double], y: BDM[Double], func: Double => Double): Unit
= {
- var i = 0
- while (i < x.rows) {
- var j = 0
- while (j < x.cols) {
- y(i, j) = func(x(i, j))
- j += 1
- }
- i += 1
- }
+ def apply(x: BDM[Double], y: BDM[Double], func: UFunc with MappingUFunc)(
+ implicit impl: func.Impl[BDM[Double], BDM[Double]]): Unit = {
+ y := func(x)
}
def apply(
- x1: BDM[Double],
- x2: BDM[Double],
- y: BDM[Double],
- func: (Double, Double) => Double): Unit = {
- var i = 0
- while (i < x1.rows) {
- var j = 0
- while (j < x1.cols) {
- y(i, j) = func(x1(i, j), x2(i, j))
- j += 1
- }
- i += 1
- }
+ x1: BDM[Double],
+ x2: BDM[Double],
+ y: BDM[Double],
+ func: UFunc with MappingUFunc)(
+ implicit impl: func.Impl2[BDM[Double], BDM[Double], BDM[Double]]):
Unit = {
+ y := func(x1, x2)
}
}
/**
- * Implements SoftMax activation function
+ * Implements Softmax activation function.
*/
private[ann] class SoftmaxFunction extends ActivationFunction {
override def eval(x: BDM[Double], y: BDM[Double]): Unit = {
- var j = 0
- // find max value to make sure later that exponent is computable
- while (j < x.cols) {
- var i = 0
- var max = Double.MinValue
- while (i < x.rows) {
- if (x(i, j) > max) {
- max = x(i, j)
- }
- i += 1
- }
- var sum = 0.0
- i = 0
- while (i < x.rows) {
- val res = Math.exp(x(i, j) - max)
- y(i, j) = res
- sum += res
- i += 1
- }
- i = 0
- while (i < x.rows) {
- y(i, j) /= sum
- i += 1
- }
- j += 1
+ (0 until x.cols).foreach { j =>
+ // subtract max value to prevent overflow during exp
+ // does not affect correctness since we normalize right after
+ val maxVal = Bmax(x(::, j))
+ y(::, j) := breeze.numerics.exp(x(::, j) - maxVal)
+ y(::, j) :/= Bsum(y(::, j))
--- End diff --
@mengxr [Local benchmarks
here](https://gist.github.com/feynmanliang/bc64b82a1258c4e86b9a). Performance
is more or less equivalent.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]