You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/05/31 15:19:41 UTC

[arrow-datafusion] branch master updated: Pass SessionState to TableProvider::scan (#2660)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 8f514f974 Pass SessionState to TableProvider::scan (#2660)
8f514f974 is described below

commit 8f514f9748a7526ad0d1e55c8fe3772c93c14726
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Tue May 31 16:19:36 2022 +0100

    Pass SessionState to TableProvider::scan (#2660)
    
    * Pass SessionState to TableProvider::scan
    
    * Update ballista pin
---
 benchmarks/src/bin/tpch.rs                        |  4 +--
 datafusion-examples/examples/custom_datasource.rs |  3 +-
 datafusion/core/benches/sort_limit_query_sql.rs   |  4 +--
 datafusion/core/src/dataframe.rs                  |  2 ++
 datafusion/core/src/datasource/datasource.rs      |  2 ++
 datafusion/core/src/datasource/empty.rs           |  2 ++
 datafusion/core/src/datasource/listing/table.rs   | 14 ++++++++--
 datafusion/core/src/datasource/memory.rs          | 34 ++++++++++++++++-------
 datafusion/core/src/datasource/view.rs            | 11 ++++----
 datafusion/core/src/execution/context.rs          | 11 +++++---
 datafusion/core/src/physical_plan/planner.rs      |  2 +-
 datafusion/core/tests/custom_sources.rs           |  3 +-
 datafusion/core/tests/provider_filter_pushdown.rs |  3 +-
 datafusion/core/tests/sql/information_schema.rs   |  2 ++
 datafusion/core/tests/sql/window.rs               |  2 +-
 datafusion/core/tests/statistics.rs               |  3 +-
 dev/build-arrow-ballista.sh                       |  2 +-
 17 files changed, 70 insertions(+), 34 deletions(-)

diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs
index c46badd64..4e49bff09 100644
--- a/benchmarks/src/bin/tpch.rs
+++ b/benchmarks/src/bin/tpch.rs
@@ -178,9 +178,9 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result<Vec<RecordB
         if opt.mem_table {
             println!("Loading table '{}' into memory", table);
             let start = Instant::now();
-            let task_ctx = ctx.task_ctx();
             let memtable =
-                MemTable::load(table_provider, Some(opt.partitions), task_ctx).await?;
+                MemTable::load(table_provider, Some(opt.partitions), &ctx.state())
+                    .await?;
             println!(
                 "Loaded table '{}' into memory in {} ms",
                 table,
diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs
index a814e585e..b68936a7c 100644
--- a/datafusion-examples/examples/custom_datasource.rs
+++ b/datafusion-examples/examples/custom_datasource.rs
@@ -22,7 +22,7 @@ use datafusion::arrow::record_batch::RecordBatch;
 use datafusion::dataframe::DataFrame;
 use datafusion::datasource::{TableProvider, TableType};
 use datafusion::error::Result;
-use datafusion::execution::context::TaskContext;
+use datafusion::execution::context::{SessionState, TaskContext};
 use datafusion::logical_plan::{provider_as_source, Expr, LogicalPlanBuilder};
 use datafusion::physical_plan::expressions::PhysicalSortExpr;
 use datafusion::physical_plan::memory::MemoryStream;
@@ -175,6 +175,7 @@ impl TableProvider for CustomDataSource {
 
     async fn scan(
         &self,
+        _state: &SessionState,
         projection: &Option<Vec<usize>>,
         // filters and limit can be used here to inject some push-down operations if needed
         _filters: &[Expr],
diff --git a/datafusion/core/benches/sort_limit_query_sql.rs b/datafusion/core/benches/sort_limit_query_sql.rs
index d1f253a98..198eb941f 100644
--- a/datafusion/core/benches/sort_limit_query_sql.rs
+++ b/datafusion/core/benches/sort_limit_query_sql.rs
@@ -89,8 +89,8 @@ fn create_context() -> Arc<Mutex<SessionContext>> {
         let ctx = SessionContext::new();
         ctx.state.write().config.target_partitions = 1;
 
-        let task_ctx = ctx.task_ctx();
-        let mem_table = MemTable::load(Arc::new(csv.await), Some(partitions), task_ctx)
+        let table_provider = Arc::new(csv.await);
+        let mem_table = MemTable::load(table_provider, Some(partitions), &ctx.state())
             .await
             .unwrap();
         ctx.register_table("aggregate_test_100", Arc::new(mem_table))
diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index c8e0eef30..7692a187e 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -615,6 +615,7 @@ impl DataFrame {
     }
 }
 
+// TODO: This will introduce a ref cycle (#2659)
 #[async_trait]
 impl TableProvider for DataFrame {
     fn as_any(&self) -> &dyn Any {
@@ -632,6 +633,7 @@ impl TableProvider for DataFrame {
 
     async fn scan(
         &self,
+        _ctx: &SessionState,
         projection: &Option<Vec<usize>>,
         filters: &[Expr],
         limit: Option<usize>,
diff --git a/datafusion/core/src/datasource/datasource.rs b/datafusion/core/src/datasource/datasource.rs
index 8ab254525..17b288bef 100644
--- a/datafusion/core/src/datasource/datasource.rs
+++ b/datafusion/core/src/datasource/datasource.rs
@@ -25,6 +25,7 @@ pub use datafusion_expr::{TableProviderFilterPushDown, TableType};
 
 use crate::arrow::datatypes::SchemaRef;
 use crate::error::Result;
+use crate::execution::context::SessionState;
 use crate::logical_plan::Expr;
 use crate::physical_plan::ExecutionPlan;
 
@@ -47,6 +48,7 @@ pub trait TableProvider: Sync + Send {
     /// parallelized or distributed.
     async fn scan(
         &self,
+        ctx: &SessionState,
         projection: &Option<Vec<usize>>,
         filters: &[Expr],
         // limit can be used to reduce the amount scanned
diff --git a/datafusion/core/src/datasource/empty.rs b/datafusion/core/src/datasource/empty.rs
index 837cd7704..3bc7a958c 100644
--- a/datafusion/core/src/datasource/empty.rs
+++ b/datafusion/core/src/datasource/empty.rs
@@ -25,6 +25,7 @@ use async_trait::async_trait;
 
 use crate::datasource::{TableProvider, TableType};
 use crate::error::Result;
+use crate::execution::context::SessionState;
 use crate::logical_plan::Expr;
 use crate::physical_plan::project_schema;
 use crate::physical_plan::{empty::EmptyExec, ExecutionPlan};
@@ -57,6 +58,7 @@ impl TableProvider for EmptyTable {
 
     async fn scan(
         &self,
+        _ctx: &SessionState,
         projection: &Option<Vec<usize>>,
         _filters: &[Expr],
         _limit: Option<usize>,
diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs
index 34e44971d..bc7ed2604 100644
--- a/datafusion/core/src/datasource/listing/table.rs
+++ b/datafusion/core/src/datasource/listing/table.rs
@@ -35,6 +35,7 @@ use crate::datasource::{
 use crate::logical_expr::TableProviderFilterPushDown;
 use crate::{
     error::{DataFusionError, Result},
+    execution::context::SessionState,
     logical_plan::Expr,
     physical_plan::{
         empty::EmptyExec,
@@ -302,6 +303,7 @@ impl TableProvider for ListingTable {
 
     async fn scan(
         &self,
+        _ctx: &SessionState,
         projection: &Option<Vec<usize>>,
         filters: &[Expr],
         limit: Option<usize>,
@@ -405,6 +407,7 @@ impl ListingTable {
 #[cfg(test)]
 mod tests {
     use crate::datasource::file_format::avro::DEFAULT_AVRO_EXTENSION;
+    use crate::prelude::SessionContext;
     use crate::{
         datafusion_data_access::object_store::local::LocalFileSystem,
         datasource::file_format::{avro::AvroFormat, parquet::ParquetFormat},
@@ -417,10 +420,12 @@ mod tests {
 
     #[tokio::test]
     async fn read_single_file() -> Result<()> {
+        let ctx = SessionContext::new();
+
         let table = load_table("alltypes_plain.parquet").await?;
         let projection = None;
         let exec = table
-            .scan(&projection, &[], None)
+            .scan(&ctx.state(), &projection, &[], None)
             .await
             .expect("Scan table");
 
@@ -447,7 +452,9 @@ mod tests {
             .with_listing_options(opt)
             .with_schema(schema);
         let table = ListingTable::try_new(config)?;
-        let exec = table.scan(&None, &[], None).await?;
+
+        let ctx = SessionContext::new();
+        let exec = table.scan(&ctx.state(), &None, &[], None).await?;
         assert_eq!(exec.statistics().num_rows, Some(8));
         assert_eq!(exec.statistics().total_byte_size, Some(671));
 
@@ -483,8 +490,9 @@ mod tests {
         // this will filter out the only file in the store
         let filter = Expr::not_eq(col("p1"), lit("v1"));
 
+        let ctx = SessionContext::new();
         let scan = table
-            .scan(&None, &[filter], None)
+            .scan(&ctx.state(), &None, &[filter], None)
             .await
             .expect("Empty execution plan");
 
diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs
index adc26d2f4..62dca1ea0 100644
--- a/datafusion/core/src/datasource/memory.rs
+++ b/datafusion/core/src/datasource/memory.rs
@@ -29,7 +29,7 @@ use async_trait::async_trait;
 
 use crate::datasource::{TableProvider, TableType};
 use crate::error::{DataFusionError, Result};
-use crate::execution::context::TaskContext;
+use crate::execution::context::{SessionState, TaskContext};
 use crate::logical_plan::Expr;
 use crate::physical_plan::common;
 use crate::physical_plan::memory::MemoryExec;
@@ -65,18 +65,18 @@ impl MemTable {
     pub async fn load(
         t: Arc<dyn TableProvider>,
         output_partitions: Option<usize>,
-        context: Arc<TaskContext>,
+        ctx: &SessionState,
     ) -> Result<Self> {
         let schema = t.schema();
-        let exec = t.scan(&None, &[], None).await?;
+        let exec = t.scan(ctx, &None, &[], None).await?;
         let partition_count = exec.output_partitioning().partition_count();
 
         let tasks = (0..partition_count)
             .map(|part_i| {
-                let context1 = context.clone();
+                let task = Arc::new(TaskContext::from(ctx));
                 let exec = exec.clone();
                 tokio::spawn(async move {
-                    let stream = exec.execute(part_i, context1.clone())?;
+                    let stream = exec.execute(part_i, task)?;
                     common::collect(stream).await
                 })
             })
@@ -103,7 +103,8 @@ impl MemTable {
             let mut output_partitions = vec![];
             for i in 0..exec.output_partitioning().partition_count() {
                 // execute this *output* partition and collect all batches
-                let mut stream = exec.execute(i, context.clone())?;
+                let task_ctx = Arc::new(TaskContext::from(ctx));
+                let mut stream = exec.execute(i, task_ctx)?;
                 let mut batches = vec![];
                 while let Some(result) = stream.next().await {
                     batches.push(result?);
@@ -133,6 +134,7 @@ impl TableProvider for MemTable {
 
     async fn scan(
         &self,
+        _ctx: &SessionState,
         projection: &Option<Vec<usize>>,
         _filters: &[Expr],
         _limit: Option<usize>,
@@ -180,7 +182,10 @@ mod tests {
         let provider = MemTable::try_new(schema, vec![vec![batch]])?;
 
         // scan with projection
-        let exec = provider.scan(&Some(vec![2, 1]), &[], None).await?;
+        let exec = provider
+            .scan(&session_ctx.state(), &Some(vec![2, 1]), &[], None)
+            .await?;
+
         let mut it = exec.execute(0, task_ctx)?;
         let batch2 = it.next().await.unwrap()?;
         assert_eq!(2, batch2.schema().fields().len());
@@ -212,7 +217,9 @@ mod tests {
 
         let provider = MemTable::try_new(schema, vec![vec![batch]])?;
 
-        let exec = provider.scan(&None, &[], None).await?;
+        let exec = provider
+            .scan(&session_ctx.state(), &None, &[], None)
+            .await?;
         let mut it = exec.execute(0, task_ctx)?;
         let batch1 = it.next().await.unwrap()?;
         assert_eq!(3, batch1.schema().fields().len());
@@ -223,6 +230,8 @@ mod tests {
 
     #[tokio::test]
     async fn test_invalid_projection() -> Result<()> {
+        let session_ctx = SessionContext::new();
+
         let schema = Arc::new(Schema::new(vec![
             Field::new("a", DataType::Int32, false),
             Field::new("b", DataType::Int32, false),
@@ -242,7 +251,10 @@ mod tests {
 
         let projection: Vec<usize> = vec![0, 4];
 
-        match provider.scan(&Some(projection), &[], None).await {
+        match provider
+            .scan(&session_ctx.state(), &Some(projection), &[], None)
+            .await
+        {
             Err(DataFusionError::ArrowError(ArrowError::SchemaError(e))) => {
                 assert_eq!(
                     "\"project index 4 out of bounds, max field 3\"",
@@ -368,7 +380,9 @@ mod tests {
         let provider =
             MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?;
 
-        let exec = provider.scan(&None, &[], None).await?;
+        let exec = provider
+            .scan(&session_ctx.state(), &None, &[], None)
+            .await?;
         let mut it = exec.execute(0, task_ctx)?;
         let batch1 = it.next().await.unwrap()?;
         assert_eq!(3, batch1.schema().fields().len());
diff --git a/datafusion/core/src/datasource/view.rs b/datafusion/core/src/datasource/view.rs
index 3db76cee1..18a43d3d4 100644
--- a/datafusion/core/src/datasource/view.rs
+++ b/datafusion/core/src/datasource/view.rs
@@ -24,17 +24,15 @@ use async_trait::async_trait;
 
 use crate::{
     error::Result,
-    execution::context::SessionContext,
     logical_plan::{Expr, LogicalPlan},
     physical_plan::ExecutionPlan,
 };
 
 use crate::datasource::{TableProvider, TableType};
+use crate::execution::context::SessionState;
 
 /// An implementation of `TableProvider` that uses another logical plan.
 pub struct ViewTable {
-    /// To create ExecutionPlan
-    context: SessionContext,
     /// LogicalPlan of the view
     logical_plan: LogicalPlan,
     /// File fields + partition columns
@@ -44,11 +42,10 @@ pub struct ViewTable {
 impl ViewTable {
     /// Create new view that is executed at query runtime.
     /// Takes a `LogicalPlan` as input.
-    pub fn try_new(context: SessionContext, logical_plan: LogicalPlan) -> Result<Self> {
+    pub fn try_new(logical_plan: LogicalPlan) -> Result<Self> {
         let table_schema = logical_plan.schema().as_ref().to_owned().into();
 
         let view = Self {
-            context,
             logical_plan,
             table_schema,
         };
@@ -73,16 +70,18 @@ impl TableProvider for ViewTable {
 
     async fn scan(
         &self,
+        ctx: &SessionState,
         _projection: &Option<Vec<usize>>,
         _filters: &[Expr],
         _limit: Option<usize>,
     ) -> Result<Arc<dyn ExecutionPlan>> {
-        self.context.create_physical_plan(&self.logical_plan).await
+        ctx.create_physical_plan(&self.logical_plan).await
     }
 }
 
 #[cfg(test)]
 mod tests {
+    use crate::prelude::SessionContext;
     use crate::{assert_batches_eq, execution::context::SessionConfig};
 
     use super::*;
diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs
index 4d579776e..ba3f86c69 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -349,16 +349,14 @@ impl SessionContext {
                     (true, Ok(_)) => {
                         self.deregister_table(name.as_str())?;
                         let plan = self.optimize(&input)?;
-                        let table =
-                            Arc::new(ViewTable::try_new(self.clone(), plan.clone())?);
+                        let table = Arc::new(ViewTable::try_new(plan.clone())?);
 
                         self.register_table(name.as_str(), table)?;
                         Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
                     }
                     (_, Err(_)) => {
                         let plan = self.optimize(&input)?;
-                        let table =
-                            Arc::new(ViewTable::try_new(self.clone(), plan.clone())?);
+                        let table = Arc::new(ViewTable::try_new(plan.clone())?);
 
                         self.register_table(name.as_str(), table)?;
                         Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
@@ -931,6 +929,11 @@ impl SessionContext {
     pub fn task_ctx(&self) -> Arc<TaskContext> {
         Arc::new(TaskContext::from(self))
     }
+
+    /// Get a copy of the [`SessionState`] of this [`SessionContext`]
+    pub fn state(&self) -> SessionState {
+        self.state.read().clone()
+    }
 }
 
 impl FunctionRegistry for SessionContext {
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 39e5e0000..ad957409c 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -389,7 +389,7 @@ impl DefaultPhysicalPlanner {
                     // referred to in the query
                     let filters = unnormalize_cols(filters.iter().cloned());
                     let unaliased: Vec<Expr> = filters.into_iter().map(unalias).collect();
-                    source.scan(projection, &unaliased, *limit).await
+                    source.scan(session_state, projection, &unaliased, *limit).await
                 }
                 LogicalPlan::Values(Values {
                     values,
diff --git a/datafusion/core/tests/custom_sources.rs b/datafusion/core/tests/custom_sources.rs
index f1356f7d4..1e4ac6e51 100644
--- a/datafusion/core/tests/custom_sources.rs
+++ b/datafusion/core/tests/custom_sources.rs
@@ -30,7 +30,7 @@ use datafusion::{
 };
 use datafusion::{error::Result, physical_plan::DisplayFormatType};
 
-use datafusion::execution::context::{SessionContext, TaskContext};
+use datafusion::execution::context::{SessionContext, SessionState, TaskContext};
 use datafusion::logical_plan::{
     col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE,
 };
@@ -201,6 +201,7 @@ impl TableProvider for CustomTableProvider {
 
     async fn scan(
         &self,
+        _state: &SessionState,
         projection: &Option<Vec<usize>>,
         _filters: &[Expr],
         _limit: Option<usize>,
diff --git a/datafusion/core/tests/provider_filter_pushdown.rs b/datafusion/core/tests/provider_filter_pushdown.rs
index 79c71afb3..9b9ba84d3 100644
--- a/datafusion/core/tests/provider_filter_pushdown.rs
+++ b/datafusion/core/tests/provider_filter_pushdown.rs
@@ -21,7 +21,7 @@ use arrow::record_batch::RecordBatch;
 use async_trait::async_trait;
 use datafusion::datasource::datasource::{TableProvider, TableType};
 use datafusion::error::Result;
-use datafusion::execution::context::{SessionContext, TaskContext};
+use datafusion::execution::context::{SessionContext, SessionState, TaskContext};
 use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
 use datafusion::physical_plan::common::SizedRecordBatchStream;
 use datafusion::physical_plan::expressions::PhysicalSortExpr;
@@ -138,6 +138,7 @@ impl TableProvider for CustomProvider {
 
     async fn scan(
         &self,
+        _state: &SessionState,
         _: &Option<Vec<usize>>,
         filters: &[Expr],
         _: Option<usize>,
diff --git a/datafusion/core/tests/sql/information_schema.rs b/datafusion/core/tests/sql/information_schema.rs
index c6ba61644..4aef69782 100644
--- a/datafusion/core/tests/sql/information_schema.rs
+++ b/datafusion/core/tests/sql/information_schema.rs
@@ -16,6 +16,7 @@
 // under the License.
 
 use async_trait::async_trait;
+use datafusion::execution::context::SessionState;
 use datafusion::{
     catalog::{
         catalog::{CatalogProvider, MemoryCatalogProvider},
@@ -175,6 +176,7 @@ async fn information_schema_tables_table_types() {
 
         async fn scan(
             &self,
+            _ctx: &SessionState,
             _: &Option<Vec<usize>>,
             _: &[Expr],
             _: Option<usize>,
diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs
index bdbc77067..120028ac4 100644
--- a/datafusion/core/tests/sql/window.rs
+++ b/datafusion/core/tests/sql/window.rs
@@ -328,7 +328,7 @@ async fn window_expr_eliminate() -> Result<()> {
     let plan = ctx
         .create_logical_plan(&("explain ".to_owned() + sql))
         .expect(&msg);
-    let state = ctx.state.read().clone();
+    let state = ctx.state();
     let plan = state.optimize(&plan)?;
     let expected = vec![
         "Explain [plan_type:Utf8, plan:Utf8]",
diff --git a/datafusion/core/tests/statistics.rs b/datafusion/core/tests/statistics.rs
index 99b53a62d..95879ebaf 100644
--- a/datafusion/core/tests/statistics.rs
+++ b/datafusion/core/tests/statistics.rs
@@ -34,7 +34,7 @@ use datafusion::{
 };
 
 use async_trait::async_trait;
-use datafusion::execution::context::TaskContext;
+use datafusion::execution::context::{SessionState, TaskContext};
 
 /// This is a testing structure for statistics
 /// It will act both as a table provider and execution plan
@@ -74,6 +74,7 @@ impl TableProvider for StatisticsValidation {
 
     async fn scan(
         &self,
+        _ctx: &SessionState,
         projection: &Option<Vec<usize>>,
         filters: &[Expr],
         // limit is ignored because it is not mandatory for a `TableProvider` to honor it
diff --git a/dev/build-arrow-ballista.sh b/dev/build-arrow-ballista.sh
index 12b6e9bc0..1d287c460 100755
--- a/dev/build-arrow-ballista.sh
+++ b/dev/build-arrow-ballista.sh
@@ -24,7 +24,7 @@ rm -rf arrow-ballista 2>/dev/null
 
 # clone the repo
 # TODO make repo/branch configurable
-git clone https://github.com/tustvold/arrow-ballista -b arrow-15
+git clone https://github.com/tustvold/arrow-ballista -b session-state-table-provider
 
 # update dependencies to local crates
 python ./dev/make-ballista-deps-local.py