You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by qu...@apache.org on 2022/10/06 22:25:06 UTC

[arrow-julia] branch main updated: Use OrderedSynchronizer instead of OrderedChannel (#339)

This is an automated email from the ASF dual-hosted git repository.

quinnj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-julia.git


The following commit(s) were added to refs/heads/main by this push:
     new d239c34  Use OrderedSynchronizer instead of OrderedChannel (#339)
d239c34 is described below

commit d239c34e1a0ba2770221f8a855d6ff851f44b85d
Author: Jacob Quinn <qu...@gmail.com>
AuthorDate: Thu Oct 6 16:24:59 2022 -0600

    Use OrderedSynchronizer instead of OrderedChannel (#339)
    
    * Use OrderedSynchronizer instead of OrderedChannel
    
    * fixes #295
---
 .github/workflows/ci.yml       |  5 ++--
 Project.toml                   |  4 ++-
 src/Arrow.jl                   |  2 +-
 src/append.jl                  |  7 ++---
 src/arraytypes/dictencoding.jl |  8 +++---
 src/table.jl                   | 17 +++++++-----
 src/utils.jl                   | 61 ------------------------------------------
 src/write.jl                   | 33 ++++++++++++++---------
 8 files changed, 45 insertions(+), 92 deletions(-)

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 1545e49..e65fd7e 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -72,8 +72,7 @@ jobs:
           - name: ArrowTypes.jl
             dir: './src/ArrowTypes'
         version:
-          - '1.0'
-          - '1.4'
+          - '1.6'
           - '1' # automatically expands to the latest stable 1.x release of Julia
           - 'nightly'
         os:
@@ -84,7 +83,7 @@ jobs:
           # Test Arrow.jl/ArrowTypes.jl on their oldest supported Julia versions
           - pkg:
               name: Arrow.jl
-            version: '1.0'
+            version: '1.6'
           - pkg:
               name: ArrowTypes.jl
             version: '1.4'
diff --git a/Project.toml b/Project.toml
index 60676a2..89369be 100644
--- a/Project.toml
+++ b/Project.toml
@@ -33,6 +33,7 @@ SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
 Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
 TimeZones = "f269a46b-ccf7-5d73-abea-4c690281aa53"
 UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
+WorkerUtilities = "76eceee3-57b5-4d4a-8e66-0e911cebbf60"
 
 [compat]
 ArrowTypes = "1.1"
@@ -45,7 +46,8 @@ PooledArrays = "0.5, 1.0"
 SentinelArrays = "1"
 Tables = "1.1"
 TimeZones = "1"
-julia = "1.4"
+WorkerUtilities = "1.1"
+julia = "1.6"
 
 [extras]
 CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
diff --git a/src/Arrow.jl b/src/Arrow.jl
index a623031..6eaa08e 100644
--- a/src/Arrow.jl
+++ b/src/Arrow.jl
@@ -44,7 +44,7 @@ module Arrow
 using Base.Iterators
 using Mmap
 import Dates
-using DataAPI, Tables, SentinelArrays, PooledArrays, CodecLz4, CodecZstd, TimeZones, BitIntegers
+using DataAPI, Tables, SentinelArrays, PooledArrays, CodecLz4, CodecZstd, TimeZones, BitIntegers, WorkerUtilities
 
 export ArrowTypes
 
diff --git a/src/append.jl b/src/append.jl
index 1d1d8ef..f6716f3 100644
--- a/src/append.jl
+++ b/src/append.jl
@@ -107,7 +107,8 @@ function append(io::IO, source, arrow_schema, compress, largelists, denseunions,
     skip(io, -8) # overwrite last 8 bytes of last empty message footer
 
     sch = Ref{Tables.Schema}(arrow_schema)
-    msgs = OrderedChannel{Message}(ntasks)
+    sync = OrderedSynchronizer()
+    msgs = Channel{Message}(ntasks)
     dictencodings = Dict{Int64, Any}() # Lockable{DictEncoding}
     # build messages
     blocks = (Block[], Block[])
@@ -134,9 +135,9 @@ function append(io::IO, source, arrow_schema, compress, largelists, denseunions,
         end
 
         if threaded
-            Threads.@spawn process_partition(tbl_cols, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested, maxdepth, msgs, alignment, i, sch, errorref, anyerror, meta, colmeta)
+            Threads.@spawn process_partition(tbl_cols, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested, maxdepth, sync, msgs, alignment, i, sch, errorref, anyerror, meta, colmeta)
         else
-            @async process_partition(tbl_cols, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested, maxdepth, msgs, alignment, i, sch, errorref, anyerror, meta, colmeta)
+            @async process_partition(tbl_cols, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested, maxdepth, sync, msgs, alignment, i, sch, errorref, anyerror, meta, colmeta)
         end
     end
     if anyerror[]
diff --git a/src/arraytypes/dictencoding.jl b/src/arraytypes/dictencoding.jl
index 505e275..e470d2e 100644
--- a/src/arraytypes/dictencoding.jl
+++ b/src/arraytypes/dictencoding.jl
@@ -129,7 +129,7 @@ function arrowvector(::DictEncodedKind, x::DictEncoded, i, nl, fi, de, ded, meta
     else
         encodinglockable = de[id]
         @lock encodinglockable begin
-            encoding = encodinglockable.x
+            encoding = encodinglockable.value
             # in this case, we just need to check if any values in our local pool need to be delta dicationary serialized
             deltas = setdiff(x.encoding, encoding)
             if !isempty(deltas)
@@ -144,7 +144,7 @@ function arrowvector(::DictEncodedKind, x::DictEncoded, i, nl, fi, de, ded, meta
                 else
                     data2 = ChainedVector([encoding.data, data])
                     encoding = DictEncoding{eltype(data2), ET, typeof(data2)}(id, data2, false, getmetadata(encoding))
-                    de[id].x = encoding
+                    de[id] = Lockable(encoding)
                 end
             end
         end
@@ -196,7 +196,7 @@ function arrowvector(::DictEncodedKind, x, i, nl, fi, de, ded, meta; dictencode:
           # also add to deltas updates
         encodinglockable = de[id]
         @lock encodinglockable begin
-            encoding = encodinglockable.x
+            encoding = encodinglockable.value
             len = length(x)
             ET = indextype(encoding)
             pool = Dict{Union{eltype(encoding), eltype(x)}, ET}(a => (b - 1) for (b, a) in enumerate(encoding))
@@ -223,7 +223,7 @@ function arrowvector(::DictEncodedKind, x, i, nl, fi, de, ded, meta; dictencode:
                 else
                     data2 = ChainedVector([encoding.data, data])
                     encoding = DictEncoding{eltype(data2), ET, typeof(data2)}(id, data2, false, getmetadata(encoding))
-                    de[id].x = encoding
+                    de[id] = Lockable(encoding)
                 end
             end
         end
diff --git a/src/table.jl b/src/table.jl
index 259bac2..7b6d8c8 100644
--- a/src/table.jl
+++ b/src/table.jl
@@ -275,11 +275,11 @@ function Table(blobs::Vector{ArrowBlob}; convert::Bool=true)
     sch = nothing
     dictencodings = Dict{Int64, DictEncoding}() # dictionary id => DictEncoding
     dictencoded = Dict{Int64, Meta.Field}() # dictionary id => field
-    tsks = Channel{Task}(Inf)
+    sync = OrderedSynchronizer()
+    tsks = Channel{Any}(Inf)
     tsk = Threads.@spawn begin
         i = 1
-        for tsk in tsks
-            cols = fetch(tsk)
+        for cols in tsks
             if i == 1
                 foreach(x -> push!(columns(t), x), cols)
             elseif i == 2
@@ -295,7 +295,8 @@ function Table(blobs::Vector{ArrowBlob}; convert::Bool=true)
         end
     end
     anyrecordbatches = false
-    for blob in blobs
+    rbi = 1
+    @sync for blob in blobs
         bytes, pos, len = blob.bytes, blob.pos, blob.len
         if len > 24 &&
             _startswith(bytes, pos, FILE_FORMAT_MAGIC_BYTES) &&
@@ -349,9 +350,11 @@ function Table(blobs::Vector{ArrowBlob}; convert::Bool=true)
             elseif header isa Meta.RecordBatch
                 anyrecordbatches = true
                 @debug 1 "parsing record batch message: compression = $(header.compression)"
-                put!(tsks, Threads.@spawn begin
-                    collect(VectorIterator(sch, batch, dictencodings, convert))
-                end)
+                Threads.@spawn begin
+                    cols = collect(VectorIterator(sch, $batch, dictencodings, convert))
+                    put!(() -> put!(tsks, cols), sync, $(rbi))
+                end
+                rbi += 1
             else
                 throw(ArgumentError("unsupported arrow message type: $(typeof(header))"))
             end
diff --git a/src/utils.jl b/src/utils.jl
index 7151579..223c6c2 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -127,67 +127,6 @@ function readmessage(filebytes, off=9)
     FlatBuffers.getrootas(Meta.Message, filebytes, off + 8)
 end
 
-# a custom Channel type that only allows put!-ing objects in a specific, monotonically increasing order
-struct OrderedChannel{T}
-    chan::Channel{T}
-    cond::Threads.Condition
-    i::Ref{Int}
-end
-
-OrderedChannel{T}(sz) where {T} = OrderedChannel{T}(Channel{T}(sz), Threads.Condition(), Ref(1))
-Base.iterate(ch::OrderedChannel, st...) = iterate(ch.chan, st...)
-
-macro lock(obj, expr)
-    esc(quote
-        lock($obj)
-        try
-            $expr
-        finally
-            unlock($obj)
-        end
-    end)
-end
-
-# when put!-ing an object, operation may have to wait until other tasks have put their
-# objects to ensure the channel is ordered correctly
-function Base.put!(ch::OrderedChannel{T}, x::T, i::Integer, incr::Bool=false) where {T}
-    @lock ch.cond begin
-        while ch.i[] < i
-            # channel index too early, need to wait for other tasks to put their objects first
-            wait(ch.cond)
-        end
-        # now it's our turn
-        put!(ch.chan, x)
-        if incr
-            ch.i[] += 1
-        end
-        # wake up tasks that may be waiting to put their objects
-        notify(ch.cond)
-    end
-    return
-end
-
-function Base.close(ch::OrderedChannel)
-    @lock ch.cond begin
-        # just need to ensure any tasks waiting to put their tasks have had a chance to put
-        while !isempty(ch.cond)
-            wait(ch.cond)
-        end
-        close(ch.chan)
-    end
-    return
-end
-
-mutable struct Lockable
-    x
-    lock::ReentrantLock
-end
-
-Lockable(x) = Lockable(x, ReentrantLock())
-
-Base.lock(x::Lockable) = lock(x.lock)
-Base.unlock(x::Lockable) = unlock(x.lock)
-
 function tobuffer(data; kwargs...)
     io = IOBuffer()
     write(io, data; kwargs...)
diff --git a/src/write.jl b/src/write.jl
index ae2da6d..9ee532b 100644
--- a/src/write.jl
+++ b/src/write.jl
@@ -122,7 +122,8 @@ mutable struct Writer{T<:IO}
     maxdepth::Int64
     meta::Union{Nothing,Base.ImmutableDict{String,String}}
     colmeta::Union{Nothing,Base.ImmutableDict{Symbol,Base.ImmutableDict{String,String}}}
-    msgs::OrderedChannel{Message}
+    sync::OrderedSynchronizer
+    msgs::Channel{Message}
     schema::Ref{Tables.Schema}
     firstcols::Ref{Any}
     dictencodings::Dict{Int64,Any}
@@ -135,7 +136,8 @@ mutable struct Writer{T<:IO}
 end
 
 function Base.open(::Type{Writer}, io::T, compress::Union{Nothing,LZ4FrameCompressor,<:AbstractVector{LZ4FrameCompressor},ZstdCompressor,<:AbstractVector{ZstdCompressor}}, writetofile::Bool, largelists::Bool, denseunions::Bool, dictencode::Bool, dictencodenested::Bool, alignment::Integer, maxdepth::Integer, ntasks::Integer, meta::Union{Nothing,Any}, colmeta::Union{Nothing,Any}, closeio::Bool) where {T<:IO}
-    msgs = OrderedChannel{Message}(ntasks)
+    sync = OrderedSynchronizer(2)
+    msgs = Channel{Message}(ntasks)
     schema = Ref{Tables.Schema}()
     firstcols = Ref{Any}()
     dictencodings = Dict{Int64,Any}() # Lockable{DictEncoding}
@@ -151,7 +153,7 @@ function Base.open(::Type{Writer}, io::T, compress::Union{Nothing,LZ4FrameCompre
     errorref = Ref{Any}()
     meta = _normalizemeta(meta)
     colmeta = _normalizecolmeta(colmeta)
-    return Writer{T}(io, closeio, compress, writetofile, largelists, denseunions, dictencode, dictencodenested, threaded, alignment, maxdepth, meta, colmeta, msgs, schema, firstcols, dictencodings, blocks, task, anyerror, errorref, 1, false)
+    return Writer{T}(io, closeio, compress, writetofile, largelists, denseunions, dictencode, dictencodenested, threaded, alignment, maxdepth, meta, colmeta, sync, msgs, schema, firstcols, dictencodings, blocks, task, anyerror, errorref, 1, false)
 end
 
 function Base.open(::Type{Writer}, io::IO, compress::Symbol, args...)
@@ -193,24 +195,24 @@ function write(writer::Writer, source)
             cols = toarrowtable(tblcols, writer.dictencodings, writer.largelists, writer.compress, writer.denseunions, writer.dictencode, writer.dictencodenested, writer.maxdepth, meta, writer.colmeta)
             writer.schema[] = Tables.schema(cols)
             writer.firstcols[] = cols
-            put!(writer.msgs, makeschemamsg(writer.schema[], cols), writer.partition_count)
+            put!(writer.msgs, makeschemamsg(writer.schema[], cols))
             if !isempty(writer.dictencodings)
                 des = sort!(collect(writer.dictencodings); by=x -> x.first, rev=true)
                 for (id, delock) in des
                     # assign dict encoding ids
-                    de = delock.x
+                    de = delock.value
                     dictsch = Tables.Schema((:col,), (eltype(de.data),))
                     dictbatchmsg = makedictionarybatchmsg(dictsch, (col=de.data,), id, false, writer.alignment)
-                    put!(writer.msgs, dictbatchmsg, writer.partition_count)
+                    put!(writer.msgs, dictbatchmsg)
                 end
             end
             recbatchmsg = makerecordbatchmsg(writer.schema[], cols, writer.alignment)
-            put!(writer.msgs, recbatchmsg, writer.partition_count, true)
+            put!(writer.msgs, recbatchmsg)
         else
             if writer.threaded
-                Threads.@spawn process_partition(tblcols, writer.dictencodings, writer.largelists, writer.compress, writer.denseunions, writer.dictencode, writer.dictencodenested, writer.maxdepth, writer.msgs, writer.alignment, $(writer.partition_count), writer.schema, writer.errorref, writer.anyerror, writer.meta, writer.colmeta)
+                Threads.@spawn process_partition(tblcols, writer.dictencodings, writer.largelists, writer.compress, writer.denseunions, writer.dictencode, writer.dictencodenested, writer.maxdepth, writer.sync, writer.msgs, writer.alignment, $(writer.partition_count), writer.schema, writer.errorref, writer.anyerror, writer.meta, writer.colmeta)
             else
-                @async process_partition(tblcols, writer.dictencodings, writer.largelists, writer.compress, writer.denseunions, writer.dictencode, writer.dictencodenested, writer.maxdepth, writer.msgs, writer.alignment, $(writer.partition_count), writer.schema, writer.errorref, writer.anyerror, writer.meta, writer.colmeta)
+                @async process_partition(tblcols, writer.dictencodings, writer.largelists, writer.compress, writer.denseunions, writer.dictencode, writer.dictencodenested, writer.maxdepth, writer.sync, writer.msgs, writer.alignment, $(writer.partition_count), writer.schema, writer.errorref, writer.anyerror, writer.meta, writer.colmeta)
             end
         end
         writer.partition_count += 1
@@ -290,16 +292,23 @@ function write(io, source, writetofile, largelists, compress, denseunions, dicte
     io
 end
 
-function process_partition(cols, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested, maxdepth, msgs, alignment, i, sch, errorref, anyerror, meta, colmeta)
+function process_partition(cols, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested, maxdepth, sync, msgs, alignment, i, sch, errorref, anyerror, meta, colmeta)
     try
         cols = toarrowtable(cols, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested, maxdepth, meta, colmeta)
+        dictmsgs = nothing
         if !isempty(cols.dictencodingdeltas)
+            dictmsgs = []
             for de in cols.dictencodingdeltas
                 dictsch = Tables.Schema((:col,), (eltype(de.data),))
-                put!(msgs, makedictionarybatchmsg(dictsch, (col=de.data,), de.id, true, alignment), i)
+                push!(dictmsgs, makedictionarybatchmsg(dictsch, (col=de.data,), de.id, true, alignment))
             end
         end
-        put!(msgs, makerecordbatchmsg(sch[], cols, alignment), i, true)
+        put!(sync, i) do
+            if !isnothing(dictmsgs)
+                foreach(msg -> put!(msgs, msg), dictmsgs)
+            end
+            put!(msgs, makerecordbatchmsg(sch[], cols, alignment))
+        end
     catch e
         errorref[] = (e, catch_backtrace(), i)
         anyerror[] = true