This is an automated email from the ASF dual-hosted git repository. quinnj pushed a commit to branch jq-table-partitions in repository https://gitbox.apache.org/repos/asf/arrow-julia.git
commit 66399b2fd9118bac3f204e5dbb5310800a2d6e0f Author: Jacob Quinn <[email protected]> AuthorDate: Tue May 23 21:54:54 2023 -0600 Add Tables.partitions definition for Arrow.Table We had this functionality w/ `Arrow.Stream`, but it's convenient and not that expensive to define it for `Arrow.Table` as well. Fixes #293. --- src/table.jl | 36 ++++++++++++++++++++++++++++++++++++ test/runtests.jl | 13 +++++++++++++ 2 files changed, 49 insertions(+) diff --git a/src/table.jl b/src/table.jl index 49b6153..9d7ddef 100644 --- a/src/table.jl +++ b/src/table.jl @@ -261,6 +261,7 @@ types(t::Table) = getfield(t, :types) columns(t::Table) = getfield(t, :columns) lookup(t::Table) = getfield(t, :lookup) schema(t::Table) = getfield(t, :schema) +metadata(t::Table) = getfield(t, :metadata) """ Arrow.getmetadata(x) @@ -286,6 +287,41 @@ Tables.columnnames(t::Table) = names(t) Tables.getcolumn(t::Table, i::Int) = columns(t)[i] Tables.getcolumn(t::Table, nm::Symbol) = lookup(t)[nm] +struct TablePartitions + table::Table + npartitions::Int +end + +function TablePartitions(table::Table) + cols = columns(table) + npartitions = if length(cols) == 0 + 0 + elseif cols[1] isa ChainedVector + length(cols[1].arrays) + else + 1 + end + return TablePartitions(table, npartitions) +end + +function Base.iterate(tp::TablePartitions, i=1) + i > tp.npartitions && return nothing + tp.npartitions == 1 && return tp.table, i + 1 + cols = columns(tp.table) + newcols = AbstractVector[cols[j].arrays[i] for j in 1:length(cols)] + nms = names(tp.table) + tbl = Table( + nms, + types(tp.table), + newcols, + Dict{Symbol, AbstractVector}(nms[i] => newcols[i] for i in 1:length(nms)), + schema(tp.table) + ) + return tbl, i + 1 +end + +Tables.partitions(t::Table) = TablePartitions(t) + # high-level user API functions Table(input, pos::Integer=1, len=nothing; kw...) = Table([ArrowBlob(tobytes(input), pos, len)]; kw...) Table(input::Vector{UInt8}, pos::Integer=1, len=nothing; kw...) = Table([ArrowBlob(tobytes(input), pos, len)]; kw...) diff --git a/test/runtests.jl b/test/runtests.jl index 47a137f..c477462 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -674,6 +674,19 @@ t = Arrow.Table(joinpath(dirname(pathof(Arrow)), "../test/java_compress_len_neg_ end +@testset "# 293" begin + +t = (a = [1, 2, 3], b = [1.0, 2.0, 3.0]) +buf = Arrow.tobuffer(t) +tbl = Arrow.Table(buf) +parts = Tables.partitioner((t, t)) +buf2 = Arrow.tobuffer(parts) +tbl2 = Arrow.Table(buf2) +for t in Tables.partitions(tbl2) + @test t.a == tbl.a + @test t.b == tbl.b +end + end # @testset "misc" end
