You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/05/31 15:19:41 UTC
[arrow-datafusion] branch master updated: Pass SessionState to TableProvider::scan (#2660)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 8f514f974 Pass SessionState to TableProvider::scan (#2660)
8f514f974 is described below
commit 8f514f9748a7526ad0d1e55c8fe3772c93c14726
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Tue May 31 16:19:36 2022 +0100
Pass SessionState to TableProvider::scan (#2660)
* Pass SessionState to TableProvider::scan
* Update ballista pin
---
benchmarks/src/bin/tpch.rs | 4 +--
datafusion-examples/examples/custom_datasource.rs | 3 +-
datafusion/core/benches/sort_limit_query_sql.rs | 4 +--
datafusion/core/src/dataframe.rs | 2 ++
datafusion/core/src/datasource/datasource.rs | 2 ++
datafusion/core/src/datasource/empty.rs | 2 ++
datafusion/core/src/datasource/listing/table.rs | 14 ++++++++--
datafusion/core/src/datasource/memory.rs | 34 ++++++++++++++++-------
datafusion/core/src/datasource/view.rs | 11 ++++----
datafusion/core/src/execution/context.rs | 11 +++++---
datafusion/core/src/physical_plan/planner.rs | 2 +-
datafusion/core/tests/custom_sources.rs | 3 +-
datafusion/core/tests/provider_filter_pushdown.rs | 3 +-
datafusion/core/tests/sql/information_schema.rs | 2 ++
datafusion/core/tests/sql/window.rs | 2 +-
datafusion/core/tests/statistics.rs | 3 +-
dev/build-arrow-ballista.sh | 2 +-
17 files changed, 70 insertions(+), 34 deletions(-)
diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs
index c46badd64..4e49bff09 100644
--- a/benchmarks/src/bin/tpch.rs
+++ b/benchmarks/src/bin/tpch.rs
@@ -178,9 +178,9 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result<Vec<RecordB
if opt.mem_table {
println!("Loading table '{}' into memory", table);
let start = Instant::now();
- let task_ctx = ctx.task_ctx();
let memtable =
- MemTable::load(table_provider, Some(opt.partitions), task_ctx).await?;
+ MemTable::load(table_provider, Some(opt.partitions), &ctx.state())
+ .await?;
println!(
"Loaded table '{}' into memory in {} ms",
table,
diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs
index a814e585e..b68936a7c 100644
--- a/datafusion-examples/examples/custom_datasource.rs
+++ b/datafusion-examples/examples/custom_datasource.rs
@@ -22,7 +22,7 @@ use datafusion::arrow::record_batch::RecordBatch;
use datafusion::dataframe::DataFrame;
use datafusion::datasource::{TableProvider, TableType};
use datafusion::error::Result;
-use datafusion::execution::context::TaskContext;
+use datafusion::execution::context::{SessionState, TaskContext};
use datafusion::logical_plan::{provider_as_source, Expr, LogicalPlanBuilder};
use datafusion::physical_plan::expressions::PhysicalSortExpr;
use datafusion::physical_plan::memory::MemoryStream;
@@ -175,6 +175,7 @@ impl TableProvider for CustomDataSource {
async fn scan(
&self,
+ _state: &SessionState,
projection: &Option<Vec<usize>>,
// filters and limit can be used here to inject some push-down operations if needed
_filters: &[Expr],
diff --git a/datafusion/core/benches/sort_limit_query_sql.rs b/datafusion/core/benches/sort_limit_query_sql.rs
index d1f253a98..198eb941f 100644
--- a/datafusion/core/benches/sort_limit_query_sql.rs
+++ b/datafusion/core/benches/sort_limit_query_sql.rs
@@ -89,8 +89,8 @@ fn create_context() -> Arc<Mutex<SessionContext>> {
let ctx = SessionContext::new();
ctx.state.write().config.target_partitions = 1;
- let task_ctx = ctx.task_ctx();
- let mem_table = MemTable::load(Arc::new(csv.await), Some(partitions), task_ctx)
+ let table_provider = Arc::new(csv.await);
+ let mem_table = MemTable::load(table_provider, Some(partitions), &ctx.state())
.await
.unwrap();
ctx.register_table("aggregate_test_100", Arc::new(mem_table))
diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index c8e0eef30..7692a187e 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -615,6 +615,7 @@ impl DataFrame {
}
}
+// TODO: This will introduce a ref cycle (#2659)
#[async_trait]
impl TableProvider for DataFrame {
fn as_any(&self) -> &dyn Any {
@@ -632,6 +633,7 @@ impl TableProvider for DataFrame {
async fn scan(
&self,
+ _ctx: &SessionState,
projection: &Option<Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
diff --git a/datafusion/core/src/datasource/datasource.rs b/datafusion/core/src/datasource/datasource.rs
index 8ab254525..17b288bef 100644
--- a/datafusion/core/src/datasource/datasource.rs
+++ b/datafusion/core/src/datasource/datasource.rs
@@ -25,6 +25,7 @@ pub use datafusion_expr::{TableProviderFilterPushDown, TableType};
use crate::arrow::datatypes::SchemaRef;
use crate::error::Result;
+use crate::execution::context::SessionState;
use crate::logical_plan::Expr;
use crate::physical_plan::ExecutionPlan;
@@ -47,6 +48,7 @@ pub trait TableProvider: Sync + Send {
/// parallelized or distributed.
async fn scan(
&self,
+ ctx: &SessionState,
projection: &Option<Vec<usize>>,
filters: &[Expr],
// limit can be used to reduce the amount scanned
diff --git a/datafusion/core/src/datasource/empty.rs b/datafusion/core/src/datasource/empty.rs
index 837cd7704..3bc7a958c 100644
--- a/datafusion/core/src/datasource/empty.rs
+++ b/datafusion/core/src/datasource/empty.rs
@@ -25,6 +25,7 @@ use async_trait::async_trait;
use crate::datasource::{TableProvider, TableType};
use crate::error::Result;
+use crate::execution::context::SessionState;
use crate::logical_plan::Expr;
use crate::physical_plan::project_schema;
use crate::physical_plan::{empty::EmptyExec, ExecutionPlan};
@@ -57,6 +58,7 @@ impl TableProvider for EmptyTable {
async fn scan(
&self,
+ _ctx: &SessionState,
projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs
index 34e44971d..bc7ed2604 100644
--- a/datafusion/core/src/datasource/listing/table.rs
+++ b/datafusion/core/src/datasource/listing/table.rs
@@ -35,6 +35,7 @@ use crate::datasource::{
use crate::logical_expr::TableProviderFilterPushDown;
use crate::{
error::{DataFusionError, Result},
+ execution::context::SessionState,
logical_plan::Expr,
physical_plan::{
empty::EmptyExec,
@@ -302,6 +303,7 @@ impl TableProvider for ListingTable {
async fn scan(
&self,
+ _ctx: &SessionState,
projection: &Option<Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
@@ -405,6 +407,7 @@ impl ListingTable {
#[cfg(test)]
mod tests {
use crate::datasource::file_format::avro::DEFAULT_AVRO_EXTENSION;
+ use crate::prelude::SessionContext;
use crate::{
datafusion_data_access::object_store::local::LocalFileSystem,
datasource::file_format::{avro::AvroFormat, parquet::ParquetFormat},
@@ -417,10 +420,12 @@ mod tests {
#[tokio::test]
async fn read_single_file() -> Result<()> {
+ let ctx = SessionContext::new();
+
let table = load_table("alltypes_plain.parquet").await?;
let projection = None;
let exec = table
- .scan(&projection, &[], None)
+ .scan(&ctx.state(), &projection, &[], None)
.await
.expect("Scan table");
@@ -447,7 +452,9 @@ mod tests {
.with_listing_options(opt)
.with_schema(schema);
let table = ListingTable::try_new(config)?;
- let exec = table.scan(&None, &[], None).await?;
+
+ let ctx = SessionContext::new();
+ let exec = table.scan(&ctx.state(), &None, &[], None).await?;
assert_eq!(exec.statistics().num_rows, Some(8));
assert_eq!(exec.statistics().total_byte_size, Some(671));
@@ -483,8 +490,9 @@ mod tests {
// this will filter out the only file in the store
let filter = Expr::not_eq(col("p1"), lit("v1"));
+ let ctx = SessionContext::new();
let scan = table
- .scan(&None, &[filter], None)
+ .scan(&ctx.state(), &None, &[filter], None)
.await
.expect("Empty execution plan");
diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs
index adc26d2f4..62dca1ea0 100644
--- a/datafusion/core/src/datasource/memory.rs
+++ b/datafusion/core/src/datasource/memory.rs
@@ -29,7 +29,7 @@ use async_trait::async_trait;
use crate::datasource::{TableProvider, TableType};
use crate::error::{DataFusionError, Result};
-use crate::execution::context::TaskContext;
+use crate::execution::context::{SessionState, TaskContext};
use crate::logical_plan::Expr;
use crate::physical_plan::common;
use crate::physical_plan::memory::MemoryExec;
@@ -65,18 +65,18 @@ impl MemTable {
pub async fn load(
t: Arc<dyn TableProvider>,
output_partitions: Option<usize>,
- context: Arc<TaskContext>,
+ ctx: &SessionState,
) -> Result<Self> {
let schema = t.schema();
- let exec = t.scan(&None, &[], None).await?;
+ let exec = t.scan(ctx, &None, &[], None).await?;
let partition_count = exec.output_partitioning().partition_count();
let tasks = (0..partition_count)
.map(|part_i| {
- let context1 = context.clone();
+ let task = Arc::new(TaskContext::from(ctx));
let exec = exec.clone();
tokio::spawn(async move {
- let stream = exec.execute(part_i, context1.clone())?;
+ let stream = exec.execute(part_i, task)?;
common::collect(stream).await
})
})
@@ -103,7 +103,8 @@ impl MemTable {
let mut output_partitions = vec![];
for i in 0..exec.output_partitioning().partition_count() {
// execute this *output* partition and collect all batches
- let mut stream = exec.execute(i, context.clone())?;
+ let task_ctx = Arc::new(TaskContext::from(ctx));
+ let mut stream = exec.execute(i, task_ctx)?;
let mut batches = vec![];
while let Some(result) = stream.next().await {
batches.push(result?);
@@ -133,6 +134,7 @@ impl TableProvider for MemTable {
async fn scan(
&self,
+ _ctx: &SessionState,
projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
@@ -180,7 +182,10 @@ mod tests {
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
// scan with projection
- let exec = provider.scan(&Some(vec![2, 1]), &[], None).await?;
+ let exec = provider
+ .scan(&session_ctx.state(), &Some(vec![2, 1]), &[], None)
+ .await?;
+
let mut it = exec.execute(0, task_ctx)?;
let batch2 = it.next().await.unwrap()?;
assert_eq!(2, batch2.schema().fields().len());
@@ -212,7 +217,9 @@ mod tests {
let provider = MemTable::try_new(schema, vec![vec![batch]])?;
- let exec = provider.scan(&None, &[], None).await?;
+ let exec = provider
+ .scan(&session_ctx.state(), &None, &[], None)
+ .await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
@@ -223,6 +230,8 @@ mod tests {
#[tokio::test]
async fn test_invalid_projection() -> Result<()> {
+ let session_ctx = SessionContext::new();
+
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
@@ -242,7 +251,10 @@ mod tests {
let projection: Vec<usize> = vec![0, 4];
- match provider.scan(&Some(projection), &[], None).await {
+ match provider
+ .scan(&session_ctx.state(), &Some(projection), &[], None)
+ .await
+ {
Err(DataFusionError::ArrowError(ArrowError::SchemaError(e))) => {
assert_eq!(
"\"project index 4 out of bounds, max field 3\"",
@@ -368,7 +380,9 @@ mod tests {
let provider =
MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?;
- let exec = provider.scan(&None, &[], None).await?;
+ let exec = provider
+ .scan(&session_ctx.state(), &None, &[], None)
+ .await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
diff --git a/datafusion/core/src/datasource/view.rs b/datafusion/core/src/datasource/view.rs
index 3db76cee1..18a43d3d4 100644
--- a/datafusion/core/src/datasource/view.rs
+++ b/datafusion/core/src/datasource/view.rs
@@ -24,17 +24,15 @@ use async_trait::async_trait;
use crate::{
error::Result,
- execution::context::SessionContext,
logical_plan::{Expr, LogicalPlan},
physical_plan::ExecutionPlan,
};
use crate::datasource::{TableProvider, TableType};
+use crate::execution::context::SessionState;
/// An implementation of `TableProvider` that uses another logical plan.
pub struct ViewTable {
- /// To create ExecutionPlan
- context: SessionContext,
/// LogicalPlan of the view
logical_plan: LogicalPlan,
/// File fields + partition columns
@@ -44,11 +42,10 @@ pub struct ViewTable {
impl ViewTable {
/// Create new view that is executed at query runtime.
/// Takes a `LogicalPlan` as input.
- pub fn try_new(context: SessionContext, logical_plan: LogicalPlan) -> Result<Self> {
+ pub fn try_new(logical_plan: LogicalPlan) -> Result<Self> {
let table_schema = logical_plan.schema().as_ref().to_owned().into();
let view = Self {
- context,
logical_plan,
table_schema,
};
@@ -73,16 +70,18 @@ impl TableProvider for ViewTable {
async fn scan(
&self,
+ ctx: &SessionState,
_projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
- self.context.create_physical_plan(&self.logical_plan).await
+ ctx.create_physical_plan(&self.logical_plan).await
}
}
#[cfg(test)]
mod tests {
+ use crate::prelude::SessionContext;
use crate::{assert_batches_eq, execution::context::SessionConfig};
use super::*;
diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs
index 4d579776e..ba3f86c69 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -349,16 +349,14 @@ impl SessionContext {
(true, Ok(_)) => {
self.deregister_table(name.as_str())?;
let plan = self.optimize(&input)?;
- let table =
- Arc::new(ViewTable::try_new(self.clone(), plan.clone())?);
+ let table = Arc::new(ViewTable::try_new(plan.clone())?);
self.register_table(name.as_str(), table)?;
Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
}
(_, Err(_)) => {
let plan = self.optimize(&input)?;
- let table =
- Arc::new(ViewTable::try_new(self.clone(), plan.clone())?);
+ let table = Arc::new(ViewTable::try_new(plan.clone())?);
self.register_table(name.as_str(), table)?;
Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
@@ -931,6 +929,11 @@ impl SessionContext {
pub fn task_ctx(&self) -> Arc<TaskContext> {
Arc::new(TaskContext::from(self))
}
+
+ /// Get a copy of the [`SessionState`] of this [`SessionContext`]
+ pub fn state(&self) -> SessionState {
+ self.state.read().clone()
+ }
}
impl FunctionRegistry for SessionContext {
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 39e5e0000..ad957409c 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -389,7 +389,7 @@ impl DefaultPhysicalPlanner {
// referred to in the query
let filters = unnormalize_cols(filters.iter().cloned());
let unaliased: Vec<Expr> = filters.into_iter().map(unalias).collect();
- source.scan(projection, &unaliased, *limit).await
+ source.scan(session_state, projection, &unaliased, *limit).await
}
LogicalPlan::Values(Values {
values,
diff --git a/datafusion/core/tests/custom_sources.rs b/datafusion/core/tests/custom_sources.rs
index f1356f7d4..1e4ac6e51 100644
--- a/datafusion/core/tests/custom_sources.rs
+++ b/datafusion/core/tests/custom_sources.rs
@@ -30,7 +30,7 @@ use datafusion::{
};
use datafusion::{error::Result, physical_plan::DisplayFormatType};
-use datafusion::execution::context::{SessionContext, TaskContext};
+use datafusion::execution::context::{SessionContext, SessionState, TaskContext};
use datafusion::logical_plan::{
col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE,
};
@@ -201,6 +201,7 @@ impl TableProvider for CustomTableProvider {
async fn scan(
&self,
+ _state: &SessionState,
projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
diff --git a/datafusion/core/tests/provider_filter_pushdown.rs b/datafusion/core/tests/provider_filter_pushdown.rs
index 79c71afb3..9b9ba84d3 100644
--- a/datafusion/core/tests/provider_filter_pushdown.rs
+++ b/datafusion/core/tests/provider_filter_pushdown.rs
@@ -21,7 +21,7 @@ use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion::datasource::datasource::{TableProvider, TableType};
use datafusion::error::Result;
-use datafusion::execution::context::{SessionContext, TaskContext};
+use datafusion::execution::context::{SessionContext, SessionState, TaskContext};
use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
use datafusion::physical_plan::common::SizedRecordBatchStream;
use datafusion::physical_plan::expressions::PhysicalSortExpr;
@@ -138,6 +138,7 @@ impl TableProvider for CustomProvider {
async fn scan(
&self,
+ _state: &SessionState,
_: &Option<Vec<usize>>,
filters: &[Expr],
_: Option<usize>,
diff --git a/datafusion/core/tests/sql/information_schema.rs b/datafusion/core/tests/sql/information_schema.rs
index c6ba61644..4aef69782 100644
--- a/datafusion/core/tests/sql/information_schema.rs
+++ b/datafusion/core/tests/sql/information_schema.rs
@@ -16,6 +16,7 @@
// under the License.
use async_trait::async_trait;
+use datafusion::execution::context::SessionState;
use datafusion::{
catalog::{
catalog::{CatalogProvider, MemoryCatalogProvider},
@@ -175,6 +176,7 @@ async fn information_schema_tables_table_types() {
async fn scan(
&self,
+ _ctx: &SessionState,
_: &Option<Vec<usize>>,
_: &[Expr],
_: Option<usize>,
diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs
index bdbc77067..120028ac4 100644
--- a/datafusion/core/tests/sql/window.rs
+++ b/datafusion/core/tests/sql/window.rs
@@ -328,7 +328,7 @@ async fn window_expr_eliminate() -> Result<()> {
let plan = ctx
.create_logical_plan(&("explain ".to_owned() + sql))
.expect(&msg);
- let state = ctx.state.read().clone();
+ let state = ctx.state();
let plan = state.optimize(&plan)?;
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
diff --git a/datafusion/core/tests/statistics.rs b/datafusion/core/tests/statistics.rs
index 99b53a62d..95879ebaf 100644
--- a/datafusion/core/tests/statistics.rs
+++ b/datafusion/core/tests/statistics.rs
@@ -34,7 +34,7 @@ use datafusion::{
};
use async_trait::async_trait;
-use datafusion::execution::context::TaskContext;
+use datafusion::execution::context::{SessionState, TaskContext};
/// This is a testing structure for statistics
/// It will act both as a table provider and execution plan
@@ -74,6 +74,7 @@ impl TableProvider for StatisticsValidation {
async fn scan(
&self,
+ _ctx: &SessionState,
projection: &Option<Vec<usize>>,
filters: &[Expr],
// limit is ignored because it is not mandatory for a `TableProvider` to honor it
diff --git a/dev/build-arrow-ballista.sh b/dev/build-arrow-ballista.sh
index 12b6e9bc0..1d287c460 100755
--- a/dev/build-arrow-ballista.sh
+++ b/dev/build-arrow-ballista.sh
@@ -24,7 +24,7 @@ rm -rf arrow-ballista 2>/dev/null
# clone the repo
# TODO make repo/branch configurable
-git clone https://github.com/tustvold/arrow-ballista -b arrow-15
+git clone https://github.com/tustvold/arrow-ballista -b session-state-table-provider
# update dependencies to local crates
python ./dev/make-ballista-deps-local.py