First, the caveats:
- on large matrices,
- with few threads, and
- punching a hole directly to the BLAS library.

Sivan Toledo's recursive LU factorization ( 
[[http://dx.doi.org/10.1137/S0895479896297744]] ) isn't just good for 
optimizing cache behavior but also avoiding top-level overheads.  A somewhat 
straight-forward iterative implementation in Julia, attached, performs about as 
quickly as the current LAPACK call for square matrices of dimension 3000 or 
above so long as you limit OpenBLAS to a single thread.  I haven't tested 
against MKL.

Unfortunately, I have to avoid the pretty BLAS interface and call xGEMM and 
xTRSV directly.  The allocations and construction of StridedMatrix views at 
every step destroys any performance.  Similarly with a light-weight wrapper 
attempt called LASubDomain below.  Also, I have to write loops rather than 
vector expressions.  That leads to ugly code compared to modern Fortran 
versions.

Current ratio of recursive LU time to the Julia LAPACK dispatch on a dual-core, 
Sandy Bridge server by dimension (vertical) and number of threads (horizontal):

|     N |   NT: 1 |       2 |       4 |      8 |      16 |      24 |      32 |
|-------+---------+---------+---------+--------+---------+---------+---------|
|    10 | 2301.01 | 2291.16 | 2335.39 | 2268.1 | 2532.12 | 2304.23 | 2320.84 |
|    30 |  661.25 |  560.34 |  443.07 | 514.52 |  392.04 |  425.56 |     472 |
|   100 |   98.62 |   90.38 |   96.69 |  94.31 |    74.6 |   73.61 |   76.33 |
|   300 |    8.12 |    8.14 |   12.19 |  16.15 |   14.55 |   14.16 |   15.82 |
|  1000 |     1.6 |    1.55 |    2.36 |   3.35 |     3.5 |    3.86 |    3.48 |
|  3000 |    1.04 |    1.13 |    1.36 |   1.99 |    1.62 |    1.36 |    1.39 |
| 10000 |    1.02 |    1.06 |    1.13 |   1.26 |    1.55 |    1.54 |    1.58 |

The best case for traditional LU factorization relying on xGEMM for the Schur 
complement is 7.7x slower than the LAPACK dispatch, and it scales far worse.

If there's interest, I'd love to work this into a faster generic 
base/linalg/lu.jl for use with other data types.  I know I need to consider if 
the la_swap!, lu_trisolve!, and lu_schur! functions are the right utility 
methods.  I'm interested in supporting doubled double, etc. with relatively 
high performance.  That implies building everything out of at least dot 
products.  For "multiplied" precisions, those can be optimized to avoid 
continual re-normalization (see Siegfried Rump, et al.'s work).

(I actually wrote the code last November, but this is first chance I've had to 
describe it.  Might be a tad out of date style-wise with respect to current 
Julia.  Given that lag, I suppose I shouldn't promise to beat this into shape 
at any fast pace...)
module RecFactorization
# Beginnings of recursive factorization implementations for Julia.
# Currently just LU.  QR should be straight-forward.  Unfortunately,
# vector expressions must be written as loops to avoid needless memory
# allocation.  I hate writing loops.

export reclufact!, lufact_schur!, lufact_noblas!

import Base.LinAlg: BlasFloat, BlasChar, BlasInt, blas_int, DimensionMismatch, 
chksquare, axpy!

# Attempt at defining a pretty LAPACK-style sub-array...
## Using LASubDomain causes far too much allocation.  I suspect using 
ArrayViews would have
## the same problem.
immutable LASubDomain
    offset::BlasInt
    leadingdim::BlasInt
    nrow::BlasInt
    ncol::BlasInt

    function LASubDomain{T}(A::DenseArray{T,2}, i, j)
        (M,N) = size(A)
        offset = i-1 + (j-1)*M
        nrow = M-i
        ncol = N-i
        return new(offset, M, nrow, ncol)
    end
    function LASubDomain{T}(A::DenseArray{T,2}, i, nrow, j, ncol)
        (M,N) = size(A)
        offset = (i-1) + (j-1)*M
        return new(offset, M, nrow, ncol)
    end
end

import Base.pointer
pointer{T}(A::DenseArray{T, 2}, dom::LASubDomain) = pointer(A) + 
dom.offset*sizeof(T)
pointer{T}(A::DenseArray{T, 2}, i, j) = pointer(A) + ((i-1) + 
(j-1)*size(A,1))*sizeof(T)
import Base.size
size(dom::LASubDomain) = (dom.nrow, dom.ncol)
size(dom::LASubDomain, k::Int) = (dom.nrow, dom.ncol)[k]

leadingdim{T}(A::DenseArray{T, 2}) = size(A, 1)
leadingdim(dom::LASubDomain) = dom.leadingdim
LAArg{T}(A::DenseArray{T,2}, dom::LASubDomain) = (pointer(A, dom), 
leadingdim(dom))

# Dig in and expose the raw BLAS calls.
const libblas = Base.libblas_name

for (gemm, elty) in
        ((:dgemm_,:Float64),
         (:sgemm_,:Float32),
         (:zgemm_,:Complex128),
         (:cgemm_,:Complex64))
    @eval begin
             # SUBROUTINE 
DGEMM(TRANSA,TRANSB,M,N,K,ALPHA,A,LDA,B,LDB,BETA,C,LDC)
             # *     .. Scalar Arguments ..
             #       DOUBLE PRECISION ALPHA,BETA
             #       INTEGER K,LDA,LDB,LDC,M,N
             #       CHARACTER TRANSA,TRANSB
             # *     .. Array Arguments ..
             #       DOUBLE PRECISION A(LDA,*),B(LDB,*),C(LDC,*)
        function painful_gemm!(transA::BlasChar, transB::BlasChar, m::BlasInt, 
n::BlasInt, k::BlasInt, alpha::($elty), A::DenseArray{$elty,2}, 
domA::LASubDomain, B::DenseArray{$elty,2}, domB::LASubDomain, beta::($elty), 
C::DenseArray{$elty,2}, domC::LASubDomain)
            (pA, lda) = LAArg(A, domA)
            (pB, ldb) = LAArg(B, domB)
            (pC, ldc) = LAArg(C, domC)
            ccall(($(string(gemm)),libblas), Void,
                  (Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, 
                   Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, 
                   Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, 
                   Ptr{BlasInt}),
                  &transA, &transB, &m, &n, 
                  &k, &alpha, pA, &lda,
                  pB, &ldb, &beta, pC, 
                  &ldc)
        end
        function painful_gemm!(transA::BlasChar, transB::BlasChar, m::BlasInt, 
n::BlasInt, k::BlasInt, alpha::($elty), A::Ptr{$elty}, lda::BlasInt, 
B::Ptr{$elty}, ldb::BlasInt, beta::($elty), C::Ptr{$elty}, ldc::BlasInt)
            ccall(($(string(gemm)),libblas), Void,
                  (Ptr{Uint8}, Ptr{Uint8}, Ptr{BlasInt}, Ptr{BlasInt}, 
                   Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}, 
                   Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, 
                   Ptr{BlasInt}),
                  &transA, &transB, &m, &n, 
                  &k, &alpha, A, &lda,
                  B, &ldb, &beta, C, 
                  &ldc)
        end
    end
end

## (TR) Triangular matrix and vector multiplication and solution
for (mmname, smname, elty) in
        ((:dtrmm_,:dtrsm_,:Float64),
         (:strmm_,:strsm_,:Float32),
         (:ztrmm_,:ztrsm_,:Complex128),
         (:ctrmm_,:ctrsm_,:Complex64))
    @eval begin
        #       SUBROUTINE DTRSM(SIDE,UPLO,TRANSA,DIAG,M,N,ALPHA,A,LDA,B,LDB)
        # *     .. Scalar Arguments ..
        #       DOUBLE PRECISION ALPHA
        #       INTEGER LDA,LDB,M,N
        #       CHARACTER DIAG,SIDE,TRANSA,UPLO
        # *     .. Array Arguments ..
        #       DOUBLE PRECISION A(LDA,*),B(LDB,*)
        function painful_trsm!(side::Char, uplo::Char, transa::Char, diag::Char,
                               alpha::$elty, A::DenseArray{$elty,2}, 
domA::LASubDomain,
                               B::DenseArray{$elty,2}, domB::LASubDomain)
            m, n = size(domB)
            k = chksquare(domA)
            k==(side == 'L' ? m : n) || throw(DimensionMismatch("size of A is 
$n, size(B)=($m,$n) and transa='$transa'"))
            pA, lda = LAArg(A, domA)
            pB, ldb = LAArg(A, domB)
            ccall(($(string(smname)), libblas), Void,
                (Ptr{Uint8}, Ptr{Uint8}, Ptr{Uint8}, Ptr{Uint8}, 
                 Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, 
                 Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
                 &side, &uplo, &transa, &diag, 
                 &m, &n, &alpha, pA, &lda, pB, &ldb)
        end
        function painful_trsm!(side::Char, uplo::Char, transa::Char, diag::Char,
                               m::BlasInt, n::BlasInt,
                               alpha::$elty, A::Ptr{$elty}, lda::BlasInt,
                               B::Ptr{$elty}, ldb::BlasInt)
            ccall(($(string(smname)), libblas), Void,
                  (Ptr{Uint8}, Ptr{Uint8}, Ptr{Uint8}, Ptr{Uint8}, 
                   Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty}, 
                   Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}),
                  &side, &uplo, &transa, &diag, 
                  &m, &n, &alpha, A, &lda, B, &ldb)
        end
    end
end

# LAPACK helper translated to Julia
function laswp!{T} (A :: Matrix{T}, startcol, ncol, k1, k2, ipiv)
    @inbounds begin
        for ii = k1:k2
            ip = ipiv[ii]
            if (ip != ii)
                for jj = startcol:(startcol+ncol-1)
                    (A[ii, jj], A[ip, jj]) = (A[ip, jj], A[ii, jj])
                end
            end
        end
    end
end

# lu_trisolve! and lu_schur! wrap the bare BLAS calls when possible.
function lu_trisolve!{T<:BlasFloat} (A :: Matrix{T}, i, j, nrow, ncol)
    painful_trsm!('L', 'L', 'N', 'U', nrow, ncol, one(T), pointer(A, i, i), 
leadingdim(A),
                  pointer(A, i, j), leadingdim(A))
end
function lu_trisolve!{T} (A::AbstractMatrix{T}, i, j, nrow, ncol)
    for k = 1:nrow-1
        for jj = j:j+ncol-1
            @simd for ii = i+k:i+nrow-1
                A[ii,jj] -= A[ii,i+k-1]
            end
        end
    end
end

function lu_schur!{T<:BlasFloat}(A::Matrix{T}, i, j, M, N, K)
    painful_gemm!('N', 'N', M, N, K, -one(T),
                  pointer(A, j, i), leadingdim(A),
                  pointer(A, i, j), leadingdim(A),
                  one(T),
                  pointer(A, j, j), leadingdim(A))
end
function lu_schur!{T}(A::Matrix{T}, i, j, M, N, K)
    for jj = j:j+N-1
        for ii = j:j+M-1
            zz = A[ii, jj]
            @simd for kk = i:i+K-1
                zz -= A[ii, kk] * A[kk, jj]
            end
            A[ii, jj] = zz
        end
    end
end

# Recursive LU factorization, iterative style.
## With appropriate bit-twiddling, recursive LU can be written as a plain loop. 
 See
## VARIANTS/lu/REC in the LAPACK source distribution.  This style avoids much 
of the
## top-level overhead.
##
## Original citation:
## Toledo, Sivan. "Locality of Reference in LU Decomposition with Partial 
Pivoting."
## SIAM Journal on Matrix Analysis and Applications 1997 18:4, 1065-1081.
## http://dx.doi.org/10.1137/S0895479896297744
function reclufact!{T} (A :: Matrix{T})

    (M::BlasInt, N::BlasInt) = size(A)
    nstep::BlasInt = min(M, N)
    info::BlasInt = 0

    ## ipiv::Vector{BlasInt}
    ipiv = Array(BlasInt, nstep)

    ### XXX: check bounds first
    @inbounds begin
        for j::BlasInt = 1:nstep
            kahead = j & -j
            kstart = j + 1 - kahead
            kcols = min(kahead, M-j)

            # Find the pivot.
            jp = j
            jpval = abs(A[j,j])
            if isnan(jpval)
                for jp2 = (jp+1):M
                    tstval = A[jp2,j]
                    if !isnan(tstval)
                        jp = jp2
                        jpval = abs(tstval)
                        break;
                    end
                end
            end
            for jj = (jp+1):M
                tstval = abs(A[jj,j])
                if tstval > jpval
                    # NaNs fail the test and will not be chosen.
                    jp = jj
                    jpval = tstval
                end
            end
            piv_ok = jpval != 0 && !isnan(jpval)
            if piv_ok
                ipiv[j] = jp
            else
                jp = j
                ipiv[j] = j
            end

            # Permute just this column
            if jp != j && piv_ok
                (A[j,j], A[jp,j]) = (A[jp,j], A[j,j])
            end

            # Apply pending permutations to L
            ntopiv = 1
            ipivstart = j
            jpivstart = j - ntopiv
            while (ntopiv < kahead)
                laswp!(A, jpivstart, ntopiv, ipivstart, j, ipiv)
                ipivstart -= ntopiv
                ntopiv *= 2
                jpivstart -= ntopiv
            end

            # Permute U block to match L
            laswp!(A, j+1, kcols, kstart, j, ipiv)

            # Factor current column
            if piv_ok
                pivent = A[j,j] # Pick up sign.
                inv_pivent = one(pivent)/pivent
                @simd for ii = (j+1):M
                    A[ii,j] *= inv_pivent
                end
            elseif info == 0
                info = j
            end

            ## Needed for BLAS.* versions below.
            ## B1 = kstart:kstart+kahead-1
            ## Br = j+1:M
            ## Bc = j+1:j+kcols

            # Solve for U block
            ##  Note that the BLAS.* calls are the most allocation expensive 
from sub,
            ##  and the LASubDomain are the next most expensive.

            ## BLAS.trsm!('L', 'L', 'N', 'U', one(T), sub(A, B1, B1), sub(A, 
B1, Bc))

            ## painful_trsm!('L', 'L', 'N', 'U', one(T),
            ##               A, LASubDomain(A, kstart, kahead, kstart, kahead),
            ##               A, LASubDomain(A, kstart, kahead, j+1, kcols))

            lu_trisolve!(A, kstart, j+1, kahead, kcols)

            # Schur complement

            ## BLAS.gemm!('N', 'N', -one(T), sub(A, Br, B1), sub(A, B1, Bc), 
one(T), sub(A, Br, Bc))

            ## painful_gemm!('N', 'N', M-j, kcols, kahead, -one(T),
            ##               A, LASubDomain(A, j+1, M-j-1, kstart, kahead),
            ##               A, LASubDomain(A, kstart, kahead, j+1, kcols),
            ##               one(T),
            ##               A, LASubDomain(A, j+1, M-j-1, j+1, kcols))

            lu_schur!(A, kstart, j+1, M-j, kcols, kahead)
        end

        # Handle pivot permutations out of recursion
        npived = nstep & -nstep
        j = nstep - npived
        while (j > 0)
            ntopiv = j & -j
            laswp!(A, j-ntopiv+1, ntopiv, j+1, nstep, ipiv)
            j = j - ntopiv
        end

        # If short and wide, handle the rest of the columns
        if M < N
            laswp!(A, M+kcols+1, N-M, 1, M, ipiv)
            ## BLAS.trsm!('L', 'L', 'N', 'U', one(T),
            ##            sub(A, :, 1:M), sub(A, :, (M+1):N))
            ## painful_trsm!('L', 'L', 'N', 'U', one(T), A, LASubDomain(A, 1, 
M, 1, M),
            ##               A, LASubDomain(A, 1, M, M+1, N-M))
            ## painful_trsm!('L', 'L', 'N', 'U',
            ##               M, N-M,
            ##               one(T), pointer(A, 1, 1), leadingdim(A),
            ##               pointer(A, 1, M+1), leadingdim(A))
            lu_trisolve!(A, 1, M+1, M, N-M)
        end
    end
    return Base.LinAlg.LU{T,typeof(A)}(A, ipiv, convert(BlasInt, info))
end
reclufact (A::AbstractMatrix) = reclufact!(copy(A))


# For comparison, a direct translation of typical LU factorization.
function lufact_noblas!{T}(A::AbstractMatrix{T})
    ##typeof(one(T)/one(T)) <: BlasFloat && return lufact!(float(A))
    m, n = size(A)
    minmn = min(m,n)
    info = 0
    ipiv = Array(BlasInt, minmn)
    @inbounds for k = 1:minmn
        # find index max
        kp = 1
        amax = zero(T)
        for i = k:m
            absi = abs(A[i,k])
            if absi > amax
                kp = i
                amax = absi
            end
        end
        ipiv[k] = kp
        if A[kp,k] != 0
            # Interchange
            for i = 1:n
                tmp = A[k,i]
                A[k,i] = A[kp,i]
                A[kp,i] = tmp
            end
            # Scale first column
            pivent = A[k,k]
            inv_pivent = one(pivent) / pivent
            @simd for i = k+1:m
                A[i,k] *= inv_pivent
            end
        elseif info == 0
            info = k
        end
        # Update the rest
        for j = k+1:n
            @simd for i = k+1:m
                A[i,j] -= A[i,k]*A[k,j]
            end
        end
    end
    if minmn > 0 && A[minmn,minmn] == 0; info = minmn; end
    return Base.LinAlg.LU{T,typeof(A)}(A, ipiv, convert(BlasInt, info))
end

# For another comparison, a translation of typical LU factorization that calls
# the BLAS for the Schur complement.
function lufact_schur!{T}(A::AbstractMatrix{T})
    ##typeof(one(T)/one(T)) <: BlasFloat && return lufact!(float(A))
    m, n = size(A)
    minmn = min(m,n)
    info = 0
    ipiv = Array(BlasInt, minmn)
    @inbounds for k = 1:minmn
        # find index max
        kp = 1
        amax = zero(T)
        for i = k:m
            absi = abs(A[i,k])
            if absi > amax
                kp = i
                amax = absi
            end
        end
        ipiv[k] = kp
        if A[kp,k] != 0
            # Interchange
            for i = 1:n
                tmp = A[k,i]
                A[k,i] = A[kp,i]
                A[kp,i] = tmp
            end
            # Scale first column
            pivent = A[k,k]
            inv_pivent = one(pivent) / pivent
            @simd for i = k+1:m
                A[i,k] *= inv_pivent
            end
        elseif info == 0
            info = k
        end
        # Update the rest
        lu_schur!(A, k+1, k+1, m-k, n-k, 1)
    end
    if minmn > 0 && A[minmn,minmn] == 0; info = minmn; end
    return Base.LinAlg.LU{T,typeof(A)}(A, ipiv, convert(BlasInt, info))
end



end
include("reclu.jl")
using RecFactorization

seed = 836392
N = 800

nargs = length(ARGS)
k = 1
while k < nargs
    if ARGS[k] == "--help" || ARGS[k] == "-h"
        println("See the source.")
        exit(-1)
    end
    if ARGS[k] == "--N" || ARGS[k] == "-N"
        N = int(ARGS[k+1])
        if N <= 0
            println("Size \"$ARGS[k]\" is bogus.")
            exit(-1)
        end
        k += 2
    elseif ARGS[k] == "--seed" || ARGS[k] == "-seed"
        if ARGS[k+1] == "random"
            seed = rand(Int)
        else
            seed = int(ARGS[k+1])
        end
        k += 2
    end
end

const A = rand(N, N)
const b = rand(N, 1)

precompile(reclufact!, (typeof(A),))
precompile(lufact!, (typeof(A),))
precompile(lufact_schur!, (typeof(A),))
precompile(lufact_noblas!, (typeof(A),))

println("N: $N\nseed: $seed")

# The first one is charged some extra allocation, so run here to avoid
# it.
Acopy = copy(A)
LU_orig, LU_orig_time, LU_orig_space = @timed lufact!(Acopy)

Acopy = copy(A)
LU_orig, LU_orig_time, LU_orig_space = @timed lufact!(Acopy)

println("LU_orig_time: ", LU_orig_time);
println("LU_orig_ratio: ", LU_orig_time / LU_orig_time);
println("LU_orig_space: ", LU_orig_space);

Acopy = copy(A)
LU_rec, LU_rec_time, LU_rec_space = @timed reclufact!(Acopy)

println("LU_rec_time: ", LU_rec_time);
println("LU_rec_ratio: ", LU_rec_time / LU_orig_time);
println("LU_rec_space: ", LU_rec_space);

Acopy = copy(A)
LU_schur, LU_schur_time, LU_schur_space = @timed lufact_schur!(Acopy)

println("LU_schur_time: ", LU_schur_time);
println("LU_schur_ratio: ", LU_schur_time / LU_orig_time);
println("LU_schur_space: ", LU_schur_space);

Acopy = copy(A)
LU_noblas, LU_noblas_time, LU_noblas_space = @timed lufact_noblas!(Acopy)

println("LU_noblas_time: ", LU_noblas_time);
println("LU_noblas_ratio: ", LU_noblas_time / LU_orig_time);
println("LU_noblas_space: ", LU_noblas_space);

err = norm(LU_orig\b - LU_rec\b, Inf)
println("LU_rec_solve_err: ", err)

fact_err = norm(LU_orig.factors - LU_rec.factors, Inf)
println("LU_rec_fact_err: ", fact_err)

Reply via email to