You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/04/26 18:26:08 UTC

[arrow-datafusion] branch master updated: Fix HashJoin evaluating during plan (#2317)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 4fe59b983 Fix HashJoin evaluating during plan (#2317)
4fe59b983 is described below

commit 4fe59b9831bee756b9ae4589e53ab552e1a278f9
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Tue Apr 26 19:26:03 2022 +0100

    Fix HashJoin evaluating during plan (#2317)
    
    * Fix HashJoin evaluating during plan
    
    * Fix partitioned hash join
    
    * Update datafusion/core/src/physical_plan/hash_join.rs
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
    
    Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
 datafusion/core/src/physical_plan/cross_join.rs |  50 +---
 datafusion/core/src/physical_plan/hash_join.rs  | 382 ++++++++++++------------
 datafusion/core/src/physical_plan/join_utils.rs | 113 +++++++
 3 files changed, 315 insertions(+), 230 deletions(-)

diff --git a/datafusion/core/src/physical_plan/cross_join.rs b/datafusion/core/src/physical_plan/cross_join.rs
index 240c5dda4..2846af9c5 100644
--- a/datafusion/core/src/physical_plan/cross_join.rs
+++ b/datafusion/core/src/physical_plan/cross_join.rs
@@ -18,13 +18,12 @@
 //! Defines the cross join plan for loading the left side of the cross join
 //! and producing batches in parallel for the right partitions
 
-use futures::{ready, FutureExt, StreamExt};
+use futures::{ready, StreamExt};
 use futures::{Stream, TryStreamExt};
-use parking_lot::Mutex;
 use std::{any::Any, sync::Arc, task::Poll};
 
 use arrow::datatypes::{Schema, SchemaRef};
-use arrow::error::{ArrowError, Result as ArrowResult};
+use arrow::error::Result as ArrowResult;
 use arrow::record_batch::RecordBatch;
 
 use super::expressions::PhysicalSortExpr;
@@ -34,7 +33,6 @@ use super::{
 };
 use crate::{error::Result, scalar::ScalarValue};
 use async_trait::async_trait;
-use futures::future::{BoxFuture, Shared};
 use std::time::Instant;
 
 use super::{
@@ -42,16 +40,12 @@ use super::{
     RecordBatchStream, SendableRecordBatchStream,
 };
 use crate::execution::context::TaskContext;
+use crate::physical_plan::join_utils::{OnceAsync, OnceFut};
 use log::debug;
 
 /// Data of the left side
 type JoinLeftData = RecordBatch;
 
-/// Type of future for collecting left data
-///
-/// [`Shared`] allows potentially multiple output streams to poll the same future to completion
-type JoinLeftFut = Shared<BoxFuture<'static, Arc<Result<RecordBatch>>>>;
-
 /// executes partitions in parallel and combines them into a set of
 /// partitions by combining all values from the left with all values on the right
 #[derive(Debug)]
@@ -63,11 +57,7 @@ pub struct CrossJoinExec {
     /// The schema once the join is applied
     schema: SchemaRef,
     /// Build-side data
-    ///
-    /// Ideally we would instantiate this in the constructor, avoiding the need for a
-    /// mutex and an option, but we need the [`TaskContext`] to evaluate the left
-    /// side data, which is only provided in [`ExecutionPlan::execute`]
-    left_fut: Mutex<Option<JoinLeftFut>>,
+    left_fut: OnceAsync<JoinLeftData>,
 }
 
 impl CrossJoinExec {
@@ -97,7 +87,7 @@ impl CrossJoinExec {
             left,
             right,
             schema,
-            left_fut: Mutex::new(None),
+            left_fut: Default::default(),
         })
     }
 
@@ -188,19 +178,11 @@ impl ExecutionPlan for CrossJoinExec {
 
         let left_fut = self
             .left_fut
-            .lock()
-            .get_or_insert_with(|| {
-                load_left_input(self.left.clone(), context)
-                    .map(Arc::new)
-                    .boxed()
-                    .shared()
-            })
-            .clone();
+            .once(|| load_left_input(self.left.clone(), context));
 
         Ok(Box::pin(CrossJoinStream {
             schema: self.schema.clone(),
             left_fut,
-            left_result: None,
             right: stream,
             right_batch: Arc::new(parking_lot::Mutex::new(None)),
             left_index: 0,
@@ -303,9 +285,7 @@ struct CrossJoinStream {
     /// Input schema
     schema: Arc<Schema>,
     /// future for data from left side
-    left_fut: JoinLeftFut,
-    /// data from the left side
-    left_result: Option<Arc<Result<RecordBatch>>>,
+    left_fut: OnceFut<JoinLeftData>,
     /// right
     right: SendableRecordBatchStream,
     /// Current value on the left
@@ -375,21 +355,9 @@ impl CrossJoinStream {
         &mut self,
         cx: &mut std::task::Context<'_>,
     ) -> std::task::Poll<Option<ArrowResult<RecordBatch>>> {
-        let left_result = match &self.left_result {
-            Some(data) => data,
-            None => {
-                let result = ready!(self.left_fut.poll_unpin(cx));
-                self.left_result.insert(result)
-            }
-        };
-
-        let left_data = match left_result.as_ref() {
+        let left_data = match ready!(self.left_fut.get(cx)) {
             Ok(left_data) => left_data,
-            Err(e) => {
-                return Poll::Ready(Some(Err(ArrowError::ExternalError(
-                    e.to_string().into(),
-                ))))
-            }
+            Err(e) => return Poll::Ready(Some(Err(e))),
         };
 
         if left_data.num_rows() == 0 {
diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs
index ce371fecf..31882c63c 100644
--- a/datafusion/core/src/physical_plan/hash_join.rs
+++ b/datafusion/core/src/physical_plan/hash_join.rs
@@ -35,8 +35,7 @@ use std::{any::Any, usize};
 use std::{time::Instant, vec};
 
 use async_trait::async_trait;
-use futures::{Stream, StreamExt, TryStreamExt};
-use tokio::sync::Mutex;
+use futures::{ready, Stream, StreamExt, TryStreamExt};
 
 use arrow::array::{new_null_array, Array};
 use arrow::datatypes::DataType;
@@ -74,8 +73,11 @@ use crate::arrow::datatypes::TimeUnit;
 use crate::execution::context::TaskContext;
 use crate::physical_plan::coalesce_batches::concat_batches;
 use crate::physical_plan::PhysicalExpr;
+
+use crate::physical_plan::join_utils::{OnceAsync, OnceFut};
 use log::debug;
 use std::fmt;
+use std::task::Poll;
 
 // Maps a `u64` hash value based on the left ["on" values] to a list of indices with this key's value.
 //
@@ -97,7 +99,7 @@ impl fmt::Debug for JoinHashMap {
     }
 }
 
-type JoinLeftData = Arc<(JoinHashMap, RecordBatch)>;
+type JoinLeftData = (JoinHashMap, RecordBatch);
 
 /// join execution plan executes partitions in parallel and combines them into a set of
 /// partitions.
@@ -113,8 +115,8 @@ pub struct HashJoinExec {
     join_type: JoinType,
     /// The schema once the join is applied
     schema: SchemaRef,
-    /// Build-side
-    build_side: Arc<Mutex<Option<JoinLeftData>>>,
+    /// Build-side data
+    left_fut: OnceAsync<JoinLeftData>,
     /// Shares the `RandomState` for the hashing algorithm
     random_state: RandomState,
     /// Partitioning mode to use
@@ -208,7 +210,7 @@ impl HashJoinExec {
             on,
             join_type: *join_type,
             schema: Arc::new(schema),
-            build_side: Arc::new(Mutex::new(None)),
+            left_fut: Default::default(),
             random_state,
             mode: partition_mode,
             metrics: ExecutionPlanMetricsSet::new(),
@@ -294,150 +296,44 @@ impl ExecutionPlan for HashJoinExec {
         context: Arc<TaskContext>,
     ) -> Result<SendableRecordBatchStream> {
         let on_left = self.on.iter().map(|on| on.0.clone()).collect::<Vec<_>>();
-        // we only want to compute the build side once for PartitionMode::CollectLeft
-        let left_data = {
-            match self.mode {
-                PartitionMode::CollectLeft => {
-                    let mut build_side = self.build_side.lock().await;
-
-                    match build_side.as_ref() {
-                        Some(stream) => stream.clone(),
-                        None => {
-                            let start = Instant::now();
-
-                            // merge all left parts into a single stream
-                            let merge = CoalescePartitionsExec::new(self.left.clone());
-                            let stream = merge.execute(0, context.clone()).await?;
-
-                            // This operation performs 2 steps at once:
-                            // 1. creates a [JoinHashMap] of all batches from the stream
-                            // 2. stores the batches in a vector.
-                            let initial = (0, Vec::new());
-                            let (num_rows, batches) = stream
-                                .try_fold(initial, |mut acc, batch| async {
-                                    acc.0 += batch.num_rows();
-                                    acc.1.push(batch);
-                                    Ok(acc)
-                                })
-                                .await?;
-                            let mut hashmap =
-                                JoinHashMap(RawTable::with_capacity(num_rows));
-                            let mut hashes_buffer = Vec::new();
-                            let mut offset = 0;
-                            for batch in batches.iter() {
-                                hashes_buffer.clear();
-                                hashes_buffer.resize(batch.num_rows(), 0);
-                                update_hash(
-                                    &on_left,
-                                    batch,
-                                    &mut hashmap,
-                                    offset,
-                                    &self.random_state,
-                                    &mut hashes_buffer,
-                                )?;
-                                offset += batch.num_rows();
-                            }
-                            // Merge all batches into a single batch, so we
-                            // can directly index into the arrays
-                            let single_batch =
-                                concat_batches(&self.left.schema(), &batches, num_rows)?;
-
-                            let left_side = Arc::new((hashmap, single_batch));
-
-                            *build_side = Some(left_side.clone());
-
-                            debug!(
-                                "Built build-side of hash join containing {} rows in {} ms",
-                                num_rows,
-                                start.elapsed().as_millis()
-                            );
-
-                            left_side
-                        }
-                    }
-                }
-                PartitionMode::Partitioned => {
-                    let start = Instant::now();
-
-                    // Load 1 partition of left side in memory
-                    let stream = self.left.execute(partition, context.clone()).await?;
-
-                    // This operation performs 2 steps at once:
-                    // 1. creates a [JoinHashMap] of all batches from the stream
-                    // 2. stores the batches in a vector.
-                    let initial = (0, Vec::new());
-                    let (num_rows, batches) = stream
-                        .try_fold(initial, |mut acc, batch| async {
-                            acc.0 += batch.num_rows();
-                            acc.1.push(batch);
-                            Ok(acc)
-                        })
-                        .await?;
-                    let mut hashmap = JoinHashMap(RawTable::with_capacity(num_rows));
-                    let mut hashes_buffer = Vec::new();
-                    let mut offset = 0;
-                    for batch in batches.iter() {
-                        hashes_buffer.clear();
-                        hashes_buffer.resize(batch.num_rows(), 0);
-                        update_hash(
-                            &on_left,
-                            batch,
-                            &mut hashmap,
-                            offset,
-                            &self.random_state,
-                            &mut hashes_buffer,
-                        )?;
-                        offset += batch.num_rows();
-                    }
-                    // Merge all batches into a single batch, so we
-                    // can directly index into the arrays
-                    let single_batch =
-                        concat_batches(&self.left.schema(), &batches, num_rows)?;
-
-                    let left_side = Arc::new((hashmap, single_batch));
-
-                    debug!(
-                        "Built build-side {} of hash join containing {} rows in {} ms",
-                        partition,
-                        num_rows,
-                        start.elapsed().as_millis()
-                    );
+        let on_right = self.on.iter().map(|on| on.1.clone()).collect::<Vec<_>>();
 
-                    left_side
-                }
-            }
+        let left_fut = match self.mode {
+            PartitionMode::CollectLeft => self.left_fut.once(|| {
+                collect_left_input(
+                    self.random_state.clone(),
+                    self.left.clone(),
+                    on_left.clone(),
+                    context.clone(),
+                )
+            }),
+            PartitionMode::Partitioned => OnceFut::new(partitioned_left_input(
+                partition,
+                self.random_state.clone(),
+                self.left.clone(),
+                on_left.clone(),
+                context.clone(),
+            )),
         };
 
         // we have the batches and the hash map with their keys. We can how create a stream
         // over the right that uses this information to issue new batches.
+        let right_stream = self.right.execute(partition, context).await?;
 
-        let right_stream = self.right.execute(partition, context.clone()).await?;
-        let on_right = self.on.iter().map(|on| on.1.clone()).collect::<Vec<_>>();
-
-        let num_rows = left_data.1.num_rows();
-        let visited_left_side = match self.join_type {
-            JoinType::Left | JoinType::Full | JoinType::Semi | JoinType::Anti => {
-                let mut buffer = BooleanBufferBuilder::new(num_rows);
-
-                buffer.append_n(num_rows, false);
-
-                buffer
-            }
-            JoinType::Inner | JoinType::Right => BooleanBufferBuilder::new(0),
-        };
-        Ok(Box::pin(HashJoinStream::new(
-            self.schema.clone(),
+        Ok(Box::pin(HashJoinStream {
+            schema: self.schema(),
             on_left,
             on_right,
-            self.join_type,
-            left_data,
-            right_stream,
-            self.column_indices.clone(),
-            self.random_state.clone(),
-            visited_left_side,
-            HashJoinMetrics::new(partition, &self.metrics),
-            self.null_equals_null,
-        )))
+            join_type: self.join_type,
+            left_fut,
+            visited_left_side: None,
+            right: right_stream,
+            column_indices: self.column_indices.clone(),
+            random_state: self.random_state.clone(),
+            join_metrics: HashJoinMetrics::new(partition, &self.metrics),
+            null_equals_null: self.null_equals_null,
+            is_exhausted: false,
+        }))
     }
 
     fn fmt_as(
@@ -468,6 +364,116 @@ impl ExecutionPlan for HashJoinExec {
     }
 }
 
+async fn collect_left_input(
+    random_state: RandomState,
+    left: Arc<dyn ExecutionPlan>,
+    on_left: Vec<Column>,
+    context: Arc<TaskContext>,
+) -> Result<JoinLeftData> {
+    let schema = left.schema();
+    let start = Instant::now();
+
+    // merge all left parts into a single stream
+    let merge = CoalescePartitionsExec::new(left);
+    let stream = merge.execute(0, context).await?;
+
+    // This operation performs 2 steps at once:
+    // 1. creates a [JoinHashMap] of all batches from the stream
+    // 2. stores the batches in a vector.
+    let initial = (0, Vec::new());
+    let (num_rows, batches) = stream
+        .try_fold(initial, |mut acc, batch| async {
+            acc.0 += batch.num_rows();
+            acc.1.push(batch);
+            Ok(acc)
+        })
+        .await?;
+
+    let mut hashmap = JoinHashMap(RawTable::with_capacity(num_rows));
+    let mut hashes_buffer = Vec::new();
+    let mut offset = 0;
+    for batch in batches.iter() {
+        hashes_buffer.clear();
+        hashes_buffer.resize(batch.num_rows(), 0);
+        update_hash(
+            &on_left,
+            batch,
+            &mut hashmap,
+            offset,
+            &random_state,
+            &mut hashes_buffer,
+        )?;
+        offset += batch.num_rows();
+    }
+    // Merge all batches into a single batch, so we
+    // can directly index into the arrays
+    let single_batch = concat_batches(&schema, &batches, num_rows)?;
+
+    debug!(
+        "Built build-side of hash join containing {} rows in {} ms",
+        num_rows,
+        start.elapsed().as_millis()
+    );
+
+    Ok((hashmap, single_batch))
+}
+
+async fn partitioned_left_input(
+    partition: usize,
+    random_state: RandomState,
+    left: Arc<dyn ExecutionPlan>,
+    on_left: Vec<Column>,
+    context: Arc<TaskContext>,
+) -> Result<JoinLeftData> {
+    let schema = left.schema();
+
+    let start = Instant::now();
+
+    // Load 1 partition of left side in memory
+    let stream = left.execute(partition, context.clone()).await?;
+
+    // This operation performs 2 steps at once:
+    // 1. creates a [JoinHashMap] of all batches from the stream
+    // 2. stores the batches in a vector.
+    let initial = (0, Vec::new());
+    let (num_rows, batches) = stream
+        .try_fold(initial, |mut acc, batch| async {
+            acc.0 += batch.num_rows();
+            acc.1.push(batch);
+            Ok(acc)
+        })
+        .await?;
+
+    let mut hashmap = JoinHashMap(RawTable::with_capacity(num_rows));
+    let mut hashes_buffer = Vec::new();
+    let mut offset = 0;
+    for batch in batches.iter() {
+        hashes_buffer.clear();
+        hashes_buffer.resize(batch.num_rows(), 0);
+        update_hash(
+            &on_left,
+            batch,
+            &mut hashmap,
+            offset,
+            &random_state,
+            &mut hashes_buffer,
+        )?;
+        offset += batch.num_rows();
+    }
+    // Merge all batches into a single batch, so we
+    // can directly index into the arrays
+    let single_batch = concat_batches(&schema, &batches, num_rows)?;
+
+    debug!(
+        "Built build-side {} of hash join containing {} rows in {} ms",
+        partition,
+        num_rows,
+        start.elapsed().as_millis()
+    );
+
+    Ok((hashmap, single_batch))
+}
+
 /// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`,
 /// assuming that the [RecordBatch] corresponds to the `index`th
 fn update_hash(
@@ -515,14 +521,14 @@ struct HashJoinStream {
     on_right: Vec<Column>,
     /// type of the join
     join_type: JoinType,
-    /// information from the left
-    left_data: JoinLeftData,
+    /// future for data from left side
+    left_fut: OnceFut<JoinLeftData>,
+    /// Keeps track of the left side rows whether they are visited
+    visited_left_side: Option<BooleanBufferBuilder>,
     /// right
     right: SendableRecordBatchStream,
     /// Random state used for hashing initialization
     random_state: RandomState,
-    /// Keeps track of the left side rows whether they are visited
-    visited_left_side: BooleanBufferBuilder,
     /// There is nothing to process anymore and left side is processed in case of left join
     is_exhausted: bool,
     /// Metrics
@@ -533,38 +539,6 @@ struct HashJoinStream {
     null_equals_null: bool,
 }
 
-#[allow(clippy::too_many_arguments)]
-impl HashJoinStream {
-    fn new(
-        schema: Arc<Schema>,
-        on_left: Vec<Column>,
-        on_right: Vec<Column>,
-        join_type: JoinType,
-        left_data: JoinLeftData,
-        right: SendableRecordBatchStream,
-        column_indices: Vec<ColumnIndex>,
-        random_state: RandomState,
-        visited_left_side: BooleanBufferBuilder,
-        join_metrics: HashJoinMetrics,
-        null_equals_null: bool,
-    ) -> Self {
-        HashJoinStream {
-            schema,
-            on_left,
-            on_right,
-            join_type,
-            left_data,
-            right,
-            column_indices,
-            random_state,
-            visited_left_side,
-            is_exhausted: false,
-            join_metrics,
-            null_equals_null,
-        }
-    }
-}
-
 impl RecordBatchStream for HashJoinStream {
     fn schema(&self) -> SchemaRef {
         self.schema.clone()
@@ -979,13 +953,32 @@ fn produce_from_matched(
     RecordBatch::try_new(schema.clone(), columns)
 }
 
-impl Stream for HashJoinStream {
-    type Item = ArrowResult<RecordBatch>;
-
-    fn poll_next(
-        mut self: std::pin::Pin<&mut Self>,
+impl HashJoinStream {
+    /// Separate implementation function that unpins the [`HashJoinStream`] so
+    /// that partial borrows work correctly
+    fn poll_next_impl(
+        &mut self,
         cx: &mut std::task::Context<'_>,
-    ) -> std::task::Poll<Option<Self::Item>> {
+    ) -> std::task::Poll<Option<ArrowResult<RecordBatch>>> {
+        let left_data = match ready!(self.left_fut.get(cx)) {
+            Ok(left_data) => left_data,
+            Err(e) => return Poll::Ready(Some(Err(e))),
+        };
+
+        let visited_left_side = self.visited_left_side.get_or_insert_with(|| {
+            let num_rows = left_data.1.num_rows();
+            match self.join_type {
+                JoinType::Left | JoinType::Full | JoinType::Semi | JoinType::Anti => {
+                    let mut buffer = BooleanBufferBuilder::new(num_rows);
+
+                    buffer.append_n(num_rows, false);
+
+                    buffer
+                }
+                JoinType::Inner | JoinType::Right => BooleanBufferBuilder::new(0),
+            }
+        });
+
         self.right
             .poll_next_unpin(cx)
             .map(|maybe_batch| match maybe_batch {
@@ -993,7 +986,7 @@ impl Stream for HashJoinStream {
                     let timer = self.join_metrics.join_time.timer();
                     let result = build_batch(
                         &batch,
-                        &self.left_data,
+                        left_data,
                         &self.on_left,
                         &self.on_right,
                         self.join_type,
@@ -1015,7 +1008,7 @@ impl Stream for HashJoinStream {
                             | JoinType::Semi
                             | JoinType::Anti => {
                                 left_side.iter().flatten().for_each(|x| {
-                                    self.visited_left_side.set_bit(x as usize, true);
+                                    visited_left_side.set_bit(x as usize, true);
                                 });
                             }
                             JoinType::Inner | JoinType::Right => {}
@@ -1034,10 +1027,10 @@ impl Stream for HashJoinStream {
                             if !self.is_exhausted =>
                         {
                             let result = produce_from_matched(
-                                &self.visited_left_side,
+                                visited_left_side,
                                 &self.schema,
                                 &self.column_indices,
-                                &self.left_data,
+                                left_data,
                                 self.join_type != JoinType::Semi,
                             );
                             if let Ok(ref batch) = result {
@@ -1066,6 +1059,17 @@ impl Stream for HashJoinStream {
     }
 }
 
+impl Stream for HashJoinStream {
+    type Item = ArrowResult<RecordBatch>;
+
+    fn poll_next(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> std::task::Poll<Option<Self::Item>> {
+        self.poll_next_impl(cx)
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use crate::{
@@ -1959,7 +1963,7 @@ mod tests {
             ("c", &vec![30, 40]),
         );
 
-        let left_data = JoinLeftData::new((JoinHashMap(hashmap_left), left));
+        let left_data = (JoinHashMap(hashmap_left), left);
         let (l, r) = build_join_indexes(
             &left_data,
             &right,
diff --git a/datafusion/core/src/physical_plan/join_utils.rs b/datafusion/core/src/physical_plan/join_utils.rs
index 8359bbc4e..71349ef14 100644
--- a/datafusion/core/src/physical_plan/join_utils.rs
+++ b/datafusion/core/src/physical_plan/join_utils.rs
@@ -21,7 +21,14 @@ use crate::error::{DataFusionError, Result};
 use crate::logical_plan::JoinType;
 use crate::physical_plan::expressions::Column;
 use arrow::datatypes::{Field, Schema};
+use arrow::error::ArrowError;
+use futures::future::{BoxFuture, Shared};
+use futures::{ready, FutureExt};
+use parking_lot::Mutex;
 use std::collections::HashSet;
+use std::future::Future;
+use std::sync::Arc;
+use std::task::{Context, Poll};
 
 /// The on clause of the join, as vector of (left, right) columns.
 pub type JoinOn = Vec<(Column, Column)>;
@@ -147,6 +154,112 @@ pub fn build_join_schema(
     (Schema::new(fields), column_indices)
 }
 
+/// A [`OnceAsync`] can be used to run an async closure once, with subsequent calls
+/// to [`OnceAsync::once`] returning a [`OnceFut`] to the same asynchronous computation
+///
+/// This is useful for joins where the results of one child are buffered in memory
+/// and shared across potentially multiple output partitions
+pub(crate) struct OnceAsync<T> {
+    fut: Mutex<Option<OnceFut<T>>>,
+}
+
+impl<T> Default for OnceAsync<T> {
+    fn default() -> Self {
+        Self {
+            fut: Mutex::new(None),
+        }
+    }
+}
+
+impl<T> std::fmt::Debug for OnceAsync<T> {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        write!(f, "OnceAsync")
+    }
+}
+
+impl<T: 'static> OnceAsync<T> {
+    /// If this is the first call to this function on this object, will invoke
+    /// `f` to obtain a future and return a [`OnceFut`] referring to this
+    ///
+    /// If this is not the first call, will return a [`OnceFut`] referring
+    /// to the same future as was returned by the first call
+    pub(crate) fn once<F, Fut>(&self, f: F) -> OnceFut<T>
+    where
+        F: FnOnce() -> Fut,
+        Fut: Future<Output = Result<T>> + Send + 'static,
+    {
+        self.fut
+            .lock()
+            .get_or_insert_with(|| OnceFut::new(f()))
+            .clone()
+    }
+}
+
+/// The shared future type used internally within [`OnceAsync`]
+type OnceFutPending<T> = Shared<BoxFuture<'static, Arc<Result<T>>>>;
+
+/// A [`OnceFut`] represents a shared asynchronous computation, that will be evaluated
+/// once for all [`Clone`]'s, with [`OnceFut::get`] providing a non-consuming interface
+/// to drive the underlying [`Future`] to completion
+pub(crate) struct OnceFut<T> {
+    state: OnceFutState<T>,
+}
+
+impl<T> Clone for OnceFut<T> {
+    fn clone(&self) -> Self {
+        Self {
+            state: self.state.clone(),
+        }
+    }
+}
+
+enum OnceFutState<T> {
+    Pending(OnceFutPending<T>),
+    Ready(Arc<Result<T>>),
+}
+
+impl<T> Clone for OnceFutState<T> {
+    fn clone(&self) -> Self {
+        match self {
+            Self::Pending(p) => Self::Pending(p.clone()),
+            Self::Ready(r) => Self::Ready(r.clone()),
+        }
+    }
+}
+
+impl<T: 'static> OnceFut<T> {
+    /// Create a new [`OnceFut`] from a [`Future`]
+    pub(crate) fn new<Fut>(fut: Fut) -> Self
+    where
+        Fut: Future<Output = Result<T>> + Send + 'static,
+    {
+        Self {
+            state: OnceFutState::Pending(fut.map(Arc::new).boxed().shared()),
+        }
+    }
+
+    /// Get the result of the computation if it is ready, without consuming it
+    pub(crate) fn get(
+        &mut self,
+        cx: &mut Context<'_>,
+    ) -> Poll<std::result::Result<&T, ArrowError>> {
+        if let OnceFutState::Pending(fut) = &mut self.state {
+            let r = ready!(fut.poll_unpin(cx));
+            self.state = OnceFutState::Ready(r);
+        }
+
+        // Cannot use loop as this would trip up the borrow checker
+        match &self.state {
+            OnceFutState::Pending(_) => unreachable!(),
+            OnceFutState::Ready(r) => Poll::Ready(
+                r.as_ref()
+                    .as_ref()
+                    .map_err(|e| ArrowError::ExternalError(e.to_string().into())),
+            ),
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;