You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by qu...@apache.org on 2023/05/24 03:55:57 UTC

[arrow-julia] 01/01: Add Tables.partitions definition for Arrow.Table

This is an automated email from the ASF dual-hosted git repository.

quinnj pushed a commit to branch jq-table-partitions
in repository https://gitbox.apache.org/repos/asf/arrow-julia.git

commit 66399b2fd9118bac3f204e5dbb5310800a2d6e0f
Author: Jacob Quinn <qu...@gmail.com>
AuthorDate: Tue May 23 21:54:54 2023 -0600

    Add Tables.partitions definition for Arrow.Table
    
    We had this functionality w/ `Arrow.Stream`, but it's convenient and not
    that expensive to define it for `Arrow.Table` as well.
    
    Fixes #293.
---
 src/table.jl     | 36 ++++++++++++++++++++++++++++++++++++
 test/runtests.jl | 13 +++++++++++++
 2 files changed, 49 insertions(+)

diff --git a/src/table.jl b/src/table.jl
index 49b6153..9d7ddef 100644
--- a/src/table.jl
+++ b/src/table.jl
@@ -261,6 +261,7 @@ types(t::Table) = getfield(t, :types)
 columns(t::Table) = getfield(t, :columns)
 lookup(t::Table) = getfield(t, :lookup)
 schema(t::Table) = getfield(t, :schema)
+metadata(t::Table) = getfield(t, :metadata)
 
 """
     Arrow.getmetadata(x)
@@ -286,6 +287,41 @@ Tables.columnnames(t::Table) = names(t)
 Tables.getcolumn(t::Table, i::Int) = columns(t)[i]
 Tables.getcolumn(t::Table, nm::Symbol) = lookup(t)[nm]
 
+struct TablePartitions
+    table::Table
+    npartitions::Int
+end
+
+function TablePartitions(table::Table)
+    cols = columns(table)
+    npartitions = if length(cols) == 0
+        0
+    elseif cols[1] isa ChainedVector
+        length(cols[1].arrays)
+    else
+        1
+    end
+    return TablePartitions(table, npartitions)
+end
+
+function Base.iterate(tp::TablePartitions, i=1)
+    i > tp.npartitions && return nothing
+    tp.npartitions == 1 && return tp.table, i + 1
+    cols = columns(tp.table)
+    newcols = AbstractVector[cols[j].arrays[i] for j in 1:length(cols)]
+    nms = names(tp.table)
+    tbl = Table(
+        nms,
+        types(tp.table),
+        newcols,
+        Dict{Symbol, AbstractVector}(nms[i] => newcols[i] for i in 1:length(nms)),
+        schema(tp.table)
+    )
+    return tbl, i + 1
+end
+
+Tables.partitions(t::Table) = TablePartitions(t)
+
 # high-level user API functions
 Table(input, pos::Integer=1, len=nothing; kw...) = Table([ArrowBlob(tobytes(input), pos, len)]; kw...)
 Table(input::Vector{UInt8}, pos::Integer=1, len=nothing; kw...) = Table([ArrowBlob(tobytes(input), pos, len)]; kw...)
diff --git a/test/runtests.jl b/test/runtests.jl
index 47a137f..c477462 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -674,6 +674,19 @@ t = Arrow.Table(joinpath(dirname(pathof(Arrow)), "../test/java_compress_len_neg_
 
 end
 
+@testset "# 293" begin
+
+t = (a = [1, 2, 3], b = [1.0, 2.0, 3.0])
+buf = Arrow.tobuffer(t)
+tbl = Arrow.Table(buf)
+parts = Tables.partitioner((t, t))
+buf2 = Arrow.tobuffer(parts)
+tbl2 = Arrow.Table(buf2)
+for t in Tables.partitions(tbl2)
+    @test t.a == tbl.a
+    @test t.b == tbl.b
+end
+
 end # @testset "misc"
 
 end