You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/04/26 18:26:08 UTC
[arrow-datafusion] branch master updated: Fix HashJoin evaluating during plan (#2317)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 4fe59b983 Fix HashJoin evaluating during plan (#2317)
4fe59b983 is described below
commit 4fe59b9831bee756b9ae4589e53ab552e1a278f9
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Tue Apr 26 19:26:03 2022 +0100
Fix HashJoin evaluating during plan (#2317)
* Fix HashJoin evaluating during plan
* Fix partitioned hash join
* Update datafusion/core/src/physical_plan/hash_join.rs
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
datafusion/core/src/physical_plan/cross_join.rs | 50 +---
datafusion/core/src/physical_plan/hash_join.rs | 382 ++++++++++++------------
datafusion/core/src/physical_plan/join_utils.rs | 113 +++++++
3 files changed, 315 insertions(+), 230 deletions(-)
diff --git a/datafusion/core/src/physical_plan/cross_join.rs b/datafusion/core/src/physical_plan/cross_join.rs
index 240c5dda4..2846af9c5 100644
--- a/datafusion/core/src/physical_plan/cross_join.rs
+++ b/datafusion/core/src/physical_plan/cross_join.rs
@@ -18,13 +18,12 @@
//! Defines the cross join plan for loading the left side of the cross join
//! and producing batches in parallel for the right partitions
-use futures::{ready, FutureExt, StreamExt};
+use futures::{ready, StreamExt};
use futures::{Stream, TryStreamExt};
-use parking_lot::Mutex;
use std::{any::Any, sync::Arc, task::Poll};
use arrow::datatypes::{Schema, SchemaRef};
-use arrow::error::{ArrowError, Result as ArrowResult};
+use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use super::expressions::PhysicalSortExpr;
@@ -34,7 +33,6 @@ use super::{
};
use crate::{error::Result, scalar::ScalarValue};
use async_trait::async_trait;
-use futures::future::{BoxFuture, Shared};
use std::time::Instant;
use super::{
@@ -42,16 +40,12 @@ use super::{
RecordBatchStream, SendableRecordBatchStream,
};
use crate::execution::context::TaskContext;
+use crate::physical_plan::join_utils::{OnceAsync, OnceFut};
use log::debug;
/// Data of the left side
type JoinLeftData = RecordBatch;
-/// Type of future for collecting left data
-///
-/// [`Shared`] allows potentially multiple output streams to poll the same future to completion
-type JoinLeftFut = Shared<BoxFuture<'static, Arc<Result<RecordBatch>>>>;
-
/// executes partitions in parallel and combines them into a set of
/// partitions by combining all values from the left with all values on the right
#[derive(Debug)]
@@ -63,11 +57,7 @@ pub struct CrossJoinExec {
/// The schema once the join is applied
schema: SchemaRef,
/// Build-side data
- ///
- /// Ideally we would instantiate this in the constructor, avoiding the need for a
- /// mutex and an option, but we need the [`TaskContext`] to evaluate the left
- /// side data, which is only provided in [`ExecutionPlan::execute`]
- left_fut: Mutex<Option<JoinLeftFut>>,
+ left_fut: OnceAsync<JoinLeftData>,
}
impl CrossJoinExec {
@@ -97,7 +87,7 @@ impl CrossJoinExec {
left,
right,
schema,
- left_fut: Mutex::new(None),
+ left_fut: Default::default(),
})
}
@@ -188,19 +178,11 @@ impl ExecutionPlan for CrossJoinExec {
let left_fut = self
.left_fut
- .lock()
- .get_or_insert_with(|| {
- load_left_input(self.left.clone(), context)
- .map(Arc::new)
- .boxed()
- .shared()
- })
- .clone();
+ .once(|| load_left_input(self.left.clone(), context));
Ok(Box::pin(CrossJoinStream {
schema: self.schema.clone(),
left_fut,
- left_result: None,
right: stream,
right_batch: Arc::new(parking_lot::Mutex::new(None)),
left_index: 0,
@@ -303,9 +285,7 @@ struct CrossJoinStream {
/// Input schema
schema: Arc<Schema>,
/// future for data from left side
- left_fut: JoinLeftFut,
- /// data from the left side
- left_result: Option<Arc<Result<RecordBatch>>>,
+ left_fut: OnceFut<JoinLeftData>,
/// right
right: SendableRecordBatchStream,
/// Current value on the left
@@ -375,21 +355,9 @@ impl CrossJoinStream {
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<ArrowResult<RecordBatch>>> {
- let left_result = match &self.left_result {
- Some(data) => data,
- None => {
- let result = ready!(self.left_fut.poll_unpin(cx));
- self.left_result.insert(result)
- }
- };
-
- let left_data = match left_result.as_ref() {
+ let left_data = match ready!(self.left_fut.get(cx)) {
Ok(left_data) => left_data,
- Err(e) => {
- return Poll::Ready(Some(Err(ArrowError::ExternalError(
- e.to_string().into(),
- ))))
- }
+ Err(e) => return Poll::Ready(Some(Err(e))),
};
if left_data.num_rows() == 0 {
diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs
index ce371fecf..31882c63c 100644
--- a/datafusion/core/src/physical_plan/hash_join.rs
+++ b/datafusion/core/src/physical_plan/hash_join.rs
@@ -35,8 +35,7 @@ use std::{any::Any, usize};
use std::{time::Instant, vec};
use async_trait::async_trait;
-use futures::{Stream, StreamExt, TryStreamExt};
-use tokio::sync::Mutex;
+use futures::{ready, Stream, StreamExt, TryStreamExt};
use arrow::array::{new_null_array, Array};
use arrow::datatypes::DataType;
@@ -74,8 +73,11 @@ use crate::arrow::datatypes::TimeUnit;
use crate::execution::context::TaskContext;
use crate::physical_plan::coalesce_batches::concat_batches;
use crate::physical_plan::PhysicalExpr;
+
+use crate::physical_plan::join_utils::{OnceAsync, OnceFut};
use log::debug;
use std::fmt;
+use std::task::Poll;
// Maps a `u64` hash value based on the left ["on" values] to a list of indices with this key's value.
//
@@ -97,7 +99,7 @@ impl fmt::Debug for JoinHashMap {
}
}
-type JoinLeftData = Arc<(JoinHashMap, RecordBatch)>;
+type JoinLeftData = (JoinHashMap, RecordBatch);
/// join execution plan executes partitions in parallel and combines them into a set of
/// partitions.
@@ -113,8 +115,8 @@ pub struct HashJoinExec {
join_type: JoinType,
/// The schema once the join is applied
schema: SchemaRef,
- /// Build-side
- build_side: Arc<Mutex<Option<JoinLeftData>>>,
+ /// Build-side data
+ left_fut: OnceAsync<JoinLeftData>,
/// Shares the `RandomState` for the hashing algorithm
random_state: RandomState,
/// Partitioning mode to use
@@ -208,7 +210,7 @@ impl HashJoinExec {
on,
join_type: *join_type,
schema: Arc::new(schema),
- build_side: Arc::new(Mutex::new(None)),
+ left_fut: Default::default(),
random_state,
mode: partition_mode,
metrics: ExecutionPlanMetricsSet::new(),
@@ -294,150 +296,44 @@ impl ExecutionPlan for HashJoinExec {
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
let on_left = self.on.iter().map(|on| on.0.clone()).collect::<Vec<_>>();
- // we only want to compute the build side once for PartitionMode::CollectLeft
- let left_data = {
- match self.mode {
- PartitionMode::CollectLeft => {
- let mut build_side = self.build_side.lock().await;
-
- match build_side.as_ref() {
- Some(stream) => stream.clone(),
- None => {
- let start = Instant::now();
-
- // merge all left parts into a single stream
- let merge = CoalescePartitionsExec::new(self.left.clone());
- let stream = merge.execute(0, context.clone()).await?;
-
- // This operation performs 2 steps at once:
- // 1. creates a [JoinHashMap] of all batches from the stream
- // 2. stores the batches in a vector.
- let initial = (0, Vec::new());
- let (num_rows, batches) = stream
- .try_fold(initial, |mut acc, batch| async {
- acc.0 += batch.num_rows();
- acc.1.push(batch);
- Ok(acc)
- })
- .await?;
- let mut hashmap =
- JoinHashMap(RawTable::with_capacity(num_rows));
- let mut hashes_buffer = Vec::new();
- let mut offset = 0;
- for batch in batches.iter() {
- hashes_buffer.clear();
- hashes_buffer.resize(batch.num_rows(), 0);
- update_hash(
- &on_left,
- batch,
- &mut hashmap,
- offset,
- &self.random_state,
- &mut hashes_buffer,
- )?;
- offset += batch.num_rows();
- }
- // Merge all batches into a single batch, so we
- // can directly index into the arrays
- let single_batch =
- concat_batches(&self.left.schema(), &batches, num_rows)?;
-
- let left_side = Arc::new((hashmap, single_batch));
-
- *build_side = Some(left_side.clone());
-
- debug!(
- "Built build-side of hash join containing {} rows in {} ms",
- num_rows,
- start.elapsed().as_millis()
- );
-
- left_side
- }
- }
- }
- PartitionMode::Partitioned => {
- let start = Instant::now();
-
- // Load 1 partition of left side in memory
- let stream = self.left.execute(partition, context.clone()).await?;
-
- // This operation performs 2 steps at once:
- // 1. creates a [JoinHashMap] of all batches from the stream
- // 2. stores the batches in a vector.
- let initial = (0, Vec::new());
- let (num_rows, batches) = stream
- .try_fold(initial, |mut acc, batch| async {
- acc.0 += batch.num_rows();
- acc.1.push(batch);
- Ok(acc)
- })
- .await?;
- let mut hashmap = JoinHashMap(RawTable::with_capacity(num_rows));
- let mut hashes_buffer = Vec::new();
- let mut offset = 0;
- for batch in batches.iter() {
- hashes_buffer.clear();
- hashes_buffer.resize(batch.num_rows(), 0);
- update_hash(
- &on_left,
- batch,
- &mut hashmap,
- offset,
- &self.random_state,
- &mut hashes_buffer,
- )?;
- offset += batch.num_rows();
- }
- // Merge all batches into a single batch, so we
- // can directly index into the arrays
- let single_batch =
- concat_batches(&self.left.schema(), &batches, num_rows)?;
-
- let left_side = Arc::new((hashmap, single_batch));
-
- debug!(
- "Built build-side {} of hash join containing {} rows in {} ms",
- partition,
- num_rows,
- start.elapsed().as_millis()
- );
+ let on_right = self.on.iter().map(|on| on.1.clone()).collect::<Vec<_>>();
- left_side
- }
- }
+ let left_fut = match self.mode {
+ PartitionMode::CollectLeft => self.left_fut.once(|| {
+ collect_left_input(
+ self.random_state.clone(),
+ self.left.clone(),
+ on_left.clone(),
+ context.clone(),
+ )
+ }),
+ PartitionMode::Partitioned => OnceFut::new(partitioned_left_input(
+ partition,
+ self.random_state.clone(),
+ self.left.clone(),
+ on_left.clone(),
+ context.clone(),
+ )),
};
// we have the batches and the hash map with their keys. We can how create a stream
// over the right that uses this information to issue new batches.
+ let right_stream = self.right.execute(partition, context).await?;
- let right_stream = self.right.execute(partition, context.clone()).await?;
- let on_right = self.on.iter().map(|on| on.1.clone()).collect::<Vec<_>>();
-
- let num_rows = left_data.1.num_rows();
- let visited_left_side = match self.join_type {
- JoinType::Left | JoinType::Full | JoinType::Semi | JoinType::Anti => {
- let mut buffer = BooleanBufferBuilder::new(num_rows);
-
- buffer.append_n(num_rows, false);
-
- buffer
- }
- JoinType::Inner | JoinType::Right => BooleanBufferBuilder::new(0),
- };
- Ok(Box::pin(HashJoinStream::new(
- self.schema.clone(),
+ Ok(Box::pin(HashJoinStream {
+ schema: self.schema(),
on_left,
on_right,
- self.join_type,
- left_data,
- right_stream,
- self.column_indices.clone(),
- self.random_state.clone(),
- visited_left_side,
- HashJoinMetrics::new(partition, &self.metrics),
- self.null_equals_null,
- )))
+ join_type: self.join_type,
+ left_fut,
+ visited_left_side: None,
+ right: right_stream,
+ column_indices: self.column_indices.clone(),
+ random_state: self.random_state.clone(),
+ join_metrics: HashJoinMetrics::new(partition, &self.metrics),
+ null_equals_null: self.null_equals_null,
+ is_exhausted: false,
+ }))
}
fn fmt_as(
@@ -468,6 +364,116 @@ impl ExecutionPlan for HashJoinExec {
}
}
+async fn collect_left_input(
+ random_state: RandomState,
+ left: Arc<dyn ExecutionPlan>,
+ on_left: Vec<Column>,
+ context: Arc<TaskContext>,
+) -> Result<JoinLeftData> {
+ let schema = left.schema();
+ let start = Instant::now();
+
+ // merge all left parts into a single stream
+ let merge = CoalescePartitionsExec::new(left);
+ let stream = merge.execute(0, context).await?;
+
+ // This operation performs 2 steps at once:
+ // 1. creates a [JoinHashMap] of all batches from the stream
+ // 2. stores the batches in a vector.
+ let initial = (0, Vec::new());
+ let (num_rows, batches) = stream
+ .try_fold(initial, |mut acc, batch| async {
+ acc.0 += batch.num_rows();
+ acc.1.push(batch);
+ Ok(acc)
+ })
+ .await?;
+
+ let mut hashmap = JoinHashMap(RawTable::with_capacity(num_rows));
+ let mut hashes_buffer = Vec::new();
+ let mut offset = 0;
+ for batch in batches.iter() {
+ hashes_buffer.clear();
+ hashes_buffer.resize(batch.num_rows(), 0);
+ update_hash(
+ &on_left,
+ batch,
+ &mut hashmap,
+ offset,
+ &random_state,
+ &mut hashes_buffer,
+ )?;
+ offset += batch.num_rows();
+ }
+ // Merge all batches into a single batch, so we
+ // can directly index into the arrays
+ let single_batch = concat_batches(&schema, &batches, num_rows)?;
+
+ debug!(
+ "Built build-side of hash join containing {} rows in {} ms",
+ num_rows,
+ start.elapsed().as_millis()
+ );
+
+ Ok((hashmap, single_batch))
+}
+
+async fn partitioned_left_input(
+ partition: usize,
+ random_state: RandomState,
+ left: Arc<dyn ExecutionPlan>,
+ on_left: Vec<Column>,
+ context: Arc<TaskContext>,
+) -> Result<JoinLeftData> {
+ let schema = left.schema();
+
+ let start = Instant::now();
+
+ // Load 1 partition of left side in memory
+ let stream = left.execute(partition, context.clone()).await?;
+
+ // This operation performs 2 steps at once:
+ // 1. creates a [JoinHashMap] of all batches from the stream
+ // 2. stores the batches in a vector.
+ let initial = (0, Vec::new());
+ let (num_rows, batches) = stream
+ .try_fold(initial, |mut acc, batch| async {
+ acc.0 += batch.num_rows();
+ acc.1.push(batch);
+ Ok(acc)
+ })
+ .await?;
+
+ let mut hashmap = JoinHashMap(RawTable::with_capacity(num_rows));
+ let mut hashes_buffer = Vec::new();
+ let mut offset = 0;
+ for batch in batches.iter() {
+ hashes_buffer.clear();
+ hashes_buffer.resize(batch.num_rows(), 0);
+ update_hash(
+ &on_left,
+ batch,
+ &mut hashmap,
+ offset,
+ &random_state,
+ &mut hashes_buffer,
+ )?;
+ offset += batch.num_rows();
+ }
+ // Merge all batches into a single batch, so we
+ // can directly index into the arrays
+ let single_batch = concat_batches(&schema, &batches, num_rows)?;
+
+ debug!(
+ "Built build-side {} of hash join containing {} rows in {} ms",
+ partition,
+ num_rows,
+ start.elapsed().as_millis()
+ );
+
+ Ok((hashmap, single_batch))
+}
+
/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`,
/// assuming that the [RecordBatch] corresponds to the `index`th
fn update_hash(
@@ -515,14 +521,14 @@ struct HashJoinStream {
on_right: Vec<Column>,
/// type of the join
join_type: JoinType,
- /// information from the left
- left_data: JoinLeftData,
+ /// future for data from left side
+ left_fut: OnceFut<JoinLeftData>,
+ /// Keeps track of the left side rows whether they are visited
+ visited_left_side: Option<BooleanBufferBuilder>,
/// right
right: SendableRecordBatchStream,
/// Random state used for hashing initialization
random_state: RandomState,
- /// Keeps track of the left side rows whether they are visited
- visited_left_side: BooleanBufferBuilder,
/// There is nothing to process anymore and left side is processed in case of left join
is_exhausted: bool,
/// Metrics
@@ -533,38 +539,6 @@ struct HashJoinStream {
null_equals_null: bool,
}
-#[allow(clippy::too_many_arguments)]
-impl HashJoinStream {
- fn new(
- schema: Arc<Schema>,
- on_left: Vec<Column>,
- on_right: Vec<Column>,
- join_type: JoinType,
- left_data: JoinLeftData,
- right: SendableRecordBatchStream,
- column_indices: Vec<ColumnIndex>,
- random_state: RandomState,
- visited_left_side: BooleanBufferBuilder,
- join_metrics: HashJoinMetrics,
- null_equals_null: bool,
- ) -> Self {
- HashJoinStream {
- schema,
- on_left,
- on_right,
- join_type,
- left_data,
- right,
- column_indices,
- random_state,
- visited_left_side,
- is_exhausted: false,
- join_metrics,
- null_equals_null,
- }
- }
-}
-
impl RecordBatchStream for HashJoinStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
@@ -979,13 +953,32 @@ fn produce_from_matched(
RecordBatch::try_new(schema.clone(), columns)
}
-impl Stream for HashJoinStream {
- type Item = ArrowResult<RecordBatch>;
-
- fn poll_next(
- mut self: std::pin::Pin<&mut Self>,
+impl HashJoinStream {
+ /// Separate implementation function that unpins the [`HashJoinStream`] so
+ /// that partial borrows work correctly
+ fn poll_next_impl(
+ &mut self,
cx: &mut std::task::Context<'_>,
- ) -> std::task::Poll<Option<Self::Item>> {
+ ) -> std::task::Poll<Option<ArrowResult<RecordBatch>>> {
+ let left_data = match ready!(self.left_fut.get(cx)) {
+ Ok(left_data) => left_data,
+ Err(e) => return Poll::Ready(Some(Err(e))),
+ };
+
+ let visited_left_side = self.visited_left_side.get_or_insert_with(|| {
+ let num_rows = left_data.1.num_rows();
+ match self.join_type {
+ JoinType::Left | JoinType::Full | JoinType::Semi | JoinType::Anti => {
+ let mut buffer = BooleanBufferBuilder::new(num_rows);
+
+ buffer.append_n(num_rows, false);
+
+ buffer
+ }
+ JoinType::Inner | JoinType::Right => BooleanBufferBuilder::new(0),
+ }
+ });
+
self.right
.poll_next_unpin(cx)
.map(|maybe_batch| match maybe_batch {
@@ -993,7 +986,7 @@ impl Stream for HashJoinStream {
let timer = self.join_metrics.join_time.timer();
let result = build_batch(
&batch,
- &self.left_data,
+ left_data,
&self.on_left,
&self.on_right,
self.join_type,
@@ -1015,7 +1008,7 @@ impl Stream for HashJoinStream {
| JoinType::Semi
| JoinType::Anti => {
left_side.iter().flatten().for_each(|x| {
- self.visited_left_side.set_bit(x as usize, true);
+ visited_left_side.set_bit(x as usize, true);
});
}
JoinType::Inner | JoinType::Right => {}
@@ -1034,10 +1027,10 @@ impl Stream for HashJoinStream {
if !self.is_exhausted =>
{
let result = produce_from_matched(
- &self.visited_left_side,
+ visited_left_side,
&self.schema,
&self.column_indices,
- &self.left_data,
+ left_data,
self.join_type != JoinType::Semi,
);
if let Ok(ref batch) = result {
@@ -1066,6 +1059,17 @@ impl Stream for HashJoinStream {
}
}
+impl Stream for HashJoinStream {
+ type Item = ArrowResult<RecordBatch>;
+
+ fn poll_next(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Option<Self::Item>> {
+ self.poll_next_impl(cx)
+ }
+}
+
#[cfg(test)]
mod tests {
use crate::{
@@ -1959,7 +1963,7 @@ mod tests {
("c", &vec![30, 40]),
);
- let left_data = JoinLeftData::new((JoinHashMap(hashmap_left), left));
+ let left_data = (JoinHashMap(hashmap_left), left);
let (l, r) = build_join_indexes(
&left_data,
&right,
diff --git a/datafusion/core/src/physical_plan/join_utils.rs b/datafusion/core/src/physical_plan/join_utils.rs
index 8359bbc4e..71349ef14 100644
--- a/datafusion/core/src/physical_plan/join_utils.rs
+++ b/datafusion/core/src/physical_plan/join_utils.rs
@@ -21,7 +21,14 @@ use crate::error::{DataFusionError, Result};
use crate::logical_plan::JoinType;
use crate::physical_plan::expressions::Column;
use arrow::datatypes::{Field, Schema};
+use arrow::error::ArrowError;
+use futures::future::{BoxFuture, Shared};
+use futures::{ready, FutureExt};
+use parking_lot::Mutex;
use std::collections::HashSet;
+use std::future::Future;
+use std::sync::Arc;
+use std::task::{Context, Poll};
/// The on clause of the join, as vector of (left, right) columns.
pub type JoinOn = Vec<(Column, Column)>;
@@ -147,6 +154,112 @@ pub fn build_join_schema(
(Schema::new(fields), column_indices)
}
+/// A [`OnceAsync`] can be used to run an async closure once, with subsequent calls
+/// to [`OnceAsync::once`] returning a [`OnceFut`] to the same asynchronous computation
+///
+/// This is useful for joins where the results of one child are buffered in memory
+/// and shared across potentially multiple output partitions
+pub(crate) struct OnceAsync<T> {
+ fut: Mutex<Option<OnceFut<T>>>,
+}
+
+impl<T> Default for OnceAsync<T> {
+ fn default() -> Self {
+ Self {
+ fut: Mutex::new(None),
+ }
+ }
+}
+
+impl<T> std::fmt::Debug for OnceAsync<T> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "OnceAsync")
+ }
+}
+
+impl<T: 'static> OnceAsync<T> {
+ /// If this is the first call to this function on this object, will invoke
+ /// `f` to obtain a future and return a [`OnceFut`] referring to this
+ ///
+ /// If this is not the first call, will return a [`OnceFut`] referring
+ /// to the same future as was returned by the first call
+ pub(crate) fn once<F, Fut>(&self, f: F) -> OnceFut<T>
+ where
+ F: FnOnce() -> Fut,
+ Fut: Future<Output = Result<T>> + Send + 'static,
+ {
+ self.fut
+ .lock()
+ .get_or_insert_with(|| OnceFut::new(f()))
+ .clone()
+ }
+}
+
+/// The shared future type used internally within [`OnceAsync`]
+type OnceFutPending<T> = Shared<BoxFuture<'static, Arc<Result<T>>>>;
+
+/// A [`OnceFut`] represents a shared asynchronous computation, that will be evaluated
+/// once for all [`Clone`]'s, with [`OnceFut::get`] providing a non-consuming interface
+/// to drive the underlying [`Future`] to completion
+pub(crate) struct OnceFut<T> {
+ state: OnceFutState<T>,
+}
+
+impl<T> Clone for OnceFut<T> {
+ fn clone(&self) -> Self {
+ Self {
+ state: self.state.clone(),
+ }
+ }
+}
+
+enum OnceFutState<T> {
+ Pending(OnceFutPending<T>),
+ Ready(Arc<Result<T>>),
+}
+
+impl<T> Clone for OnceFutState<T> {
+ fn clone(&self) -> Self {
+ match self {
+ Self::Pending(p) => Self::Pending(p.clone()),
+ Self::Ready(r) => Self::Ready(r.clone()),
+ }
+ }
+}
+
+impl<T: 'static> OnceFut<T> {
+ /// Create a new [`OnceFut`] from a [`Future`]
+ pub(crate) fn new<Fut>(fut: Fut) -> Self
+ where
+ Fut: Future<Output = Result<T>> + Send + 'static,
+ {
+ Self {
+ state: OnceFutState::Pending(fut.map(Arc::new).boxed().shared()),
+ }
+ }
+
+ /// Get the result of the computation if it is ready, without consuming it
+ pub(crate) fn get(
+ &mut self,
+ cx: &mut Context<'_>,
+ ) -> Poll<std::result::Result<&T, ArrowError>> {
+ if let OnceFutState::Pending(fut) = &mut self.state {
+ let r = ready!(fut.poll_unpin(cx));
+ self.state = OnceFutState::Ready(r);
+ }
+
+ // Cannot use loop as this would trip up the borrow checker
+ match &self.state {
+ OnceFutState::Pending(_) => unreachable!(),
+ OnceFutState::Ready(r) => Poll::Ready(
+ r.as_ref()
+ .as_ref()
+ .map_err(|e| ArrowError::ExternalError(e.to_string().into())),
+ ),
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;