You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by nj...@apache.org on 2023/06/28 01:27:02 UTC

[arrow-ballista] branch main updated: Fix index out of bounds panic (#819)

This is an automated email from the ASF dual-hosted git repository.

nju_yaho pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new 553b9a7d Fix index out of bounds panic (#819)
553b9a7d is described below

commit 553b9a7d74b60fb90b4af10c52d856ca207ad141
Author: yahoNanJing <90...@users.noreply.github.com>
AuthorDate: Wed Jun 28 09:26:57 2023 +0800

    Fix index out of bounds panic (#819)
    
    * Remove output_partitioning from TaskDescription
    
    * Remove output_partitioning from execution stages
    
    * Remove input_partition_count from UnresolvedShuffleExec
    
    * Add input stage id for ShuffleReaderExec
    
    * Correct the behavior of output_partitioning() for ShuffleWriterExec
    
    ---------
    
    Co-authored-by: yangzhong <ya...@ebay.com>
---
 ballista/core/proto/ballista.proto                 |  7 +-
 .../core/src/execution_plans/shuffle_reader.rs     | 18 +++--
 .../core/src/execution_plans/shuffle_writer.rs     | 18 +++--
 .../core/src/execution_plans/unresolved_shuffle.rs |  5 --
 ballista/core/src/serde/generated/ballista.rs      | 21 +-----
 ballista/core/src/serde/mod.rs                     |  8 +-
 ballista/scheduler/src/planner.rs                  | 17 +----
 ballista/scheduler/src/scheduler_server/mod.rs     |  5 +-
 ballista/scheduler/src/state/execution_graph.rs    | 23 ++++--
 .../src/state/execution_graph/execution_stage.rs   | 87 ++++------------------
 .../scheduler/src/state/execution_graph_dot.rs     |  2 +-
 ballista/scheduler/src/test_utils.rs               | 10 +--
 12 files changed, 70 insertions(+), 151 deletions(-)

diff --git a/ballista/core/proto/ballista.proto b/ballista/core/proto/ballista.proto
index e596c1a7..74145fc2 100644
--- a/ballista/core/proto/ballista.proto
+++ b/ballista/core/proto/ballista.proto
@@ -49,13 +49,14 @@ message ShuffleWriterExecNode {
 message UnresolvedShuffleExecNode {
   uint32 stage_id = 1;
   datafusion.Schema schema = 2;
-  uint32 input_partition_count = 3;
   uint32 output_partition_count = 4;
 }
 
 message ShuffleReaderExecNode {
   repeated ShuffleReaderPartition partition = 1;
   datafusion.Schema schema = 2;
+  // The stage to read from
+  uint32 stage_id = 3;
 }
 
 message ShuffleReaderPartition {
@@ -98,7 +99,6 @@ message ExecutionGraphStage {
 
 message UnResolvedStage {
   uint32 stage_id = 1;
-  datafusion.PhysicalHashRepartition output_partitioning = 2;
   repeated uint32 output_links = 3;
   repeated  GraphStageInput inputs = 4;
   bytes plan = 5;
@@ -109,7 +109,6 @@ message UnResolvedStage {
 message ResolvedStage {
   uint32 stage_id = 1;
   uint32 partitions = 2;
-  datafusion.PhysicalHashRepartition output_partitioning = 3;
   repeated uint32 output_links = 4;
   repeated  GraphStageInput inputs = 5;
   bytes plan = 6;
@@ -120,7 +119,6 @@ message ResolvedStage {
 message SuccessfulStage {
   uint32 stage_id = 1;
   uint32 partitions = 2;
-  datafusion.PhysicalHashRepartition output_partitioning = 3;
   repeated uint32 output_links = 4;
   repeated  GraphStageInput inputs = 5;
   bytes plan = 6;
@@ -132,7 +130,6 @@ message SuccessfulStage {
 message FailedStage {
   uint32 stage_id = 1;
   uint32 partitions = 2;
-  datafusion.PhysicalHashRepartition output_partitioning = 3;
   repeated uint32 output_links = 4;
   bytes plan = 5;
   repeated TaskInfo task_infos = 6;
diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs b/ballista/core/src/execution_plans/shuffle_reader.rs
index 3bab9266..5acde9ec 100644
--- a/ballista/core/src/execution_plans/shuffle_reader.rs
+++ b/ballista/core/src/execution_plans/shuffle_reader.rs
@@ -58,9 +58,11 @@ use tokio_stream::wrappers::ReceiverStream;
 /// being executed by an executor
 #[derive(Debug, Clone)]
 pub struct ShuffleReaderExec {
+    /// The query stage id to read from
+    pub stage_id: usize,
+    pub(crate) schema: SchemaRef,
     /// Each partition of a shuffle can read data from multiple locations
     pub partition: Vec<Vec<PartitionLocation>>,
-    pub(crate) schema: SchemaRef,
     /// Execution metrics
     metrics: ExecutionPlanMetricsSet,
 }
@@ -68,12 +70,14 @@ pub struct ShuffleReaderExec {
 impl ShuffleReaderExec {
     /// Create a new ShuffleReaderExec
     pub fn try_new(
+        stage_id: usize,
         partition: Vec<Vec<PartitionLocation>>,
         schema: SchemaRef,
     ) -> Result<Self> {
         Ok(Self {
-            partition,
+            stage_id,
             schema,
+            partition,
             metrics: ExecutionPlanMetricsSet::new(),
         })
     }
@@ -513,13 +517,14 @@ mod tests {
         ]);
 
         let job_id = "test_job_1";
+        let input_stage_id = 2;
         let mut partitions: Vec<PartitionLocation> = vec![];
         for partition_id in 0..4 {
             partitions.push(PartitionLocation {
                 map_partition_id: 0,
                 partition_id: PartitionId {
                     job_id: job_id.to_string(),
-                    stage_id: 2,
+                    stage_id: input_stage_id,
                     partition_id,
                 },
                 executor_meta: ExecutorMetadata {
@@ -534,8 +539,11 @@ mod tests {
             })
         }
 
-        let shuffle_reader_exec =
-            ShuffleReaderExec::try_new(vec![partitions], Arc::new(schema))?;
+        let shuffle_reader_exec = ShuffleReaderExec::try_new(
+            input_stage_id,
+            vec![partitions],
+            Arc::new(schema),
+        )?;
         let mut stream = shuffle_reader_exec.execute(0, task_ctx)?;
         let batches = utils::collect_stream(&mut stream).await;
 
diff --git a/ballista/core/src/execution_plans/shuffle_writer.rs b/ballista/core/src/execution_plans/shuffle_writer.rs
index 443f8517..742f0d4e 100644
--- a/ballista/core/src/execution_plans/shuffle_writer.rs
+++ b/ballista/core/src/execution_plans/shuffle_writer.rs
@@ -71,7 +71,8 @@ pub struct ShuffleWriterExec {
     plan: Arc<dyn ExecutionPlan>,
     /// Path to write output streams to
     work_dir: String,
-    /// Optional shuffle output partitioning
+    /// Optional shuffle output partitioning.
+    /// If it's none, it means there's no need to do repartitioning.
     shuffle_output_partitioning: Option<Partitioning>,
     /// Execution metrics
     metrics: ExecutionPlanMetricsSet,
@@ -134,6 +135,11 @@ impl ShuffleWriterExec {
         self.stage_id
     }
 
+    /// Get the input partition count
+    pub fn input_partition_count(&self) -> usize {
+        self.plan.output_partitioning().partition_count()
+    }
+
     /// Get the true output partitioning
     pub fn shuffle_output_partitioning(&self) -> Option<&Partitioning> {
         self.shuffle_output_partitioning.as_ref()
@@ -297,12 +303,12 @@ impl ExecutionPlan for ShuffleWriterExec {
         self.plan.schema()
     }
 
+    /// If [`shuffle_output_partitioning`] is none, then there's no need to do repartitioning.
+    /// Therefore, the partition is the same as its input plan's.
     fn output_partitioning(&self) -> Partitioning {
-        // This operator needs to be executed once for each *input* partition and there
-        // isn't really a mechanism yet in DataFusion to support this use case so we report
-        // the input partitioning as the output partitioning here. The executor reports
-        // output partition meta data back to the scheduler.
-        self.plan.output_partitioning()
+        self.shuffle_output_partitioning
+            .clone()
+            .unwrap_or_else(|| self.plan.output_partitioning())
     }
 
     fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
diff --git a/ballista/core/src/execution_plans/unresolved_shuffle.rs b/ballista/core/src/execution_plans/unresolved_shuffle.rs
index fe36134d..7c799741 100644
--- a/ballista/core/src/execution_plans/unresolved_shuffle.rs
+++ b/ballista/core/src/execution_plans/unresolved_shuffle.rs
@@ -38,9 +38,6 @@ pub struct UnresolvedShuffleExec {
     // The schema this node will have once it is replaced with a ShuffleReaderExec
     pub schema: SchemaRef,
 
-    // The number of shuffle writer partition tasks that will produce the partitions
-    pub input_partition_count: usize,
-
     // The partition count this node will have once it is replaced with a ShuffleReaderExec
     pub output_partition_count: usize,
 }
@@ -50,13 +47,11 @@ impl UnresolvedShuffleExec {
     pub fn new(
         stage_id: usize,
         schema: SchemaRef,
-        input_partition_count: usize,
         output_partition_count: usize,
     ) -> Self {
         Self {
             stage_id,
             schema,
-            input_partition_count,
             output_partition_count,
         }
     }
diff --git a/ballista/core/src/serde/generated/ballista.rs b/ballista/core/src/serde/generated/ballista.rs
index d0c29e0b..3a98a28b 100644
--- a/ballista/core/src/serde/generated/ballista.rs
+++ b/ballista/core/src/serde/generated/ballista.rs
@@ -45,8 +45,6 @@ pub struct UnresolvedShuffleExecNode {
     pub stage_id: u32,
     #[prost(message, optional, tag = "2")]
     pub schema: ::core::option::Option<::datafusion_proto::protobuf::Schema>,
-    #[prost(uint32, tag = "3")]
-    pub input_partition_count: u32,
     #[prost(uint32, tag = "4")]
     pub output_partition_count: u32,
 }
@@ -57,6 +55,9 @@ pub struct ShuffleReaderExecNode {
     pub partition: ::prost::alloc::vec::Vec<ShuffleReaderPartition>,
     #[prost(message, optional, tag = "2")]
     pub schema: ::core::option::Option<::datafusion_proto::protobuf::Schema>,
+    /// The stage to read from
+    #[prost(uint32, tag = "3")]
+    pub stage_id: u32,
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
@@ -132,10 +133,6 @@ pub mod execution_graph_stage {
 pub struct UnResolvedStage {
     #[prost(uint32, tag = "1")]
     pub stage_id: u32,
-    #[prost(message, optional, tag = "2")]
-    pub output_partitioning: ::core::option::Option<
-        ::datafusion_proto::protobuf::PhysicalHashRepartition,
-    >,
     #[prost(uint32, repeated, tag = "3")]
     pub output_links: ::prost::alloc::vec::Vec<u32>,
     #[prost(message, repeated, tag = "4")]
@@ -156,10 +153,6 @@ pub struct ResolvedStage {
     pub stage_id: u32,
     #[prost(uint32, tag = "2")]
     pub partitions: u32,
-    #[prost(message, optional, tag = "3")]
-    pub output_partitioning: ::core::option::Option<
-        ::datafusion_proto::protobuf::PhysicalHashRepartition,
-    >,
     #[prost(uint32, repeated, tag = "4")]
     pub output_links: ::prost::alloc::vec::Vec<u32>,
     #[prost(message, repeated, tag = "5")]
@@ -180,10 +173,6 @@ pub struct SuccessfulStage {
     pub stage_id: u32,
     #[prost(uint32, tag = "2")]
     pub partitions: u32,
-    #[prost(message, optional, tag = "3")]
-    pub output_partitioning: ::core::option::Option<
-        ::datafusion_proto::protobuf::PhysicalHashRepartition,
-    >,
     #[prost(uint32, repeated, tag = "4")]
     pub output_links: ::prost::alloc::vec::Vec<u32>,
     #[prost(message, repeated, tag = "5")]
@@ -204,10 +193,6 @@ pub struct FailedStage {
     pub stage_id: u32,
     #[prost(uint32, tag = "2")]
     pub partitions: u32,
-    #[prost(message, optional, tag = "3")]
-    pub output_partitioning: ::core::option::Option<
-        ::datafusion_proto::protobuf::PhysicalHashRepartition,
-    >,
     #[prost(uint32, repeated, tag = "4")]
     pub output_links: ::prost::alloc::vec::Vec<u32>,
     #[prost(bytes = "vec", tag = "5")]
diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs
index dd4dc162..7f74e5ae 100644
--- a/ballista/core/src/serde/mod.rs
+++ b/ballista/core/src/serde/mod.rs
@@ -157,6 +157,7 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec {
                 )?))
             }
             PhysicalPlanType::ShuffleReader(shuffle_reader) => {
+                let stage_id = shuffle_reader.stage_id as usize;
                 let schema = Arc::new(convert_required!(shuffle_reader.schema)?);
                 let partition_location: Vec<Vec<PartitionLocation>> = shuffle_reader
                     .partition
@@ -175,7 +176,7 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec {
                     })
                     .collect::<Result<Vec<_>, DataFusionError>>()?;
                 let shuffle_reader =
-                    ShuffleReaderExec::try_new(partition_location, schema)?;
+                    ShuffleReaderExec::try_new(stage_id, partition_location, schema)?;
                 Ok(Arc::new(shuffle_reader))
             }
             PhysicalPlanType::UnresolvedShuffle(unresolved_shuffle) => {
@@ -183,8 +184,6 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec {
                 Ok(Arc::new(UnresolvedShuffleExec {
                     stage_id: unresolved_shuffle.stage_id as usize,
                     schema,
-                    input_partition_count: unresolved_shuffle.input_partition_count
-                        as usize,
                     output_partition_count: unresolved_shuffle.output_partition_count
                         as usize,
                 }))
@@ -237,6 +236,7 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec {
 
             Ok(())
         } else if let Some(exec) = node.as_any().downcast_ref::<ShuffleReaderExec>() {
+            let stage_id = exec.stage_id as u32;
             let mut partition = vec![];
             for location in &exec.partition {
                 partition.push(protobuf::ShuffleReaderPartition {
@@ -255,6 +255,7 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec {
             let proto = protobuf::BallistaPhysicalPlanNode {
                 physical_plan_type: Some(PhysicalPlanType::ShuffleReader(
                     protobuf::ShuffleReaderExecNode {
+                        stage_id,
                         partition,
                         schema: Some(exec.schema().as_ref().try_into()?),
                     },
@@ -273,7 +274,6 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec {
                     protobuf::UnresolvedShuffleExecNode {
                         stage_id: exec.stage_id as u32,
                         schema: Some(exec.schema().as_ref().try_into()?),
-                        input_partition_count: exec.input_partition_count as u32,
                         output_partition_count: exec.output_partition_count as u32,
                     },
                 )),
diff --git a/ballista/scheduler/src/planner.rs b/ballista/scheduler/src/planner.rs
index 87c523cb..763a7ef2 100644
--- a/ballista/scheduler/src/planner.rs
+++ b/ballista/scheduler/src/planner.rs
@@ -178,10 +178,6 @@ fn create_unresolved_shuffle(
         shuffle_writer.stage_id(),
         shuffle_writer.schema(),
         shuffle_writer.output_partitioning().partition_count(),
-        shuffle_writer
-            .shuffle_output_partitioning()
-            .map(|p| p.partition_count())
-            .unwrap_or_else(|| shuffle_writer.output_partitioning().partition_count()),
     ))
 }
 
@@ -246,6 +242,7 @@ pub fn remove_unresolved_shuffles(
                     .join("\n")
             );
             new_children.push(Arc::new(ShuffleReaderExec::try_new(
+                unresolved_shuffle.stage_id,
                 relevant_locations,
                 unresolved_shuffle.schema().clone(),
             )?))
@@ -265,15 +262,13 @@ pub fn rollback_resolved_shuffles(
     let mut new_children: Vec<Arc<dyn ExecutionPlan>> = vec![];
     for child in stage.children() {
         if let Some(shuffle_reader) = child.as_any().downcast_ref::<ShuffleReaderExec>() {
-            let partition_locations = &shuffle_reader.partition;
-            let output_partition_count = partition_locations.len();
-            let input_partition_count = partition_locations[0].len();
-            let stage_id = partition_locations[0][0].partition_id.stage_id;
+            let output_partition_count =
+                shuffle_reader.output_partitioning().partition_count();
+            let stage_id = shuffle_reader.stage_id;
 
             let unresolved_shuffle = Arc::new(UnresolvedShuffleExec::new(
                 stage_id,
                 shuffle_reader.schema(),
-                input_partition_count,
                 output_partition_count,
             ));
             new_children.push(unresolved_shuffle);
@@ -392,7 +387,6 @@ mod test {
         let unresolved_shuffle =
             downcast_exec!(unresolved_shuffle, UnresolvedShuffleExec);
         assert_eq!(unresolved_shuffle.stage_id, 1);
-        assert_eq!(unresolved_shuffle.input_partition_count, 2);
         assert_eq!(unresolved_shuffle.output_partition_count, 2);
 
         // verify stage 2
@@ -402,7 +396,6 @@ mod test {
         let unresolved_shuffle =
             downcast_exec!(unresolved_shuffle, UnresolvedShuffleExec);
         assert_eq!(unresolved_shuffle.stage_id, 2);
-        assert_eq!(unresolved_shuffle.input_partition_count, 2);
         assert_eq!(unresolved_shuffle.output_partition_count, 2);
 
         Ok(())
@@ -555,7 +548,6 @@ order by
         let join_input_1 = join_input_1.children()[0].clone();
         let unresolved_shuffle_reader_1 =
             downcast_exec!(join_input_1, UnresolvedShuffleExec);
-        assert_eq!(unresolved_shuffle_reader_1.input_partition_count, 2); // lineitem
         assert_eq!(unresolved_shuffle_reader_1.output_partition_count, 2);
 
         let join_input_2 = join.children()[1].clone();
@@ -563,7 +555,6 @@ order by
         let join_input_2 = join_input_2.children()[0].clone();
         let unresolved_shuffle_reader_2 =
             downcast_exec!(join_input_2, UnresolvedShuffleExec);
-        assert_eq!(unresolved_shuffle_reader_2.input_partition_count, 1); // orders
         assert_eq!(unresolved_shuffle_reader_2.output_partition_count, 2);
 
         // final partitioned hash aggregate
diff --git a/ballista/scheduler/src/scheduler_server/mod.rs b/ballista/scheduler/src/scheduler_server/mod.rs
index 86893d19..74bf3359 100644
--- a/ballista/scheduler/src/scheduler_server/mod.rs
+++ b/ballista/scheduler/src/scheduler_server/mod.rs
@@ -460,10 +460,7 @@ mod test {
             if let Some(task) = task {
                 let mut partitions: Vec<ShuffleWritePartition> = vec![];
 
-                let num_partitions = task
-                    .output_partitioning
-                    .map(|p| p.partition_count())
-                    .unwrap_or(1);
+                let num_partitions = task.get_output_partition_number();
 
                 for partition_id in 0..num_partitions {
                     partitions.push(ShuffleWritePartition {
diff --git a/ballista/scheduler/src/state/execution_graph.rs b/ballista/scheduler/src/state/execution_graph.rs
index 16fbffbf..c00363d7 100644
--- a/ballista/scheduler/src/state/execution_graph.rs
+++ b/ballista/scheduler/src/state/execution_graph.rs
@@ -23,9 +23,7 @@ use std::sync::Arc;
 use std::time::{SystemTime, UNIX_EPOCH};
 
 use datafusion::physical_plan::display::DisplayableExecutionPlan;
-use datafusion::physical_plan::{
-    accept, ExecutionPlan, ExecutionPlanVisitor, Partitioning,
-};
+use datafusion::physical_plan::{accept, ExecutionPlan, ExecutionPlanVisitor};
 use datafusion::prelude::SessionContext;
 use datafusion_proto::logical_plan::AsLogicalPlan;
 use log::{error, info, warn};
@@ -913,7 +911,6 @@ impl ExecutionGraph {
                     task_id,
                     task_attempt,
                     plan: stage.plan.clone(),
-                    output_partitioning: stage.output_partitioning.clone(),
                 })
             } else {
                 Err(BallistaError::General(format!("Stage {stage_id} is not a running stage")))
@@ -1468,7 +1465,6 @@ impl ExecutionStageBuilder {
 
         // Now, create the execution stages
         for stage in stages {
-            let partitioning = stage.shuffle_output_partitioning().cloned();
             let stage_id = stage.stage_id();
             let output_links = self.output_links.remove(&stage_id).unwrap_or_default();
 
@@ -1482,7 +1478,6 @@ impl ExecutionStageBuilder {
                     stage_id,
                     0,
                     stage,
-                    partitioning,
                     output_links,
                     HashMap::new(),
                     HashSet::new(),
@@ -1491,7 +1486,6 @@ impl ExecutionStageBuilder {
                 ExecutionStage::UnResolved(UnresolvedStage::new(
                     stage_id,
                     stage,
-                    partitioning,
                     output_links,
                     child_stages,
                 ))
@@ -1549,7 +1543,6 @@ pub struct TaskDescription {
     pub task_id: usize,
     pub task_attempt: usize,
     pub plan: Arc<dyn ExecutionPlan>,
-    pub output_partitioning: Option<Partitioning>,
 }
 
 impl Debug for TaskDescription {
@@ -1570,6 +1563,20 @@ impl Debug for TaskDescription {
     }
 }
 
+impl TaskDescription {
+    pub fn get_output_partition_number(&self) -> usize {
+        let shuffle_writer = self
+            .plan
+            .as_any()
+            .downcast_ref::<ShuffleWriterExec>()
+            .unwrap();
+        shuffle_writer
+            .shuffle_output_partitioning()
+            .map(|partitioning| partitioning.partition_count())
+            .unwrap_or_else(|| 1)
+    }
+}
+
 fn partition_to_location(
     job_id: &str,
     map_partition_id: usize,
diff --git a/ballista/scheduler/src/state/execution_graph/execution_stage.rs b/ballista/scheduler/src/state/execution_graph/execution_stage.rs
index 3cc3eb14..a4cdf238 100644
--- a/ballista/scheduler/src/state/execution_graph/execution_stage.rs
+++ b/ballista/scheduler/src/state/execution_graph/execution_stage.rs
@@ -26,22 +26,21 @@ use datafusion::physical_optimizer::join_selection::JoinSelection;
 use datafusion::physical_optimizer::PhysicalOptimizerRule;
 use datafusion::physical_plan::display::DisplayableExecutionPlan;
 use datafusion::physical_plan::metrics::{MetricValue, MetricsSet};
-use datafusion::physical_plan::{ExecutionPlan, Metric, Partitioning};
+use datafusion::physical_plan::{ExecutionPlan, Metric};
 use datafusion::prelude::{SessionConfig, SessionContext};
 use datafusion_proto::logical_plan::AsLogicalPlan;
 use log::{debug, warn};
 
 use ballista_core::error::{BallistaError, Result};
+use ballista_core::execution_plans::ShuffleWriterExec;
 use ballista_core::serde::protobuf::failed_task::FailedReason;
 use ballista_core::serde::protobuf::{
     self, task_info, FailedTask, GraphStageInput, OperatorMetricsSet, ResultLost,
     SuccessfulTask, TaskStatus,
 };
 use ballista_core::serde::protobuf::{task_status, RunningTask};
-use ballista_core::serde::scheduler::to_proto::hash_partitioning_to_proto;
 use ballista_core::serde::scheduler::PartitionLocation;
 use ballista_core::serde::BallistaCodec;
-use datafusion_proto::physical_plan::from_proto::parse_protobuf_hash_partitioning;
 use datafusion_proto::physical_plan::AsExecutionPlan;
 
 use crate::display::DisplayableBallistaExecutionPlan;
@@ -107,8 +106,6 @@ pub(crate) struct UnresolvedStage {
     pub(crate) stage_id: usize,
     /// Stage Attempt number
     pub(crate) stage_attempt_num: usize,
-    /// Output partitioning for this stage.
-    pub(crate) output_partitioning: Option<Partitioning>,
     /// Stage ID of the stage that will take this stages outputs as inputs.
     /// If `output_links` is empty then this the final stage in the `ExecutionGraph`
     pub(crate) output_links: Vec<usize>,
@@ -132,8 +129,6 @@ pub(crate) struct ResolvedStage {
     /// Total number of partitions for this stage.
     /// This stage will produce on task for partition.
     pub(crate) partitions: usize,
-    /// Output partitioning for this stage.
-    pub(crate) output_partitioning: Option<Partitioning>,
     /// Stage ID of the stage that will take this stages outputs as inputs.
     /// If `output_links` is empty then this the final stage in the `ExecutionGraph`
     pub(crate) output_links: Vec<usize>,
@@ -159,8 +154,6 @@ pub(crate) struct RunningStage {
     /// Total number of partitions for this stage.
     /// This stage will produce on task for partition.
     pub(crate) partitions: usize,
-    /// Output partitioning for this stage.
-    pub(crate) output_partitioning: Option<Partitioning>,
     /// Stage ID of the stage that will take this stages outputs as inputs.
     /// If `output_links` is empty then this the final stage in the `ExecutionGraph`
     pub(crate) output_links: Vec<usize>,
@@ -188,8 +181,6 @@ pub(crate) struct SuccessfulStage {
     /// Total number of partitions for this stage.
     /// This stage will produce on task for partition.
     pub(crate) partitions: usize,
-    /// Output partitioning for this stage.
-    pub(crate) output_partitioning: Option<Partitioning>,
     /// Stage ID of the stage that will take this stages outputs as inputs.
     /// If `output_links` is empty then this the final stage in the `ExecutionGraph`
     pub(crate) output_links: Vec<usize>,
@@ -214,8 +205,6 @@ pub(crate) struct FailedStage {
     /// Total number of partitions for this stage.
     /// This stage will produce on task for partition.
     pub(crate) partitions: usize,
-    /// Output partitioning for this stage.
-    pub(crate) output_partitioning: Option<Partitioning>,
     /// Stage ID of the stage that will take this stages outputs as inputs.
     /// If `output_links` is empty then this the final stage in the `ExecutionGraph`
     pub(crate) output_links: Vec<usize>,
@@ -252,7 +241,6 @@ impl UnresolvedStage {
     pub(super) fn new(
         stage_id: usize,
         plan: Arc<dyn ExecutionPlan>,
-        output_partitioning: Option<Partitioning>,
         output_links: Vec<usize>,
         child_stage_ids: Vec<usize>,
     ) -> Self {
@@ -264,7 +252,6 @@ impl UnresolvedStage {
         Self {
             stage_id,
             stage_attempt_num: 0,
-            output_partitioning,
             output_links,
             inputs,
             plan,
@@ -276,7 +263,6 @@ impl UnresolvedStage {
         stage_id: usize,
         stage_attempt_num: usize,
         plan: Arc<dyn ExecutionPlan>,
-        output_partitioning: Option<Partitioning>,
         output_links: Vec<usize>,
         inputs: HashMap<usize, StageOutput>,
         last_attempt_failure_reasons: HashSet<String>,
@@ -284,7 +270,6 @@ impl UnresolvedStage {
         Self {
             stage_id,
             stage_attempt_num,
-            output_partitioning,
             output_links,
             inputs,
             plan,
@@ -371,7 +356,6 @@ impl UnresolvedStage {
             self.stage_id,
             self.stage_attempt_num,
             plan,
-            self.output_partitioning.clone(),
             self.output_links.clone(),
             self.inputs.clone(),
             self.last_attempt_failure_reasons.clone(),
@@ -390,18 +374,11 @@ impl UnresolvedStage {
             codec.physical_extension_codec(),
         )?;
 
-        let output_partitioning: Option<Partitioning> = parse_protobuf_hash_partitioning(
-            stage.output_partitioning.as_ref(),
-            session_ctx,
-            plan.schema().as_ref(),
-        )?;
-
         let inputs = decode_inputs(stage.inputs)?;
 
         Ok(UnresolvedStage {
             stage_id: stage.stage_id as usize,
             stage_attempt_num: stage.stage_attempt_num as usize,
-            output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as usize).collect(),
             plan,
             inputs,
@@ -421,13 +398,9 @@ impl UnresolvedStage {
 
         let inputs = encode_inputs(stage.inputs)?;
 
-        let output_partitioning =
-            hash_partitioning_to_proto(stage.output_partitioning.as_ref())?;
-
         Ok(protobuf::UnResolvedStage {
             stage_id: stage.stage_id as u32,
             stage_attempt_num: stage.stage_attempt_num as u32,
-            output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as u32).collect(),
             inputs,
             plan,
@@ -459,18 +432,16 @@ impl ResolvedStage {
         stage_id: usize,
         stage_attempt_num: usize,
         plan: Arc<dyn ExecutionPlan>,
-        output_partitioning: Option<Partitioning>,
         output_links: Vec<usize>,
         inputs: HashMap<usize, StageOutput>,
         last_attempt_failure_reasons: HashSet<String>,
     ) -> Self {
-        let partitions = plan.output_partitioning().partition_count();
+        let partitions = get_stage_partitions(plan.clone());
 
         Self {
             stage_id,
             stage_attempt_num,
             partitions,
-            output_partitioning,
             output_links,
             inputs,
             plan,
@@ -485,7 +456,6 @@ impl ResolvedStage {
             self.stage_attempt_num,
             self.plan.clone(),
             self.partitions,
-            self.output_partitioning.clone(),
             self.output_links.clone(),
             self.inputs.clone(),
         )
@@ -499,7 +469,6 @@ impl ResolvedStage {
             self.stage_id,
             self.stage_attempt_num,
             new_plan,
-            self.output_partitioning.clone(),
             self.output_links.clone(),
             self.inputs.clone(),
             self.last_attempt_failure_reasons.clone(),
@@ -519,19 +488,12 @@ impl ResolvedStage {
             codec.physical_extension_codec(),
         )?;
 
-        let output_partitioning: Option<Partitioning> = parse_protobuf_hash_partitioning(
-            stage.output_partitioning.as_ref(),
-            session_ctx,
-            plan.schema().as_ref(),
-        )?;
-
         let inputs = decode_inputs(stage.inputs)?;
 
         Ok(ResolvedStage {
             stage_id: stage.stage_id as usize,
             stage_attempt_num: stage.stage_attempt_num as usize,
             partitions: stage.partitions as usize,
-            output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as usize).collect(),
             inputs,
             plan,
@@ -549,16 +511,12 @@ impl ResolvedStage {
         U::try_from_physical_plan(stage.plan, codec.physical_extension_codec())
             .and_then(|proto| proto.try_encode(&mut plan))?;
 
-        let output_partitioning =
-            hash_partitioning_to_proto(stage.output_partitioning.as_ref())?;
-
         let inputs = encode_inputs(stage.inputs)?;
 
         Ok(protobuf::ResolvedStage {
             stage_id: stage.stage_id as u32,
             stage_attempt_num: stage.stage_attempt_num as u32,
             partitions: stage.partitions as u32,
-            output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as u32).collect(),
             inputs,
             plan,
@@ -587,7 +545,6 @@ impl RunningStage {
         stage_attempt_num: usize,
         plan: Arc<dyn ExecutionPlan>,
         partitions: usize,
-        output_partitioning: Option<Partitioning>,
         output_links: Vec<usize>,
         inputs: HashMap<usize, StageOutput>,
     ) -> Self {
@@ -595,7 +552,6 @@ impl RunningStage {
             stage_id,
             stage_attempt_num,
             partitions,
-            output_partitioning,
             output_links,
             inputs,
             plan,
@@ -627,7 +583,6 @@ impl RunningStage {
             stage_id: self.stage_id,
             stage_attempt_num: self.stage_attempt_num,
             partitions: self.partitions,
-            output_partitioning: self.output_partitioning.clone(),
             output_links: self.output_links.clone(),
             inputs: self.inputs.clone(),
             plan: self.plan.clone(),
@@ -641,7 +596,6 @@ impl RunningStage {
             stage_id: self.stage_id,
             stage_attempt_num: self.stage_attempt_num,
             partitions: self.partitions,
-            output_partitioning: self.output_partitioning.clone(),
             output_links: self.output_links.clone(),
             plan: self.plan.clone(),
             task_infos: self.task_infos.clone(),
@@ -656,7 +610,6 @@ impl RunningStage {
             self.stage_id,
             self.stage_attempt_num + 1,
             self.plan.clone(),
-            self.output_partitioning.clone(),
             self.output_links.clone(),
             self.inputs.clone(),
             HashSet::new(),
@@ -674,7 +627,6 @@ impl RunningStage {
             self.stage_id,
             self.stage_attempt_num + 1,
             new_plan,
-            self.output_partitioning.clone(),
             self.output_links.clone(),
             self.inputs.clone(),
             failure_reasons,
@@ -946,7 +898,6 @@ impl SuccessfulStage {
             stage_id: self.stage_id,
             stage_attempt_num: self.stage_attempt_num + 1,
             partitions: self.partitions,
-            output_partitioning: self.output_partitioning.clone(),
             output_links: self.output_links.clone(),
             inputs: self.inputs.clone(),
             plan: self.plan.clone(),
@@ -1007,12 +958,6 @@ impl SuccessfulStage {
             codec.physical_extension_codec(),
         )?;
 
-        let output_partitioning: Option<Partitioning> = parse_protobuf_hash_partitioning(
-            stage.output_partitioning.as_ref(),
-            session_ctx,
-            plan.schema().as_ref(),
-        )?;
-
         let inputs = decode_inputs(stage.inputs)?;
         assert_eq!(
             stage.task_infos.len(),
@@ -1030,7 +975,6 @@ impl SuccessfulStage {
             stage_id: stage.stage_id as usize,
             stage_attempt_num: stage.stage_attempt_num as usize,
             partitions: stage.partitions as usize,
-            output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as usize).collect(),
             inputs,
             plan,
@@ -1050,9 +994,6 @@ impl SuccessfulStage {
         U::try_from_physical_plan(stage.plan, codec.physical_extension_codec())
             .and_then(|proto| proto.try_encode(&mut plan))?;
 
-        let output_partitioning =
-            hash_partitioning_to_proto(stage.output_partitioning.as_ref())?;
-
         let inputs = encode_inputs(stage.inputs)?;
         let task_infos = stage
             .task_infos
@@ -1071,7 +1012,6 @@ impl SuccessfulStage {
             stage_id: stage_id as u32,
             stage_attempt_num: stage.stage_attempt_num as u32,
             partitions: stage.partitions as u32,
-            output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as u32).collect(),
             inputs,
             plan,
@@ -1137,12 +1077,6 @@ impl FailedStage {
             codec.physical_extension_codec(),
         )?;
 
-        let output_partitioning: Option<Partitioning> = parse_protobuf_hash_partitioning(
-            stage.output_partitioning.as_ref(),
-            session_ctx,
-            plan.schema().as_ref(),
-        )?;
-
         let mut task_infos: Vec<Option<TaskInfo>> = vec![None; stage.partitions as usize];
         for info in stage.task_infos {
             task_infos[info.partition_id as usize] = Some(decode_taskinfo(info.clone()));
@@ -1163,7 +1097,6 @@ impl FailedStage {
             stage_id: stage.stage_id as usize,
             stage_attempt_num: stage.stage_attempt_num as usize,
             partitions: stage.partitions as usize,
-            output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as usize).collect(),
             plan,
             task_infos,
@@ -1183,9 +1116,6 @@ impl FailedStage {
         U::try_from_physical_plan(stage.plan, codec.physical_extension_codec())
             .and_then(|proto| proto.try_encode(&mut plan))?;
 
-        let output_partitioning =
-            hash_partitioning_to_proto(stage.output_partitioning.as_ref())?;
-
         let task_infos: Vec<protobuf::TaskInfo> = stage
             .task_infos
             .into_iter()
@@ -1206,7 +1136,6 @@ impl FailedStage {
             stage_id: stage_id as u32,
             stage_attempt_num: stage.stage_attempt_num as u32,
             partitions: stage.partitions as u32,
-            output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as u32).collect(),
             plan,
             task_infos,
@@ -1235,6 +1164,16 @@ impl Debug for FailedStage {
     }
 }
 
+/// Get the total number of partitions for a stage with plan.
+/// Only for [`ShuffleWriterExec`], the input partition count and the output partition count
+/// will be different. Here, we should use the input partition count.
+fn get_stage_partitions(plan: Arc<dyn ExecutionPlan>) -> usize {
+    plan.as_any()
+        .downcast_ref::<ShuffleWriterExec>()
+        .map(|shuffle_writer| shuffle_writer.input_partition_count())
+        .unwrap_or_else(|| plan.output_partitioning().partition_count())
+}
+
 /// This data structure collects the partition locations for an `ExecutionStage`.
 /// Each `ExecutionStage` will hold a `StageOutput`s for each of its child stages.
 /// When all tasks for the child stage are complete, it will mark the `StageOutput`
diff --git a/ballista/scheduler/src/state/execution_graph_dot.rs b/ballista/scheduler/src/state/execution_graph_dot.rs
index d9df9488..254b6072 100644
--- a/ballista/scheduler/src/state/execution_graph_dot.rs
+++ b/ballista/scheduler/src/state/execution_graph_dot.rs
@@ -318,7 +318,7 @@ filter_expr={}",
     } else if let Some(exec) = plan.as_any().downcast_ref::<ShuffleWriterExec>() {
         format!(
             "ShuffleWriter [{} partitions]",
-            exec.output_partitioning().partition_count()
+            exec.input_partition_count()
         )
     } else if plan.as_any().downcast_ref::<MemoryExec>().is_some() {
         "MemoryExec".to_string()
diff --git a/ballista/scheduler/src/test_utils.rs b/ballista/scheduler/src/test_utils.rs
index 06798903..1821c729 100644
--- a/ballista/scheduler/src/test_utils.rs
+++ b/ballista/scheduler/src/test_utils.rs
@@ -994,10 +994,7 @@ pub fn mock_executor(executor_id: String) -> ExecutorMetadata {
 pub fn mock_completed_task(task: TaskDescription, executor_id: &str) -> TaskStatus {
     let mut partitions: Vec<protobuf::ShuffleWritePartition> = vec![];
 
-    let num_partitions = task
-        .output_partitioning
-        .map(|p| p.partition_count())
-        .unwrap_or(1);
+    let num_partitions = task.get_output_partition_number();
 
     for partition_id in 0..num_partitions {
         partitions.push(protobuf::ShuffleWritePartition {
@@ -1035,10 +1032,7 @@ pub fn mock_completed_task(task: TaskDescription, executor_id: &str) -> TaskStat
 pub fn mock_failed_task(task: TaskDescription, failed_task: FailedTask) -> TaskStatus {
     let mut partitions: Vec<protobuf::ShuffleWritePartition> = vec![];
 
-    let num_partitions = task
-        .output_partitioning
-        .map(|p| p.partition_count())
-        .unwrap_or(1);
+    let num_partitions = task.get_output_partition_number();
 
     for partition_id in 0..num_partitions {
         partitions.push(protobuf::ShuffleWritePartition {