You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by nj...@apache.org on 2023/06/28 01:27:13 UTC
[arrow-ballista] branch main updated: Refactor the TaskDefinition by changing encoding execution plan to the decoded one (#817)
This is an automated email from the ASF dual-hosted git repository.
nju_yaho pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git
The following commit(s) were added to refs/heads/main by this push:
new d7a808cc Refactor the TaskDefinition by changing encoding execution plan to the decoded one (#817)
d7a808cc is described below
commit d7a808cc71053a4d0982eb0db17be89c895dac73
Author: yahoNanJing <90...@users.noreply.github.com>
AuthorDate: Wed Jun 28 09:27:09 2023 +0800
Refactor the TaskDefinition by changing encoding execution plan to the decoded one (#817)
* Revert "Only decode plan in `LaunchMultiTaskParams` once (#743)"
This reverts commit 4e4842ce5221b8ce6ce39b82bb1346e337129b0d.
* Refactor the TaskDefinition by changing encoding execution plan to the decoded one
* Refine the error handling of run_task in the executor_server
---------
Co-authored-by: yangzhong <ya...@ebay.com>
---
ballista/core/src/serde/scheduler/from_proto.rs | 194 ++++++++++++-----
ballista/core/src/serde/scheduler/mod.rs | 46 +++-
ballista/core/src/serde/scheduler/to_proto.rs | 33 +--
ballista/executor/src/execution_engine.rs | 7 -
ballista/executor/src/executor_server.rs | 272 ++++++++----------------
5 files changed, 272 insertions(+), 280 deletions(-)
diff --git a/ballista/core/src/serde/scheduler/from_proto.rs b/ballista/core/src/serde/scheduler/from_proto.rs
index 17875e2b..545896d8 100644
--- a/ballista/core/src/serde/scheduler/from_proto.rs
+++ b/ballista/core/src/serde/scheduler/from_proto.rs
@@ -16,10 +16,15 @@
// under the License.
use chrono::{TimeZone, Utc};
+use datafusion::common::tree_node::{Transformed, TreeNode};
+use datafusion::execution::runtime_env::RuntimeEnv;
+use datafusion::logical_expr::{AggregateUDF, ScalarUDF};
use datafusion::physical_plan::metrics::{
Count, Gauge, MetricValue, MetricsSet, Time, Timestamp,
};
-use datafusion::physical_plan::Metric;
+use datafusion::physical_plan::{ExecutionPlan, Metric};
+use datafusion_proto::logical_plan::AsLogicalPlan;
+use datafusion_proto::physical_plan::AsExecutionPlan;
use std::collections::HashMap;
use std::convert::TryInto;
use std::sync::Arc;
@@ -28,10 +33,10 @@ use std::time::Duration;
use crate::error::BallistaError;
use crate::serde::scheduler::{
Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId,
- PartitionLocation, PartitionStats, TaskDefinition,
+ PartitionLocation, PartitionStats, SimpleFunctionRegistry, TaskDefinition,
};
-use crate::serde::protobuf;
+use crate::serde::{protobuf, BallistaCodec};
use protobuf::{operator_metric, NamedCount, NamedGauge, NamedTime};
impl TryInto<Action> for protobuf::Action {
@@ -269,67 +274,138 @@ impl Into<ExecutorData> for protobuf::ExecutorData {
}
}
-impl TryInto<(TaskDefinition, Vec<u8>)> for protobuf::TaskDefinition {
- type Error = BallistaError;
-
- fn try_into(self) -> Result<(TaskDefinition, Vec<u8>), Self::Error> {
- let mut props = HashMap::new();
- for kv_pair in self.props {
- props.insert(kv_pair.key, kv_pair.value);
- }
+pub fn get_task_definition<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
+ task: protobuf::TaskDefinition,
+ runtime: Arc<RuntimeEnv>,
+ scalar_functions: HashMap<String, Arc<ScalarUDF>>,
+ aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
+ codec: BallistaCodec<T, U>,
+) -> Result<TaskDefinition, BallistaError> {
+ let mut props = HashMap::new();
+ for kv_pair in task.props {
+ props.insert(kv_pair.key, kv_pair.value);
+ }
+ let props = Arc::new(props);
- Ok((
- TaskDefinition {
- task_id: self.task_id as usize,
- task_attempt_num: self.task_attempt_num as usize,
- job_id: self.job_id,
- stage_id: self.stage_id as usize,
- stage_attempt_num: self.stage_attempt_num as usize,
- partition_id: self.partition_id as usize,
- plan: vec![],
- session_id: self.session_id,
- launch_time: self.launch_time,
- props,
- },
- self.plan,
- ))
+ let mut task_scalar_functions = HashMap::new();
+ let mut task_aggregate_functions = HashMap::new();
+ // TODO combine the functions from Executor's functions and TaskDefinition's function resources
+ for scalar_func in scalar_functions {
+ task_scalar_functions.insert(scalar_func.0, scalar_func.1);
}
-}
+ for agg_func in aggregate_functions {
+ task_aggregate_functions.insert(agg_func.0, agg_func.1);
+ }
+ let function_registry = Arc::new(SimpleFunctionRegistry {
+ scalar_functions: task_scalar_functions,
+ aggregate_functions: task_aggregate_functions,
+ });
-impl TryInto<(Vec<TaskDefinition>, Vec<u8>)> for protobuf::MultiTaskDefinition {
- type Error = BallistaError;
+ let encoded_plan = task.plan.as_slice();
+ let plan: Arc<dyn ExecutionPlan> = U::try_decode(encoded_plan).and_then(|proto| {
+ proto.try_into_physical_plan(
+ function_registry.as_ref(),
+ runtime.as_ref(),
+ codec.physical_extension_codec(),
+ )
+ })?;
- fn try_into(self) -> Result<(Vec<TaskDefinition>, Vec<u8>), Self::Error> {
- let mut props = HashMap::new();
- for kv_pair in self.props {
- props.insert(kv_pair.key, kv_pair.value);
- }
+ let job_id = task.job_id;
+ let stage_id = task.stage_id as usize;
+ let partition_id = task.partition_id as usize;
+ let task_attempt_num = task.task_attempt_num as usize;
+ let stage_attempt_num = task.stage_attempt_num as usize;
+ let launch_time = task.launch_time;
+ let task_id = task.task_id as usize;
+ let session_id = task.session_id;
- let plan = self.plan;
- let session_id = self.session_id;
- let job_id = self.job_id;
- let stage_id = self.stage_id as usize;
- let stage_attempt_num = self.stage_attempt_num as usize;
- let launch_time = self.launch_time;
- let task_ids = self.task_ids;
+ Ok(TaskDefinition {
+ task_id,
+ task_attempt_num,
+ job_id,
+ stage_id,
+ stage_attempt_num,
+ partition_id,
+ plan,
+ launch_time,
+ session_id,
+ props,
+ function_registry,
+ })
+}
- Ok((
- task_ids
- .iter()
- .map(|task_id| TaskDefinition {
- task_id: task_id.task_id as usize,
- task_attempt_num: task_id.task_attempt_num as usize,
- job_id: job_id.clone(),
- stage_id,
- stage_attempt_num,
- partition_id: task_id.partition_id as usize,
- plan: vec![],
- session_id: session_id.clone(),
- launch_time,
- props: props.clone(),
- })
- .collect(),
- plan,
- ))
+pub fn get_task_definition_vec<
+ T: 'static + AsLogicalPlan,
+ U: 'static + AsExecutionPlan,
+>(
+ multi_task: protobuf::MultiTaskDefinition,
+ runtime: Arc<RuntimeEnv>,
+ scalar_functions: HashMap<String, Arc<ScalarUDF>>,
+ aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
+ codec: BallistaCodec<T, U>,
+) -> Result<Vec<TaskDefinition>, BallistaError> {
+ let mut props = HashMap::new();
+ for kv_pair in multi_task.props {
+ props.insert(kv_pair.key, kv_pair.value);
}
+ let props = Arc::new(props);
+
+ let mut task_scalar_functions = HashMap::new();
+ let mut task_aggregate_functions = HashMap::new();
+ // TODO combine the functions from Executor's functions and TaskDefinition's function resources
+ for scalar_func in scalar_functions {
+ task_scalar_functions.insert(scalar_func.0, scalar_func.1);
+ }
+ for agg_func in aggregate_functions {
+ task_aggregate_functions.insert(agg_func.0, agg_func.1);
+ }
+ let function_registry = Arc::new(SimpleFunctionRegistry {
+ scalar_functions: task_scalar_functions,
+ aggregate_functions: task_aggregate_functions,
+ });
+
+ let encoded_plan = multi_task.plan.as_slice();
+ let plan: Arc<dyn ExecutionPlan> = U::try_decode(encoded_plan).and_then(|proto| {
+ proto.try_into_physical_plan(
+ function_registry.as_ref(),
+ runtime.as_ref(),
+ codec.physical_extension_codec(),
+ )
+ })?;
+
+ let job_id = multi_task.job_id;
+ let stage_id = multi_task.stage_id as usize;
+ let stage_attempt_num = multi_task.stage_attempt_num as usize;
+ let launch_time = multi_task.launch_time;
+ let task_ids = multi_task.task_ids;
+ let session_id = multi_task.session_id;
+
+ task_ids
+ .iter()
+ .map(|task_id| {
+ Ok(TaskDefinition {
+ task_id: task_id.task_id as usize,
+ task_attempt_num: task_id.task_attempt_num as usize,
+ job_id: job_id.clone(),
+ stage_id,
+ stage_attempt_num,
+ partition_id: task_id.partition_id as usize,
+ plan: reset_metrics_for_execution_plan(plan.clone())?,
+ launch_time,
+ session_id: session_id.clone(),
+ props: props.clone(),
+ function_registry: function_registry.clone(),
+ })
+ })
+ .collect()
+}
+
+fn reset_metrics_for_execution_plan(
+ plan: Arc<dyn ExecutionPlan>,
+) -> Result<Arc<dyn ExecutionPlan>, BallistaError> {
+ plan.transform(&|plan| {
+ let children = plan.children().clone();
+ plan.with_new_children(children).map(Transformed::Yes)
+ })
+ .map_err(BallistaError::DataFusionError)
}
diff --git a/ballista/core/src/serde/scheduler/mod.rs b/ballista/core/src/serde/scheduler/mod.rs
index 6e9440a3..96c4e0fa 100644
--- a/ballista/core/src/serde/scheduler/mod.rs
+++ b/ballista/core/src/serde/scheduler/mod.rs
@@ -15,12 +15,17 @@
// specific language governing permissions and limitations
// under the License.
+use std::collections::HashSet;
+use std::fmt::Debug;
use std::{collections::HashMap, fmt, sync::Arc};
use datafusion::arrow::array::{
ArrayBuilder, StructArray, StructBuilder, UInt64Array, UInt64Builder,
};
use datafusion::arrow::datatypes::{DataType, Field};
+use datafusion::common::DataFusionError;
+use datafusion::execution::FunctionRegistry;
+use datafusion::logical_expr::{AggregateUDF, ScalarUDF};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::Partitioning;
use serde::Serialize;
@@ -271,7 +276,7 @@ impl ExecutePartitionResult {
}
}
-#[derive(Debug, Clone)]
+#[derive(Clone, Debug)]
pub struct TaskDefinition {
pub task_id: usize,
pub task_attempt_num: usize,
@@ -279,8 +284,41 @@ pub struct TaskDefinition {
pub stage_id: usize,
pub stage_attempt_num: usize,
pub partition_id: usize,
- pub plan: Vec<u8>,
- pub session_id: String,
+ pub plan: Arc<dyn ExecutionPlan>,
pub launch_time: u64,
- pub props: HashMap<String, String>,
+ pub session_id: String,
+ pub props: Arc<HashMap<String, String>>,
+ pub function_registry: Arc<SimpleFunctionRegistry>,
+}
+
+#[derive(Debug)]
+pub struct SimpleFunctionRegistry {
+ pub scalar_functions: HashMap<String, Arc<ScalarUDF>>,
+ pub aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
+}
+
+impl FunctionRegistry for SimpleFunctionRegistry {
+ fn udfs(&self) -> HashSet<String> {
+ self.scalar_functions.keys().cloned().collect()
+ }
+
+ fn udf(&self, name: &str) -> datafusion::common::Result<Arc<ScalarUDF>> {
+ let result = self.scalar_functions.get(name);
+
+ result.cloned().ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "There is no UDF named \"{name}\" in the TaskContext"
+ ))
+ })
+ }
+
+ fn udaf(&self, name: &str) -> datafusion::common::Result<Arc<AggregateUDF>> {
+ let result = self.aggregate_functions.get(name);
+
+ result.cloned().ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "There is no UDAF named \"{name}\" in the TaskContext"
+ ))
+ })
+ }
}
diff --git a/ballista/core/src/serde/scheduler/to_proto.rs b/ballista/core/src/serde/scheduler/to_proto.rs
index ccb5ec42..6ceb1dd6 100644
--- a/ballista/core/src/serde/scheduler/to_proto.rs
+++ b/ballista/core/src/serde/scheduler/to_proto.rs
@@ -26,12 +26,10 @@ use datafusion_proto::protobuf as datafusion_protobuf;
use crate::serde::scheduler::{
Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId,
- PartitionLocation, PartitionStats, TaskDefinition,
+ PartitionLocation, PartitionStats,
};
use datafusion::physical_plan::Partitioning;
-use protobuf::{
- action::ActionType, operator_metric, KeyValuePair, NamedCount, NamedGauge, NamedTime,
-};
+use protobuf::{action::ActionType, operator_metric, NamedCount, NamedGauge, NamedTime};
impl TryInto<protobuf::Action> for Action {
type Error = BallistaError;
@@ -242,30 +240,3 @@ impl Into<protobuf::ExecutorData> for ExecutorData {
}
}
}
-
-#[allow(clippy::from_over_into)]
-impl Into<protobuf::TaskDefinition> for TaskDefinition {
- fn into(self) -> protobuf::TaskDefinition {
- let props = self
- .props
- .iter()
- .map(|(k, v)| KeyValuePair {
- key: k.to_owned(),
- value: v.to_owned(),
- })
- .collect::<Vec<_>>();
-
- protobuf::TaskDefinition {
- task_id: self.task_id as u32,
- task_attempt_num: self.task_attempt_num as u32,
- job_id: self.job_id,
- stage_id: self.stage_id as u32,
- stage_attempt_num: self.stage_attempt_num as u32,
- partition_id: self.partition_id as u32,
- plan: self.plan,
- session_id: self.session_id,
- launch_time: self.launch_time,
- props,
- }
- }
-}
diff --git a/ballista/executor/src/execution_engine.rs b/ballista/executor/src/execution_engine.rs
index 96515329..5121f016 100644
--- a/ballista/executor/src/execution_engine.rs
+++ b/ballista/executor/src/execution_engine.rs
@@ -15,7 +15,6 @@
// specific language governing permissions and limitations
// under the License.
-use arrow::datatypes::SchemaRef;
use async_trait::async_trait;
use ballista_core::execution_plans::ShuffleWriterExec;
use ballista_core::serde::protobuf::ShuffleWritePartition;
@@ -52,8 +51,6 @@ pub trait QueryStageExecutor: Sync + Send + Debug {
) -> Result<Vec<ShuffleWritePartition>>;
fn collect_plan_metrics(&self) -> Vec<MetricsSet>;
-
- fn schema(&self) -> SchemaRef;
}
pub struct DefaultExecutionEngine {}
@@ -111,10 +108,6 @@ impl QueryStageExecutor for DefaultQueryStageExec {
.await
}
- fn schema(&self) -> SchemaRef {
- self.shuffle_writer.schema()
- }
-
fn collect_plan_metrics(&self) -> Vec<MetricsSet> {
utils::collect_plan_metrics(&self.shuffle_writer)
}
diff --git a/ballista/executor/src/executor_server.rs b/ballista/executor/src/executor_server.rs
index 9102923f..2892cb0b 100644
--- a/ballista/executor/src/executor_server.rs
+++ b/ballista/executor/src/executor_server.rs
@@ -16,11 +16,8 @@
// under the License.
use ballista_core::BALLISTA_VERSION;
-use datafusion::config::ConfigOptions;
-use datafusion::prelude::SessionConfig;
use std::collections::HashMap;
use std::convert::TryInto;
-use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
@@ -41,18 +38,22 @@ use ballista_core::serde::protobuf::{
LaunchTaskResult, RegisterExecutorParams, RemoveJobDataParams, RemoveJobDataResult,
StopExecutorParams, StopExecutorResult, TaskStatus, UpdateTaskStatusParams,
};
+use ballista_core::serde::scheduler::from_proto::{
+ get_task_definition, get_task_definition_vec,
+};
use ballista_core::serde::scheduler::PartitionId;
use ballista_core::serde::scheduler::TaskDefinition;
use ballista_core::serde::BallistaCodec;
use ballista_core::utils::{create_grpc_client_connection, create_grpc_server};
use dashmap::DashMap;
-use datafusion::execution::context::TaskContext;
+use datafusion::config::ConfigOptions;
+use datafusion::execution::TaskContext;
+use datafusion::prelude::SessionConfig;
use datafusion_proto::{logical_plan::AsLogicalPlan, physical_plan::AsExecutionPlan};
use tokio::sync::mpsc::error::TryRecvError;
use tokio::task::JoinHandle;
use crate::cpu_bound_executor::DedicatedExecutor;
-use crate::execution_engine::QueryStageExecutor;
use crate::executor::Executor;
use crate::executor_process::ExecutorProcessConfig;
use crate::shutdown::ShutdownNotifier;
@@ -65,8 +66,7 @@ type SchedulerClients = Arc<DashMap<String, SchedulerGrpcClient<Channel>>>;
#[derive(Debug)]
struct CuratorTaskDefinition {
scheduler_id: String,
- plan: Vec<u8>,
- tasks: Vec<TaskDefinition>,
+ task: TaskDefinition,
}
/// Wrap TaskStatus with its curator scheduler id for task update to its specific curator scheduler later
@@ -298,100 +298,22 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
}
}
- async fn decode_task(
- &self,
- curator_task: TaskDefinition,
- plan: &[u8],
- ) -> Result<Arc<dyn QueryStageExecutor>, BallistaError> {
- let task = curator_task;
- let task_identity = task_identity(&task);
- let task_props = task.props;
- let mut config = ConfigOptions::new();
- for (k, v) in task_props {
- config.set(&k, &v)?;
- }
- let session_config = SessionConfig::from(config);
-
- let mut task_scalar_functions = HashMap::new();
- let mut task_aggregate_functions = HashMap::new();
- for scalar_func in self.executor.scalar_functions.clone() {
- task_scalar_functions.insert(scalar_func.0, scalar_func.1);
- }
- for agg_func in self.executor.aggregate_functions.clone() {
- task_aggregate_functions.insert(agg_func.0, agg_func.1);
- }
-
- let task_context = Arc::new(TaskContext::new(
- Some(task_identity),
- task.session_id.clone(),
- session_config,
- task_scalar_functions,
- task_aggregate_functions,
- self.executor.runtime.clone(),
- ));
-
- let plan = U::try_decode(plan).and_then(|proto| {
- proto.try_into_physical_plan(
- task_context.deref(),
- &self.executor.runtime,
- self.codec.physical_extension_codec(),
- )
- })?;
-
- Ok(self.executor.execution_engine.create_query_stage_exec(
- task.job_id,
- task.stage_id,
- plan,
- &self.executor.work_dir,
- )?)
- }
-
- async fn run_task(
- &self,
- task_identity: &str,
- scheduler_id: String,
- curator_task: TaskDefinition,
- query_stage_exec: Arc<dyn QueryStageExecutor>,
- ) -> Result<(), BallistaError> {
+ /// This method should not return Err. If task fails, a failure task status should be sent
+ /// to the channel to notify the scheduler.
+ async fn run_task(&self, task_identity: String, curator_task: CuratorTaskDefinition) {
let start_exec_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
info!("Start to run task {}", task_identity);
- let task = curator_task;
- let task_props = task.props;
- let mut config = ConfigOptions::new();
- for (k, v) in task_props {
- config.set(&k, &v)?;
- }
- let session_config = SessionConfig::from(config);
-
- let mut task_scalar_functions = HashMap::new();
- let mut task_aggregate_functions = HashMap::new();
- // TODO combine the functions from Executor's functions and TaskDefintion's function resources
- for scalar_func in self.executor.scalar_functions.clone() {
- task_scalar_functions.insert(scalar_func.0, scalar_func.1);
- }
- for agg_func in self.executor.aggregate_functions.clone() {
- task_aggregate_functions.insert(agg_func.0, agg_func.1);
- }
-
- let session_id = task.session_id;
- let runtime = self.executor.runtime.clone();
- let task_context = Arc::new(TaskContext::new(
- Some(task_identity.to_string()),
- session_id,
- session_config,
- task_scalar_functions,
- task_aggregate_functions,
- runtime.clone(),
- ));
+ let task = curator_task.task;
let task_id = task.task_id;
let job_id = task.job_id;
let stage_id = task.stage_id;
let stage_attempt_num = task.stage_attempt_num;
let partition_id = task.partition_id;
+ let plan = task.plan;
let part = PartitionId {
job_id: job_id.clone(),
@@ -399,6 +321,40 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
partition_id,
};
+ let query_stage_exec = self
+ .executor
+ .execution_engine
+ .create_query_stage_exec(
+ job_id.clone(),
+ stage_id,
+ plan,
+ &self.executor.work_dir,
+ )
+ .unwrap();
+
+ let task_context = {
+ let task_props = task.props;
+ let mut config = ConfigOptions::new();
+ for (k, v) in task_props.iter() {
+ if let Err(e) = config.set(k, v) {
+ debug!("Fail to set session config for ({},{}): {:?}", k, v, e);
+ }
+ }
+ let session_config = SessionConfig::from(config);
+
+ let function_registry = task.function_registry;
+ let runtime = self.executor.runtime.clone();
+
+ Arc::new(TaskContext::new(
+ Some(task_identity.clone()),
+ task.session_id,
+ session_config,
+ function_registry.scalar_functions.clone(),
+ function_registry.aggregate_functions.clone(),
+ runtime,
+ ))
+ };
+
info!("Start to execute shuffle write for task {}", task_identity);
let execution_result = self
@@ -414,10 +370,14 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
debug!("Statistics: {:?}", execution_result);
let plan_metrics = query_stage_exec.collect_plan_metrics();
- let operator_metrics = plan_metrics
+ let operator_metrics = match plan_metrics
.into_iter()
.map(|m| m.try_into())
- .collect::<Result<Vec<_>, BallistaError>>()?;
+ .collect::<Result<Vec<_>, BallistaError>>()
+ {
+ Ok(metrics) => Some(metrics),
+ Err(_) => None,
+ };
let executor_id = &self.executor.metadata.id;
let end_exec_time = SystemTime::now()
@@ -436,10 +396,11 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
task_id,
stage_attempt_num,
part,
- Some(operator_metrics),
+ operator_metrics,
task_execution_times,
);
+ let scheduler_id = curator_task.scheduler_id;
let task_status_sender = self.executor_env.tx_task_status.clone();
task_status_sender
.send(CuratorTaskStatus {
@@ -448,7 +409,6 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
})
.await
.unwrap();
- Ok(())
}
// TODO populate with real metrics
@@ -505,18 +465,6 @@ struct TaskRunnerPool<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
executor_server: Arc<ExecutorServer<T, U>>,
}
-fn task_identity(task: &TaskDefinition) -> String {
- format!(
- "TID {} {}/{}.{}/{}.{}",
- &task.task_id,
- &task.job_id,
- &task.stage_id,
- &task.stage_attempt_num,
- &task.partition_id,
- &task.task_attempt_num,
- )
-}
-
impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskRunnerPool<T, U> {
fn new(executor_server: Arc<ExecutorServer<T, U>>) -> Self {
Self { executor_server }
@@ -638,64 +586,22 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskRunnerPool<T,
return;
}
};
- if let Some(task) = maybe_task {
+ if let Some(curator_task) = maybe_task {
+ let task_identity = format!(
+ "TID {} {}/{}.{}/{}.{}",
+ &curator_task.task.task_id,
+ &curator_task.task.job_id,
+ &curator_task.task.stage_id,
+ &curator_task.task.stage_attempt_num,
+ &curator_task.task.partition_id,
+ &curator_task.task.task_attempt_num,
+ );
+ info!("Received task {:?}", &task_identity);
+
let server = executor_server.clone();
- let plan = task.plan;
- let curator_task = task.tasks[0].clone();
- let out: tokio::sync::oneshot::Receiver<
- Result<Arc<dyn QueryStageExecutor>, BallistaError>,
- > = dedicated_executor.spawn(async move {
- server.decode_task(curator_task, &plan).await
+ dedicated_executor.spawn(async move {
+ server.run_task(task_identity.clone(), curator_task).await;
});
-
- let plan = out.await;
-
- let plan = match plan {
- Ok(Ok(plan)) => plan,
- Ok(Err(e)) => {
- error!(
- "Failed to decode the plan of task {:?} due to {:?}",
- task_identity(&task.tasks[0]),
- e
- );
- return;
- }
- Err(e) => {
- error!(
- "Failed to receive error plan of task {:?} due to {:?}",
- task_identity(&task.tasks[0]),
- e
- );
- return;
- }
- };
- let scheduler_id = task.scheduler_id.clone();
-
- for curator_task in task.tasks {
- let plan = plan.clone();
- let scheduler_id = scheduler_id.clone();
-
- let task_identity = task_identity(&curator_task);
- info!("Received task {:?}", &task_identity);
-
- let server = executor_server.clone();
- dedicated_executor.spawn(async move {
- server
- .run_task(
- &task_identity,
- scheduler_id,
- curator_task,
- plan,
- )
- .await
- .unwrap_or_else(|e| {
- error!(
- "Fail to run the task {:?} due to {:?}",
- task_identity, e
- );
- });
- });
- }
} else {
info!("Channel is closed and will exit the task receive loop");
drop(task_runner_complete);
@@ -720,15 +626,17 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorGrpc
} = request.into_inner();
let task_sender = self.executor_env.tx_task.clone();
for task in tasks {
- let (task_def, plan) = task
- .try_into()
- .map_err(|e| Status::invalid_argument(format!("{e}")))?;
-
task_sender
.send(CuratorTaskDefinition {
scheduler_id: scheduler_id.clone(),
- plan,
- tasks: vec![task_def],
+ task: get_task_definition(
+ task,
+ self.executor.runtime.clone(),
+ self.executor.scalar_functions.clone(),
+ self.executor.aggregate_functions.clone(),
+ self.codec.clone(),
+ )
+ .map_err(|e| Status::invalid_argument(format!("{e}")))?,
})
.await
.unwrap();
@@ -748,17 +656,23 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorGrpc
} = request.into_inner();
let task_sender = self.executor_env.tx_task.clone();
for multi_task in multi_tasks {
- let (multi_task, plan): (Vec<TaskDefinition>, Vec<u8>) = multi_task
- .try_into()
- .map_err(|e| Status::invalid_argument(format!("{e}")))?;
- task_sender
- .send(CuratorTaskDefinition {
- scheduler_id: scheduler_id.clone(),
- plan,
- tasks: multi_task,
- })
- .await
- .unwrap();
+ let multi_task: Vec<TaskDefinition> = get_task_definition_vec(
+ multi_task,
+ self.executor.runtime.clone(),
+ self.executor.scalar_functions.clone(),
+ self.executor.aggregate_functions.clone(),
+ self.codec.clone(),
+ )
+ .map_err(|e| Status::invalid_argument(format!("{e}")))?;
+ for task in multi_task {
+ task_sender
+ .send(CuratorTaskDefinition {
+ scheduler_id: scheduler_id.clone(),
+ task,
+ })
+ .await
+ .unwrap();
+ }
}
Ok(Response::new(LaunchMultiTaskResult { success: true }))
}