This is an automated email from the ASF dual-hosted git repository.

quinnj pushed a commit to branch jq-compress-locks
in repository https://gitbox.apache.org/repos/asf/arrow-julia.git

commit bed76e089b2f6fed15cc1fb9c1b87ce4e45c6698
Author: Jacob Quinn <[email protected]>
AuthorDate: Thu May 25 16:01:30 2023 -0600

    Refactor compressors/decompressors for laziness + safety
    
    Fixes #396.
    
    As noted in the originally reported issue, enabling debug logging when
    writing arrow data with compression can result in segfaults because the
    underlying CodecX package have debug logs, causing task switches/migration
    and thus making the pattern of using a single `X_COMPRESSOR` array indexed
    by `Threads.threadid()` unsafe since multiple threads may try using the
    same compressor at the same time.
    
    We fix this by wrapping each compressor in a `Lockable` and ensuring the
    `compress` (or `uncompress`) operation holds the lock for the duration of
    the operation. We also:
    * Add a decompressor per thread to avoid recreating them over and over 
during reading
    * Lazily initialize compressors/decompressors in a way that is 1.9+ safe 
and only
    creates the object when needed by a specific thread
    * Switch from WorkerUtilities -> ConcurrentUtilities (the package was 
renamed)
    
    Co-authored-by: J S <[email protected]>
---
 Project.toml                 |  4 +--
 src/Arrow.jl                 | 65 +++++++++++++++++++++++++++++++++++++-------
 src/arraytypes/arraytypes.jl | 16 +++++++----
 src/table.jl                 | 10 +++++--
 src/write.jl                 | 20 ++++----------
 5 files changed, 82 insertions(+), 33 deletions(-)

diff --git a/Project.toml b/Project.toml
index 5f75563..ed76b2f 100644
--- a/Project.toml
+++ b/Project.toml
@@ -24,6 +24,7 @@ ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd"
 BitIntegers = "c3b6d118-76ef-56ca-8cc7-ebb389d030a1"
 CodecLz4 = "5ba52731-8f18-5e0d-9241-30f10d1ec561"
 CodecZstd = "6b39b394-51ab-5f42-8807-6242bab2b4c2"
+ConcurrentUtilities = "f0e56b4a-5159-44fe-b623-3e5288b988bb"
 DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
 Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
 EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
@@ -35,13 +36,13 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
 TimeZones = "f269a46b-ccf7-5d73-abea-4c690281aa53"
 TranscodingStreams = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
 UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
-WorkerUtilities = "76eceee3-57b5-4d4a-8e66-0e911cebbf60"
 
 [compat]
 ArrowTypes = "1.1,2"
 BitIntegers = "0.2, 0.3"
 CodecLz4 = "0.4"
 CodecZstd = "0.7"
+ConcurrentUtilities = "2"
 DataAPI = "1"
 EnumX = "1"
 FilePathsBase = "0.9"
@@ -51,7 +52,6 @@ SentinelArrays = "1"
 Tables = "1.1"
 TimeZones = "1"
 TranscodingStreams = "0.9.12"
-WorkerUtilities = "1.1"
 julia = "1.6"
 
 [extras]
diff --git a/src/Arrow.jl b/src/Arrow.jl
index 01c7993..9d5a7fb 100644
--- a/src/Arrow.jl
+++ b/src/Arrow.jl
@@ -45,7 +45,7 @@ using Base.Iterators
 using Mmap
 using LoggingExtras
 import Dates
-using DataAPI, Tables, SentinelArrays, PooledArrays, CodecLz4, CodecZstd, 
TimeZones, BitIntegers, WorkerUtilities
+using DataAPI, Tables, SentinelArrays, PooledArrays, CodecLz4, CodecZstd, 
TimeZones, BitIntegers, ConcurrentUtilities
 
 export ArrowTypes
 
@@ -71,18 +71,63 @@ include("write.jl")
 include("append.jl")
 include("show.jl")
 
-const LZ4_FRAME_COMPRESSOR = LZ4FrameCompressor[]
-const ZSTD_COMPRESSOR = ZstdCompressor[]
+const ZSTD_COMPRESSOR = Lockable{ZstdCompressor}[]
+const ZSTD_DECOMPRESSOR = Lockable{ZstdDecompressor}[]
+const LZ4_FRAME_COMPRESSOR = Lockable{LZ4FrameCompressor}[]
+const LZ4_FRAME_DECOMPRESSOR = Lockable{LZ4FrameDecompressor}[]
+
+function init_zstd_compressor()
+    zstd = ZstdCompressor(; level=3)
+    CodecZstd.TranscodingStreams.initialize(zstd)
+    return Lockable(zstd)
+end
+
+function init_zstd_decompressor()
+    zstd = ZstdDecompressor()
+    CodecZstd.TranscodingStreams.initialize(zstd)
+    return Lockable(zstd)
+end
+
+function init_lz4_frame_compressor()
+    lz4 = LZ4FrameCompressor(; compressionlevel=4)
+    CodecLz4.TranscodingStreams.initialize(lz4)
+    return Lockable(lz4)
+end
+
+function init_lz4_frame_decompressor()
+    lz4 = LZ4FrameDecompressor()
+    CodecLz4.TranscodingStreams.initialize(lz4)
+    return Lockable(lz4)
+end
+
+function access_threaded(f, v::Vector)
+    tid = Threads.threadid()
+    0 < tid <= length(v) || _length_assert()
+    if @inbounds isassigned(v, tid)
+        @inbounds x = v[tid]
+    else
+        x = f()
+        @inbounds v[tid] = x
+    end
+    return x
+end
+@noinline _length_assert() =  @assert false "0 < tid <= v"
+
+zstd_compressor() = access_threaded(init_zstd_compressor, ZSTD_COMPRESSOR)
+zstd_decompressor() = access_threaded(init_zstd_decompressor, 
ZSTD_DECOMPRESSOR)
+lz4_frame_compressor() = access_threaded(init_lz4_frame_compressor, 
LZ4_FRAME_COMPRESSOR)
+lz4_frame_decompressor() = access_threaded(init_lz4_frame_decompressor, 
LZ4_FRAME_DECOMPRESSOR)
 
 function __init__()
-    for _ = 1:Threads.nthreads()
-        zstd = ZstdCompressor(; level=3)
-        CodecZstd.TranscodingStreams.initialize(zstd)
-        push!(ZSTD_COMPRESSOR, zstd)
-        lz4 = LZ4FrameCompressor(; compressionlevel=4)
-        CodecLz4.TranscodingStreams.initialize(lz4)
-        push!(LZ4_FRAME_COMPRESSOR, lz4)
+    nt = @static if isdefined(Base.Threads, :maxthreadid)
+        Threads.maxthreadid()
+    else
+        Threads.nthreads()
     end
+    resize!(empty!(LZ4_FRAME_COMPRESSOR), nt)
+    resize!(empty!(ZSTD_COMPRESSOR), nt)
+    resize!(empty!(LZ4_FRAME_DECOMPRESSOR), nt)
+    resize!(empty!(ZSTD_DECOMPRESSOR), nt)
     return
 end
 
diff --git a/src/arraytypes/arraytypes.jl b/src/arraytypes/arraytypes.jl
index 3bbfd0e..a3449f1 100644
--- a/src/arraytypes/arraytypes.jl
+++ b/src/arraytypes/arraytypes.jl
@@ -31,18 +31,24 @@ nullcount(x::ArrowVector) = validitybitmap(x).nc
 getmetadata(x::ArrowVector) = x.metadata
 Base.deleteat!(x::T, inds) where {T <: ArrowVector} = 
throw(ArgumentError("`$T` does not support `deleteat!`; arrow data is by nature 
immutable"))
 
-function toarrowvector(x, i=1, de=Dict{Int64, Any}(), ded=DictEncoding[], 
meta=getmetadata(x); compression::Union{Nothing, Vector{LZ4FrameCompressor}, 
LZ4FrameCompressor, Vector{ZstdCompressor}, ZstdCompressor}=nothing, kw...)
+function toarrowvector(x, i=1, de=Dict{Int64, Any}(), ded=DictEncoding[], 
meta=getmetadata(x); compression::Union{Nothing, Symbol, LZ4FrameCompressor, 
ZstdCompressor}=nothing, kw...)
     @debugv 2 "converting top-level column to arrow format: col = 
$(typeof(x)), compression = $compression, kw = $(values(kw))"
     @debugv 3 x
     A = arrowvector(x, i, 0, 0, de, ded, meta; compression=compression, kw...)
     if compression isa LZ4FrameCompressor
         A = compress(Meta.CompressionType.LZ4_FRAME, compression, A)
-    elseif compression isa Vector{LZ4FrameCompressor}
-        A = compress(Meta.CompressionType.LZ4_FRAME, 
compression[Threads.threadid()], A)
     elseif compression isa ZstdCompressor
         A = compress(Meta.CompressionType.ZSTD, compression, A)
-    elseif compression isa Vector{ZstdCompressor}
-        A = compress(Meta.CompressionType.ZSTD, 
compression[Threads.threadid()], A)
+    elseif compression isa Symbol && compression == :lz4
+        comp = lz4_frame_compressor()
+        A = Base.@lock comp begin
+            compress(Meta.CompressionType.LZ4_FRAME, comp[], A)
+        end
+    elseif compression isa Symbol && compression == :zstd
+        comp = zstd_compressor()
+        A = Base.@lock comp begin
+            compress(Meta.CompressionType.ZSTD, comp[], A)
+        end
     end
     @debugv 2 "converted top-level column to arrow format: $(typeof(A))"
     @debugv 3 A
diff --git a/src/table.jl b/src/table.jl
index c32fe5a..479ccc3 100644
--- a/src/table.jl
+++ b/src/table.jl
@@ -531,9 +531,15 @@ function uncompress(ptr::Ptr{UInt8}, buffer, compression)
     end
     decodedbytes = Vector{UInt8}(undef, len)
     if compression.codec === Meta.CompressionType.LZ4_FRAME
-        transcode(LZ4FrameDecompressor, encodedbytes, decodedbytes)
+        comp = lz4_frame_decompressor()
+        Base.@lock comp begin
+            transcode(comp[], encodedbytes, decodedbytes)
+        end
     elseif compression.codec === Meta.CompressionType.ZSTD
-        transcode(ZstdDecompressor, encodedbytes, decodedbytes)
+        comp = zstd_decompressor()
+        Base.@lock comp begin
+            transcode(comp[], encodedbytes, decodedbytes)
+        end
     else
         error("unsupported compression type when reading arrow buffers: 
$(typeof(compression.codec))")
     end
diff --git a/src/write.jl b/src/write.jl
index 1c6975d..628f62e 100644
--- a/src/write.jl
+++ b/src/write.jl
@@ -111,7 +111,7 @@ julia> open(Arrow.Writer, tempname()) do writer
 mutable struct Writer{T<:IO}
     io::T
     closeio::Bool
-    
compress::Union{Nothing,LZ4FrameCompressor,Vector{LZ4FrameCompressor},ZstdCompressor,Vector{ZstdCompressor}}
+    compress::Union{Nothing,Symbol,LZ4FrameCompressor,ZstdCompressor}
     writetofile::Bool
     largelists::Bool
     denseunions::Bool
@@ -135,7 +135,10 @@ mutable struct Writer{T<:IO}
     isclosed::Bool
 end
 
-function Base.open(::Type{Writer}, io::T, 
compress::Union{Nothing,LZ4FrameCompressor,<:AbstractVector{LZ4FrameCompressor},ZstdCompressor,<:AbstractVector{ZstdCompressor}},
 writetofile::Bool, largelists::Bool, denseunions::Bool, dictencode::Bool, 
dictencodenested::Bool, alignment::Integer, maxdepth::Integer, ntasks::Integer, 
meta::Union{Nothing,Any}, colmeta::Union{Nothing,Any}, closeio::Bool) where 
{T<:IO}
+function Base.open(::Type{Writer}, io::T, 
compress::Union{Nothing,Symbol,LZ4FrameCompressor,ZstdCompressor}, 
writetofile::Bool, largelists::Bool, denseunions::Bool, dictencode::Bool, 
dictencodenested::Bool, alignment::Integer, maxdepth::Integer, ntasks::Integer, 
meta::Union{Nothing,Any}, colmeta::Union{Nothing,Any}, closeio::Bool) where 
{T<:IO}
+    if compress !== :lz4 && compress !== :zstd
+        throw(ArgumentError("unsupported compress keyword argument value: 
$compress. Valid values include `:lz4` or `:zstd`"))
+    end
     sync = OrderedSynchronizer(2)
     msgs = Channel{Message}(ntasks)
     schema = Ref{Tables.Schema}()
@@ -156,18 +159,7 @@ function Base.open(::Type{Writer}, io::T, 
compress::Union{Nothing,LZ4FrameCompre
     return Writer{T}(io, closeio, compress, writetofile, largelists, 
denseunions, dictencode, dictencodenested, threaded, alignment, maxdepth, meta, 
colmeta, sync, msgs, schema, firstcols, dictencodings, blocks, task, anyerror, 
errorref, 1, false)
 end
 
-function Base.open(::Type{Writer}, io::IO, compress::Symbol, args...)
-    compressor = if compress === :lz4
-        LZ4_FRAME_COMPRESSOR
-    elseif compress === :zstd
-        ZSTD_COMPRESSOR
-    else
-        throw(ArgumentError("unsupported compress keyword argument value: 
$compress. Valid values include `:lz4` or `:zstd`"))
-    end
-    open(Writer, io, compressor, args...)
-end
-
-function Base.open(::Type{Writer}, io::IO; 
compress::Union{Nothing,Symbol,LZ4FrameCompressor,<:AbstractVector{LZ4FrameCompressor},ZstdCompressor,<:AbstractVector{ZstdCompressor}}=nothing,
 file::Bool=true, largelists::Bool=false, denseunions::Bool=true, 
dictencode::Bool=false, dictencodenested::Bool=false, alignment::Integer=8, 
maxdepth::Integer=DEFAULT_MAX_DEPTH, ntasks::Integer=typemax(Int32), 
metadata::Union{Nothing,Any}=nothing, colmetadata::Union{Nothing,Any}=nothing, 
closeio::Bool=false)
+function Base.open(::Type{Writer}, io::IO; 
compress::Union{Nothing,Symbol,LZ4FrameCompressor,ZstdCompressor}=nothing, 
file::Bool=true, largelists::Bool=false, denseunions::Bool=true, 
dictencode::Bool=false, dictencodenested::Bool=false, alignment::Integer=8, 
maxdepth::Integer=DEFAULT_MAX_DEPTH, ntasks::Integer=typemax(Int32), 
metadata::Union{Nothing,Any}=nothing, colmetadata::Union{Nothing,Any}=nothing, 
closeio::Bool=false)
     open(Writer, io, compress, file, largelists, denseunions, dictencode, 
dictencodenested, alignment, maxdepth, ntasks, metadata, colmetadata, closeio)
 end
 

Reply via email to