You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by qu...@apache.org on 2023/05/18 23:00:24 UTC
[arrow-julia] 01/01: Don't treat Vector{UInt8} as Arrow Binary type
This is an automated email from the ASF dual-hosted git repository.
quinnj pushed a commit to branch jq-byte-vector-not-binary
in repository https://gitbox.apache.org/repos/asf/arrow-julia.git
commit 23c13e1b87ae5b9b6e69935026ec863c3a52693d
Author: Jacob Quinn <qu...@gmail.com>
AuthorDate: Thu May 18 16:57:57 2023 -0600
Don't treat Vector{UInt8} as Arrow Binary type
Fixes #411. Alternative to #419.
This PR should be compatible with or without the ArrowTypes changes.
I think it's fine to do compat things in Arrow like this as long as
they don't get out of hand and we can eventually remove them as we
bump required ArrowTypes versions and such.
The PR consists of not treating `Vector{UInt8}` as the Arrow Binary
type, which is meant for "binary string"s. Julia has a pretty good
match for that in `Base.CodeUnits`, so instead, we use that to write
Binary and `Vector{UInt8}` is treated as a regular List of Primitive UInt8
type.
---
src/ArrowTypes/src/ArrowTypes.jl | 3 +++
src/ArrowTypes/test/tests.jl | 3 +++
src/arraytypes/list.jl | 44 ++++++++++++++++++++++++++++++----------
src/arraytypes/map.jl | 4 ++--
src/eltypes.jl | 6 +++---
test/runtests.jl | 36 ++++++++++++++++++++++++++++++++
6 files changed, 80 insertions(+), 16 deletions(-)
diff --git a/src/ArrowTypes/src/ArrowTypes.jl b/src/ArrowTypes/src/ArrowTypes.jl
index 130e07e..2850eeb 100644
--- a/src/ArrowTypes/src/ArrowTypes.jl
+++ b/src/ArrowTypes/src/ArrowTypes.jl
@@ -208,8 +208,11 @@ isstringtype(::ListKind{stringtype}) where {stringtype} = stringtype
isstringtype(::Type{ListKind{stringtype}}) where {stringtype} = stringtype
ArrowKind(::Type{<:AbstractString}) = ListKind{true}()
+# Treate Base.CodeUnits as Binary arrow type
+ArrowKind(::Type{<:Base.CodeUnits}) = ListKind{true}()
fromarrow(::Type{T}, ptr::Ptr{UInt8}, len::Int) where {T} = fromarrow(T, unsafe_string(ptr, len))
+fromarrow(::Type{T}, x) where {T <: Base.CodeUnits} = Base.CodeUnits(x)
ArrowType(::Type{Symbol}) = String
toarrow(x::Symbol) = String(x)
diff --git a/src/ArrowTypes/test/tests.jl b/src/ArrowTypes/test/tests.jl
index e035b48..26f3945 100644
--- a/src/ArrowTypes/test/tests.jl
+++ b/src/ArrowTypes/test/tests.jl
@@ -72,9 +72,12 @@ end
@test !ArrowTypes.isstringtype(ArrowTypes.ListKind())
@test !ArrowTypes.isstringtype(typeof(ArrowTypes.ListKind()))
@test ArrowTypes.ArrowKind(String) == ArrowTypes.ListKind{true}()
+@test ArrowTypes.ArrowKind(Base.CodeUnits) == ArrowTypes.ListKind{true}()
hey = collect(b"hey")
@test ArrowTypes.fromarrow(String, pointer(hey), 3) == "hey"
+@test ArrowTypes.fromarrow(Base.CodeUnits, pointer(hey), 3) == b"hey"
+@test ArrowTypes.fromarrow(Union{Base.CodeUnits, Missing}, pointer(hey), 3) == b"hey"
@test ArrowTypes.ArrowType(Symbol) == String
@test ArrowTypes.toarrow(:hey) == "hey"
diff --git a/src/arraytypes/list.jl b/src/arraytypes/list.jl
index 1525f38..2275f9f 100644
--- a/src/arraytypes/list.jl
+++ b/src/arraytypes/list.jl
@@ -49,11 +49,20 @@ Base.size(l::List) = (l.ℓ,)
@inbounds lo, hi = l.offsets[i]
S = Base.nonmissingtype(T)
K = ArrowTypes.ArrowKind(ArrowTypes.ArrowType(S))
- if ArrowTypes.isstringtype(K)
+ # special-case Base.CodeUnits for ArrowTypes compat
+ if ArrowTypes.isstringtype(K) || S <: Base.CodeUnits
if S !== T
- return l.validity[i] ? ArrowTypes.fromarrow(T, pointer(l.data, lo), hi - lo + 1) : missing
+ if S <: Base.CodeUnits
+ return l.validity[i] ? Base.CodeUnits(unsafe_string(pointer(l.data, lo), hi - lo + 1)) : missing
+ else
+ return l.validity[i] ? ArrowTypes.fromarrow(T, pointer(l.data, lo), hi - lo + 1) : missing
+ end
else
- return ArrowTypes.fromarrow(T, pointer(l.data, lo), hi - lo + 1)
+ if S <: Base.CodeUnits
+ return Base.CodeUnits(unsafe_string(pointer(l.data, lo), hi - lo + 1))
+ else
+ return ArrowTypes.fromarrow(T, pointer(l.data, lo), hi - lo + 1)
+ end
end
elseif S !== T
return l.validity[i] ? ArrowTypes.fromarrow(T, view(l.data, lo:hi)) : missing
@@ -66,6 +75,12 @@ end
# end
+# internal interface definitions to be able to treat AbstractString/CodeUnits similarly
+_ncodeunits(x::AbstractString) = ncodeunits(x)
+_codeunits(x::AbstractString) = codeunits(x)
+_ncodeunits(x::Base.CodeUnits) = length(x)
+_codeunits(x::Base.CodeUnits) = x
+
# an AbstractVector version of Iterators.flatten
# code based on SentinelArrays.ChainedVector
struct ToList{T, stringtype, A, I} <: AbstractVector{T}
@@ -74,14 +89,21 @@ struct ToList{T, stringtype, A, I} <: AbstractVector{T}
end
origtype(::ToList{T, S, A, I}) where {T, S, A, I} = A
+liststringtype(::Type{ToList{T, S, A, I}}) where {T, S, A, I} = S
+function liststringtype(::List{T, O, A}) where {T, O, A}
+ ST = Base.nonmissingtype(T)
+ K = ArrowTypes.ArrowKind(ST)
+ return liststringtype(A) || ArrowTypes.isstringtype(K) || ST <: Base.CodeUnits # add the CodeUnits check for ArrowTypes compat for now
+end
+liststringtype(T) = false
function ToList(input; largelists::Bool=false)
AT = eltype(input)
ST = Base.nonmissingtype(AT)
K = ArrowTypes.ArrowKind(ST)
- stringtype = ArrowTypes.isstringtype(K)
+ stringtype = ArrowTypes.isstringtype(K) || ST <: Base.CodeUnits # add the CodeUnits check for ArrowTypes compat for now
T = stringtype ? UInt8 : eltype(ST)
- len = stringtype ? ncodeunits : length
+ len = stringtype ? _ncodeunits : length
data = AT[]
I = largelists ? Int64 : Int32
inds = I[0]
@@ -122,7 +144,7 @@ Base.@propagate_inbounds function Base.getindex(A::ToList{T, stringtype}, i::Int
@boundscheck checkbounds(A, i)
chunk, ix = index(A, i)
@inbounds x = A.data[chunk]
- return @inbounds stringtype ? codeunits(x)[ix] : x[ix]
+ return @inbounds stringtype ? _codeunits(x)[ix] : x[ix]
end
Base.@propagate_inbounds function Base.setindex!(A::ToList{T, stringtype}, v, i::Integer) where {T, stringtype}
@@ -130,7 +152,7 @@ Base.@propagate_inbounds function Base.setindex!(A::ToList{T, stringtype}, v, i:
chunk, ix = index(A, i)
@inbounds x = A.data[chunk]
if stringtype
- codeunits(x)[ix] = v
+ _codeunits(x)[ix] = v
else
x[ix] = v
end
@@ -149,7 +171,7 @@ end
chunk_len = A.inds[chunk]
end
val = A.data[chunk - 1]
- x = stringtype ? codeunits(val)[1] : val[1]
+ x = stringtype ? _codeunits(val)[1] : val[1]
# find next valid index
i += 1
if i > chunk_len
@@ -168,7 +190,7 @@ end
@inline function Base.iterate(A::ToList{T, stringtype}, (i, chunk, chunk_i, chunk_len, len)) where {T, stringtype}
i > len && return nothing
@inbounds val = A.data[chunk - 1]
- @inbounds x = stringtype ? codeunits(val)[chunk_i] : val[chunk_i]
+ @inbounds x = stringtype ? _codeunits(val)[chunk_i] : val[chunk_i]
i += 1
if i > chunk_len
chunk_i = 1
@@ -191,7 +213,7 @@ function arrowvector(::ListKind, x, i, nl, fi, de, ded, meta; largelists::Bool=f
validity = ValidityBitmap(x)
flat = ToList(x; largelists=largelists)
offsets = Offsets(UInt8[], flat.inds)
- if eltype(flat) == UInt8 # binary or utf8string
+ if liststringtype(typeof(flat)) && eltype(flat) == UInt8 # binary or utf8string
data = flat
T = origtype(flat)
else
@@ -208,7 +230,7 @@ function compress(Z::Meta.CompressionType, comp, x::List{T, O, A}) where {T, O,
offsets = compress(Z, comp, x.offsets.offsets)
buffers = [validity, offsets]
children = Compressed[]
- if eltype(A) == UInt8
+ if liststringtype(x)
push!(buffers, compress(Z, comp, x.data))
else
push!(children, compress(Z, comp, x.data))
diff --git a/src/arraytypes/map.jl b/src/arraytypes/map.jl
index 67a791c..8140b25 100644
--- a/src/arraytypes/map.jl
+++ b/src/arraytypes/map.jl
@@ -90,7 +90,7 @@ function makenodesbuffers!(col::Union{Map{T, O, A}, List{T, O, A}}, fieldnodes,
push!(fieldbuffers, Buffer(bufferoffset, blen))
@debugv 1 "made field buffer: bufferidx = $(length(fieldbuffers)), offset = $(fieldbuffers[end].offset), len = $(fieldbuffers[end].length), padded = $(padding(fieldbuffers[end].length, alignment))"
bufferoffset += padding(blen, alignment)
- if eltype(A) == UInt8
+ if liststringtype(col)
blen = length(col.data)
push!(fieldbuffers, Buffer(bufferoffset, blen))
@debugv 1 "made field buffer: bufferidx = $(length(fieldbuffers)), offset = $(fieldbuffers[end].offset), len = $(fieldbuffers[end].length), padded = $(padding(fieldbuffers[end].length, alignment))"
@@ -110,7 +110,7 @@ function writebuffer(io, col::Union{Map{T, O, A}, List{T, O, A}}, alignment) whe
@debugv 1 "writing array: col = $(typeof(col.offsets.offsets)), n = $n, padded = $(padding(n, alignment))"
writezeros(io, paddinglength(n, alignment))
# write values array
- if eltype(A) == UInt8
+ if liststringtype(col)
n = writearray(io, UInt8, col.data)
@debugv 1 "writing array: col = $(typeof(col.data)), n = $n, padded = $(padding(n, alignment))"
writezeros(io, paddinglength(n, alignment))
diff --git a/src/eltypes.jl b/src/eltypes.jl
index 4bec444..11a845f 100644
--- a/src/eltypes.jl
+++ b/src/eltypes.jl
@@ -129,7 +129,7 @@ juliaeltype(f::Meta.Field, b::Union{Meta.Utf8, Meta.LargeUtf8}, convert) = Strin
datasizeof(x) = sizeof(x)
datasizeof(x::AbstractVector) = sum(datasizeof, x)
-juliaeltype(f::Meta.Field, b::Union{Meta.Binary, Meta.LargeBinary}, convert) = Vector{UInt8}
+juliaeltype(f::Meta.Field, b::Union{Meta.Binary, Meta.LargeBinary}, convert) = Base.CodeUnits
juliaeltype(f::Meta.Field, x::Meta.FixedSizeBinary, convert) = NTuple{Int(x.byteWidth), UInt8}
@@ -393,7 +393,7 @@ end
# arrowtype will call fieldoffset recursively for children
function arrowtype(b, x::List{T, O, A}) where {T, O, A}
- if eltype(A) == UInt8
+ if liststringtype(x)
if T <: AbstractString || T <: Union{AbstractString, Missing}
if O == Int32
Meta.utf8Start(b)
@@ -402,7 +402,7 @@ function arrowtype(b, x::List{T, O, A}) where {T, O, A}
Meta.largUtf8Start(b)
return Meta.LargeUtf8, Meta.largUtf8End(b), nothing
end
- else # if Vector{UInt8}
+ else # if Base.CodeUnits
if O == Int32
Meta.binaryStart(b)
return Meta.Binary, Meta.binaryEnd(b), nothing
diff --git a/test/runtests.jl b/test/runtests.jl
index 511be4e..3cdac88 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -629,6 +629,42 @@ for (col1, col2) in zip(Tables.columns(df), Tables.columns(df_load))
@test col1 == col2
end
+@testset "# 411" begin
+# Vector{UInt8} are written as List{UInt8} in Arrow
+# Base.CodeUnits are written as Binary
+t = (
+ a=[[0x00, 0x01], UInt8[], [0x03]],
+ am=[[0x00, 0x01], [0x03], missing],
+ b=[b"01", b"", b"3"],
+ bm=[b"01", b"3", missing],
+ c=["a", "b", "c"],
+ cm=["a", "c", missing]
+)
+buf = Arrow.tobuffer(t)
+tt = Arrow.Table(buf)
+@test t.a == tt.a
+@test isequal(t.am, tt.am)
+@test t.b == tt.b
+@test isequal(t.bm, tt.bm)
+@test t.c == tt.c
+@test isequal(t.cm, tt.cm)
+@test Arrow.schema(tt)[].fields[1].type isa Arrow.Flatbuf.List
+@test Arrow.schema(tt)[].fields[3].type isa Arrow.Flatbuf.Binary
+pos = position(buf)
+Arrow.append(buf, tt)
+seekstart(buf)
+buf1 = read(buf, pos)
+buf2 = read(buf)
+t1 = Arrow.Table(buf1)
+t2 = Arrow.Table(buf2)
+@test isequal(t1.a, t2.a)
+@test isequal(t1.am, t2.am)
+@test isequal(t1.b, t2.b)
+@test isequal(t1.bm, t2.bm)
+@test isequal(t1.c, t2.c)
+@test isequal(t1.cm, t2.cm)
+
+end
end # @testset "misc"