This is an automated email from the ASF dual-hosted git repository. thinkharderdev pushed a commit to branch cluster-state-refactor-2 in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git
commit f03f208e355dbe5fe5c3688e2d856ea2b82a689b Author: Dan Harris <[email protected]> AuthorDate: Mon Feb 6 16:45:22 2023 -0500 Implement JobState --- ballista/core/src/serde/mod.rs | 11 +- ballista/scheduler/scheduler_config_spec.toml | 8 +- ballista/scheduler/src/bin/main.rs | 108 +++----------- ballista/scheduler/src/cluster/kv.rs | 156 +++++++++++++++++++-- ballista/scheduler/src/cluster/memory.rs | 87 ++++++++++-- ballista/scheduler/src/cluster/mod.rs | 143 ++++++++++++++++++- ballista/scheduler/src/cluster/storage/mod.rs | 11 +- ballista/scheduler/src/cluster/storage/sled.rs | 4 +- ballista/scheduler/src/config.rs | 46 ++++++ ballista/scheduler/src/scheduler_process.rs | 9 +- ballista/scheduler/src/scheduler_server/grpc.rs | 36 ++--- ballista/scheduler/src/scheduler_server/mod.rs | 69 ++------- ballista/scheduler/src/standalone.rs | 22 ++- ballista/scheduler/src/state/execution_graph.rs | 9 +- .../scheduler/src/state/execution_graph_dot.rs | 4 +- ballista/scheduler/src/state/executor_manager.rs | 36 ++--- ballista/scheduler/src/state/mod.rs | 64 +++------ ballista/scheduler/src/state/session_manager.rs | 69 +-------- ballista/scheduler/src/state/task_manager.rs | 150 +++++--------------- ballista/scheduler/src/test_utils.rs | 21 ++- 20 files changed, 583 insertions(+), 480 deletions(-) diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs index b1b2ab38..ded11f8b 100644 --- a/ballista/core/src/serde/mod.rs +++ b/ballista/core/src/serde/mod.rs @@ -27,6 +27,7 @@ use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::logical_plan::{ AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec, }; +use datafusion_proto::protobuf::LogicalPlanNode; use prost::bytes::BufMut; use prost::Message; use std::fmt::Debug; @@ -34,6 +35,7 @@ use std::marker::PhantomData; use std::sync::Arc; use std::{convert::TryInto, io::Cursor}; +use crate::serde::protobuf::PhysicalPlanNode; pub use generated::ballista as protobuf; pub mod generated; @@ -132,16 +134,17 @@ impl PhysicalExtensionCodec for DefaultPhysicalExtensionCodec { } #[derive(Clone, Debug)] -pub struct BallistaCodec<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> { +pub struct BallistaCodec< + T: 'static + AsLogicalPlan = LogicalPlanNode, + U: 'static + AsExecutionPlan = PhysicalPlanNode, +> { logical_extension_codec: Arc<dyn LogicalExtensionCodec>, physical_extension_codec: Arc<dyn PhysicalExtensionCodec>, logical_plan_repr: PhantomData<T>, physical_plan_repr: PhantomData<U>, } -impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> Default - for BallistaCodec<T, U> -{ +impl Default for BallistaCodec { fn default() -> Self { Self { logical_extension_codec: Arc::new(DefaultLogicalExtensionCodec {}), diff --git a/ballista/scheduler/scheduler_config_spec.toml b/ballista/scheduler/scheduler_config_spec.toml index 0bb609cf..4ae8969d 100644 --- a/ballista/scheduler/scheduler_config_spec.toml +++ b/ballista/scheduler/scheduler_config_spec.toml @@ -32,16 +32,16 @@ doc = "Route for proxying flight results via scheduler. Should be of the form 'I [[param]] abbr = "b" name = "config_backend" -type = "ballista_scheduler::state::backend::StateBackend" +type = "ballista_scheduler::cluster::ClusterStorage" doc = "The configuration backend for the scheduler, possible values: etcd, memory, sled. Default: sled" -default = "ballista_scheduler::state::backend::StateBackend::Sled" +default = "ballista_scheduler::cluster::ClusterStorage::Sled" [[param]] abbr = "c" name = "cluster_backend" -type = "ballista_scheduler::state::backend::StateBackend" +type = "ballista_scheduler::cluster::ClusterStorage" doc = "The configuration backend for the scheduler cluster state, possible values: etcd, memory, sled. Default: sled" -default = "ballista_scheduler::state::backend::StateBackend::Sled" +default = "ballista_scheduler::cluster::ClusterStorage::Sled" [[param]] abbr = "n" diff --git a/ballista/scheduler/src/bin/main.rs b/ballista/scheduler/src/bin/main.rs index ca248825..62236309 100644 --- a/ballista/scheduler/src/bin/main.rs +++ b/ballista/scheduler/src/bin/main.rs @@ -17,18 +17,18 @@ //! Ballista Rust scheduler binary. -use std::{env, io, sync::Arc}; +use std::{env, io}; -use anyhow::{Context, Result}; +use anyhow::Result; use ballista_core::print_version; use ballista_scheduler::scheduler_process::start_server; -#[cfg(feature = "etcd")] -use ballista_scheduler::state::backend::etcd::EtcdClient; -use ballista_scheduler::state::backend::memory::MemoryBackendClient; -#[cfg(feature = "sled")] -use ballista_scheduler::state::backend::sled::SledClient; -use ballista_scheduler::state::backend::{StateBackend, StateBackendClient}; + +use crate::config::{Config, ResultExt}; +use ballista_core::config::LogRotationPolicy; +use ballista_scheduler::cluster::BallistaCluster; +use ballista_scheduler::config::{ClusterStorageConfig, SchedulerConfig}; +use tracing_subscriber::EnvFilter; #[macro_use] extern crate configure_me; @@ -43,15 +43,6 @@ mod config { )); } -use ballista_core::config::LogRotationPolicy; -use ballista_core::utils::default_session_builder; -use ballista_scheduler::cluster::memory::{InMemoryClusterState, InMemoryJobState}; -use ballista_scheduler::cluster::BallistaCluster; -use ballista_scheduler::config::SchedulerConfig; -use ballista_scheduler::state::backend::cluster::DefaultClusterState; -use config::prelude::*; -use tracing_subscriber::EnvFilter; - #[tokio::main] async fn main() -> Result<()> { // parse options @@ -64,26 +55,14 @@ async fn main() -> Result<()> { std::process::exit(0); } - let config_backend = init_kv_backend(&opt.config_backend, &opt).await?; - - let cluster_state = if opt.cluster_backend == opt.config_backend { - Arc::new(DefaultClusterState::new(config_backend.clone())) - } else { - let cluster_kv_store = init_kv_backend(&opt.cluster_backend, &opt).await?; - - Arc::new(DefaultClusterState::new(cluster_kv_store)) - }; - let special_mod_log_level = opt.log_level_setting; - let namespace = opt.namespace; - let external_host = opt.external_host; - let bind_host = opt.bind_host; - let port = opt.bind_port; let log_dir = opt.log_dir; let print_thread_info = opt.print_thread_info; - let log_file_name_prefix = - format!("scheduler_{}_{}_{}", namespace, external_host, port); - let scheduler_name = format!("{}:{}", external_host, port); + + let log_file_name_prefix = format!( + "scheduler_{}_{}_{}", + opt.namespace, opt.external_host, opt.bind_port + ); let rust_log = env::var(EnvFilter::DEFAULT_ENV); let log_filter = EnvFilter::new(rust_log.unwrap_or(special_mod_log_level)); @@ -121,10 +100,13 @@ async fn main() -> Result<()> { .init(); } - let addr = format!("{}:{}", bind_host, port); + let addr = format!("{}:{}", opt.bind_host, opt.bind_port); let addr = addr.parse()?; let config = SchedulerConfig { + namespace: opt.namespace, + external_host: opt.external_host, + bind_port: opt.bind_port, scheduling_policy: opt.scheduler_policy, event_loop_buffer_size: opt.event_loop_buffer_size, executor_slots_policy: opt.executor_slots_policy, @@ -133,61 +115,11 @@ async fn main() -> Result<()> { finished_job_state_clean_up_interval_seconds: opt .finished_job_state_clean_up_interval_seconds, advertise_flight_sql_endpoint: opt.advertise_flight_sql_endpoint, + cluster_storage: ClusterStorageConfig::Memory, }; - let cluster = BallistaCluster::new( - Arc::new(InMemoryClusterState::default()), - Arc::new(InMemoryJobState::new( - &scheduler_name, - default_session_builder, - )), - ); + let cluster = BallistaCluster::new_from_config(&config).await?; - start_server(scheduler_name, cluster, config_backend, addr, config).await?; + start_server(cluster, addr, config).await?; Ok(()) } - -async fn init_kv_backend( - backend: &StateBackend, - opt: &Config, -) -> Result<Arc<dyn StateBackendClient>> { - let cluster_backend: Arc<dyn StateBackendClient> = match backend { - #[cfg(feature = "etcd")] - StateBackend::Etcd => { - let etcd = etcd_client::Client::connect(&[opt.etcd_urls.clone()], None) - .await - .context("Could not connect to etcd")?; - Arc::new(EtcdClient::new(opt.namespace.clone(), etcd)) - } - #[cfg(not(feature = "etcd"))] - StateBackend::Etcd => { - unimplemented!( - "build the scheduler with the `etcd` feature to use the etcd config backend" - ) - } - #[cfg(feature = "sled")] - StateBackend::Sled => { - if opt.sled_dir.is_empty() { - Arc::new( - SledClient::try_new_temporary() - .context("Could not create sled config backend")?, - ) - } else { - println!("{}", opt.sled_dir); - Arc::new( - SledClient::try_new(opt.sled_dir.clone()) - .context("Could not create sled config backend")?, - ) - } - } - #[cfg(not(feature = "sled"))] - StateBackend::Sled => { - unimplemented!( - "build the scheduler with the `sled` feature to use the sled config backend" - ) - } - StateBackend::Memory => Arc::new(MemoryBackendClient::new()), - }; - - Ok(cluster_backend) -} diff --git a/ballista/scheduler/src/cluster/kv.rs b/ballista/scheduler/src/cluster/kv.rs index d130f339..25f09ee3 100644 --- a/ballista/scheduler/src/cluster/kv.rs +++ b/ballista/scheduler/src/cluster/kv.rs @@ -18,30 +18,33 @@ use crate::cluster::storage::{KeyValueStore, Keyspace, Lock, Operation, WatchEvent}; use crate::cluster::{ reserve_slots_bias, reserve_slots_round_robin, ClusterState, ExecutorHeartbeatStream, - JobState, JobStateEventStream, JobStatus, TaskDistribution, + JobState, JobStateEvent, JobStateEventStream, JobStatus, TaskDistribution, }; +use crate::scheduler_server::SessionBuilder; use crate::state::execution_graph::ExecutionGraph; use crate::state::executor_manager::ExecutorReservation; +use crate::state::session_manager::create_datafusion_context; use crate::state::{decode_into, decode_protobuf}; use async_trait::async_trait; use ballista_core::config::BallistaConfig; use ballista_core::error::{BallistaError, Result}; +use ballista_core::serde::protobuf::job_status::Status; use ballista_core::serde::protobuf::{ - self, AvailableTaskSlots, ExecutorHeartbeat, ExecutorTaskSlots, + self, ExecutorHeartbeat, ExecutorTaskSlots, FailedJob, KeyValuePair, }; use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata}; use ballista_core::serde::AsExecutionPlan; use ballista_core::serde::BallistaCodec; +use dashmap::DashMap; use datafusion::prelude::SessionContext; use datafusion_proto::logical_plan::AsLogicalPlan; use futures::StreamExt; -use itertools::Itertools; +use log::warn; use prost::Message; use std::collections::{HashMap, HashSet}; use std::future::Future; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; -use dashmap::DashMap; /// State implementation based on underlying `KeyValueStore` pub struct KeyValueState< @@ -54,9 +57,12 @@ pub struct KeyValueState< /// Codec used to serialize/deserialize execution plan codec: BallistaCodec<T, U>, /// Name of current scheduler. Should be `{host}:{port}` + #[allow(dead_code)] scheduler: String, /// In-memory store of queued jobs. Map from Job ID -> (Job Name, queued_at timestamp) queued_jobs: DashMap<String, (String, u64)>, + //// `SessionBuilder` for constructing `SessionContext` from stored `BallistaConfig` + session_builder: SessionBuilder, } impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> @@ -66,12 +72,14 @@ impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> scheduler: impl Into<String>, store: S, codec: BallistaCodec<T, U>, + session_builder: SessionBuilder, ) -> Self { Self { store, scheduler: scheduler.into(), codec, queued_jobs: DashMap::new(), + session_builder, } } } @@ -362,8 +370,16 @@ impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> JobState for KeyValueState<S, T, U> { - async fn accept_job(&self, job_id: &str, job_name: &str, queued_at: u64) -> Result<()> { - todo!() + async fn accept_job( + &self, + job_id: &str, + job_name: &str, + queued_at: u64, + ) -> Result<()> { + self.queued_jobs + .insert(job_id.to_string(), (job_name.to_string(), queued_at)); + + Ok(()) } async fn submit_job(&self, job_id: String, graph: &ExecutionGraph) -> Result<()> { @@ -439,8 +455,44 @@ impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> .await } + async fn fail_unscheduled_job(&self, job_id: &str, reason: String) -> Result<()> { + if let Some((job_id, (job_name, queued_at))) = self.queued_jobs.remove(job_id) { + let status = JobStatus { + job_id: job_id.clone(), + job_name, + status: Some(Status::Failed(FailedJob { + error: reason, + queued_at, + started_at: 0, + ended_at: 0, + })), + }; + + self.store + .put(Keyspace::JobStatus, job_id, status.encode_to_vec()) + .await + } else { + Err(BallistaError::Internal(format!( + "Could not fail unscheduled job {job_id}, not found in queued jobs" + ))) + } + } + async fn remove_job(&self, job_id: &str) -> Result<()> { - todo!() + if self.queued_jobs.remove(job_id).is_none() { + self.store + .apply_txn(vec![ + (Operation::Delete, Keyspace::JobStatus, job_id.to_string()), + ( + Operation::Delete, + Keyspace::ExecutionGraph, + job_id.to_string(), + ), + ]) + .await + } else { + Ok(()) + } } async fn try_acquire_job(&self, _job_id: &str) -> Result<Option<ExecutionGraph>> { @@ -449,19 +501,81 @@ impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> )) } - async fn job_state_events(&self) -> JobStateEventStream { - todo!() + async fn job_state_events(&self) -> Result<JobStateEventStream> { + let watch = self + .store + .watch(Keyspace::JobStatus, String::default()) + .await?; + + let stream = watch + .filter_map(|event| { + futures::future::ready(match event { + WatchEvent::Put(key, value) => { + if let Some(job_id) = Keyspace::JobStatus.strip_prefix(&key) { + match JobStatus::decode(value.as_slice()) { + Ok(status) => Some(JobStateEvent::JobUpdated { + job_id: job_id.to_string(), + status, + }), + Err(err) => { + warn!( + "Error decoding job status from watch event: {err:?}" + ); + None + } + } + } else { + None + } + } + _ => None, + }) + }) + .boxed(); + + Ok(stream) } async fn get_session(&self, session_id: &str) -> Result<Arc<SessionContext>> { - todo!() + let value = self.store.get(Keyspace::Sessions, session_id).await?; + + let settings: protobuf::SessionSettings = decode_protobuf(&value)?; + + let mut config_builder = BallistaConfig::builder(); + for kv_pair in &settings.configs { + config_builder = config_builder.set(&kv_pair.key, &kv_pair.value); + } + let config = config_builder.build()?; + + Ok(create_datafusion_context(&config, self.session_builder)) } async fn create_session( &self, config: &BallistaConfig, ) -> Result<Arc<SessionContext>> { - todo!() + let mut settings: Vec<KeyValuePair> = vec![]; + + for (key, value) in config.settings() { + settings.push(KeyValuePair { + key: key.clone(), + value: value.clone(), + }) + } + + let value = protobuf::SessionSettings { configs: settings }; + + let session = create_datafusion_context(config, self.session_builder); + + self.store + .put( + Keyspace::Sessions, + session.session_id(), + value.encode_to_vec(), + ) + .await?; + + Ok(session) } async fn update_session( @@ -469,7 +583,25 @@ impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> session_id: &str, config: &BallistaConfig, ) -> Result<Arc<SessionContext>> { - todo!() + let mut settings: Vec<KeyValuePair> = vec![]; + + for (key, value) in config.settings() { + settings.push(KeyValuePair { + key: key.clone(), + value: value.clone(), + }) + } + + let value = protobuf::SessionSettings { configs: settings }; + self.store + .put( + Keyspace::Sessions, + session_id.to_owned(), + value.encode_to_vec(), + ) + .await?; + + Ok(create_datafusion_context(config, self.session_builder)) } } diff --git a/ballista/scheduler/src/cluster/memory.rs b/ballista/scheduler/src/cluster/memory.rs index 258444ec..2d59fd9b 100644 --- a/ballista/scheduler/src/cluster/memory.rs +++ b/ballista/scheduler/src/cluster/memory.rs @@ -24,13 +24,15 @@ use crate::state::executor_manager::ExecutorReservation; use async_trait::async_trait; use ballista_core::config::BallistaConfig; use ballista_core::error::{BallistaError, Result}; -use ballista_core::serde::protobuf::{ExecutorHeartbeat, ExecutorTaskSlots}; +use ballista_core::serde::protobuf::{ + executor_status, ExecutorHeartbeat, ExecutorStatus, ExecutorTaskSlots, FailedJob, +}; use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata}; use dashmap::DashMap; use datafusion::prelude::SessionContext; use crate::cluster::event::ClusterEventSender; -use crate::scheduler_server::SessionBuilder; +use crate::scheduler_server::{timestamp_millis, SessionBuilder}; use crate::state::session_manager::{ create_datafusion_context, update_datafusion_context, }; @@ -39,8 +41,8 @@ use itertools::Itertools; use log::warn; use parking_lot::Mutex; use std::collections::{HashMap, HashSet}; -use std::ops::{Deref, DerefMut}; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::ops::DerefMut; + use std::sync::Arc; #[derive(Default)] @@ -145,6 +147,15 @@ impl ClusterState for InMemoryClusterState { mut spec: ExecutorData, reserve: bool, ) -> Result<Vec<ExecutorReservation>> { + let heartbeat = ExecutorHeartbeat { + executor_id: metadata.id.clone(), + timestamp: timestamp_millis(), + metrics: vec![], + status: Some(ExecutorStatus { + status: Some(executor_status::Status::Active(String::default())), + }), + }; + if reserve { let slots = std::mem::take(&mut spec.available_task_slots) as usize; @@ -155,20 +166,39 @@ impl ClusterState for InMemoryClusterState { self.executors.insert(metadata.id.clone(), (metadata, spec)); + self.heartbeat_sender.send(&heartbeat); + Ok(reservations) } else { self.executors.insert(metadata.id.clone(), (metadata, spec)); + self.heartbeat_sender.send(&heartbeat); + Ok(vec![]) } } async fn save_executor_metadata(&self, metadata: ExecutorMetadata) -> Result<()> { - todo!() + if let Some(pair) = self.executors.get_mut(&metadata.id).as_deref_mut() { + pair.0 = metadata; + } else { + warn!( + "Failed to update executor metadata, executor with ID {} not found", + metadata.id + ); + } + Ok(()) } async fn get_executor_metadata(&self, executor_id: &str) -> Result<ExecutorMetadata> { - todo!() + self.executors + .get(executor_id) + .map(|pair| pair.value().0.clone()) + .ok_or_else(|| { + BallistaError::Internal(format!( + "Not executor with ID {executor_id} found" + )) + }) } async fn save_executor_heartbeat(&self, heartbeat: ExecutorHeartbeat) -> Result<()> { @@ -247,7 +277,6 @@ impl InMemoryJobState { #[async_trait] impl JobState for InMemoryJobState { - async fn submit_job(&self, job_id: String, _graph: &ExecutionGraph) -> Result<()> { self.job_event_sender.send(&JobStateEvent::JobAcquired { job_id, @@ -328,8 +357,8 @@ impl JobState for InMemoryJobState { } } - async fn job_state_events(&self) -> JobStateEventStream { - Box::pin(self.job_event_sender.subscribe()) + async fn job_state_events(&self) -> Result<JobStateEventStream> { + Ok(Box::pin(self.job_event_sender.subscribe())) } async fn remove_job(&self, job_id: &str) -> Result<()> { @@ -347,7 +376,43 @@ impl JobState for InMemoryJobState { .collect()) } - async fn accept_job(&self, job_id: &str, job_name: &str, queued_at: u64) -> Result<()> { - todo!() + async fn accept_job( + &self, + job_id: &str, + job_name: &str, + queued_at: u64, + ) -> Result<()> { + self.queued_jobs + .insert(job_id.to_string(), (job_name.to_string(), queued_at)); + + Ok(()) + } + + async fn fail_unscheduled_job(&self, job_id: &str, reason: String) -> Result<()> { + if let Some(pair) = self.queued_jobs.get(job_id) { + let (job_name, queued_at) = pair.value(); + self.completed_jobs.insert( + job_id.to_string(), + ( + JobStatus { + job_id: job_id.to_string(), + job_name: job_name.clone(), + status: Some(Status::Failed(FailedJob { + error: reason, + queued_at: *queued_at, + started_at: 0, + ended_at: timestamp_millis(), + })), + }, + None, + ), + ); + + Ok(()) + } else { + Err(BallistaError::Internal(format!( + "Could not fail unscheduler job {job_id}, job not found in queued jobs" + ))) + } } } diff --git a/ballista/scheduler/src/cluster/mod.rs b/ballista/scheduler/src/cluster/mod.rs index da7c16e8..f23fa070 100644 --- a/ballista/scheduler/src/cluster/mod.rs +++ b/ballista/scheduler/src/cluster/mod.rs @@ -20,21 +20,55 @@ pub mod kv; pub mod memory; pub mod storage; +use crate::cluster::kv::KeyValueState; +use crate::cluster::memory::{InMemoryClusterState, InMemoryJobState}; +use crate::cluster::storage::etcd::EtcdClient; +use crate::cluster::storage::sled::SledClient; use crate::cluster::storage::KeyValueStore; +use crate::config::{ClusterStorageConfig, SchedulerConfig}; +use crate::scheduler_server::SessionBuilder; use crate::state::execution_graph::ExecutionGraph; use crate::state::executor_manager::ExecutorReservation; use ballista_core::config::BallistaConfig; -use ballista_core::error::Result; +use ballista_core::error::{BallistaError, Result}; use ballista_core::serde::protobuf::{AvailableTaskSlots, ExecutorHeartbeat, JobStatus}; use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata}; -use ballista_core::serde::AsExecutionPlan; +use ballista_core::serde::{AsExecutionPlan, BallistaCodec}; +use ballista_core::utils::default_session_builder; +use clap::ArgEnum; use datafusion::prelude::SessionContext; use datafusion_proto::logical_plan::AsLogicalPlan; use futures::Stream; +use log::info; use std::collections::{HashMap, HashSet}; +use std::fmt; use std::pin::Pin; use std::sync::Arc; +// an enum used to configure the backend +// needs to be visible to code generated by configure_me +#[derive(Debug, Clone, ArgEnum, serde::Deserialize, PartialEq, Eq)] +pub enum ClusterStorage { + Etcd, + Memory, + Sled, +} + +impl std::str::FromStr for ClusterStorage { + type Err = String; + + fn from_str(s: &str) -> std::result::Result<Self, Self::Err> { + ArgEnum::from_str(s, true) + } +} + +impl parse_arg::ParseArgFromStr for ClusterStorage { + fn describe_type<W: fmt::Write>(mut writer: W) -> fmt::Result { + write!(writer, "The cluster storage backend for the scheduler") + } +} + +#[derive(Clone)] pub struct BallistaCluster { cluster_state: Arc<dyn ClusterState>, job_state: Arc<dyn JobState>, @@ -51,6 +85,98 @@ impl BallistaCluster { } } + pub fn new_memory( + scheduler: impl Into<String>, + session_builder: SessionBuilder, + ) -> Self { + Self { + cluster_state: Arc::new(InMemoryClusterState::default()), + job_state: Arc::new(InMemoryJobState::new(scheduler, session_builder)), + } + } + + pub fn new_kv< + S: KeyValueStore, + T: 'static + AsLogicalPlan, + U: 'static + AsExecutionPlan, + >( + store: S, + scheduler: impl Into<String>, + session_builder: SessionBuilder, + codec: BallistaCodec<T, U>, + ) -> Self { + let kv_state = + Arc::new(KeyValueState::new(scheduler, store, codec, session_builder)); + Self { + cluster_state: kv_state.clone(), + job_state: kv_state, + } + } + + pub async fn new_from_config(config: &SchedulerConfig) -> Result<Self> { + let scheduler = config.scheduler_name(); + + match &config.cluster_storage { + #[cfg(feature = "etcd")] + ClusterStorageConfig::Etcd(urls) => { + let etcd = etcd_client::Client::connect(urls.as_slice(), None) + .await + .map_err(|err| { + BallistaError::Internal(format!( + "Could not connect to etcd: {err:?}" + )) + })?; + + Ok(Self::new_kv( + EtcdClient::new(config.namespace.clone(), etcd), + scheduler, + default_session_builder, + BallistaCodec::default(), + )) + } + #[cfg(not(feature = "etcd"))] + StateBackend::Etcd => { + unimplemented!( + "build the scheduler with the `etcd` feature to use the etcd config backend" + ) + } + #[cfg(feature = "sled")] + ClusterStorageConfig::Sled(dir) => { + if let Some(dir) = dir.as_ref() { + info!("Initializing Sled database in directory {}", dir); + let sled = SledClient::try_new(dir)?; + + Ok(Self::new_kv( + sled, + scheduler, + default_session_builder, + BallistaCodec::default(), + )) + } else { + info!("Initializing Sled database in temp directory"); + let sled = SledClient::try_new_temporary()?; + + Ok(Self::new_kv( + sled, + scheduler, + default_session_builder, + BallistaCodec::default(), + )) + } + } + #[cfg(not(feature = "sled"))] + StateBackend::Sled => { + unimplemented!( + "build the scheduler with the `sled` feature to use the sled config backend" + ) + } + ClusterStorageConfig::Memory => Ok(BallistaCluster::new_memory( + scheduler, + default_session_builder, + )), + } + } + pub fn cluster_state(&self) -> Arc<dyn ClusterState> { self.cluster_state.clone() } @@ -184,7 +310,12 @@ pub trait JobState: Send + Sync { /// Accept job into a scheduler's job queue. This should be called when a job is /// received by the scheduler but before it is planned and may or may not be saved /// in global state - async fn accept_job(&self, job_id: &str, job_name: &str, queued_at: u64) -> Result<()>; + async fn accept_job( + &self, + job_id: &str, + job_name: &str, + queued_at: u64, + ) -> Result<()>; /// Submit a new job to the `JobState`. It is assumed that the submitter owns the job. /// In local state the job should be save as `JobStatus::Active` and in shared state @@ -207,6 +338,10 @@ pub trait JobState: Send + Sync { /// if the job is not owned by the caller. async fn save_job(&self, job_id: &str, graph: &ExecutionGraph) -> Result<()>; + /// Mark a job which has not been submitted as failed. This should be called if a job fails + /// during planning (and does not yet have an `ExecutionGraph`) + async fn fail_unscheduled_job(&self, job_id: &str, reason: String) -> Result<()>; + /// Delete a job from the global state async fn remove_job(&self, job_id: &str) -> Result<()>; @@ -217,7 +352,7 @@ pub trait JobState: Send + Sync { /// Get a stream of all `JobState` events. An event should be published any time that status /// of a job changes in state - async fn job_state_events(&self) -> JobStateEventStream; + async fn job_state_events(&self) -> Result<JobStateEventStream>; /// Get the `SessionContext` associated with `session_id`. Returns an error if the /// session does not exist diff --git a/ballista/scheduler/src/cluster/storage/mod.rs b/ballista/scheduler/src/cluster/storage/mod.rs index 0cee7a1e..a7f412a4 100644 --- a/ballista/scheduler/src/cluster/storage/mod.rs +++ b/ballista/scheduler/src/cluster/storage/mod.rs @@ -31,14 +31,17 @@ pub enum Keyspace { Executors, JobStatus, ExecutionGraph, - ActiveJobs, - CompletedJobs, - FailedJobs, Slots, Sessions, Heartbeats, } +impl Keyspace { + pub fn strip_prefix<'a>(&'a self, key: &'a str) -> Option<&'a str> { + key.strip_prefix(&format!("{:?}/", self)) + } +} + #[derive(Debug, Eq, PartialEq, Hash)] pub enum Operation { Put(Vec<u8>), @@ -47,7 +50,7 @@ pub enum Operation { /// A trait that defines a KeyValue interface with basic locking primitives for persisting Ballista cluster state #[async_trait] -pub trait KeyValueStore: Send + Sync + Clone { +pub trait KeyValueStore: Send + Sync + Clone + 'static { /// Retrieve the data associated with a specific key in a given keyspace. /// /// An empty vec is returned if the key does not exist. diff --git a/ballista/scheduler/src/cluster/storage/sled.rs b/ballista/scheduler/src/cluster/storage/sled.rs index 67700d2c..27ed8ac6 100644 --- a/ballista/scheduler/src/cluster/storage/sled.rs +++ b/ballista/scheduler/src/cluster/storage/sled.rs @@ -288,8 +288,8 @@ impl Stream for SledWatch { mod tests { use super::{KeyValueStore, SledClient, Watch, WatchEvent}; - use crate::cluster::storage::{Keyspace, Operation}; - use crate::state::with_locks; + use crate::cluster::storage::Keyspace; + use futures::StreamExt; use std::result::Result; diff --git a/ballista/scheduler/src/config.rs b/ballista/scheduler/src/config.rs index 97f79787..e10a1b69 100644 --- a/ballista/scheduler/src/config.rs +++ b/ballista/scheduler/src/config.rs @@ -25,6 +25,13 @@ use std::fmt; /// Configurations for the ballista scheduler of scheduling jobs and tasks #[derive(Debug, Clone)] pub struct SchedulerConfig { + /// Namespace of this scheduler. Schedulers using the same cluster storage and namespace + /// will share gloabl cluster state. + pub namespace: String, + /// The external hostname of the scheduler + pub external_host: String, + /// The bind port for the scheduler's gRPC service + pub bind_port: u16, /// The task scheduling policy for the scheduler pub scheduling_policy: TaskSchedulingPolicy, /// The event loop buffer size. for a system of high throughput, a larger value like 1000000 is recommended @@ -37,26 +44,51 @@ pub struct SchedulerConfig { pub finished_job_state_clean_up_interval_seconds: u64, /// The route endpoint for proxying flight sql results via scheduler pub advertise_flight_sql_endpoint: Option<String>, + /// Configuration for ballista cluster storage + pub cluster_storage: ClusterStorageConfig, } impl Default for SchedulerConfig { fn default() -> Self { Self { + namespace: String::default(), + external_host: "localhost".to_string(), + bind_port: 50050, scheduling_policy: TaskSchedulingPolicy::PullStaged, event_loop_buffer_size: 10000, executor_slots_policy: SlotsPolicy::Bias, finished_job_data_clean_up_interval_seconds: 300, finished_job_state_clean_up_interval_seconds: 3600, advertise_flight_sql_endpoint: None, + cluster_storage: ClusterStorageConfig::Memory, } } } impl SchedulerConfig { + pub fn scheduler_name(&self) -> String { + format!("{}:{}", self.external_host, self.bind_port) + } + pub fn is_push_staged_scheduling(&self) -> bool { matches!(self.scheduling_policy, TaskSchedulingPolicy::PushStaged) } + pub fn with_namespace(mut self, namespace: impl Into<String>) -> Self { + self.namespace = namespace.into(); + self + } + + pub fn with_hostname(mut self, hostname: impl Into<String>) -> Self { + self.external_host = hostname.into(); + self + } + + pub fn with_port(mut self, port: u16) -> Self { + self.bind_port = port; + self + } + pub fn with_scheduler_policy(mut self, policy: TaskSchedulingPolicy) -> Self { self.scheduling_policy = policy; self @@ -95,6 +127,20 @@ impl SchedulerConfig { self.executor_slots_policy = policy; self } + + pub fn with_cluster_storage(mut self, config: ClusterStorageConfig) -> Self { + self.cluster_storage = config; + self + } +} + +#[derive(Clone, Debug)] +pub enum ClusterStorageConfig { + Memory, + #[cfg(feature = "etcd")] + Etcd(Vec<String>), + #[cfg(feature = "sled")] + Sled(Option<String>), } // an enum used to configure the executor slots policy diff --git a/ballista/scheduler/src/scheduler_process.rs b/ballista/scheduler/src/scheduler_process.rs index 440baca2..3fe2ef04 100644 --- a/ballista/scheduler/src/scheduler_process.rs +++ b/ballista/scheduler/src/scheduler_process.rs @@ -22,7 +22,7 @@ use futures::future::{self, Either, TryFutureExt}; use hyper::{server::conn::AddrStream, service::make_service_fn, Server}; use log::info; use std::convert::Infallible; -use std::{net::SocketAddr, sync::Arc}; +use std::net::SocketAddr; use tonic::transport::server::Connected; use tower::Service; @@ -42,13 +42,9 @@ use crate::flight_sql::FlightSqlServiceImpl; use crate::metrics::default_metrics_collector; use crate::scheduler_server::externalscaler::external_scaler_server::ExternalScalerServer; use crate::scheduler_server::SchedulerServer; -use crate::state::backend::cluster::ClusterState; -use crate::state::backend::StateBackendClient; pub async fn start_server( - scheduler_name: String, cluster: BallistaCluster, - config_backend: Arc<dyn StateBackendClient>, addr: SocketAddr, config: SchedulerConfig, ) -> Result<()> { @@ -66,8 +62,7 @@ pub async fn start_server( let mut scheduler_server: SchedulerServer<LogicalPlanNode, PhysicalPlanNode> = SchedulerServer::new( - scheduler_name, - config_backend.clone(), + config.scheduler_name(), cluster, BallistaCodec::default(), config, diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs b/ballista/scheduler/src/scheduler_server/grpc.rs index ad04efb2..5ab7ba53 100644 --- a/ballista/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/scheduler/src/scheduler_server/grpc.rs @@ -570,7 +570,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerGrpc #[cfg(all(test, feature = "sled"))] mod test { - use std::sync::Arc; + use std::time::Duration; use datafusion_proto::protobuf::LogicalPlanNode; @@ -586,23 +586,21 @@ mod test { }; use ballista_core::serde::scheduler::ExecutorSpecification; use ballista_core::serde::BallistaCodec; - use ballista_core::utils::default_session_builder; - use crate::state::backend::cluster::DefaultClusterState; use crate::state::executor_manager::DEFAULT_EXECUTOR_TIMEOUT_SECONDS; - use crate::state::{backend::sled::SledClient, SchedulerState}; + use crate::state::SchedulerState; + use crate::test_utils::test_cluster_context; use super::{SchedulerGrpc, SchedulerServer}; #[tokio::test] async fn test_poll_work() -> Result<(), BallistaError> { - let state_storage = Arc::new(SledClient::try_new_temporary()?); - let cluster_state = Arc::new(DefaultClusterState::new(state_storage.clone())); + let cluster = test_cluster_context(); + let mut scheduler: SchedulerServer<LogicalPlanNode, PhysicalPlanNode> = SchedulerServer::new( "localhost:50050".to_owned(), - state_storage.clone(), - cluster_state.clone(), + cluster.clone(), BallistaCodec::default(), SchedulerConfig::default(), default_metrics_collector().unwrap(), @@ -629,9 +627,7 @@ mod test { assert!(response.tasks.is_empty()); let state: SchedulerState<LogicalPlanNode, PhysicalPlanNode> = SchedulerState::new_with_default_scheduler_name( - state_storage.clone(), - cluster_state.clone(), - default_session_builder, + cluster.clone(), BallistaCodec::default(), ); state.init().await?; @@ -663,9 +659,7 @@ mod test { assert!(response.tasks.is_empty()); let state: SchedulerState<LogicalPlanNode, PhysicalPlanNode> = SchedulerState::new_with_default_scheduler_name( - state_storage.clone(), - cluster_state, - default_session_builder, + cluster.clone(), BallistaCodec::default(), ); state.init().await?; @@ -687,13 +681,12 @@ mod test { #[tokio::test] async fn test_stop_executor() -> Result<(), BallistaError> { - let state_storage = Arc::new(SledClient::try_new_temporary()?); - let cluster_state = Arc::new(DefaultClusterState::new(state_storage.clone())); + let cluster = test_cluster_context(); + let mut scheduler: SchedulerServer<LogicalPlanNode, PhysicalPlanNode> = SchedulerServer::new( "localhost:50050".to_owned(), - state_storage, - cluster_state, + cluster.clone(), BallistaCodec::default(), SchedulerConfig::default(), default_metrics_collector().unwrap(), @@ -770,13 +763,12 @@ mod test { #[tokio::test] #[ignore] async fn test_expired_executor() -> Result<(), BallistaError> { - let state_storage = Arc::new(SledClient::try_new_temporary()?); - let cluster_state = Arc::new(DefaultClusterState::new(state_storage.clone())); + let cluster = test_cluster_context(); + let mut scheduler: SchedulerServer<LogicalPlanNode, PhysicalPlanNode> = SchedulerServer::new( "localhost:50050".to_owned(), - state_storage, - cluster_state, + cluster.clone(), BallistaCodec::default(), SchedulerConfig::default(), default_metrics_collector().unwrap(), diff --git a/ballista/scheduler/src/scheduler_server/mod.rs b/ballista/scheduler/src/scheduler_server/mod.rs index 1feb01b1..155e2725 100644 --- a/ballista/scheduler/src/scheduler_server/mod.rs +++ b/ballista/scheduler/src/scheduler_server/mod.rs @@ -22,7 +22,6 @@ use ballista_core::error::Result; use ballista_core::event_loop::{EventLoop, EventSender}; use ballista_core::serde::protobuf::{StopExecutorParams, TaskStatus}; use ballista_core::serde::{AsExecutionPlan, BallistaCodec}; -use ballista_core::utils::default_session_builder; use datafusion::execution::context::SessionState; use datafusion::logical_expr::LogicalPlan; @@ -36,8 +35,7 @@ use log::{error, warn}; use crate::scheduler_server::event::QueryStageSchedulerEvent; use crate::scheduler_server::query_stage_scheduler::QueryStageScheduler; -use crate::state::backend::cluster::ClusterState; -use crate::state::backend::StateBackendClient; + use crate::state::executor_manager::{ ExecutorManager, ExecutorReservation, DEFAULT_EXECUTOR_TIMEOUT_SECONDS, }; @@ -70,51 +68,13 @@ pub struct SchedulerServer<T: 'static + AsLogicalPlan, U: 'static + AsExecutionP impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T, U> { pub fn new( scheduler_name: String, - config_backend: Arc<dyn StateBackendClient>, - cluster: BallistaCluster, - codec: BallistaCodec<T, U>, - config: SchedulerConfig, - metrics_collector: Arc<dyn SchedulerMetricsCollector>, - ) -> Self { - let state = Arc::new(SchedulerState::new( - config_backend, - cluster, - default_session_builder, - codec, - scheduler_name.clone(), - config.clone(), - )); - let query_stage_scheduler = - Arc::new(QueryStageScheduler::new(state.clone(), metrics_collector)); - let query_stage_event_loop = EventLoop::new( - "query_stage".to_owned(), - config.event_loop_buffer_size as usize, - query_stage_scheduler.clone(), - ); - - Self { - scheduler_name, - start_time: timestamp_millis() as u128, - state, - query_stage_event_loop, - query_stage_scheduler, - } - } - - pub fn with_session_builder( - scheduler_name: String, - config_backend: Arc<dyn StateBackendClient>, cluster: BallistaCluster, - codec: BallistaCodec<T, U>, config: SchedulerConfig, - session_builder: SessionBuilder, metrics_collector: Arc<dyn SchedulerMetricsCollector>, ) -> Self { let state = Arc::new(SchedulerState::new( - config_backend, cluster, - session_builder, codec, scheduler_name.clone(), config.clone(), @@ -137,19 +97,16 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T } #[allow(dead_code)] - pub(crate) fn with_task_launcher( + pub(crate) fn new_with_task_launcher( scheduler_name: String, - config_backend: Arc<dyn StateBackendClient>, cluster: BallistaCluster, codec: BallistaCodec<T, U>, config: SchedulerConfig, metrics_collector: Arc<dyn SchedulerMetricsCollector>, task_launcher: Arc<dyn TaskLauncher>, ) -> Self { - let state = Arc::new(SchedulerState::with_task_launcher( - config_backend, + let state = Arc::new(SchedulerState::new_with_task_launcher( cluster, - default_session_builder, codec, scheduler_name.clone(), config.clone(), @@ -371,13 +328,11 @@ mod test { use ballista_core::serde::BallistaCodec; use crate::scheduler_server::{timestamp_millis, SchedulerServer}; - use crate::state::backend::cluster::DefaultClusterState; - use crate::state::backend::sled::SledClient; use crate::test_utils::{ assert_completed_event, assert_failed_event, assert_no_submitted_event, - assert_submitted_event, ExplodingTableProvider, SchedulerTest, TaskRunnerFn, - TestMetricsCollector, + assert_submitted_event, test_cluster_context, ExplodingTableProvider, + SchedulerTest, TaskRunnerFn, TestMetricsCollector, }; #[tokio::test] @@ -508,6 +463,7 @@ mod test { match status.status { Some(job_status::Status::Successful(SuccessfulJob { partition_location, + .. })) => { assert_eq!(partition_location.len(), 4); } @@ -583,7 +539,8 @@ mod test { matches!( status, JobStatus { - status: Some(job_status::Status::Failed(_)) + status: Some(job_status::Status::Failed(_)), + .. } ), "Expected job status to be failed but it was {:?}", @@ -624,7 +581,8 @@ mod test { matches!( status, JobStatus { - status: Some(job_status::Status::Failed(_)) + status: Some(job_status::Status::Failed(_)), + .. } ), "Expected job status to be failed but it was {:?}", @@ -640,13 +598,12 @@ mod test { async fn test_scheduler( scheduling_policy: TaskSchedulingPolicy, ) -> Result<SchedulerServer<LogicalPlanNode, PhysicalPlanNode>> { - let state_storage = Arc::new(SledClient::try_new_temporary()?); - let cluster_state = Arc::new(DefaultClusterState::new(state_storage.clone())); + let cluster = test_cluster_context(); + let mut scheduler: SchedulerServer<LogicalPlanNode, PhysicalPlanNode> = SchedulerServer::new( "localhost:50050".to_owned(), - state_storage, - cluster_state, + cluster, BallistaCodec::default(), SchedulerConfig::default().with_scheduler_policy(scheduling_policy), Arc::new(TestMetricsCollector::default()), diff --git a/ballista/scheduler/src/standalone.rs b/ballista/scheduler/src/standalone.rs index c8405347..f3dea2bc 100644 --- a/ballista/scheduler/src/standalone.rs +++ b/ballista/scheduler/src/standalone.rs @@ -15,13 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::cluster::event::ClusterEventSender; -use crate::cluster::memory::{InMemoryClusterState, InMemoryJobState}; use crate::cluster::BallistaCluster; use crate::config::SchedulerConfig; use crate::metrics::default_metrics_collector; -use crate::state::backend::cluster::DefaultClusterState; -use crate::{scheduler_server::SchedulerServer, state::backend::sled::SledClient}; +use crate::{cluster::storage::sled::SledClient, scheduler_server::SchedulerServer}; use ballista_core::serde::protobuf::PhysicalPlanNode; use ballista_core::serde::BallistaCodec; use ballista_core::utils::{create_grpc_server, default_session_builder}; @@ -31,31 +28,28 @@ use ballista_core::{ }; use datafusion_proto::protobuf::LogicalPlanNode; use log::info; -use std::{net::SocketAddr, sync::Arc}; +use std::net::SocketAddr; use tokio::net::TcpListener; pub async fn new_standalone_scheduler() -> Result<SocketAddr> { - let backend = Arc::new(SledClient::try_new_temporary()?); - let metrics_collector = default_metrics_collector()?; - let cluster = BallistaCluster::new( - Arc::new(InMemoryClusterState::default()), - Arc::new(InMemoryJobState::new( - "localhost:50050", - default_session_builder, - )), + let cluster = BallistaCluster::new_kv( + SledClient::try_new_temporary()?, + "localhost:50050", + default_session_builder, + BallistaCodec::default(), ); let mut scheduler_server: SchedulerServer<LogicalPlanNode, PhysicalPlanNode> = SchedulerServer::new( "localhost:50050".to_owned(), - backend.clone(), cluster, BallistaCodec::default(), SchedulerConfig::default(), metrics_collector, ); + scheduler_server.init().await?; let server = SchedulerGrpcServer::new(scheduler_server.clone()); // Let the OS assign a random, free port diff --git a/ballista/scheduler/src/state/execution_graph.rs b/ballista/scheduler/src/state/execution_graph.rs index a86f82ad..dbd29226 100644 --- a/ballista/scheduler/src/state/execution_graph.rs +++ b/ballista/scheduler/src/state/execution_graph.rs @@ -1703,7 +1703,8 @@ mod test { assert!(matches!( status, protobuf::JobStatus { - status: Some(job_status::Status::Successful(_)) + status: Some(job_status::Status::Successful(_)), + .. } )); @@ -1915,7 +1916,8 @@ mod test { matches!( agg_graph.status, JobStatus { - status: Some(job_status::Status::Queued(_)) + status: Some(job_status::Status::Queued(_)), + .. } ), "Expected job status to be running" @@ -2000,7 +2002,8 @@ mod test { matches!( agg_graph.status, JobStatus { - status: Some(job_status::Status::Failed(_)) + status: Some(job_status::Status::Failed(_)), + .. } ), "Expected job status to be Failed" diff --git a/ballista/scheduler/src/state/execution_graph_dot.rs b/ballista/scheduler/src/state/execution_graph_dot.rs index 521271e0..dd0b7b33 100644 --- a/ballista/scheduler/src/state/execution_graph_dot.rs +++ b/ballista/scheduler/src/state/execution_graph_dot.rs @@ -430,7 +430,7 @@ mod tests { #[tokio::test] async fn dot() -> Result<()> { let graph = test_graph().await?; - let dot = ExecutionGraphDot::generate(Arc::new(graph)) + let dot = ExecutionGraphDot::generate(&graph) .map_err(|e| BallistaError::Internal(format!("{:?}", e)))?; let expected = r#"digraph G { @@ -531,7 +531,7 @@ filter_expr="] #[tokio::test] async fn dot_optimized() -> Result<()> { let graph = test_graph_optimized().await?; - let dot = ExecutionGraphDot::generate(Arc::new(graph)) + let dot = ExecutionGraphDot::generate(&graph) .map_err(|e| BallistaError::Internal(format!("{:?}", e)))?; let expected = r#"digraph G { diff --git a/ballista/scheduler/src/state/executor_manager.rs b/ballista/scheduler/src/state/executor_manager.rs index 6bd18336..9155c015 100644 --- a/ballista/scheduler/src/state/executor_manager.rs +++ b/ballista/scheduler/src/state/executor_manager.rs @@ -650,15 +650,15 @@ impl ExecutorManager { #[cfg(test)] mod test { + use crate::config::SlotsPolicy; - use crate::state::backend::cluster::DefaultClusterState; - use crate::state::backend::sled::SledClient; + use crate::state::executor_manager::{ExecutorManager, ExecutorReservation}; + use crate::test_utils::test_cluster_context; use ballista_core::error::Result; use ballista_core::serde::scheduler::{ ExecutorData, ExecutorMetadata, ExecutorSpecification, }; - use std::sync::Arc; #[tokio::test] async fn test_reserve_and_cancel() -> Result<()> { @@ -670,11 +670,10 @@ mod test { } async fn test_reserve_and_cancel_inner(slots_policy: SlotsPolicy) -> Result<()> { - let cluster_state = Arc::new(DefaultClusterState::new(Arc::new( - SledClient::try_new_temporary()?, - ))); + let cluster = test_cluster_context(); - let executor_manager = ExecutorManager::new(cluster_state, slots_policy); + let executor_manager = + ExecutorManager::new(cluster.cluster_state(), slots_policy); let executors = test_executors(10, 4); @@ -720,11 +719,9 @@ mod test { } async fn test_reserve_partial_inner(slots_policy: SlotsPolicy) -> Result<()> { - let cluster_state = Arc::new(DefaultClusterState::new(Arc::new( - SledClient::try_new_temporary()?, - ))); - - let executor_manager = ExecutorManager::new(cluster_state, slots_policy); + let cluster = test_cluster_context(); + let executor_manager = + ExecutorManager::new(cluster.cluster_state(), slots_policy); let executors = test_executors(10, 4); @@ -777,11 +774,9 @@ mod test { let executors = test_executors(10, 4); - let cluster_state = Arc::new(DefaultClusterState::new(Arc::new( - SledClient::try_new_temporary()?, - ))); - - let executor_manager = ExecutorManager::new(cluster_state, slots_policy); + let cluster = test_cluster_context(); + let executor_manager = + ExecutorManager::new(cluster.cluster_state(), slots_policy); for (executor_metadata, executor_data) in executors { executor_manager @@ -824,11 +819,10 @@ mod test { } async fn test_register_reserve_inner(slots_policy: SlotsPolicy) -> Result<()> { - let cluster_state = Arc::new(DefaultClusterState::new(Arc::new( - SledClient::try_new_temporary()?, - ))); + let cluster = test_cluster_context(); - let executor_manager = ExecutorManager::new(cluster_state, slots_policy); + let executor_manager = + ExecutorManager::new(cluster.cluster_state(), slots_policy); let executors = test_executors(10, 4); diff --git a/ballista/scheduler/src/state/mod.rs b/ballista/scheduler/src/state/mod.rs index da0a36cb..342dd7aa 100644 --- a/ballista/scheduler/src/state/mod.rs +++ b/ballista/scheduler/src/state/mod.rs @@ -25,16 +25,15 @@ use std::sync::Arc; use std::time::Instant; use crate::scheduler_server::event::QueryStageSchedulerEvent; -use crate::scheduler_server::SessionBuilder; -use crate::state::backend::{Lock, StateBackendClient}; + +use crate::state::backend::Lock; use crate::state::executor_manager::{ExecutorManager, ExecutorReservation}; use crate::state::session_manager::SessionManager; use crate::state::task_manager::{TaskLauncher, TaskManager}; -use crate::cluster::{BallistaCluster, JobState}; +use crate::cluster::BallistaCluster; use crate::config::SchedulerConfig; use crate::state::execution_graph::TaskDescription; -use backend::cluster::ClusterState; use ballista_core::error::{BallistaError, Result}; use ballista_core::serde::protobuf::TaskStatus; use ballista_core::serde::{AsExecutionPlan, BallistaCodec}; @@ -100,15 +99,11 @@ pub(super) struct SchedulerState<T: 'static + AsLogicalPlan, U: 'static + AsExec impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerState<T, U> { #[cfg(test)] pub fn new_with_default_scheduler_name( - config_client: Arc<dyn StateBackendClient>, cluster: BallistaCluster, - session_builder: SessionBuilder, codec: BallistaCodec<T, U>, ) -> Self { SchedulerState::new( - config_client, - cluster_state, - session_builder, + cluster, codec, "localhost:50050".to_owned(), SchedulerConfig::default(), @@ -116,9 +111,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerState<T, } pub fn new( - config_client: Arc<dyn StateBackendClient>, cluster: BallistaCluster, - session_builder: SessionBuilder, codec: BallistaCodec<T, U>, scheduler_name: String, config: SchedulerConfig, @@ -129,23 +122,19 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerState<T, config.executor_slots_policy, ), task_manager: TaskManager::new( - config_client.clone(), cluster.job_state(), - session_builder, codec.clone(), scheduler_name, ), - session_manager: SessionManager::new(config_client, session_builder), + session_manager: SessionManager::new(cluster.job_state()), codec, config, } } #[allow(dead_code)] - pub(crate) fn with_task_launcher( - config_client: Arc<dyn StateBackendClient>, + pub(crate) fn new_with_task_launcher( cluster: BallistaCluster, - session_builder: SessionBuilder, codec: BallistaCodec<T, U>, scheduler_name: String, config: SchedulerConfig, @@ -157,14 +146,12 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerState<T, config.executor_slots_policy, ), task_manager: TaskManager::with_launcher( - config_client.clone(), cluster.job_state(), - session_builder, codec.clone(), scheduler_name, dispatcher, ), - session_manager: SessionManager::new(config_client, session_builder), + session_manager: SessionManager::new(cluster.job_state()), codec, config, } @@ -455,7 +442,7 @@ pub async fn with_locks<Out, F: Future<Output = Out>>( #[cfg(test)] mod test { - use crate::state::backend::sled::SledClient; + use crate::state::SchedulerState; use ballista_core::config::{BallistaConfig, BALLISTA_DEFAULT_SHUFFLE_PARTITIONS}; use ballista_core::error::Result; @@ -466,11 +453,10 @@ mod test { ExecutorData, ExecutorMetadata, ExecutorSpecification, }; use ballista_core::serde::BallistaCodec; - use ballista_core::utils::default_session_builder; use crate::config::SchedulerConfig; - use crate::state::backend::cluster::DefaultClusterState; - use crate::test_utils::BlackholeTaskLauncher; + + use crate::test_utils::{test_cluster_context, BlackholeTaskLauncher}; use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::logical_expr::{col, sum}; use datafusion::physical_plan::ExecutionPlan; @@ -479,16 +465,14 @@ mod test { use datafusion_proto::protobuf::LogicalPlanNode; use std::sync::Arc; + const TEST_SCHEDULER_NAME: &str = "localhost:50050"; + // We should free any reservations which are not assigned #[tokio::test] async fn test_offer_free_reservations() -> Result<()> { - let state_storage = Arc::new(SledClient::try_new_temporary()?); - let cluster_state = Arc::new(DefaultClusterState::new(state_storage.clone())); let state: Arc<SchedulerState<LogicalPlanNode, PhysicalPlanNode>> = Arc::new(SchedulerState::new_with_default_scheduler_name( - state_storage, - cluster_state, - default_session_builder, + test_cluster_context(), BallistaCodec::default(), )); @@ -520,15 +504,12 @@ mod test { let config = BallistaConfig::builder() .set(BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, "4") .build()?; - let state_storage = Arc::new(SledClient::try_new_temporary()?); - let cluster_state = Arc::new(DefaultClusterState::new(state_storage.clone())); + let state: Arc<SchedulerState<LogicalPlanNode, PhysicalPlanNode>> = - Arc::new(SchedulerState::with_task_launcher( - state_storage, - cluster_state, - default_session_builder, + Arc::new(SchedulerState::new_with_task_launcher( + test_cluster_context(), BallistaCodec::default(), - String::default(), + TEST_SCHEDULER_NAME.into(), SchedulerConfig::default(), Arc::new(BlackholeTaskLauncher::default()), )); @@ -607,15 +588,12 @@ mod test { let config = BallistaConfig::builder() .set(BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, "4") .build()?; - let state_storage = Arc::new(SledClient::try_new_temporary()?); - let cluster_state = Arc::new(DefaultClusterState::new(state_storage.clone())); + let state: Arc<SchedulerState<LogicalPlanNode, PhysicalPlanNode>> = - Arc::new(SchedulerState::with_task_launcher( - state_storage, - cluster_state, - default_session_builder, + Arc::new(SchedulerState::new_with_task_launcher( + test_cluster_context(), BallistaCodec::default(), - String::default(), + TEST_SCHEDULER_NAME.into(), SchedulerConfig::default(), Arc::new(BlackholeTaskLauncher::default()), )); diff --git a/ballista/scheduler/src/state/session_manager.rs b/ballista/scheduler/src/state/session_manager.rs index eb7df9f2..a860aa6e 100644 --- a/ballista/scheduler/src/state/session_manager.rs +++ b/ballista/scheduler/src/state/session_manager.rs @@ -16,32 +16,23 @@ // under the License. use crate::scheduler_server::SessionBuilder; -use crate::state::backend::{Keyspace, StateBackendClient}; -use crate::state::{decode_protobuf, encode_protobuf}; use ballista_core::config::BallistaConfig; use ballista_core::error::Result; -use ballista_core::serde::protobuf::{self, KeyValuePair}; use datafusion::prelude::{SessionConfig, SessionContext}; +use crate::cluster::JobState; use datafusion::common::ScalarValue; use log::warn; use std::sync::Arc; #[derive(Clone)] pub struct SessionManager { - state: Arc<dyn StateBackendClient>, - session_builder: SessionBuilder, + state: Arc<dyn JobState>, } impl SessionManager { - pub fn new( - state: Arc<dyn StateBackendClient>, - session_builder: SessionBuilder, - ) -> Self { - Self { - state, - session_builder, - } + pub fn new(state: Arc<dyn JobState>) -> Self { + Self { state } } pub async fn update_session( @@ -49,64 +40,18 @@ impl SessionManager { session_id: &str, config: &BallistaConfig, ) -> Result<Arc<SessionContext>> { - let mut settings: Vec<KeyValuePair> = vec![]; - - for (key, value) in config.settings() { - settings.push(KeyValuePair { - key: key.clone(), - value: value.clone(), - }) - } - - let value = encode_protobuf(&protobuf::SessionSettings { configs: settings })?; - self.state - .put(Keyspace::Sessions, session_id.to_owned(), value) - .await?; - - Ok(create_datafusion_context(config, self.session_builder)) + self.state.update_session(session_id, config).await } pub async fn create_session( &self, config: &BallistaConfig, ) -> Result<Arc<SessionContext>> { - let mut settings: Vec<KeyValuePair> = vec![]; - - for (key, value) in config.settings() { - settings.push(KeyValuePair { - key: key.clone(), - value: value.clone(), - }) - } - - let mut config_builder = BallistaConfig::builder(); - for kv_pair in &settings { - config_builder = config_builder.set(&kv_pair.key, &kv_pair.value); - } - let config = config_builder.build()?; - - let ctx = create_datafusion_context(&config, self.session_builder); - - let value = encode_protobuf(&protobuf::SessionSettings { configs: settings })?; - self.state - .put(Keyspace::Sessions, ctx.session_id(), value) - .await?; - - Ok(ctx) + self.state.create_session(config).await } pub async fn get_session(&self, session_id: &str) -> Result<Arc<SessionContext>> { - let value = self.state.get(Keyspace::Sessions, session_id).await?; - - let settings: protobuf::SessionSettings = decode_protobuf(&value)?; - - let mut config_builder = BallistaConfig::builder(); - for kv_pair in &settings.configs { - config_builder = config_builder.set(&kv_pair.key, &kv_pair.value); - } - let config = config_builder.build()?; - - Ok(create_datafusion_context(&config, self.session_builder)) + self.state.get_session(session_id).await } } diff --git a/ballista/scheduler/src/state/task_manager.rs b/ballista/scheduler/src/state/task_manager.rs index 25dcdaea..ebc58c5f 100644 --- a/ballista/scheduler/src/state/task_manager.rs +++ b/ballista/scheduler/src/state/task_manager.rs @@ -16,31 +16,25 @@ // under the License. use crate::scheduler_server::event::QueryStageSchedulerEvent; -use crate::scheduler_server::SessionBuilder; -use crate::state::backend::{Keyspace, Operation, StateBackendClient}; + use crate::state::execution_graph::{ ExecutionGraph, ExecutionStage, RunningTaskInfo, TaskDescription, }; use crate::state::executor_manager::{ExecutorManager, ExecutorReservation}; -use crate::state::{decode_protobuf, encode_protobuf, with_lock, with_locks}; -use ballista_core::config::BallistaConfig; + use ballista_core::error::BallistaError; use ballista_core::error::Result; -use crate::state::backend::Keyspace::{CompletedJobs, FailedJobs}; -use crate::state::session_manager::create_datafusion_context; - use crate::cluster::JobState; use ballista_core::serde::protobuf::{ - self, job_status, FailedJob, JobStatus, MultiTaskDefinition, TaskDefinition, TaskId, - TaskStatus, + self, JobStatus, MultiTaskDefinition, TaskDefinition, TaskId, TaskStatus, }; use ballista_core::serde::scheduler::to_proto::hash_partitioning_to_proto; use ballista_core::serde::scheduler::ExecutorMetadata; use ballista_core::serde::{AsExecutionPlan, BallistaCodec}; use dashmap::DashMap; use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::SessionContext; + use datafusion_proto::logical_plan::AsLogicalPlan; use log::{debug, error, info, warn}; use rand::distributions::Alphanumeric; @@ -51,7 +45,7 @@ use std::sync::Arc; use std::time::Duration; use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::RwLock; -use tokio::sync::RwLockReadGuard; + use tracing::trace; type ActiveJobCache = Arc<DashMap<String, JobInfoCache>>; @@ -110,9 +104,7 @@ impl TaskLauncher for DefaultTaskLauncher { #[derive(Clone)] pub struct TaskManager<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> { - state: Arc<dyn StateBackendClient>, - job_state: Arc<dyn JobState>, - session_builder: SessionBuilder, + state: Arc<dyn JobState>, codec: BallistaCodec<T, U>, scheduler_id: String, // Cache for active jobs curated by this scheduler @@ -148,16 +140,12 @@ pub struct UpdatedStages { impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> { pub fn new( - state: Arc<dyn StateBackendClient>, - job_state: Arc<dyn JobState>, - session_builder: SessionBuilder, + state: Arc<dyn JobState>, codec: BallistaCodec<T, U>, scheduler_id: String, ) -> Self { Self { state, - job_state, - session_builder, codec, scheduler_id: scheduler_id.clone(), active_job_cache: Arc::new(DashMap::new()), @@ -167,17 +155,13 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> #[allow(dead_code)] pub(crate) fn with_launcher( - state: Arc<dyn StateBackendClient>, - job_state: Arc<dyn JobState>, - session_builder: SessionBuilder, + state: Arc<dyn JobState>, codec: BallistaCodec<T, U>, scheduler_id: String, launcher: Arc<dyn TaskLauncher>, ) -> Self { Self { state, - job_state, - session_builder, codec, scheduler_id, active_job_cache: Arc::new(DashMap::new()), @@ -206,9 +190,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> )?; info!("Submitting execution graph: {:?}", graph); - self.job_state - .submit_job(job_id.to_string(), &graph) - .await?; + self.state.submit_job(job_id.to_string(), &graph).await?; graph.revive(); self.active_job_cache @@ -219,7 +201,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> /// Get a list of active job ids pub async fn get_jobs(&self) -> Result<Vec<JobOverview>> { - let job_ids = self.job_state.get_jobs().await?; + let job_ids = self.state.get_jobs().await?; let mut jobs = vec![]; for job_id in &job_ids { @@ -227,7 +209,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> let graph = cached.execution_graph.read().await; jobs.push(graph.deref().into()); } else { - let graph = self.job_state + let graph = self.state .get_execution_graph(job_id) .await? .ok_or_else(|| BallistaError::Internal(format!("Error getting job overview, no execution graph found for job {job_id}")))?; @@ -240,7 +222,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> /// Get the status of of a job. First look in the active cache. /// If no one found, then in the Active/Completed jobs, and then in Failed jobs pub async fn get_job_status(&self, job_id: &str) -> Result<Option<JobStatus>> { - self.job_state.get_job_status(job_id).await + self.state.get_job_status(job_id).await } /// Get the execution graph of of a job. First look in the active cache. @@ -254,7 +236,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> Ok(Some(Arc::new(guard.deref().clone()))) } else { - let graph = self.job_state.get_execution_graph(job_id).await?; + let graph = self.state.get_execution_graph(job_id).await?; Ok(graph.map(Arc::new)) } @@ -367,7 +349,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> if let Some(graph) = self.remove_active_execution_graph(job_id).await { let graph = graph.read().await.clone(); if graph.is_successful() { - self.job_state.save_job(job_id, &graph).await?; + self.state.save_job(job_id, &graph).await?; } else { error!("Job {} has not finished and cannot be completed", job_id); return Ok(()); @@ -393,35 +375,28 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> job_id: &str, failure_reason: String, ) -> Result<(Vec<RunningTaskInfo>, usize)> { - let locks = self - .state - .acquire_locks(vec![ - (Keyspace::ActiveJobs, job_id), - (Keyspace::FailedJobs, job_id), - ]) - .await?; let (tasks_to_cancel, pending_tasks) = if let Some(graph) = self.get_active_execution_graph(job_id).await { - let (pending_tasks, running_tasks) = { - let guard = graph.read().await; - (guard.available_tasks(), guard.running_tasks()) - }; + let mut guard = graph.write().await; + + let pending_tasks = guard.available_tasks(); + let running_tasks = guard.running_tasks(); info!( "Cancelling {} running tasks for job {}", running_tasks.len(), job_id ); - with_locks(locks, self.fail_job_state(job_id, failure_reason)) - .await - .unwrap(); + + guard.fail_job(failure_reason); + + self.state.save_job(job_id, &guard).await?; (running_tasks, pending_tasks) } else { // TODO listen the job state update event and fix task cancelling warn!("Fail to find job {} in the cache, unable to cancel tasks for job, fail the job state only.", job_id); - with_locks(locks, self.fail_job_state(job_id, failure_reason)).await?; (vec![], 0) }; @@ -435,62 +410,9 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> job_id: &str, failure_reason: String, ) -> Result<()> { - debug!("Moving job {} from Active or Queue to Failed", job_id); - let locks = self - .state - .acquire_locks(vec![ - (Keyspace::ActiveJobs, job_id), - (Keyspace::FailedJobs, job_id), - ]) - .await?; - with_locks(locks, self.fail_job_state(job_id, failure_reason)).await?; - - Ok(()) - } - - async fn fail_job_state(&self, job_id: &str, failure_reason: String) -> Result<()> { - let txn_operations = |value: Vec<u8>| -> Vec<(Operation, Keyspace, String)> { - vec![ - (Operation::Delete, Keyspace::ActiveJobs, job_id.to_string()), - ( - Operation::Put(value), - Keyspace::FailedJobs, - job_id.to_string(), - ), - ] - }; - - if let Some(graph) = self.remove_active_execution_graph(job_id).await { - let mut graph = graph.write().await; - let previous_status = graph.status(); - graph.fail_job(failure_reason); - - let value = encode_protobuf(&graph.status())?; - let txn_ops = txn_operations(value); - let result = self.state.apply_txn(txn_ops).await; - if result.is_err() { - // Rollback - graph.update_status(previous_status); - warn!("Rollback Execution Graph state change since it did not persisted due to a possible connection error.") - }; - } else { - warn!( - "Fail to find job {} in the cache, not updating status to failed", - job_id - ); - // let status = JobStatus { - // job_id: job_id.to_string(), - // job_name: "".to_string(), - // status: Some(job_status::Status::Failed(FailedJob { - // error: failure_reason.clone(), - // })), - // }; - // let value = encode_protobuf(&status)?; - // let txn_ops = txn_operations(value); - // self.state.apply_txn(txn_ops).await?; - }; - - Ok(()) + self.state + .fail_unscheduled_job(job_id, failure_reason) + .await } pub async fn update_job(&self, job_id: &str) -> Result<usize> { @@ -503,7 +425,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> graph.revive(); let graph = graph.clone(); - self.job_state.save_job(job_id, &graph).await?; + self.state.save_job(job_id, &graph).await?; let new_tasks = graph.available_tasks() - curr_available_tasks; @@ -728,15 +650,15 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> Ok(guard.clone()) } else { - let graph = self - .job_state - .get_execution_graph(job_id) - .await? - .ok_or_else(|| { - BallistaError::Internal(format!( - "No ExecutionGraph found for job {job_id}" - )) - })?; + let graph = + self.state + .get_execution_graph(job_id) + .await? + .ok_or_else(|| { + BallistaError::Internal(format!( + "No ExecutionGraph found for job {job_id}" + )) + })?; Ok(graph) } @@ -759,7 +681,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> return; } - let state = self.job_state.clone(); + let state = self.state.clone(); tokio::spawn(async move { let job_id = job_id; tokio::time::sleep(Duration::from_secs(clean_up_interval)).await; diff --git a/ballista/scheduler/src/test_utils.rs b/ballista/scheduler/src/test_utils.rs index 3f98e370..36704503 100644 --- a/ballista/scheduler/src/test_utils.rs +++ b/ballista/scheduler/src/test_utils.rs @@ -27,7 +27,6 @@ use async_trait::async_trait; use crate::config::SchedulerConfig; use crate::metrics::SchedulerMetricsCollector; use crate::scheduler_server::{timestamp_millis, SchedulerServer}; -use crate::state::backend::sled::SledClient; use crate::state::executor_manager::ExecutorManager; use crate::state::task_manager::TaskLauncher; @@ -50,8 +49,10 @@ use datafusion::logical_expr::{Expr, LogicalPlan}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::CsvReadOptions; +use crate::cluster::BallistaCluster; use crate::scheduler_server::event::QueryStageSchedulerEvent; -use crate::state::backend::cluster::DefaultClusterState; + +use ballista_core::utils::default_session_builder; use datafusion_proto::protobuf::LogicalPlanNode; use parking_lot::Mutex; use tokio::sync::mpsc::{channel, Receiver, Sender}; @@ -60,6 +61,8 @@ pub const TPCH_TABLES: &[&str] = &[ "part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", "region", ]; +const TEST_SCHEDULER_NAME: &str = "localhost:50050"; + /// Sometimes we need to construct logical plans that will produce errors /// when we try and create physical plan. A scan using `ExplodingTableProvider` /// will do the trick @@ -115,6 +118,10 @@ pub async fn await_condition<Fut: Future<Output = Result<bool>>, F: Fn() -> Fut> Ok(false) } +pub fn test_cluster_context() -> BallistaCluster { + BallistaCluster::new_memory(TEST_SCHEDULER_NAME, default_session_builder) +} + pub async fn datafusion_test_context(path: &str) -> Result<SessionContext> { let default_shuffle_partitions = 2; let config = SessionConfig::new().with_target_partitions(default_shuffle_partitions); @@ -381,8 +388,7 @@ impl SchedulerTest { task_slots_per_executor: usize, runner: Option<Arc<dyn TaskRunner>>, ) -> Result<Self> { - let state_storage = Arc::new(SledClient::try_new_temporary()?); - let cluster_state = Arc::new(DefaultClusterState::new(state_storage.clone())); + let cluster = BallistaCluster::new_from_config(&config).await?; let ballista_config = BallistaConfig::builder() .set( @@ -415,10 +421,9 @@ impl SchedulerTest { }; let mut scheduler: SchedulerServer<LogicalPlanNode, PhysicalPlanNode> = - SchedulerServer::with_task_launcher( + SchedulerServer::new_with_task_launcher( "localhost:50050".to_owned(), - state_storage, - cluster_state, + cluster, BallistaCodec::default(), config, metrics_collector, @@ -526,6 +531,7 @@ impl SchedulerTest { if let Some(JobStatus { status: Some(inner), + .. }) = status.as_ref() { match inner { @@ -581,6 +587,7 @@ impl SchedulerTest { if let Some(JobStatus { status: Some(inner), + .. }) = status.as_ref() { match inner {
