zhengruifeng commented on a change in pull request #32822:
URL: https://github.com/apache/spark/pull/32822#discussion_r649624193
##########
File path: mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala
##########
@@ -94,4 +95,43 @@ private[spark] object Utils {
math.log1p(math.exp(x))
}
}
+
+ /**
+ * Perform in-place softmax conversion.
+ */
+ def softmax(values: Array[Double]): Unit = {
+ var maxValue = Double.MinValue
+ var i = 0
+ while (i < values.length) {
+ val value = values(i)
+ if (value.isPosInfinity) {
+ java.util.Arrays.fill(values, 0)
+ values(i) = 1.0
+ return
+ } else if (value > maxValue) {
+ maxValue = value
+ }
+ i += 1
+ }
+
+ var sum = 0.0
+ i = 0
+ if (maxValue > 0) {
Review comment:
sounds reasonable
##########
File path: mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala
##########
@@ -94,4 +95,43 @@ private[spark] object Utils {
math.log1p(math.exp(x))
}
}
+
+ /**
+ * Perform in-place softmax conversion.
+ */
+ def softmax(values: Array[Double]): Unit = {
+ var maxValue = Double.MinValue
+ var i = 0
+ while (i < values.length) {
+ val value = values(i)
+ if (value.isPosInfinity) {
+ java.util.Arrays.fill(values, 0)
+ values(i) = 1.0
+ return
+ } else if (value > maxValue) {
+ maxValue = value
+ }
+ i += 1
+ }
+
+ var sum = 0.0
+ i = 0
+ if (maxValue > 0) {
Review comment:
```py
def softmax(X, copy=True):
"""
Calculate the softmax function.
The softmax function is calculated by
np.exp(X) / np.sum(np.exp(X), axis=1)
This will cause overflow when large values are exponentiated.
Hence the largest value in each row is subtracted from each data
point to prevent this.
Parameters
----------
X : array-like of float of shape (M, N)
Argument to the logistic function.
copy : bool, default=True
Copy X or not.
Returns
-------
out : ndarray of shape (M, N)
Softmax function evaluated at every point in x.
"""
if copy:
X = np.copy(X)
max_prob = np.max(X, axis=1).reshape((-1, 1))
X -= max_prob
np.exp(X, X)
sum_prob = np.sum(X, axis=1).reshape((-1, 1))
X /= sum_prob
return X
```
softmax in scikit-learn does not check whether the maxvalue is positive or
not
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]