zhengruifeng commented on code in PR #36469:
URL: https://github.com/apache/spark/pull/36469#discussion_r867285403
##########
mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/MultinomialLogisticBlockAggregator.scala:
##########
@@ -151,15 +158,39 @@ private[ml] class MultinomialLogisticBlockAggregator(
case dm: DenseMatrix =>
// gradientSumArray[0 : F X C] += mat.T X dm
BLAS.nativeBLAS.dgemm("T", "T", numClasses, numFeatures, size, 1.0,
- mat.values, size, dm.values, numFeatures, 1.0, gradientSumArray,
numClasses)
+ arr, size, dm.values, numFeatures, 1.0, gradientSumArray, numClasses)
case sm: SparseMatrix =>
- // TODO: convert Coefficients to row major order to simplify BLAS
operations?
- // linearGradSumMat = sm.T X mat
- // existing BLAS.gemm requires linearGradSumMat is NOT Transposed.
- val linearGradSumMat = DenseMatrix.zeros(numFeatures, numClasses)
- BLAS.gemm(1.0, sm.transpose, mat, 0.0, linearGradSumMat)
- linearGradSumMat.foreachActive { (i, j, v) => gradientSumArray(i *
numClasses + j) += v }
+ // dedicated sparse GEMM implementation for transposed C: C += A * B,
where:
Review Comment:
this impl refers to
https://github.com/apache/spark/blob/master/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala#L569-L586,
execpt the indexing of C
https://github.com/apache/spark/blob/master/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala#L580
.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]