Github user mengxr commented on a diff in the pull request: https://github.com/apache/spark/pull/2294#discussion_r17255285 --- Diff: mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala --- @@ -197,4 +199,451 @@ private[mllib] object BLAS extends Serializable { throw new IllegalArgumentException(s"scal doesn't support vector type ${x.getClass}.") } } + + // For level-3 routines, we use the native BLAS. + private def nativeBLAS: NetlibBLAS = { + if (_nativeBLAS == null) { + _nativeBLAS = NativeBLAS + } + _nativeBLAS + } + + /** + * C := alpha * A * B + beta * C + * @param transA specify whether to use matrix A, or the transpose of matrix A. Should be "N" or + * "n" to use A, and "T" or "t" to use the transpose of A. + * @param transB specify whether to use matrix B, or the transpose of matrix B. Should be "N" or + * "n" to use B, and "T" or "t" to use the transpose of B. + * @param alpha a scalar to scale the multiplication A * B. + * @param A the matrix A that will be left multiplied to B. Size of m x k. + * @param B the matrix B that will be left multiplied by A. Size of k x n. + * @param beta a scalar that can be used to scale matrix C. + * @param C the resulting matrix C. Size of m x n. + */ + def gemm( + transA: String, + transB: String, + alpha: Double, + A: Matrix, + B: DenseMatrix, + beta: Double, + C: DenseMatrix) { + var mA: Int = A.numRows + var nB: Int = B.numCols + var kA: Int = A.numCols + var kB: Int = B.numRows + + if (transA == "T" || transA=="t"){ + mA = A.numCols + kA = A.numRows + } + require(transA == "T" || transA == "t" || transA == "N" || transA == "n", + s"""Invalid argument used for transA: $transA. Must be \"N\", \"n\", \"T\", or \"t\"""") + if (transB == "T" || transB=="t"){ + nB = B.numRows + kB = B.numCols + } + require(transB == "T" || transB == "t" || transB == "N" || transB == "n", + s"""Invalid argument used for transB: $transB. Must be \"N\", \"n\", \"T\", or \"t\"""") + + require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") + require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") + require(nB == C.numCols, + s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB") + + A match { + case sparse: SparseMatrix => + gemm(transA, transB, alpha, sparse, B, beta, C) + case dense: DenseMatrix => + gemm(transA, transB, alpha, dense, B, beta, C) + case _ => + throw new IllegalArgumentException(s"gemm doesn't support matrix type ${A.getClass}.") + } + } + + /** + * C := alpha * A * B + beta * C + * + * @param alpha a scalar to scale the multiplication A * B. + * @param A the matrix A that will be left multiplied to B. Size of m x k. + * @param B the matrix B that will be left multiplied by A. Size of k x n. + * @param beta a scalar that can be used to scale matrix C. + * @param C the resulting matrix C. Size of m x n. + */ + def gemm( + alpha: Double, + A: Matrix, + B: DenseMatrix, + beta: Double, + C: DenseMatrix) { + gemm("N", "N", alpha, A, B, beta, C) + } + + /** + * C := alpha * A * B + * + * @param transA specify whether to use matrix A, or the transpose of matrix A. Should be "N" or + * "n" to use A, and "T" or "t" to use the transpose of A. + * @param transB specify whether to use matrix B, or the transpose of matrix B. Should be "N" or + * "n" to use B, and "T" or "t" to use the transpose of B. + * @param alpha a scalar to scale the multiplication A * B. + * @param A the matrix A that will be left multiplied to B. Size of m x k. + * @param B the matrix B that will be left multiplied by A. Size of k x n. + * + * @return The resulting matrix C. Size of m x n. + */ + def gemm( + transA: String, + transB: String, + alpha: Double, + A: Matrix, + B: DenseMatrix) : DenseMatrix = { + var mA: Int = A.numRows + var nB: Int = B.numCols + var kA: Int = A.numCols + var kB: Int = B.numRows + + if (transA == "T" || transA=="T"){ + mA = A.numCols + kA = A.numRows + } + if (transB == "T" || transB=="T"){ + nB = B.numRows + kB = B.numCols + } + + val C: DenseMatrix = DenseMatrix.zeros(mA, nB) + gemm(transA, transB, alpha, A, B, 0.0, C) + + C + } + + /** + * C := alpha * A * B + * + * @param alpha a scalar to scale the multiplication A * B. + * @param A the matrix A that will be left multiplied to B. Size of m x k. + * @param B the matrix B that will be left multiplied by A. Size of k x n. + * + * @return The resulting matrix C. Size of m x n. + */ + def gemm( + alpha: Double, + A: Matrix, + B: DenseMatrix) : DenseMatrix = { + gemm("N", "N", alpha, A, B) + } + + /** + * C := alpha * A * B + beta * C + * For `DenseMatrix` A. + */ + private def gemm( + transA: String, + transB: String, + alpha: Double, + A: DenseMatrix, + B: DenseMatrix, + beta: Double, + C: DenseMatrix) { + var mA: Int = A.numRows + var nB: Int = B.numCols + var kA: Int = A.numCols + var kB: Int = B.numRows + + if (transA == "T" || transA=="t"){ + mA = A.numCols + kA = A.numRows + } + if (transB == "T" || transB=="t"){ + nB = B.numRows + kB = B.numCols + } + + nativeBLAS.dgemm(transA,transB, mA, nB, kA, alpha, A.toArray, A.numRows, B.toArray, B.numRows, + beta, C.toArray, C.numRows) + } + + /** + * C := alpha * A * B + beta * C + * For `SparseMatrix` A. + */ + private def gemm( + transA: String, + transB: String, + alpha: Double, + A: SparseMatrix, + B: DenseMatrix, + beta: Double, + C: DenseMatrix) { + var transposeA = false + var transposeB = false + var mA: Int = A.numRows + var nB: Int = B.numCols + var kA: Int = A.numCols + + if (transA == "T" || transA=="t"){ + mA = A.numCols + kA = A.numRows + transposeA = true + } + if (transB == "T" || transB=="t"){ + nB = B.numRows + transposeB = true + } + val Avals = A.toArray + val Arows = if (!transposeA) A.rowIndices else A.colPointers + val Acols = if (!transposeA) A.colPointers else A.rowIndices + + // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices + if (transposeA){ + var colCounter = 0 + while (colCounter < nB){ // Tests showed that the outer loop being columns was faster + var rowCounter = 0 + while (rowCounter < mA){ + val indStart = Arows(rowCounter) + val indEnd = Arows(rowCounter + 1) + var elementCount = 0 // Loop over non-zero entries in column (actually the row indices, + // since this is a transposed multiplication + var sum = 0.0 + while(indStart + elementCount < indEnd){ + val AcolIndex = Acols(indStart + elementCount) + val Bval = if (!transposeB) B(AcolIndex, colCounter) else B(colCounter, AcolIndex) --- End diff -- do not put the if check in the inner loop, which is expensive. make four outer blocks instead.
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org