You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/07/16 15:47:46 UTC

[arrow-ballista] branch master updated: Support for multi-scheduler deployments (#59)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/master by this push:
     new a2c794e7 Support for multi-scheduler deployments (#59)
a2c794e7 is described below

commit a2c794e72f0b6e21b05bb8c05b38f68bba206a44
Author: Dan Harris <13...@users.noreply.github.com>
AuthorDate: Sat Jul 16 11:47:40 2022 -0400

    Support for multi-scheduler deployments (#59)
    
    * Initial design and implementation
    
    * ExecutorManager tests
    
    * Only consider alive executors
    
    * Use correct session ID provided in request
    
    * Fix bug in etcd key scan
    
    * Debugging
    
    * Drop for EtcdLock
    
    * Better locking
    
    * Debug for ExecutionGraph
    
    * Fix partition accounting in ExecutionGraph
    
    * Fix input partition accounting
    
    * Handle stages with multiple inputs better
    
    * Simplify output buffer
    
    * Cleanup
    
    * Cleanup
    
    * Linting
    
    * Linting and docs
    
    * Job queueing and general cleanup
    
    * Handle job queueing and failure
    
    * Tests
    
    * Fix doc comments
    
    * Tests
    
    * Add license header
    
    * Fix graph complete logic
    
    * Fix bug in partition mapping
    
    * Eagerly offer pending tasks
    
    * Tests for event loop
    
    * Merge upstream
    
    * Fix compiler error after rebase
    
    * Clippy fix
    
    * Merge pull request #4 from coralogix/scheduler-fix
    
    Scheduler fixes
    
    * Use correct bind address for executor registration
    
    * Use correct keyspace when initing heartbeats
    
    * Fix after cherry-pick bugfixes
    
    * Fix conflicts after merge
    
    Co-authored-by: Martins Purins <ma...@coralogix.com>
---
 ballista/rust/core/proto/ballista.proto            |  35 +
 ballista/rust/core/src/config.rs                   |   2 +-
 ballista/rust/executor/src/execution_loop.rs       |  16 +-
 ballista/rust/executor/src/executor_server.rs      |   2 +-
 ballista/rust/scheduler/src/api/handlers.rs        |   3 +-
 ballista/rust/scheduler/src/main.rs                |   2 +-
 ballista/rust/scheduler/src/planner.rs             | 188 ++--
 .../rust/scheduler/src/scheduler_server/event.rs   |  25 +-
 .../scheduler/src/scheduler_server/event_loop.rs   | 453 +++++++---
 .../src/scheduler_server/external_scaler.rs        |   6 +-
 .../rust/scheduler/src/scheduler_server/grpc.rs    | 497 +++++------
 .../rust/scheduler/src/scheduler_server/mod.rs     | 910 +++++++++++--------
 .../src/scheduler_server/query_stage_scheduler.rs  | 484 ++--------
 ballista/rust/scheduler/src/state/backend/etcd.rs  | 183 +++-
 ballista/rust/scheduler/src/state/backend/mod.rs   |  60 +-
 .../rust/scheduler/src/state/backend/standalone.rs | 193 +++-
 .../rust/scheduler/src/state/execution_graph.rs    | 974 +++++++++++++++++++++
 .../rust/scheduler/src/state/executor_manager.rs   | 609 +++++++++++--
 ballista/rust/scheduler/src/state/mod.rs           | 283 ++----
 .../rust/scheduler/src/state/persistent_state.rs   | 525 -----------
 .../rust/scheduler/src/state/session_manager.rs    | 144 +++
 .../rust/scheduler/src/state/session_registry.rs   |  68 ++
 ballista/rust/scheduler/src/state/stage_manager.rs | 783 -----------------
 ballista/rust/scheduler/src/state/task_manager.rs  | 751 ++++++++++++++++
 .../rust/scheduler/src/state/task_scheduler.rs     | 212 -----
 ballista/rust/scheduler/src/test_utils.rs          | 111 ++-
 benchmarks/docker-compose.yaml                     |   7 +-
 27 files changed, 4439 insertions(+), 3087 deletions(-)

diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto
index f899a603..4e2c55f6 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -408,6 +408,37 @@ enum JoinSide{
 ///////////////////////////////////////////////////////////////////////////////////////////////////
 // Ballista Scheduling
 ///////////////////////////////////////////////////////////////////////////////////////////////////
+message TaskInputPartitions {
+  uint32 partition = 1;
+  repeated PartitionLocation partition_location = 2;
+}
+
+message GraphStageInput {
+  uint32 stage_id = 1;
+  repeated TaskInputPartitions partition_locations = 2;
+  bool complete = 3;
+}
+
+
+message ExecutionGraphStage {
+  uint64 stage_id = 1;
+  uint32 partitions = 2;
+  PhysicalHashRepartition output_partitioning = 3;
+  repeated  GraphStageInput inputs = 4;
+  bytes plan = 5;
+  repeated TaskStatus task_statuses = 6;
+  uint32 output_link = 7;
+  bool resolved = 8;
+}
+
+message ExecutionGraph {
+  string job_id = 1;
+  string session_id = 2;
+  JobStatus status = 3;
+  repeated ExecutionGraphStage stages = 4;
+  uint64 output_partitions = 5;
+  repeated PartitionLocation output_locations = 6;
+}
 
 message KeyValuePair {
   string key = 1;
@@ -581,6 +612,10 @@ message TaskDefinition {
   repeated KeyValuePair props = 5;
 }
 
+message SessionSettings {
+  repeated KeyValuePair configs = 1;
+}
+
 message JobSessionConfig {
   string session_id = 1;
   repeated KeyValuePair configs = 2;
diff --git a/ballista/rust/core/src/config.rs b/ballista/rust/core/src/config.rs
index b423fe18..8975e8e3 100644
--- a/ballista/rust/core/src/config.rs
+++ b/ballista/rust/core/src/config.rs
@@ -92,7 +92,7 @@ impl BallistaConfigBuilder {
 }
 
 /// Ballista configuration
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, PartialEq)]
 pub struct BallistaConfig {
     /// Settings stored in map for easy serde
     settings: HashMap<String, String>,
diff --git a/ballista/rust/executor/src/execution_loop.rs b/ballista/rust/executor/src/execution_loop.rs
index 0520ab65..c9408124 100644
--- a/ballista/rust/executor/src/execution_loop.rs
+++ b/ballista/rust/executor/src/execution_loop.rs
@@ -15,21 +15,23 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::as_task_status;
-use crate::executor::Executor;
-use ballista_core::error::BallistaError;
-use ballista_core::serde::physical_plan::from_proto::parse_protobuf_hash_partitioning;
+use datafusion::physical_plan::ExecutionPlan;
+
 use ballista_core::serde::protobuf::{
     scheduler_grpc_client::SchedulerGrpcClient, PollWorkParams, PollWorkResult,
     TaskDefinition, TaskStatus,
 };
+
+use crate::as_task_status;
+use crate::executor::Executor;
+use ballista_core::error::BallistaError;
+use ballista_core::serde::physical_plan::from_proto::parse_protobuf_hash_partitioning;
 use ballista_core::serde::scheduler::ExecutorSpecification;
 use ballista_core::serde::{AsExecutionPlan, BallistaCodec};
 use datafusion::execution::context::TaskContext;
-use datafusion::physical_plan::ExecutionPlan;
 use datafusion_proto::logical_plan::AsLogicalPlan;
 use futures::FutureExt;
-use log::{debug, error, info, warn};
+use log::{debug, error, info, trace, warn};
 use std::any::Any;
 use std::collections::HashMap;
 use std::error::Error;
@@ -57,7 +59,7 @@ pub async fn poll_loop<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
         std::sync::mpsc::channel::<TaskStatus>();
 
     loop {
-        debug!("Starting registration loop with scheduler");
+        trace!("Starting registration loop with scheduler");
 
         let task_status: Vec<TaskStatus> =
             sample_tasks_status(&mut task_status_receiver).await;
diff --git a/ballista/rust/executor/src/executor_server.rs b/ballista/rust/executor/src/executor_server.rs
index 158c708a..64f2f4f7 100644
--- a/ballista/rust/executor/src/executor_server.rs
+++ b/ballista/rust/executor/src/executor_server.rs
@@ -71,7 +71,7 @@ pub async fn startup<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
                 .map(|h| match h {
                     OptionalHost::Host(host) => host,
                 })
-                .unwrap_or_else(|| String::from("127.0.0.1")),
+                .unwrap_or_else(|| String::from("0.0.0.0")),
             executor_meta.grpc_port
         );
         let addr = addr.parse().unwrap();
diff --git a/ballista/rust/scheduler/src/api/handlers.rs b/ballista/rust/scheduler/src/api/handlers.rs
index b6c322dc..92d2e087 100644
--- a/ballista/rust/scheduler/src/api/handlers.rs
+++ b/ballista/rust/scheduler/src/api/handlers.rs
@@ -37,7 +37,8 @@ pub(crate) async fn scheduler_state<T: AsLogicalPlan, U: AsExecutionPlan>(
     // TODO: Display last seen information in UI
     let executors: Vec<ExecutorMetaResponse> = data_server
         .state
-        .get_executors_metadata()
+        .executor_manager
+        .get_executor_state()
         .await
         .unwrap_or_default()
         .into_iter()
diff --git a/ballista/rust/scheduler/src/main.rs b/ballista/rust/scheduler/src/main.rs
index 39e893b9..2dd34c92 100644
--- a/ballista/rust/scheduler/src/main.rs
+++ b/ballista/rust/scheduler/src/main.rs
@@ -168,7 +168,7 @@ async fn main() -> Result<()> {
             let etcd = etcd_client::Client::connect(&[opt.etcd_urls], None)
                 .await
                 .context("Could not connect to etcd")?;
-            Arc::new(EtcdClient::new(etcd))
+            Arc::new(EtcdClient::new(namespace.clone(), etcd))
         }
         #[cfg(not(feature = "etcd"))]
         StateBackend::Etcd => {
diff --git a/ballista/rust/scheduler/src/planner.rs b/ballista/rust/scheduler/src/planner.rs
index 33dc3ac0..1c946937 100644
--- a/ballista/rust/scheduler/src/planner.rs
+++ b/ballista/rust/scheduler/src/planner.rs
@@ -33,8 +33,7 @@ use datafusion::physical_plan::windows::WindowAggExec;
 use datafusion::physical_plan::{
     with_new_children_if_necessary, ExecutionPlan, Partitioning,
 };
-use futures::future::BoxFuture;
-use futures::FutureExt;
+
 use log::info;
 
 type PartialQueryStageResult = (Arc<dyn ExecutionPlan>, Vec<Arc<ShuffleWriterExec>>);
@@ -59,15 +58,14 @@ impl DistributedPlanner {
     /// Returns a vector of ExecutionPlans, where the root node is a [ShuffleWriterExec].
     /// Plans that depend on the input of other plans will have leaf nodes of type [UnresolvedShuffleExec].
     /// A [ShuffleWriterExec] is created whenever the partitioning changes.
-    pub async fn plan_query_stages<'a>(
+    pub fn plan_query_stages<'a>(
         &'a mut self,
         job_id: &'a str,
         execution_plan: Arc<dyn ExecutionPlan>,
     ) -> Result<Vec<Arc<ShuffleWriterExec>>> {
         info!("planning query stages");
-        let (new_plan, mut stages) = self
-            .plan_query_stages_internal(job_id, execution_plan)
-            .await?;
+        let (new_plan, mut stages) =
+            self.plan_query_stages_internal(job_id, execution_plan)?;
         stages.push(create_shuffle_writer(
             job_id,
             self.next_stage_id(),
@@ -84,97 +82,91 @@ impl DistributedPlanner {
         &'a mut self,
         job_id: &'a str,
         execution_plan: Arc<dyn ExecutionPlan>,
-    ) -> BoxFuture<'a, Result<PartialQueryStageResult>> {
-        async move {
-            // recurse down and replace children
-            if execution_plan.children().is_empty() {
-                return Ok((execution_plan, vec![]));
-            }
+    ) -> Result<PartialQueryStageResult> {
+        // async move {
+        // recurse down and replace children
+        if execution_plan.children().is_empty() {
+            return Ok((execution_plan, vec![]));
+        }
 
-            let mut stages = vec![];
-            let mut children = vec![];
-            for child in execution_plan.children() {
-                let (new_child, mut child_stages) = self
-                    .plan_query_stages_internal(job_id, child.clone())
-                    .await?;
-                children.push(new_child);
-                stages.append(&mut child_stages);
-            }
+        let mut stages = vec![];
+        let mut children = vec![];
+        for child in execution_plan.children() {
+            let (new_child, mut child_stages) =
+                self.plan_query_stages_internal(job_id, child.clone())?;
+            children.push(new_child);
+            stages.append(&mut child_stages);
+        }
 
-            if let Some(_coalesce) = execution_plan
-                .as_any()
-                .downcast_ref::<CoalescePartitionsExec>()
-            {
-                let shuffle_writer = create_shuffle_writer(
-                    job_id,
-                    self.next_stage_id(),
-                    children[0].clone(),
-                    None,
-                )?;
-                let unresolved_shuffle = Arc::new(UnresolvedShuffleExec::new(
-                    shuffle_writer.stage_id(),
-                    shuffle_writer.schema(),
-                    shuffle_writer.output_partitioning().partition_count(),
-                    shuffle_writer
-                        .shuffle_output_partitioning()
-                        .map(|p| p.partition_count())
-                        .unwrap_or_else(|| {
-                            shuffle_writer.output_partitioning().partition_count()
-                        }),
-                ));
-                stages.push(shuffle_writer);
-                Ok((
-                    with_new_children_if_necessary(
-                        execution_plan,
-                        vec![unresolved_shuffle],
-                    )?,
-                    stages,
-                ))
-            } else if let Some(repart) =
-                execution_plan.as_any().downcast_ref::<RepartitionExec>()
-            {
-                match repart.output_partitioning() {
-                    Partitioning::Hash(_, _) => {
-                        let shuffle_writer = create_shuffle_writer(
-                            job_id,
-                            self.next_stage_id(),
-                            children[0].clone(),
-                            Some(repart.partitioning().to_owned()),
-                        )?;
-                        let unresolved_shuffle = Arc::new(UnresolvedShuffleExec::new(
-                            shuffle_writer.stage_id(),
-                            shuffle_writer.schema(),
-                            shuffle_writer.output_partitioning().partition_count(),
-                            shuffle_writer
-                                .shuffle_output_partitioning()
-                                .map(|p| p.partition_count())
-                                .unwrap_or_else(|| {
-                                    shuffle_writer.output_partitioning().partition_count()
-                                }),
-                        ));
-                        stages.push(shuffle_writer);
-                        Ok((unresolved_shuffle, stages))
-                    }
-                    _ => {
-                        // remove any non-hash repartition from the distributed plan
-                        Ok((children[0].clone(), stages))
-                    }
+        if let Some(_coalesce) = execution_plan
+            .as_any()
+            .downcast_ref::<CoalescePartitionsExec>()
+        {
+            let shuffle_writer = create_shuffle_writer(
+                job_id,
+                self.next_stage_id(),
+                children[0].clone(),
+                None,
+            )?;
+            let unresolved_shuffle = Arc::new(UnresolvedShuffleExec::new(
+                shuffle_writer.stage_id(),
+                shuffle_writer.schema(),
+                shuffle_writer.output_partitioning().partition_count(),
+                shuffle_writer
+                    .shuffle_output_partitioning()
+                    .map(|p| p.partition_count())
+                    .unwrap_or_else(|| {
+                        shuffle_writer.output_partitioning().partition_count()
+                    }),
+            ));
+            stages.push(shuffle_writer);
+            Ok((
+                with_new_children_if_necessary(execution_plan, vec![unresolved_shuffle])?,
+                stages,
+            ))
+        } else if let Some(repart) =
+            execution_plan.as_any().downcast_ref::<RepartitionExec>()
+        {
+            match repart.output_partitioning() {
+                Partitioning::Hash(_, _) => {
+                    let shuffle_writer = create_shuffle_writer(
+                        job_id,
+                        self.next_stage_id(),
+                        children[0].clone(),
+                        Some(repart.partitioning().to_owned()),
+                    )?;
+                    let unresolved_shuffle = Arc::new(UnresolvedShuffleExec::new(
+                        shuffle_writer.stage_id(),
+                        shuffle_writer.schema(),
+                        shuffle_writer.output_partitioning().partition_count(),
+                        shuffle_writer
+                            .shuffle_output_partitioning()
+                            .map(|p| p.partition_count())
+                            .unwrap_or_else(|| {
+                                shuffle_writer.output_partitioning().partition_count()
+                            }),
+                    ));
+                    stages.push(shuffle_writer);
+                    Ok((unresolved_shuffle, stages))
+                }
+                _ => {
+                    // remove any non-hash repartition from the distributed plan
+                    Ok((children[0].clone(), stages))
                 }
-            } else if let Some(window) =
-                execution_plan.as_any().downcast_ref::<WindowAggExec>()
-            {
-                Err(BallistaError::NotImplemented(format!(
-                    "WindowAggExec with window {:?}",
-                    window
-                )))
-            } else {
-                Ok((
-                    with_new_children_if_necessary(execution_plan, children)?,
-                    stages,
-                ))
             }
+        } else if let Some(window) =
+            execution_plan.as_any().downcast_ref::<WindowAggExec>()
+        {
+            Err(BallistaError::NotImplemented(format!(
+                "WindowAggExec with window {:?}",
+                window
+            )))
+        } else {
+            Ok((
+                with_new_children_if_necessary(execution_plan, children)?,
+                stages,
+            ))
         }
-        .boxed()
     }
 
     /// Generate a new stage ID
@@ -318,9 +310,7 @@ mod test {
 
         let mut planner = DistributedPlanner::new();
         let job_uuid = Uuid::new_v4();
-        let stages = planner
-            .plan_query_stages(&job_uuid.to_string(), plan)
-            .await?;
+        let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
         for stage in &stages {
             println!("{}", displayable(stage.as_ref()).indent());
         }
@@ -432,9 +422,7 @@ order by
 
         let mut planner = DistributedPlanner::new();
         let job_uuid = Uuid::new_v4();
-        let stages = planner
-            .plan_query_stages(&job_uuid.to_string(), plan)
-            .await?;
+        let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
         for stage in &stages {
             println!("{}", displayable(stage.as_ref()).indent());
         }
@@ -580,9 +568,7 @@ order by
 
         let mut planner = DistributedPlanner::new();
         let job_uuid = Uuid::new_v4();
-        let stages = planner
-            .plan_query_stages(&job_uuid.to_string(), plan)
-            .await?;
+        let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
 
         let partial_hash = stages[0].children()[0].clone();
         let partial_hash_serde = roundtrip_operator(partial_hash.clone())?;
diff --git a/ballista/rust/scheduler/src/scheduler_server/event.rs b/ballista/rust/scheduler/src/scheduler_server/event.rs
index 9252453e..458fb875 100644
--- a/ballista/rust/scheduler/src/scheduler_server/event.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/event.rs
@@ -15,19 +15,28 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use datafusion::physical_plan::ExecutionPlan;
+use crate::state::executor_manager::ExecutorReservation;
+
+use datafusion::logical_plan::LogicalPlan;
+
+use datafusion::prelude::SessionContext;
 use std::sync::Arc;
 
-#[derive(Clone)]
-pub(crate) enum SchedulerServerEvent {
-    // number of offer rounds
-    ReviveOffers(u32),
+#[derive(Clone, Debug)]
+pub enum SchedulerServerEvent {
+    /// Offer a list of executor reservations (representing executor task slots available for scheduling)
+    Offer(Vec<ExecutorReservation>),
 }
 
 #[derive(Clone)]
 pub enum QueryStageSchedulerEvent {
-    JobSubmitted(String, Arc<dyn ExecutionPlan>),
-    StageFinished(String, u32),
+    JobQueued {
+        job_id: String,
+        session_id: String,
+        session_ctx: Arc<SessionContext>,
+        plan: Box<LogicalPlan>,
+    },
+    JobSubmitted(String),
     JobFinished(String),
-    JobFailed(String, u32, String),
+    JobFailed(String, String),
 }
diff --git a/ballista/rust/scheduler/src/scheduler_server/event_loop.rs b/ballista/rust/scheduler/src/scheduler_server/event_loop.rs
index 5ba830e8..b0037ab0 100644
--- a/ballista/rust/scheduler/src/scheduler_server/event_loop.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/event_loop.rs
@@ -16,132 +16,109 @@
 // under the License.
 
 use std::sync::Arc;
-use std::time::Duration;
 
 use async_trait::async_trait;
-use log::{debug, warn};
+use log::{error, info};
 
 use crate::scheduler_server::event::SchedulerServerEvent;
-use crate::scheduler_server::ExecutorsClient;
-use crate::state::task_scheduler::TaskScheduler;
-use crate::state::SchedulerState;
 use ballista_core::error::{BallistaError, Result};
 use ballista_core::event_loop::EventAction;
-use ballista_core::serde::protobuf::{LaunchTaskParams, TaskDefinition};
-use ballista_core::serde::scheduler::ExecutorDataChange;
+
 use ballista_core::serde::AsExecutionPlan;
 use datafusion_proto::logical_plan::AsLogicalPlan;
 
+use crate::state::executor_manager::ExecutorReservation;
+use crate::state::SchedulerState;
+
+/// EventAction which will process `SchedulerServerEvent`s.
+/// In push-based scheduling, this is the primary mechanism for scheduling tasks
+/// on executors.
 pub(crate) struct SchedulerServerEventAction<
     T: 'static + AsLogicalPlan,
     U: 'static + AsExecutionPlan,
 > {
     state: Arc<SchedulerState<T, U>>,
-    executors_client: ExecutorsClient,
 }
 
 impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
     SchedulerServerEventAction<T, U>
 {
-    pub fn new(
-        state: Arc<SchedulerState<T, U>>,
-        executors_client: ExecutorsClient,
-    ) -> Self {
-        Self {
-            state,
-            executors_client,
-        }
+    pub fn new(state: Arc<SchedulerState<T, U>>) -> Self {
+        Self { state }
     }
 
-    #[allow(unused_variables)]
-    async fn offer_resources(&self, n: u32) -> Result<Option<SchedulerServerEvent>> {
-        let mut available_executors =
-            self.state.executor_manager.get_available_executors_data();
-        // In case of there's no enough resources, reschedule the tasks of the job
-        if available_executors.is_empty() {
-            // TODO Maybe it's better to use an exclusive runtime for this kind task scheduling
-            warn!("Not enough available executors for task running");
-            tokio::time::sleep(Duration::from_millis(100)).await;
-            return Ok(Some(SchedulerServerEvent::ReviveOffers(1)));
-        }
-
-        let mut executors_data_change: Vec<ExecutorDataChange> = available_executors
-            .iter()
-            .map(|executor_data| ExecutorDataChange {
-                executor_id: executor_data.executor_id.clone(),
-                task_slots: executor_data.available_task_slots as i32,
-            })
-            .collect();
-
-        let (tasks_assigment, num_tasks) = self
+    /// Process reservations which are offered. The basic process is
+    /// 1. Attempt to fill the offered reservations with available tasks
+    /// 2. For any reservation that filled, launch the assigned task on the executor.
+    /// 3. For any reservations that could not be filled, cancel the reservation (i.e. return the
+    ///    task slot back to the pool of available task slots).
+    ///
+    /// NOTE Error handling in this method is very important. No matter what we need to ensure
+    /// that unfilled reservations are cancelled or else they could become permanently "invisible"
+    /// to the scheduler.
+    async fn offer_reservation(
+        &self,
+        reservations: Vec<ExecutorReservation>,
+    ) -> Result<Option<SchedulerServerEvent>> {
+        let (free_list, pending_tasks) = match self
             .state
-            .fetch_schedulable_tasks(&mut available_executors, n)
-            .await?;
-        for (data_change, data) in executors_data_change
-            .iter_mut()
-            .zip(available_executors.iter())
+            .task_manager
+            .fill_reservations(&reservations)
+            .await
         {
-            data_change.task_slots =
-                data.available_task_slots as i32 - data_change.task_slots;
-        }
-
-        #[cfg(not(test))]
-        if num_tasks > 0 {
-            self.launch_tasks(&executors_data_change, tasks_assigment)
-                .await?;
-        }
-
-        Ok(None)
-    }
-
-    #[allow(dead_code)]
-    async fn launch_tasks(
-        &self,
-        executors: &[ExecutorDataChange],
-        tasks_assigment: Vec<Vec<TaskDefinition>>,
-    ) -> Result<()> {
-        for (idx_executor, tasks) in tasks_assigment.into_iter().enumerate() {
-            if !tasks.is_empty() {
-                let executor_data_change = &executors[idx_executor];
-                debug!(
-                    "Start to launch tasks {:?} to executor {:?}",
-                    tasks
-                        .iter()
-                        .map(|task| {
-                            if let Some(task_id) = task.task_id.as_ref() {
-                                format!(
-                                    "{}/{}/{}",
-                                    task_id.job_id,
-                                    task_id.stage_id,
-                                    task_id.partition_id
-                                )
-                            } else {
-                                "".to_string()
+            Ok((assignments, mut unassigned_reservations, pending_tasks)) => {
+                for (executor_id, task) in assignments.into_iter() {
+                    match self
+                        .state
+                        .executor_manager
+                        .get_executor_metadata(&executor_id)
+                        .await
+                    {
+                        Ok(executor) => {
+                            if let Err(e) =
+                                self.state.task_manager.launch_task(&executor, task).await
+                            {
+                                error!("Failed to launch new task: {:?}", e);
+                                unassigned_reservations.push(
+                                    ExecutorReservation::new_free(executor_id.clone()),
+                                );
                             }
-                        })
-                        .collect::<Vec<String>>(),
-                    executor_data_change.executor_id
-                );
-                let mut client = {
-                    let clients = self.executors_client.read().await;
-                    clients
-                        .get(&executor_data_change.executor_id)
-                        .unwrap()
-                        .clone()
-                };
-                // TODO check whether launching task is successful or not
-                client.launch_task(LaunchTaskParams { task: tasks }).await?;
-                self.state
-                    .executor_manager
-                    .update_executor_data(executor_data_change);
-            } else {
-                // Since the task assignment policy is round robin,
-                // if find tasks for one executor is empty, just break fast
-                break;
+                        }
+                        Err(e) => {
+                            error!("Failed to launch new task, could not get executor metadata: {:?}", e);
+                            unassigned_reservations
+                                .push(ExecutorReservation::new_free(executor_id.clone()));
+                        }
+                    }
+                }
+                (unassigned_reservations, pending_tasks)
             }
-        }
+            Err(e) => {
+                error!("Error filling reservations: {:?}", e);
+                (reservations, 0)
+            }
+        };
 
-        Ok(())
+        dbg!(free_list.clone());
+        dbg!(pending_tasks);
+        // If any reserved slots remain, return them to the pool
+        if !free_list.is_empty() {
+            self.state
+                .executor_manager
+                .cancel_reservations(free_list)
+                .await?;
+            Ok(None)
+        } else if pending_tasks > 0 {
+            // If there are pending tasks available, try and schedule them
+            let new_reservations = self
+                .state
+                .executor_manager
+                .reserve_slots(pending_tasks as u32)
+                .await?;
+            Ok(Some(SchedulerServerEvent::Offer(new_reservations)))
+        } else {
+            Ok(None)
+        }
     }
 }
 
@@ -149,21 +126,283 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
 impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
     EventAction<SchedulerServerEvent> for SchedulerServerEventAction<T, U>
 {
-    // TODO
-    fn on_start(&self) {}
+    fn on_start(&self) {
+        info!("Starting SchedulerServerEvent handler")
+    }
 
-    // TODO
-    fn on_stop(&self) {}
+    fn on_stop(&self) {
+        info!("Stopping SchedulerServerEvent handler")
+    }
 
     async fn on_receive(
         &self,
         event: SchedulerServerEvent,
     ) -> Result<Option<SchedulerServerEvent>> {
         match event {
-            SchedulerServerEvent::ReviveOffers(n) => self.offer_resources(n).await,
+            SchedulerServerEvent::Offer(reservations) => {
+                self.offer_reservation(reservations).await
+            }
+        }
+    }
+
+    fn on_error(&self, error: BallistaError) {
+        error!("Error in SchedulerServerEvent handler: {:?}", error);
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use crate::scheduler_server::event::SchedulerServerEvent;
+    use crate::scheduler_server::event_loop::SchedulerServerEventAction;
+    use crate::state::backend::standalone::StandaloneClient;
+    use crate::state::SchedulerState;
+    use ballista_core::config::{BallistaConfig, BALLISTA_DEFAULT_SHUFFLE_PARTITIONS};
+    use ballista_core::error::Result;
+    use ballista_core::serde::protobuf::{
+        task_status, CompletedTask, PartitionId, PhysicalPlanNode, ShuffleWritePartition,
+        TaskStatus,
+    };
+    use ballista_core::serde::scheduler::{
+        ExecutorData, ExecutorMetadata, ExecutorSpecification,
+    };
+    use ballista_core::serde::BallistaCodec;
+    use datafusion::arrow::datatypes::{DataType, Field, Schema};
+    use datafusion::execution::context::default_session_builder;
+    use datafusion::logical_expr::{col, sum};
+    use datafusion::physical_plan::ExecutionPlan;
+    use datafusion::prelude::SessionContext;
+    use datafusion::test_util::scan_empty;
+    use datafusion_proto::protobuf::LogicalPlanNode;
+    use std::sync::Arc;
+
+    // We should free any reservations which are not assigned
+    #[tokio::test]
+    async fn test_offer_free_reservations() -> Result<()> {
+        let state_storage = Arc::new(StandaloneClient::try_new_temporary()?);
+        let state: Arc<SchedulerState<LogicalPlanNode, PhysicalPlanNode>> =
+            Arc::new(SchedulerState::new(
+                state_storage,
+                "default".to_string(),
+                default_session_builder,
+                BallistaCodec::default(),
+            ));
+
+        let executors = test_executors(1, 4);
+
+        let (executor_metadata, executor_data) = executors[0].clone();
+
+        let reservations = state
+            .executor_manager
+            .register_executor(executor_metadata, executor_data, true)
+            .await?;
+
+        let event_action = Arc::new(SchedulerServerEventAction::new(state.clone()));
+
+        let result = event_action.offer_reservation(reservations).await?;
+
+        assert!(result.is_none());
+
+        // All reservations should have been cancelled so we should be able to reserve them now
+        let reservations = state.executor_manager.reserve_slots(4).await?;
+
+        assert_eq!(reservations.len(), 4);
+
+        Ok(())
+    }
+
+    // We should fill unbound reservations to any available task
+    #[tokio::test]
+    async fn test_offer_fill_reservations() -> Result<()> {
+        let config = BallistaConfig::builder()
+            .set(BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, "4")
+            .build()?;
+        let state_storage = Arc::new(StandaloneClient::try_new_temporary()?);
+        let state: Arc<SchedulerState<LogicalPlanNode, PhysicalPlanNode>> =
+            Arc::new(SchedulerState::new(
+                state_storage,
+                "default".to_string(),
+                default_session_builder,
+                BallistaCodec::default(),
+            ));
+
+        let session_ctx = state.session_manager.create_session(&config).await?;
+
+        let plan = test_graph(session_ctx.clone()).await;
+
+        // Create 4 jobs so we have four pending tasks
+        state
+            .task_manager
+            .submit_job("job-1", session_ctx.session_id().as_str(), plan.clone())
+            .await?;
+        state
+            .task_manager
+            .submit_job("job-2", session_ctx.session_id().as_str(), plan.clone())
+            .await?;
+        state
+            .task_manager
+            .submit_job("job-3", session_ctx.session_id().as_str(), plan.clone())
+            .await?;
+        state
+            .task_manager
+            .submit_job("job-4", session_ctx.session_id().as_str(), plan.clone())
+            .await?;
+
+        let executors = test_executors(1, 4);
+
+        let (executor_metadata, executor_data) = executors[0].clone();
+
+        let reservations = state
+            .executor_manager
+            .register_executor(executor_metadata, executor_data, true)
+            .await?;
+
+        let event_action = Arc::new(SchedulerServerEventAction::new(state.clone()));
+
+        let result = event_action.offer_reservation(reservations).await?;
+
+        assert!(result.is_none());
+
+        // All task slots should be assigned so we should not be able to reserve more tasks
+        let reservations = state.executor_manager.reserve_slots(4).await?;
+
+        assert_eq!(reservations.len(), 0);
+
+        Ok(())
+    }
+
+    // We should generate a new event for tasks that are still pending
+    #[tokio::test]
+    async fn test_offer_resubmit_pending() -> Result<()> {
+        let config = BallistaConfig::builder()
+            .set(BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, "4")
+            .build()?;
+        let state_storage = Arc::new(StandaloneClient::try_new_temporary()?);
+        let state: Arc<SchedulerState<LogicalPlanNode, PhysicalPlanNode>> =
+            Arc::new(SchedulerState::new(
+                state_storage,
+                "default".to_string(),
+                default_session_builder,
+                BallistaCodec::default(),
+            ));
+
+        let session_ctx = state.session_manager.create_session(&config).await?;
+
+        let plan = test_graph(session_ctx.clone()).await;
+
+        // Create a job
+        state
+            .task_manager
+            .submit_job("job-1", session_ctx.session_id().as_str(), plan.clone())
+            .await?;
+
+        let executors = test_executors(1, 4);
+
+        let (executor_metadata, executor_data) = executors[0].clone();
+
+        // Complete the first stage. So we should now have 4 pending tasks for this job stage 2
+        let mut partitions: Vec<ShuffleWritePartition> = vec![];
+
+        for partition_id in 0..4 {
+            partitions.push(ShuffleWritePartition {
+                partition_id: partition_id as u64,
+                path: "some/path".to_string(),
+                num_batches: 1,
+                num_rows: 1,
+                num_bytes: 1,
+            })
+        }
+
+        state
+            .task_manager
+            .update_task_statuses(
+                &executor_metadata,
+                vec![TaskStatus {
+                    task_id: Some(PartitionId {
+                        job_id: "job-1".to_string(),
+                        stage_id: 1,
+                        partition_id: 0,
+                    }),
+                    status: Some(task_status::Status::Completed(CompletedTask {
+                        executor_id: "executor-1".to_string(),
+                        partitions,
+                    })),
+                }],
+            )
+            .await?;
+
+        state
+            .executor_manager
+            .register_executor(executor_metadata, executor_data, false)
+            .await?;
+
+        let reservation = state.executor_manager.reserve_slots(1).await?;
+
+        assert_eq!(reservation.len(), 1);
+
+        let event_action = Arc::new(SchedulerServerEventAction::new(state.clone()));
+
+        // Offer the reservation. It should be filled with one of the 4 pending tasks. The other 3 should
+        // be reserved for the other 3 tasks, emitting another offer event
+        let result = event_action.offer_reservation(reservation).await?;
+
+        assert!(result.is_some());
+
+        match result {
+            Some(SchedulerServerEvent::Offer(reservations)) => {
+                assert_eq!(reservations.len(), 3)
+            }
+            _ => panic!("Expected 3 new reservations offered"),
         }
+
+        // Remaining 3 task slots should be reserved for pending tasks
+        let reservations = state.executor_manager.reserve_slots(4).await?;
+
+        assert_eq!(reservations.len(), 0);
+
+        Ok(())
+    }
+
+    fn test_executors(
+        total_executors: usize,
+        slots_per_executor: u32,
+    ) -> Vec<(ExecutorMetadata, ExecutorData)> {
+        let mut result: Vec<(ExecutorMetadata, ExecutorData)> = vec![];
+
+        for i in 0..total_executors {
+            result.push((
+                ExecutorMetadata {
+                    id: format!("executor-{}", i),
+                    host: format!("host-{}", i),
+                    port: 8080,
+                    grpc_port: 9090,
+                    specification: ExecutorSpecification {
+                        task_slots: slots_per_executor,
+                    },
+                },
+                ExecutorData {
+                    executor_id: format!("executor-{}", i),
+                    total_task_slots: slots_per_executor,
+                    available_task_slots: slots_per_executor,
+                },
+            ));
+        }
+
+        result
     }
 
-    // TODO
-    fn on_error(&self, _error: BallistaError) {}
+    async fn test_graph(ctx: Arc<SessionContext>) -> Arc<dyn ExecutionPlan> {
+        let schema = Schema::new(vec![
+            Field::new("id", DataType::Utf8, false),
+            Field::new("gmv", DataType::UInt64, false),
+        ]);
+
+        let plan = scan_empty(None, &schema, Some(vec![0, 1]))
+            .unwrap()
+            .aggregate(vec![col("id")], vec![sum(col("gmv"))])
+            .unwrap()
+            .build()
+            .unwrap();
+
+        ctx.create_physical_plan(&plan).await.unwrap()
+    }
 }
diff --git a/ballista/rust/scheduler/src/scheduler_server/external_scaler.rs b/ballista/rust/scheduler/src/scheduler_server/external_scaler.rs
index 1b8d42c2..d40c7b9b 100644
--- a/ballista/rust/scheduler/src/scheduler_server/external_scaler.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/external_scaler.rs
@@ -22,7 +22,7 @@ use crate::scheduler_server::externalscaler::{
 use crate::scheduler_server::SchedulerServer;
 use ballista_core::serde::AsExecutionPlan;
 use datafusion_proto::logical_plan::AsLogicalPlan;
-use log::debug;
+
 use tonic::{Request, Response};
 
 const INFLIGHT_TASKS_METRIC_NAME: &str = "inflight_tasks";
@@ -35,9 +35,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExternalScaler
         &self,
         _request: Request<ScaledObjectRef>,
     ) -> Result<Response<IsActiveResponse>, tonic::Status> {
-        let result = self.state.stage_manager.has_running_tasks();
-        debug!("Are there active tasks? {}", result);
-        Ok(Response::new(IsActiveResponse { result }))
+        Ok(Response::new(IsActiveResponse { result: true }))
     }
 
     async fn get_metric_spec(
diff --git a/ballista/rust/scheduler/src/scheduler_server/grpc.rs b/ballista/rust/scheduler/src/scheduler_server/grpc.rs
index 0cea1be4..f80da4fa 100644
--- a/ballista/rust/scheduler/src/scheduler_server/grpc.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/grpc.rs
@@ -15,48 +15,41 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::convert::TryInto;
-use std::ops::Deref;
-use std::sync::Arc;
-use std::time::Instant;
-use std::time::{SystemTime, UNIX_EPOCH};
+use ballista_core::config::{BallistaConfig, TaskSchedulingPolicy};
+
+use ballista_core::serde::protobuf::execute_query_params::{OptionalSessionId, Query};
+
+use ballista_core::serde::protobuf::executor_registration::OptionalHost;
+use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpc;
+use ballista_core::serde::protobuf::{
+    ExecuteQueryParams, ExecuteQueryResult, ExecutorHeartbeat, GetFileMetadataParams,
+    GetFileMetadataResult, GetJobStatusParams, GetJobStatusResult, HeartBeatParams,
+    HeartBeatResult, PollWorkParams, PollWorkResult, RegisterExecutorParams,
+    RegisterExecutorResult, UpdateTaskStatusParams, UpdateTaskStatusResult,
+};
+use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata};
+use ballista_core::serde::AsExecutionPlan;
+
+use object_store::{local::LocalFileSystem, path::Path, ObjectStore};
 
-use anyhow::Context;
 use datafusion::datasource::file_format::parquet::ParquetFormat;
 use datafusion::datasource::file_format::FileFormat;
 use datafusion_proto::logical_plan::AsLogicalPlan;
 use datafusion_proto::protobuf::FileType;
 use futures::TryStreamExt;
 use log::{debug, error, info, trace, warn};
-use object_store::local::LocalFileSystem;
-use object_store::path::Path;
-use object_store::ObjectStore;
-use rand::{distributions::Alphanumeric, thread_rng, Rng};
-use tonic::{Request, Response, Status};
 
-use ballista_core::config::{BallistaConfig, TaskSchedulingPolicy};
-use ballista_core::error::BallistaError;
-use ballista_core::serde::protobuf::execute_query_params::{OptionalSessionId, Query};
-use ballista_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient;
-use ballista_core::serde::protobuf::executor_registration::OptionalHost;
-use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpc;
-use ballista_core::serde::protobuf::{
-    job_status, ExecuteQueryParams, ExecuteQueryResult, ExecutorHeartbeat, FailedJob,
-    GetFileMetadataParams, GetFileMetadataResult, GetJobStatusParams, GetJobStatusResult,
-    HeartBeatParams, HeartBeatResult, JobStatus, PollWorkParams, PollWorkResult,
-    QueuedJob, RegisterExecutorParams, RegisterExecutorResult, UpdateTaskStatusParams,
-    UpdateTaskStatusResult,
-};
-use ballista_core::serde::scheduler::{
-    ExecutorData, ExecutorDataChange, ExecutorMetadata,
-};
-use ballista_core::serde::AsExecutionPlan;
+// use http_body::Body;
+use std::convert::TryInto;
+use std::ops::Deref;
+use std::sync::Arc;
 
-use crate::scheduler_server::event::QueryStageSchedulerEvent;
-use crate::scheduler_server::{
-    create_datafusion_context, update_datafusion_context, SchedulerServer,
-};
-use crate::state::task_scheduler::TaskScheduler;
+use std::time::{SystemTime, UNIX_EPOCH};
+use tonic::{Request, Response, Status};
+
+use crate::scheduler_server::event::{QueryStageSchedulerEvent, SchedulerServerEvent};
+use crate::scheduler_server::SchedulerServer;
+use crate::state::executor_manager::ExecutorReservation;
 
 #[tonic::async_trait]
 impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerGrpc
@@ -65,7 +58,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerGrpc
     async fn poll_work(
         &self,
         request: Request<PollWorkParams>,
-    ) -> std::result::Result<Response<PollWorkResult>, tonic::Status> {
+    ) -> Result<Response<PollWorkResult>, Status> {
         if let TaskSchedulingPolicy::PushStaged = self.policy {
             error!("Poll work interface is not supported for push-based task scheduling");
             return Err(tonic::Status::failed_precondition(
@@ -100,61 +93,70 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerGrpc
                     .as_secs(),
                 state: None,
             };
-            // In case that it's the first time to poll work, do registration
-            if self.state.get_executor_metadata(&metadata.id).is_none() {
-                self.state
-                    .save_executor_metadata(metadata.clone())
-                    .await
-                    .map_err(|e| {
-                        let msg = format!("Could not save executor metadata: {}", e);
-                        error!("{}", msg);
-                        tonic::Status::internal(msg)
-                    })?;
-            }
+
             self.state
                 .executor_manager
-                .save_executor_heartbeat(executor_heartbeat);
-            self.update_task_status(task_status).await.map_err(|e| {
-                let msg = format!(
-                    "Fail to update tasks status from executor {:?} due to {:?}",
-                    &metadata.id, e
-                );
-                error!("{}", msg);
-                tonic::Status::internal(msg)
-            })?;
-            let task: Result<Option<_>, Status> = if can_accept_task {
-                let mut executors_data = vec![ExecutorData {
-                    executor_id: metadata.id.clone(),
-                    total_task_slots: 1,
-                    available_task_slots: 1,
-                }];
-                let (mut tasks, num_tasks) = self
+                .save_executor_metadata(metadata.clone())
+                .await
+                .map_err(|e| {
+                    let msg = format!("Could not save executor metadata: {}", e);
+                    error!("{}", msg);
+                    Status::internal(msg)
+                })?;
+
+            self.state
+                .executor_manager
+                .save_executor_heartbeat(executor_heartbeat)
+                .await
+                .map_err(|e| {
+                    let msg = format!("Could not save executor heartbeat: {}", e);
+                    error!("{}", msg);
+                    Status::internal(msg)
+                })?;
+
+            self.update_task_status(&metadata.id, task_status)
+                .await
+                .map_err(|e| {
+                    let msg = format!(
+                        "Fail to update tasks status from executor {:?} due to {:?}",
+                        &metadata.id, e
+                    );
+                    error!("{}", msg);
+                    Status::internal(msg)
+                })?;
+
+            // If executor can accept another task, try and find one.
+            let next_task = if can_accept_task {
+                let reservations =
+                    vec![ExecutorReservation::new_free(metadata.id.clone())];
+                if let Ok((mut assignments, _, _)) = self
                     .state
-                    .fetch_schedulable_tasks(&mut executors_data, 1)
+                    .task_manager
+                    .fill_reservations(&reservations)
                     .await
-                    .map_err(|e| {
-                        let msg = format!("Error finding next assignable task: {}", e);
-                        error!("{}", msg);
-                        tonic::Status::internal(msg)
-                    })?;
-                if num_tasks == 0 {
-                    Ok(None)
+                {
+                    if let Some((_, task)) = assignments.pop() {
+                        match self.state.task_manager.prepare_task_definition(task) {
+                            Ok(task_definition) => Some(task_definition),
+                            Err(e) => {
+                                error!("Error preparing task definition: {:?}", e);
+                                None
+                            }
+                        }
+                    } else {
+                        None
+                    }
                 } else {
-                    assert_eq!(tasks.len(), 1);
-                    let mut task = tasks.pop().unwrap();
-                    assert_eq!(task.len(), 1);
-                    let task = task.pop().unwrap();
-                    Ok(Some(task))
+                    None
                 }
             } else {
-                Ok(None)
+                None
             };
-            Ok(Response::new(PollWorkResult { task: task? }))
+
+            Ok(Response::new(PollWorkResult { task: next_task }))
         } else {
             warn!("Received invalid executor poll_work request");
-            Err(tonic::Status::invalid_argument(
-                "Missing metadata in request",
-            ))
+            Err(Status::invalid_argument("Missing metadata in request"))
         }
     }
 
@@ -180,42 +182,41 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerGrpc
                 grpc_port: metadata.grpc_port as u16,
                 specification: metadata.specification.unwrap().into(),
             };
-            // Check whether the executor starts the grpc service
-            {
-                let executor_url =
-                    format!("http://{}:{}", metadata.host, metadata.grpc_port);
-                info!("Connect to executor {:?}", executor_url);
-                let executor_client = ExecutorGrpcClient::connect(executor_url)
-                    .await
-                    .context("Could not connect to executor")
-                    .map_err(|e| tonic::Status::internal(format!("{:?}", e)))?;
-                let mut clients = self.executors_client.as_ref().unwrap().write().await;
-                // TODO check duplicated registration
-                clients.insert(metadata.id.clone(), executor_client);
-                info!("Size of executor clients: {:?}", clients.len());
-            }
-            self.state
-                .save_executor_metadata(metadata.clone())
-                .await
-                .map_err(|e| {
-                    let msg = format!("Could not save executor metadata: {}", e);
-                    error!("{}", msg);
-                    tonic::Status::internal(msg)
-                })?;
             let executor_data = ExecutorData {
                 executor_id: metadata.id.clone(),
                 total_task_slots: metadata.specification.task_slots,
                 available_task_slots: metadata.specification.task_slots,
             };
-            self.state
-                .executor_manager
-                .save_executor_data(executor_data);
+
+            if let Ok(Some(sender)) =
+                self.event_loop.as_ref().map(|e| e.get_sender()).transpose()
+            {
+                // If we are using push-based scheduling then reserve this executors slots and send
+                // them for scheduling tasks.
+                let reservations = self
+                    .state
+                    .executor_manager
+                    .register_executor(metadata, executor_data, true)
+                    .await
+                    .unwrap();
+
+                sender
+                    .post_event(SchedulerServerEvent::Offer(reservations))
+                    .await
+                    .unwrap();
+            } else {
+                // Otherwise just save the executor to state
+                self.state
+                    .executor_manager
+                    .register_executor(metadata, executor_data, false)
+                    .await
+                    .unwrap();
+            }
+
             Ok(Response::new(RegisterExecutorResult { success: true }))
         } else {
             warn!("Received invalid register executor request");
-            Err(tonic::Status::invalid_argument(
-                "Missing metadata in request",
-            ))
+            Err(Status::invalid_argument("Missing metadata in request"))
         }
     }
 
@@ -237,7 +238,13 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerGrpc
         };
         self.state
             .executor_manager
-            .save_executor_heartbeat(executor_heartbeat);
+            .save_executor_heartbeat(executor_heartbeat)
+            .await
+            .map_err(|e| {
+                let msg = format!("Could not save executor heartbeat: {}", e);
+                error!("{}", msg);
+                Status::internal(msg)
+            })?;
         Ok(Response::new(HeartBeatResult { reregister: false }))
     }
 
@@ -254,28 +261,17 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerGrpc
             "Received task status update request for executor {:?}",
             executor_id
         );
-        let num_tasks = task_status.len();
-        if let Some(executor_data) =
-            self.state.executor_manager.get_executor_data(&executor_id)
-        {
-            self.state
-                .executor_manager
-                .update_executor_data(&ExecutorDataChange {
-                    executor_id: executor_data.executor_id,
-                    task_slots: num_tasks as i32,
-                });
-        } else {
-            error!("Fail to get executor data for {:?}", &executor_id);
-        }
 
-        self.update_task_status(task_status).await.map_err(|e| {
-            let msg = format!(
-                "Fail to update tasks status from executor {:?} due to {:?}",
-                &executor_id, e
-            );
-            error!("{}", msg);
-            tonic::Status::internal(msg)
-        })?;
+        self.update_task_status(&executor_id, task_status)
+            .await
+            .map_err(|e| {
+                let msg = format!(
+                    "Fail to update tasks status from executor {:?} due to {:?}",
+                    &executor_id, e
+                );
+                error!("{}", msg);
+                Status::internal(msg)
+            })?;
 
         Ok(Response::new(UpdateTaskStatusResult { success: true }))
     }
@@ -283,7 +279,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerGrpc
     async fn get_file_metadata(
         &self,
         request: Request<GetFileMetadataParams>,
-    ) -> std::result::Result<Response<GetFileMetadataResult>, tonic::Status> {
+    ) -> Result<Response<GetFileMetadataResult>, Status> {
         // TODO support multiple object stores
         let obj_store: Arc<dyn ObjectStore> = Arc::new(LocalFileSystem::new());
         // TODO shouldn't this take a ListingOption object as input?
@@ -338,7 +334,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerGrpc
     async fn execute_query(
         &self,
         request: Request<ExecuteQueryParams>,
-    ) -> std::result::Result<Response<ExecuteQueryResult>, tonic::Status> {
+    ) -> Result<Response<ExecuteQueryResult>, Status> {
         let query_params = request.into_inner();
         if let ExecuteQueryParams {
             query: Some(query),
@@ -354,32 +350,38 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerGrpc
             let config = config_builder.build().map_err(|e| {
                 let msg = format!("Could not parse configs: {}", e);
                 error!("{}", msg);
-                tonic::Status::internal(msg)
+                Status::internal(msg)
             })?;
 
-            let df_session = match optional_session_id {
+            let (session_id, session_ctx) = match optional_session_id {
                 Some(OptionalSessionId::SessionId(session_id)) => {
-                    let session_ctx = self
+                    let ctx = self
                         .state
-                        .session_registry()
-                        .lookup_session(session_id.as_str())
+                        .session_manager
+                        .update_session(&session_id, &config)
                         .await
-                        .ok_or_else(|| {
-                            Status::invalid_argument(format!(
-                                "SessionContext not found for session ID {}",
-                                session_id
+                        .map_err(|e| {
+                            Status::internal(format!(
+                                "Failed to load SessionContext for session ID {}: {:?}",
+                                session_id, e
                             ))
                         })?;
-                    update_datafusion_context(session_ctx, &config)
+                    (session_id, ctx)
                 }
                 _ => {
-                    let df_session =
-                        create_datafusion_context(&config, self.session_builder);
-                    self.state
-                        .session_registry()
-                        .register_session(df_session.clone())
-                        .await;
-                    df_session
+                    let ctx = self
+                        .state
+                        .session_manager
+                        .create_session(&config)
+                        .await
+                        .map_err(|e| {
+                            Status::internal(format!(
+                                "Failed to create SessionContext: {:?}",
+                                e
+                            ))
+                        })?;
+
+                    (ctx.session_id(), ctx)
                 }
             };
 
@@ -387,128 +389,65 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerGrpc
                 Query::LogicalPlan(message) => T::try_decode(message.as_slice())
                     .and_then(|m| {
                         m.try_into_logical_plan(
-                            df_session.deref(),
+                            session_ctx.deref(),
                             self.codec.logical_extension_codec(),
                         )
                     })
                     .map_err(|e| {
                         let msg = format!("Could not parse logical plan protobuf: {}", e);
                         error!("{}", msg);
-                        tonic::Status::internal(msg)
+                        Status::internal(msg)
                     })?,
-                Query::Sql(sql) => df_session
+                Query::Sql(sql) => session_ctx
                     .sql(&sql)
                     .await
                     .and_then(|df| df.to_logical_plan())
                     .map_err(|e| {
                         let msg = format!("Error parsing SQL: {}", e);
                         error!("{}", msg);
-                        tonic::Status::internal(msg)
+                        Status::internal(msg)
                     })?,
             };
+
             debug!("Received plan for execution: {:?}", plan);
 
-            // Generate job id.
-            // TODO Maybe the format will be changed in the future
-            let job_id = generate_job_id();
-            let session_id = df_session.session_id();
-            let state = self.state.clone();
+            let job_id = self.state.task_manager.generate_job_id();
+
+            self.state
+                .task_manager
+                .queue_job(&job_id)
+                .await
+                .map_err(|e| {
+                    let msg = format!("Failed to queue job {}: {:?}", job_id, e);
+                    error!("{}", msg);
+
+                    Status::internal(msg)
+                })?;
+
             let query_stage_event_sender =
                 self.query_stage_event_loop.get_sender().map_err(|e| {
-                    tonic::Status::internal(format!(
+                    Status::internal(format!(
                         "Could not get query stage event sender due to: {}",
                         e
                     ))
                 })?;
 
-            // Save placeholder job metadata
-            state
-                .save_job_metadata(
-                    &job_id,
-                    &JobStatus {
-                        status: Some(job_status::Status::Queued(QueuedJob {})),
-                    },
-                )
+            query_stage_event_sender
+                .post_event(QueryStageSchedulerEvent::JobQueued {
+                    job_id: job_id.clone(),
+                    session_id: session_id.clone(),
+                    session_ctx,
+                    plan: Box::new(plan),
+                })
                 .await
                 .map_err(|e| {
-                    tonic::Status::internal(format!("Could not save job metadata: {}", e))
-                })?;
+                    let msg =
+                        format!("Failed to send JobQueued event for {}: {:?}", job_id, e);
+                    error!("{}", msg);
 
-            state
-                .save_job_session(&job_id, &session_id, settings)
-                .await
-                .map_err(|e| {
-                    tonic::Status::internal(format!(
-                        "Could not save job session mapping: {}",
-                        e
-                    ))
+                    Status::internal(msg)
                 })?;
 
-            let job_id_spawn = job_id.clone();
-            let ctx = df_session.clone();
-            tokio::spawn(async move {
-                if let Err(e) = async {
-                    // create physical plan
-                    let start = Instant::now();
-                    let plan = async {
-                        let optimized_plan = ctx.optimize(&plan).map_err(|e| {
-                            let msg =
-                                format!("Could not create optimized logical plan: {}", e);
-                            error!("{}", msg);
-
-                            BallistaError::General(msg)
-                        })?;
-
-                        debug!("Calculated optimized plan: {:?}", optimized_plan);
-
-                        ctx.create_physical_plan(&optimized_plan)
-                            .await
-                            .map_err(|e| {
-                                let msg =
-                                    format!("Could not create physical plan: {}", e);
-                                error!("{}", msg);
-
-                                BallistaError::General(msg)
-                            })
-                    }
-                    .await?;
-                    info!(
-                        "DataFusion created physical plan in {} milliseconds",
-                        start.elapsed().as_millis()
-                    );
-
-                    query_stage_event_sender
-                        .post_event(QueryStageSchedulerEvent::JobSubmitted(
-                            job_id_spawn.clone(),
-                            plan,
-                        ))
-                        .await?;
-
-                    Ok::<(), BallistaError>(())
-                }
-                .await
-                {
-                    let msg = format!("Job {} failed due to {}", job_id_spawn, e);
-                    warn!("{}", msg);
-                    state
-                        .save_job_metadata(
-                            &job_id_spawn,
-                            &JobStatus {
-                                status: Some(job_status::Status::Failed(FailedJob {
-                                    error: msg.to_string(),
-                                })),
-                            },
-                        )
-                        .await
-                        .unwrap_or_else(|_| {
-                            panic!(
-                                "Fail to update job status to failed for {}",
-                                job_id_spawn
-                            )
-                        });
-                }
-            });
-
             Ok(Response::new(ExecuteQueryResult { job_id, session_id }))
         } else if let ExecuteQueryParams {
             query: None,
@@ -524,44 +463,46 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerGrpc
             let config = config_builder.build().map_err(|e| {
                 let msg = format!("Could not parse configs: {}", e);
                 error!("{}", msg);
-                tonic::Status::internal(msg)
+                Status::internal(msg)
             })?;
-            let df_session = create_datafusion_context(&config, self.session_builder);
-            self.state
-                .session_registry()
-                .register_session(df_session.clone())
-                .await;
+            let session = self
+                .state
+                .session_manager
+                .create_session(&config)
+                .await
+                .map_err(|e| {
+                    Status::internal(format!(
+                        "Failed to create new SessionContext: {:?}",
+                        e
+                    ))
+                })?;
+
             Ok(Response::new(ExecuteQueryResult {
                 job_id: "NA".to_owned(),
-                session_id: df_session.session_id(),
+                session_id: session.session_id(),
             }))
         } else {
-            Err(tonic::Status::internal("Error parsing request"))
+            Err(Status::internal("Error parsing request"))
         }
     }
 
     async fn get_job_status(
         &self,
         request: Request<GetJobStatusParams>,
-    ) -> std::result::Result<Response<GetJobStatusResult>, tonic::Status> {
+    ) -> Result<Response<GetJobStatusResult>, Status> {
         let job_id = request.into_inner().job_id;
         debug!("Received get_job_status request for job {}", job_id);
-        let job_meta = self.state.get_job_metadata(&job_id).unwrap();
-        Ok(Response::new(GetJobStatusResult {
-            status: Some(job_meta),
-        }))
+        match self.state.task_manager.get_job_status(&job_id).await {
+            Ok(status) => Ok(Response::new(GetJobStatusResult { status })),
+            Err(e) => {
+                let msg = format!("Error getting status for job {}: {:?}", job_id, e);
+                error!("{}", msg);
+                Err(Status::internal(msg))
+            }
+        }
     }
 }
 
-fn generate_job_id() -> String {
-    let mut rng = thread_rng();
-    std::iter::repeat(())
-        .map(|()| rng.sample(Alphanumeric))
-        .map(char::from)
-        .take(7)
-        .collect()
-}
-
 #[cfg(all(test, feature = "sled"))]
 mod test {
     use std::sync::Arc;
@@ -594,7 +535,7 @@ mod test {
             );
         let exec_meta = ExecutorRegistration {
             id: "abc".to_owned(),
-            optional_host: Some(OptionalHost::Host("".to_owned())),
+            optional_host: Some(OptionalHost::Host("http://host:8080".to_owned())),
             port: 0,
             grpc_port: 0,
             specification: Some(ExecutorSpecification { task_slots: 2 }.into()),
@@ -619,8 +560,18 @@ mod test {
                 BallistaCodec::default(),
             );
         state.init().await?;
+
         // executor should be registered
-        assert_eq!(state.get_executors_metadata().await.unwrap().len(), 1);
+        let stored_executor = state
+            .executor_manager
+            .get_executor_metadata("abc")
+            .await
+            .expect("getting executor");
+
+        assert_eq!(stored_executor.grpc_port, 0);
+        assert_eq!(stored_executor.port, 0);
+        assert_eq!(stored_executor.specification.task_slots, 2);
+        assert_eq!(stored_executor.host, "http://host:8080".to_owned());
 
         let request: Request<PollWorkParams> = Request::new(PollWorkParams {
             metadata: Some(exec_meta.clone()),
@@ -632,7 +583,8 @@ mod test {
             .await
             .expect("Received error response")
             .into_inner();
-        // still no response task since there are no tasks in the scheduelr
+
+        // still no response task since there are no tasks in the scheduler
         assert!(response.task.is_none());
         let state: SchedulerState<LogicalPlanNode, PhysicalPlanNode> =
             SchedulerState::new(
@@ -642,8 +594,19 @@ mod test {
                 BallistaCodec::default(),
             );
         state.init().await?;
+
         // executor should be registered
-        assert_eq!(state.get_executors_metadata().await.unwrap().len(), 1);
+        let stored_executor = state
+            .executor_manager
+            .get_executor_metadata("abc")
+            .await
+            .expect("getting executor");
+
+        assert_eq!(stored_executor.grpc_port, 0);
+        assert_eq!(stored_executor.port, 0);
+        assert_eq!(stored_executor.specification.task_slots, 2);
+        assert_eq!(stored_executor.host, "http://host:8080".to_owned());
+
         Ok(())
     }
 }
diff --git a/ballista/rust/scheduler/src/scheduler_server/mod.rs b/ballista/rust/scheduler/src/scheduler_server/mod.rs
index 10f26b1d..a6a26d80 100644
--- a/ballista/rust/scheduler/src/scheduler_server/mod.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/mod.rs
@@ -15,22 +15,20 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::collections::HashMap;
 use std::sync::Arc;
 use std::time::{SystemTime, UNIX_EPOCH};
 
-use datafusion::execution::context::{default_session_builder, SessionState};
-use datafusion::prelude::{SessionConfig, SessionContext};
-use datafusion_proto::logical_plan::AsLogicalPlan;
-use tokio::sync::RwLock;
-use tonic::transport::Channel;
-
-use ballista_core::config::{BallistaConfig, TaskSchedulingPolicy};
+use ballista_core::config::TaskSchedulingPolicy;
 use ballista_core::error::Result;
-use ballista_core::event_loop::EventLoop;
-use ballista_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient;
+use ballista_core::event_loop::{EventAction, EventLoop};
 use ballista_core::serde::protobuf::TaskStatus;
 use ballista_core::serde::{AsExecutionPlan, BallistaCodec};
+use datafusion::execution::context::{default_session_builder, SessionState};
+
+use datafusion::prelude::SessionConfig;
+use datafusion_proto::logical_plan::AsLogicalPlan;
+
+use log::error;
 
 use crate::scheduler_server::event::{QueryStageSchedulerEvent, SchedulerServerEvent};
 use crate::scheduler_server::event_loop::SchedulerServerEventAction;
@@ -50,7 +48,6 @@ mod external_scaler;
 mod grpc;
 mod query_stage_scheduler;
 
-type ExecutorsClient = Arc<RwLock<HashMap<String, ExecutorGrpcClient<Channel>>>>;
 pub(crate) type SessionBuilder = fn(SessionConfig) -> SessionState;
 
 #[derive(Clone)]
@@ -58,12 +55,9 @@ pub struct SchedulerServer<T: 'static + AsLogicalPlan, U: 'static + AsExecutionP
     pub(crate) state: Arc<SchedulerState<T, U>>,
     pub start_time: u128,
     policy: TaskSchedulingPolicy,
-    executors_client: Option<ExecutorsClient>,
     event_loop: Option<EventLoop<SchedulerServerEvent>>,
-    query_stage_event_loop: EventLoop<QueryStageSchedulerEvent>,
+    pub(crate) query_stage_event_loop: EventLoop<QueryStageSchedulerEvent>,
     codec: BallistaCodec<T, U>,
-    /// SessionState Builder
-    session_builder: SessionBuilder,
 }
 
 impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T, U> {
@@ -110,20 +104,14 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T
             codec.clone(),
         ));
 
-        let (executors_client, event_loop) =
-            if matches!(policy, TaskSchedulingPolicy::PushStaged) {
-                let executors_client = Arc::new(RwLock::new(HashMap::new()));
-                let event_action: Arc<SchedulerServerEventAction<T, U>> =
-                    Arc::new(SchedulerServerEventAction::new(
-                        state.clone(),
-                        executors_client.clone(),
-                    ));
-                let event_loop =
-                    EventLoop::new("scheduler".to_owned(), 10000, event_action);
-                (Some(executors_client), Some(event_loop))
-            } else {
-                (None, None)
-            };
+        let event_loop = if matches!(policy, TaskSchedulingPolicy::PushStaged) {
+            let event_action: Arc<SchedulerServerEventAction<T, U>> =
+                Arc::new(SchedulerServerEventAction::new(state.clone()));
+            let event_loop = EventLoop::new("scheduler".to_owned(), 10000, event_action);
+            Some(event_loop)
+        } else {
+            None
+        };
         let query_stage_scheduler =
             Arc::new(QueryStageScheduler::new(state.clone(), None));
         let query_stage_event_loop =
@@ -135,11 +123,42 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T
                 .unwrap()
                 .as_millis(),
             policy,
-            executors_client,
             event_loop,
             query_stage_event_loop,
             codec,
+            // session_builder,
+        }
+    }
+
+    pub fn new_with_event_action(
+        config: Arc<dyn StateBackendClient>,
+        namespace: String,
+        codec: BallistaCodec<T, U>,
+        session_builder: SessionBuilder,
+        event_action: Arc<dyn EventAction<SchedulerServerEvent>>,
+    ) -> Self {
+        let state = Arc::new(SchedulerState::new(
+            config,
+            namespace,
             session_builder,
+            codec.clone(),
+        ));
+
+        let event_loop = EventLoop::new("scheduler".to_owned(), 10000, event_action);
+        let query_stage_scheduler =
+            Arc::new(QueryStageScheduler::new(state.clone(), None));
+        let query_stage_event_loop =
+            EventLoop::new("query_stage".to_owned(), 10000, query_stage_scheduler);
+        Self {
+            state,
+            start_time: SystemTime::now()
+                .duration_since(UNIX_EPOCH)
+                .unwrap()
+                .as_millis(),
+            policy: TaskSchedulingPolicy::PushStaged,
+            event_loop: Some(event_loop),
+            query_stage_event_loop,
+            codec,
         }
     }
 
@@ -173,20 +192,40 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T
 
     pub(crate) async fn update_task_status(
         &self,
+        executor_id: &str,
         tasks_status: Vec<TaskStatus>,
     ) -> Result<()> {
-        let num_tasks_status = tasks_status.len() as u32;
-        let stage_events = self.state.stage_manager.update_tasks_status(tasks_status);
-        if stage_events.is_empty() {
-            if let Some(event_loop) = self.event_loop.as_ref() {
-                event_loop
-                    .get_sender()?
-                    .post_event(SchedulerServerEvent::ReviveOffers(num_tasks_status))
-                    .await?;
+        let num_status = tasks_status.len();
+        let executor = self
+            .state
+            .executor_manager
+            .get_executor_metadata(executor_id)
+            .await?;
+
+        match self
+            .state
+            .task_manager
+            .update_task_statuses(&executor, tasks_status)
+            .await
+        {
+            Ok((stage_events, offers)) => {
+                if let Some(event_loop) = self.event_loop.as_ref() {
+                    event_loop
+                        .get_sender()?
+                        .post_event(SchedulerServerEvent::Offer(offers))
+                        .await?;
+                }
+
+                for stage_event in stage_events {
+                    self.post_stage_event(stage_event).await?;
+                }
             }
-        } else {
-            for stage_event in stage_events {
-                self.post_stage_event(stage_event).await?;
+            Err(e) => {
+                error!(
+                    "Failed to update {} task statuses for executor {}: {:?}",
+                    num_status, executor_id, e
+                );
+                // TODO what do we do here?
             }
         }
 
@@ -201,158 +240,70 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T
     }
 }
 
-/// Create a DataFusion session context that is compatible with Ballista Configuration
-pub fn create_datafusion_context(
-    config: &BallistaConfig,
-    session_builder: SessionBuilder,
-) -> Arc<SessionContext> {
-    let config = SessionConfig::new()
-        .with_target_partitions(config.default_shuffle_partitions())
-        .with_batch_size(config.default_batch_size())
-        .with_repartition_joins(config.repartition_joins())
-        .with_repartition_aggregations(config.repartition_aggregations())
-        .with_repartition_windows(config.repartition_windows())
-        .with_parquet_pruning(config.parquet_pruning());
-    let session_state = session_builder(config);
-    Arc::new(SessionContext::with_state(session_state))
-}
-
-/// Update the existing DataFusion session context with Ballista Configuration
-pub fn update_datafusion_context(
-    session_ctx: Arc<SessionContext>,
-    config: &BallistaConfig,
-) -> Arc<SessionContext> {
-    {
-        let mut mut_state = session_ctx.state.write();
-        // TODO Currently we have to start from default session config due to the interface not support update
-        mut_state.config = SessionConfig::default()
-            .with_target_partitions(config.default_shuffle_partitions())
-            .with_batch_size(config.default_batch_size())
-            .with_repartition_joins(config.repartition_joins())
-            .with_repartition_aggregations(config.repartition_aggregations())
-            .with_repartition_windows(config.repartition_windows())
-            .with_parquet_pruning(config.parquet_pruning());
-    }
-    session_ctx
-}
-
-/// A Registry holds all the datafusion session contexts
-pub struct SessionContextRegistry {
-    /// A map from session_id to SessionContext
-    pub running_sessions: RwLock<HashMap<String, Arc<SessionContext>>>,
-}
-
-impl Default for SessionContextRegistry {
-    fn default() -> Self {
-        Self::new()
-    }
-}
-
-impl SessionContextRegistry {
-    /// Create the registry that session contexts can registered into.
-    /// ['LocalFileSystem'] store is registered in by default to support read local files natively.
-    pub fn new() -> Self {
-        Self {
-            running_sessions: RwLock::new(HashMap::new()),
-        }
-    }
-
-    /// Adds a new session to this registry.
-    pub async fn register_session(
-        &self,
-        session_ctx: Arc<SessionContext>,
-    ) -> Option<Arc<SessionContext>> {
-        let session_id = session_ctx.session_id();
-        let mut sessions = self.running_sessions.write().await;
-        sessions.insert(session_id, session_ctx)
-    }
-
-    /// Lookup the session context registered
-    pub async fn lookup_session(&self, session_id: &str) -> Option<Arc<SessionContext>> {
-        let sessions = self.running_sessions.read().await;
-        sessions.get(session_id).cloned()
-    }
-
-    /// Remove a session from this registry.
-    pub async fn unregister_session(
-        &self,
-        session_id: &str,
-    ) -> Option<Arc<SessionContext>> {
-        let mut sessions = self.running_sessions.write().await;
-        sessions.remove(session_id)
-    }
-}
-
 #[cfg(all(test, feature = "sled"))]
 mod test {
     use std::sync::Arc;
-    use std::time::{Duration, Instant};
+    use std::time::Duration;
 
     use datafusion::arrow::datatypes::{DataType, Field, Schema};
     use datafusion::execution::context::default_session_builder;
     use datafusion::logical_plan::{col, sum, LogicalPlan};
-    use datafusion::prelude::{SessionConfig, SessionContext};
+
     use datafusion::test_util::scan_empty;
     use datafusion_proto::protobuf::LogicalPlanNode;
 
-    use ballista_core::config::TaskSchedulingPolicy;
+    use ballista_core::config::{
+        BallistaConfig, TaskSchedulingPolicy, BALLISTA_DEFAULT_SHUFFLE_PARTITIONS,
+    };
     use ballista_core::error::{BallistaError, Result};
-    use ballista_core::execution_plans::ShuffleWriterExec;
+    use ballista_core::event_loop::EventAction;
+
     use ballista_core::serde::protobuf::{
-        job_status, task_status, CompletedTask, PartitionId, PhysicalPlanNode, TaskStatus,
+        job_status, task_status, CompletedTask, FailedTask, JobStatus, PartitionId,
+        PhysicalPlanNode, ShuffleWritePartition, TaskStatus,
+    };
+    use ballista_core::serde::scheduler::{
+        ExecutorData, ExecutorMetadata, ExecutorSpecification,
     };
-    use ballista_core::serde::scheduler::ExecutorData;
     use ballista_core::serde::BallistaCodec;
 
-    use crate::scheduler_server::event::QueryStageSchedulerEvent;
+    use crate::scheduler_server::event::{
+        QueryStageSchedulerEvent, SchedulerServerEvent,
+    };
     use crate::scheduler_server::SchedulerServer;
     use crate::state::backend::standalone::StandaloneClient;
-    use crate::state::task_scheduler::TaskScheduler;
 
-    #[tokio::test]
-    async fn test_pull_based_task_scheduling() -> Result<()> {
-        let now = Instant::now();
-        test_task_scheduling(TaskSchedulingPolicy::PullStaged, test_plan(), 4).await?;
-        println!(
-            "pull-based task scheduling cost {}ms",
-            now.elapsed().as_millis()
-        );
-
-        Ok(())
-    }
+    use crate::state::executor_manager::ExecutorReservation;
+    use crate::test_utils::{
+        await_condition, ExplodingTableProvider, SchedulerEventObserver,
+    };
 
     #[tokio::test]
-    async fn test_push_based_task_scheduling() -> Result<()> {
-        let now = Instant::now();
-        test_task_scheduling(TaskSchedulingPolicy::PushStaged, test_plan(), 4).await?;
-        println!(
-            "push-based task scheduling cost {}ms",
-            now.elapsed().as_millis()
-        );
+    async fn test_pull_scheduling() -> Result<()> {
+        let plan = test_plan();
+        let task_slots = 4;
 
-        Ok(())
-    }
+        let scheduler = test_scheduler(TaskSchedulingPolicy::PullStaged).await?;
 
-    async fn test_task_scheduling(
-        policy: TaskSchedulingPolicy,
-        plan_of_linear_stages: LogicalPlan,
-        total_available_task_slots: usize,
-    ) -> Result<()> {
-        let scheduler = test_scheduler(policy).await?;
-        if matches!(policy, TaskSchedulingPolicy::PushStaged) {
-            let executors = test_executors(total_available_task_slots);
-            for executor_data in executors {
-                scheduler
-                    .state
-                    .executor_manager
-                    .save_executor_data(executor_data);
-            }
+        let executors = test_executors(task_slots);
+        for (executor_metadata, executor_data) in executors {
+            scheduler
+                .state
+                .executor_manager
+                .register_executor(executor_metadata, executor_data, false)
+                .await?;
         }
-        let config =
-            SessionConfig::new().with_target_partitions(total_available_task_slots);
-        let ctx = Arc::new(SessionContext::with_config(config));
+
+        let config = test_session(task_slots);
+
+        let ctx = scheduler
+            .state
+            .session_manager
+            .create_session(&config)
+            .await?;
+
         let plan = async {
-            let optimized_plan = ctx.optimize(&plan_of_linear_stages).map_err(|e| {
+            let optimized_plan = ctx.optimize(&plan).map_err(|e| {
                 BallistaError::General(format!(
                     "Could not create optimized logical plan: {}",
                     e
@@ -371,185 +322,422 @@ mod test {
         .await?;
 
         let job_id = "job";
+        let session_id = ctx.session_id();
+
+        // Submit job
         scheduler
             .state
-            .session_registry()
-            .register_session(ctx.clone())
-            .await;
-        scheduler
+            .task_manager
+            .submit_job(job_id, &session_id, plan)
+            .await
+            .expect("submitting plan");
+
+        loop {
+            // Refresh the ExecutionGraph
+            let mut graph = scheduler
+                .state
+                .task_manager
+                .get_execution_graph(job_id)
+                .await?;
+
+            if let Some(task) = graph.pop_next_task("executor-1")? {
+                let mut partitions: Vec<ShuffleWritePartition> = vec![];
+
+                let num_partitions = task
+                    .output_partitioning
+                    .map(|p| p.partition_count())
+                    .unwrap_or(1);
+
+                for partition_id in 0..num_partitions {
+                    partitions.push(ShuffleWritePartition {
+                        partition_id: partition_id as u64,
+                        path: "some/path".to_string(),
+                        num_batches: 1,
+                        num_rows: 1,
+                        num_bytes: 1,
+                    })
+                }
+
+                // Complete the task
+                let task_status = TaskStatus {
+                    status: Some(task_status::Status::Completed(CompletedTask {
+                        executor_id: "executor-1".to_owned(),
+                        partitions,
+                    })),
+                    task_id: Some(PartitionId {
+                        job_id: job_id.to_owned(),
+                        stage_id: task.partition.stage_id as u32,
+                        partition_id: task.partition.partition_id as u32,
+                    }),
+                };
+
+                scheduler
+                    .update_task_status("executor-1", vec![task_status])
+                    .await?;
+            } else {
+                break;
+            }
+        }
+
+        let final_graph = scheduler
             .state
-            .save_job_session(job_id, ctx.session_id().as_str(), vec![])
+            .task_manager
+            .get_execution_graph(job_id)
             .await?;
-        {
-            // verify job submit
+
+        assert!(final_graph.complete());
+        assert_eq!(final_graph.output_locations().len(), 4);
+
+        for output_location in final_graph.output_locations() {
+            assert_eq!(output_location.path, "some/path".to_owned());
+            assert_eq!(output_location.executor_meta.host, "localhost1".to_owned())
+        }
+
+        Ok(())
+    }
+
+    /// This test will exercise the push-based scheduling. We setup our scheduler server
+    /// with `SchedulerEventObserver` to listen to `SchedulerServerEvents` and then just immediately
+    /// complete the tasks.
+    #[tokio::test]
+    async fn test_push_scheduling() -> Result<()> {
+        let plan = test_plan();
+        let task_slots = 4;
+
+        let (sender, mut event_receiver) =
+            tokio::sync::mpsc::channel::<SchedulerServerEvent>(1000);
+        let (error_sender, _) = tokio::sync::mpsc::channel::<BallistaError>(1000);
+
+        let event_action = SchedulerEventObserver::new(sender, error_sender);
+
+        let scheduler = test_scheduler_with_event_action(Arc::new(event_action)).await?;
+
+        let executors = test_executors(task_slots);
+        for (executor_metadata, executor_data) in executors {
             scheduler
-                .post_stage_event(QueryStageSchedulerEvent::JobSubmitted(
-                    job_id.to_owned(),
-                    plan,
-                ))
+                .state
+                .executor_manager
+                .register_executor(executor_metadata, executor_data, false)
                 .await?;
-
-            let waiting_time_ms =
-                test_waiting_async(|| scheduler.state.get_job_metadata(job_id).is_some())
-                    .await;
-            let job_status = scheduler.state.get_job_metadata(job_id);
-            assert!(
-                job_status.is_some(),
-                "Fail to receive JobSubmitted event within {}ms",
-                waiting_time_ms
-            );
         }
 
-        let stage_task_num = test_get_job_stage_task_num(&scheduler, job_id);
-        let first_stage_id = 1u32;
-        let final_stage_id = stage_task_num.len() as u32 - 1;
-        assert!(scheduler
+        let config = test_session(task_slots);
+
+        let ctx = scheduler
             .state
-            .stage_manager
-            .is_final_stage(job_id, final_stage_id));
+            .session_manager
+            .create_session(&config)
+            .await?;
 
-        if matches!(policy, TaskSchedulingPolicy::PullStaged) {
-            assert!(!scheduler.state.stage_manager.has_running_tasks());
-            assert!(scheduler
+        let job_id = "job";
+        let session_id = ctx.session_id();
+
+        // Send JobQueued event to kick off the event loop
+        scheduler
+            .query_stage_event_loop
+            .get_sender()?
+            .post_event(QueryStageSchedulerEvent::JobQueued {
+                job_id: job_id.to_owned(),
+                session_id,
+                session_ctx: ctx,
+                plan: Box::new(plan),
+            })
+            .await?;
+
+        // Complete tasks that are offered through scheduler events
+        while let Some(SchedulerServerEvent::Offer(reservations)) =
+            event_receiver.recv().await
+        {
+            let free_list = match scheduler
                 .state
-                .stage_manager
-                .is_running_stage(job_id, first_stage_id));
-            if first_stage_id != final_stage_id {
-                assert!(scheduler
-                    .state
-                    .stage_manager
-                    .is_pending_stage(job_id, final_stage_id));
-            }
-        }
+                .task_manager
+                .fill_reservations(&reservations)
+                .await
+            {
+                Ok((assignments, mut unassigned_reservations, _)) => {
+                    // Break when we are no longer assigning tasks
+                    if unassigned_reservations.len() == reservations.len() {
+                        break;
+                    }
+
+                    for (executor_id, task) in assignments.into_iter() {
+                        match scheduler
+                            .state
+                            .executor_manager
+                            .get_executor_metadata(&executor_id)
+                            .await
+                        {
+                            Ok(executor) => {
+                                let mut partitions: Vec<ShuffleWritePartition> = vec![];
+
+                                let num_partitions = task
+                                    .output_partitioning
+                                    .map(|p| p.partition_count())
+                                    .unwrap_or(1);
+
+                                for partition_id in 0..num_partitions {
+                                    partitions.push(ShuffleWritePartition {
+                                        partition_id: partition_id as u64,
+                                        path: "some/path".to_string(),
+                                        num_batches: 1,
+                                        num_rows: 1,
+                                        num_bytes: 1,
+                                    })
+                                }
+
+                                // Complete the task
+                                let task_status = TaskStatus {
+                                    status: Some(task_status::Status::Completed(
+                                        CompletedTask {
+                                            executor_id: executor.id.clone(),
+                                            partitions,
+                                        },
+                                    )),
+                                    task_id: Some(PartitionId {
+                                        job_id: job_id.to_owned(),
+                                        stage_id: task.partition.stage_id as u32,
+                                        partition_id: task.partition.partition_id as u32,
+                                    }),
+                                };
+
+                                scheduler
+                                    .update_task_status(&executor.id, vec![task_status])
+                                    .await?;
+                            }
+                            Err(_e) => {
+                                unassigned_reservations.push(
+                                    ExecutorReservation::new_free(executor_id.clone()),
+                                );
+                            }
+                        }
+                    }
+                    unassigned_reservations
+                }
+                Err(_e) => reservations,
+            };
 
-        // complete stage one by one
-        for stage_id in first_stage_id..final_stage_id {
-            let next_stage_id = stage_id + 1;
-            let num_tasks = stage_task_num[stage_id as usize] as usize;
-            if matches!(policy, TaskSchedulingPolicy::PullStaged) {
-                let mut executors = test_executors(total_available_task_slots);
-                let _fet_tasks = scheduler
+            // If any reserved slots remain, return them to the pool
+            if !free_list.is_empty() {
+                scheduler
                     .state
-                    .fetch_schedulable_tasks(&mut executors, 1)
+                    .executor_manager
+                    .cancel_reservations(free_list)
                     .await?;
             }
-            assert!(scheduler.state.stage_manager.has_running_tasks());
-            assert!(scheduler
-                .state
-                .stage_manager
-                .is_running_stage(job_id, stage_id));
-            assert!(scheduler
-                .state
-                .stage_manager
-                .is_pending_stage(job_id, next_stage_id));
+        }
 
-            test_complete_stage(&scheduler, job_id, 1, num_tasks).await?;
-            assert!(!scheduler.state.stage_manager.has_running_tasks());
-            assert!(!scheduler
-                .state
-                .stage_manager
-                .is_running_stage(job_id, stage_id));
-            assert!(scheduler
-                .state
-                .stage_manager
-                .is_completed_stage(job_id, stage_id));
-            let waiting_time_ms = test_waiting_async(|| {
-                !scheduler
-                    .state
-                    .stage_manager
-                    .is_pending_stage(job_id, next_stage_id)
-            })
-            .await;
-            assert!(
-                !scheduler
-                    .state
-                    .stage_manager
-                    .is_pending_stage(job_id, next_stage_id),
-                "Fail to update stage state machine within {}ms",
-                waiting_time_ms
-            );
-            assert!(scheduler
+        let final_graph = scheduler
+            .state
+            .task_manager
+            .get_execution_graph(job_id)
+            .await?;
+
+        assert!(final_graph.complete());
+        assert_eq!(final_graph.output_locations().len(), 4);
+
+        Ok(())
+    }
+
+    // Simulate a task failure and ensure the job status is updated correctly
+    #[tokio::test]
+    async fn test_job_failure() -> Result<()> {
+        let plan = test_plan();
+        let task_slots = 4;
+
+        let (sender, mut event_receiver) =
+            tokio::sync::mpsc::channel::<SchedulerServerEvent>(1000);
+        let (error_sender, _) = tokio::sync::mpsc::channel::<BallistaError>(1000);
+
+        let event_action = SchedulerEventObserver::new(sender, error_sender);
+
+        let scheduler = test_scheduler_with_event_action(Arc::new(event_action)).await?;
+
+        let executors = test_executors(task_slots);
+        for (executor_metadata, executor_data) in executors {
+            scheduler
                 .state
-                .stage_manager
-                .is_running_stage(job_id, next_stage_id));
+                .executor_manager
+                .register_executor(executor_metadata, executor_data, false)
+                .await?;
         }
 
-        // complete the final stage
+        let config = test_session(task_slots);
+
+        let ctx = scheduler
+            .state
+            .session_manager
+            .create_session(&config)
+            .await?;
+
+        let job_id = "job";
+        let session_id = ctx.session_id();
+
+        // Send JobQueued event to kick off the event loop
+        scheduler
+            .query_stage_event_loop
+            .get_sender()?
+            .post_event(QueryStageSchedulerEvent::JobQueued {
+                job_id: job_id.to_owned(),
+                session_id,
+                session_ctx: ctx,
+                plan: Box::new(plan),
+            })
+            .await?;
+
+        // Complete tasks that are offered through scheduler events
+        if let Some(SchedulerServerEvent::Offer(reservations)) =
+            event_receiver.recv().await
         {
-            let num_tasks = stage_task_num[final_stage_id as usize] as usize;
-            if matches!(policy, TaskSchedulingPolicy::PullStaged) {
-                let mut executors = test_executors(total_available_task_slots);
-                let _fet_tasks = scheduler
+            let free_list = match scheduler
+                .state
+                .task_manager
+                .fill_reservations(&reservations)
+                .await
+            {
+                Ok((assignments, mut unassigned_reservations, _)) => {
+                    for (executor_id, task) in assignments.into_iter() {
+                        match scheduler
+                            .state
+                            .executor_manager
+                            .get_executor_metadata(&executor_id)
+                            .await
+                        {
+                            Ok(executor) => {
+                                let mut partitions: Vec<ShuffleWritePartition> = vec![];
+
+                                let num_partitions = task
+                                    .output_partitioning
+                                    .map(|p| p.partition_count())
+                                    .unwrap_or(1);
+
+                                for partition_id in 0..num_partitions {
+                                    partitions.push(ShuffleWritePartition {
+                                        partition_id: partition_id as u64,
+                                        path: "some/path".to_string(),
+                                        num_batches: 1,
+                                        num_rows: 1,
+                                        num_bytes: 1,
+                                    })
+                                }
+
+                                // Complete the task
+                                let task_status = TaskStatus {
+                                    status: Some(task_status::Status::Failed(
+                                        FailedTask {
+                                            error: "".to_string(),
+                                        },
+                                    )),
+                                    task_id: Some(PartitionId {
+                                        job_id: job_id.to_owned(),
+                                        stage_id: task.partition.stage_id as u32,
+                                        partition_id: task.partition.partition_id as u32,
+                                    }),
+                                };
+
+                                scheduler
+                                    .update_task_status(&executor.id, vec![task_status])
+                                    .await?;
+                            }
+                            Err(_e) => {
+                                unassigned_reservations.push(
+                                    ExecutorReservation::new_free(executor_id.clone()),
+                                );
+                            }
+                        }
+                    }
+                    unassigned_reservations
+                }
+                Err(_e) => reservations,
+            };
+
+            // If any reserved slots remain, return them to the pool
+            if !free_list.is_empty() {
+                scheduler
                     .state
-                    .fetch_schedulable_tasks(&mut executors, 1)
+                    .executor_manager
+                    .cancel_reservations(free_list)
                     .await?;
             }
-            assert!(scheduler.state.stage_manager.has_running_tasks());
+        } else {
+            panic!("No reservations offered");
+        }
 
-            test_complete_stage(&scheduler, job_id, final_stage_id, num_tasks).await?;
-            assert!(!scheduler.state.stage_manager.has_running_tasks());
-            assert!(!scheduler
-                .state
-                .stage_manager
-                .is_running_stage(job_id, final_stage_id));
-            assert!(scheduler
-                .state
-                .stage_manager
-                .is_completed_stage(job_id, final_stage_id));
-            let waiting_time_ms = test_waiting_async(|| {
-                let job_status = scheduler.state.get_job_metadata(job_id).unwrap();
-                matches!(job_status.status, Some(job_status::Status::Completed(_)))
-            })
-            .await;
+        let status = scheduler.state.task_manager.get_job_status(job_id).await?;
 
-            let job_status = scheduler.state.get_job_metadata(job_id).unwrap();
-            assert!(
-                matches!(job_status.status, Some(job_status::Status::Completed(_))),
-                "Fail to update job state machine within {}ms",
-                waiting_time_ms
-            );
-        }
+        assert!(
+            matches!(
+                status,
+                Some(JobStatus {
+                    status: Some(job_status::Status::Failed(_))
+                })
+            ),
+            "Expected job status to be failed"
+        );
 
         Ok(())
     }
 
-    async fn test_waiting_async<F>(cond: F) -> u64
-    where
-        F: Fn() -> bool,
-    {
-        let round_waiting_time = 10;
-        let num_round = 5;
-        for _i in 0..num_round {
-            if cond() {
-                break;
-            }
-            tokio::time::sleep(Duration::from_millis(round_waiting_time)).await;
-        }
+    // If the physical planning fails, the job should be marked as failed.
+    // Here we simulate a planning failure using ExplodingTableProvider to test this.
+    #[tokio::test]
+    async fn test_planning_failure() -> Result<()> {
+        let task_slots = 4;
 
-        round_waiting_time * num_round
-    }
+        let (sender, _event_receiver) =
+            tokio::sync::mpsc::channel::<SchedulerServerEvent>(1000);
+        let (error_sender, _) = tokio::sync::mpsc::channel::<BallistaError>(1000);
 
-    async fn test_complete_stage(
-        scheduler: &SchedulerServer<LogicalPlanNode, PhysicalPlanNode>,
-        job_id: &str,
-        stage_id: u32,
-        num_tasks: usize,
-    ) -> Result<()> {
-        let tasks_status: Vec<TaskStatus> = (0..num_tasks as u32)
-            .into_iter()
-            .map(|task_id| TaskStatus {
-                status: Some(task_status::Status::Completed(CompletedTask {
-                    executor_id: "localhost".to_owned(),
-                    partitions: Vec::new(),
-                })),
-                task_id: Some(PartitionId {
-                    job_id: job_id.to_owned(),
-                    stage_id,
-                    partition_id: task_id,
-                }),
+        let event_action = SchedulerEventObserver::new(sender, error_sender);
+
+        let scheduler = test_scheduler_with_event_action(Arc::new(event_action)).await?;
+
+        let config = test_session(task_slots);
+
+        let ctx = scheduler
+            .state
+            .session_manager
+            .create_session(&config)
+            .await?;
+
+        ctx.register_table("explode", Arc::new(ExplodingTableProvider))?;
+
+        let plan = ctx.sql("SELECT * FROM explode").await?.to_logical_plan()?;
+
+        let job_id = "job";
+        let session_id = ctx.session_id();
+
+        // Send JobQueued event to kick off the event loop
+        // This should fail when we try and create the physical plan
+        scheduler
+            .query_stage_event_loop
+            .get_sender()?
+            .post_event(QueryStageSchedulerEvent::JobQueued {
+                job_id: job_id.to_owned(),
+                session_id,
+                session_ctx: ctx,
+                plan: Box::new(plan),
             })
-            .collect();
-        scheduler.update_task_status(tasks_status).await
+            .await?;
+
+        let scheduler = scheduler.clone();
+
+        let check = || async {
+            let status = scheduler.state.task_manager.get_job_status(job_id).await?;
+
+            Ok(matches!(
+                status,
+                Some(JobStatus {
+                    status: Some(job_status::Status::Failed(_))
+                })
+            ))
+        };
+
+        // Sine this happens in an event loop, we need to check a few times.
+        let job_failed = await_condition(Duration::from_millis(100), 10, check).await?;
+
+        assert!(job_failed, "Job status not failed after 1 second");
+
+        Ok(())
     }
 
     async fn test_scheduler(
@@ -569,41 +757,59 @@ mod test {
         Ok(scheduler)
     }
 
-    fn test_executors(num_partitions: usize) -> Vec<ExecutorData> {
-        let task_slots = (num_partitions as u32 + 1) / 2;
+    async fn test_scheduler_with_event_action(
+        event_action: Arc<dyn EventAction<SchedulerServerEvent>>,
+    ) -> Result<SchedulerServer<LogicalPlanNode, PhysicalPlanNode>> {
+        let state_storage = Arc::new(StandaloneClient::try_new_temporary()?);
 
-        vec![
-            ExecutorData {
-                executor_id: "localhost1".to_owned(),
-                total_task_slots: task_slots,
-                available_task_slots: task_slots,
-            },
-            ExecutorData {
-                executor_id: "localhost2".to_owned(),
-                total_task_slots: num_partitions as u32 - task_slots,
-                available_task_slots: num_partitions as u32 - task_slots,
-            },
-        ]
+        let mut scheduler: SchedulerServer<LogicalPlanNode, PhysicalPlanNode> =
+            SchedulerServer::new_with_event_action(
+                state_storage.clone(),
+                "default".to_owned(),
+                BallistaCodec::default(),
+                default_session_builder,
+                event_action,
+            );
+        scheduler.init().await?;
+
+        Ok(scheduler)
     }
 
-    fn test_get_job_stage_task_num(
-        scheduler: &SchedulerServer<LogicalPlanNode, PhysicalPlanNode>,
-        job_id: &str,
-    ) -> Vec<u32> {
-        let mut ret = vec![0, 1];
-        let mut stage_id = 1;
-        while let Some(stage_plan) = scheduler.state.get_stage_plan(job_id, stage_id) {
-            if let Some(shuffle_writer) =
-                stage_plan.as_any().downcast_ref::<ShuffleWriterExec>()
-            {
-                if let Some(partitions) = shuffle_writer.shuffle_output_partitioning() {
-                    ret.push(partitions.partition_count() as u32)
-                }
-            }
-            stage_id += 1;
-        }
+    fn test_executors(num_partitions: usize) -> Vec<(ExecutorMetadata, ExecutorData)> {
+        let task_slots = (num_partitions as u32 + 1) / 2;
 
-        ret
+        vec![
+            (
+                ExecutorMetadata {
+                    id: "executor-1".to_string(),
+                    host: "localhost1".to_string(),
+                    port: 8080,
+                    grpc_port: 9090,
+                    specification: ExecutorSpecification { task_slots },
+                },
+                ExecutorData {
+                    executor_id: "executor-1".to_owned(),
+                    total_task_slots: task_slots,
+                    available_task_slots: task_slots,
+                },
+            ),
+            (
+                ExecutorMetadata {
+                    id: "executor-2".to_string(),
+                    host: "localhost2".to_string(),
+                    port: 8080,
+                    grpc_port: 9090,
+                    specification: ExecutorSpecification {
+                        task_slots: num_partitions as u32 - task_slots,
+                    },
+                },
+                ExecutorData {
+                    executor_id: "executor-2".to_owned(),
+                    total_task_slots: num_partitions as u32 - task_slots,
+                    available_task_slots: num_partitions as u32 - task_slots,
+                },
+            ),
+        ]
     }
 
     fn test_plan() -> LogicalPlan {
@@ -619,4 +825,14 @@ mod test {
             .build()
             .unwrap()
     }
+
+    fn test_session(partitions: usize) -> BallistaConfig {
+        BallistaConfig::builder()
+            .set(
+                BALLISTA_DEFAULT_SHUFFLE_PARTITIONS,
+                format!("{}", partitions).as_str(),
+            )
+            .build()
+            .expect("creating BallistaConfig")
+    }
 }
diff --git a/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs b/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs
index fe1c2d5d..8596aa60 100644
--- a/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs
@@ -15,27 +15,24 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::planner::{
-    find_unresolved_shuffles, remove_unresolved_shuffles, DistributedPlanner,
-};
-use crate::scheduler_server::event::{QueryStageSchedulerEvent, SchedulerServerEvent};
-use crate::state::SchedulerState;
-use async_recursion::async_recursion;
+use std::sync::Arc;
+use std::time::Instant;
+
 use async_trait::async_trait;
+use datafusion::logical_plan::LogicalPlan;
+use datafusion::prelude::SessionContext;
+use log::{debug, error, info};
+
 use ballista_core::error::{BallistaError, Result};
 use ballista_core::event_loop::{EventAction, EventSender};
-use ballista_core::execution_plans::UnresolvedShuffleExec;
-use ballista_core::serde::protobuf::{
-    job_status, task_status, CompletedJob, CompletedTask, FailedJob, FailedTask,
-    JobStatus, RunningJob, TaskStatus,
-};
-use ballista_core::serde::scheduler::{ExecutorMetadata, PartitionStats};
-use ballista_core::serde::{protobuf, AsExecutionPlan};
-use datafusion::physical_plan::ExecutionPlan;
+
+use ballista_core::serde::AsExecutionPlan;
 use datafusion_proto::logical_plan::AsLogicalPlan;
-use log::{debug, error, info, warn};
-use std::collections::{HashMap, HashSet};
-use std::sync::Arc;
+
+use crate::scheduler_server::event::{QueryStageSchedulerEvent, SchedulerServerEvent};
+
+use crate::state::executor_manager::ExecutorReservation;
+use crate::state::SchedulerState;
 
 pub(crate) struct QueryStageScheduler<
     T: 'static + AsLogicalPlan,
@@ -56,418 +53,121 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> QueryStageSchedul
         }
     }
 
-    async fn generate_stages(
+    async fn submit_job(
         &self,
-        job_id: &str,
-        plan: Arc<dyn ExecutionPlan>,
+        job_id: String,
+        session_id: String,
+        session_ctx: Arc<SessionContext>,
+        plan: &LogicalPlan,
     ) -> Result<()> {
-        let mut planner = DistributedPlanner::new();
-        // The last one is the final stage
-        let stages = planner.plan_query_stages(job_id, plan).await.map_err(|e| {
-            let msg = format!("Could not plan query stages: {}", e);
-            error!("{}", msg);
-            BallistaError::General(msg)
-        })?;
+        let start = Instant::now();
+        let optimized_plan = session_ctx.optimize(plan)?;
 
-        let mut stages_dependency: HashMap<u32, HashSet<u32>> = HashMap::new();
-        // save stages into state
-        for shuffle_writer in stages.iter() {
-            let stage_id = shuffle_writer.stage_id();
-            let stage_plan: Arc<dyn ExecutionPlan> = shuffle_writer.clone();
-            self.state
-                .save_stage_plan(job_id, stage_id, stage_plan.clone())
-                .await
-                .map_err(|e| {
-                    let msg = format!("Could not save stage plan: {}", e);
-                    error!("{}", msg);
-                    BallistaError::General(msg)
-                })?;
+        debug!("Calculated optimized plan: {:?}", optimized_plan);
 
-            for child in find_unresolved_shuffles(&stage_plan)? {
-                stages_dependency
-                    .entry(child.stage_id as u32)
-                    .or_insert_with(HashSet::new)
-                    .insert(stage_id as u32);
-            }
-        }
+        let plan = session_ctx.create_physical_plan(&optimized_plan).await?;
 
         self.state
-            .stage_manager
-            .add_stages_dependency(job_id, stages_dependency);
-
-        let final_stage_id = stages.last().as_ref().unwrap().stage_id();
-        self.state
-            .stage_manager
-            .add_final_stage(job_id, final_stage_id as u32);
-        self.submit_stage(job_id, final_stage_id).await?;
-
-        Ok(())
-    }
+            .task_manager
+            .submit_job(&job_id, &session_id, plan.clone())
+            .await?;
 
-    async fn submit_pending_stages(&self, job_id: &str, stage_id: usize) -> Result<()> {
-        if let Some(parent_stages) = self
-            .state
-            .stage_manager
-            .get_parent_stages(job_id, stage_id as u32)
-        {
-            self.state
-                .stage_manager
-                .remove_pending_stage(job_id, &parent_stages);
-            for parent_stage in parent_stages {
-                self.submit_stage(job_id, parent_stage as usize).await?;
-            }
-        }
+        let elapsed = start.elapsed();
 
-        Ok(())
-    }
+        info!("Planned job {} in {:?}", job_id, elapsed);
 
-    #[async_recursion]
-    async fn submit_stage(&self, job_id: &str, stage_id: usize) -> Result<()> {
-        {
-            if self
-                .state
-                .stage_manager
-                .is_running_stage(job_id, stage_id as u32)
-            {
-                debug!("stage {}/{} has already been submitted", job_id, stage_id);
-                return Ok(());
-            }
-            if self
-                .state
-                .stage_manager
-                .is_pending_stage(job_id, stage_id as u32)
-            {
-                debug!(
-                    "stage {}/{} has already been added to the pending list",
-                    job_id, stage_id
-                );
-                return Ok(());
-            }
-        }
-        if let Some(stage_plan) = self.state.get_stage_plan(job_id, stage_id) {
-            if let Some(incomplete_unresolved_shuffles) = self
-                .try_resolve_stage(job_id, stage_id, stage_plan.clone())
-                .await?
-            {
-                assert!(
-                    !incomplete_unresolved_shuffles.is_empty(),
-                    "there are no incomplete unresolved shuffles"
-                );
-                for incomplete_unresolved_shuffle in incomplete_unresolved_shuffles {
-                    self.submit_stage(job_id, incomplete_unresolved_shuffle.stage_id)
-                        .await?;
-                }
-                self.state
-                    .stage_manager
-                    .add_pending_stage(job_id, stage_id as u32);
-            } else {
-                self.state.stage_manager.add_running_stage(
-                    job_id,
-                    stage_id as u32,
-                    stage_plan.output_partitioning().partition_count() as u32,
-                );
-            }
-        } else {
-            return Err(BallistaError::General(format!(
-                "Fail to find stage plan for {}/{}",
-                job_id, stage_id
-            )));
-        }
         Ok(())
     }
-
-    /// Try to resolve a stage if all of the unresolved shuffles are completed.
-    /// Return the unresolved shuffles which are incomplete
-    async fn try_resolve_stage(
-        &self,
-        job_id: &str,
-        stage_id: usize,
-        stage_plan: Arc<dyn ExecutionPlan>,
-    ) -> Result<Option<Vec<UnresolvedShuffleExec>>> {
-        // Find all of the unresolved shuffles
-        let unresolved_shuffles = find_unresolved_shuffles(&stage_plan)?;
-
-        // If no dependent shuffles
-        if unresolved_shuffles.is_empty() {
-            return Ok(None);
-        }
-
-        // Find all of the incomplete unresolved shuffles
-        let (incomplete_unresolved_shuffles, unresolved_shuffles): (
-            Vec<UnresolvedShuffleExec>,
-            Vec<UnresolvedShuffleExec>,
-        ) = unresolved_shuffles.into_iter().partition(|s| {
-            !self
-                .state
-                .stage_manager
-                .is_completed_stage(job_id, s.stage_id as u32)
-        });
-
-        if !incomplete_unresolved_shuffles.is_empty() {
-            return Ok(Some(incomplete_unresolved_shuffles));
-        }
-
-        // All of the unresolved shuffles are completed, update the stage plan
-        {
-            let mut partition_locations: HashMap<
-                usize, // input stage id
-                HashMap<
-                    usize,                                                   // task id of this stage
-                    Vec<ballista_core::serde::scheduler::PartitionLocation>, // shuffle partitions
-                >,
-            > = HashMap::new();
-            for unresolved_shuffle in unresolved_shuffles.iter() {
-                let input_stage_id = unresolved_shuffle.stage_id;
-                let stage_shuffle_partition_locations = partition_locations
-                    .entry(input_stage_id)
-                    .or_insert_with(HashMap::new);
-                if let Some(input_stage_tasks) = self
-                    .state
-                    .stage_manager
-                    .get_stage_tasks(job_id, input_stage_id as u32)
-                {
-                    // each input partition can produce multiple output partitions
-                    for (shuffle_input_partition_id, task_status) in
-                        input_stage_tasks.iter().enumerate()
-                    {
-                        match &task_status.status {
-                            Some(task_status::Status::Completed(CompletedTask {
-                                executor_id,
-                                partitions,
-                            })) => {
-                                debug!(
-                                    "Task for unresolved shuffle input partition {} completed and produced these shuffle partitions:\n\t{}",
-                                    shuffle_input_partition_id,
-                                    partitions.iter().map(|p| format!("{}={}", p.partition_id, &p.path)).collect::<Vec<_>>().join("\n\t")
-                                );
-
-                                for shuffle_write_partition in partitions {
-                                    let temp = stage_shuffle_partition_locations
-                                        .entry(
-                                            shuffle_write_partition.partition_id as usize,
-                                        )
-                                        .or_insert(Vec::new());
-                                    let executor_meta = self
-                                        .state
-                                        .get_executor_metadata(executor_id)
-                                        .ok_or_else(|| {
-                                            BallistaError::General(format!(
-                                                "Fail to find executor metadata for {}",
-                                                &executor_id
-                                            ))
-                                        })?;
-                                    let partition_location =
-                                        ballista_core::serde::scheduler::PartitionLocation {
-                                            partition_id:
-                                            ballista_core::serde::scheduler::PartitionId {
-                                                job_id: job_id.to_owned(),
-                                                stage_id: unresolved_shuffle.stage_id,
-                                                partition_id: shuffle_write_partition
-                                                    .partition_id
-                                                    as usize,
-                                            },
-                                            executor_meta,
-                                            partition_stats: PartitionStats::new(
-                                                Some(shuffle_write_partition.num_rows),
-                                                Some(shuffle_write_partition.num_batches),
-                                                Some(shuffle_write_partition.num_bytes),
-                                            ),
-                                            path: shuffle_write_partition.path.clone(),
-                                        };
-                                    debug!(
-                                            "Scheduler storing stage {} output partition {} path: {}",
-                                            unresolved_shuffle.stage_id,
-                                            partition_location.partition_id.partition_id,
-                                            partition_location.path
-                                        );
-                                    temp.push(partition_location);
-                                }
-                            }
-                            _ => {
-                                debug!(
-                                    "Stage {} input partition {} has not completed yet",
-                                    unresolved_shuffle.stage_id,
-                                    shuffle_input_partition_id
-                                );
-                                // TODO task error handling
-                            }
-                        }
-                    }
-                } else {
-                    return Err(BallistaError::General(format!(
-                        "Fail to find completed stage for {}/{}",
-                        job_id, stage_id
-                    )));
-                }
-            }
-
-            let plan = remove_unresolved_shuffles(stage_plan, &partition_locations)?;
-            self.state.save_stage_plan(job_id, stage_id, plan).await?;
-        }
-
-        Ok(None)
-    }
 }
 
 #[async_trait]
 impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
     EventAction<QueryStageSchedulerEvent> for QueryStageScheduler<T, U>
 {
-    // TODO
-    fn on_start(&self) {}
+    fn on_start(&self) {
+        info!("Starting QueryStageScheduler");
+    }
 
-    // TODO
-    fn on_stop(&self) {}
+    fn on_stop(&self) {
+        info!("Stopping QueryStageScheduler")
+    }
 
     async fn on_receive(
         &self,
         event: QueryStageSchedulerEvent,
     ) -> Result<Option<QueryStageSchedulerEvent>> {
         match event {
-            QueryStageSchedulerEvent::JobSubmitted(job_id, plan) => {
+            QueryStageSchedulerEvent::JobQueued {
+                job_id,
+                session_id,
+                session_ctx,
+                plan,
+            } => {
+                info!("Job {} queued", job_id);
+                return if let Err(e) = self
+                    .submit_job(job_id.clone(), session_id, session_ctx, &plan)
+                    .await
+                {
+                    let msg = format!("Error planning job {}: {:?}", job_id, e);
+                    error!("{}", msg);
+                    Ok(Some(QueryStageSchedulerEvent::JobFailed(job_id, msg)))
+                } else {
+                    Ok(Some(QueryStageSchedulerEvent::JobSubmitted(job_id)))
+                };
+            }
+            QueryStageSchedulerEvent::JobSubmitted(job_id) => {
                 info!("Job {} submitted", job_id);
-                match self.generate_stages(&job_id, plan).await {
-                    Err(e) => {
-                        let msg = format!("Job {} failed due to {}", job_id, e);
-                        warn!("{}", msg);
+                if let Some(sender) = &self.event_sender {
+                    let available_tasks = self
+                        .state
+                        .task_manager
+                        .get_available_task_count(&job_id)
+                        .await?;
+
+                    let reservations: Vec<ExecutorReservation> = self
+                        .state
+                        .executor_manager
+                        .reserve_slots(available_tasks as u32)
+                        .await?
+                        .into_iter()
+                        .map(|res| res.assign(job_id.clone()))
+                        .collect();
+
+                    debug!(
+                        "Reserved {} task slots for submitted job {}",
+                        reservations.len(),
+                        job_id
+                    );
+
+                    if let Err(e) = sender
+                        .post_event(SchedulerServerEvent::Offer(reservations.clone()))
+                        .await
+                    {
+                        error!("Error posting offer: {:?}", e);
                         self.state
-                            .save_job_metadata(
-                                &job_id,
-                                &JobStatus {
-                                    status: Some(job_status::Status::Failed(FailedJob {
-                                        error: msg.to_string(),
-                                    })),
-                                },
-                            )
+                            .executor_manager
+                            .cancel_reservations(reservations)
                             .await?;
                     }
-                    Ok(()) => {
-                        if let Err(e) = self
-                            .state
-                            .save_job_metadata(
-                                &job_id,
-                                &JobStatus {
-                                    status: Some(job_status::Status::Running(
-                                        RunningJob {},
-                                    )),
-                                },
-                            )
-                            .await
-                        {
-                            warn!(
-                                "Could not update job {} status to running: {}",
-                                job_id, e
-                            );
-                        }
-                    }
                 }
             }
-            QueryStageSchedulerEvent::StageFinished(job_id, stage_id) => {
-                info!("Job stage {}/{} finished", job_id, stage_id);
-                self.submit_pending_stages(&job_id, stage_id as usize)
-                    .await?;
-            }
             QueryStageSchedulerEvent::JobFinished(job_id) => {
-                info!("Job {} finished", job_id);
-                let tasks_for_complete_final_stage = self
-                    .state
-                    .stage_manager
-                    .get_tasks_for_complete_final_stage(&job_id)?;
-                let executors: HashMap<String, ExecutorMetadata> = self
-                    .state
-                    .get_executors_metadata()
-                    .await?
-                    .into_iter()
-                    .map(|(meta, _)| (meta.id.to_string(), meta))
-                    .collect();
-                let job_status = get_job_status_from_tasks(
-                    &tasks_for_complete_final_stage,
-                    &executors,
-                );
-                self.state.save_job_metadata(&job_id, &job_status).await?;
+                info!("Job {} complete", job_id);
+                self.state.task_manager.complete_job(&job_id).await?;
             }
-            QueryStageSchedulerEvent::JobFailed(job_id, stage_id, fail_message) => {
-                error!(
-                    "Job stage {}/{} failed due to {}",
-                    &job_id, stage_id, fail_message
-                );
-                let job_status = JobStatus {
-                    status: Some(job_status::Status::Failed(FailedJob {
-                        error: fail_message,
-                    })),
-                };
-                self.state.save_job_metadata(&job_id, &job_status).await?;
+            QueryStageSchedulerEvent::JobFailed(job_id, fail_message) => {
+                error!("Job {} failed: {}", job_id, fail_message);
+                self.state
+                    .task_manager
+                    .fail_job(&job_id, fail_message)
+                    .await?;
             }
         }
 
-        if let Some(event_sender) = self.event_sender.as_ref() {
-            // The stage event must triggerred with releasing some resources. Therefore, revive offers for the scheduler
-            event_sender
-                .post_event(SchedulerServerEvent::ReviveOffers(1))
-                .await?;
-        };
         Ok(None)
     }
 
-    // TODO
-    fn on_error(&self, _error: BallistaError) {}
-}
-
-fn get_job_status_from_tasks(
-    tasks: &[Arc<TaskStatus>],
-    executors: &HashMap<String, ExecutorMetadata>,
-) -> JobStatus {
-    let mut job_status = tasks
-        .iter()
-        .map(|task| match &task.status {
-            Some(task_status::Status::Completed(CompletedTask {
-                executor_id,
-                partitions,
-            })) => Ok((task, executor_id, partitions)),
-            _ => Err(BallistaError::General("Task not completed".to_string())),
-        })
-        .collect::<Result<Vec<_>>>()
-        .ok()
-        .map(|info| {
-            let mut partition_location = vec![];
-            for (status, executor_id, partitions) in info {
-                let input_partition_id = status.task_id.as_ref().unwrap(); // TODO unwrap
-                let executor_meta = executors.get(executor_id).map(|e| e.clone().into());
-                for shuffle_write_partition in partitions {
-                    let shuffle_input_partition_id = Some(protobuf::PartitionId {
-                        job_id: input_partition_id.job_id.clone(),
-                        stage_id: input_partition_id.stage_id,
-                        partition_id: input_partition_id.partition_id,
-                    });
-                    partition_location.push(protobuf::PartitionLocation {
-                        partition_id: shuffle_input_partition_id.clone(),
-                        executor_meta: executor_meta.clone(),
-                        partition_stats: Some(protobuf::PartitionStats {
-                            num_batches: shuffle_write_partition.num_batches as i64,
-                            num_rows: shuffle_write_partition.num_rows as i64,
-                            num_bytes: shuffle_write_partition.num_bytes as i64,
-                            column_stats: vec![],
-                        }),
-                        path: shuffle_write_partition.path.clone(),
-                    });
-                }
-            }
-            job_status::Status::Completed(CompletedJob { partition_location })
-        });
-
-    if job_status.is_none() {
-        // Update other statuses
-        for task in tasks.iter() {
-            if let Some(task_status::Status::Failed(FailedTask { error })) = &task.status
-            {
-                let error = error.clone();
-                job_status = Some(job_status::Status::Failed(FailedJob { error }));
-                break;
-            }
-        }
-    }
-
-    JobStatus {
-        status: Some(job_status.unwrap_or(job_status::Status::Running(RunningJob {}))),
+    fn on_error(&self, error: BallistaError) {
+        error!("Error received by QueryStageScheduler: {:?}", error);
     }
 }
diff --git a/ballista/rust/scheduler/src/state/backend/etcd.rs b/ballista/rust/scheduler/src/state/backend/etcd.rs
index fa85e54d..4b24b7aa 100644
--- a/ballista/rust/scheduler/src/state/backend/etcd.rs
+++ b/ballista/rust/scheduler/src/state/backend/etcd.rs
@@ -17,31 +17,39 @@
 
 //! Etcd config backend.
 
+use std::collections::HashSet;
+
 use std::task::Poll;
 
 use ballista_core::error::{ballista_error, Result};
+use std::time::Instant;
 
-use etcd_client::{GetOptions, LockResponse, WatchOptions, WatchStream, Watcher};
+use etcd_client::{
+    GetOptions, LockOptions, LockResponse, Txn, TxnOp, WatchOptions, WatchStream, Watcher,
+};
 use futures::{Stream, StreamExt};
-use log::warn;
+use log::{debug, error, warn};
 
-use crate::state::backend::{Lock, StateBackendClient, Watch, WatchEvent};
+use crate::state::backend::{Keyspace, Lock, StateBackendClient, Watch, WatchEvent};
 
 /// A [`StateBackendClient`] implementation that uses etcd to save cluster configuration.
 #[derive(Clone)]
 pub struct EtcdClient {
+    namespace: String,
     etcd: etcd_client::Client,
 }
 
 impl EtcdClient {
-    pub fn new(etcd: etcd_client::Client) -> Self {
-        Self { etcd }
+    pub fn new(namespace: String, etcd: etcd_client::Client) -> Self {
+        Self { namespace, etcd }
     }
 }
 
 #[tonic::async_trait]
 impl StateBackendClient for EtcdClient {
-    async fn get(&self, key: &str) -> Result<Vec<u8>> {
+    async fn get(&self, keyspace: Keyspace, key: &str) -> Result<Vec<u8>> {
+        let key = format!("/{}/{:?}/{}", self.namespace, keyspace, key);
+
         Ok(self
             .etcd
             .clone()
@@ -54,7 +62,13 @@ impl StateBackendClient for EtcdClient {
             .unwrap_or_default())
     }
 
-    async fn get_from_prefix(&self, prefix: &str) -> Result<Vec<(String, Vec<u8>)>> {
+    async fn get_from_prefix(
+        &self,
+        keyspace: Keyspace,
+        prefix: &str,
+    ) -> Result<Vec<(String, Vec<u8>)>> {
+        let prefix = format!("/{}/{:?}/{}", self.namespace, keyspace, prefix);
+
         Ok(self
             .etcd
             .clone()
@@ -67,9 +81,59 @@ impl StateBackendClient for EtcdClient {
             .collect())
     }
 
-    async fn put(&self, key: String, value: Vec<u8>) -> Result<()> {
+    async fn scan(
+        &self,
+        keyspace: Keyspace,
+        limit: Option<usize>,
+    ) -> Result<Vec<(String, Vec<u8>)>> {
+        let prefix = format!("/{}/{:?}/", self.namespace, keyspace);
+
+        let options = if let Some(limit) = limit {
+            GetOptions::new().with_prefix().with_limit(limit as i64)
+        } else {
+            GetOptions::new().with_prefix()
+        };
+
+        Ok(self
+            .etcd
+            .clone()
+            .get(prefix, Some(options))
+            .await
+            .map_err(|e| ballista_error(&format!("etcd error {:?}", e)))?
+            .kvs()
+            .iter()
+            .map(|kv| (kv.key_str().unwrap().to_owned(), kv.value().to_owned()))
+            .collect())
+    }
+
+    async fn scan_keys(&self, keyspace: Keyspace) -> Result<HashSet<String>> {
+        let prefix = format!("/{}/{:?}/", self.namespace, keyspace);
+
+        let options = GetOptions::new().with_prefix().with_keys_only();
+
+        Ok(self
+            .etcd
+            .clone()
+            .get(prefix.clone(), Some(options))
+            .await
+            .map_err(|e| ballista_error(&format!("etcd error {:?}", e)))?
+            .kvs()
+            .iter()
+            .map(|kv| {
+                kv.key_str()
+                    .unwrap()
+                    .strip_prefix(&prefix)
+                    .unwrap()
+                    .to_owned()
+            })
+            .collect())
+    }
+
+    async fn put(&self, keyspace: Keyspace, key: String, value: Vec<u8>) -> Result<()> {
+        let key = format!("/{}/{:?}/{}", self.namespace, keyspace, key);
+
         let mut etcd = self.etcd.clone();
-        etcd.put(key.clone(), value.clone(), None)
+        etcd.put(key, value.clone(), None)
             .await
             .map_err(|e| {
                 warn!("etcd put failed: {}", e);
@@ -78,20 +142,100 @@ impl StateBackendClient for EtcdClient {
             .map(|_| ())
     }
 
-    async fn lock(&self) -> Result<Box<dyn Lock>> {
+    async fn put_txn(&self, ops: Vec<(Keyspace, String, Vec<u8>)>) -> Result<()> {
+        let mut etcd = self.etcd.clone();
+
+        let txn_ops: Vec<TxnOp> = ops
+            .into_iter()
+            .map(|(ks, key, value)| {
+                let key = format!("/{}/{:?}/{}", self.namespace, ks, key);
+                TxnOp::put(key, value, None)
+            })
+            .collect();
+
+        let txn = Txn::new().and_then(txn_ops);
+
+        etcd.txn(txn)
+            .await
+            .map_err(|e| {
+                error!("etcd put failed: {}", e);
+                ballista_error("etcd transaction put failed")
+            })
+            .map(|_| ())
+    }
+
+    async fn mv(
+        &self,
+        from_keyspace: Keyspace,
+        to_keyspace: Keyspace,
+        key: &str,
+    ) -> Result<()> {
         let mut etcd = self.etcd.clone();
-        // TODO: make this a namespaced-lock
+        let from_key = format!("/{}/{:?}/{}", self.namespace, from_keyspace, key);
+        let to_key = format!("/{}/{:?}/{}", self.namespace, to_keyspace, key);
+
+        let current_value = etcd
+            .get(from_key.as_str(), None)
+            .await
+            .map_err(|e| ballista_error(&format!("etcd error {:?}", e)))?
+            .kvs()
+            .get(0)
+            .map(|kv| kv.value().to_owned());
+
+        if let Some(value) = current_value {
+            let txn = Txn::new().and_then(vec![
+                TxnOp::delete(from_key.as_str(), None),
+                TxnOp::put(to_key.as_str(), value, None),
+            ]);
+            etcd.txn(txn).await.map_err(|e| {
+                error!("etcd put failed: {}", e);
+                ballista_error("etcd move failed")
+            })?;
+        } else {
+            warn!("Cannot move value at {}, does not exist", from_key);
+        }
+
+        Ok(())
+    }
+
+    async fn lock(&self, keyspace: Keyspace, key: &str) -> Result<Box<dyn Lock>> {
+        let start = Instant::now();
+        let mut etcd = self.etcd.clone();
+
+        let lock_id = format!("/{}/mutex/{:?}/{}", self.namespace, keyspace, key);
+
+        // Create a lease which expires after 30 seconds. We then associate this lease with the lock
+        // acquired below. This protects against a scheduler dying unexpectedly while holding locks
+        // on shared resources. In that case, those locks would expire once the lease expires.
+        // TODO This is not great to do for every lock. We should have a single lease per scheduler instance
+        let lease_id = etcd
+            .lease_client()
+            .grant(30, None)
+            .await
+            .map_err(|e| {
+                warn!("etcd lease failed: {}", e);
+                ballista_error("etcd lease failed")
+            })?
+            .id();
+
+        let lock_options = LockOptions::new().with_lease(lease_id);
+
         let lock = etcd
-            .lock("/ballista_global_lock", None)
+            .lock(lock_id.as_str(), Some(lock_options))
             .await
             .map_err(|e| {
                 warn!("etcd lock failed: {}", e);
                 ballista_error("etcd lock failed")
             })?;
+
+        let elapsed = start.elapsed();
+        debug!("Acquired lock {} in {:?}", lock_id, elapsed);
         Ok(Box::new(EtcdLockGuard { etcd, lock }))
     }
 
-    async fn watch(&self, prefix: String) -> Result<Box<dyn Watch>> {
+    async fn watch(&self, keyspace: Keyspace, prefix: String) -> Result<Box<dyn Watch>> {
+        let prefix = format!("/{}/{:?}/{}", self.namespace, keyspace, prefix);
+
         let mut etcd = self.etcd.clone();
         let options = WatchOptions::new().with_prefix();
         let (watcher, stream) = etcd.watch(prefix, Some(options)).await.map_err(|e| {
@@ -104,6 +248,19 @@ impl StateBackendClient for EtcdClient {
             buffered_events: Vec::new(),
         }))
     }
+
+    async fn delete(&self, keyspace: Keyspace, key: &str) -> Result<()> {
+        let key = format!("/{}/{:?}/{}", self.namespace, keyspace, key);
+
+        let mut etcd = self.etcd.clone();
+
+        etcd.delete(key, None).await.map_err(|e| {
+            warn!("etcd delete failed: {:?}", e);
+            ballista_error("etcd delete failed")
+        })?;
+
+        Ok(())
+    }
 }
 
 struct EtcdWatch {
diff --git a/ballista/rust/scheduler/src/state/backend/mod.rs b/ballista/rust/scheduler/src/state/backend/mod.rs
index 15f244b6..4a6334ab 100644
--- a/ballista/rust/scheduler/src/state/backend/mod.rs
+++ b/ballista/rust/scheduler/src/state/backend/mod.rs
@@ -18,6 +18,7 @@
 use ballista_core::error::Result;
 use clap::ArgEnum;
 use futures::Stream;
+use std::collections::HashSet;
 use std::fmt;
 use tokio::sync::OwnedMutexGuard;
 
@@ -48,24 +49,67 @@ impl parse_arg::ParseArgFromStr for StateBackend {
     }
 }
 
+#[derive(Debug, Eq, PartialEq, Hash)]
+pub enum Keyspace {
+    Executors,
+    ActiveJobs,
+    CompletedJobs,
+    QueuedJobs,
+    FailedJobs,
+    Slots,
+    Sessions,
+    Heartbeats,
+}
+
 /// A trait that contains the necessary methods to save and retrieve the state and configuration of a cluster.
 #[tonic::async_trait]
 pub trait StateBackendClient: Send + Sync {
-    /// Retrieve the data associated with a specific key.
+    /// Retrieve the data associated with a specific key in a given keyspace.
     ///
     /// An empty vec is returned if the key does not exist.
-    async fn get(&self, key: &str) -> Result<Vec<u8>>;
-
-    /// Retrieve all data associated with a specific key.
-    async fn get_from_prefix(&self, prefix: &str) -> Result<Vec<(String, Vec<u8>)>>;
+    async fn get(&self, keyspace: Keyspace, key: &str) -> Result<Vec<u8>>;
+
+    /// Retrieve all key/value pairs in given keyspace matching a given key prefix.
+    async fn get_from_prefix(
+        &self,
+        keyspace: Keyspace,
+        prefix: &str,
+    ) -> Result<Vec<(String, Vec<u8>)>>;
+
+    /// Retrieve all key/value pairs in a given keyspace. If a limit is specified, will return at
+    /// most `limit` key-value pairs.
+    async fn scan(
+        &self,
+        keyspace: Keyspace,
+        limit: Option<usize>,
+    ) -> Result<Vec<(String, Vec<u8>)>>;
+
+    /// Retrieve all keys from a given keyspace (without their values). The implementations
+    /// should handle stripping any prefixes it may add.
+    async fn scan_keys(&self, keyspace: Keyspace) -> Result<HashSet<String>>;
 
     /// Saves the value into the provided key, overriding any previous data that might have been associated to that key.
-    async fn put(&self, key: String, value: Vec<u8>) -> Result<()>;
+    async fn put(&self, keyspace: Keyspace, key: String, value: Vec<u8>) -> Result<()>;
 
-    async fn lock(&self) -> Result<Box<dyn Lock>>;
+    /// Save multiple values in a single transaction. Either all values should be saved, or all should fail
+    async fn put_txn(&self, ops: Vec<(Keyspace, String, Vec<u8>)>) -> Result<()>;
+
+    /// Atomically move the given key from one keyspace to another
+    async fn mv(
+        &self,
+        from_keyspace: Keyspace,
+        to_keyspace: Keyspace,
+        key: &str,
+    ) -> Result<()>;
+
+    /// Acquire mutex with specified ID.
+    async fn lock(&self, keyspace: Keyspace, key: &str) -> Result<Box<dyn Lock>>;
 
     /// Watch all events that happen on a specific prefix.
-    async fn watch(&self, prefix: String) -> Result<Box<dyn Watch>>;
+    async fn watch(&self, keyspace: Keyspace, prefix: String) -> Result<Box<dyn Watch>>;
+
+    /// Permanently delete a key from state
+    async fn delete(&self, keyspace: Keyspace, key: &str) -> Result<()>;
 }
 
 /// A Watch is a cancelable stream of put or delete events in the [StateBackendClient]
diff --git a/ballista/rust/scheduler/src/state/backend/standalone.rs b/ballista/rust/scheduler/src/state/backend/standalone.rs
index 5bb4e384..4e5dc063 100644
--- a/ballista/rust/scheduler/src/state/backend/standalone.rs
+++ b/ballista/rust/scheduler/src/state/backend/standalone.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use std::collections::{HashMap, HashSet};
 use std::{sync::Arc, task::Poll};
 
 use ballista_core::error::{ballista_error, BallistaError, Result};
@@ -24,13 +25,13 @@ use log::warn;
 use sled_package as sled;
 use tokio::sync::Mutex;
 
-use crate::state::backend::{Lock, StateBackendClient, Watch, WatchEvent};
+use crate::state::backend::{Keyspace, Lock, StateBackendClient, Watch, WatchEvent};
 
 /// A [`StateBackendClient`] implementation that uses file-based storage to save cluster configuration.
 #[derive(Clone)]
 pub struct StandaloneClient {
     db: sled::Db,
-    lock: Arc<Mutex<()>>,
+    locks: Arc<Mutex<HashMap<String, Arc<Mutex<()>>>>>,
 }
 
 impl StandaloneClient {
@@ -38,7 +39,7 @@ impl StandaloneClient {
     pub fn try_new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
         Ok(Self {
             db: sled::open(path).map_err(sled_to_ballista_error)?,
-            lock: Arc::new(Mutex::new(())),
+            locks: Arc::new(Mutex::new(HashMap::new())),
         })
     }
 
@@ -49,7 +50,7 @@ impl StandaloneClient {
                 .temporary(true)
                 .open()
                 .map_err(sled_to_ballista_error)?,
-            lock: Arc::new(Mutex::new(())),
+            locks: Arc::new(Mutex::new(HashMap::new())),
         })
     }
 }
@@ -63,7 +64,8 @@ fn sled_to_ballista_error(e: sled::Error) -> BallistaError {
 
 #[tonic::async_trait]
 impl StateBackendClient for StandaloneClient {
-    async fn get(&self, key: &str) -> Result<Vec<u8>> {
+    async fn get(&self, keyspace: Keyspace, key: &str) -> Result<Vec<u8>> {
+        let key = format!("/{:?}/{}", keyspace, key);
         Ok(self
             .db
             .get(key)
@@ -72,7 +74,12 @@ impl StateBackendClient for StandaloneClient {
             .unwrap_or_default())
     }
 
-    async fn get_from_prefix(&self, prefix: &str) -> Result<Vec<(String, Vec<u8>)>> {
+    async fn get_from_prefix(
+        &self,
+        keyspace: Keyspace,
+        prefix: &str,
+    ) -> Result<Vec<(String, Vec<u8>)>> {
+        let prefix = format!("/{:?}/{}", keyspace, prefix);
         Ok(self
             .db
             .scan_prefix(prefix)
@@ -88,7 +95,64 @@ impl StateBackendClient for StandaloneClient {
             .map_err(|e| ballista_error(&format!("sled error {:?}", e)))?)
     }
 
-    async fn put(&self, key: String, value: Vec<u8>) -> Result<()> {
+    async fn scan(
+        &self,
+        keyspace: Keyspace,
+        limit: Option<usize>,
+    ) -> Result<Vec<(String, Vec<u8>)>> {
+        let prefix = format!("/{:?}/", keyspace);
+        if let Some(limit) = limit {
+            Ok(self
+                .db
+                .scan_prefix(prefix)
+                .take(limit)
+                .map(|v| {
+                    v.map(|(key, value)| {
+                        (
+                            std::str::from_utf8(&key).unwrap().to_owned(),
+                            value.to_vec(),
+                        )
+                    })
+                })
+                .collect::<std::result::Result<Vec<_>, _>>()
+                .map_err(|e| ballista_error(&format!("sled error {:?}", e)))?)
+        } else {
+            Ok(self
+                .db
+                .scan_prefix(prefix)
+                .map(|v| {
+                    v.map(|(key, value)| {
+                        (
+                            std::str::from_utf8(&key).unwrap().to_owned(),
+                            value.to_vec(),
+                        )
+                    })
+                })
+                .collect::<std::result::Result<Vec<_>, _>>()
+                .map_err(|e| ballista_error(&format!("sled error {:?}", e)))?)
+        }
+    }
+
+    async fn scan_keys(&self, keyspace: Keyspace) -> Result<HashSet<String>> {
+        let prefix = format!("/{:?}/", keyspace);
+        Ok(self
+            .db
+            .scan_prefix(prefix.clone())
+            .map(|v| {
+                v.map(|(key, _value)| {
+                    std::str::from_utf8(&key)
+                        .unwrap()
+                        .strip_prefix(&prefix)
+                        .unwrap()
+                        .to_owned()
+                })
+            })
+            .collect::<std::result::Result<HashSet<_>, _>>()
+            .map_err(|e| ballista_error(&format!("sled error {:?}", e)))?)
+    }
+
+    async fn put(&self, keyspace: Keyspace, key: String, value: Vec<u8>) -> Result<()> {
+        let key = format!("/{:?}/{}", keyspace, key);
         self.db
             .insert(key, value)
             .map_err(|e| {
@@ -98,15 +162,80 @@ impl StateBackendClient for StandaloneClient {
             .map(|_| ())
     }
 
-    async fn lock(&self) -> Result<Box<dyn Lock>> {
-        Ok(Box::new(self.lock.clone().lock_owned().await))
+    async fn put_txn(&self, ops: Vec<(Keyspace, String, Vec<u8>)>) -> Result<()> {
+        let mut batch = sled::Batch::default();
+
+        for (ks, key, value) in ops {
+            let key = format!("/{:?}/{}", ks, key);
+            batch.insert(key.as_str(), value);
+        }
+
+        self.db.apply_batch(batch).map_err(|e| {
+            warn!("sled transaction insert failed: {}", e);
+            ballista_error("sled insert failed")
+        })
     }
 
-    async fn watch(&self, prefix: String) -> Result<Box<dyn Watch>> {
+    async fn mv(
+        &self,
+        from_keyspace: Keyspace,
+        to_keyspace: Keyspace,
+        key: &str,
+    ) -> Result<()> {
+        let from_key = format!("/{:?}/{}", from_keyspace, key);
+        let to_key = format!("/{:?}/{}", to_keyspace, key);
+
+        let current_value = self
+            .db
+            .get(from_key.as_str())
+            .map_err(|e| ballista_error(&format!("sled error {:?}", e)))?
+            .map(|v| v.to_vec());
+
+        if let Some(value) = current_value {
+            let mut batch = sled::Batch::default();
+
+            batch.remove(from_key.as_str());
+            batch.insert(to_key.as_str(), value);
+
+            self.db.apply_batch(batch).map_err(|e| {
+                warn!("sled transaction insert failed: {}", e);
+                ballista_error("sled insert failed")
+            })
+        } else {
+            // TODO should this return an error?
+            warn!("Cannot move value at {}, does not exist", from_key);
+            Ok(())
+        }
+    }
+
+    async fn lock(&self, keyspace: Keyspace, key: &str) -> Result<Box<dyn Lock>> {
+        let mut mlock = self.locks.lock().await;
+        let lock_key = format!("/{:?}/{}", keyspace, key);
+        if let Some(lock) = mlock.get(&lock_key) {
+            Ok(Box::new(lock.clone().lock_owned().await))
+        } else {
+            let new_lock = Arc::new(Mutex::new(()));
+            mlock.insert(lock_key, new_lock.clone());
+            Ok(Box::new(new_lock.lock_owned().await))
+        }
+    }
+
+    async fn watch(&self, keyspace: Keyspace, prefix: String) -> Result<Box<dyn Watch>> {
+        let prefix = format!("/{:?}/{}", keyspace, prefix);
+
         Ok(Box::new(SledWatch {
             subscriber: self.db.watch_prefix(prefix),
         }))
     }
+
+    async fn delete(&self, keyspace: Keyspace, key: &str) -> Result<()> {
+        let key = format!("/{:?}/{}", keyspace, key);
+        self.db.remove(key).map_err(|e| {
+            warn!("sled delete failed: {:?}", e);
+            ballista_error("sled delete failed")
+        })?;
+        Ok(())
+    }
 }
 
 struct SledWatch {
@@ -150,6 +279,7 @@ impl Stream for SledWatch {
 mod tests {
     use super::{StandaloneClient, StateBackendClient, Watch, WatchEvent};
 
+    use crate::state::backend::Keyspace;
     use futures::StreamExt;
     use std::result::Result;
 
@@ -162,8 +292,10 @@ mod tests {
         let client = create_instance()?;
         let key = "key";
         let value = "value".as_bytes();
-        client.put(key.to_owned(), value.to_vec()).await?;
-        assert_eq!(client.get(key).await?, value);
+        client
+            .put(Keyspace::Slots, key.to_owned(), value.to_vec())
+            .await?;
+        assert_eq!(client.get(Keyspace::Slots, key).await?, value);
         Ok(())
     }
 
@@ -172,7 +304,7 @@ mod tests {
         let client = create_instance()?;
         let key = "key";
         let empty: &[u8] = &[];
-        assert_eq!(client.get(key).await?, empty);
+        assert_eq!(client.get(Keyspace::Slots, key).await?, empty);
         Ok(())
     }
 
@@ -181,13 +313,17 @@ mod tests {
         let client = create_instance()?;
         let key = "key";
         let value = "value".as_bytes();
-        client.put(format!("{}/1", key), value.to_vec()).await?;
-        client.put(format!("{}/2", key), value.to_vec()).await?;
+        client
+            .put(Keyspace::Slots, format!("{}/1", key), value.to_vec())
+            .await?;
+        client
+            .put(Keyspace::Slots, format!("{}/2", key), value.to_vec())
+            .await?;
         assert_eq!(
-            client.get_from_prefix(key).await?,
+            client.get_from_prefix(Keyspace::Slots, key).await?,
             vec![
-                ("key/1".to_owned(), value.to_vec()),
-                ("key/2".to_owned(), value.to_vec())
+                ("/Slots/key/1".to_owned(), value.to_vec()),
+                ("/Slots/key/2".to_owned(), value.to_vec())
             ]
         );
         Ok(())
@@ -198,17 +334,28 @@ mod tests {
         let client = create_instance()?;
         let key = "key";
         let value = "value".as_bytes();
-        let mut watch: Box<dyn Watch> = client.watch(key.to_owned()).await?;
-        client.put(key.to_owned(), value.to_vec()).await?;
+        let mut watch: Box<dyn Watch> =
+            client.watch(Keyspace::Slots, key.to_owned()).await?;
+        client
+            .put(Keyspace::Slots, key.to_owned(), value.to_vec())
+            .await?;
         assert_eq!(
             watch.next().await,
-            Some(WatchEvent::Put(key.to_owned(), value.to_owned()))
+            Some(WatchEvent::Put(
+                format!("/{:?}/{}", Keyspace::Slots, key.to_owned()),
+                value.to_owned()
+            ))
         );
         let value2 = "value2".as_bytes();
-        client.put(key.to_owned(), value2.to_vec()).await?;
+        client
+            .put(Keyspace::Slots, key.to_owned(), value2.to_vec())
+            .await?;
         assert_eq!(
             watch.next().await,
-            Some(WatchEvent::Put(key.to_owned(), value2.to_owned()))
+            Some(WatchEvent::Put(
+                format!("/{:?}/{}", Keyspace::Slots, key.to_owned()),
+                value2.to_owned()
+            ))
         );
         watch.cancel().await?;
         Ok(())
diff --git a/ballista/rust/scheduler/src/state/execution_graph.rs b/ballista/rust/scheduler/src/state/execution_graph.rs
new file mode 100644
index 00000000..1412f7e0
--- /dev/null
+++ b/ballista/rust/scheduler/src/state/execution_graph.rs
@@ -0,0 +1,974 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::planner::DistributedPlanner;
+use ballista_core::error::{BallistaError, Result};
+use ballista_core::execution_plans::{ShuffleWriterExec, UnresolvedShuffleExec};
+
+use ballista_core::serde::protobuf::{
+    self, CompletedJob, JobStatus, QueuedJob, TaskStatus,
+};
+use ballista_core::serde::protobuf::{job_status, FailedJob, ShuffleWritePartition};
+use ballista_core::serde::protobuf::{task_status, RunningTask};
+use ballista_core::serde::scheduler::{
+    ExecutorMetadata, PartitionId, PartitionLocation, PartitionStats,
+};
+use datafusion::physical_plan::{
+    accept, ExecutionPlan, ExecutionPlanVisitor, Partitioning,
+};
+use log::debug;
+use std::collections::HashMap;
+use std::convert::TryInto;
+use std::fmt::{Debug, Formatter};
+
+use datafusion::physical_plan::display::DisplayableExecutionPlan;
+use std::sync::Arc;
+
+/// This data structure collects the partition locations for an `ExecutionStage`.
+/// Each `ExecutionStage` will hold a `StageOutput`s for each of its child stages.
+/// When all tasks for the child stage are complete, it will mark the `StageOutput`
+#[derive(Clone, Debug, Default)]
+pub struct StageOutput {
+    /// Map from partition -> partition locations
+    pub(crate) partition_locations: HashMap<usize, Vec<PartitionLocation>>,
+    /// Flag indicating whether all tasks are complete
+    pub(crate) complete: bool,
+}
+
+impl StageOutput {
+    pub fn new() -> Self {
+        Self {
+            partition_locations: HashMap::new(),
+            complete: false,
+        }
+    }
+
+    /// Add a `PartitionLocation` to the `StageOutput`
+    pub fn add_partition(&mut self, partition_location: PartitionLocation) {
+        if let Some(parts) = self
+            .partition_locations
+            .get_mut(&partition_location.partition_id.partition_id)
+        {
+            parts.push(partition_location)
+        } else {
+            self.partition_locations.insert(
+                partition_location.partition_id.partition_id,
+                vec![partition_location],
+            );
+        }
+    }
+
+    pub fn is_complete(&self) -> bool {
+        self.complete
+    }
+}
+
+/// A stage in the ExecutionGraph.
+///
+/// This represents a set of tasks (one per each `partition`) which can
+/// be executed concurrently.
+#[derive(Clone)]
+pub struct ExecutionStage {
+    /// Stage ID
+    pub(crate) stage_id: usize,
+    /// Total number of output partitions for this stage.
+    /// This stage will produce on task for partition.
+    pub(crate) partitions: usize,
+    /// Output partitioning for this stage.
+    pub(crate) output_partitioning: Option<Partitioning>,
+    /// Represents the outputs from this stage's child stages.
+    /// This stage can only be resolved an executed once all child stages are completed.
+    pub(crate) inputs: HashMap<usize, StageOutput>,
+    // `ExecutionPlan` for this stage
+    pub(crate) plan: Arc<dyn ExecutionPlan>,
+    /// Status of each already scheduled task. If status is None, the partition has not yet been scheduled
+    pub(crate) task_statuses: Vec<Option<task_status::Status>>,
+    /// Stage ID of the stage that will take this stages outputs as inputs.
+    /// If `output_link` is `None` then this the final stage in the `ExecutionGraph`
+    pub(crate) output_link: Option<usize>,
+    /// Flag indicating whether all input partitions have been resolved and the plan
+    /// has UnresovledShuffleExec operators resolved to ShuffleReadExec operators.
+    pub(crate) resolved: bool,
+}
+
+impl Debug for ExecutionStage {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        let plan = DisplayableExecutionPlan::new(self.plan.as_ref()).indent();
+        let scheduled_tasks = self.task_statuses.iter().filter(|t| t.is_some()).count();
+
+        write!(
+            f,
+            "Stage[id={}, partitions={:?}, children={}, completed_tasks={}, resolved={}, scheduled_tasks={}, available_tasks={}]\nInputs{:?}\n\n{}",
+            self.stage_id,
+            self.partitions,
+            self.inputs.len(),
+            self.completed_tasks(),
+            self.resolved,
+            scheduled_tasks,
+            self.available_tasks(),
+            self.inputs,
+            plan
+        )
+    }
+}
+
+impl ExecutionStage {
+    pub fn new(
+        stage_id: usize,
+        plan: Arc<dyn ExecutionPlan>,
+        output_partitioning: Option<Partitioning>,
+        output_link: Option<usize>,
+        child_stages: Vec<usize>,
+    ) -> Self {
+        let num_tasks = plan.output_partitioning().partition_count();
+
+        let resolved = child_stages.is_empty();
+
+        let mut inputs: HashMap<usize, StageOutput> = HashMap::new();
+
+        for input_stage_id in &child_stages {
+            inputs.insert(*input_stage_id, StageOutput::new());
+        }
+
+        Self {
+            stage_id,
+            partitions: num_tasks,
+            output_partitioning,
+            inputs,
+            plan,
+            task_statuses: vec![None; num_tasks],
+            output_link,
+            resolved,
+        }
+    }
+
+    /// Returns true if all inputs are complete and we can resolve all
+    /// UnresolvedShuffleExec operators to ShuffleReadExec
+    pub fn resolvable(&self) -> bool {
+        self.inputs.iter().all(|(_, outputs)| outputs.is_complete())
+    }
+
+    /// Returns `true` if all tasks for this stage are complete
+    pub fn complete(&self) -> bool {
+        self.task_statuses
+            .iter()
+            .all(|status| matches!(status, Some(task_status::Status::Completed(_))))
+    }
+
+    /// Returns the number of tasks
+    pub fn completed_tasks(&self) -> usize {
+        self.task_statuses
+            .iter()
+            .filter(|status| matches!(status, Some(task_status::Status::Completed(_))))
+            .count()
+    }
+
+    /// Marks the input stage ID as complete.
+    pub fn complete_input(&mut self, stage_id: usize) {
+        if let Some(input) = self.inputs.get_mut(&stage_id) {
+            input.complete = true;
+        }
+    }
+
+    /// Returns true if the stage plan has all UnresolvedShuffleExec operators resolved to
+    /// ShuffleReadExec
+    pub fn resolved(&self) -> bool {
+        self.resolved
+    }
+
+    /// Returns the number of tasks in this stage which are available for scheduling.
+    /// If the stage is not yet resolved, then this will return `0`, otherwise it will
+    /// return the number of tasks where the task status is not yet set.
+    pub fn available_tasks(&self) -> usize {
+        if self.resolved {
+            self.task_statuses.iter().filter(|s| s.is_none()).count()
+        } else {
+            0
+        }
+    }
+
+    /// Resolve any UnresolvedShuffleExec operators within this stage's plan
+    pub fn resolve_shuffles(&mut self) -> Result<()> {
+        println!("Resolving shuffles\n{:?}", self);
+        if self.resolved {
+            // If this stage has no input shuffles, then it is already resolved
+            Ok(())
+        } else {
+            let input_locations = self
+                .inputs
+                .iter()
+                .map(|(stage, outputs)| (*stage, outputs.partition_locations.clone()))
+                .collect();
+            // Otherwise, rewrite the plan to replace UnresolvedShuffleExec with ShuffleReadExec
+            let new_plan = crate::planner::remove_unresolved_shuffles(
+                self.plan.clone(),
+                &input_locations,
+            )?;
+            self.plan = new_plan;
+            self.resolved = true;
+            Ok(())
+        }
+    }
+
+    /// Update the status for task partition
+    pub fn update_task_status(&mut self, partition: usize, status: task_status::Status) {
+        debug!("Updating task status for partition {}", partition);
+        self.task_statuses[partition] = Some(status);
+    }
+
+    /// Add input partitions published from an input stage.
+    pub fn add_input_partitions(
+        &mut self,
+        stage_id: usize,
+        _partition_id: usize,
+        locations: Vec<PartitionLocation>,
+    ) -> Result<()> {
+        if let Some(stage_inputs) = self.inputs.get_mut(&stage_id) {
+            for partition in locations {
+                stage_inputs.add_partition(partition);
+            }
+        } else {
+            return Err(BallistaError::Internal(format!("Error adding input partitions to stage {}, {} is not a valid child stage ID", self.stage_id, stage_id)));
+        }
+
+        Ok(())
+    }
+}
+
+/// Utility for building a set of `ExecutionStage`s from
+/// a list of `ShuffleWriterExec`.
+///
+/// This will infer the dependency structure for the stages
+/// so that we can construct a DAG from the stages.
+struct ExecutionStageBuilder {
+    /// Stage ID which is currently being visited
+    current_stage_id: usize,
+    /// Map from stage ID -> List of child stage IDs
+    stage_dependencies: HashMap<usize, Vec<usize>>,
+    /// Map from Stage ID -> output link
+    output_links: HashMap<usize, usize>,
+}
+
+impl ExecutionStageBuilder {
+    pub fn new() -> Self {
+        Self {
+            current_stage_id: 0,
+            stage_dependencies: HashMap::new(),
+            output_links: HashMap::new(),
+        }
+    }
+
+    pub fn build(
+        mut self,
+        stages: Vec<Arc<ShuffleWriterExec>>,
+    ) -> Result<HashMap<usize, ExecutionStage>> {
+        let mut execution_stages: HashMap<usize, ExecutionStage> = HashMap::new();
+        // First, build the dependency graph
+        for stage in &stages {
+            accept(stage.as_ref(), &mut self)?;
+        }
+
+        // Now, create the execution stages
+        for stage in stages {
+            let partitioning = stage.shuffle_output_partitioning().cloned();
+            let stage_id = stage.stage_id();
+            let output_link = self.output_links.remove(&stage_id);
+
+            let child_stages = self
+                .stage_dependencies
+                .remove(&stage_id)
+                .unwrap_or_default();
+
+            execution_stages.insert(
+                stage_id,
+                ExecutionStage::new(
+                    stage_id,
+                    stage,
+                    partitioning,
+                    output_link,
+                    child_stages,
+                ),
+            );
+        }
+
+        Ok(execution_stages)
+    }
+}
+
+impl ExecutionPlanVisitor for ExecutionStageBuilder {
+    type Error = BallistaError;
+
+    fn pre_visit(
+        &mut self,
+        plan: &dyn ExecutionPlan,
+    ) -> std::result::Result<bool, Self::Error> {
+        if let Some(shuffle_write) = plan.as_any().downcast_ref::<ShuffleWriterExec>() {
+            self.current_stage_id = shuffle_write.stage_id();
+        } else if let Some(unresolved_shuffle) =
+            plan.as_any().downcast_ref::<UnresolvedShuffleExec>()
+        {
+            self.output_links
+                .insert(unresolved_shuffle.stage_id, self.current_stage_id);
+
+            if let Some(deps) = self.stage_dependencies.get_mut(&self.current_stage_id) {
+                deps.push(unresolved_shuffle.stage_id)
+            } else {
+                self.stage_dependencies
+                    .insert(self.current_stage_id, vec![unresolved_shuffle.stage_id]);
+            }
+        }
+        Ok(true)
+    }
+}
+
+/// Represents the basic unit of work for the Ballista executor. Will execute
+/// one partition of one stage on one task slot.
+#[derive(Clone)]
+pub struct Task {
+    pub session_id: String,
+    pub partition: PartitionId,
+    pub plan: Arc<dyn ExecutionPlan>,
+    pub output_partitioning: Option<Partitioning>,
+}
+
+impl Debug for Task {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        let plan = DisplayableExecutionPlan::new(self.plan.as_ref()).indent();
+        write!(
+            f,
+            "Task[session_id: {}, job: {}, stage: {}, partition: {}]\n{}",
+            self.session_id,
+            self.partition.job_id,
+            self.partition.stage_id,
+            self.partition.partition_id,
+            plan
+        )
+    }
+}
+
+/// Represents the DAG for a distributed query plan.
+///
+/// A distributed query plan consists of a set of stages which must be executed sequentially.
+///
+/// Each stage consists of a set of partitions which can be executed in parallel, where each partition
+/// represents a `Task`, which is the basic unit of scheduling in Ballista.
+///
+/// As an example, consider a SQL query which performs a simple aggregation:
+///
+/// `SELECT id, SUM(gmv) FROM some_table GROUP BY id`
+///
+/// This will produce a DataFusion execution plan that looks something like
+///
+///
+///   CoalesceBatchesExec: target_batch_size=4096
+///     RepartitionExec: partitioning=Hash([Column { name: "id", index: 0 }], 4)
+///       AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[SUM(some_table.gmv)]
+///         TableScan: some_table
+///
+/// The Ballista `DistributedPlanner` will turn this into a distributed plan by creating a shuffle
+/// boundary (called a "Stage") whenever the underlying plan needs to perform a repartition.
+/// In this case we end up with a distributed plan with two stages:
+///
+///
+/// ExecutionGraph[job_id=job, session_id=session, available_tasks=1, complete=false]
+/// Stage[id=2, partitions=4, children=1, completed_tasks=0, resolved=false, scheduled_tasks=0, available_tasks=0]
+/// Inputs{1: StageOutput { partition_locations: {}, complete: false }}
+///
+/// ShuffleWriterExec: None
+///   AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[SUM(?table?.gmv)]
+///     CoalesceBatchesExec: target_batch_size=4096
+///       UnresolvedShuffleExec
+///
+/// Stage[id=1, partitions=1, children=0, completed_tasks=0, resolved=true, scheduled_tasks=0, available_tasks=1]
+/// Inputs{}
+///
+/// ShuffleWriterExec: Some(Hash([Column { name: "id", index: 0 }], 4))
+///   AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[SUM(?table?.gmv)]
+///     TableScan: some_table
+///
+///
+/// The DAG structure of this `ExecutionGraph` is encoded in the stages. Each stage's `input` field
+/// will indicate which stages it depends on, and each stage's `output_link` will indicate which
+/// stage it needs to publish its output to.
+///
+/// If a stage has `output_link == None` then it is the final stage in this query, and it should
+/// publish its outputs to the `ExecutionGraph`s `output_locations` representing the final query results.
+#[derive(Clone)]
+pub struct ExecutionGraph {
+    /// ID for this job
+    pub(crate) job_id: String,
+    /// Session ID for this job
+    pub(crate) session_id: String,
+    /// Status of this job
+    pub(crate) status: JobStatus,
+    /// Map from Stage ID -> ExecutionStage
+    pub(crate) stages: HashMap<usize, ExecutionStage>,
+    /// Total number fo output partitions
+    pub(crate) output_partitions: usize,
+    /// Locations of this `ExecutionGraph` final output locations
+    pub(crate) output_locations: Vec<PartitionLocation>,
+}
+
+impl ExecutionGraph {
+    pub fn new(
+        job_id: &str,
+        session_id: &str,
+        plan: Arc<dyn ExecutionPlan>,
+    ) -> Result<Self> {
+        let mut planner = DistributedPlanner::new();
+
+        let output_partitions = plan.output_partitioning().partition_count();
+
+        let shuffle_stages = planner.plan_query_stages(job_id, plan)?;
+
+        let builder = ExecutionStageBuilder::new();
+        let stages = builder.build(shuffle_stages)?;
+
+        Ok(Self {
+            job_id: job_id.to_string(),
+            session_id: session_id.to_string(),
+            status: JobStatus {
+                status: Some(job_status::Status::Queued(QueuedJob {})),
+            },
+            stages,
+            output_partitions,
+            output_locations: vec![],
+        })
+    }
+
+    pub fn job_id(&self) -> &str {
+        self.job_id.as_str()
+    }
+
+    pub fn session_id(&self) -> &str {
+        self.session_id.as_str()
+    }
+
+    pub fn status(&self) -> JobStatus {
+        self.status.clone()
+    }
+
+    /// An ExecutionGraph is complete if all its stages are complete
+    pub fn complete(&self) -> bool {
+        self.stages.values().all(|s| s.complete())
+    }
+
+    /// Update task statuses in the graph. This will push shuffle partitions to their
+    /// respective shuffle read stages.
+    pub fn update_task_status(
+        &mut self,
+        executor: &ExecutorMetadata,
+        statuses: Vec<TaskStatus>,
+    ) -> Result<()> {
+        for status in statuses.into_iter() {
+            if let TaskStatus {
+                task_id:
+                    Some(protobuf::PartitionId {
+                        job_id,
+                        stage_id,
+                        partition_id,
+                    }),
+                status: Some(task_status),
+            } = status
+            {
+                if job_id != self.job_id() {
+                    return Err(BallistaError::Internal(format!(
+                        "Error updating job {}: Invalid task status job ID {}",
+                        self.job_id(),
+                        job_id
+                    )));
+                }
+
+                let stage_id = stage_id as usize;
+                let partition = partition_id as usize;
+                if let Some(stage) = self.stages.get_mut(&stage_id) {
+                    stage.update_task_status(partition, task_status.clone());
+                    let stage_complete = stage.complete();
+
+                    // TODO Should be able to reschedule this task.
+                    if let task_status::Status::Failed(failed_task) = task_status {
+                        self.status = JobStatus {
+                            status: Some(job_status::Status::Failed(FailedJob {
+                                error: format!(
+                                    "Task {}/{}/{} failed: {}",
+                                    job_id, stage_id, partition_id, failed_task.error
+                                ),
+                            })),
+                        };
+                        return Ok(());
+                    } else if let task_status::Status::Completed(completed_task) =
+                        task_status
+                    {
+                        let locations = partition_to_location(
+                            self.job_id.as_str(),
+                            stage_id,
+                            executor,
+                            completed_task.partitions,
+                        );
+
+                        if let Some(link) = stage.output_link {
+                            // If this is an intermediate stage, we need to push its `PartitionLocation`s to the parent stage
+                            if let Some(linked_stage) = self.stages.get_mut(&link) {
+                                linked_stage.add_input_partitions(
+                                    stage_id, partition, locations,
+                                )?;
+
+                                // If all tasks for this stage are complete, mark the input complete in the parent stage
+                                if stage_complete {
+                                    linked_stage.complete_input(stage_id);
+                                }
+
+                                // If all input partitions are ready, we can resolve any UnresolvedShuffleExec in the parent stage plan
+                                if linked_stage.resolvable() {
+                                    linked_stage.resolve_shuffles()?;
+                                }
+                            } else {
+                                return Err(BallistaError::Internal(format!("Error updating job {}: Invalid output link {} for stage {}", job_id, stage_id, link)));
+                            }
+                        } else {
+                            // If `output_link` is `None`, then this is a final stage
+                            self.output_locations.extend(locations);
+                        }
+                    }
+                } else {
+                    return Err(BallistaError::Internal(format!(
+                        "Invalid stage ID {} for job {}",
+                        stage_id,
+                        self.job_id()
+                    )));
+                }
+            }
+        }
+
+        Ok(())
+    }
+
+    /// Total number of tasks in this plan that are ready for scheduling
+    pub fn available_tasks(&self) -> usize {
+        self.stages
+            .iter()
+            .map(|(_, stage)| stage.available_tasks())
+            .sum()
+    }
+
+    /// Get next task that can be assigned to the given executor.
+    /// This method should only be called when the resulting task is immediately
+    /// being launched as the status will be set to Running and it will not be
+    /// available to the scheduler.
+    /// If the task is not launched the status must be reset to allow the task to
+    /// be scheduled elsewhere.
+    pub fn pop_next_task(&mut self, executor_id: &str) -> Result<Option<Task>> {
+        let job_id = self.job_id.clone();
+        let session_id = self.session_id.clone();
+        self.stages.iter_mut().find(|(_stage_id, stage)| {
+            stage.resolved() && stage.available_tasks() > 0
+        }).map(|(stage_id, stage)| {
+            let (partition_id,_) = stage
+                .task_statuses
+                .iter()
+                .enumerate()
+                .find(|(_partition,status)| status.is_none())
+                .ok_or_else(|| {
+                BallistaError::Internal(format!("Error getting next task for job {}: Stage {} is ready but has no pending tasks", job_id, stage_id))
+            })?;
+
+             let partition = PartitionId {
+                job_id,
+                stage_id: *stage_id,
+                partition_id
+            };
+
+            // Set the status to Running
+            stage.task_statuses[partition_id] = Some(task_status::Status::Running(RunningTask {
+                executor_id: executor_id.to_owned()
+            }));
+
+            Ok(Task {
+                session_id,
+                partition,
+                plan: stage.plan.clone(),
+                output_partitioning: stage.output_partitioning.clone()
+            })
+        }).transpose()
+    }
+
+    pub fn finalize(&mut self) -> Result<()> {
+        if !self.complete() {
+            return Err(BallistaError::Internal(format!(
+                "Attempt to finalize an incomplete job {}",
+                self.job_id()
+            )));
+        }
+
+        let partition_location = self
+            .output_locations()
+            .into_iter()
+            .map(|l| l.try_into())
+            .collect::<Result<Vec<_>>>()?;
+
+        self.status = JobStatus {
+            status: Some(job_status::Status::Completed(CompletedJob {
+                partition_location,
+            })),
+        };
+
+        Ok(())
+    }
+
+    pub fn update_status(&mut self, status: JobStatus) {
+        self.status = status;
+    }
+
+    /// Reset the status for the given task. This should be called is a task failed to
+    /// launch and it needs to be returned to the set of available tasks and be
+    /// re-scheduled.
+    pub fn reset_task_status(&mut self, task: Task) {
+        let stage_id = task.partition.stage_id;
+        let partition = task.partition.partition_id;
+
+        if let Some(stage) = self.stages.get_mut(&stage_id) {
+            stage.task_statuses[partition] = None;
+        }
+    }
+
+    pub fn output_locations(&self) -> Vec<PartitionLocation> {
+        self.output_locations.clone()
+    }
+}
+
+impl Debug for ExecutionGraph {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        let stages = self
+            .stages
+            .iter()
+            .map(|(_, stage)| format!("{:?}", stage))
+            .collect::<Vec<String>>()
+            .join("\n");
+        write!(f, "ExecutionGraph[job_id={}, session_id={}, available_tasks={}, complete={}]\n{}", self.job_id, self.session_id, self.available_tasks(), self.complete(), stages)
+    }
+}
+
+fn partition_to_location(
+    job_id: &str,
+    stage_id: usize,
+    executor: &ExecutorMetadata,
+    shuffles: Vec<ShuffleWritePartition>,
+) -> Vec<PartitionLocation> {
+    shuffles
+        .into_iter()
+        .map(|shuffle| PartitionLocation {
+            partition_id: PartitionId {
+                job_id: job_id.to_owned(),
+                stage_id,
+                partition_id: shuffle.partition_id as usize,
+            },
+            executor_meta: executor.clone(),
+            partition_stats: PartitionStats::new(
+                Some(shuffle.num_rows),
+                Some(shuffle.num_batches),
+                Some(shuffle.num_bytes),
+            ),
+            path: shuffle.path,
+        })
+        .collect()
+}
+
+#[cfg(test)]
+mod test {
+    use crate::state::execution_graph::ExecutionGraph;
+    use ballista_core::error::Result;
+    use ballista_core::serde::protobuf::{self, job_status, task_status};
+    use ballista_core::serde::scheduler::{ExecutorMetadata, ExecutorSpecification};
+    use datafusion::arrow::datatypes::{DataType, Field, Schema};
+    use datafusion::logical_expr::{col, sum, Expr};
+
+    use datafusion::logical_plan::JoinType;
+    use datafusion::physical_plan::display::DisplayableExecutionPlan;
+    use datafusion::prelude::{SessionConfig, SessionContext};
+    use datafusion::test_util::scan_empty;
+
+    use std::sync::Arc;
+
+    #[tokio::test]
+    async fn test_drain_tasks() -> Result<()> {
+        let mut agg_graph = test_aggregation_plan(4).await;
+
+        println!("Graph: {:?}", agg_graph);
+
+        drain_tasks(&mut agg_graph)?;
+
+        assert!(agg_graph.complete(), "Failed to complete aggregation plan");
+
+        let mut coalesce_graph = test_coalesce_plan(4).await;
+
+        drain_tasks(&mut coalesce_graph)?;
+
+        assert!(
+            coalesce_graph.complete(),
+            "Failed to complete coalesce plan"
+        );
+
+        let mut join_graph = test_join_plan(4).await;
+
+        drain_tasks(&mut join_graph)?;
+
+        println!("{:?}", join_graph);
+
+        assert!(join_graph.complete(), "Failed to complete join plan");
+
+        let mut union_all_graph = test_union_all_plan(4).await;
+
+        drain_tasks(&mut union_all_graph)?;
+
+        println!("{:?}", union_all_graph);
+
+        assert!(union_all_graph.complete(), "Failed to complete union plan");
+
+        let mut union_graph = test_union_plan(4).await;
+
+        drain_tasks(&mut union_graph)?;
+
+        println!("{:?}", union_graph);
+
+        assert!(union_graph.complete(), "Failed to complete union plan");
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_finalize() -> Result<()> {
+        let mut agg_graph = test_aggregation_plan(4).await;
+
+        drain_tasks(&mut agg_graph)?;
+        agg_graph.finalize()?;
+
+        let status = agg_graph.status();
+
+        assert!(matches!(
+            status,
+            protobuf::JobStatus {
+                status: Some(job_status::Status::Completed(_))
+            }
+        ));
+
+        let outputs = agg_graph.output_locations();
+
+        assert_eq!(outputs.len(), agg_graph.output_partitions);
+
+        for location in outputs {
+            assert_eq!(location.executor_meta.host, "localhost2".to_owned());
+        }
+
+        Ok(())
+    }
+
+    fn drain_tasks(graph: &mut ExecutionGraph) -> Result<()> {
+        let executor = test_executor();
+        let job_id = graph.job_id().to_owned();
+        while let Some(task) = graph.pop_next_task("executor-id")? {
+            let mut partitions: Vec<protobuf::ShuffleWritePartition> = vec![];
+
+            let num_partitions = task
+                .output_partitioning
+                .map(|p| p.partition_count())
+                .unwrap_or(1);
+
+            for partition_id in 0..num_partitions {
+                partitions.push(protobuf::ShuffleWritePartition {
+                    partition_id: partition_id as u64,
+                    path: format!(
+                        "/{}/{}/{}",
+                        task.partition.job_id,
+                        task.partition.stage_id,
+                        task.partition.partition_id
+                    ),
+                    num_batches: 1,
+                    num_rows: 1,
+                    num_bytes: 1,
+                })
+            }
+
+            // Complete the task
+            let task_status = protobuf::TaskStatus {
+                status: Some(task_status::Status::Completed(protobuf::CompletedTask {
+                    executor_id: "executor-1".to_owned(),
+                    partitions,
+                })),
+                task_id: Some(protobuf::PartitionId {
+                    job_id: job_id.clone(),
+                    stage_id: task.partition.stage_id as u32,
+                    partition_id: task.partition.partition_id as u32,
+                }),
+            };
+
+            graph.update_task_status(&executor, vec![task_status])?;
+        }
+
+        Ok(())
+    }
+
+    async fn test_aggregation_plan(partition: usize) -> ExecutionGraph {
+        let config = SessionConfig::new().with_target_partitions(partition);
+        let ctx = Arc::new(SessionContext::with_config(config));
+
+        let schema = Schema::new(vec![
+            Field::new("id", DataType::Utf8, false),
+            Field::new("gmv", DataType::UInt64, false),
+        ]);
+
+        let logical_plan = scan_empty(None, &schema, Some(vec![0, 1]))
+            .unwrap()
+            .aggregate(vec![col("id")], vec![sum(col("gmv"))])
+            .unwrap()
+            .build()
+            .unwrap();
+
+        let optimized_plan = ctx.optimize(&logical_plan).unwrap();
+
+        let plan = ctx.create_physical_plan(&optimized_plan).await.unwrap();
+
+        println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent());
+
+        ExecutionGraph::new("job", "session", plan).unwrap()
+    }
+
+    async fn test_coalesce_plan(partition: usize) -> ExecutionGraph {
+        let config = SessionConfig::new().with_target_partitions(partition);
+        let ctx = Arc::new(SessionContext::with_config(config));
+
+        let schema = Schema::new(vec![
+            Field::new("id", DataType::Utf8, false),
+            Field::new("gmv", DataType::UInt64, false),
+        ]);
+
+        let logical_plan = scan_empty(None, &schema, Some(vec![0, 1]))
+            .unwrap()
+            .limit(None, Some(1))
+            .unwrap()
+            .build()
+            .unwrap();
+
+        let optimized_plan = ctx.optimize(&logical_plan).unwrap();
+
+        let plan = ctx.create_physical_plan(&optimized_plan).await.unwrap();
+
+        ExecutionGraph::new("job", "session", plan).unwrap()
+    }
+
+    async fn test_join_plan(partition: usize) -> ExecutionGraph {
+        let config = SessionConfig::new().with_target_partitions(partition);
+        let ctx = Arc::new(SessionContext::with_config(config));
+
+        let schema = Schema::new(vec![
+            Field::new("id", DataType::Utf8, false),
+            Field::new("gmv", DataType::UInt64, false),
+        ]);
+
+        let left_plan = scan_empty(Some("left"), &schema, None).unwrap();
+
+        let right_plan = scan_empty(Some("right"), &schema, None)
+            .unwrap()
+            .build()
+            .unwrap();
+
+        let sort_expr = Expr::Sort {
+            expr: Box::new(col("id")),
+            asc: false,
+            nulls_first: false,
+        };
+
+        let logical_plan = left_plan
+            .join(&right_plan, JoinType::Inner, (vec!["id"], vec!["id"]), None)
+            .unwrap()
+            .aggregate(vec![col("id")], vec![sum(col("gmv"))])
+            .unwrap()
+            .sort(vec![sort_expr])
+            .unwrap()
+            .build()
+            .unwrap();
+
+        let optimized_plan = ctx.optimize(&logical_plan).unwrap();
+
+        let plan = ctx.create_physical_plan(&optimized_plan).await.unwrap();
+
+        println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent());
+
+        let graph = ExecutionGraph::new("job", "session", plan).unwrap();
+
+        println!("{:?}", graph);
+
+        graph
+    }
+
+    async fn test_union_all_plan(partition: usize) -> ExecutionGraph {
+        let config = SessionConfig::new().with_target_partitions(partition);
+        let ctx = Arc::new(SessionContext::with_config(config));
+
+        let logical_plan = ctx
+            .sql("SELECT 1 as NUMBER union all SELECT 1 as NUMBER;")
+            .await
+            .unwrap()
+            .to_logical_plan()
+            .unwrap();
+
+        let optimized_plan = ctx.optimize(&logical_plan).unwrap();
+
+        let plan = ctx.create_physical_plan(&optimized_plan).await.unwrap();
+
+        println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent());
+
+        let graph = ExecutionGraph::new("job", "session", plan).unwrap();
+
+        println!("{:?}", graph);
+
+        graph
+    }
+
+    async fn test_union_plan(partition: usize) -> ExecutionGraph {
+        let config = SessionConfig::new().with_target_partitions(partition);
+        let ctx = Arc::new(SessionContext::with_config(config));
+
+        let logical_plan = ctx
+            .sql("SELECT 1 as NUMBER union SELECT 1 as NUMBER;")
+            .await
+            .unwrap()
+            .to_logical_plan()
+            .unwrap();
+
+        let optimized_plan = ctx.optimize(&logical_plan).unwrap();
+
+        let plan = ctx.create_physical_plan(&optimized_plan).await.unwrap();
+
+        println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent());
+
+        let graph = ExecutionGraph::new("job", "session", plan).unwrap();
+
+        println!("{:?}", graph);
+
+        graph
+    }
+
+    fn test_executor() -> ExecutorMetadata {
+        ExecutorMetadata {
+            id: "executor-2".to_string(),
+            host: "localhost2".to_string(),
+            port: 8080,
+            grpc_port: 9090,
+            specification: ExecutorSpecification { task_slots: 1 },
+        }
+    }
+}
diff --git a/ballista/rust/scheduler/src/state/executor_manager.rs b/ballista/rust/scheduler/src/state/executor_manager.rs
index 40821bea..ad2a4389 100644
--- a/ballista/rust/scheduler/src/state/executor_manager.rs
+++ b/ballista/rust/scheduler/src/state/executor_manager.rs
@@ -15,43 +15,354 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::time::{Duration, SystemTime, UNIX_EPOCH};
+use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
 
-use ballista_core::serde::protobuf::ExecutorHeartbeat;
-use ballista_core::serde::scheduler::{ExecutorData, ExecutorDataChange};
-use log::{error, info, warn};
+use crate::state::backend::{Keyspace, StateBackendClient, WatchEvent};
+
+use crate::state::{decode_into, decode_protobuf, encode_protobuf, with_lock};
+use ballista_core::error::{BallistaError, Result};
+use ballista_core::serde::protobuf;
+
+use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata};
+use futures::StreamExt;
+use log::{debug, info};
 use parking_lot::RwLock;
 use std::collections::{HashMap, HashSet};
 use std::sync::Arc;
 
+/// Represents a task slot that is reserved (i.e. available for scheduling but not visible to the
+/// rest of the system).
+/// When tasks finish we want to preferentially assign new tasks from the same job, so the reservation
+/// can already be assigned to a particular job ID. In that case, the scheduler will try to schedule
+/// available tasks for that job to the reserved task slot.
+#[derive(Clone, Debug)]
+pub struct ExecutorReservation {
+    pub executor_id: String,
+    pub job_id: Option<String>,
+}
+
+impl ExecutorReservation {
+    pub fn new_free(executor_id: String) -> Self {
+        Self {
+            executor_id,
+            job_id: None,
+        }
+    }
+
+    pub fn new_assigned(executor_id: String, job_id: String) -> Self {
+        Self {
+            executor_id,
+            job_id: Some(job_id),
+        }
+    }
+
+    pub fn assign(mut self, job_id: String) -> Self {
+        self.job_id = Some(job_id);
+        self
+    }
+
+    pub fn assigned(&self) -> bool {
+        self.job_id.is_some()
+    }
+}
+
 #[derive(Clone)]
 pub(crate) struct ExecutorManager {
-    executors_heartbeat: Arc<RwLock<HashMap<String, ExecutorHeartbeat>>>,
-    executors_data: Arc<RwLock<HashMap<String, ExecutorData>>>,
+    state: Arc<dyn StateBackendClient>,
+    executor_metadata: Arc<RwLock<HashMap<String, ExecutorMetadata>>>,
+    executors_heartbeat: Arc<RwLock<HashMap<String, protobuf::ExecutorHeartbeat>>>,
 }
 
 impl ExecutorManager {
-    pub(crate) fn new() -> Self {
+    pub(crate) fn new(state: Arc<dyn StateBackendClient>) -> Self {
         Self {
+            state,
+            executor_metadata: Arc::new(RwLock::new(HashMap::new())),
             executors_heartbeat: Arc::new(RwLock::new(HashMap::new())),
-            executors_data: Arc::new(RwLock::new(HashMap::new())),
         }
     }
 
-    pub(crate) fn save_executor_heartbeat(&self, heartbeat: ExecutorHeartbeat) {
+    /// Initialize the `ExecutorManager` state. This will fill the `executor_heartbeats` value
+    /// with existing heartbeats. Then new updates will be consumed through the `ExecutorHeartbeatListener`
+    pub async fn init(&self) -> Result<()> {
+        self.init_executor_heartbeats().await?;
+        let heartbeat_listener = ExecutorHeartbeatListener::new(
+            self.state.clone(),
+            self.executors_heartbeat.clone(),
+        );
+        heartbeat_listener.start().await
+    }
+
+    /// Reserve up to n executor task slots. Once reserved these slots will not be available
+    /// for scheduling.
+    /// This operation is atomic, so if this method return an Err, no slots have been reserved.
+    pub async fn reserve_slots(&self, n: u32) -> Result<Vec<ExecutorReservation>> {
+        let lock = self.state.lock(Keyspace::Slots, "global").await?;
+
+        with_lock(lock, async {
+            debug!("Attempting to reserve {} executor slots", n);
+            let start = Instant::now();
+            let mut reservations: Vec<ExecutorReservation> = vec![];
+            let mut desired: u32 = n;
+
+            let alive_executors = self.get_alive_executors_within_one_minute();
+
+            let mut txn_ops: Vec<(Keyspace, String, Vec<u8>)> = vec![];
+
+            for executor_id in alive_executors {
+                let value = self.state.get(Keyspace::Slots, &executor_id).await?;
+                let mut data =
+                    decode_into::<protobuf::ExecutorData, ExecutorData>(&value)?;
+                let take = std::cmp::min(data.available_task_slots, desired);
+
+                for _ in 0..take {
+                    reservations.push(ExecutorReservation::new_free(executor_id.clone()));
+                    data.available_task_slots -= 1;
+                    desired -= 1;
+                }
+
+                let proto: protobuf::ExecutorData = data.into();
+                let new_data = encode_protobuf(&proto)?;
+                txn_ops.push((Keyspace::Slots, executor_id, new_data));
+
+                if desired == 0 {
+                    break;
+                }
+            }
+
+            self.state.put_txn(txn_ops).await?;
+
+            let elapsed = start.elapsed();
+            info!(
+                "Reserved {} executor slots in {:?}",
+                reservations.len(),
+                elapsed
+            );
+
+            Ok(reservations)
+        })
+        .await
+    }
+
+    /// Returned reserved task slots to the pool of available slots. This operation is atomic
+    /// so either the entire pool of reserved task slots it returned or none are.
+    pub async fn cancel_reservations(
+        &self,
+        reservations: Vec<ExecutorReservation>,
+    ) -> Result<()> {
+        let lock = self.state.lock(Keyspace::Slots, "global").await?;
+
+        with_lock(lock, async {
+            let num_reservations = reservations.len();
+            debug!("Cancelling {} reservations", num_reservations);
+            let start = Instant::now();
+
+            let mut executor_slots: HashMap<String, ExecutorData> = HashMap::new();
+
+            for reservation in reservations {
+                let executor_id = &reservation.executor_id;
+                if let Some(data) = executor_slots.get_mut(executor_id) {
+                    data.available_task_slots += 1;
+                } else {
+                    let value = self.state.get(Keyspace::Slots, executor_id).await?;
+                    let mut data =
+                        decode_into::<protobuf::ExecutorData, ExecutorData>(&value)?;
+                    data.available_task_slots += 1;
+                    executor_slots.insert(executor_id.clone(), data);
+                }
+            }
+
+            let txn_ops: Vec<(Keyspace, String, Vec<u8>)> = executor_slots
+                .into_iter()
+                .map(|(executor_id, data)| {
+                    let proto: protobuf::ExecutorData = data.into();
+                    let new_data = encode_protobuf(&proto)?;
+                    Ok((Keyspace::Slots, executor_id, new_data))
+                })
+                .collect::<Result<Vec<_>>>()?;
+
+            self.state.put_txn(txn_ops).await?;
+
+            let elapsed = start.elapsed();
+            info!(
+                "Cancelled {} reservations in {:?}",
+                num_reservations, elapsed
+            );
+
+            Ok(())
+        })
+        .await
+    }
+
+    /// Get a list of all executors along with the timestamp of their last recorded heartbeat
+    pub async fn get_executor_state(&self) -> Result<Vec<(ExecutorMetadata, Duration)>> {
+        let heartbeat_timestamps: Vec<(String, u64)> = {
+            let heartbeats = self.executors_heartbeat.read();
+
+            heartbeats
+                .iter()
+                .map(|(executor_id, heartbeat)| {
+                    (executor_id.clone(), heartbeat.timestamp)
+                })
+                .collect()
+        };
+
+        let mut state: Vec<(ExecutorMetadata, Duration)> = vec![];
+        for (executor_id, ts) in heartbeat_timestamps {
+            let duration = Duration::from_secs(ts);
+
+            let metadata = self.get_executor_metadata(&executor_id).await?;
+
+            state.push((metadata, duration));
+        }
+
+        Ok(state)
+    }
+
+    pub async fn get_executor_metadata(
+        &self,
+        executor_id: &str,
+    ) -> Result<ExecutorMetadata> {
+        {
+            let metadata_cache = self.executor_metadata.read();
+            if let Some(cached) = metadata_cache.get(executor_id) {
+                return Ok(cached.clone());
+            }
+        }
+
+        let value = self.state.get(Keyspace::Executors, executor_id).await?;
+
+        let decoded =
+            decode_into::<protobuf::ExecutorMetadata, ExecutorMetadata>(&value)?;
+        Ok(decoded)
+    }
+
+    pub async fn save_executor_metadata(&self, metadata: ExecutorMetadata) -> Result<()> {
+        let executor_id = metadata.id.clone();
+        let proto: protobuf::ExecutorMetadata = metadata.into();
+        let value = encode_protobuf(&proto)?;
+
+        self.state
+            .put(Keyspace::Executors, executor_id, value)
+            .await
+    }
+
+    /// Register the executor with the scheduler. This will save the executor metadata and the
+    /// executor data to persistent state.
+    ///
+    /// If `reserve` is true, then any available task slots will be reserved and dispatched for scheduling.
+    /// If `reserve` is false, then the executor data will be saved as is.
+    ///
+    /// In general, reserve should be true is the scheduler is using push-based scheduling and false
+    /// if the scheduler is using pull-based scheduling.
+    pub async fn register_executor(
+        &self,
+        metadata: ExecutorMetadata,
+        specification: ExecutorData,
+        reserve: bool,
+    ) -> Result<Vec<ExecutorReservation>> {
+        self.test_scheduler_connectivity(&metadata).await?;
+
+        let executor_id = metadata.id.clone();
+
+        let current_ts = SystemTime::now()
+            .duration_since(UNIX_EPOCH)
+            .map_err(|e| {
+                BallistaError::Internal(format!(
+                    "Error getting current timestamp: {:?}",
+                    e
+                ))
+            })?
+            .as_secs();
+
+        //TODO this should be in a transaction
+        // Now that we know we can connect, save the metadata and slots
+        self.save_executor_metadata(metadata).await?;
+        self.save_executor_heartbeat(protobuf::ExecutorHeartbeat {
+            executor_id: executor_id.clone(),
+            timestamp: current_ts,
+            state: None,
+        })
+        .await?;
+
+        if !reserve {
+            let proto: protobuf::ExecutorData = specification.into();
+            let value = encode_protobuf(&proto)?;
+            self.state.put(Keyspace::Slots, executor_id, value).await?;
+            Ok(vec![])
+        } else {
+            let mut specification = specification;
+            let num_slots = specification.available_task_slots as usize;
+            let mut reservations: Vec<ExecutorReservation> = vec![];
+            for _ in 0..num_slots {
+                reservations.push(ExecutorReservation::new_free(executor_id.clone()));
+            }
+
+            specification.available_task_slots = 0;
+            let proto: protobuf::ExecutorData = specification.into();
+            let value = encode_protobuf(&proto)?;
+            self.state.put(Keyspace::Slots, executor_id, value).await?;
+            Ok(reservations)
+        }
+    }
+
+    #[cfg(not(test))]
+    async fn test_scheduler_connectivity(
+        &self,
+        metadata: &ExecutorMetadata,
+    ) -> Result<()> {
+        let executor_url = format!("http://{}:{}", metadata.host, metadata.grpc_port);
+        debug!("Connecting to executor {:?}", executor_url);
+        let _ = protobuf::executor_grpc_client::ExecutorGrpcClient::connect(executor_url)
+            .await
+            .map_err(|e| {
+                BallistaError::Internal(format!(
+                    "Failed to register executor at {}:{}, could not connect: {:?}",
+                    metadata.host, metadata.grpc_port, e
+                ))
+            })?;
+        Ok(())
+    }
+
+    #[cfg(test)]
+    async fn test_scheduler_connectivity(
+        &self,
+        _metadata: &ExecutorMetadata,
+    ) -> Result<()> {
+        Ok(())
+    }
+
+    pub(crate) async fn save_executor_heartbeat(
+        &self,
+        heartbeat: protobuf::ExecutorHeartbeat,
+    ) -> Result<()> {
+        let executor_id = heartbeat.executor_id.clone();
+        let value = encode_protobuf(&heartbeat)?;
+        self.state
+            .put(Keyspace::Heartbeats, executor_id, value)
+            .await?;
+
         let mut executors_heartbeat = self.executors_heartbeat.write();
         executors_heartbeat.insert(heartbeat.executor_id.clone(), heartbeat);
+
+        Ok(())
     }
 
-    pub(crate) fn get_executors_heartbeat(&self) -> Vec<ExecutorHeartbeat> {
-        let executors_heartbeat = self.executors_heartbeat.read();
-        executors_heartbeat
-            .iter()
-            .map(|(_exec, heartbeat)| heartbeat.clone())
-            .collect()
+    /// Initialize the set of executor heartbeats from storage
+    pub(crate) async fn init_executor_heartbeats(&self) -> Result<()> {
+        let heartbeats = self.state.scan(Keyspace::Heartbeats, None).await?;
+        let mut cache = self.executors_heartbeat.write();
+
+        for (_, value) in heartbeats {
+            let data: protobuf::ExecutorHeartbeat = decode_protobuf(&value)?;
+            let executor_id = data.executor_id.clone();
+            cache.insert(executor_id, data);
+        }
+        Ok(())
     }
 
-    /// last_seen_ts_threshold is in seconds
+    /// Retrieve the set of all executor IDs where the executor has been observed in the last
+    /// `last_seen_ts_threshold` seconds.
     pub(crate) fn get_alive_executors(
         &self,
         last_seen_ts_threshold: u64,
@@ -75,71 +386,227 @@ impl ExecutorManager {
             .unwrap_or_else(|| Duration::from_secs(0));
         self.get_alive_executors(last_seen_threshold.as_secs())
     }
+}
 
-    pub(crate) fn save_executor_data(&self, executor_data: ExecutorData) {
-        let mut executors_data = self.executors_data.write();
-        executors_data.insert(executor_data.executor_id.clone(), executor_data);
+/// Rather than doing a scan across persistent state to find alive executors every time
+/// we need to find the set of alive executors, we start a watch on the `Heartbeats` keyspace
+/// and maintain an in-memory copy of the executor heartbeats.
+struct ExecutorHeartbeatListener {
+    state: Arc<dyn StateBackendClient>,
+    executors_heartbeat: Arc<RwLock<HashMap<String, protobuf::ExecutorHeartbeat>>>,
+}
+
+impl ExecutorHeartbeatListener {
+    pub fn new(
+        state: Arc<dyn StateBackendClient>,
+        executors_heartbeat: Arc<RwLock<HashMap<String, protobuf::ExecutorHeartbeat>>>,
+    ) -> Self {
+        Self {
+            state,
+            executors_heartbeat,
+        }
     }
 
-    pub(crate) fn update_executor_data(&self, executor_data_change: &ExecutorDataChange) {
-        let mut executors_data = self.executors_data.write();
-        if let Some(executor_data) =
-            executors_data.get_mut(&executor_data_change.executor_id)
-        {
-            let available_task_slots = executor_data.available_task_slots as i32
-                + executor_data_change.task_slots;
-            if available_task_slots < 0 {
-                error!(
-                    "Available task slots {} for executor {} is less than 0",
-                    available_task_slots, executor_data.executor_id
-                );
-            } else {
-                info!(
-                    "available_task_slots for executor {} becomes {}",
-                    executor_data.executor_id, available_task_slots
-                );
-                executor_data.available_task_slots = available_task_slots as u32;
+    /// Spawn an sync task which will watch the the Heartbeats keyspace and insert
+    /// new heartbeats in the `executors_heartbeat` cache.
+    pub async fn start(&self) -> Result<()> {
+        let mut watch = self
+            .state
+            .watch(Keyspace::Heartbeats, "".to_owned())
+            .await?;
+        let heartbeats = self.executors_heartbeat.clone();
+        tokio::task::spawn(async move {
+            while let Some(event) = watch.next().await {
+                if let WatchEvent::Put(_, value) = event {
+                    if let Ok(data) =
+                        decode_protobuf::<protobuf::ExecutorHeartbeat>(&value)
+                    {
+                        let executor_id = data.executor_id.clone();
+                        let mut heartbeats = heartbeats.write();
+
+                        heartbeats.insert(executor_id, data);
+                    }
+                }
             }
-        } else {
-            warn!(
-                "Could not find executor data for {}",
-                executor_data_change.executor_id
-            );
+        });
+
+        Ok(())
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use crate::state::backend::standalone::StandaloneClient;
+    use crate::state::executor_manager::{ExecutorManager, ExecutorReservation};
+    use ballista_core::error::Result;
+    use ballista_core::serde::scheduler::{
+        ExecutorData, ExecutorMetadata, ExecutorSpecification,
+    };
+    use std::sync::Arc;
+
+    #[tokio::test]
+    async fn test_reserve_and_cancel() -> Result<()> {
+        let state_storage = Arc::new(StandaloneClient::try_new_temporary()?);
+
+        let executor_manager = ExecutorManager::new(state_storage);
+
+        let executors = test_executors(10, 4);
+
+        for (executor_metadata, executor_data) in executors {
+            executor_manager
+                .register_executor(executor_metadata, executor_data, false)
+                .await?;
         }
+
+        // Reserve all the slots
+        let reservations = executor_manager.reserve_slots(40).await?;
+
+        assert_eq!(reservations.len(), 40);
+
+        // Now cancel them
+        executor_manager.cancel_reservations(reservations).await?;
+
+        // Now reserve again
+        let reservations = executor_manager.reserve_slots(40).await?;
+
+        assert_eq!(reservations.len(), 40);
+
+        Ok(())
     }
 
-    pub(crate) fn get_executor_data(&self, executor_id: &str) -> Option<ExecutorData> {
-        let executors_data = self.executors_data.read();
-        executors_data.get(executor_id).cloned()
+    #[tokio::test]
+    async fn test_reserve_partial() -> Result<()> {
+        let state_storage = Arc::new(StandaloneClient::try_new_temporary()?);
+
+        let executor_manager = ExecutorManager::new(state_storage);
+
+        let executors = test_executors(10, 4);
+
+        for (executor_metadata, executor_data) in executors {
+            executor_manager
+                .register_executor(executor_metadata, executor_data, false)
+                .await?;
+        }
+
+        // Reserve all the slots
+        let reservations = executor_manager.reserve_slots(30).await?;
+
+        assert_eq!(reservations.len(), 30);
+
+        // Try to reserve 30 more. Only ten are available though so we should only get 10
+        let more_reservations = executor_manager.reserve_slots(30).await?;
+
+        assert_eq!(more_reservations.len(), 10);
+
+        // Now cancel them
+        executor_manager.cancel_reservations(reservations).await?;
+        executor_manager
+            .cancel_reservations(more_reservations)
+            .await?;
+
+        // Now reserve again
+        let reservations = executor_manager.reserve_slots(40).await?;
+
+        assert_eq!(reservations.len(), 40);
+
+        let more_reservations = executor_manager.reserve_slots(30).await?;
+
+        assert_eq!(more_reservations.len(), 0);
+
+        Ok(())
     }
 
-    /// There are two checks:
-    /// 1. firstly alive
-    /// 2. secondly available task slots > 0
-    #[cfg(not(test))]
-    #[allow(dead_code)]
-    pub(crate) fn get_available_executors_data(&self) -> Vec<ExecutorData> {
-        let mut res = {
-            let alive_executors = self.get_alive_executors_within_one_minute();
-            let executors_data = self.executors_data.read();
-            executors_data
-                .iter()
-                .filter_map(|(exec, data)| {
-                    (data.available_task_slots > 0 && alive_executors.contains(exec))
-                        .then(|| data.clone())
-                })
-                .collect::<Vec<ExecutorData>>()
-        };
-        res.sort_by(|a, b| Ord::cmp(&b.available_task_slots, &a.available_task_slots));
-        res
+    #[tokio::test]
+    async fn test_reserve_concurrent() -> Result<()> {
+        let (sender, mut receiver) =
+            tokio::sync::mpsc::channel::<Result<Vec<ExecutorReservation>>>(1000);
+
+        let executors = test_executors(10, 4);
+
+        let state_storage = Arc::new(StandaloneClient::try_new_temporary()?);
+
+        let executor_manager = ExecutorManager::new(state_storage);
+
+        for (executor_metadata, executor_data) in executors {
+            executor_manager
+                .register_executor(executor_metadata, executor_data, false)
+                .await?;
+        }
+
+        {
+            let sender = sender;
+            // Spawn 20 async tasks to each try and reserve all 40 slots
+            for _ in 0..20 {
+                let executor_manager = executor_manager.clone();
+                let sender = sender.clone();
+                tokio::task::spawn(async move {
+                    let reservations = executor_manager.reserve_slots(40).await;
+                    sender.send(reservations).await.unwrap();
+                });
+            }
+        }
+
+        let mut total_reservations: Vec<ExecutorReservation> = vec![];
+
+        while let Some(Ok(reservations)) = receiver.recv().await {
+            total_reservations.extend(reservations);
+        }
+
+        // The total number of reservations should never exceed the number of slots
+        assert_eq!(total_reservations.len(), 40);
+
+        Ok(())
     }
 
-    #[cfg(test)]
-    #[allow(dead_code)]
-    pub(crate) fn get_available_executors_data(&self) -> Vec<ExecutorData> {
-        let mut res: Vec<ExecutorData> =
-            self.executors_data.read().values().cloned().collect();
-        res.sort_by(|a, b| Ord::cmp(&b.available_task_slots, &a.available_task_slots));
-        res
+    #[tokio::test]
+    async fn test_register_reserve() -> Result<()> {
+        let state_storage = Arc::new(StandaloneClient::try_new_temporary()?);
+
+        let executor_manager = ExecutorManager::new(state_storage);
+
+        let executors = test_executors(10, 4);
+
+        for (executor_metadata, executor_data) in executors {
+            let reservations = executor_manager
+                .register_executor(executor_metadata, executor_data, true)
+                .await?;
+
+            assert_eq!(reservations.len(), 4);
+        }
+
+        // All slots should be reserved
+        let reservations = executor_manager.reserve_slots(1).await?;
+
+        assert_eq!(reservations.len(), 0);
+
+        Ok(())
+    }
+
+    fn test_executors(
+        total_executors: usize,
+        slots_per_executor: u32,
+    ) -> Vec<(ExecutorMetadata, ExecutorData)> {
+        let mut result: Vec<(ExecutorMetadata, ExecutorData)> = vec![];
+
+        for i in 0..total_executors {
+            result.push((
+                ExecutorMetadata {
+                    id: format!("executor-{}", i),
+                    host: format!("host-{}", i),
+                    port: 8080,
+                    grpc_port: 9090,
+                    specification: ExecutorSpecification {
+                        task_slots: slots_per_executor,
+                    },
+                },
+                ExecutorData {
+                    executor_id: format!("executor-{}", i),
+                    total_task_slots: slots_per_executor,
+                    available_task_slots: slots_per_executor,
+                },
+            ));
+        }
+
+        result
     }
 }
diff --git a/ballista/rust/scheduler/src/state/mod.rs b/ballista/rust/scheduler/src/state/mod.rs
index de780fab..1083665f 100644
--- a/ballista/rust/scheduler/src/state/mod.rs
+++ b/ballista/rust/scheduler/src/state/mod.rs
@@ -15,242 +15,107 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::scheduler_server::{SessionBuilder, SessionContextRegistry};
-use crate::state::backend::StateBackendClient;
-use crate::state::executor_manager::ExecutorManager;
-use crate::state::persistent_state::PersistentSchedulerState;
-use crate::state::stage_manager::StageManager;
-use ballista_core::error::Result;
-use ballista_core::serde::protobuf::{ExecutorHeartbeat, JobStatus, KeyValuePair};
-use ballista_core::serde::scheduler::ExecutorMetadata;
+use std::any::type_name;
+use std::future::Future;
+
+use std::sync::Arc;
+
+use prost::Message;
+
+use ballista_core::error::{BallistaError, Result};
+
+use crate::scheduler_server::SessionBuilder;
+
 use ballista_core::serde::{AsExecutionPlan, BallistaCodec};
-use datafusion::physical_plan::ExecutionPlan;
 use datafusion_proto::logical_plan::AsLogicalPlan;
-use std::collections::HashMap;
-use std::sync::Arc;
-use std::time::{Duration, SystemTime, UNIX_EPOCH};
+
+use crate::state::backend::{Lock, StateBackendClient};
+
+use crate::state::executor_manager::ExecutorManager;
+use crate::state::session_manager::SessionManager;
+use crate::state::task_manager::TaskManager;
 
 pub mod backend;
-mod executor_manager;
-mod persistent_state;
-mod stage_manager;
-pub mod task_scheduler;
+pub mod execution_graph;
+pub mod executor_manager;
+pub mod session_manager;
+pub mod session_registry;
+mod task_manager;
+
+pub fn decode_protobuf<T: Message + Default>(bytes: &[u8]) -> Result<T> {
+    T::decode(bytes).map_err(|e| {
+        BallistaError::Internal(format!(
+            "Could not deserialize {}: {}",
+            type_name::<T>(),
+            e
+        ))
+    })
+}
+
+pub fn decode_into<T: Message + Default, U: From<T>>(bytes: &[u8]) -> Result<U> {
+    T::decode(bytes)
+        .map_err(|e| {
+            BallistaError::Internal(format!(
+                "Could not deserialize {}: {}",
+                type_name::<T>(),
+                e
+            ))
+        })
+        .map(|t| t.into())
+}
+
+pub fn encode_protobuf<T: Message + Default>(msg: &T) -> Result<Vec<u8>> {
+    let mut value: Vec<u8> = Vec::with_capacity(msg.encoded_len());
+    msg.encode(&mut value).map_err(|e| {
+        BallistaError::Internal(format!(
+            "Could not serialize {}: {}",
+            type_name::<T>(),
+            e
+        ))
+    })?;
+    Ok(value)
+}
 
 #[derive(Clone)]
 pub(super) struct SchedulerState<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
 {
-    persistent_state: PersistentSchedulerState<T, U>,
     pub executor_manager: ExecutorManager,
-    pub stage_manager: StageManager,
+    pub task_manager: TaskManager<T, U>,
+    pub session_manager: SessionManager,
+    _codec: BallistaCodec<T, U>,
 }
 
 impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerState<T, U> {
     pub fn new(
         config_client: Arc<dyn StateBackendClient>,
-        namespace: String,
+        _namespace: String,
         session_builder: SessionBuilder,
         codec: BallistaCodec<T, U>,
     ) -> Self {
         Self {
-            persistent_state: PersistentSchedulerState::new(
-                config_client,
-                namespace,
+            executor_manager: ExecutorManager::new(config_client.clone()),
+            task_manager: TaskManager::new(
+                config_client.clone(),
                 session_builder,
-                codec,
+                codec.clone(),
             ),
-            executor_manager: ExecutorManager::new(),
-            stage_manager: StageManager::new(),
+            session_manager: SessionManager::new(config_client, session_builder),
+            _codec: codec,
         }
     }
 
     pub async fn init(&self) -> Result<()> {
-        self.persistent_state.init().await?;
-
-        Ok(())
-    }
-
-    pub fn get_codec(&self) -> &BallistaCodec<T, U> {
-        &self.persistent_state.codec
-    }
-
-    pub async fn get_executors_metadata(
-        &self,
-    ) -> Result<Vec<(ExecutorMetadata, Duration)>> {
-        let mut result = vec![];
-
-        let executors_heartbeat = self
-            .executor_manager
-            .get_executors_heartbeat()
-            .into_iter()
-            .map(|heartbeat| (heartbeat.executor_id.clone(), heartbeat))
-            .collect::<HashMap<String, ExecutorHeartbeat>>();
-
-        let executors_metadata = self.persistent_state.get_executors_metadata();
-
-        let now_epoch_ts = SystemTime::now()
-            .duration_since(UNIX_EPOCH)
-            .expect("Time went backwards");
-
-        for meta in executors_metadata.into_iter() {
-            // If there's no heartbeat info for an executor, regard its heartbeat timestamp as 0
-            // so that it will always be excluded when requesting alive executors
-            let ts = executors_heartbeat
-                .get(&meta.id)
-                .map(|heartbeat| Duration::from_secs(heartbeat.timestamp))
-                .unwrap_or_else(|| Duration::from_secs(0));
-            let time_since_last_seen = now_epoch_ts
-                .checked_sub(ts)
-                .unwrap_or_else(|| Duration::from_secs(0));
-            result.push((meta, time_since_last_seen));
-        }
-        Ok(result)
-    }
-
-    pub fn get_executor_metadata(&self, executor_id: &str) -> Option<ExecutorMetadata> {
-        self.persistent_state.get_executor_metadata(executor_id)
-    }
-
-    pub async fn save_executor_metadata(
-        &self,
-        executor_meta: ExecutorMetadata,
-    ) -> Result<()> {
-        self.persistent_state
-            .save_executor_metadata(executor_meta)
-            .await
-    }
-
-    pub async fn save_job_session(
-        &self,
-        job_id: &str,
-        session_id: &str,
-        configs: Vec<KeyValuePair>,
-    ) -> Result<()> {
-        self.persistent_state
-            .save_job_session(job_id, session_id, configs)
-            .await
-    }
-
-    pub fn get_session_from_job(&self, job_id: &str) -> Option<String> {
-        self.persistent_state.get_session_from_job(job_id)
-    }
-
-    pub async fn save_job_metadata(
-        &self,
-        job_id: &str,
-        status: &JobStatus,
-    ) -> Result<()> {
-        self.persistent_state
-            .save_job_metadata(job_id, status)
-            .await
-    }
-
-    pub fn get_job_metadata(&self, job_id: &str) -> Option<JobStatus> {
-        self.persistent_state.get_job_metadata(job_id)
-    }
-
-    pub async fn save_stage_plan(
-        &self,
-        job_id: &str,
-        stage_id: usize,
-        plan: Arc<dyn ExecutionPlan>,
-    ) -> Result<()> {
-        self.persistent_state
-            .save_stage_plan(job_id, stage_id, plan)
-            .await
+        self.executor_manager.init().await
     }
+}
 
-    pub fn get_stage_plan(
-        &self,
-        job_id: &str,
-        stage_id: usize,
-    ) -> Option<Arc<dyn ExecutionPlan>> {
-        self.persistent_state.get_stage_plan(job_id, stage_id)
-    }
+pub async fn with_lock<Out, F: Future<Output = Out>>(lock: Box<dyn Lock>, op: F) -> Out {
+    let mut lock = lock;
+    let result = op.await;
+    lock.unlock().await;
 
-    pub fn session_registry(&self) -> Arc<SessionContextRegistry> {
-        self.persistent_state.session_registry()
-    }
+    result
 }
 
 #[cfg(all(test, feature = "sled"))]
-mod test {
-    use std::sync::Arc;
-
-    use ballista_core::error::BallistaError;
-    use ballista_core::serde::protobuf::{
-        job_status, JobStatus, PhysicalPlanNode, QueuedJob,
-    };
-    use ballista_core::serde::scheduler::{ExecutorMetadata, ExecutorSpecification};
-    use ballista_core::serde::BallistaCodec;
-    use datafusion::execution::context::default_session_builder;
-    use datafusion_proto::protobuf::LogicalPlanNode;
-
-    use super::{backend::standalone::StandaloneClient, SchedulerState};
-
-    #[tokio::test]
-    async fn executor_metadata() -> Result<(), BallistaError> {
-        let state: SchedulerState<LogicalPlanNode, PhysicalPlanNode> =
-            SchedulerState::new(
-                Arc::new(StandaloneClient::try_new_temporary()?),
-                "test".to_string(),
-                default_session_builder,
-                BallistaCodec::default(),
-            );
-        let meta = ExecutorMetadata {
-            id: "123".to_owned(),
-            host: "localhost".to_owned(),
-            port: 123,
-            grpc_port: 124,
-            specification: ExecutorSpecification { task_slots: 2 },
-        };
-        state.save_executor_metadata(meta.clone()).await?;
-        let result: Vec<_> = state
-            .get_executors_metadata()
-            .await?
-            .into_iter()
-            .map(|(meta, _)| meta)
-            .collect();
-        assert_eq!(vec![meta], result);
-        Ok(())
-    }
-
-    #[tokio::test]
-    async fn job_metadata() -> Result<(), BallistaError> {
-        let state: SchedulerState<LogicalPlanNode, PhysicalPlanNode> =
-            SchedulerState::new(
-                Arc::new(StandaloneClient::try_new_temporary()?),
-                "test".to_string(),
-                default_session_builder,
-                BallistaCodec::default(),
-            );
-        let meta = JobStatus {
-            status: Some(job_status::Status::Queued(QueuedJob {})),
-        };
-        state.save_job_metadata("job", &meta).await?;
-        let result = state.get_job_metadata("job").unwrap();
-        assert!(result.status.is_some());
-        match result.status.unwrap() {
-            job_status::Status::Queued(_) => (),
-            _ => panic!("Unexpected status"),
-        }
-        Ok(())
-    }
-
-    #[tokio::test]
-    async fn job_metadata_non_existant() -> Result<(), BallistaError> {
-        let state: SchedulerState<LogicalPlanNode, PhysicalPlanNode> =
-            SchedulerState::new(
-                Arc::new(StandaloneClient::try_new_temporary()?),
-                "test".to_string(),
-                default_session_builder,
-                BallistaCodec::default(),
-            );
-        let meta = JobStatus {
-            status: Some(job_status::Status::Queued(QueuedJob {})),
-        };
-        state.save_job_metadata("job", &meta).await?;
-        let result = state.get_job_metadata("job2");
-        assert!(result.is_none());
-        Ok(())
-    }
-}
+mod test {}
diff --git a/ballista/rust/scheduler/src/state/persistent_state.rs b/ballista/rust/scheduler/src/state/persistent_state.rs
deleted file mode 100644
index e9692482..00000000
--- a/ballista/rust/scheduler/src/state/persistent_state.rs
+++ /dev/null
@@ -1,525 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-use crate::scheduler_server::{
-    create_datafusion_context, SessionBuilder, SessionContextRegistry,
-};
-use crate::state::backend::StateBackendClient;
-use crate::state::stage_manager::StageKey;
-use ballista_core::config::BallistaConfig;
-use ballista_core::error::{BallistaError, Result};
-use ballista_core::serde::protobuf::{JobSessionConfig, JobStatus, KeyValuePair};
-use ballista_core::serde::scheduler::ExecutorMetadata;
-use ballista_core::serde::{protobuf, AsExecutionPlan, BallistaCodec};
-use datafusion::physical_plan::ExecutionPlan;
-use datafusion_proto::logical_plan::AsLogicalPlan;
-use log::{debug, error};
-use parking_lot::RwLock;
-use prost::Message;
-use std::any::type_name;
-use std::collections::HashMap;
-use std::ops::Deref;
-use std::sync::Arc;
-
-#[derive(Clone)]
-pub(crate) struct PersistentSchedulerState<
-    T: 'static + AsLogicalPlan,
-    U: 'static + AsExecutionPlan,
-> {
-    // for db
-    config_client: Arc<dyn StateBackendClient>,
-    namespace: String,
-    pub(crate) codec: BallistaCodec<T, U>,
-
-    // for in-memory cache
-    executors_metadata: Arc<RwLock<HashMap<String, ExecutorMetadata>>>,
-
-    // TODO add remove logic
-    jobs: Arc<RwLock<HashMap<String, JobStatus>>>,
-    stages: Arc<RwLock<HashMap<StageKey, Arc<dyn ExecutionPlan>>>>,
-    job2session: Arc<RwLock<HashMap<String, String>>>,
-
-    /// DataFusion session contexts that are registered within the Scheduler
-    session_context_registry: Arc<SessionContextRegistry>,
-
-    session_builder: SessionBuilder,
-}
-
-impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
-    PersistentSchedulerState<T, U>
-{
-    pub(crate) fn new(
-        config_client: Arc<dyn StateBackendClient>,
-        namespace: String,
-        session_builder: SessionBuilder,
-        codec: BallistaCodec<T, U>,
-    ) -> Self {
-        Self {
-            config_client,
-            namespace,
-            codec,
-            executors_metadata: Arc::new(RwLock::new(HashMap::new())),
-            jobs: Arc::new(RwLock::new(HashMap::new())),
-            stages: Arc::new(RwLock::new(HashMap::new())),
-            job2session: Arc::new(RwLock::new(HashMap::new())),
-            session_context_registry: Arc::new(SessionContextRegistry::default()),
-            session_builder,
-        }
-    }
-
-    /// Load the state stored in storage into memory
-    pub(crate) async fn init(&self) -> Result<()> {
-        self.init_executors_metadata_from_storage().await?;
-        self.init_jobs_from_storage().await?;
-        self.init_stages_from_storage().await?;
-
-        Ok(())
-    }
-
-    async fn init_executors_metadata_from_storage(&self) -> Result<()> {
-        let entries = self
-            .config_client
-            .get_from_prefix(&get_executors_metadata_prefix(&self.namespace))
-            .await?;
-
-        let mut executors_metadata = self.executors_metadata.write();
-        for (_key, entry) in entries {
-            let meta: protobuf::ExecutorMetadata = decode_protobuf(&entry)?;
-            executors_metadata.insert(meta.id.clone(), meta.into());
-        }
-
-        Ok(())
-    }
-
-    async fn init_jobs_from_storage(&self) -> Result<()> {
-        let entries = self
-            .config_client
-            .get_from_prefix(&get_job_prefix(&self.namespace))
-            .await?;
-
-        let mut jobs = self.jobs.write();
-        for (key, entry) in entries {
-            let job: JobStatus = decode_protobuf(&entry)?;
-            let job_id = extract_job_id_from_job_key(&key)
-                .map(|job_id| job_id.to_string())
-                .unwrap();
-            jobs.insert(job_id, job);
-        }
-
-        Ok(())
-    }
-
-    async fn init_stages_from_storage(&self) -> Result<()> {
-        let entries = self
-            .config_client
-            .get_from_prefix(&get_stage_prefix(&self.namespace))
-            .await?;
-
-        let mut tmp_stages: HashMap<StageKey, Arc<dyn ExecutionPlan>> = HashMap::new();
-        {
-            for (key, entry) in entries {
-                let (job_id, stage_id) = extract_stage_id_from_stage_key(&key).unwrap();
-                let job_session = self
-                    .config_client
-                    .get(&get_job_config_key(&self.namespace, &job_id))
-                    .await?;
-                let job_session: JobSessionConfig = decode_protobuf(&job_session)?;
-
-                // Rebuild SessionContext from serialized settings
-                let mut config_builder = BallistaConfig::builder();
-                for kv_pair in &job_session.configs {
-                    config_builder = config_builder.set(&kv_pair.key, &kv_pair.value);
-                }
-                let config = config_builder.build().map_err(|e| {
-                    let msg = format!("Could not parse configs: {}", e);
-                    error!("{}", msg);
-                    BallistaError::Internal(format!(
-                        "Error building configs for job ID {}",
-                        job_id
-                    ))
-                })?;
-
-                let session_ctx =
-                    create_datafusion_context(&config, self.session_builder);
-                self.session_registry()
-                    .register_session(session_ctx.clone())
-                    .await;
-
-                let value = U::try_decode(&entry)?;
-                let runtime = session_ctx.runtime_env();
-                let plan = value.try_into_physical_plan(
-                    session_ctx.deref(),
-                    runtime.deref(),
-                    self.codec.physical_extension_codec(),
-                )?;
-
-                let mut job2_sess = self.job2session.write();
-                job2_sess.insert(job_id.clone(), job_session.session_id);
-
-                tmp_stages.insert((job_id, stage_id), plan);
-            }
-        }
-        let mut stages = self.stages.write();
-        for tmp_stage in tmp_stages {
-            stages.insert(tmp_stage.0, tmp_stage.1);
-        }
-        Ok(())
-    }
-
-    pub(crate) async fn save_executor_metadata(
-        &self,
-        executor_meta: ExecutorMetadata,
-    ) -> Result<()> {
-        {
-            // Save in db
-            let key = get_executor_metadata_key(&self.namespace, &executor_meta.id);
-            let value = {
-                let executor_meta: protobuf::ExecutorMetadata =
-                    executor_meta.clone().into();
-                encode_protobuf(&executor_meta)?
-            };
-            self.synchronize_save(key, value).await?;
-        }
-
-        {
-            // Save in memory
-            let mut executors_metadata = self.executors_metadata.write();
-            executors_metadata.insert(executor_meta.id.clone(), executor_meta);
-        }
-
-        Ok(())
-    }
-
-    pub(crate) fn get_executor_metadata(
-        &self,
-        executor_id: &str,
-    ) -> Option<ExecutorMetadata> {
-        let executors_metadata = self.executors_metadata.read();
-        executors_metadata.get(executor_id).cloned()
-    }
-
-    pub(crate) fn get_executors_metadata(&self) -> Vec<ExecutorMetadata> {
-        let executors_metadata = self.executors_metadata.read();
-        executors_metadata.values().cloned().collect()
-    }
-
-    pub(crate) async fn save_job_session(
-        &self,
-        job_id: &str,
-        session_id: &str,
-        configs: Vec<KeyValuePair>,
-    ) -> Result<()> {
-        let key = get_job_config_key(&self.namespace, job_id);
-        let value = encode_protobuf(&protobuf::JobSessionConfig {
-            session_id: session_id.to_string(),
-            configs,
-        })?;
-
-        self.synchronize_save(key, value).await?;
-
-        let mut job2_sess = self.job2session.write();
-        job2_sess.insert(job_id.to_string(), session_id.to_string());
-
-        Ok(())
-    }
-
-    pub(crate) fn get_session_from_job(&self, job_id: &str) -> Option<String> {
-        let job_session = self.job2session.read();
-        job_session.get(job_id).cloned()
-    }
-
-    pub(crate) async fn save_job_metadata(
-        &self,
-        job_id: &str,
-        status: &JobStatus,
-    ) -> Result<()> {
-        debug!("Saving job metadata: {:?}", status);
-        {
-            // Save in db
-            let key = get_job_key(&self.namespace, job_id);
-            let value = encode_protobuf(status)?;
-            self.synchronize_save(key, value).await?;
-        }
-
-        {
-            // Save in memory
-            let mut jobs = self.jobs.write();
-            jobs.insert(job_id.to_string(), status.clone());
-        }
-
-        Ok(())
-    }
-
-    pub(crate) fn get_job_metadata(&self, job_id: &str) -> Option<JobStatus> {
-        let jobs = self.jobs.read();
-        jobs.get(job_id).cloned()
-    }
-
-    pub(crate) async fn save_stage_plan(
-        &self,
-        job_id: &str,
-        stage_id: usize,
-        plan: Arc<dyn ExecutionPlan>,
-    ) -> Result<()> {
-        {
-            // Save in db
-            let key = get_stage_plan_key(&self.namespace, job_id, stage_id as u32);
-            let value = {
-                let mut buf: Vec<u8> = vec![];
-                let proto = U::try_from_physical_plan(
-                    plan.clone(),
-                    self.codec.physical_extension_codec(),
-                )?;
-                proto.try_encode(&mut buf)?;
-
-                buf
-            };
-            self.synchronize_save(key, value).await?;
-        }
-
-        {
-            // Save in memory
-            let mut stages = self.stages.write();
-            stages.insert((job_id.to_string(), stage_id as u32), plan);
-        }
-
-        Ok(())
-    }
-
-    pub(crate) fn get_stage_plan(
-        &self,
-        job_id: &str,
-        stage_id: usize,
-    ) -> Option<Arc<dyn ExecutionPlan>> {
-        let stages = self.stages.read();
-        let key = (job_id.to_string(), stage_id as u32);
-        stages.get(&key).cloned()
-    }
-
-    async fn synchronize_save(&self, key: String, value: Vec<u8>) -> Result<()> {
-        let mut lock = self.config_client.lock().await?;
-        self.config_client.put(key, value).await?;
-        lock.unlock().await;
-
-        Ok(())
-    }
-
-    pub fn session_registry(&self) -> Arc<SessionContextRegistry> {
-        self.session_context_registry.clone()
-    }
-}
-
-fn get_executors_metadata_prefix(namespace: &str) -> String {
-    format!("/ballista/{}/executor_metadata", namespace)
-}
-
-fn get_executor_metadata_key(namespace: &str, id: &str) -> String {
-    format!("{}/{}", get_executors_metadata_prefix(namespace), id)
-}
-
-fn get_job_prefix(namespace: &str) -> String {
-    format!("/ballista/{}/jobs", namespace)
-}
-
-fn get_job_key(namespace: &str, id: &str) -> String {
-    format!("{}/{}", get_job_prefix(namespace), id)
-}
-
-fn get_job_config_key(namespace: &str, id: &str) -> String {
-    format!("config/{}/{}", get_job_prefix(namespace), id)
-}
-
-fn get_stage_prefix(namespace: &str) -> String {
-    format!("/ballista/{}/stages", namespace,)
-}
-
-fn get_stage_plan_key(namespace: &str, job_id: &str, stage_id: u32) -> String {
-    format!("{}/{}/{}", get_stage_prefix(namespace), job_id, stage_id,)
-}
-fn extract_job_id_from_job_key(job_key: &str) -> Result<&str> {
-    job_key.split('/').nth(2).ok_or_else(|| {
-        BallistaError::Internal(format!("Unexpected task key: {}", job_key))
-    })
-}
-
-fn extract_stage_id_from_stage_key(stage_key: &str) -> Result<StageKey> {
-    let splits: Vec<&str> = stage_key.split('/').collect();
-    if splits.len() > 4 {
-        Ok((
-            splits[splits.len() - 2].to_string(),
-            splits[splits.len() - 1].parse::<u32>().map_err(|e| {
-                BallistaError::Internal(format!(
-                    "Invalid stage ID in stage key: {}, {:?}",
-                    stage_key, e
-                ))
-            })?,
-        ))
-    } else {
-        Err(BallistaError::Internal(format!(
-            "Unexpected stage key: {}",
-            stage_key
-        )))
-    }
-}
-
-fn decode_protobuf<T: Message + Default>(bytes: &[u8]) -> Result<T> {
-    T::decode(bytes).map_err(|e| {
-        BallistaError::Internal(format!(
-            "Could not deserialize {}: {}",
-            type_name::<T>(),
-            e
-        ))
-    })
-}
-
-fn encode_protobuf<T: Message + Default>(msg: &T) -> Result<Vec<u8>> {
-    let mut value: Vec<u8> = Vec::with_capacity(msg.encoded_len());
-    msg.encode(&mut value).map_err(|e| {
-        BallistaError::Internal(format!(
-            "Could not serialize {}: {}",
-            type_name::<T>(),
-            e
-        ))
-    })?;
-    Ok(value)
-}
-
-#[cfg(test)]
-mod test {
-    use super::extract_stage_id_from_stage_key;
-    use crate::state::backend::standalone::StandaloneClient;
-
-    use crate::state::persistent_state::PersistentSchedulerState;
-
-    use ballista_core::serde::protobuf::job_status::Status;
-    use ballista_core::serde::protobuf::{JobStatus, PhysicalPlanNode, QueuedJob};
-    use ballista_core::serde::BallistaCodec;
-    use datafusion::execution::context::default_session_builder;
-    use datafusion::logical_plan::LogicalPlanBuilder;
-    use datafusion::prelude::SessionContext;
-    use datafusion_proto::protobuf::LogicalPlanNode;
-
-    use std::sync::Arc;
-
-    #[test]
-    fn test_extract_stage_id_from_stage_key() {
-        let (job_id, stage_id) =
-            extract_stage_id_from_stage_key("/ballista/default/stages/2Yoyba8/1")
-                .expect("extracting stage key");
-
-        assert_eq!(job_id.as_str(), "2Yoyba8");
-        assert_eq!(stage_id, 1);
-
-        let (job_id, stage_id) =
-            extract_stage_id_from_stage_key("ballista/default/stages/2Yoyba8/1")
-                .expect("extracting stage key");
-
-        assert_eq!(job_id.as_str(), "2Yoyba8");
-        assert_eq!(stage_id, 1);
-
-        let (job_id, stage_id) =
-            extract_stage_id_from_stage_key("ballista//stages/2Yoyba8/1")
-                .expect("extracting stage key");
-
-        assert_eq!(job_id.as_str(), "2Yoyba8");
-        assert_eq!(stage_id, 1);
-    }
-
-    #[tokio::test]
-    async fn test_init_from_storage() {
-        let ctx = SessionContext::new();
-
-        let plan = LogicalPlanBuilder::empty(true)
-            .build()
-            .expect("create empty logical plan");
-        let plan = ctx
-            .create_physical_plan(&plan)
-            .await
-            .expect("create physical plan");
-
-        let expected_plan = format!("{:?}", plan);
-
-        let job_id = "job-id".to_string();
-        let session_id = "session-id".to_string();
-
-        let config_client = Arc::new(
-            StandaloneClient::try_new_temporary().expect("creating config client"),
-        );
-
-        let persistent_state: PersistentSchedulerState<
-            LogicalPlanNode,
-            PhysicalPlanNode,
-        > = PersistentSchedulerState::new(
-            config_client.clone(),
-            "default".to_string(),
-            default_session_builder,
-            BallistaCodec::default(),
-        );
-
-        persistent_state
-            .save_job_session(&job_id, &session_id, vec![])
-            .await
-            .expect("saving session");
-        persistent_state
-            .save_job_metadata(
-                &job_id,
-                &JobStatus {
-                    status: Some(Status::Queued(QueuedJob {})),
-                },
-            )
-            .await
-            .expect("saving job metadata");
-        persistent_state
-            .save_stage_plan(&job_id, 1, plan)
-            .await
-            .expect("saving stage plan");
-
-        assert_eq!(
-            persistent_state
-                .get_stage_plan(&job_id, 1)
-                .map(|plan| format!("{:?}", plan)),
-            Some(expected_plan.clone())
-        );
-        assert_eq!(
-            persistent_state.get_session_from_job(&job_id),
-            Some("session-id".to_string())
-        );
-
-        let persistent_state: PersistentSchedulerState<
-            LogicalPlanNode,
-            PhysicalPlanNode,
-        > = PersistentSchedulerState::new(
-            config_client.clone(),
-            "default".to_string(),
-            default_session_builder,
-            BallistaCodec::default(),
-        );
-
-        persistent_state.init().await.expect("initializing state");
-
-        assert_eq!(
-            persistent_state
-                .get_stage_plan(&job_id, 1)
-                .map(|plan| format!("{:?}", plan)),
-            Some(expected_plan.clone())
-        );
-        assert_eq!(
-            persistent_state.get_session_from_job(&job_id),
-            Some("session-id".to_string())
-        );
-    }
-}
diff --git a/ballista/rust/scheduler/src/state/session_manager.rs b/ballista/rust/scheduler/src/state/session_manager.rs
new file mode 100644
index 00000000..0c67204d
--- /dev/null
+++ b/ballista/rust/scheduler/src/state/session_manager.rs
@@ -0,0 +1,144 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::scheduler_server::SessionBuilder;
+use crate::state::backend::{Keyspace, StateBackendClient};
+use crate::state::{decode_protobuf, encode_protobuf};
+use ballista_core::config::BallistaConfig;
+use ballista_core::error::Result;
+use ballista_core::serde::protobuf::{self, KeyValuePair};
+use datafusion::prelude::{SessionConfig, SessionContext};
+
+use std::sync::Arc;
+
+#[derive(Clone)]
+pub struct SessionManager {
+    state: Arc<dyn StateBackendClient>,
+    session_builder: SessionBuilder,
+}
+
+impl SessionManager {
+    pub fn new(
+        state: Arc<dyn StateBackendClient>,
+        session_builder: SessionBuilder,
+    ) -> Self {
+        Self {
+            state,
+            session_builder,
+        }
+    }
+
+    pub async fn update_session(
+        &self,
+        session_id: &str,
+        config: &BallistaConfig,
+    ) -> Result<Arc<SessionContext>> {
+        let mut settings: Vec<KeyValuePair> = vec![];
+
+        for (key, value) in config.settings() {
+            settings.push(KeyValuePair {
+                key: key.clone(),
+                value: value.clone(),
+            })
+        }
+
+        let value = encode_protobuf(&protobuf::SessionSettings { configs: settings })?;
+        self.state
+            .put(Keyspace::Sessions, session_id.to_owned(), value)
+            .await?;
+
+        Ok(create_datafusion_context(config, self.session_builder))
+    }
+
+    pub async fn create_session(
+        &self,
+        config: &BallistaConfig,
+    ) -> Result<Arc<SessionContext>> {
+        let mut settings: Vec<KeyValuePair> = vec![];
+
+        for (key, value) in config.settings() {
+            settings.push(KeyValuePair {
+                key: key.clone(),
+                value: value.clone(),
+            })
+        }
+
+        let mut config_builder = BallistaConfig::builder();
+        for kv_pair in &settings {
+            config_builder = config_builder.set(&kv_pair.key, &kv_pair.value);
+        }
+        let config = config_builder.build()?;
+
+        let ctx = create_datafusion_context(&config, self.session_builder);
+
+        let value = encode_protobuf(&protobuf::SessionSettings { configs: settings })?;
+        self.state
+            .put(Keyspace::Sessions, ctx.session_id(), value)
+            .await?;
+
+        Ok(ctx)
+    }
+
+    pub async fn get_session(&self, session_id: &str) -> Result<Arc<SessionContext>> {
+        let value = self.state.get(Keyspace::Sessions, session_id).await?;
+
+        let settings: protobuf::SessionSettings = decode_protobuf(&value)?;
+
+        let mut config_builder = BallistaConfig::builder();
+        for kv_pair in &settings.configs {
+            config_builder = config_builder.set(&kv_pair.key, &kv_pair.value);
+        }
+        let config = config_builder.build()?;
+
+        Ok(create_datafusion_context(&config, self.session_builder))
+    }
+}
+
+/// Create a DataFusion session context that is compatible with Ballista Configuration
+pub fn create_datafusion_context(
+    config: &BallistaConfig,
+    session_builder: SessionBuilder,
+) -> Arc<SessionContext> {
+    let config = SessionConfig::new()
+        .with_target_partitions(config.default_shuffle_partitions())
+        .with_batch_size(config.default_batch_size())
+        .with_repartition_joins(config.repartition_joins())
+        .with_repartition_aggregations(config.repartition_aggregations())
+        .with_repartition_windows(config.repartition_windows())
+        .with_parquet_pruning(config.parquet_pruning());
+    let session_state = session_builder(config);
+    Arc::new(SessionContext::with_state(session_state))
+}
+
+/// Update the existing DataFusion session context with Ballista Configuration
+pub fn update_datafusion_context(
+    session_ctx: Arc<SessionContext>,
+    config: &BallistaConfig,
+) -> Arc<SessionContext> {
+    {
+        let mut mut_state = session_ctx.state.write();
+        // TODO Currently we have to start from default session config due to the interface not support update
+        mut_state.config = SessionConfig::default()
+            .with_target_partitions(config.default_shuffle_partitions())
+            .with_batch_size(config.default_batch_size())
+            .with_repartition_joins(config.repartition_joins())
+            .with_repartition_aggregations(config.repartition_aggregations())
+            .with_repartition_windows(config.repartition_windows())
+            .with_parquet_pruning(config.parquet_pruning());
+    }
+    session_ctx
+}
diff --git a/ballista/rust/scheduler/src/state/session_registry.rs b/ballista/rust/scheduler/src/state/session_registry.rs
new file mode 100644
index 00000000..1281449b
--- /dev/null
+++ b/ballista/rust/scheduler/src/state/session_registry.rs
@@ -0,0 +1,68 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion::prelude::SessionContext;
+use std::collections::HashMap;
+use std::sync::Arc;
+use tokio::sync::RwLock;
+
+/// A Registry holds all the datafusion session contexts
+pub struct SessionContextRegistry {
+    /// A map from session_id to SessionContext
+    pub running_sessions: RwLock<HashMap<String, Arc<SessionContext>>>,
+}
+
+impl Default for SessionContextRegistry {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl SessionContextRegistry {
+    /// Create the registry that session contexts can registered into.
+    /// ['LocalFileSystem'] store is registered in by default to support read local files natively.
+    pub fn new() -> Self {
+        Self {
+            running_sessions: RwLock::new(HashMap::new()),
+        }
+    }
+
+    /// Adds a new session to this registry.
+    pub async fn register_session(
+        &self,
+        session_ctx: Arc<SessionContext>,
+    ) -> Option<Arc<SessionContext>> {
+        let session_id = session_ctx.session_id();
+        let mut sessions = self.running_sessions.write().await;
+        sessions.insert(session_id, session_ctx)
+    }
+
+    /// Lookup the session context registered
+    pub async fn lookup_session(&self, session_id: &str) -> Option<Arc<SessionContext>> {
+        let sessions = self.running_sessions.read().await;
+        sessions.get(session_id).cloned()
+    }
+
+    /// Remove a session from this registry.
+    pub async fn unregister_session(
+        &self,
+        session_id: &str,
+    ) -> Option<Arc<SessionContext>> {
+        let mut sessions = self.running_sessions.write().await;
+        sessions.remove(session_id)
+    }
+}
diff --git a/ballista/rust/scheduler/src/state/stage_manager.rs b/ballista/rust/scheduler/src/state/stage_manager.rs
deleted file mode 100644
index e926c1db..00000000
--- a/ballista/rust/scheduler/src/state/stage_manager.rs
+++ /dev/null
@@ -1,783 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-use std::collections::{HashMap, HashSet};
-use std::sync::Arc;
-
-use log::{debug, error, warn};
-use parking_lot::RwLock;
-use rand::Rng;
-
-use crate::scheduler_server::event::QueryStageSchedulerEvent;
-use crate::state::task_scheduler::StageScheduler;
-use ballista_core::error::{BallistaError, Result};
-use ballista_core::serde::protobuf;
-use ballista_core::serde::protobuf::{task_status, FailedTask, TaskStatus};
-
-/// job_id + stage_id
-pub type StageKey = (String, u32);
-
-#[derive(Clone)]
-pub struct StageManager {
-    stage_distribution: Arc<RwLock<StageDistribution>>,
-
-    // The final stage id for jobs
-    final_stages: Arc<RwLock<HashMap<String, u32>>>,
-
-    // (job_id, stage_id) -> stage set in which each one depends on (job_id, stage_id)
-    stages_dependency: Arc<RwLock<HashMap<StageKey, HashSet<u32>>>>,
-
-    // job_id -> pending stages
-    pending_stages: Arc<RwLock<HashMap<String, HashSet<u32>>>>,
-}
-
-impl StageManager {
-    pub fn new() -> Self {
-        Self {
-            stage_distribution: Arc::new(RwLock::new(StageDistribution::new())),
-            final_stages: Arc::new(RwLock::new(HashMap::new())),
-            stages_dependency: Arc::new(RwLock::new(HashMap::new())),
-            pending_stages: Arc::new(RwLock::new(HashMap::new())),
-        }
-    }
-
-    pub fn add_final_stage(&self, job_id: &str, stage_id: u32) {
-        let mut final_stages = self.final_stages.write();
-        final_stages.insert(job_id.to_owned(), stage_id);
-    }
-
-    pub fn is_final_stage(&self, job_id: &str, stage_id: u32) -> bool {
-        self.get_final_stage_id(job_id)
-            .map(|final_stage_id| final_stage_id == stage_id)
-            .unwrap_or(false)
-    }
-
-    fn get_final_stage_id(&self, job_id: &str) -> Option<u32> {
-        let final_stages = self.final_stages.read();
-        final_stages.get(job_id).cloned()
-    }
-
-    pub fn get_tasks_for_complete_final_stage(
-        &self,
-        job_id: &str,
-    ) -> Result<Vec<Arc<TaskStatus>>> {
-        let final_stage_id = self.get_final_stage_id(job_id).ok_or_else(|| {
-            BallistaError::General(format!(
-                "Fail to find final stage id for job {}",
-                job_id
-            ))
-        })?;
-
-        let stage_key = (job_id.to_owned(), final_stage_id);
-        let stage_distribution = self.stage_distribution.read();
-
-        if let Some(stage) = stage_distribution.stages_completed.get(&stage_key) {
-            Ok(stage.tasks.clone())
-        } else {
-            Err(BallistaError::General(format!(
-                "The final stage id {} has not been completed yet",
-                final_stage_id
-            )))
-        }
-    }
-
-    pub fn add_pending_stage(&self, job_id: &str, stage_id: u32) {
-        let mut pending_stages = self.pending_stages.write();
-        pending_stages
-            .entry(job_id.to_owned())
-            .or_insert_with(HashSet::new)
-            .insert(stage_id);
-    }
-
-    pub fn is_pending_stage(&self, job_id: &str, stage_id: u32) -> bool {
-        let pending_stages = self.pending_stages.read();
-        if let Some(pending_stages) = pending_stages.get(job_id) {
-            pending_stages.contains(&stage_id)
-        } else {
-            false
-        }
-    }
-
-    pub fn remove_pending_stage(
-        &self,
-        job_id: &str,
-        stages_remove: &HashSet<u32>,
-    ) -> bool {
-        let mut pending_stages = self.pending_stages.write();
-        let mut is_stages_empty = false;
-        let ret = if let Some(stages) = pending_stages.get_mut(job_id) {
-            let len_before_remove = stages.len();
-            for stage_id in stages_remove {
-                stages.remove(stage_id);
-            }
-            is_stages_empty = stages.is_empty();
-            stages.len() != len_before_remove
-        } else {
-            false
-        };
-
-        if is_stages_empty {
-            pending_stages.remove(job_id);
-        }
-
-        ret
-    }
-
-    pub fn add_stages_dependency(
-        &self,
-        job_id: &str,
-        dependencies: HashMap<u32, HashSet<u32>>,
-    ) {
-        let mut stages_dependency = self.stages_dependency.write();
-        for (stage_id, parent_stages) in dependencies.into_iter() {
-            stages_dependency.insert((job_id.to_owned(), stage_id), parent_stages);
-        }
-    }
-
-    pub fn get_parent_stages(&self, job_id: &str, stage_id: u32) -> Option<HashSet<u32>> {
-        let stage_key = (job_id.to_owned(), stage_id);
-        let stages_dependency = self.stages_dependency.read();
-        stages_dependency.get(&stage_key).cloned()
-    }
-
-    pub fn add_running_stage(&self, job_id: &str, stage_id: u32, num_partitions: u32) {
-        let stage = Stage::new(job_id, stage_id, num_partitions);
-
-        let mut stage_distribution = self.stage_distribution.write();
-        stage_distribution
-            .stages_running
-            .insert((job_id.to_string(), stage_id), stage);
-    }
-
-    pub fn is_running_stage(&self, job_id: &str, stage_id: u32) -> bool {
-        let stage_key = (job_id.to_owned(), stage_id);
-        let stage_distribution = self.stage_distribution.read();
-        stage_distribution.stages_running.get(&stage_key).is_some()
-    }
-
-    pub fn is_completed_stage(&self, job_id: &str, stage_id: u32) -> bool {
-        let stage_key = (job_id.to_owned(), stage_id);
-        let stage_distribution = self.stage_distribution.read();
-        stage_distribution
-            .stages_completed
-            .get(&stage_key)
-            .is_some()
-    }
-
-    pub(crate) fn get_stage_tasks(
-        &self,
-        job_id: &str,
-        stage_id: u32,
-    ) -> Option<Vec<Arc<TaskStatus>>> {
-        let stage_key = (job_id.to_owned(), stage_id);
-        let stage_distribution = self.stage_distribution.read();
-        if let Some(stage) = stage_distribution.stages_running.get(&stage_key) {
-            Some(stage.tasks.clone())
-        } else {
-            stage_distribution
-                .stages_completed
-                .get(&stage_key)
-                .map(|task| task.tasks.clone())
-        }
-    }
-
-    pub(crate) fn update_tasks_status(
-        &self,
-        tasks_status: Vec<TaskStatus>,
-    ) -> Vec<QueryStageSchedulerEvent> {
-        let mut all_tasks_status: HashMap<StageKey, Vec<TaskStatus>> = HashMap::new();
-        for task_status in tasks_status {
-            if let Some(task_id) = task_status.task_id.as_ref() {
-                let stage_tasks_status = all_tasks_status
-                    .entry((task_id.job_id.clone(), task_id.stage_id))
-                    .or_insert_with(Vec::new);
-                stage_tasks_status.push(task_status);
-            } else {
-                error!("There's no task id when updating status");
-            }
-        }
-
-        let mut ret = vec![];
-        let mut stage_distribution = self.stage_distribution.write();
-        for (stage_key, stage_tasks_status) in all_tasks_status.into_iter() {
-            if let Some(stage) = stage_distribution.stages_running.get_mut(&stage_key) {
-                for task_status in &stage_tasks_status {
-                    stage.update_task_status(task_status);
-                }
-                if let Some(fail_message) = stage.get_fail_message() {
-                    ret.push(QueryStageSchedulerEvent::JobFailed(
-                        stage_key.0.clone(),
-                        stage_key.1,
-                        fail_message,
-                    ));
-                } else if stage.is_completed() {
-                    stage_distribution.complete_stage(stage_key.clone());
-                    if self.is_final_stage(&stage_key.0, stage_key.1) {
-                        ret.push(QueryStageSchedulerEvent::JobFinished(
-                            stage_key.0.clone(),
-                        ));
-                    } else {
-                        ret.push(QueryStageSchedulerEvent::StageFinished(
-                            stage_key.0.clone(),
-                            stage_key.1,
-                        ));
-                    }
-                }
-            } else {
-                error!("Fail to find stage for {:?}/{}", &stage_key.0, stage_key.1);
-            }
-        }
-
-        ret
-    }
-
-    pub fn fetch_pending_tasks<F>(
-        &self,
-        max_num: usize,
-        cond: F,
-    ) -> Option<(String, u32, Vec<u32>)>
-    where
-        F: Fn(&StageKey) -> bool,
-    {
-        if let Some(next_stage) = self.fetch_schedulable_stage(cond) {
-            if let Some(next_tasks) =
-                self.find_stage_pending_tasks(&next_stage.0, next_stage.1, max_num)
-            {
-                Some((next_stage.0.to_owned(), next_stage.1, next_tasks))
-            } else {
-                warn!(
-                    "Fail to find pending tasks for stage {}/{}",
-                    next_stage.0, next_stage.1
-                );
-                None
-            }
-        } else {
-            None
-        }
-    }
-
-    fn find_stage_pending_tasks(
-        &self,
-        job_id: &str,
-        stage_id: u32,
-        max_num: usize,
-    ) -> Option<Vec<u32>> {
-        let stage_key = (job_id.to_owned(), stage_id);
-        let stage_distribution = self.stage_distribution.read();
-        stage_distribution
-            .stages_running
-            .get(&stage_key)
-            .map(|stage| stage.find_pending_tasks(max_num))
-    }
-
-    pub fn has_running_tasks(&self) -> bool {
-        let stage_distribution = self.stage_distribution.read();
-        for stage in stage_distribution.stages_running.values() {
-            if !stage.get_running_tasks().is_empty() {
-                return true;
-            }
-        }
-
-        false
-    }
-}
-
-// TODO Currently, it will randomly choose a stage. In the future, we can add more sophisticated stage choose algorithm here, like priority, etc.
-impl StageScheduler for StageManager {
-    fn fetch_schedulable_stage<F>(&self, cond: F) -> Option<StageKey>
-    where
-        F: Fn(&StageKey) -> bool,
-    {
-        let mut rng = rand::thread_rng();
-        let stage_distribution = self.stage_distribution.read();
-        let stages_running = &stage_distribution.stages_running;
-        if stages_running.is_empty() {
-            debug!("There's no running stages");
-            return None;
-        }
-        let stages = stages_running
-            .iter()
-            .filter(|entry| entry.1.is_schedulable() && cond(entry.0))
-            .map(|entry| entry.0)
-            .collect::<Vec<&StageKey>>();
-        if stages.is_empty() {
-            None
-        } else {
-            let n_th = rng.gen_range(0..stages.len());
-            Some(stages[n_th].clone())
-        }
-    }
-}
-
-struct StageDistribution {
-    // The key is (job_id, stage_id)
-    stages_running: HashMap<StageKey, Stage>,
-    stages_completed: HashMap<StageKey, Stage>,
-}
-
-impl StageDistribution {
-    fn new() -> Self {
-        Self {
-            stages_running: HashMap::new(),
-            stages_completed: HashMap::new(),
-        }
-    }
-
-    fn complete_stage(&mut self, stage_key: StageKey) {
-        if let Some(stage) = self.stages_running.remove(&stage_key) {
-            assert!(
-                stage.is_completed(),
-                "Stage {}/{} is not completed",
-                stage_key.0,
-                stage_key.1
-            );
-            self.stages_completed.insert(stage_key, stage);
-        } else {
-            warn!(
-                "Fail to find running stage {:?}/{}",
-                stage_key.0, stage_key.1
-            );
-        }
-    }
-}
-
-pub struct Stage {
-    pub stage_id: u32,
-    tasks: Vec<Arc<TaskStatus>>,
-
-    tasks_distribution: TaskStatusDistribution,
-}
-
-impl Stage {
-    fn new(job_id: &str, stage_id: u32, num_partitions: u32) -> Self {
-        let mut tasks = vec![];
-        for partition_id in 0..num_partitions {
-            let pending_status = Arc::new(TaskStatus {
-                task_id: Some(protobuf::PartitionId {
-                    job_id: job_id.to_owned(),
-                    stage_id,
-                    partition_id,
-                }),
-                status: None,
-            });
-
-            tasks.push(pending_status);
-        }
-
-        Stage {
-            stage_id,
-            tasks,
-            tasks_distribution: TaskStatusDistribution::new(num_partitions as usize),
-        }
-    }
-
-    // If error happens for updating some task status, just quietly print the error message
-    fn update_task_status(&mut self, task: &TaskStatus) {
-        if let Some(task_id) = &task.task_id {
-            let task_idx = task_id.partition_id as usize;
-            if task_idx < self.tasks.len() {
-                let existing_task_status = self.tasks[task_idx].clone();
-                if self.tasks_distribution.update(
-                    task_idx,
-                    &existing_task_status.status,
-                    &task.status,
-                ) {
-                    self.tasks[task_idx] = Arc::new(task.clone());
-                } else {
-                    error!(
-                        "Fail to update status from {:?} to {:?} for task: {:?}/{:?}/{:?}", &existing_task_status.status, &task.status,
-                        &task_id.job_id, &task_id.stage_id, task_idx
-                    )
-                }
-            } else {
-                error!(
-                    "Fail to find existing task: {:?}/{:?}/{:?}",
-                    &task_id.job_id, &task_id.stage_id, task_idx
-                )
-            }
-        } else {
-            error!("Fail to update task status due to no task id");
-        }
-    }
-
-    fn is_schedulable(&self) -> bool {
-        self.tasks_distribution.is_schedulable()
-    }
-
-    fn is_completed(&self) -> bool {
-        self.tasks_distribution.is_completed()
-    }
-
-    // If return None, means no failed tasks
-    fn get_fail_message(&self) -> Option<String> {
-        if self.tasks_distribution.is_failed() {
-            let task_idx = self.tasks_distribution.sample_failed_index();
-            if let Some(task) = self.tasks.get(task_idx) {
-                if let Some(task_status::Status::Failed(FailedTask { error })) =
-                    &task.status
-                {
-                    Some(error.clone())
-                } else {
-                    warn!("task {:?} is not failed", task);
-                    None
-                }
-            } else {
-                warn!("Could not find error tasks");
-                None
-            }
-        } else {
-            None
-        }
-    }
-
-    pub fn find_pending_tasks(&self, max_num: usize) -> Vec<u32> {
-        self.tasks_distribution.find_pending_indicators(max_num)
-    }
-
-    fn get_running_tasks(&self) -> Vec<Arc<TaskStatus>> {
-        self.tasks_distribution
-            .running_indicator
-            .indicator
-            .iter()
-            .enumerate()
-            .filter(|(_i, is_running)| **is_running)
-            .map(|(i, _is_running)| self.tasks[i].clone())
-            .collect()
-    }
-}
-
-#[derive(Clone)]
-struct TaskStatusDistribution {
-    len: usize,
-    pending_indicator: TaskStatusIndicator,
-    running_indicator: TaskStatusIndicator,
-    failed_indicator: TaskStatusIndicator,
-    completed_indicator: TaskStatusIndicator,
-}
-
-impl TaskStatusDistribution {
-    fn new(len: usize) -> Self {
-        Self {
-            len,
-            pending_indicator: TaskStatusIndicator {
-                indicator: (0..len).map(|_| true).collect::<Vec<bool>>(),
-                n_of_true: len,
-            },
-            running_indicator: TaskStatusIndicator {
-                indicator: (0..len).map(|_| false).collect::<Vec<bool>>(),
-                n_of_true: 0,
-            },
-            failed_indicator: TaskStatusIndicator {
-                indicator: (0..len).map(|_| false).collect::<Vec<bool>>(),
-                n_of_true: 0,
-            },
-            completed_indicator: TaskStatusIndicator {
-                indicator: (0..len).map(|_| false).collect::<Vec<bool>>(),
-                n_of_true: 0,
-            },
-        }
-    }
-
-    fn is_schedulable(&self) -> bool {
-        self.pending_indicator.n_of_true != 0
-    }
-
-    fn is_completed(&self) -> bool {
-        self.completed_indicator.n_of_true == self.len
-    }
-
-    fn is_failed(&self) -> bool {
-        self.failed_indicator.n_of_true != 0
-    }
-
-    fn sample_failed_index(&self) -> usize {
-        for i in 0..self.len {
-            if self.failed_indicator.indicator[i] {
-                return i;
-            }
-        }
-
-        self.len
-    }
-
-    fn find_pending_indicators(&self, max_num: usize) -> Vec<u32> {
-        let mut ret = vec![];
-        if max_num < 1 {
-            return ret;
-        }
-
-        let len = std::cmp::min(max_num, self.len);
-        for idx in 0..self.len {
-            if self.pending_indicator.indicator[idx] {
-                ret.push(idx as u32);
-                if ret.len() >= len {
-                    break;
-                }
-            }
-        }
-
-        ret
-    }
-
-    fn update(
-        &mut self,
-        idx: usize,
-        from: &Option<task_status::Status>,
-        to: &Option<task_status::Status>,
-    ) -> bool {
-        assert!(
-            idx < self.len,
-            "task index {} should be smaller than {}",
-            idx,
-            self.len
-        );
-
-        match (from, to) {
-            (Some(from), Some(to)) => match (from, to) {
-                (task_status::Status::Running(_), task_status::Status::Failed(_)) => {
-                    self.running_indicator.set_false(idx);
-                    self.failed_indicator.set_true(idx);
-                }
-                (task_status::Status::Running(_), task_status::Status::Completed(_)) => {
-                    self.running_indicator.set_false(idx);
-                    self.completed_indicator.set_true(idx);
-                }
-                _ => {
-                    return false;
-                }
-            },
-            (None, Some(task_status::Status::Running(_))) => {
-                self.pending_indicator.set_false(idx);
-                self.running_indicator.set_true(idx);
-            }
-            (Some(from), None) => match from {
-                task_status::Status::Failed(_) => {
-                    self.failed_indicator.set_false(idx);
-                    self.pending_indicator.set_true(idx);
-                }
-                task_status::Status::Completed(_) => {
-                    self.completed_indicator.set_false(idx);
-                    self.pending_indicator.set_true(idx);
-                }
-                _ => {
-                    return false;
-                }
-            },
-            _ => {
-                return false;
-            }
-        }
-
-        true
-    }
-}
-
-#[derive(Clone)]
-struct TaskStatusIndicator {
-    indicator: Vec<bool>,
-    n_of_true: usize,
-}
-
-impl TaskStatusIndicator {
-    fn set_false(&mut self, idx: usize) {
-        self.indicator[idx] = false;
-        self.n_of_true -= 1;
-    }
-
-    fn set_true(&mut self, idx: usize) {
-        self.indicator[idx] = true;
-        self.n_of_true += 1;
-    }
-}
-
-#[cfg(test)]
-mod test {
-    use crate::state::stage_manager::StageManager;
-    use ballista_core::error::Result;
-    use ballista_core::serde::protobuf::{
-        task_status, CompletedTask, FailedTask, PartitionId, RunningTask, TaskStatus,
-    };
-
-    #[tokio::test]
-    async fn test_task_status_state_machine_failed() -> Result<()> {
-        let stage_manager = StageManager::new();
-
-        let num_partitions = 3;
-        let job_id = "job";
-        let stage_id = 1u32;
-
-        stage_manager.add_running_stage(job_id, stage_id, num_partitions);
-
-        let task_id = PartitionId {
-            job_id: job_id.to_owned(),
-            stage_id,
-            partition_id: 2,
-        };
-
-        {
-            // Invalid transformation from Pending to Failed
-            stage_manager.update_tasks_status(vec![TaskStatus {
-                status: Some(task_status::Status::Failed(FailedTask {
-                    error: "error".to_owned(),
-                })),
-                task_id: Some(task_id.clone()),
-            }]);
-            let ret = stage_manager.get_stage_tasks(job_id, stage_id);
-            assert!(ret.is_some());
-            assert!(ret
-                .unwrap()
-                .get(task_id.partition_id as usize)
-                .unwrap()
-                .status
-                .is_none());
-        }
-
-        {
-            // Valid transformation from Pending to Running to Failed
-            stage_manager.update_tasks_status(vec![TaskStatus {
-                status: Some(task_status::Status::Running(RunningTask {
-                    executor_id: "localhost".to_owned(),
-                })),
-                task_id: Some(task_id.clone()),
-            }]);
-            stage_manager.update_tasks_status(vec![TaskStatus {
-                status: Some(task_status::Status::Failed(FailedTask {
-                    error: "error".to_owned(),
-                })),
-                task_id: Some(task_id.clone()),
-            }]);
-            let ret = stage_manager.get_stage_tasks(job_id, stage_id);
-            assert!(ret.is_some());
-            match ret
-                .unwrap()
-                .get(task_id.partition_id as usize)
-                .unwrap()
-                .status
-                .as_ref()
-                .unwrap()
-            {
-                task_status::Status::Failed(_) => (),
-                _ => panic!("Unexpected status"),
-            }
-        }
-
-        Ok(())
-    }
-
-    #[tokio::test]
-    async fn test_task_status_state_machine_completed() -> Result<()> {
-        let stage_manager = StageManager::new();
-
-        let num_partitions = 3;
-        let job_id = "job";
-        let stage_id = 1u32;
-
-        stage_manager.add_running_stage(job_id, stage_id, num_partitions);
-
-        let task_id = PartitionId {
-            job_id: job_id.to_owned(),
-            stage_id,
-            partition_id: 2,
-        };
-
-        // Valid transformation from Pending to Running to Completed to Pending
-        task_from_pending_to_completed(&stage_manager, &task_id);
-        let ret = stage_manager.get_stage_tasks(job_id, stage_id);
-        assert!(ret.is_some());
-        match ret
-            .unwrap()
-            .get(task_id.partition_id as usize)
-            .unwrap()
-            .status
-            .as_ref()
-            .unwrap()
-        {
-            task_status::Status::Completed(_) => (),
-            _ => panic!("Unexpected status"),
-        }
-        stage_manager.update_tasks_status(vec![TaskStatus {
-            status: None,
-            task_id: Some(task_id.clone()),
-        }]);
-        let ret = stage_manager.get_stage_tasks(job_id, stage_id);
-        assert!(ret.is_some());
-        assert!(ret
-            .unwrap()
-            .get(task_id.partition_id as usize)
-            .unwrap()
-            .status
-            .is_none());
-
-        Ok(())
-    }
-
-    #[tokio::test]
-    async fn test_stage_state_machine_completed() -> Result<()> {
-        let stage_manager = StageManager::new();
-
-        let num_partitions = 3;
-        let job_id = "job";
-        let stage_id = 1u32;
-
-        // Valid transformation from Running to Completed
-        stage_manager.add_running_stage(job_id, stage_id, num_partitions);
-        assert!(stage_manager.is_running_stage(job_id, stage_id));
-        for partition_id in 0..num_partitions {
-            task_from_pending_to_completed(
-                &stage_manager,
-                &PartitionId {
-                    job_id: job_id.to_owned(),
-                    stage_id,
-                    partition_id,
-                },
-            );
-        }
-        assert!(stage_manager.is_completed_stage(job_id, stage_id));
-
-        // Valid transformation from Completed to Running
-        stage_manager.update_tasks_status(vec![TaskStatus {
-            status: None,
-            task_id: Some(PartitionId {
-                job_id: job_id.to_owned(),
-                stage_id,
-                partition_id: 0,
-            }),
-        }]);
-        assert!(!stage_manager.is_running_stage(job_id, stage_id));
-
-        Ok(())
-    }
-
-    fn task_from_pending_to_completed(
-        stage_manager: &StageManager,
-        task_id: &PartitionId,
-    ) {
-        stage_manager.update_tasks_status(vec![TaskStatus {
-            status: Some(task_status::Status::Running(RunningTask {
-                executor_id: "localhost".to_owned(),
-            })),
-            task_id: Some(task_id.clone()),
-        }]);
-        stage_manager.update_tasks_status(vec![TaskStatus {
-            status: Some(task_status::Status::Completed(CompletedTask {
-                executor_id: "localhost".to_owned(),
-                partitions: Vec::new(),
-            })),
-            task_id: Some(task_id.clone()),
-        }]);
-    }
-}
diff --git a/ballista/rust/scheduler/src/state/task_manager.rs b/ballista/rust/scheduler/src/state/task_manager.rs
new file mode 100644
index 00000000..e3ceb610
--- /dev/null
+++ b/ballista/rust/scheduler/src/state/task_manager.rs
@@ -0,0 +1,751 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::scheduler_server::event::QueryStageSchedulerEvent;
+use crate::scheduler_server::SessionBuilder;
+use crate::state::backend::{Keyspace, StateBackendClient};
+use crate::state::execution_graph::{ExecutionGraph, ExecutionStage, StageOutput, Task};
+use crate::state::executor_manager::ExecutorReservation;
+use crate::state::{decode_protobuf, encode_protobuf, with_lock};
+use ballista_core::config::BallistaConfig;
+use ballista_core::error::{BallistaError, Result};
+use ballista_core::serde::physical_plan::from_proto::parse_protobuf_hash_partitioning;
+use ballista_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient;
+
+use crate::state::session_manager::create_datafusion_context;
+use ballista_core::serde::protobuf::{
+    self, job_status, task_status, FailedJob, JobStatus, PartitionId, QueuedJob,
+    TaskDefinition, TaskStatus,
+};
+use ballista_core::serde::scheduler::to_proto::hash_partitioning_to_proto;
+use ballista_core::serde::scheduler::{ExecutorMetadata, PartitionLocation};
+use ballista_core::serde::{AsExecutionPlan, BallistaCodec};
+use datafusion::physical_plan::{ExecutionPlan, Partitioning};
+use datafusion::prelude::SessionContext;
+use datafusion_proto::logical_plan::AsLogicalPlan;
+use log::{debug, info, warn};
+use rand::distributions::Alphanumeric;
+use rand::{thread_rng, Rng};
+use std::collections::{HashMap, HashSet};
+use std::convert::TryInto;
+use std::default::Default;
+use std::sync::Arc;
+use tokio::sync::RwLock;
+use tonic::transport::Channel;
+
+type ExecutorClients = Arc<RwLock<HashMap<String, ExecutorGrpcClient<Channel>>>>;
+
+#[derive(Clone)]
+pub struct TaskManager<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> {
+    state: Arc<dyn StateBackendClient>,
+    #[allow(dead_code)]
+    clients: ExecutorClients,
+    session_builder: SessionBuilder,
+    codec: BallistaCodec<T, U>,
+}
+
+impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> {
+    pub fn new(
+        state: Arc<dyn StateBackendClient>,
+        session_builder: SessionBuilder,
+        codec: BallistaCodec<T, U>,
+    ) -> Self {
+        Self {
+            state,
+            clients: Default::default(),
+            session_builder,
+            codec,
+        }
+    }
+
+    /// Generate an ExecutionGraph for the job and save it to the persistent state.
+    pub async fn submit_job(
+        &self,
+        job_id: &str,
+        session_id: &str,
+        plan: Arc<dyn ExecutionPlan>,
+    ) -> Result<()> {
+        let graph = ExecutionGraph::new(job_id, session_id, plan)?;
+        self.state
+            .put(
+                Keyspace::ActiveJobs,
+                job_id.to_owned(),
+                self.encode_execution_graph(graph)?,
+            )
+            .await?;
+
+        if let Err(e) = self.state.delete(Keyspace::QueuedJobs, job_id).await {
+            warn!("Failed to remove key in QueuedJobs for {}: {:?}", job_id, e);
+        }
+
+        Ok(())
+    }
+
+    /// Queue a job. When a batch job is submitted we do the physical planning asynchronously so we
+    /// need to add a marker so we can report on its status.
+    pub async fn queue_job(&self, job_id: &str) -> Result<()> {
+        self.state
+            .put(Keyspace::QueuedJobs, job_id.to_owned(), vec![0x0])
+            .await
+    }
+
+    /// Get the status of of a job. First look in Active/Completed jobs, and then in Queued jobs, and
+    /// finally in FailedJobs.
+    pub async fn get_job_status(&self, job_id: &str) -> Result<Option<JobStatus>> {
+        let queue_marker = self.state.get(Keyspace::QueuedJobs, job_id).await?;
+        if !queue_marker.is_empty() {
+            Ok(Some(JobStatus {
+                status: Some(job_status::Status::Queued(QueuedJob {})),
+            }))
+        } else if let Ok(graph) = self.get_execution_graph(job_id).await {
+            Ok(Some(graph.status()))
+        } else {
+            let value = self.state.get(Keyspace::FailedJobs, job_id).await?;
+
+            if !value.is_empty() {
+                let status = decode_protobuf(&value)?;
+                Ok(Some(status))
+            } else {
+                Ok(None)
+            }
+        }
+    }
+
+    /// Generate a new random Job ID
+    pub fn generate_job_id(&self) -> String {
+        let mut rng = thread_rng();
+        std::iter::repeat(())
+            .map(|()| rng.sample(Alphanumeric))
+            .map(char::from)
+            .take(7)
+            .collect()
+    }
+
+    /// Atomically update given task statuses in the respective job and return a tuple containing:
+    /// 1. A list of QueryStageSchedulerEvent to publish.
+    /// 2. A list of reservations that can now be offered.
+    ///
+    /// When a task is updated, there may or may not be more tasks pending for its job. If there are more
+    /// tasks pending then we want to reschedule one of those tasks on the same task slot. In that case
+    /// we will set the `job_id` on the `ExecutorReservation` so the scheduler attempts to assign tasks from
+    /// the same job. Note that when the scheduler attempts to fill the reservation, there is no guarantee
+    /// that the available task is still available.
+    pub(crate) async fn update_task_statuses(
+        &self,
+        executor: &ExecutorMetadata,
+        task_status: Vec<TaskStatus>,
+    ) -> Result<(Vec<QueryStageSchedulerEvent>, Vec<ExecutorReservation>)> {
+        let lock = self.state.lock(Keyspace::ActiveJobs, "").await?;
+
+        with_lock(lock, async {
+            let mut events: Vec<QueryStageSchedulerEvent> = vec![];
+            let mut reservation: Vec<ExecutorReservation> = vec![];
+
+            let mut job_updates: HashMap<String, Vec<TaskStatus>> = HashMap::new();
+
+            for status in task_status {
+                debug!("Task Update\n{:?}", status);
+                if let Some(job_id) = status.task_id.as_ref().map(|id| &id.job_id) {
+                    if let Some(statuses) = job_updates.get_mut(job_id) {
+                        statuses.push(status)
+                    } else {
+                        job_updates.insert(job_id.clone(), vec![status]);
+                    }
+                } else {
+                    warn!("Received task with no job ID");
+                }
+            }
+
+            let mut txn_ops: Vec<(Keyspace, String, Vec<u8>)> = vec![];
+
+            for (job_id, statuses) in job_updates {
+                let num_tasks = statuses.len();
+                debug!("Updating {} tasks in job {}", num_tasks, job_id);
+
+                let mut graph = self.get_execution_graph(&job_id).await?;
+
+                graph.update_task_status(executor, statuses)?;
+
+                if graph.complete() {
+                    // If this ExecutionGraph is complete, finalize it
+                    info!(
+                        "Job {} is complete, finalizing output partitions",
+                        graph.job_id()
+                    );
+                    graph.finalize()?;
+                    events.push(QueryStageSchedulerEvent::JobFinished(job_id.clone()));
+
+                    for _ in 0..num_tasks {
+                        reservation
+                            .push(ExecutorReservation::new_free(executor.id.to_owned()));
+                    }
+                } else if let Some(job_status::Status::Failed(failure)) =
+                    graph.status().status
+                {
+                    events.push(QueryStageSchedulerEvent::JobFailed(
+                        job_id.clone(),
+                        failure.error,
+                    ));
+
+                    for _ in 0..num_tasks {
+                        reservation
+                            .push(ExecutorReservation::new_free(executor.id.to_owned()));
+                    }
+                } else {
+                    // Otherwise keep the task slots reserved for this job
+                    for _ in 0..num_tasks {
+                        reservation.push(ExecutorReservation::new_assigned(
+                            executor.id.to_owned(),
+                            job_id.clone(),
+                        ));
+                    }
+                }
+
+                txn_ops.push((
+                    Keyspace::ActiveJobs,
+                    job_id.clone(),
+                    self.encode_execution_graph(graph)?,
+                ));
+            }
+
+            self.state.put_txn(txn_ops).await?;
+
+            Ok((events, reservation))
+        })
+        .await
+    }
+
+    /// Take a list of executor reservations and fill them with tasks that are ready
+    /// to be scheduled. When the reservation is filled, the underlying stage task in the
+    /// `ExecutionGraph` will be set to a status of Running, so if the task is not subsequently launched
+    /// we must ensure that the task status is reset.
+    ///
+    /// Here we use the following  algorithm:
+    ///
+    /// 1. For each reservation with a `job_id` assigned try and assign another task from the same job.
+    /// 2. If a reservation either does not have a `job_id` or there are no available tasks for its `job_id`,
+    ///    add it to a list of "free" reservations.
+    /// 3. For each free reservation, try to assign a task from one of the jobs we have already considered.
+    /// 4. If we cannot find a task, then looks for a task among all active jobs
+    /// 5. If we cannot find a task in all active jobs, then add the reservation to the list of unassigned reservations
+    ///
+    /// Finally, we return:
+    /// 1. A list of assignments which is a (Executor ID, Task) tuple
+    /// 2. A list of unassigned reservations which we could not find tasks for
+    /// 3. The number of pending tasks across active jobs
+    pub async fn fill_reservations(
+        &self,
+        reservations: &[ExecutorReservation],
+    ) -> Result<(Vec<(String, Task)>, Vec<ExecutorReservation>, usize)> {
+        let lock = self.state.lock(Keyspace::ActiveJobs, "").await?;
+
+        with_lock(lock, async {
+            let mut assignments: Vec<(String, Task)> = vec![];
+            let mut free_reservations: Vec<ExecutorReservation> = vec![];
+            // let _txn_ops: Vec<(Keyspace, String, Vec<u8>)> = vec![];
+
+            // Need to collect graphs we update so we can update them in storage when we are done
+            let mut graphs: HashMap<String, ExecutionGraph> = HashMap::new();
+
+            // First try and fill reservations for particular jobs. If the job has no more tasks
+            // free the reservation.
+            for reservation in reservations {
+                debug!(
+                "Filling reservation for executor {} from job {:?}",
+                reservation.executor_id, reservation.job_id
+            );
+                let executor_id = &reservation.executor_id;
+                if let Some(job_id) = &reservation.job_id {
+                    if let Some(graph) = graphs.get_mut(job_id) {
+                        if let Ok(Some(next_task)) = graph.pop_next_task(executor_id) {
+                            debug!(
+                            "Filled reservation for executor {} with task {:?}",
+                            executor_id, next_task
+                        );
+                            assignments.push((executor_id.clone(), next_task));
+                        } else {
+                            debug!("Cannot fill reservation for executor {} from job {}, freeing reservation", executor_id, job_id);
+                            free_reservations
+                                .push(ExecutorReservation::new_free(executor_id.clone()));
+                        }
+                    } else {
+                        // let lock = self.state.lock(Keyspace::ActiveJobs, job_id).await?;
+                        let mut graph = self.get_execution_graph(job_id).await?;
+
+                        if let Ok(Some(next_task)) = graph.pop_next_task(executor_id) {
+                            debug!(
+                            "Filled reservation for executor {} with task {:?}",
+                            executor_id, next_task
+                        );
+                            assignments.push((executor_id.clone(), next_task));
+                            graphs.insert(job_id.clone(), graph);
+                            // locks.push(lock);
+                        } else {
+                            debug!("Cannot fill reservation for executor {} from job {}, freeing reservation", executor_id, job_id);
+                            free_reservations
+                                .push(ExecutorReservation::new_free(executor_id.clone()));
+                        }
+                    }
+                } else {
+                    free_reservations.push(reservation.clone());
+                }
+            }
+
+            let mut other_jobs: Vec<String> =
+                self.get_active_jobs().await?.into_iter().collect();
+
+            let mut unassigned: Vec<ExecutorReservation> = vec![];
+            // Now try and find tasks for free reservations from current set of graphs
+            for reservation in free_reservations {
+                debug!(
+                "Filling free reservation for executor {}",
+                reservation.executor_id
+            );
+                let mut assigned = false;
+                let executor_id = reservation.executor_id.clone();
+
+                // Try and find a task in the graphs we already have locks on
+                if let Ok(Some(assignment)) = find_next_task(&executor_id, &mut graphs) {
+                    debug!(
+                    "Filled free reservation for executor {} with task {:?}",
+                    reservation.executor_id, assignment.1
+                );
+                    // First check if we can find another task
+                    assignments.push(assignment);
+                    assigned = true;
+                } else {
+                    // Otherwise start searching through other active jobs.
+                    debug!(
+                    "Filling free reservation for executor {} from active jobs {:?}",
+                    reservation.executor_id, other_jobs
+                );
+                    while let Some(job_id) = other_jobs.pop() {
+                        if graphs.get(&job_id).is_none() {
+                            // let lock = self.state.lock(Keyspace::ActiveJobs, &job_id).await?;
+                            let mut graph = self.get_execution_graph(&job_id).await?;
+
+                            if let Ok(Some(task)) = graph.pop_next_task(&executor_id) {
+                                debug!(
+                                "Filled free reservation for executor {} with task {:?}",
+                                reservation.executor_id, task
+                            );
+                                assignments.push((executor_id.clone(), task));
+                                // locks.push(lock);
+                                graphs.insert(job_id, graph);
+                                assigned = true;
+                                break;
+                            } else {
+                                debug!("No available tasks for job {}", job_id);
+                            }
+                        }
+                    }
+                }
+
+                if !assigned {
+                    debug!(
+                    "Unable to fill reservation for executor {}, no tasks available",
+                    executor_id
+                );
+                    unassigned.push(reservation);
+                }
+            }
+
+            let mut pending_tasks = 0;
+
+            // Transactional update graphs now that we have assigned tasks
+            let txn_ops: Vec<(Keyspace, String, Vec<u8>)> = graphs
+                .into_iter()
+                .map(|(job_id, graph)| {
+                    pending_tasks += graph.available_tasks();
+                    let value = self.encode_execution_graph(graph)?;
+                    Ok((Keyspace::ActiveJobs, job_id, value))
+                })
+                .collect::<Result<Vec<_>>>()?;
+
+            self.state.put_txn(txn_ops).await?;
+
+            Ok((assignments, unassigned, pending_tasks))
+        }).await
+    }
+
+    /// Move the given job to the CompletedJobs keyspace in persistent storage.
+    pub async fn complete_job(&self, job_id: &str) -> Result<()> {
+        debug!("Moving job {} from Active to Completed", job_id);
+        let lock = self.state.lock(Keyspace::ActiveJobs, "").await?;
+        with_lock(
+            lock,
+            self.state
+                .mv(Keyspace::ActiveJobs, Keyspace::CompletedJobs, job_id),
+        )
+        .await
+    }
+
+    /// Mark a job as failed. This will create a key under the FailedJobs keyspace
+    /// and remove the job from ActiveJobs or QueuedJobs
+    /// TODO this should be atomic
+    pub async fn fail_job(&self, job_id: &str, error_message: String) -> Result<()> {
+        let lock = self.state.lock(Keyspace::ActiveJobs, "").await?;
+        with_lock(lock, self.state.delete(Keyspace::ActiveJobs, job_id)).await?;
+
+        self.state.delete(Keyspace::QueuedJobs, job_id).await?;
+
+        let status = JobStatus {
+            status: Some(job_status::Status::Failed(FailedJob {
+                error: error_message,
+            })),
+        };
+        let value = encode_protobuf(&status)?;
+
+        self.state
+            .put(Keyspace::FailedJobs, job_id.to_owned(), value)
+            .await
+    }
+
+    #[cfg(not(test))]
+    /// Launch the given task on the specified executor
+    pub async fn launch_task(
+        &self,
+        executor: &ExecutorMetadata,
+        task: Task,
+    ) -> Result<()> {
+        info!("Launching task {:?} on executor {:?}", task, executor.id);
+        let task_definition = self.prepare_task_definition(task)?;
+        let mut clients = self.clients.write().await;
+        if let Some(client) = clients.get_mut(&executor.id) {
+            client
+                .launch_task(protobuf::LaunchTaskParams {
+                    task: vec![task_definition],
+                })
+                .await
+                .map_err(|e| {
+                    BallistaError::Internal(format!(
+                        "Failed to connect to executor {}: {:?}",
+                        executor.id, e
+                    ))
+                })?;
+        } else {
+            let executor_id = executor.id.clone();
+            let executor_url = format!("http://{}:{}", executor.host, executor.grpc_port);
+            let mut client = ExecutorGrpcClient::connect(executor_url).await?;
+            clients.insert(executor_id, client.clone());
+            client
+                .launch_task(protobuf::LaunchTaskParams {
+                    task: vec![task_definition],
+                })
+                .await
+                .map_err(|e| {
+                    BallistaError::Internal(format!(
+                        "Failed to connect to executor {}: {:?}",
+                        executor.id, e
+                    ))
+                })?;
+        }
+        Ok(())
+    }
+
+    /// In unit tests, we do not have actual executors running, so it simplifies things to just noop.
+    #[cfg(test)]
+    pub async fn launch_task(
+        &self,
+        _executor: &ExecutorMetadata,
+        _task: Task,
+    ) -> Result<()> {
+        Ok(())
+    }
+
+    /// Retrieve the number of available tasks for the given job. The value returned
+    /// is strictly a point-in-time snapshot
+    pub async fn get_available_task_count(&self, job_id: &str) -> Result<usize> {
+        let graph = self.get_execution_graph(job_id).await?;
+
+        Ok(graph.available_tasks())
+    }
+
+    #[allow(dead_code)]
+    pub fn prepare_task_definition(&self, task: Task) -> Result<TaskDefinition> {
+        debug!("Preparing task definition for {:?}", task);
+        let mut plan_buf: Vec<u8> = vec![];
+        let plan_proto =
+            U::try_from_physical_plan(task.plan, self.codec.physical_extension_codec())?;
+        plan_proto.try_encode(&mut plan_buf)?;
+
+        let output_partitioning =
+            hash_partitioning_to_proto(task.output_partitioning.as_ref())?;
+
+        let task_definition = TaskDefinition {
+            task_id: Some(PartitionId {
+                job_id: task.partition.job_id.clone(),
+                stage_id: task.partition.stage_id as u32,
+                partition_id: task.partition.partition_id as u32,
+            }),
+            plan: plan_buf,
+            output_partitioning,
+            session_id: task.session_id,
+            props: vec![],
+        };
+        Ok(task_definition)
+    }
+
+    ///  Return a set of active job IDs. This will return all keys
+    /// in the `ActiveJobs` keyspace stripped of any prefixes used for
+    /// the storage layer (i.e. just the Job IDs).
+    async fn get_active_jobs(&self) -> Result<HashSet<String>> {
+        debug!("Scanning for active job IDs");
+        self.state.scan_keys(Keyspace::ActiveJobs).await
+    }
+
+    /// Get the `ExecutionGraph` for the given job ID. This will search fist in the `ActiveJobs`
+    /// keyspace and then, if it doesn't find anything, search the `CompletedJobs` keyspace.
+    pub(crate) async fn get_execution_graph(
+        &self,
+        job_id: &str,
+    ) -> Result<ExecutionGraph> {
+        let value = self.state.get(Keyspace::ActiveJobs, job_id).await?;
+
+        if value.is_empty() {
+            let value = self.state.get(Keyspace::CompletedJobs, job_id).await?;
+            self.decode_execution_graph(value).await
+        } else {
+            self.decode_execution_graph(value).await
+        }
+    }
+
+    async fn get_session(&self, session_id: &str) -> Result<Arc<SessionContext>> {
+        let value = self.state.get(Keyspace::Sessions, session_id).await?;
+
+        let settings: protobuf::SessionSettings = decode_protobuf(&value)?;
+
+        let mut config_builder = BallistaConfig::builder();
+        for kv_pair in &settings.configs {
+            config_builder = config_builder.set(&kv_pair.key, &kv_pair.value);
+        }
+        let config = config_builder.build()?;
+
+        Ok(create_datafusion_context(&config, self.session_builder))
+    }
+
+    async fn decode_execution_graph(&self, value: Vec<u8>) -> Result<ExecutionGraph> {
+        let proto: protobuf::ExecutionGraph = decode_protobuf(&value)?;
+
+        let session_id = &proto.session_id;
+
+        let session_ctx = self.get_session(session_id).await?;
+        let mut stages: HashMap<usize, ExecutionStage> = HashMap::new();
+        for stage in proto.stages {
+            let plan_proto = U::try_decode(stage.plan.as_slice())?;
+            let plan = plan_proto.try_into_physical_plan(
+                session_ctx.as_ref(),
+                session_ctx.runtime_env().as_ref(),
+                self.codec.physical_extension_codec(),
+            )?;
+
+            let stage_id = stage.stage_id as usize;
+            let partitions: usize = stage.partitions as usize;
+
+            let mut task_statuses: Vec<Option<task_status::Status>> =
+                vec![None; partitions];
+
+            for status in stage.task_statuses {
+                if let Some(task_id) = status.task_id.as_ref() {
+                    task_statuses[task_id.partition_id as usize] = status.status
+                }
+            }
+
+            // This is a little hacky but since we can't make an optional
+            // primitive field in protobuf, we just use 0 to encode None.
+            // Should work since stage IDs are 1-indexed.
+            let output_link = if stage.output_link == 0 {
+                None
+            } else {
+                Some(stage.output_link as usize)
+            };
+
+            let output_partitioning: Option<Partitioning> =
+                parse_protobuf_hash_partitioning(
+                    stage.output_partitioning.as_ref(),
+                    session_ctx.as_ref(),
+                    plan.schema().as_ref(),
+                )?;
+
+            let mut inputs: HashMap<usize, StageOutput> = HashMap::new();
+
+            for input in stage.inputs {
+                let stage_id = input.stage_id as usize;
+
+                let outputs = input
+                    .partition_locations
+                    .into_iter()
+                    .map(|loc| {
+                        let partition = loc.partition as usize;
+                        let locations = loc
+                            .partition_location
+                            .into_iter()
+                            .map(|l| l.try_into())
+                            .collect::<Result<Vec<_>>>()?;
+                        Ok((partition, locations))
+                    })
+                    .collect::<Result<HashMap<usize, Vec<PartitionLocation>>>>()?;
+
+                inputs.insert(
+                    stage_id,
+                    StageOutput {
+                        partition_locations: outputs,
+                        complete: input.complete,
+                    },
+                );
+            }
+
+            let execution_stage = ExecutionStage {
+                stage_id: stage.stage_id as usize,
+                partitions,
+                output_partitioning,
+                inputs,
+                plan,
+                task_statuses,
+                output_link,
+                resolved: stage.resolved,
+            };
+            stages.insert(stage_id, execution_stage);
+        }
+
+        let output_locations: Vec<PartitionLocation> = proto
+            .output_locations
+            .into_iter()
+            .map(|loc| loc.try_into())
+            .collect::<Result<Vec<_>>>()?;
+
+        Ok(ExecutionGraph {
+            job_id: proto.job_id,
+            session_id: proto.session_id,
+            status: proto.status.ok_or_else(|| {
+                BallistaError::Internal(
+                    "Invalid Execution Graph: missing job status".to_owned(),
+                )
+            })?,
+            stages,
+            output_partitions: proto.output_partitions as usize,
+            output_locations,
+        })
+    }
+
+    fn encode_execution_graph(&self, graph: ExecutionGraph) -> Result<Vec<u8>> {
+        let job_id = graph.job_id().to_owned();
+
+        let stages = graph
+            .stages
+            .into_iter()
+            .map(|(stage_id, stage)| {
+                // This is a little hacky but since we can't make an optional
+                // primitive field in protobuf, we just use 0 to encode None.
+                // Should work since stage IDs are 1-indexed.
+                let output_link = if let Some(link) = stage.output_link {
+                    link as u32
+                } else {
+                    0
+                };
+
+                let mut plan: Vec<u8> = vec![];
+
+                U::try_from_physical_plan(
+                    stage.plan,
+                    self.codec.physical_extension_codec(),
+                )
+                .and_then(|proto| proto.try_encode(&mut plan))?;
+
+                let mut inputs: Vec<protobuf::GraphStageInput> = vec![];
+
+                for (stage, output) in stage.inputs.into_iter() {
+                    inputs.push(protobuf::GraphStageInput {
+                        stage_id: stage as u32,
+                        partition_locations: output
+                            .partition_locations
+                            .into_iter()
+                            .map(|(partition, locations)| {
+                                Ok(protobuf::TaskInputPartitions {
+                                    partition: partition as u32,
+                                    partition_location: locations
+                                        .into_iter()
+                                        .map(|l| l.try_into())
+                                        .collect::<Result<Vec<_>>>()?,
+                                })
+                            })
+                            .collect::<Result<Vec<_>>>()?,
+                        complete: output.complete,
+                    });
+                }
+
+                let task_statuses: Vec<protobuf::TaskStatus> = stage
+                    .task_statuses
+                    .into_iter()
+                    .enumerate()
+                    .filter_map(|(partition, status)| {
+                        status.map(|status| protobuf::TaskStatus {
+                            task_id: Some(protobuf::PartitionId {
+                                job_id: job_id.clone(),
+                                stage_id: stage_id as u32,
+                                partition_id: partition as u32,
+                            }),
+                            status: Some(status),
+                        })
+                    })
+                    .collect();
+
+                let output_partitioning =
+                    hash_partitioning_to_proto(stage.output_partitioning.as_ref())?;
+
+                Ok(protobuf::ExecutionGraphStage {
+                    stage_id: stage_id as u64,
+                    partitions: stage.partitions as u32,
+                    output_partitioning,
+                    inputs,
+                    plan,
+                    task_statuses,
+                    output_link,
+                    resolved: stage.resolved,
+                })
+            })
+            .collect::<Result<Vec<_>>>()?;
+
+        let output_locations: Vec<protobuf::PartitionLocation> = graph
+            .output_locations
+            .into_iter()
+            .map(|loc| loc.try_into())
+            .collect::<Result<Vec<_>>>()?;
+
+        encode_protobuf(&protobuf::ExecutionGraph {
+            job_id: graph.job_id,
+            session_id: graph.session_id,
+            status: Some(graph.status),
+            stages,
+            output_partitions: graph.output_partitions as u64,
+            output_locations,
+        })
+    }
+}
+
+/// Find the next available task in a set of `ExecutionGraph`s
+fn find_next_task(
+    executor_id: &str,
+    graphs: &mut HashMap<String, ExecutionGraph>,
+) -> Result<Option<(String, Task)>> {
+    for graph in graphs.values_mut() {
+        if let Ok(Some(task)) = graph.pop_next_task(executor_id) {
+            return Ok(Some((executor_id.to_owned(), task)));
+        }
+    }
+    Ok(None)
+}
diff --git a/ballista/rust/scheduler/src/state/task_scheduler.rs b/ballista/rust/scheduler/src/state/task_scheduler.rs
deleted file mode 100644
index 0b40091e..00000000
--- a/ballista/rust/scheduler/src/state/task_scheduler.rs
+++ /dev/null
@@ -1,212 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-use crate::state::stage_manager::StageKey;
-use crate::state::SchedulerState;
-use async_trait::async_trait;
-use ballista_core::error::BallistaError;
-use ballista_core::execution_plans::ShuffleWriterExec;
-use ballista_core::serde::protobuf::{
-    job_status, task_status, FailedJob, KeyValuePair, RunningTask, TaskDefinition,
-    TaskStatus,
-};
-use ballista_core::serde::scheduler::to_proto::hash_partitioning_to_proto;
-use ballista_core::serde::scheduler::{ExecutorData, PartitionId};
-use ballista_core::serde::AsExecutionPlan;
-use datafusion_proto::logical_plan::AsLogicalPlan;
-use log::{debug, info};
-
-#[async_trait]
-pub trait TaskScheduler {
-    // For each round, it will fetch tasks from one stage
-    async fn fetch_schedulable_tasks(
-        &self,
-        available_executors: &mut [ExecutorData],
-        n_round: u32,
-    ) -> Result<(Vec<Vec<TaskDefinition>>, usize), BallistaError>;
-}
-
-pub trait StageScheduler {
-    fn fetch_schedulable_stage<F>(&self, cond: F) -> Option<StageKey>
-    where
-        F: Fn(&StageKey) -> bool;
-}
-
-#[async_trait]
-impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskScheduler
-    for SchedulerState<T, U>
-{
-    async fn fetch_schedulable_tasks(
-        &self,
-        available_executors: &mut [ExecutorData],
-        n_round: u32,
-    ) -> Result<(Vec<Vec<TaskDefinition>>, usize), BallistaError> {
-        let mut ret: Vec<Vec<TaskDefinition>> =
-            Vec::with_capacity(available_executors.len());
-        let mut max_task_num = 0u32;
-        for executor in available_executors.iter() {
-            ret.push(Vec::new());
-            max_task_num += executor.available_task_slots;
-        }
-
-        let mut tasks_status = vec![];
-        let mut has_resources = true;
-        for i in 0..n_round {
-            if !has_resources {
-                break;
-            }
-            let mut num_tasks = 0;
-            // For each round, it will fetch tasks from one stage
-            if let Some((job_id, stage_id, tasks)) =
-                self.stage_manager.fetch_pending_tasks(
-                    max_task_num as usize - tasks_status.len(),
-                    |stage_key| {
-                        // Don't scheduler stages for jobs with error status
-                        if let Some(job_meta) = self.get_job_metadata(&stage_key.0) {
-                            if !matches!(
-                                &job_meta.status,
-                                Some(job_status::Status::Failed(FailedJob { error: _ }))
-                            ) {
-                                true
-                            } else {
-                                info!("Stage {}/{} not to be scheduled due to its job failed", stage_key.0, stage_key.1);
-                                false
-                            }
-                        } else {
-                            false
-                        }
-                    },
-                )
-            {
-                let plan =
-                    self.get_stage_plan(&job_id, stage_id as usize)
-                        .ok_or_else(|| {
-                            BallistaError::General(format!(
-                                "Fail to find execution plan for stage {}/{}",
-                                job_id, stage_id
-                            ))
-                        })?;
-                loop {
-                    debug!("Go inside fetching task loop for stage {}/{}", job_id, stage_id);
-
-                    let mut has_tasks = true;
-                    for (idx, executor) in available_executors.iter_mut().enumerate() {
-                        if executor.available_task_slots == 0 {
-                            has_resources = false;
-                            break;
-                        }
-
-                        if num_tasks >= tasks.len() {
-                            has_tasks = false;
-                            break;
-                        }
-
-                        let task_id = PartitionId {
-                            job_id: job_id.clone(),
-                            stage_id: stage_id as usize,
-                            partition_id: tasks[num_tasks] as usize,
-                        };
-
-                        let task_id = Some(task_id.into());
-                        let running_task = TaskStatus {
-                            task_id: task_id.clone(),
-                            status: Some(task_status::Status::Running(RunningTask {
-                                executor_id: executor.executor_id.to_owned(),
-                            })),
-                        };
-                        tasks_status.push(running_task);
-
-                        let plan_clone = plan.clone();
-                        let output_partitioning = if let Some(shuffle_writer) =
-                            plan_clone.as_any().downcast_ref::<ShuffleWriterExec>()
-                        {
-                            shuffle_writer.shuffle_output_partitioning()
-                        } else {
-                            return Err(BallistaError::General(format!(
-                                "Task root plan was not a ShuffleWriterExec: {:?}",
-                                plan_clone
-                            )));
-                        };
-
-                        let mut buf: Vec<u8> = vec![];
-                        U::try_from_physical_plan(
-                            plan.clone(),
-                            self.get_codec().physical_extension_codec(),
-                        )
-                        .and_then(|m| m.try_encode(&mut buf))
-                        .map_err(|e| {
-                            tonic::Status::internal(format!(
-                                "error serializing execution plan: {:?}",
-                                e
-                            ))
-                        })?;
-
-                        let session_id = self.get_session_from_job(&job_id).expect("session id does not exist for job");
-                        let session_props = self
-                            .session_registry()
-                            .lookup_session(&session_id)
-                            .await
-                            .expect("SessionContext does not exist in SessionContextRegistry.")
-                            .copied_config()
-                            .to_props();
-                        let task_props = session_props
-                            .iter()
-                            .map(|(k, v)| KeyValuePair {
-                                key: k.to_owned(),
-                                value: v.to_owned(),
-                            })
-                            .collect::<Vec<_>>();
-
-                        ret[idx].push(TaskDefinition {
-                            plan: buf,
-                            task_id,
-                            output_partitioning: hash_partitioning_to_proto(
-                                output_partitioning,
-                            )
-                            .map_err(|_| tonic::Status::internal("TBD".to_string()))?,
-                            session_id,
-                            props: task_props,
-                        });
-                        executor.available_task_slots -= 1;
-                        num_tasks += 1;
-                    }
-                    if !has_tasks {
-                        break;
-                    }
-                    if !has_resources {
-                        break;
-                    }
-                }
-            }
-            if !has_resources {
-                info!(
-                    "Not enough resource for task running. Stopped at round {}",
-                    i
-                );
-                break;
-            }
-        }
-
-        let total_task_num = tasks_status.len();
-        debug!("{} tasks to be scheduled", total_task_num);
-
-        // No need to deal with the stage event, since the task status is changing from pending to running
-        self.stage_manager.update_tasks_status(tasks_status);
-
-        Ok((ret, total_task_num))
-    }
-}
diff --git a/ballista/rust/scheduler/src/test_utils.rs b/ballista/rust/scheduler/src/test_utils.rs
index 3f8b5767..6cbd36ff 100644
--- a/ballista/rust/scheduler/src/test_utils.rs
+++ b/ballista/rust/scheduler/src/test_utils.rs
@@ -15,16 +15,121 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use ballista_core::error::Result;
+use ballista_core::error::{BallistaError, Result};
+use std::any::Any;
+use std::future::Future;
+use std::sync::Arc;
+use std::time::Duration;
 
-use datafusion::arrow::datatypes::{DataType, Field, Schema};
-use datafusion::execution::context::{SessionConfig, SessionContext};
+use crate::scheduler_server::event::SchedulerServerEvent;
+
+use async_trait::async_trait;
+use ballista_core::event_loop::EventAction;
+
+use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
+use datafusion::common::DataFusionError;
+use datafusion::datasource::{TableProvider, TableType};
+use datafusion::execution::context::{SessionConfig, SessionContext, SessionState};
+use datafusion::logical_expr::Expr;
+use datafusion::physical_plan::ExecutionPlan;
 use datafusion::prelude::CsvReadOptions;
+use tokio::sync::mpsc::Sender;
 
 pub const TPCH_TABLES: &[&str] = &[
     "part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", "region",
 ];
 
+/// Test utility that allows observing scheduler events.
+pub struct SchedulerEventObserver {
+    sender: Sender<SchedulerServerEvent>,
+    errors: Sender<BallistaError>,
+}
+
+impl SchedulerEventObserver {
+    pub fn new(
+        sender: Sender<SchedulerServerEvent>,
+        errors: Sender<BallistaError>,
+    ) -> Self {
+        Self { sender, errors }
+    }
+}
+
+#[async_trait]
+impl EventAction<SchedulerServerEvent> for SchedulerEventObserver {
+    fn on_start(&self) {}
+
+    fn on_stop(&self) {}
+
+    async fn on_receive(
+        &self,
+        event: SchedulerServerEvent,
+    ) -> Result<Option<SchedulerServerEvent>> {
+        self.sender.send(event).await.unwrap();
+
+        Ok(None)
+    }
+
+    fn on_error(&self, error: BallistaError) {
+        let errors = self.errors.clone();
+        tokio::task::spawn(async move { errors.send(error).await.unwrap() });
+    }
+}
+
+/// Sometimes we need to construct logical plans that will produce errors
+/// when we try and create physical plan. A scan using `ExplodingTableProvider`
+/// will do the trick
+pub struct ExplodingTableProvider;
+
+#[async_trait]
+impl TableProvider for ExplodingTableProvider {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn schema(&self) -> SchemaRef {
+        Arc::new(Schema::empty())
+    }
+
+    fn table_type(&self) -> TableType {
+        TableType::Base
+    }
+
+    async fn scan(
+        &self,
+        _ctx: &SessionState,
+        _projection: &Option<Vec<usize>>,
+        _filters: &[Expr],
+        _limit: Option<usize>,
+    ) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
+        Err(DataFusionError::Plan(
+            "ExplodingTableProvider just throws an error!".to_owned(),
+        ))
+    }
+}
+
+/// Utility for running some async check multiple times to verify a condition. It will run the check
+/// at the specified interval up to a maximum of the specified iterations.
+pub async fn await_condition<Fut: Future<Output = Result<bool>>, F: Fn() -> Fut>(
+    interval: Duration,
+    iterations: usize,
+    cond: F,
+) -> Result<bool> {
+    let mut iteration = 0;
+
+    while iteration < iterations {
+        let check = cond().await?;
+
+        if check {
+            return Ok(true);
+        } else {
+            iteration += 1;
+            tokio::time::sleep(interval).await;
+        }
+    }
+
+    Ok(false)
+}
+
 pub async fn datafusion_test_context(path: &str) -> Result<SessionContext> {
     let default_shuffle_partitions = 2;
     let config = SessionConfig::new().with_target_partitions(default_shuffle_partitions);
diff --git a/benchmarks/docker-compose.yaml b/benchmarks/docker-compose.yaml
index 1aa8da50..3ebe597e 100644
--- a/benchmarks/docker-compose.yaml
+++ b/benchmarks/docker-compose.yaml
@@ -19,11 +19,13 @@ services:
   etcd:
     image: quay.io/coreos/etcd:v3.4.9
     command: "etcd -advertise-client-urls http://etcd:2379 -listen-client-urls http://0.0.0.0:2379"
+    ports:
+      - 2379:2379
   ballista-scheduler:
     image: ballista:0.7.0
     command: "/scheduler --config-backend etcd --etcd-urls etcd:2379 --bind-host 0.0.0.0 --bind-port 50050"
     environment:
-      - RUST_LOG=ballista=debug
+      - RUST_LOG=info
     volumes:
       - ./data:/data
     depends_on:
@@ -33,7 +35,7 @@ services:
     command: "/executor --bind-host 0.0.0.0 --bind-port 50051 --scheduler-host ballista-scheduler"
     scale: 2
     environment:
-      - RUST_LOG=info
+      - RUST_LOG=ballista=debug,info
     volumes:
       - ./data:/data
     depends_on:
@@ -48,4 +50,3 @@ services:
     depends_on:
       - ballista-scheduler
       - ballista-executor
-