You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by qu...@apache.org on 2023/05/24 03:55:57 UTC
[arrow-julia] 01/01: Add Tables.partitions definition for Arrow.Table
This is an automated email from the ASF dual-hosted git repository.
quinnj pushed a commit to branch jq-table-partitions
in repository https://gitbox.apache.org/repos/asf/arrow-julia.git
commit 66399b2fd9118bac3f204e5dbb5310800a2d6e0f
Author: Jacob Quinn <qu...@gmail.com>
AuthorDate: Tue May 23 21:54:54 2023 -0600
Add Tables.partitions definition for Arrow.Table
We had this functionality w/ `Arrow.Stream`, but it's convenient and not
that expensive to define it for `Arrow.Table` as well.
Fixes #293.
---
src/table.jl | 36 ++++++++++++++++++++++++++++++++++++
test/runtests.jl | 13 +++++++++++++
2 files changed, 49 insertions(+)
diff --git a/src/table.jl b/src/table.jl
index 49b6153..9d7ddef 100644
--- a/src/table.jl
+++ b/src/table.jl
@@ -261,6 +261,7 @@ types(t::Table) = getfield(t, :types)
columns(t::Table) = getfield(t, :columns)
lookup(t::Table) = getfield(t, :lookup)
schema(t::Table) = getfield(t, :schema)
+metadata(t::Table) = getfield(t, :metadata)
"""
Arrow.getmetadata(x)
@@ -286,6 +287,41 @@ Tables.columnnames(t::Table) = names(t)
Tables.getcolumn(t::Table, i::Int) = columns(t)[i]
Tables.getcolumn(t::Table, nm::Symbol) = lookup(t)[nm]
+struct TablePartitions
+ table::Table
+ npartitions::Int
+end
+
+function TablePartitions(table::Table)
+ cols = columns(table)
+ npartitions = if length(cols) == 0
+ 0
+ elseif cols[1] isa ChainedVector
+ length(cols[1].arrays)
+ else
+ 1
+ end
+ return TablePartitions(table, npartitions)
+end
+
+function Base.iterate(tp::TablePartitions, i=1)
+ i > tp.npartitions && return nothing
+ tp.npartitions == 1 && return tp.table, i + 1
+ cols = columns(tp.table)
+ newcols = AbstractVector[cols[j].arrays[i] for j in 1:length(cols)]
+ nms = names(tp.table)
+ tbl = Table(
+ nms,
+ types(tp.table),
+ newcols,
+ Dict{Symbol, AbstractVector}(nms[i] => newcols[i] for i in 1:length(nms)),
+ schema(tp.table)
+ )
+ return tbl, i + 1
+end
+
+Tables.partitions(t::Table) = TablePartitions(t)
+
# high-level user API functions
Table(input, pos::Integer=1, len=nothing; kw...) = Table([ArrowBlob(tobytes(input), pos, len)]; kw...)
Table(input::Vector{UInt8}, pos::Integer=1, len=nothing; kw...) = Table([ArrowBlob(tobytes(input), pos, len)]; kw...)
diff --git a/test/runtests.jl b/test/runtests.jl
index 47a137f..c477462 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -674,6 +674,19 @@ t = Arrow.Table(joinpath(dirname(pathof(Arrow)), "../test/java_compress_len_neg_
end
+@testset "# 293" begin
+
+t = (a = [1, 2, 3], b = [1.0, 2.0, 3.0])
+buf = Arrow.tobuffer(t)
+tbl = Arrow.Table(buf)
+parts = Tables.partitioner((t, t))
+buf2 = Arrow.tobuffer(parts)
+tbl2 = Arrow.Table(buf2)
+for t in Tables.partitions(tbl2)
+ @test t.a == tbl.a
+ @test t.b == tbl.b
+end
+
end # @testset "misc"
end