You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by nj...@apache.org on 2023/04/03 04:01:44 UTC

[arrow-ballista] branch main updated: Upgrade DataFusion to 21.0.0 (#727)

This is an automated email from the ASF dual-hosted git repository.

nju_yaho pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new 32f97930 Upgrade DataFusion to 21.0.0 (#727)
32f97930 is described below

commit 32f979301a41abaceda67559a152a0ede0a09d51
Author: r.4ntix <r....@gmail.com>
AuthorDate: Mon Apr 3 12:01:40 2023 +0800

    Upgrade DataFusion to 21.0.0 (#727)
    
    * Upgrade DataFusion to 21.0.0
    
    * chore: separate the get and register logic of object store registry
---
 Cargo.toml                                         | 14 +---
 ballista-cli/Cargo.toml                            | 14 +---
 ballista/client/README.md                          |  2 +-
 ballista/client/src/context.rs                     | 13 +---
 ballista/core/src/utils.rs                         | 85 +++++++++++++++-------
 ballista/executor/src/execution_loop.rs            | 17 +++--
 ballista/executor/src/executor_server.rs           | 17 +++--
 .../src/state/execution_graph/execution_stage.rs   |  3 +-
 .../scheduler/src/state/execution_graph_dot.rs     |  4 +-
 ballista/scheduler/src/test_utils.rs               |  4 +-
 10 files changed, 97 insertions(+), 76 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index 29e6bc98..c2646cd3 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -16,22 +16,14 @@
 # under the License.
 
 [workspace]
-members = [
-    "ballista-cli",
-    "ballista/client",
-    "ballista/core",
-    "ballista/executor",
-    "ballista/scheduler",
-    "benchmarks",
-    "examples",
-]
+members = ["ballista-cli", "ballista/client", "ballista/core", "ballista/executor", "ballista/scheduler", "benchmarks", "examples"]
 exclude = ["python"]
 
 [workspace.dependencies]
 arrow = { version = "34.0.0" }
 arrow-flight = { version = "34.0.0", features = ["flight-sql-experimental"] }
-datafusion = "20.0.0"
-datafusion-proto = "20.0.0"
+datafusion = "21.0.0"
+datafusion-proto = "21.0.0"
 
 # cargo build --profile release-lto
 [profile.release-lto]
diff --git a/ballista-cli/Cargo.toml b/ballista-cli/Cargo.toml
index 695fba9a..075989f2 100644
--- a/ballista-cli/Cargo.toml
+++ b/ballista-cli/Cargo.toml
@@ -29,24 +29,16 @@ rust-version = "1.63"
 readme = "README.md"
 
 [dependencies]
-ballista = { path = "../ballista/client", version = "0.11.0", features = [
-    "standalone",
-] }
+ballista = { path = "../ballista/client", version = "0.11.0", features = ["standalone"] }
 clap = { version = "3", features = ["derive", "cargo"] }
 datafusion = { workspace = true }
-datafusion-cli = "20.0.0"
+datafusion-cli = "21.0.0"
 dirs = "4.0.0"
 env_logger = "0.10"
 mimalloc = { version = "0.1", default-features = false }
 num_cpus = "1.13.0"
 rustyline = "10.0"
-tokio = { version = "1.0", features = [
-    "macros",
-    "rt",
-    "rt-multi-thread",
-    "sync",
-    "parking_lot",
-] }
+tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] }
 
 [features]
 s3 = ["ballista/s3"]
diff --git a/ballista/client/README.md b/ballista/client/README.md
index 525d6987..4826e5ca 100644
--- a/ballista/client/README.md
+++ b/ballista/client/README.md
@@ -85,7 +85,7 @@ To build a simple ballista example, add the following dependencies to your `Carg
 ```toml
 [dependencies]
 ballista = "0.10"
-datafusion = "20.0.0"
+datafusion = "21.0.0"
 tokio = "1.0"
 ```
 
diff --git a/ballista/client/src/context.rs b/ballista/client/src/context.rs
index 37777b30..84bbaf02 100644
--- a/ballista/client/src/context.rs
+++ b/ballista/client/src/context.rs
@@ -400,7 +400,7 @@ impl BallistaContext {
                 ref if_not_exists,
                 ..
             }) => {
-                let table_exists = ctx.table_exist(name.as_table_reference())?;
+                let table_exists = ctx.table_exist(name)?;
                 let schema: SchemaRef = Arc::new(schema.as_ref().to_owned().into());
                 let table_partition_cols = table_partition_cols
                     .iter()
@@ -422,17 +422,12 @@ impl BallistaContext {
                             if !schema.fields().is_empty() {
                                 options = options.schema(&schema);
                             }
-                            self.register_csv(
-                                name.as_table_reference().table(),
-                                location,
-                                options,
-                            )
-                            .await?;
+                            self.register_csv(name.table(), location, options).await?;
                             Ok(DataFrame::new(ctx.state(), plan))
                         }
                         "parquet" => {
                             self.register_parquet(
-                                name.as_table_reference().table(),
+                                name.table(),
                                 location,
                                 ParquetReadOptions::default()
                                     .table_partition_cols(table_partition_cols),
@@ -442,7 +437,7 @@ impl BallistaContext {
                         }
                         "avro" => {
                             self.register_avro(
-                                name.as_table_reference().table(),
+                                name.table(),
                                 location,
                                 AvroReadOptions::default()
                                     .table_partition_cols(table_partition_cols),
diff --git a/ballista/core/src/utils.rs b/ballista/core/src/utils.rs
index 1a0bd7de..6745b529 100644
--- a/ballista/core/src/utils.rs
+++ b/ballista/core/src/utils.rs
@@ -24,7 +24,9 @@ use crate::serde::scheduler::PartitionStats;
 use async_trait::async_trait;
 use datafusion::arrow::datatypes::Schema;
 use datafusion::arrow::{ipc::writer::FileWriter, record_batch::RecordBatch};
-use datafusion::datasource::object_store::{ObjectStoreProvider, ObjectStoreRegistry};
+use datafusion::datasource::object_store::{
+    DefaultObjectStoreRegistry, ObjectStoreRegistry,
+};
 use datafusion::error::DataFusionError;
 use datafusion::execution::context::{
     QueryPlanner, SessionConfig, SessionContext, SessionState,
@@ -78,23 +80,29 @@ pub fn default_session_builder(config: SessionConfig) -> SessionState {
 
 /// Get a RuntimeConfig with specific ObjectStoreDetector in the ObjectStoreRegistry
 pub fn with_object_store_provider(config: RuntimeConfig) -> RuntimeConfig {
-    config.with_object_store_registry(Arc::new(ObjectStoreRegistry::new_with_provider(
-        Some(Arc::new(FeatureBasedObjectStoreProvider)),
-    )))
+    let object_store_registry = BallistaObjectStoreRegistry::new();
+    config.with_object_store_registry(Arc::new(object_store_registry))
 }
 
 /// An object store detector based on which features are enable for different kinds of object stores
-pub struct FeatureBasedObjectStoreProvider;
+#[derive(Debug, Default)]
+pub struct BallistaObjectStoreRegistry {
+    inner: DefaultObjectStoreRegistry,
+}
+
+impl BallistaObjectStoreRegistry {
+    pub fn new() -> Self {
+        Default::default()
+    }
 
-impl ObjectStoreProvider for FeatureBasedObjectStoreProvider {
-    /// Detector a suitable object store based on its url if possible
-    /// Return the key and object store
-    #[allow(unused_variables)]
-    fn get_by_url(&self, url: &Url) -> datafusion::error::Result<Arc<dyn ObjectStore>> {
+    /// Find a suitable object store based on its url and enabled features if possible
+    fn get_feature_store(
+        &self,
+        url: &Url,
+    ) -> datafusion::error::Result<Arc<dyn ObjectStore>> {
         #[cfg(any(feature = "hdfs", feature = "hdfs3"))]
         {
-            let store = HadoopFileSystem::new(url.as_str());
-            if let Some(store) = store {
+            if let Some(store) = HadoopFileSystem::new(url.as_str()) {
                 return Ok(Arc::new(store));
             }
         }
@@ -103,21 +111,25 @@ impl ObjectStoreProvider for FeatureBasedObjectStoreProvider {
         {
             if url.as_str().starts_with("s3://") {
                 if let Some(bucket_name) = url.host_str() {
-                    let store = AmazonS3Builder::from_env()
-                        .with_bucket_name(bucket_name)
-                        .build()?;
-                    return Ok(Arc::new(store));
+                    let store = Arc::new(
+                        AmazonS3Builder::from_env()
+                            .with_bucket_name(bucket_name)
+                            .build()?,
+                    );
+                    return Ok(store);
                 }
             // Support Alibaba Cloud OSS
             // Use S3 compatibility mode to access Alibaba Cloud OSS
             // The `AWS_ENDPOINT` should have bucket name included
             } else if url.as_str().starts_with("oss://") {
                 if let Some(bucket_name) = url.host_str() {
-                    let store = AmazonS3Builder::from_env()
-                        .with_virtual_hosted_style_request(true)
-                        .with_bucket_name(bucket_name)
-                        .build()?;
-                    return Ok(Arc::new(store));
+                    let store = Arc::new(
+                        AmazonS3Builder::from_env()
+                            .with_virtual_hosted_style_request(true)
+                            .with_bucket_name(bucket_name)
+                            .build()?,
+                    );
+                    return Ok(store);
                 }
             }
         }
@@ -126,20 +138,41 @@ impl ObjectStoreProvider for FeatureBasedObjectStoreProvider {
         {
             if url.to_string().starts_with("azure://") {
                 if let Some(bucket_name) = url.host_str() {
-                    let store = MicrosoftAzureBuilder::from_env()
-                        .with_container_name(bucket_name)
-                        .build()?;
-                    return Ok(Arc::new(store));
+                    let store = Arc::new(
+                        MicrosoftAzureBuilder::from_env()
+                            .with_container_name(bucket_name)
+                            .build()?,
+                    );
+                    return Ok(store);
                 }
             }
         }
 
         Err(DataFusionError::Execution(format!(
-            "No object store available for {url}"
+            "No object store available for: {url}"
         )))
     }
 }
 
+impl ObjectStoreRegistry for BallistaObjectStoreRegistry {
+    fn register_store(
+        &self,
+        url: &Url,
+        store: Arc<dyn ObjectStore>,
+    ) -> Option<Arc<dyn ObjectStore>> {
+        self.inner.register_store(url, store)
+    }
+
+    fn get_store(&self, url: &Url) -> datafusion::error::Result<Arc<dyn ObjectStore>> {
+        self.inner.get_store(url).or_else(|_| {
+            let store = self.get_feature_store(url)?;
+            self.inner.register_store(url, store.clone());
+
+            Ok(store)
+        })
+    }
+}
+
 /// Stream data to disk in Arrow IPC format
 pub async fn write_stream_to_disk(
     stream: &mut Pin<Box<dyn RecordBatchStream + Send>>,
diff --git a/ballista/executor/src/execution_loop.rs b/ballista/executor/src/execution_loop.rs
index 6b7070b8..aaaa90a3 100644
--- a/ballista/executor/src/execution_loop.rs
+++ b/ballista/executor/src/execution_loop.rs
@@ -15,13 +15,14 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use datafusion::config::Extensions;
+use datafusion::config::ConfigOptions;
 use datafusion::physical_plan::ExecutionPlan;
 
 use ballista_core::serde::protobuf::{
     scheduler_grpc_client::SchedulerGrpcClient, PollWorkParams, PollWorkResult,
     TaskDefinition, TaskStatus,
 };
+use datafusion::prelude::SessionConfig;
 use tokio::sync::{OwnedSemaphorePermit, Semaphore};
 
 use crate::cpu_bound_executor::DedicatedExecutor;
@@ -173,6 +174,11 @@ async fn run_received_task<T: 'static + AsLogicalPlan, U: 'static + AsExecutionP
     for kv_pair in task.props {
         task_props.insert(kv_pair.key, kv_pair.value);
     }
+    let mut config = ConfigOptions::new();
+    for (k, v) in task_props {
+        config.set(&k, &v)?;
+    }
+    let session_config = SessionConfig::from(config);
 
     let mut task_scalar_functions = HashMap::new();
     let mut task_aggregate_functions = HashMap::new();
@@ -185,15 +191,14 @@ async fn run_received_task<T: 'static + AsLogicalPlan, U: 'static + AsExecutionP
     }
     let runtime = executor.runtime.clone();
     let session_id = task.session_id.clone();
-    let task_context = Arc::new(TaskContext::try_new(
-        task_identity.clone(),
+    let task_context = Arc::new(TaskContext::new(
+        Some(task_identity.clone()),
         session_id,
-        task_props,
+        session_config,
         task_scalar_functions,
         task_aggregate_functions,
         runtime.clone(),
-        Extensions::default(),
-    )?);
+    ));
 
     let plan: Arc<dyn ExecutionPlan> =
         U::try_decode(task.plan.as_slice()).and_then(|proto| {
diff --git a/ballista/executor/src/executor_server.rs b/ballista/executor/src/executor_server.rs
index 7256d8e2..26dec059 100644
--- a/ballista/executor/src/executor_server.rs
+++ b/ballista/executor/src/executor_server.rs
@@ -16,7 +16,8 @@
 // under the License.
 
 use ballista_core::BALLISTA_VERSION;
-use datafusion::config::Extensions;
+use datafusion::config::ConfigOptions;
+use datafusion::prelude::SessionConfig;
 use std::collections::HashMap;
 use std::convert::TryInto;
 use std::ops::Deref;
@@ -307,6 +308,11 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
         info!("Start to run task {}", task_identity);
         let task = curator_task.task;
         let task_props = task.props;
+        let mut config = ConfigOptions::new();
+        for (k, v) in task_props {
+            config.set(&k, &v)?;
+        }
+        let session_config = SessionConfig::from(config);
 
         let mut task_scalar_functions = HashMap::new();
         let mut task_aggregate_functions = HashMap::new();
@@ -320,15 +326,14 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
 
         let session_id = task.session_id;
         let runtime = self.executor.runtime.clone();
-        let task_context = Arc::new(TaskContext::try_new(
-            task_identity.clone(),
+        let task_context = Arc::new(TaskContext::new(
+            Some(task_identity.clone()),
             session_id,
-            task_props,
+            session_config,
             task_scalar_functions,
             task_aggregate_functions,
             runtime.clone(),
-            Extensions::default(),
-        )?);
+        ));
 
         let encoded_plan = &task.plan.as_slice();
 
diff --git a/ballista/scheduler/src/state/execution_graph/execution_stage.rs b/ballista/scheduler/src/state/execution_graph/execution_stage.rs
index 542d5116..3cc3eb14 100644
--- a/ballista/scheduler/src/state/execution_graph/execution_stage.rs
+++ b/ballista/scheduler/src/state/execution_graph/execution_stage.rs
@@ -365,8 +365,7 @@ impl UnresolvedStage {
 
         // Optimize join order based on new resolved statistics
         let optimize_join = JoinSelection::new();
-        let plan =
-            optimize_join.optimize(plan, SessionConfig::default().config_options())?;
+        let plan = optimize_join.optimize(plan, SessionConfig::default().options())?;
 
         Ok(ResolvedStage::new(
             self.stage_id,
diff --git a/ballista/scheduler/src/state/execution_graph_dot.rs b/ballista/scheduler/src/state/execution_graph_dot.rs
index 5b4f744d..d9df9488 100644
--- a/ballista/scheduler/src/state/execution_graph_dot.rs
+++ b/ballista/scheduler/src/state/execution_graph_dot.rs
@@ -627,7 +627,7 @@ filter_expr="]
             .with_target_partitions(48)
             .with_batch_size(4096);
         config
-            .config_options_mut()
+            .options_mut()
             .optimizer
             .enable_round_robin_repartition = false;
         let ctx = SessionContext::with_config(config);
@@ -654,7 +654,7 @@ filter_expr="]
             .with_target_partitions(48)
             .with_batch_size(4096);
         config
-            .config_options_mut()
+            .options_mut()
             .optimizer
             .enable_round_robin_repartition = false;
         let ctx = SessionContext::with_config(config);
diff --git a/ballista/scheduler/src/test_utils.rs b/ballista/scheduler/src/test_utils.rs
index ee35958d..aceef06c 100644
--- a/ballista/scheduler/src/test_utils.rs
+++ b/ballista/scheduler/src/test_utils.rs
@@ -882,7 +882,7 @@ pub async fn test_coalesce_plan(partition: usize) -> ExecutionGraph {
 pub async fn test_join_plan(partition: usize) -> ExecutionGraph {
     let mut config = SessionConfig::new().with_target_partitions(partition);
     config
-        .config_options_mut()
+        .options_mut()
         .optimizer
         .enable_round_robin_repartition = false;
     let ctx = Arc::new(SessionContext::with_config(config));
@@ -905,7 +905,7 @@ pub async fn test_join_plan(partition: usize) -> ExecutionGraph {
     let logical_plan = left_plan
         .join(right_plan, JoinType::Inner, (vec!["id"], vec!["id"]), None)
         .unwrap()
-        .aggregate(vec![col("id")], vec![sum(col("gmv"))])
+        .aggregate(vec![col("left.id")], vec![sum(col("left.gmv"))])
         .unwrap()
         .sort(vec![sort_expr])
         .unwrap()