This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git
The following commit(s) were added to refs/heads/master by this push:
new caff2fd4 Atomic support for enhancement (#319)
caff2fd4 is described below
commit caff2fd413c31b456ccd73febd1de998731e0589
Author: Metehan Yıldırım <[email protected]>
AuthorDate: Fri Oct 7 15:29:39 2022 +0300
Atomic support for enhancement (#319)
* Atomic support for enhancement
* Incorporate review feedback, simplify with_lock(s) functions
Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
ballista/rust/executor/Cargo.toml | 1 +
ballista/rust/executor/src/executor_server.rs | 20 ++--
ballista/rust/scheduler/Cargo.toml | 1 +
ballista/rust/scheduler/src/state/backend/etcd.rs | 24 +++--
ballista/rust/scheduler/src/state/backend/mod.rs | 23 ++++-
.../rust/scheduler/src/state/backend/standalone.rs | 48 ++++++++--
.../rust/scheduler/src/state/executor_manager.rs | 89 ++++++++----------
ballista/rust/scheduler/src/state/mod.rs | 18 +++-
.../rust/scheduler/src/state/session_registry.rs | 21 +++--
ballista/rust/scheduler/src/state/task_manager.rs | 101 ++++++++++++---------
10 files changed, 211 insertions(+), 135 deletions(-)
diff --git a/ballista/rust/executor/Cargo.toml
b/ballista/rust/executor/Cargo.toml
index 253052fd..ca3532ac 100644
--- a/ballista/rust/executor/Cargo.toml
+++ b/ballista/rust/executor/Cargo.toml
@@ -40,6 +40,7 @@ async-trait = "0.1.41"
ballista-core = { path = "../core", version = "0.8.0" }
chrono = { version = "0.4", default-features = false }
configure_me = "0.4.0"
+dashmap = "5.4.0"
datafusion = { git = "https://github.com/apache/arrow-datafusion", rev =
"06a4f79f02fcb6ea85303925b7c5a9b0231e3fee" }
datafusion-proto = { git = "https://github.com/apache/arrow-datafusion", rev =
"06a4f79f02fcb6ea85303925b7c5a9b0231e3fee" }
futures = "0.3"
diff --git a/ballista/rust/executor/src/executor_server.rs
b/ballista/rust/executor/src/executor_server.rs
index 2bc84f72..bf036d6f 100644
--- a/ballista/rust/executor/src/executor_server.rs
+++ b/ballista/rust/executor/src/executor_server.rs
@@ -21,7 +21,7 @@ use std::convert::TryInto;
use std::ops::Deref;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
-use tokio::sync::{mpsc, RwLock};
+use tokio::sync::mpsc;
use log::{debug, error, info, warn};
use tonic::transport::Channel;
@@ -45,6 +45,7 @@ use ballista_core::serde::{AsExecutionPlan, BallistaCodec};
use ballista_core::utils::{
collect_plan_metrics, create_grpc_client_connection, create_grpc_server,
};
+use dashmap::DashMap;
use datafusion::execution::context::TaskContext;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_proto::logical_plan::AsLogicalPlan;
@@ -57,7 +58,7 @@ use crate::shutdown::ShutdownNotifier;
use crate::{as_task_status, TaskExecutionTimes};
type ServerHandle = JoinHandle<Result<(), BallistaError>>;
-type SchedulerClients = Arc<RwLock<HashMap<String,
SchedulerGrpcClient<Channel>>>>;
+type SchedulerClients = Arc<DashMap<String, SchedulerGrpcClient<Channel>>>;
/// Wrap TaskDefinition with its curator scheduler id for task update to its
specific curator scheduler later
#[derive(Debug)]
@@ -216,10 +217,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> ExecutorServer<T,
&self,
scheduler_id: &str,
) -> Result<SchedulerGrpcClient<Channel>, BallistaError> {
- let scheduler = {
- let schedulers = self.schedulers.read().await;
- schedulers.get(scheduler_id).cloned()
- };
+ let scheduler = self.schedulers.get(scheduler_id).map(|value|
value.clone());
// If channel does not exist, create a new one
if let Some(scheduler) = scheduler {
Ok(scheduler)
@@ -229,8 +227,8 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> ExecutorServer<T,
let scheduler = SchedulerGrpcClient::new(connection);
{
- let mut schedulers = self.schedulers.write().await;
- schedulers.insert(scheduler_id.to_owned(), scheduler.clone());
+ self.schedulers
+ .insert(scheduler_id.to_owned(), scheduler.clone());
}
Ok(scheduler)
@@ -263,8 +261,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> ExecutorServer<T,
}
};
- let schedulers = self.schedulers.read().await.clone();
- for (scheduler_id, mut scheduler) in schedulers {
+ for mut item in self.schedulers.iter_mut() {
+ let scheduler_id = item.key().clone();
+ let scheduler = item.value_mut();
+
match scheduler
.heart_beat_from_executor(heartbeat_params.clone())
.await
diff --git a/ballista/rust/scheduler/Cargo.toml
b/ballista/rust/scheduler/Cargo.toml
index 954e5a48..90e3f020 100644
--- a/ballista/rust/scheduler/Cargo.toml
+++ b/ballista/rust/scheduler/Cargo.toml
@@ -45,6 +45,7 @@ ballista-core = { path = "../core", version = "0.8.0" }
base64 = { version = "0.13", default-features = false }
clap = { version = "3", features = ["derive", "cargo"] }
configure_me = "0.4.0"
+dashmap = "5.4.0"
datafusion = { git = "https://github.com/apache/arrow-datafusion", rev =
"06a4f79f02fcb6ea85303925b7c5a9b0231e3fee" }
datafusion-proto = { git = "https://github.com/apache/arrow-datafusion", rev =
"06a4f79f02fcb6ea85303925b7c5a9b0231e3fee" }
etcd-client = { version = "0.9", optional = true }
diff --git a/ballista/rust/scheduler/src/state/backend/etcd.rs
b/ballista/rust/scheduler/src/state/backend/etcd.rs
index 4b24b7aa..e753df66 100644
--- a/ballista/rust/scheduler/src/state/backend/etcd.rs
+++ b/ballista/rust/scheduler/src/state/backend/etcd.rs
@@ -30,7 +30,9 @@ use etcd_client::{
use futures::{Stream, StreamExt};
use log::{debug, error, warn};
-use crate::state::backend::{Keyspace, Lock, StateBackendClient, Watch,
WatchEvent};
+use crate::state::backend::{
+ Keyspace, Lock, Operation, StateBackendClient, Watch, WatchEvent,
+};
/// A [`StateBackendClient`] implementation that uses etcd to save cluster
configuration.
#[derive(Clone)]
@@ -137,29 +139,31 @@ impl StateBackendClient for EtcdClient {
.await
.map_err(|e| {
warn!("etcd put failed: {}", e);
- ballista_error("etcd put failed")
+ ballista_error(&*format!("etcd put failed: {}", e))
})
.map(|_| ())
}
- async fn put_txn(&self, ops: Vec<(Keyspace, String, Vec<u8>)>) ->
Result<()> {
+ /// Apply multiple operations in a single transaction.
+ async fn apply_txn(&self, ops: Vec<(Operation, Keyspace, String)>) ->
Result<()> {
let mut etcd = self.etcd.clone();
let txn_ops: Vec<TxnOp> = ops
.into_iter()
- .map(|(ks, key, value)| {
+ .map(|(operation, ks, key)| {
let key = format!("/{}/{:?}/{}", self.namespace, ks, key);
- TxnOp::put(key, value, None)
+ match operation {
+ Operation::Put(value) => TxnOp::put(key, value, None),
+ Operation::Delete => TxnOp::delete(key, None),
+ }
})
.collect();
- let txn = Txn::new().and_then(txn_ops);
-
- etcd.txn(txn)
+ etcd.txn(Txn::new().and_then(txn_ops))
.await
.map_err(|e| {
- error!("etcd put failed: {}", e);
- ballista_error("etcd transaction put failed")
+ error!("etcd operation failed: {}", e);
+ ballista_error(&*format!("etcd operation failed: {}", e))
})
.map(|_| ())
}
diff --git a/ballista/rust/scheduler/src/state/backend/mod.rs
b/ballista/rust/scheduler/src/state/backend/mod.rs
index b69403b2..85c0f34d 100644
--- a/ballista/rust/scheduler/src/state/backend/mod.rs
+++ b/ballista/rust/scheduler/src/state/backend/mod.rs
@@ -17,7 +17,7 @@
use ballista_core::error::Result;
use clap::ArgEnum;
-use futures::Stream;
+use futures::{future, Stream};
use std::collections::HashSet;
use std::fmt;
use tokio::sync::OwnedMutexGuard;
@@ -60,6 +60,12 @@ pub enum Keyspace {
Heartbeats,
}
+#[derive(Debug, Eq, PartialEq, Hash)]
+pub enum Operation {
+ Put(Vec<u8>),
+ Delete,
+}
+
/// A trait that contains the necessary methods to save and retrieve the state
and configuration of a cluster.
#[tonic::async_trait]
pub trait StateBackendClient: Send + Sync {
@@ -90,8 +96,19 @@ pub trait StateBackendClient: Send + Sync {
/// Saves the value into the provided key, overriding any previous data
that might have been associated to that key.
async fn put(&self, keyspace: Keyspace, key: String, value: Vec<u8>) ->
Result<()>;
- /// Save multiple values in a single transaction. Either all values should
be saved, or all should fail
- async fn put_txn(&self, ops: Vec<(Keyspace, String, Vec<u8>)>) ->
Result<()>;
+ /// Bundle multiple operation in a single transaction. Either all values
should be saved, or all should fail.
+ /// It can support multiple types of operations and keyspaces. If the
count of the unique keyspace is more than one,
+ /// more than one locks has to be acquired.
+ async fn apply_txn(&self, ops: Vec<(Operation, Keyspace, String)>) ->
Result<()>;
+ /// Acquire mutex with specified IDs.
+ async fn acquire_locks(
+ &self,
+ mut ids: Vec<(Keyspace, &str)>,
+ ) -> Result<Vec<Box<dyn Lock>>> {
+ // We always acquire locks in a specific order to avoid deadlocks.
+ ids.sort_by_key(|n| format!("/{:?}/{}", n.0, n.1));
+ future::try_join_all(ids.into_iter().map(|(ks, key)| self.lock(ks,
key))).await
+ }
/// Atomically move the given key from one keyspace to another
async fn mv(
diff --git a/ballista/rust/scheduler/src/state/backend/standalone.rs
b/ballista/rust/scheduler/src/state/backend/standalone.rs
index 4e5dc063..57bf7470 100644
--- a/ballista/rust/scheduler/src/state/backend/standalone.rs
+++ b/ballista/rust/scheduler/src/state/backend/standalone.rs
@@ -25,7 +25,9 @@ use log::warn;
use sled_package as sled;
use tokio::sync::Mutex;
-use crate::state::backend::{Keyspace, Lock, StateBackendClient, Watch,
WatchEvent};
+use crate::state::backend::{
+ Keyspace, Lock, Operation, StateBackendClient, Watch, WatchEvent,
+};
/// A [`StateBackendClient`] implementation that uses file-based storage to
save cluster configuration.
#[derive(Clone)]
@@ -162,17 +164,20 @@ impl StateBackendClient for StandaloneClient {
.map(|_| ())
}
- async fn put_txn(&self, ops: Vec<(Keyspace, String, Vec<u8>)>) ->
Result<()> {
+ async fn apply_txn(&self, ops: Vec<(Operation, Keyspace, String)>) ->
Result<()> {
let mut batch = sled::Batch::default();
- for (ks, key, value) in ops {
- let key = format!("/{:?}/{}", ks, key);
- batch.insert(key.as_str(), value);
+ for (op, keyspace, key_str) in ops {
+ let key = format!("/{:?}/{}", &keyspace, key_str);
+ match op {
+ Operation::Put(value) => batch.insert(key.as_str(), value),
+ Operation::Delete => batch.remove(key.as_str()),
+ }
}
self.db.apply_batch(batch).map_err(|e| {
warn!("sled transaction insert failed: {}", e);
- ballista_error("sled insert failed")
+ ballista_error("sled operations failed")
})
}
@@ -279,7 +284,8 @@ impl Stream for SledWatch {
mod tests {
use super::{StandaloneClient, StateBackendClient, Watch, WatchEvent};
- use crate::state::backend::Keyspace;
+ use crate::state::backend::{Keyspace, Operation};
+ use crate::state::with_locks;
use futures::StreamExt;
use std::result::Result;
@@ -299,6 +305,34 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn multiple_operation() -> Result<(), Box<dyn std::error::Error>> {
+ let client = create_instance()?;
+ let key = "key".to_string();
+ let value = "value".as_bytes().to_vec();
+ let locks = client
+ .acquire_locks(vec![(Keyspace::ActiveJobs, ""), (Keyspace::Slots,
"")])
+ .await?;
+
+ let _r: ballista_core::error::Result<()> = with_locks(locks, async {
+ let txn_ops = vec![
+ (Operation::Put(value.clone()), Keyspace::Slots, key.clone()),
+ (
+ Operation::Put(value.clone()),
+ Keyspace::ActiveJobs,
+ key.clone(),
+ ),
+ ];
+ client.apply_txn(txn_ops).await?;
+ Ok(())
+ })
+ .await;
+
+ assert_eq!(client.get(Keyspace::Slots, key.as_str()).await?, value);
+ assert_eq!(client.get(Keyspace::ActiveJobs, key.as_str()).await?,
value);
+ Ok(())
+ }
+
#[tokio::test]
async fn read_empty() -> Result<(), Box<dyn std::error::Error>> {
let client = create_instance()?;
diff --git a/ballista/rust/scheduler/src/state/executor_manager.rs
b/ballista/rust/scheduler/src/state/executor_manager.rs
index 1d135ef8..322e5a7b 100644
--- a/ballista/rust/scheduler/src/state/executor_manager.rs
+++ b/ballista/rust/scheduler/src/state/executor_manager.rs
@@ -17,7 +17,7 @@
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
-use crate::state::backend::{Keyspace, StateBackendClient, WatchEvent};
+use crate::state::backend::{Keyspace, Operation, StateBackendClient,
WatchEvent};
use crate::state::{decode_into, decode_protobuf, encode_protobuf, with_lock};
use ballista_core::error::{BallistaError, Result};
@@ -30,14 +30,14 @@ use ballista_core::serde::protobuf::{
};
use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata};
use ballista_core::utils::create_grpc_client_connection;
+use dashmap::{DashMap, DashSet};
use futures::StreamExt;
use log::{debug, error, info};
-use parking_lot::RwLock;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tonic::transport::Channel;
-type ExecutorClients = Arc<RwLock<HashMap<String,
ExecutorGrpcClient<Channel>>>>;
+type ExecutorClients = Arc<DashMap<String, ExecutorGrpcClient<Channel>>>;
/// Represents a task slot that is reserved (i.e. available for scheduling but
not visible to the
/// rest of the system).
@@ -85,11 +85,11 @@ pub const DEFAULT_EXECUTOR_TIMEOUT_SECONDS: u64 = 180;
pub(crate) struct ExecutorManager {
state: Arc<dyn StateBackendClient>,
// executor_id -> ExecutorMetadata map
- executor_metadata: Arc<RwLock<HashMap<String, ExecutorMetadata>>>,
+ executor_metadata: Arc<DashMap<String, ExecutorMetadata>>,
// executor_id -> ExecutorHeartbeat map
- executors_heartbeat: Arc<RwLock<HashMap<String,
protobuf::ExecutorHeartbeat>>>,
+ executors_heartbeat: Arc<DashMap<String, protobuf::ExecutorHeartbeat>>,
// dead executor sets:
- dead_executors: Arc<RwLock<HashSet<String>>>,
+ dead_executors: Arc<DashSet<String>>,
clients: ExecutorClients,
}
@@ -97,9 +97,9 @@ impl ExecutorManager {
pub(crate) fn new(state: Arc<dyn StateBackendClient>) -> Self {
Self {
state,
- executor_metadata: Arc::new(RwLock::new(HashMap::new())),
- executors_heartbeat: Arc::new(RwLock::new(HashMap::new())),
- dead_executors: Arc::new(RwLock::new(HashSet::new())),
+ executor_metadata: Arc::new(DashMap::new()),
+ executors_heartbeat: Arc::new(DashMap::new()),
+ dead_executors: Arc::new(DashSet::new()),
clients: Default::default(),
}
}
@@ -130,7 +130,7 @@ impl ExecutorManager {
let alive_executors = self.get_alive_executors_within_one_minute();
- let mut txn_ops: Vec<(Keyspace, String, Vec<u8>)> = vec![];
+ let mut txn_ops: Vec<(Operation, Keyspace, String)> = vec![];
for executor_id in alive_executors {
let value = self.state.get(Keyspace::Slots,
&executor_id).await?;
@@ -146,14 +146,14 @@ impl ExecutorManager {
let proto: protobuf::ExecutorData = data.into();
let new_data = encode_protobuf(&proto)?;
- txn_ops.push((Keyspace::Slots, executor_id, new_data));
+ txn_ops.push((Operation::Put(new_data), Keyspace::Slots,
executor_id));
if desired == 0 {
break;
}
}
- self.state.put_txn(txn_ops).await?;
+ self.state.apply_txn(txn_ops).await?;
let elapsed = start.elapsed();
info!(
@@ -195,16 +195,16 @@ impl ExecutorManager {
}
}
- let txn_ops: Vec<(Keyspace, String, Vec<u8>)> = executor_slots
+ let txn_ops: Vec<(Operation, Keyspace, String)> = executor_slots
.into_iter()
.map(|(executor_id, data)| {
let proto: protobuf::ExecutorData = data.into();
let new_data = encode_protobuf(&proto)?;
- Ok((Keyspace::Slots, executor_id, new_data))
+ Ok((Operation::Put(new_data), Keyspace::Slots,
executor_id))
})
.collect::<Result<Vec<_>>>()?;
- self.state.put_txn(txn_ops).await?;
+ self.state.apply_txn(txn_ops).await?;
let elapsed = start.elapsed();
info!(
@@ -262,10 +262,7 @@ impl ExecutorManager {
&self,
executor_id: &str,
) -> Result<ExecutorGrpcClient<Channel>> {
- let client = {
- let clients = self.clients.read();
- clients.get(executor_id).cloned()
- };
+ let client = self.clients.get(executor_id).map(|value| value.clone());
if let Some(client) = client {
Ok(client)
@@ -279,8 +276,7 @@ impl ExecutorManager {
let client = ExecutorGrpcClient::new(connection);
{
- let mut clients = self.clients.write();
- clients.insert(executor_id.to_owned(), client.clone());
+ self.clients.insert(executor_id.to_owned(), client.clone());
}
Ok(client)
}
@@ -289,11 +285,10 @@ impl ExecutorManager {
/// Get a list of all executors along with the timestamp of their last
recorded heartbeat
pub async fn get_executor_state(&self) -> Result<Vec<(ExecutorMetadata,
Duration)>> {
let heartbeat_timestamps: Vec<(String, u64)> = {
- let heartbeats = self.executors_heartbeat.read();
-
- heartbeats
+ self.executors_heartbeat
.iter()
- .map(|(executor_id, heartbeat)| {
+ .map(|item| {
+ let (executor_id, heartbeat) = item.pair();
(executor_id.clone(), heartbeat.timestamp)
})
.collect()
@@ -316,8 +311,7 @@ impl ExecutorManager {
executor_id: &str,
) -> Result<ExecutorMetadata> {
{
- let metadata_cache = self.executor_metadata.read();
- if let Some(cached) = metadata_cache.get(executor_id) {
+ if let Some(cached) = self.executor_metadata.get(executor_id) {
return Ok(cached.clone());
}
}
@@ -468,8 +462,8 @@ impl ExecutorManager {
.put(Keyspace::Heartbeats, executor_id, value)
.await?;
- let mut executors_heartbeat = self.executors_heartbeat.write();
- executors_heartbeat.insert(heartbeat.executor_id.clone(), heartbeat);
+ self.executors_heartbeat
+ .insert(heartbeat.executor_id.clone(), heartbeat);
Ok(())
}
@@ -484,22 +478,19 @@ impl ExecutorManager {
.put(Keyspace::Heartbeats, executor_id.clone(), value)
.await?;
- let mut executors_heartbeat = self.executors_heartbeat.write();
- executors_heartbeat.remove(&heartbeat.executor_id.clone());
-
- let mut dead_executors = self.dead_executors.write();
- dead_executors.insert(executor_id);
+ self.executors_heartbeat
+ .remove(&heartbeat.executor_id.clone());
+ self.dead_executors.insert(executor_id);
Ok(())
}
pub(crate) fn is_dead_executor(&self, executor_id: &str) -> bool {
- self.dead_executors.read().contains(executor_id)
+ self.dead_executors.contains(executor_id)
}
/// Initialize the set of active executor heartbeats from storage
async fn init_active_executor_heartbeats(&self) -> Result<()> {
let heartbeats = self.state.scan(Keyspace::Heartbeats, None).await?;
- let mut cache = self.executors_heartbeat.write();
for (_, value) in heartbeats {
let data: protobuf::ExecutorHeartbeat = decode_protobuf(&value)?;
@@ -508,7 +499,7 @@ impl ExecutorManager {
status: Some(executor_status::Status::Active(_)),
}) = data.status
{
- cache.insert(executor_id, data);
+ self.executors_heartbeat.insert(executor_id, data);
}
}
Ok(())
@@ -520,10 +511,10 @@ impl ExecutorManager {
&self,
last_seen_ts_threshold: u64,
) -> HashSet<String> {
- let executors_heartbeat = self.executors_heartbeat.read();
- executors_heartbeat
+ self.executors_heartbeat
.iter()
- .filter_map(|(exec, heartbeat)| {
+ .filter_map(|pair| {
+ let (exec, heartbeat) = pair.pair();
(heartbeat.timestamp > last_seen_ts_threshold).then(||
exec.clone())
})
.collect()
@@ -539,10 +530,11 @@ impl ExecutorManager {
.unwrap_or_else(|| Duration::from_secs(0))
.as_secs();
- let lock = self.executors_heartbeat.read();
- let expired_executors = lock
+ let expired_executors = self
+ .executors_heartbeat
.iter()
- .filter_map(|(_exec, heartbeat)| {
+ .filter_map(|pair| {
+ let (_exec, heartbeat) = pair.pair();
(heartbeat.timestamp <= last_seen_threshold).then(||
heartbeat.clone())
})
.collect::<Vec<_>>();
@@ -565,15 +557,15 @@ impl ExecutorManager {
/// and maintain an in-memory copy of the executor heartbeats.
struct ExecutorHeartbeatListener {
state: Arc<dyn StateBackendClient>,
- executors_heartbeat: Arc<RwLock<HashMap<String,
protobuf::ExecutorHeartbeat>>>,
- dead_executors: Arc<RwLock<HashSet<String>>>,
+ executors_heartbeat: Arc<DashMap<String, protobuf::ExecutorHeartbeat>>,
+ dead_executors: Arc<DashSet<String>>,
}
impl ExecutorHeartbeatListener {
pub fn new(
state: Arc<dyn StateBackendClient>,
- executors_heartbeat: Arc<RwLock<HashMap<String,
protobuf::ExecutorHeartbeat>>>,
- dead_executors: Arc<RwLock<HashSet<String>>>,
+ executors_heartbeat: Arc<DashMap<String, protobuf::ExecutorHeartbeat>>,
+ dead_executors: Arc<DashSet<String>>,
) -> Self {
Self {
state,
@@ -598,14 +590,13 @@ impl ExecutorHeartbeatListener {
decode_protobuf::<protobuf::ExecutorHeartbeat>(&value)
{
let executor_id = data.executor_id.clone();
- let mut heartbeats = heartbeats.write();
// Remove dead executors
if let Some(ExecutorStatus {
status: Some(executor_status::Status::Dead(_)),
}) = data.status
{
heartbeats.remove(&executor_id);
- dead_executors.write().insert(executor_id);
+ dead_executors.insert(executor_id);
} else {
heartbeats.insert(executor_id, data);
}
diff --git a/ballista/rust/scheduler/src/state/mod.rs
b/ballista/rust/scheduler/src/state/mod.rs
index 2168ab64..3319cb12 100644
--- a/ballista/rust/scheduler/src/state/mod.rs
+++ b/ballista/rust/scheduler/src/state/mod.rs
@@ -274,11 +274,23 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> SchedulerState<T,
}
}
-pub async fn with_lock<Out, F: Future<Output = Out>>(lock: Box<dyn Lock>, op:
F) -> Out {
- let mut lock = lock;
+pub async fn with_lock<Out, F: Future<Output = Out>>(
+ mut lock: Box<dyn Lock>,
+ op: F,
+) -> Out {
let result = op.await;
lock.unlock().await;
-
+ result
+}
+/// It takes multiple locks and reverse the order for releasing them to
prevent a race condition.
+pub async fn with_locks<Out, F: Future<Output = Out>>(
+ locks: Vec<Box<dyn Lock>>,
+ op: F,
+) -> Out {
+ let result = op.await;
+ for mut lock in locks.into_iter().rev() {
+ lock.unlock().await;
+ }
result
}
diff --git a/ballista/rust/scheduler/src/state/session_registry.rs
b/ballista/rust/scheduler/src/state/session_registry.rs
index 1281449b..b6f214e5 100644
--- a/ballista/rust/scheduler/src/state/session_registry.rs
+++ b/ballista/rust/scheduler/src/state/session_registry.rs
@@ -15,15 +15,14 @@
// specific language governing permissions and limitations
// under the License.
+use dashmap::DashMap;
use datafusion::prelude::SessionContext;
-use std::collections::HashMap;
use std::sync::Arc;
-use tokio::sync::RwLock;
/// A Registry holds all the datafusion session contexts
pub struct SessionContextRegistry {
/// A map from session_id to SessionContext
- pub running_sessions: RwLock<HashMap<String, Arc<SessionContext>>>,
+ pub running_sessions: DashMap<String, Arc<SessionContext>>,
}
impl Default for SessionContextRegistry {
@@ -37,7 +36,7 @@ impl SessionContextRegistry {
/// ['LocalFileSystem'] store is registered in by default to support read
local files natively.
pub fn new() -> Self {
Self {
- running_sessions: RwLock::new(HashMap::new()),
+ running_sessions: DashMap::new(),
}
}
@@ -47,14 +46,14 @@ impl SessionContextRegistry {
session_ctx: Arc<SessionContext>,
) -> Option<Arc<SessionContext>> {
let session_id = session_ctx.session_id();
- let mut sessions = self.running_sessions.write().await;
- sessions.insert(session_id, session_ctx)
+ self.running_sessions.insert(session_id, session_ctx)
}
/// Lookup the session context registered
pub async fn lookup_session(&self, session_id: &str) ->
Option<Arc<SessionContext>> {
- let sessions = self.running_sessions.read().await;
- sessions.get(session_id).cloned()
+ self.running_sessions
+ .get(session_id)
+ .map(|value| value.clone())
}
/// Remove a session from this registry.
@@ -62,7 +61,9 @@ impl SessionContextRegistry {
&self,
session_id: &str,
) -> Option<Arc<SessionContext>> {
- let mut sessions = self.running_sessions.write().await;
- sessions.remove(session_id)
+ match self.running_sessions.remove(session_id) {
+ None => None,
+ Some(value) => Some(value.1),
+ }
}
}
diff --git a/ballista/rust/scheduler/src/state/task_manager.rs
b/ballista/rust/scheduler/src/state/task_manager.rs
index 810d0f37..acafbcd8 100644
--- a/ballista/rust/scheduler/src/state/task_manager.rs
+++ b/ballista/rust/scheduler/src/state/task_manager.rs
@@ -17,12 +17,12 @@
use crate::scheduler_server::event::QueryStageSchedulerEvent;
use crate::scheduler_server::SessionBuilder;
-use crate::state::backend::{Keyspace, Lock, StateBackendClient};
+use crate::state::backend::{Keyspace, Operation, StateBackendClient};
use crate::state::execution_graph::{
ExecutionGraph, ExecutionStage, RunningTaskInfo, TaskDescription,
};
use crate::state::executor_manager::{ExecutorManager, ExecutorReservation};
-use crate::state::{decode_protobuf, encode_protobuf, with_lock};
+use crate::state::{decode_protobuf, encode_protobuf, with_lock, with_locks};
use ballista_core::config::BallistaConfig;
#[cfg(not(test))]
use ballista_core::error::BallistaError;
@@ -35,6 +35,7 @@ use ballista_core::serde::protobuf::{
use ballista_core::serde::scheduler::to_proto::hash_partitioning_to_proto;
use ballista_core::serde::scheduler::ExecutorMetadata;
use ballista_core::serde::{AsExecutionPlan, BallistaCodec};
+use dashmap::DashMap;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;
use datafusion_proto::logical_plan::AsLogicalPlan;
@@ -45,8 +46,7 @@ use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
-
-type ExecutionGraphCache = Arc<RwLock<HashMap<String,
Arc<RwLock<ExecutionGraph>>>>>;
+type ExecutionGraphCache = Arc<DashMap<String, Arc<RwLock<ExecutionGraph>>>>;
// TODO move to configuration file
/// Default max failure attempts for task level retry
@@ -85,7 +85,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> TaskManager<T, U>
session_builder,
codec,
scheduler_id,
- active_job_cache: Arc::new(RwLock::new(HashMap::new())),
+ active_job_cache: Arc::new(DashMap::new()),
}
}
@@ -111,9 +111,8 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> TaskManager<T, U>
.await?;
graph.revive();
-
- let mut active_graph_cache = self.active_job_cache.write().await;
- active_graph_cache.insert(job_id.to_owned(),
Arc::new(RwLock::new(graph)));
+ self.active_job_cache
+ .insert(job_id.to_owned(), Arc::new(RwLock::new(graph)));
Ok(())
}
@@ -262,8 +261,8 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> TaskManager<T, U>
let mut assignments: Vec<(String, TaskDescription)> = vec![];
let mut pending_tasks = 0usize;
let mut assign_tasks = 0usize;
- let job_cache = self.active_job_cache.read().await;
- for (_job_id, graph) in job_cache.iter() {
+ for pairs in self.active_job_cache.iter() {
+ let (_job_id, graph) = pairs.pair();
let mut graph = graph.write().await;
for reservation in free_reservations.iter().skip(assign_tasks) {
if let Some(task) =
graph.pop_next_task(&reservation.executor_id)? {
@@ -321,7 +320,13 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> TaskManager<T, U>
job_id: &str,
failure_reason: String,
) -> Result<Vec<RunningTaskInfo>> {
- let lock = self.state.lock(Keyspace::ActiveJobs, "").await?;
+ let locks = self
+ .state
+ .acquire_locks(vec![
+ (Keyspace::ActiveJobs, job_id),
+ (Keyspace::FailedJobs, job_id),
+ ])
+ .await?;
if let Some(graph) = self.get_active_execution_graph(job_id).await {
let running_tasks = graph.read().await.running_tasks();
info!(
@@ -329,12 +334,12 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> TaskManager<T, U>
running_tasks.len(),
job_id
);
- self.fail_job_state(lock, job_id, failure_reason).await?;
+ with_locks(locks, self.fail_job_state(job_id,
failure_reason)).await?;
Ok(running_tasks)
} else {
// TODO listen the job state update event and fix task cancelling
warn!("Fail to find job {} in the cache, unable to cancel tasks
for job, fail the job state only.", job_id);
- self.fail_job_state(lock, job_id, failure_reason).await?;
+ with_locks(locks, self.fail_job_state(job_id,
failure_reason)).await?;
Ok(vec![])
}
}
@@ -348,44 +353,55 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> TaskManager<T, U>
failure_reason: String,
) -> Result<()> {
debug!("Moving job {} from Active or Queue to Failed", job_id);
- let lock = self.state.lock(Keyspace::ActiveJobs, "").await?;
- self.fail_job_state(lock, job_id, failure_reason).await
+ let locks = self
+ .state
+ .acquire_locks(vec![
+ (Keyspace::ActiveJobs, job_id),
+ (Keyspace::FailedJobs, job_id),
+ ])
+ .await?;
+ with_locks(locks, self.fail_job_state(job_id, failure_reason)).await
}
- async fn fail_job_state(
- &self,
- lock: Box<dyn Lock>,
- job_id: &str,
- failure_reason: String,
- ) -> Result<()> {
- with_lock(lock, self.state.delete(Keyspace::ActiveJobs,
job_id)).await?;
+ async fn fail_job_state(&self, job_id: &str, failure_reason: String) ->
Result<()> {
+ let txn_operations = |value: Vec<u8>| -> Vec<(Operation, Keyspace,
String)> {
+ vec![
+ (Operation::Delete, Keyspace::ActiveJobs, job_id.to_string()),
+ (
+ Operation::Put(value),
+ Keyspace::FailedJobs,
+ job_id.to_string(),
+ ),
+ ]
+ };
- let value = if let Some(graph) =
self.get_active_execution_graph(job_id).await {
+ let _res = if let Some(graph) =
self.get_active_execution_graph(job_id).await {
let mut graph = graph.write().await;
- for stage_id in graph.running_stages() {
- graph.fail_stage(stage_id, failure_reason.clone());
- }
+ let previous_status = graph.status();
graph.fail_job(failure_reason);
- let graph = graph.clone();
- self.encode_execution_graph(graph)?
+ let value = self.encode_execution_graph(graph.clone())?;
+ let txn_ops = txn_operations(value);
+ let result = self.state.apply_txn(txn_ops).await;
+ if result.is_err() {
+ // Rollback
+ graph.update_status(previous_status);
+ warn!("Rollback Execution Graph state change since it did not
persisted due to a possible connection error.")
+ };
+ result
} else {
warn!("Fail to find job {} in the cache", job_id);
-
let status = JobStatus {
status: Some(job_status::Status::Failed(FailedJob {
error: failure_reason.clone(),
})),
};
- encode_protobuf(&status)?
+ let value = encode_protobuf(&status)?;
+ let txn_ops = txn_operations(value);
+ self.state.apply_txn(txn_ops).await
};
- self.state
- .put(Keyspace::FailedJobs, job_id.to_owned(), value)
- .await?;
-
Ok(())
}
-
pub async fn update_job(&self, job_id: &str) -> Result<()> {
debug!("Update job {} in Active", job_id);
if let Some(graph) = self.get_active_execution_graph(job_id).await {
@@ -408,10 +424,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> TaskManager<T, U>
// Collect all the running task need to cancel when there are running
stages rolled back.
let mut running_tasks_to_cancel: Vec<RunningTaskInfo> = vec![];
// Collect graphs we update so we can update them in storage
- let mut updated_graphs: HashMap<String, ExecutionGraph> =
HashMap::new();
+ let updated_graphs: DashMap<String, ExecutionGraph> = DashMap::new();
{
- let job_cache = self.active_job_cache.read().await;
- for (job_id, graph) in job_cache.iter() {
+ for pairs in self.active_job_cache.iter() {
+ let (job_id, graph) = pairs.pair();
let mut graph = graph.write().await;
let reset = graph.reset_stages_on_lost_executor(executor_id)?;
if !reset.0.is_empty() {
@@ -424,14 +440,14 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> TaskManager<T, U>
let lock = self.state.lock(Keyspace::ActiveJobs, "").await?;
with_lock(lock, async {
// Transactional update graphs
- let txn_ops: Vec<(Keyspace, String, Vec<u8>)> = updated_graphs
+ let txn_ops: Vec<(Operation, Keyspace, String)> = updated_graphs
.into_iter()
.map(|(job_id, graph)| {
let value = self.encode_execution_graph(graph)?;
- Ok((Keyspace::ActiveJobs, job_id, value))
+ Ok((Operation::Put(value), Keyspace::ActiveJobs, job_id))
})
.collect::<Result<Vec<_>>>()?;
- self.state.put_txn(txn_ops).await?;
+ self.state.apply_txn(txn_ops).await?;
Ok(running_tasks_to_cancel)
})
.await
@@ -524,8 +540,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> TaskManager<T, U>
&self,
job_id: &str,
) -> Option<Arc<RwLock<ExecutionGraph>>> {
- let active_graph_cache = self.active_job_cache.read().await;
- active_graph_cache.get(job_id).cloned()
+ self.active_job_cache.get(job_id).map(|value| value.clone())
}
/// Get the `ExecutionGraph` for the given job ID. This will search fist
in the `ActiveJobs`