You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by qu...@apache.org on 2023/05/19 14:16:05 UTC

[arrow-julia] branch main updated: Don't treat Vector{UInt8} as Arrow Binary type (#439)

This is an automated email from the ASF dual-hosted git repository.

quinnj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-julia.git


The following commit(s) were added to refs/heads/main by this push:
     new 899ecb0  Don't treat Vector{UInt8} as Arrow Binary type (#439)
899ecb0 is described below

commit 899ecb0207bb6e4f1c50d9a17dc779d87f5031aa
Author: Jacob Quinn <qu...@gmail.com>
AuthorDate: Fri May 19 08:15:58 2023 -0600

    Don't treat Vector{UInt8} as Arrow Binary type (#439)
    
    Fixes #411. Alternative to #419.
    
    This PR should be compatible with or without the ArrowTypes changes. I
    think it's fine to do compat things in Arrow like this as long as they
    don't get out of hand and we can eventually remove them as we bump
    required ArrowTypes versions and such.
    
    The PR consists of not treating `Vector{UInt8}` as the Arrow Binary
    type, which is meant for "binary string"s. Julia has a pretty good match
    for that in `Base.CodeUnits`, so instead, we use that to write Binary
    and `Vector{UInt8}` is treated as a regular List of Primitive UInt8
    type.
---
 src/ArrowTypes/src/ArrowTypes.jl |  4 ++++
 src/ArrowTypes/test/tests.jl     |  3 +++
 src/arraytypes/list.jl           | 44 ++++++++++++++++++++++++++++++----------
 src/arraytypes/map.jl            |  4 ++--
 src/eltypes.jl                   |  6 +++---
 test/runtests.jl                 | 36 ++++++++++++++++++++++++++++++++
 6 files changed, 81 insertions(+), 16 deletions(-)

diff --git a/src/ArrowTypes/src/ArrowTypes.jl b/src/ArrowTypes/src/ArrowTypes.jl
index 130e07e..1391029 100644
--- a/src/ArrowTypes/src/ArrowTypes.jl
+++ b/src/ArrowTypes/src/ArrowTypes.jl
@@ -208,8 +208,12 @@ isstringtype(::ListKind{stringtype}) where {stringtype} = stringtype
 isstringtype(::Type{ListKind{stringtype}}) where {stringtype} = stringtype
 
 ArrowKind(::Type{<:AbstractString}) = ListKind{true}()
+# Treate Base.CodeUnits as Binary arrow type
+ArrowKind(::Type{<:Base.CodeUnits}) = ListKind{true}()
 
 fromarrow(::Type{T}, ptr::Ptr{UInt8}, len::Int) where {T} = fromarrow(T, unsafe_string(ptr, len))
+fromarrow(::Type{T}, x) where {T <: Base.CodeUnits} = Base.CodeUnits(x)
+fromarrow(::Type{Union{Missing, Base.CodeUnits}}, x) = x === missing ? missing : Base.CodeUnits(x)
 
 ArrowType(::Type{Symbol}) = String
 toarrow(x::Symbol) = String(x)
diff --git a/src/ArrowTypes/test/tests.jl b/src/ArrowTypes/test/tests.jl
index e035b48..26f3945 100644
--- a/src/ArrowTypes/test/tests.jl
+++ b/src/ArrowTypes/test/tests.jl
@@ -72,9 +72,12 @@ end
 @test !ArrowTypes.isstringtype(ArrowTypes.ListKind())
 @test !ArrowTypes.isstringtype(typeof(ArrowTypes.ListKind()))
 @test ArrowTypes.ArrowKind(String) == ArrowTypes.ListKind{true}()
+@test ArrowTypes.ArrowKind(Base.CodeUnits) == ArrowTypes.ListKind{true}()
 
 hey = collect(b"hey")
 @test ArrowTypes.fromarrow(String, pointer(hey), 3) == "hey"
+@test ArrowTypes.fromarrow(Base.CodeUnits, pointer(hey), 3) == b"hey"
+@test ArrowTypes.fromarrow(Union{Base.CodeUnits, Missing}, pointer(hey), 3) == b"hey"
 
 @test ArrowTypes.ArrowType(Symbol) == String
 @test ArrowTypes.toarrow(:hey) == "hey"
diff --git a/src/arraytypes/list.jl b/src/arraytypes/list.jl
index 1525f38..2275f9f 100644
--- a/src/arraytypes/list.jl
+++ b/src/arraytypes/list.jl
@@ -49,11 +49,20 @@ Base.size(l::List) = (l.ℓ,)
     @inbounds lo, hi = l.offsets[i]
     S = Base.nonmissingtype(T)
     K = ArrowTypes.ArrowKind(ArrowTypes.ArrowType(S))
-    if ArrowTypes.isstringtype(K)
+    # special-case Base.CodeUnits for ArrowTypes compat
+    if ArrowTypes.isstringtype(K) || S <: Base.CodeUnits
         if S !== T
-            return l.validity[i] ? ArrowTypes.fromarrow(T, pointer(l.data, lo), hi - lo + 1) : missing
+            if S <: Base.CodeUnits
+                return l.validity[i] ? Base.CodeUnits(unsafe_string(pointer(l.data, lo), hi - lo + 1)) : missing
+            else
+                return l.validity[i] ? ArrowTypes.fromarrow(T, pointer(l.data, lo), hi - lo + 1) : missing
+            end
         else
-            return ArrowTypes.fromarrow(T, pointer(l.data, lo), hi - lo + 1)
+            if S <: Base.CodeUnits
+                return Base.CodeUnits(unsafe_string(pointer(l.data, lo), hi - lo + 1))
+            else
+                return ArrowTypes.fromarrow(T, pointer(l.data, lo), hi - lo + 1)
+            end
         end
     elseif S !== T
         return l.validity[i] ? ArrowTypes.fromarrow(T, view(l.data, lo:hi)) : missing
@@ -66,6 +75,12 @@ end
 
 # end
 
+# internal interface definitions to be able to treat AbstractString/CodeUnits similarly
+_ncodeunits(x::AbstractString) = ncodeunits(x)
+_codeunits(x::AbstractString) = codeunits(x)
+_ncodeunits(x::Base.CodeUnits) = length(x)
+_codeunits(x::Base.CodeUnits) = x
+
 # an AbstractVector version of Iterators.flatten
 # code based on SentinelArrays.ChainedVector
 struct ToList{T, stringtype, A, I} <: AbstractVector{T}
@@ -74,14 +89,21 @@ struct ToList{T, stringtype, A, I} <: AbstractVector{T}
 end
 
 origtype(::ToList{T, S, A, I}) where {T, S, A, I} = A
+liststringtype(::Type{ToList{T, S, A, I}}) where {T, S, A, I} = S
+function liststringtype(::List{T, O, A}) where {T, O, A}
+    ST = Base.nonmissingtype(T)
+    K = ArrowTypes.ArrowKind(ST)
+    return liststringtype(A) || ArrowTypes.isstringtype(K) || ST <: Base.CodeUnits # add the CodeUnits check for ArrowTypes compat for now
+end
+liststringtype(T) = false
 
 function ToList(input; largelists::Bool=false)
     AT = eltype(input)
     ST = Base.nonmissingtype(AT)
     K = ArrowTypes.ArrowKind(ST)
-    stringtype = ArrowTypes.isstringtype(K)
+    stringtype = ArrowTypes.isstringtype(K) || ST <: Base.CodeUnits # add the CodeUnits check for ArrowTypes compat for now
     T = stringtype ? UInt8 : eltype(ST)
-    len = stringtype ? ncodeunits : length
+    len = stringtype ? _ncodeunits : length
     data = AT[]
     I = largelists ? Int64 : Int32
     inds = I[0]
@@ -122,7 +144,7 @@ Base.@propagate_inbounds function Base.getindex(A::ToList{T, stringtype}, i::Int
     @boundscheck checkbounds(A, i)
     chunk, ix = index(A, i)
     @inbounds x = A.data[chunk]
-    return @inbounds stringtype ? codeunits(x)[ix] : x[ix]
+    return @inbounds stringtype ? _codeunits(x)[ix] : x[ix]
 end
 
 Base.@propagate_inbounds function Base.setindex!(A::ToList{T, stringtype}, v, i::Integer) where {T, stringtype}
@@ -130,7 +152,7 @@ Base.@propagate_inbounds function Base.setindex!(A::ToList{T, stringtype}, v, i:
     chunk, ix = index(A, i)
     @inbounds x = A.data[chunk]
     if stringtype
-        codeunits(x)[ix] = v
+        _codeunits(x)[ix] = v
     else
         x[ix] = v
     end
@@ -149,7 +171,7 @@ end
         chunk_len = A.inds[chunk]
     end
     val = A.data[chunk - 1]
-    x = stringtype ? codeunits(val)[1] : val[1]
+    x = stringtype ? _codeunits(val)[1] : val[1]
     # find next valid index
     i += 1
     if i > chunk_len
@@ -168,7 +190,7 @@ end
 @inline function Base.iterate(A::ToList{T, stringtype}, (i, chunk, chunk_i, chunk_len, len)) where {T, stringtype}
     i > len && return nothing
     @inbounds val = A.data[chunk - 1]
-    @inbounds x = stringtype ? codeunits(val)[chunk_i] : val[chunk_i]
+    @inbounds x = stringtype ? _codeunits(val)[chunk_i] : val[chunk_i]
     i += 1
     if i > chunk_len
         chunk_i = 1
@@ -191,7 +213,7 @@ function arrowvector(::ListKind, x, i, nl, fi, de, ded, meta; largelists::Bool=f
     validity = ValidityBitmap(x)
     flat = ToList(x; largelists=largelists)
     offsets = Offsets(UInt8[], flat.inds)
-    if eltype(flat) == UInt8 # binary or utf8string
+    if liststringtype(typeof(flat)) && eltype(flat) == UInt8 # binary or utf8string
         data = flat
         T = origtype(flat)
     else
@@ -208,7 +230,7 @@ function compress(Z::Meta.CompressionType, comp, x::List{T, O, A}) where {T, O,
     offsets = compress(Z, comp, x.offsets.offsets)
     buffers = [validity, offsets]
     children = Compressed[]
-    if eltype(A) == UInt8
+    if liststringtype(x)
         push!(buffers, compress(Z, comp, x.data))
     else
         push!(children, compress(Z, comp, x.data))
diff --git a/src/arraytypes/map.jl b/src/arraytypes/map.jl
index 67a791c..8140b25 100644
--- a/src/arraytypes/map.jl
+++ b/src/arraytypes/map.jl
@@ -90,7 +90,7 @@ function makenodesbuffers!(col::Union{Map{T, O, A}, List{T, O, A}}, fieldnodes,
     push!(fieldbuffers, Buffer(bufferoffset, blen))
     @debugv 1 "made field buffer: bufferidx = $(length(fieldbuffers)), offset = $(fieldbuffers[end].offset), len = $(fieldbuffers[end].length), padded = $(padding(fieldbuffers[end].length, alignment))"
     bufferoffset += padding(blen, alignment)
-    if eltype(A) == UInt8
+    if liststringtype(col)
         blen = length(col.data)
         push!(fieldbuffers, Buffer(bufferoffset, blen))
         @debugv 1 "made field buffer: bufferidx = $(length(fieldbuffers)), offset = $(fieldbuffers[end].offset), len = $(fieldbuffers[end].length), padded = $(padding(fieldbuffers[end].length, alignment))"
@@ -110,7 +110,7 @@ function writebuffer(io, col::Union{Map{T, O, A}, List{T, O, A}}, alignment) whe
     @debugv 1 "writing array: col = $(typeof(col.offsets.offsets)), n = $n, padded = $(padding(n, alignment))"
     writezeros(io, paddinglength(n, alignment))
     # write values array
-    if eltype(A) == UInt8
+    if liststringtype(col)
         n = writearray(io, UInt8, col.data)
         @debugv 1 "writing array: col = $(typeof(col.data)), n = $n, padded = $(padding(n, alignment))"
         writezeros(io, paddinglength(n, alignment))
diff --git a/src/eltypes.jl b/src/eltypes.jl
index 4bec444..11a845f 100644
--- a/src/eltypes.jl
+++ b/src/eltypes.jl
@@ -129,7 +129,7 @@ juliaeltype(f::Meta.Field, b::Union{Meta.Utf8, Meta.LargeUtf8}, convert) = Strin
 datasizeof(x) = sizeof(x)
 datasizeof(x::AbstractVector) = sum(datasizeof, x)
 
-juliaeltype(f::Meta.Field, b::Union{Meta.Binary, Meta.LargeBinary}, convert) = Vector{UInt8}
+juliaeltype(f::Meta.Field, b::Union{Meta.Binary, Meta.LargeBinary}, convert) = Base.CodeUnits
 
 juliaeltype(f::Meta.Field, x::Meta.FixedSizeBinary, convert) = NTuple{Int(x.byteWidth), UInt8}
 
@@ -393,7 +393,7 @@ end
 
 # arrowtype will call fieldoffset recursively for children
 function arrowtype(b, x::List{T, O, A}) where {T, O, A}
-    if eltype(A) == UInt8
+    if liststringtype(x)
         if T <: AbstractString || T <: Union{AbstractString, Missing}
             if O == Int32
                 Meta.utf8Start(b)
@@ -402,7 +402,7 @@ function arrowtype(b, x::List{T, O, A}) where {T, O, A}
                 Meta.largUtf8Start(b)
                 return Meta.LargeUtf8, Meta.largUtf8End(b), nothing
             end
-        else # if Vector{UInt8}
+        else # if Base.CodeUnits
             if O == Int32
                 Meta.binaryStart(b)
                 return Meta.Binary, Meta.binaryEnd(b), nothing
diff --git a/test/runtests.jl b/test/runtests.jl
index 511be4e..3cdac88 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -629,6 +629,42 @@ for (col1, col2) in zip(Tables.columns(df), Tables.columns(df_load))
     @test col1 == col2
 end
 
+@testset "# 411" begin
+# Vector{UInt8} are written as List{UInt8} in Arrow
+# Base.CodeUnits are written as Binary
+t = (
+    a=[[0x00, 0x01], UInt8[], [0x03]],
+    am=[[0x00, 0x01], [0x03], missing],
+    b=[b"01", b"", b"3"],
+    bm=[b"01", b"3", missing],
+    c=["a", "b", "c"],
+    cm=["a", "c", missing]
+)
+buf = Arrow.tobuffer(t)
+tt = Arrow.Table(buf)
+@test t.a == tt.a
+@test isequal(t.am, tt.am)
+@test t.b == tt.b
+@test isequal(t.bm, tt.bm)
+@test t.c == tt.c
+@test isequal(t.cm, tt.cm)
+@test Arrow.schema(tt)[].fields[1].type isa Arrow.Flatbuf.List
+@test Arrow.schema(tt)[].fields[3].type isa Arrow.Flatbuf.Binary
+pos = position(buf)
+Arrow.append(buf, tt)
+seekstart(buf)
+buf1 = read(buf, pos)
+buf2 = read(buf)
+t1 = Arrow.Table(buf1)
+t2 = Arrow.Table(buf2)
+@test isequal(t1.a, t2.a)
+@test isequal(t1.am, t2.am)
+@test isequal(t1.b, t2.b)
+@test isequal(t1.bm, t2.bm)
+@test isequal(t1.c, t2.c)
+@test isequal(t1.cm, t2.cm)
+
+end
 
 end # @testset "misc"