You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by qu...@apache.org on 2023/05/19 14:16:05 UTC
[arrow-julia] branch main updated: Don't treat Vector{UInt8} as Arrow Binary type (#439)
This is an automated email from the ASF dual-hosted git repository.
quinnj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-julia.git
The following commit(s) were added to refs/heads/main by this push:
new 899ecb0 Don't treat Vector{UInt8} as Arrow Binary type (#439)
899ecb0 is described below
commit 899ecb0207bb6e4f1c50d9a17dc779d87f5031aa
Author: Jacob Quinn <qu...@gmail.com>
AuthorDate: Fri May 19 08:15:58 2023 -0600
Don't treat Vector{UInt8} as Arrow Binary type (#439)
Fixes #411. Alternative to #419.
This PR should be compatible with or without the ArrowTypes changes. I
think it's fine to do compat things in Arrow like this as long as they
don't get out of hand and we can eventually remove them as we bump
required ArrowTypes versions and such.
The PR consists of not treating `Vector{UInt8}` as the Arrow Binary
type, which is meant for "binary string"s. Julia has a pretty good match
for that in `Base.CodeUnits`, so instead, we use that to write Binary
and `Vector{UInt8}` is treated as a regular List of Primitive UInt8
type.
---
src/ArrowTypes/src/ArrowTypes.jl | 4 ++++
src/ArrowTypes/test/tests.jl | 3 +++
src/arraytypes/list.jl | 44 ++++++++++++++++++++++++++++++----------
src/arraytypes/map.jl | 4 ++--
src/eltypes.jl | 6 +++---
test/runtests.jl | 36 ++++++++++++++++++++++++++++++++
6 files changed, 81 insertions(+), 16 deletions(-)
diff --git a/src/ArrowTypes/src/ArrowTypes.jl b/src/ArrowTypes/src/ArrowTypes.jl
index 130e07e..1391029 100644
--- a/src/ArrowTypes/src/ArrowTypes.jl
+++ b/src/ArrowTypes/src/ArrowTypes.jl
@@ -208,8 +208,12 @@ isstringtype(::ListKind{stringtype}) where {stringtype} = stringtype
isstringtype(::Type{ListKind{stringtype}}) where {stringtype} = stringtype
ArrowKind(::Type{<:AbstractString}) = ListKind{true}()
+# Treate Base.CodeUnits as Binary arrow type
+ArrowKind(::Type{<:Base.CodeUnits}) = ListKind{true}()
fromarrow(::Type{T}, ptr::Ptr{UInt8}, len::Int) where {T} = fromarrow(T, unsafe_string(ptr, len))
+fromarrow(::Type{T}, x) where {T <: Base.CodeUnits} = Base.CodeUnits(x)
+fromarrow(::Type{Union{Missing, Base.CodeUnits}}, x) = x === missing ? missing : Base.CodeUnits(x)
ArrowType(::Type{Symbol}) = String
toarrow(x::Symbol) = String(x)
diff --git a/src/ArrowTypes/test/tests.jl b/src/ArrowTypes/test/tests.jl
index e035b48..26f3945 100644
--- a/src/ArrowTypes/test/tests.jl
+++ b/src/ArrowTypes/test/tests.jl
@@ -72,9 +72,12 @@ end
@test !ArrowTypes.isstringtype(ArrowTypes.ListKind())
@test !ArrowTypes.isstringtype(typeof(ArrowTypes.ListKind()))
@test ArrowTypes.ArrowKind(String) == ArrowTypes.ListKind{true}()
+@test ArrowTypes.ArrowKind(Base.CodeUnits) == ArrowTypes.ListKind{true}()
hey = collect(b"hey")
@test ArrowTypes.fromarrow(String, pointer(hey), 3) == "hey"
+@test ArrowTypes.fromarrow(Base.CodeUnits, pointer(hey), 3) == b"hey"
+@test ArrowTypes.fromarrow(Union{Base.CodeUnits, Missing}, pointer(hey), 3) == b"hey"
@test ArrowTypes.ArrowType(Symbol) == String
@test ArrowTypes.toarrow(:hey) == "hey"
diff --git a/src/arraytypes/list.jl b/src/arraytypes/list.jl
index 1525f38..2275f9f 100644
--- a/src/arraytypes/list.jl
+++ b/src/arraytypes/list.jl
@@ -49,11 +49,20 @@ Base.size(l::List) = (l.ℓ,)
@inbounds lo, hi = l.offsets[i]
S = Base.nonmissingtype(T)
K = ArrowTypes.ArrowKind(ArrowTypes.ArrowType(S))
- if ArrowTypes.isstringtype(K)
+ # special-case Base.CodeUnits for ArrowTypes compat
+ if ArrowTypes.isstringtype(K) || S <: Base.CodeUnits
if S !== T
- return l.validity[i] ? ArrowTypes.fromarrow(T, pointer(l.data, lo), hi - lo + 1) : missing
+ if S <: Base.CodeUnits
+ return l.validity[i] ? Base.CodeUnits(unsafe_string(pointer(l.data, lo), hi - lo + 1)) : missing
+ else
+ return l.validity[i] ? ArrowTypes.fromarrow(T, pointer(l.data, lo), hi - lo + 1) : missing
+ end
else
- return ArrowTypes.fromarrow(T, pointer(l.data, lo), hi - lo + 1)
+ if S <: Base.CodeUnits
+ return Base.CodeUnits(unsafe_string(pointer(l.data, lo), hi - lo + 1))
+ else
+ return ArrowTypes.fromarrow(T, pointer(l.data, lo), hi - lo + 1)
+ end
end
elseif S !== T
return l.validity[i] ? ArrowTypes.fromarrow(T, view(l.data, lo:hi)) : missing
@@ -66,6 +75,12 @@ end
# end
+# internal interface definitions to be able to treat AbstractString/CodeUnits similarly
+_ncodeunits(x::AbstractString) = ncodeunits(x)
+_codeunits(x::AbstractString) = codeunits(x)
+_ncodeunits(x::Base.CodeUnits) = length(x)
+_codeunits(x::Base.CodeUnits) = x
+
# an AbstractVector version of Iterators.flatten
# code based on SentinelArrays.ChainedVector
struct ToList{T, stringtype, A, I} <: AbstractVector{T}
@@ -74,14 +89,21 @@ struct ToList{T, stringtype, A, I} <: AbstractVector{T}
end
origtype(::ToList{T, S, A, I}) where {T, S, A, I} = A
+liststringtype(::Type{ToList{T, S, A, I}}) where {T, S, A, I} = S
+function liststringtype(::List{T, O, A}) where {T, O, A}
+ ST = Base.nonmissingtype(T)
+ K = ArrowTypes.ArrowKind(ST)
+ return liststringtype(A) || ArrowTypes.isstringtype(K) || ST <: Base.CodeUnits # add the CodeUnits check for ArrowTypes compat for now
+end
+liststringtype(T) = false
function ToList(input; largelists::Bool=false)
AT = eltype(input)
ST = Base.nonmissingtype(AT)
K = ArrowTypes.ArrowKind(ST)
- stringtype = ArrowTypes.isstringtype(K)
+ stringtype = ArrowTypes.isstringtype(K) || ST <: Base.CodeUnits # add the CodeUnits check for ArrowTypes compat for now
T = stringtype ? UInt8 : eltype(ST)
- len = stringtype ? ncodeunits : length
+ len = stringtype ? _ncodeunits : length
data = AT[]
I = largelists ? Int64 : Int32
inds = I[0]
@@ -122,7 +144,7 @@ Base.@propagate_inbounds function Base.getindex(A::ToList{T, stringtype}, i::Int
@boundscheck checkbounds(A, i)
chunk, ix = index(A, i)
@inbounds x = A.data[chunk]
- return @inbounds stringtype ? codeunits(x)[ix] : x[ix]
+ return @inbounds stringtype ? _codeunits(x)[ix] : x[ix]
end
Base.@propagate_inbounds function Base.setindex!(A::ToList{T, stringtype}, v, i::Integer) where {T, stringtype}
@@ -130,7 +152,7 @@ Base.@propagate_inbounds function Base.setindex!(A::ToList{T, stringtype}, v, i:
chunk, ix = index(A, i)
@inbounds x = A.data[chunk]
if stringtype
- codeunits(x)[ix] = v
+ _codeunits(x)[ix] = v
else
x[ix] = v
end
@@ -149,7 +171,7 @@ end
chunk_len = A.inds[chunk]
end
val = A.data[chunk - 1]
- x = stringtype ? codeunits(val)[1] : val[1]
+ x = stringtype ? _codeunits(val)[1] : val[1]
# find next valid index
i += 1
if i > chunk_len
@@ -168,7 +190,7 @@ end
@inline function Base.iterate(A::ToList{T, stringtype}, (i, chunk, chunk_i, chunk_len, len)) where {T, stringtype}
i > len && return nothing
@inbounds val = A.data[chunk - 1]
- @inbounds x = stringtype ? codeunits(val)[chunk_i] : val[chunk_i]
+ @inbounds x = stringtype ? _codeunits(val)[chunk_i] : val[chunk_i]
i += 1
if i > chunk_len
chunk_i = 1
@@ -191,7 +213,7 @@ function arrowvector(::ListKind, x, i, nl, fi, de, ded, meta; largelists::Bool=f
validity = ValidityBitmap(x)
flat = ToList(x; largelists=largelists)
offsets = Offsets(UInt8[], flat.inds)
- if eltype(flat) == UInt8 # binary or utf8string
+ if liststringtype(typeof(flat)) && eltype(flat) == UInt8 # binary or utf8string
data = flat
T = origtype(flat)
else
@@ -208,7 +230,7 @@ function compress(Z::Meta.CompressionType, comp, x::List{T, O, A}) where {T, O,
offsets = compress(Z, comp, x.offsets.offsets)
buffers = [validity, offsets]
children = Compressed[]
- if eltype(A) == UInt8
+ if liststringtype(x)
push!(buffers, compress(Z, comp, x.data))
else
push!(children, compress(Z, comp, x.data))
diff --git a/src/arraytypes/map.jl b/src/arraytypes/map.jl
index 67a791c..8140b25 100644
--- a/src/arraytypes/map.jl
+++ b/src/arraytypes/map.jl
@@ -90,7 +90,7 @@ function makenodesbuffers!(col::Union{Map{T, O, A}, List{T, O, A}}, fieldnodes,
push!(fieldbuffers, Buffer(bufferoffset, blen))
@debugv 1 "made field buffer: bufferidx = $(length(fieldbuffers)), offset = $(fieldbuffers[end].offset), len = $(fieldbuffers[end].length), padded = $(padding(fieldbuffers[end].length, alignment))"
bufferoffset += padding(blen, alignment)
- if eltype(A) == UInt8
+ if liststringtype(col)
blen = length(col.data)
push!(fieldbuffers, Buffer(bufferoffset, blen))
@debugv 1 "made field buffer: bufferidx = $(length(fieldbuffers)), offset = $(fieldbuffers[end].offset), len = $(fieldbuffers[end].length), padded = $(padding(fieldbuffers[end].length, alignment))"
@@ -110,7 +110,7 @@ function writebuffer(io, col::Union{Map{T, O, A}, List{T, O, A}}, alignment) whe
@debugv 1 "writing array: col = $(typeof(col.offsets.offsets)), n = $n, padded = $(padding(n, alignment))"
writezeros(io, paddinglength(n, alignment))
# write values array
- if eltype(A) == UInt8
+ if liststringtype(col)
n = writearray(io, UInt8, col.data)
@debugv 1 "writing array: col = $(typeof(col.data)), n = $n, padded = $(padding(n, alignment))"
writezeros(io, paddinglength(n, alignment))
diff --git a/src/eltypes.jl b/src/eltypes.jl
index 4bec444..11a845f 100644
--- a/src/eltypes.jl
+++ b/src/eltypes.jl
@@ -129,7 +129,7 @@ juliaeltype(f::Meta.Field, b::Union{Meta.Utf8, Meta.LargeUtf8}, convert) = Strin
datasizeof(x) = sizeof(x)
datasizeof(x::AbstractVector) = sum(datasizeof, x)
-juliaeltype(f::Meta.Field, b::Union{Meta.Binary, Meta.LargeBinary}, convert) = Vector{UInt8}
+juliaeltype(f::Meta.Field, b::Union{Meta.Binary, Meta.LargeBinary}, convert) = Base.CodeUnits
juliaeltype(f::Meta.Field, x::Meta.FixedSizeBinary, convert) = NTuple{Int(x.byteWidth), UInt8}
@@ -393,7 +393,7 @@ end
# arrowtype will call fieldoffset recursively for children
function arrowtype(b, x::List{T, O, A}) where {T, O, A}
- if eltype(A) == UInt8
+ if liststringtype(x)
if T <: AbstractString || T <: Union{AbstractString, Missing}
if O == Int32
Meta.utf8Start(b)
@@ -402,7 +402,7 @@ function arrowtype(b, x::List{T, O, A}) where {T, O, A}
Meta.largUtf8Start(b)
return Meta.LargeUtf8, Meta.largUtf8End(b), nothing
end
- else # if Vector{UInt8}
+ else # if Base.CodeUnits
if O == Int32
Meta.binaryStart(b)
return Meta.Binary, Meta.binaryEnd(b), nothing
diff --git a/test/runtests.jl b/test/runtests.jl
index 511be4e..3cdac88 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -629,6 +629,42 @@ for (col1, col2) in zip(Tables.columns(df), Tables.columns(df_load))
@test col1 == col2
end
+@testset "# 411" begin
+# Vector{UInt8} are written as List{UInt8} in Arrow
+# Base.CodeUnits are written as Binary
+t = (
+ a=[[0x00, 0x01], UInt8[], [0x03]],
+ am=[[0x00, 0x01], [0x03], missing],
+ b=[b"01", b"", b"3"],
+ bm=[b"01", b"3", missing],
+ c=["a", "b", "c"],
+ cm=["a", "c", missing]
+)
+buf = Arrow.tobuffer(t)
+tt = Arrow.Table(buf)
+@test t.a == tt.a
+@test isequal(t.am, tt.am)
+@test t.b == tt.b
+@test isequal(t.bm, tt.bm)
+@test t.c == tt.c
+@test isequal(t.cm, tt.cm)
+@test Arrow.schema(tt)[].fields[1].type isa Arrow.Flatbuf.List
+@test Arrow.schema(tt)[].fields[3].type isa Arrow.Flatbuf.Binary
+pos = position(buf)
+Arrow.append(buf, tt)
+seekstart(buf)
+buf1 = read(buf, pos)
+buf2 = read(buf)
+t1 = Arrow.Table(buf1)
+t2 = Arrow.Table(buf2)
+@test isequal(t1.a, t2.a)
+@test isequal(t1.am, t2.am)
+@test isequal(t1.b, t2.b)
+@test isequal(t1.bm, t2.bm)
+@test isequal(t1.c, t2.c)
+@test isequal(t1.cm, t2.cm)
+
+end
end # @testset "misc"