baumgold commented on a change in pull request #277:
URL: https://github.com/apache/arrow-julia/pull/277#discussion_r820287759



##########
File path: src/write.jl
##########
@@ -51,131 +51,213 @@ Supported keyword arguments to `Arrow.write` include:
 """
 function write end
 
-write(io_or_file; kw...) = x -> write(io_or_file, x; kw...)
+struct Message
+    msgflatbuf
+    columns
+    bodylen
+    isrecordbatch::Bool
+    blockmsg::Bool
+    headerType
+end
 
-function write(file_path, tbl; metadata=getmetadata(tbl), colmetadata=nothing, 
largelists::Bool=false, compress::Union{Nothing, Symbol, LZ4FrameCompressor, 
ZstdCompressor}=nothing, denseunions::Bool=true, dictencode::Bool=false, 
dictencodenested::Bool=false, alignment::Int=8, 
maxdepth::Int=DEFAULT_MAX_DEPTH, ntasks=Inf, file::Bool=true)
-    open(file_path, "w") do io
-        write(io, tbl, file, largelists, compress, denseunions, dictencode, 
dictencodenested, alignment, maxdepth, ntasks, metadata, colmetadata)
-    end
-    return file_path
+struct Block
+    offset::Int64
+    metaDataLength::Int32
+    bodyLength::Int64
 end
 
-function write(io::IO, tbl; metadata=getmetadata(tbl), colmetadata=nothing, 
largelists::Bool=false, compress::Union{Nothing, Symbol, LZ4FrameCompressor, 
ZstdCompressor}=nothing, denseunions::Bool=true, dictencode::Bool=false, 
dictencodenested::Bool=false, alignment::Int=8, 
maxdepth::Int=DEFAULT_MAX_DEPTH, ntasks=Inf, file::Bool=false)
-    return write(io, tbl, file, largelists, compress, denseunions, dictencode, 
dictencodenested, alignment, maxdepth, ntasks, metadata, colmetadata)
+mutable struct Writer{T<:IO}
+    io::T
+    closeio::Bool
+    
compress::Union{Nothing,LZ4FrameCompressor,Vector{LZ4FrameCompressor},ZstdCompressor,Vector{ZstdCompressor}}
+    writetofile::Bool
+    largelists::Bool
+    denseunions::Bool
+    dictencode::Bool
+    dictencodenested::Bool
+    threaded::Bool
+    alignment::Int32
+    maxdepth::Int64
+    meta::Union{Nothing,Base.ImmutableDict{String,String}}
+    
colmeta::Union{Nothing,Base.ImmutableDict{Symbol,Base.ImmutableDict{String,String}}}
+    msgs::OrderedChannel{Message}
+    schema::Ref{Tables.Schema}
+    firstcols::Ref{Any}
+    dictencodings::Dict{Int64, Any}
+    blocks::NTuple{2, Vector{Block}}
+    task::Task
+    anyerror::Threads.Atomic{Bool}
+    errorref::Ref{Any}
+    partition_count::Int32
+    isclosed::Bool
 end
 
-function write(io, source, writetofile, largelists, compress, denseunions, 
dictencode, dictencodenested, alignment, maxdepth, ntasks, meta, colmeta)
+function Base.open(Writer, io::T, 
compress::Union{Nothing,LZ4FrameCompressor,<:AbstractVector{LZ4FrameCompressor},ZstdCompressor,<:AbstractVector{ZstdCompressor}},
 writetofile::Bool, largelists::Bool, denseunions::Bool, dictencode::Bool, 
dictencodenested::Bool, alignment::Integer, maxdepth::Integer, ntasks::Integer, 
meta::Union{Nothing,Any}, colmeta::Union{Nothing,Any}, closeio::Bool) where 
{T<:IO}
     if ntasks < 1
         throw(ArgumentError("ntasks keyword argument must be > 0; pass 
`ntasks=1` to disable multithreaded writing"))
     end
-    if compress === :lz4
-        compress = LZ4_FRAME_COMPRESSOR
-    elseif compress === :zstd
-        compress = ZSTD_COMPRESSOR
-    elseif compress isa Symbol
-        throw(ArgumentError("unsupported compress keyword argument value: 
$compress. Valid values include `:lz4` or `:zstd`"))
-    end
-    # TODO: we're probably not threadsafe if user passes own single compressor 
instance + ntasks > 1
-    # if ntasks > 1 && compres !== nothing && !(compress isa Vector)
-    #     compress = Threads.resize_nthreads!([compress])
-    # end
-    if writetofile
-        @debug 1 "starting write of arrow formatted file"
-        Base.write(io, "ARROW1\0\0")
-    end
     msgs = OrderedChannel{Message}(ntasks)
-    # build messages
-    sch = Ref{Tables.Schema}()
+    schema = Ref{Tables.Schema}()
     firstcols = Ref{Any}()
     dictencodings = Dict{Int64, Any}() # Lockable{DictEncoding}
     blocks = (Block[], Block[])
     # start message writing from channel
     threaded = ntasks > 1
-    tsk = threaded ? (Threads.@spawn for msg in msgs
-        Base.write(io, msg, blocks, sch, alignment)
+    task = threaded ? (Threads.@spawn for msg in msgs
+        Base.write(io, msg, blocks, schema, alignment)
     end) : (@async for msg in msgs
-        Base.write(io, msg, blocks, sch, alignment)
+        Base.write(io, msg, blocks, schema, alignment)
     end)
     anyerror = Threads.Atomic{Bool}(false)
     errorref = Ref{Any}()
-    @sync for (i, tbl) in enumerate(Tables.partitions(source))
-        if anyerror[]
-            @error "error writing arrow data on partition = $(errorref[][3])" 
exception=(errorref[][1], errorref[][2])
-            error("fatal error writing arrow data")
-        end
-        @debug 1 "processing table partition i = $i"
+    meta = _normalizemeta(meta)
+    colmeta = _normalizecolmeta(colmeta)
+    Writer{T}(io, closeio, compress, writetofile, largelists, denseunions, 
dictencode, dictencodenested, threaded, alignment, maxdepth, meta, colmeta, 
msgs, schema, firstcols, dictencodings, blocks, task, anyerror, errorref, 1, 
false)
+end
+
+function Base.open(Writer, io::IO, compress::Symbol, args...)
+    compressor = if compress === :lz4
+        LZ4_FRAME_COMPRESSOR
+    elseif compress === :zstd
+        ZSTD_COMPRESSOR
+    else
+        throw(ArgumentError("unsupported compress keyword argument value: 
$compress. Valid values include `:lz4` or `:zstd`"))
+    end
+    open(Writer, io, compressor, args...)
+end
+
+function Base.open(::Type{Writer}, io::IO; 
compress::Union{Nothing,Symbol,LZ4FrameCompressor,<:AbstractVector{LZ4FrameCompressor},ZstdCompressor,<:AbstractVector{ZstdCompressor}}=nothing,
 file::Bool=true, largelists::Bool=false, denseunions::Bool=true, 
dictencode::Bool=false, dictencodenested::Bool=false, alignment::Integer=8, 
maxdepth::Integer=DEFAULT_MAX_DEPTH, ntasks::Integer=typemax(Int32), 
metadata::Union{Nothing,Any}=nothing, colmetadata::Union{Nothing,Any}=nothing, 
closeio::Bool = false)
+    open(Writer, io, compress, file, largelists, denseunions, dictencode, 
dictencodenested, alignment, maxdepth, ntasks, metadata, colmetadata, closeio)
+end
+
+Base.open(::Type{Writer}, file_path; kwargs...) = open(Writer, open(file_path, 
"w"); kwargs..., closeio=true)
+
+function check_errors(writer::Writer)
+    if writer.anyerror[]
+        errorref = writer.errorref[]
+        @error "error writing arrow data on partition = $(errorref[3])" 
exception=(errorref[1], errorref[2])
+        error("fatal error writing arrow data")
+    end
+end
+
+function write(writer::Writer, source)
+    @sync for tbl in Tables.partitions(source)
+        check_errors(writer)
+        @debug 1 "processing table partition $(writer.partition_count)"
         tblcols = Tables.columns(tbl)
-        if i == 1
-            cols = toarrowtable(tblcols, dictencodings, largelists, compress, 
denseunions, dictencode, dictencodenested, maxdepth, meta, colmeta)
-            sch[] = Tables.schema(cols)
-            firstcols[] = cols
-            put!(msgs, makeschemamsg(sch[], cols), i)
-            if !isempty(dictencodings)
-                des = sort!(collect(dictencodings); by=x->x.first, rev=true)
+        if !isassigned(writer.firstcols)
+            if writer.writetofile
+                @debug 1 "starting write of arrow formatted file"
+                Base.write(writer.io, "ARROW1\0\0")
+            end
+            meta = isnothing(writer.meta) ? getmetadata(source) : writer.meta
+            cols = toarrowtable(tblcols, writer.dictencodings, 
writer.largelists, writer.compress, writer.denseunions, writer.dictencode, 
writer.dictencodenested, writer.maxdepth, meta, writer.colmeta)
+            writer.schema[] = Tables.schema(cols)
+            writer.firstcols[] = cols
+            put!(writer.msgs, makeschemamsg(writer.schema[], cols), 
writer.partition_count)
+            if !isempty(writer.dictencodings)
+                des = sort!(collect(writer.dictencodings); by=x->x.first, 
rev=true)
                 for (id, delock) in des
                     # assign dict encoding ids
                     de = delock.x
                     dictsch = Tables.Schema((:col,), (eltype(de.data),))
-                    put!(msgs, makedictionarybatchmsg(dictsch, (col=de.data,), 
id, false, alignment), i)
+                    dictbatchmsg = makedictionarybatchmsg(dictsch, 
(col=de.data,), id, false, writer.alignment)
+                    put!(writer.msgs, dictbatchmsg, writer.partition_count)
                 end
             end
-            put!(msgs, makerecordbatchmsg(sch[], cols, alignment), i, true)
+            recbatchmsg = makerecordbatchmsg(writer.schema[], cols, 
writer.alignment)
+            put!(writer.msgs, recbatchmsg, writer.partition_count, true)
         else
-            if threaded
-                Threads.@spawn process_partition(tblcols, dictencodings, 
largelists, compress, denseunions, dictencode, dictencodenested, maxdepth, 
msgs, alignment, i, sch, errorref, anyerror, meta, colmeta)
+            if writer.threaded
+                Threads.@spawn process_partition(tblcols, 
writer.dictencodings, writer.largelists, writer.compress, writer.denseunions, 
writer.dictencode, writer.dictencodenested, writer.maxdepth, writer.msgs, 
writer.alignment, $(writer.partition_count), writer.schema, writer.errorref, 
writer.anyerror, writer.meta, writer.colmeta)
             else
-                @async process_partition(tblcols, dictencodings, largelists, 
compress, denseunions, dictencode, dictencodenested, maxdepth, msgs, alignment, 
i, sch, errorref, anyerror, meta, colmeta)
+                @async process_partition(tblcols, writer.dictencodings, 
writer.largelists, writer.compress, writer.denseunions, writer.dictencode, 
writer.dictencodenested, writer.maxdepth, writer.msgs, writer.alignment, 
$(writer.partition_count), writer.schema, writer.errorref, writer.anyerror, 
writer.meta, writer.colmeta)
             end
         end
+        writer.partition_count += 1
     end
-    if anyerror[]
-        @error "error writing arrow data on partition = $(errorref[][3])" 
exception=(errorref[][1], errorref[][2])
-        error("fatal error writing arrow data")
-    end
+    check_errors(writer)
+    nothing
+end
+
+function Base.close(writer::Writer)
+    writer.isclosed && return
     # close our message-writing channel, no further put!-ing is allowed
-    close(msgs)
+    close(writer.msgs)
     # now wait for our message-writing task to finish writing
-    wait(tsk)
+    !istaskfailed(writer.task) && wait(writer.task)
+    if (!isassigned(writer.schema) || !isassigned(writer.firstcols))
+        writer.closeio && close(writer.io)
+        writer.isclosed = true
+        return
+    end
     # write empty message
-    if !writetofile
-        Base.write(io, Message(UInt8[], nothing, 0, true, false, Meta.Schema), 
blocks, sch, alignment)
+    if !writer.writetofile
+        msg = Message(UInt8[], nothing, 0, true, false, Meta.Schema)
+        Base.write(writer.io, msg, writer.blocks, writer.schema, 
writer.alignment)
+        writer.closeio && close(writer.io)
+        writer.isclosed = true
+        return
     end
-    if writetofile
-        b = FlatBuffers.Builder(1024)
-        schfoot = makeschema(b, sch[], firstcols[])
-        if !isempty(blocks[1])
-            N = length(blocks[1])
-            Meta.footerStartRecordBatchesVector(b, N)
-            for blk in Iterators.reverse(blocks[1])
-                Meta.createBlock(b, blk.offset, blk.metaDataLength, 
blk.bodyLength)
-            end
-            recordbatches = FlatBuffers.endvector!(b, N)
-        else
-            recordbatches = FlatBuffers.UOffsetT(0)
+    b = FlatBuffers.Builder(1024)
+    schfoot = makeschema(b, writer.schema[], writer.firstcols[])
+    recordbatches = if !isempty(writer.blocks[1])
+        N = length(writer.blocks[1])
+        Meta.footerStartRecordBatchesVector(b, N)
+        for blk in Iterators.reverse(writer.blocks[1])
+            Meta.createBlock(b, blk.offset, blk.metaDataLength, blk.bodyLength)
         end
-        if !isempty(blocks[2])
-            N = length(blocks[2])
-            Meta.footerStartDictionariesVector(b, N)
-            for blk in Iterators.reverse(blocks[2])
-                Meta.createBlock(b, blk.offset, blk.metaDataLength, 
blk.bodyLength)
-            end
-            dicts = FlatBuffers.endvector!(b, N)
-        else
-            dicts = FlatBuffers.UOffsetT(0)
+        FlatBuffers.endvector!(b, N)
+    else
+        FlatBuffers.UOffsetT(0)
+    end
+    dicts = if !isempty(writer.blocks[2])
+        N = length(writer.blocks[2])
+        Meta.footerStartDictionariesVector(b, N)
+        for blk in Iterators.reverse(writer.blocks[2])
+            Meta.createBlock(b, blk.offset, blk.metaDataLength, blk.bodyLength)
         end
-        Meta.footerStart(b)
-        Meta.footerAddVersion(b, Meta.MetadataVersions.V4)
-        Meta.footerAddSchema(b, schfoot)
-        Meta.footerAddDictionaries(b, dicts)
-        Meta.footerAddRecordBatches(b, recordbatches)
-        foot = Meta.footerEnd(b)
-        FlatBuffers.finish!(b, foot)
-        footer = FlatBuffers.finishedbytes(b)
-        Base.write(io, footer)
-        Base.write(io, Int32(length(footer)))
-        Base.write(io, "ARROW1")
+        FlatBuffers.endvector!(b, N)
+    else
+        FlatBuffers.UOffsetT(0)
+    end
+    Meta.footerStart(b)
+    Meta.footerAddVersion(b, Meta.MetadataVersions.V4)
+    Meta.footerAddSchema(b, schfoot)
+    Meta.footerAddDictionaries(b, dicts)
+    Meta.footerAddRecordBatches(b, recordbatches)
+    foot = Meta.footerEnd(b)
+    FlatBuffers.finish!(b, foot)
+    footer = FlatBuffers.finishedbytes(b)
+    Base.write(writer.io, footer)
+    Base.write(writer.io, Int32(length(footer)))
+    Base.write(writer.io, "ARROW1")
+    writer.closeio && close(writer.io)
+    writer.isclosed = true
+    nothing
+end
+
+write(io_or_file; kw...) = x -> write(io_or_file, x; kw...)

Review comment:
       Sure, no problem.

##########
File path: src/write.jl
##########
@@ -51,131 +51,213 @@ Supported keyword arguments to `Arrow.write` include:
 """
 function write end
 
-write(io_or_file; kw...) = x -> write(io_or_file, x; kw...)
+struct Message
+    msgflatbuf
+    columns
+    bodylen
+    isrecordbatch::Bool
+    blockmsg::Bool
+    headerType
+end
 
-function write(file_path, tbl; metadata=getmetadata(tbl), colmetadata=nothing, 
largelists::Bool=false, compress::Union{Nothing, Symbol, LZ4FrameCompressor, 
ZstdCompressor}=nothing, denseunions::Bool=true, dictencode::Bool=false, 
dictencodenested::Bool=false, alignment::Int=8, 
maxdepth::Int=DEFAULT_MAX_DEPTH, ntasks=Inf, file::Bool=true)
-    open(file_path, "w") do io
-        write(io, tbl, file, largelists, compress, denseunions, dictencode, 
dictencodenested, alignment, maxdepth, ntasks, metadata, colmetadata)
-    end
-    return file_path
+struct Block
+    offset::Int64
+    metaDataLength::Int32
+    bodyLength::Int64
 end
 
-function write(io::IO, tbl; metadata=getmetadata(tbl), colmetadata=nothing, 
largelists::Bool=false, compress::Union{Nothing, Symbol, LZ4FrameCompressor, 
ZstdCompressor}=nothing, denseunions::Bool=true, dictencode::Bool=false, 
dictencodenested::Bool=false, alignment::Int=8, 
maxdepth::Int=DEFAULT_MAX_DEPTH, ntasks=Inf, file::Bool=false)
-    return write(io, tbl, file, largelists, compress, denseunions, dictencode, 
dictencodenested, alignment, maxdepth, ntasks, metadata, colmetadata)
+mutable struct Writer{T<:IO}
+    io::T
+    closeio::Bool
+    
compress::Union{Nothing,LZ4FrameCompressor,Vector{LZ4FrameCompressor},ZstdCompressor,Vector{ZstdCompressor}}
+    writetofile::Bool
+    largelists::Bool
+    denseunions::Bool
+    dictencode::Bool
+    dictencodenested::Bool
+    threaded::Bool
+    alignment::Int32
+    maxdepth::Int64
+    meta::Union{Nothing,Base.ImmutableDict{String,String}}
+    
colmeta::Union{Nothing,Base.ImmutableDict{Symbol,Base.ImmutableDict{String,String}}}
+    msgs::OrderedChannel{Message}
+    schema::Ref{Tables.Schema}
+    firstcols::Ref{Any}
+    dictencodings::Dict{Int64, Any}
+    blocks::NTuple{2, Vector{Block}}
+    task::Task
+    anyerror::Threads.Atomic{Bool}
+    errorref::Ref{Any}
+    partition_count::Int32
+    isclosed::Bool
 end
 
-function write(io, source, writetofile, largelists, compress, denseunions, 
dictencode, dictencodenested, alignment, maxdepth, ntasks, meta, colmeta)
+function Base.open(Writer, io::T, 
compress::Union{Nothing,LZ4FrameCompressor,<:AbstractVector{LZ4FrameCompressor},ZstdCompressor,<:AbstractVector{ZstdCompressor}},
 writetofile::Bool, largelists::Bool, denseunions::Bool, dictencode::Bool, 
dictencodenested::Bool, alignment::Integer, maxdepth::Integer, ntasks::Integer, 
meta::Union{Nothing,Any}, colmeta::Union{Nothing,Any}, closeio::Bool) where 
{T<:IO}
     if ntasks < 1
         throw(ArgumentError("ntasks keyword argument must be > 0; pass 
`ntasks=1` to disable multithreaded writing"))
     end
-    if compress === :lz4
-        compress = LZ4_FRAME_COMPRESSOR
-    elseif compress === :zstd
-        compress = ZSTD_COMPRESSOR
-    elseif compress isa Symbol
-        throw(ArgumentError("unsupported compress keyword argument value: 
$compress. Valid values include `:lz4` or `:zstd`"))
-    end
-    # TODO: we're probably not threadsafe if user passes own single compressor 
instance + ntasks > 1
-    # if ntasks > 1 && compres !== nothing && !(compress isa Vector)
-    #     compress = Threads.resize_nthreads!([compress])
-    # end
-    if writetofile
-        @debug 1 "starting write of arrow formatted file"
-        Base.write(io, "ARROW1\0\0")
-    end
     msgs = OrderedChannel{Message}(ntasks)
-    # build messages
-    sch = Ref{Tables.Schema}()
+    schema = Ref{Tables.Schema}()
     firstcols = Ref{Any}()
     dictencodings = Dict{Int64, Any}() # Lockable{DictEncoding}
     blocks = (Block[], Block[])
     # start message writing from channel
     threaded = ntasks > 1
-    tsk = threaded ? (Threads.@spawn for msg in msgs
-        Base.write(io, msg, blocks, sch, alignment)
+    task = threaded ? (Threads.@spawn for msg in msgs
+        Base.write(io, msg, blocks, schema, alignment)
     end) : (@async for msg in msgs
-        Base.write(io, msg, blocks, sch, alignment)
+        Base.write(io, msg, blocks, schema, alignment)
     end)
     anyerror = Threads.Atomic{Bool}(false)
     errorref = Ref{Any}()
-    @sync for (i, tbl) in enumerate(Tables.partitions(source))
-        if anyerror[]
-            @error "error writing arrow data on partition = $(errorref[][3])" 
exception=(errorref[][1], errorref[][2])
-            error("fatal error writing arrow data")
-        end
-        @debug 1 "processing table partition i = $i"
+    meta = _normalizemeta(meta)
+    colmeta = _normalizecolmeta(colmeta)
+    Writer{T}(io, closeio, compress, writetofile, largelists, denseunions, 
dictencode, dictencodenested, threaded, alignment, maxdepth, meta, colmeta, 
msgs, schema, firstcols, dictencodings, blocks, task, anyerror, errorref, 1, 
false)
+end
+
+function Base.open(Writer, io::IO, compress::Symbol, args...)

Review comment:
       Indeed, my mistake.  Good catch - thanks.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to