You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by qu...@apache.org on 2022/10/06 08:18:00 UTC

[arrow-julia] 01/01: Use OrderedSynchronizer instead of OrderedChannel

This is an automated email from the ASF dual-hosted git repository.

quinnj pushed a commit to branch jq/orderedsync
in repository https://gitbox.apache.org/repos/asf/arrow-julia.git

commit 9779ee9b8b765d223ee35c7e4a3a9102bd7bccef
Author: Jacob Quinn <qu...@gmail.com>
AuthorDate: Thu Oct 6 02:17:52 2022 -0600

    Use OrderedSynchronizer instead of OrderedChannel
---
 Project.toml                   | 18 +------------
 src/Arrow.jl                   |  2 +-
 src/append.jl                  |  7 ++---
 src/arraytypes/dictencoding.jl |  8 +++---
 src/table.jl                   | 17 +++++++-----
 src/utils.jl                   | 61 ------------------------------------------
 src/write.jl                   | 33 ++++++++++++++---------
 7 files changed, 41 insertions(+), 105 deletions(-)

diff --git a/Project.toml b/Project.toml
index 60676a2..66fb563 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,20 +1,3 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
 name = "Arrow"
 uuid = "69666777-d1a9-59fb-9406-91d4454c9d45"
 authors = ["quinnj <qu...@gmail.com>"]
@@ -33,6 +16,7 @@ SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c"
 Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
 TimeZones = "f269a46b-ccf7-5d73-abea-4c690281aa53"
 UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
+WorkerUtilities = "76eceee3-57b5-4d4a-8e66-0e911cebbf60"
 
 [compat]
 ArrowTypes = "1.1"
diff --git a/src/Arrow.jl b/src/Arrow.jl
index a623031..6eaa08e 100644
--- a/src/Arrow.jl
+++ b/src/Arrow.jl
@@ -44,7 +44,7 @@ module Arrow
 using Base.Iterators
 using Mmap
 import Dates
-using DataAPI, Tables, SentinelArrays, PooledArrays, CodecLz4, CodecZstd, TimeZones, BitIntegers
+using DataAPI, Tables, SentinelArrays, PooledArrays, CodecLz4, CodecZstd, TimeZones, BitIntegers, WorkerUtilities
 
 export ArrowTypes
 
diff --git a/src/append.jl b/src/append.jl
index 1d1d8ef..f6716f3 100644
--- a/src/append.jl
+++ b/src/append.jl
@@ -107,7 +107,8 @@ function append(io::IO, source, arrow_schema, compress, largelists, denseunions,
     skip(io, -8) # overwrite last 8 bytes of last empty message footer
 
     sch = Ref{Tables.Schema}(arrow_schema)
-    msgs = OrderedChannel{Message}(ntasks)
+    sync = OrderedSynchronizer()
+    msgs = Channel{Message}(ntasks)
     dictencodings = Dict{Int64, Any}() # Lockable{DictEncoding}
     # build messages
     blocks = (Block[], Block[])
@@ -134,9 +135,9 @@ function append(io::IO, source, arrow_schema, compress, largelists, denseunions,
         end
 
         if threaded
-            Threads.@spawn process_partition(tbl_cols, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested, maxdepth, msgs, alignment, i, sch, errorref, anyerror, meta, colmeta)
+            Threads.@spawn process_partition(tbl_cols, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested, maxdepth, sync, msgs, alignment, i, sch, errorref, anyerror, meta, colmeta)
         else
-            @async process_partition(tbl_cols, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested, maxdepth, msgs, alignment, i, sch, errorref, anyerror, meta, colmeta)
+            @async process_partition(tbl_cols, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested, maxdepth, sync, msgs, alignment, i, sch, errorref, anyerror, meta, colmeta)
         end
     end
     if anyerror[]
diff --git a/src/arraytypes/dictencoding.jl b/src/arraytypes/dictencoding.jl
index 505e275..e470d2e 100644
--- a/src/arraytypes/dictencoding.jl
+++ b/src/arraytypes/dictencoding.jl
@@ -129,7 +129,7 @@ function arrowvector(::DictEncodedKind, x::DictEncoded, i, nl, fi, de, ded, meta
     else
         encodinglockable = de[id]
         @lock encodinglockable begin
-            encoding = encodinglockable.x
+            encoding = encodinglockable.value
             # in this case, we just need to check if any values in our local pool need to be delta dicationary serialized
             deltas = setdiff(x.encoding, encoding)
             if !isempty(deltas)
@@ -144,7 +144,7 @@ function arrowvector(::DictEncodedKind, x::DictEncoded, i, nl, fi, de, ded, meta
                 else
                     data2 = ChainedVector([encoding.data, data])
                     encoding = DictEncoding{eltype(data2), ET, typeof(data2)}(id, data2, false, getmetadata(encoding))
-                    de[id].x = encoding
+                    de[id] = Lockable(encoding)
                 end
             end
         end
@@ -196,7 +196,7 @@ function arrowvector(::DictEncodedKind, x, i, nl, fi, de, ded, meta; dictencode:
           # also add to deltas updates
         encodinglockable = de[id]
         @lock encodinglockable begin
-            encoding = encodinglockable.x
+            encoding = encodinglockable.value
             len = length(x)
             ET = indextype(encoding)
             pool = Dict{Union{eltype(encoding), eltype(x)}, ET}(a => (b - 1) for (b, a) in enumerate(encoding))
@@ -223,7 +223,7 @@ function arrowvector(::DictEncodedKind, x, i, nl, fi, de, ded, meta; dictencode:
                 else
                     data2 = ChainedVector([encoding.data, data])
                     encoding = DictEncoding{eltype(data2), ET, typeof(data2)}(id, data2, false, getmetadata(encoding))
-                    de[id].x = encoding
+                    de[id] = Lockable(encoding)
                 end
             end
         end
diff --git a/src/table.jl b/src/table.jl
index 259bac2..7b6d8c8 100644
--- a/src/table.jl
+++ b/src/table.jl
@@ -275,11 +275,11 @@ function Table(blobs::Vector{ArrowBlob}; convert::Bool=true)
     sch = nothing
     dictencodings = Dict{Int64, DictEncoding}() # dictionary id => DictEncoding
     dictencoded = Dict{Int64, Meta.Field}() # dictionary id => field
-    tsks = Channel{Task}(Inf)
+    sync = OrderedSynchronizer()
+    tsks = Channel{Any}(Inf)
     tsk = Threads.@spawn begin
         i = 1
-        for tsk in tsks
-            cols = fetch(tsk)
+        for cols in tsks
             if i == 1
                 foreach(x -> push!(columns(t), x), cols)
             elseif i == 2
@@ -295,7 +295,8 @@ function Table(blobs::Vector{ArrowBlob}; convert::Bool=true)
         end
     end
     anyrecordbatches = false
-    for blob in blobs
+    rbi = 1
+    @sync for blob in blobs
         bytes, pos, len = blob.bytes, blob.pos, blob.len
         if len > 24 &&
             _startswith(bytes, pos, FILE_FORMAT_MAGIC_BYTES) &&
@@ -349,9 +350,11 @@ function Table(blobs::Vector{ArrowBlob}; convert::Bool=true)
             elseif header isa Meta.RecordBatch
                 anyrecordbatches = true
                 @debug 1 "parsing record batch message: compression = $(header.compression)"
-                put!(tsks, Threads.@spawn begin
-                    collect(VectorIterator(sch, batch, dictencodings, convert))
-                end)
+                Threads.@spawn begin
+                    cols = collect(VectorIterator(sch, $batch, dictencodings, convert))
+                    put!(() -> put!(tsks, cols), sync, $(rbi))
+                end
+                rbi += 1
             else
                 throw(ArgumentError("unsupported arrow message type: $(typeof(header))"))
             end
diff --git a/src/utils.jl b/src/utils.jl
index 7151579..223c6c2 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -127,67 +127,6 @@ function readmessage(filebytes, off=9)
     FlatBuffers.getrootas(Meta.Message, filebytes, off + 8)
 end
 
-# a custom Channel type that only allows put!-ing objects in a specific, monotonically increasing order
-struct OrderedChannel{T}
-    chan::Channel{T}
-    cond::Threads.Condition
-    i::Ref{Int}
-end
-
-OrderedChannel{T}(sz) where {T} = OrderedChannel{T}(Channel{T}(sz), Threads.Condition(), Ref(1))
-Base.iterate(ch::OrderedChannel, st...) = iterate(ch.chan, st...)
-
-macro lock(obj, expr)
-    esc(quote
-        lock($obj)
-        try
-            $expr
-        finally
-            unlock($obj)
-        end
-    end)
-end
-
-# when put!-ing an object, operation may have to wait until other tasks have put their
-# objects to ensure the channel is ordered correctly
-function Base.put!(ch::OrderedChannel{T}, x::T, i::Integer, incr::Bool=false) where {T}
-    @lock ch.cond begin
-        while ch.i[] < i
-            # channel index too early, need to wait for other tasks to put their objects first
-            wait(ch.cond)
-        end
-        # now it's our turn
-        put!(ch.chan, x)
-        if incr
-            ch.i[] += 1
-        end
-        # wake up tasks that may be waiting to put their objects
-        notify(ch.cond)
-    end
-    return
-end
-
-function Base.close(ch::OrderedChannel)
-    @lock ch.cond begin
-        # just need to ensure any tasks waiting to put their tasks have had a chance to put
-        while !isempty(ch.cond)
-            wait(ch.cond)
-        end
-        close(ch.chan)
-    end
-    return
-end
-
-mutable struct Lockable
-    x
-    lock::ReentrantLock
-end
-
-Lockable(x) = Lockable(x, ReentrantLock())
-
-Base.lock(x::Lockable) = lock(x.lock)
-Base.unlock(x::Lockable) = unlock(x.lock)
-
 function tobuffer(data; kwargs...)
     io = IOBuffer()
     write(io, data; kwargs...)
diff --git a/src/write.jl b/src/write.jl
index ae2da6d..9ee532b 100644
--- a/src/write.jl
+++ b/src/write.jl
@@ -122,7 +122,8 @@ mutable struct Writer{T<:IO}
     maxdepth::Int64
     meta::Union{Nothing,Base.ImmutableDict{String,String}}
     colmeta::Union{Nothing,Base.ImmutableDict{Symbol,Base.ImmutableDict{String,String}}}
-    msgs::OrderedChannel{Message}
+    sync::OrderedSynchronizer
+    msgs::Channel{Message}
     schema::Ref{Tables.Schema}
     firstcols::Ref{Any}
     dictencodings::Dict{Int64,Any}
@@ -135,7 +136,8 @@ mutable struct Writer{T<:IO}
 end
 
 function Base.open(::Type{Writer}, io::T, compress::Union{Nothing,LZ4FrameCompressor,<:AbstractVector{LZ4FrameCompressor},ZstdCompressor,<:AbstractVector{ZstdCompressor}}, writetofile::Bool, largelists::Bool, denseunions::Bool, dictencode::Bool, dictencodenested::Bool, alignment::Integer, maxdepth::Integer, ntasks::Integer, meta::Union{Nothing,Any}, colmeta::Union{Nothing,Any}, closeio::Bool) where {T<:IO}
-    msgs = OrderedChannel{Message}(ntasks)
+    sync = OrderedSynchronizer(2)
+    msgs = Channel{Message}(ntasks)
     schema = Ref{Tables.Schema}()
     firstcols = Ref{Any}()
     dictencodings = Dict{Int64,Any}() # Lockable{DictEncoding}
@@ -151,7 +153,7 @@ function Base.open(::Type{Writer}, io::T, compress::Union{Nothing,LZ4FrameCompre
     errorref = Ref{Any}()
     meta = _normalizemeta(meta)
     colmeta = _normalizecolmeta(colmeta)
-    return Writer{T}(io, closeio, compress, writetofile, largelists, denseunions, dictencode, dictencodenested, threaded, alignment, maxdepth, meta, colmeta, msgs, schema, firstcols, dictencodings, blocks, task, anyerror, errorref, 1, false)
+    return Writer{T}(io, closeio, compress, writetofile, largelists, denseunions, dictencode, dictencodenested, threaded, alignment, maxdepth, meta, colmeta, sync, msgs, schema, firstcols, dictencodings, blocks, task, anyerror, errorref, 1, false)
 end
 
 function Base.open(::Type{Writer}, io::IO, compress::Symbol, args...)
@@ -193,24 +195,24 @@ function write(writer::Writer, source)
             cols = toarrowtable(tblcols, writer.dictencodings, writer.largelists, writer.compress, writer.denseunions, writer.dictencode, writer.dictencodenested, writer.maxdepth, meta, writer.colmeta)
             writer.schema[] = Tables.schema(cols)
             writer.firstcols[] = cols
-            put!(writer.msgs, makeschemamsg(writer.schema[], cols), writer.partition_count)
+            put!(writer.msgs, makeschemamsg(writer.schema[], cols))
             if !isempty(writer.dictencodings)
                 des = sort!(collect(writer.dictencodings); by=x -> x.first, rev=true)
                 for (id, delock) in des
                     # assign dict encoding ids
-                    de = delock.x
+                    de = delock.value
                     dictsch = Tables.Schema((:col,), (eltype(de.data),))
                     dictbatchmsg = makedictionarybatchmsg(dictsch, (col=de.data,), id, false, writer.alignment)
-                    put!(writer.msgs, dictbatchmsg, writer.partition_count)
+                    put!(writer.msgs, dictbatchmsg)
                 end
             end
             recbatchmsg = makerecordbatchmsg(writer.schema[], cols, writer.alignment)
-            put!(writer.msgs, recbatchmsg, writer.partition_count, true)
+            put!(writer.msgs, recbatchmsg)
         else
             if writer.threaded
-                Threads.@spawn process_partition(tblcols, writer.dictencodings, writer.largelists, writer.compress, writer.denseunions, writer.dictencode, writer.dictencodenested, writer.maxdepth, writer.msgs, writer.alignment, $(writer.partition_count), writer.schema, writer.errorref, writer.anyerror, writer.meta, writer.colmeta)
+                Threads.@spawn process_partition(tblcols, writer.dictencodings, writer.largelists, writer.compress, writer.denseunions, writer.dictencode, writer.dictencodenested, writer.maxdepth, writer.sync, writer.msgs, writer.alignment, $(writer.partition_count), writer.schema, writer.errorref, writer.anyerror, writer.meta, writer.colmeta)
             else
-                @async process_partition(tblcols, writer.dictencodings, writer.largelists, writer.compress, writer.denseunions, writer.dictencode, writer.dictencodenested, writer.maxdepth, writer.msgs, writer.alignment, $(writer.partition_count), writer.schema, writer.errorref, writer.anyerror, writer.meta, writer.colmeta)
+                @async process_partition(tblcols, writer.dictencodings, writer.largelists, writer.compress, writer.denseunions, writer.dictencode, writer.dictencodenested, writer.maxdepth, writer.sync, writer.msgs, writer.alignment, $(writer.partition_count), writer.schema, writer.errorref, writer.anyerror, writer.meta, writer.colmeta)
             end
         end
         writer.partition_count += 1
@@ -290,16 +292,23 @@ function write(io, source, writetofile, largelists, compress, denseunions, dicte
     io
 end
 
-function process_partition(cols, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested, maxdepth, msgs, alignment, i, sch, errorref, anyerror, meta, colmeta)
+function process_partition(cols, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested, maxdepth, sync, msgs, alignment, i, sch, errorref, anyerror, meta, colmeta)
     try
         cols = toarrowtable(cols, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested, maxdepth, meta, colmeta)
+        dictmsgs = nothing
         if !isempty(cols.dictencodingdeltas)
+            dictmsgs = []
             for de in cols.dictencodingdeltas
                 dictsch = Tables.Schema((:col,), (eltype(de.data),))
-                put!(msgs, makedictionarybatchmsg(dictsch, (col=de.data,), de.id, true, alignment), i)
+                push!(dictmsgs, makedictionarybatchmsg(dictsch, (col=de.data,), de.id, true, alignment))
             end
         end
-        put!(msgs, makerecordbatchmsg(sch[], cols, alignment), i, true)
+        put!(sync, i) do
+            if !isnothing(dictmsgs)
+                foreach(msg -> put!(msgs, msg), dictmsgs)
+            end
+            put!(msgs, makerecordbatchmsg(sch[], cols, alignment))
+        end
     catch e
         errorref[] = (e, catch_backtrace(), i)
         anyerror[] = true