This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/master by this push:
     new caff2fd4 Atomic support for enhancement (#319)
caff2fd4 is described below

commit caff2fd413c31b456ccd73febd1de998731e0589
Author: Metehan Yıldırım <[email protected]>
AuthorDate: Fri Oct 7 15:29:39 2022 +0300

    Atomic support for enhancement (#319)
    
    * Atomic support for enhancement
    
    * Incorporate review feedback, simplify with_lock(s) functions
    
    Co-authored-by: Mehmet Ozan Kabak <[email protected]>
---
 ballista/rust/executor/Cargo.toml                  |   1 +
 ballista/rust/executor/src/executor_server.rs      |  20 ++--
 ballista/rust/scheduler/Cargo.toml                 |   1 +
 ballista/rust/scheduler/src/state/backend/etcd.rs  |  24 +++--
 ballista/rust/scheduler/src/state/backend/mod.rs   |  23 ++++-
 .../rust/scheduler/src/state/backend/standalone.rs |  48 ++++++++--
 .../rust/scheduler/src/state/executor_manager.rs   |  89 ++++++++----------
 ballista/rust/scheduler/src/state/mod.rs           |  18 +++-
 .../rust/scheduler/src/state/session_registry.rs   |  21 +++--
 ballista/rust/scheduler/src/state/task_manager.rs  | 101 ++++++++++++---------
 10 files changed, 211 insertions(+), 135 deletions(-)

diff --git a/ballista/rust/executor/Cargo.toml 
b/ballista/rust/executor/Cargo.toml
index 253052fd..ca3532ac 100644
--- a/ballista/rust/executor/Cargo.toml
+++ b/ballista/rust/executor/Cargo.toml
@@ -40,6 +40,7 @@ async-trait = "0.1.41"
 ballista-core = { path = "../core", version = "0.8.0" }
 chrono = { version = "0.4", default-features = false }
 configure_me = "0.4.0"
+dashmap = "5.4.0"
 datafusion = { git = "https://github.com/apache/arrow-datafusion";, rev = 
"06a4f79f02fcb6ea85303925b7c5a9b0231e3fee" }
 datafusion-proto = { git = "https://github.com/apache/arrow-datafusion";, rev = 
"06a4f79f02fcb6ea85303925b7c5a9b0231e3fee" }
 futures = "0.3"
diff --git a/ballista/rust/executor/src/executor_server.rs 
b/ballista/rust/executor/src/executor_server.rs
index 2bc84f72..bf036d6f 100644
--- a/ballista/rust/executor/src/executor_server.rs
+++ b/ballista/rust/executor/src/executor_server.rs
@@ -21,7 +21,7 @@ use std::convert::TryInto;
 use std::ops::Deref;
 use std::sync::Arc;
 use std::time::{Duration, SystemTime, UNIX_EPOCH};
-use tokio::sync::{mpsc, RwLock};
+use tokio::sync::mpsc;
 
 use log::{debug, error, info, warn};
 use tonic::transport::Channel;
@@ -45,6 +45,7 @@ use ballista_core::serde::{AsExecutionPlan, BallistaCodec};
 use ballista_core::utils::{
     collect_plan_metrics, create_grpc_client_connection, create_grpc_server,
 };
+use dashmap::DashMap;
 use datafusion::execution::context::TaskContext;
 use datafusion::physical_plan::ExecutionPlan;
 use datafusion_proto::logical_plan::AsLogicalPlan;
@@ -57,7 +58,7 @@ use crate::shutdown::ShutdownNotifier;
 use crate::{as_task_status, TaskExecutionTimes};
 
 type ServerHandle = JoinHandle<Result<(), BallistaError>>;
-type SchedulerClients = Arc<RwLock<HashMap<String, 
SchedulerGrpcClient<Channel>>>>;
+type SchedulerClients = Arc<DashMap<String, SchedulerGrpcClient<Channel>>>;
 
 /// Wrap TaskDefinition with its curator scheduler id for task update to its 
specific curator scheduler later
 #[derive(Debug)]
@@ -216,10 +217,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorServer<T,
         &self,
         scheduler_id: &str,
     ) -> Result<SchedulerGrpcClient<Channel>, BallistaError> {
-        let scheduler = {
-            let schedulers = self.schedulers.read().await;
-            schedulers.get(scheduler_id).cloned()
-        };
+        let scheduler = self.schedulers.get(scheduler_id).map(|value| 
value.clone());
         // If channel does not exist, create a new one
         if let Some(scheduler) = scheduler {
             Ok(scheduler)
@@ -229,8 +227,8 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorServer<T,
             let scheduler = SchedulerGrpcClient::new(connection);
 
             {
-                let mut schedulers = self.schedulers.write().await;
-                schedulers.insert(scheduler_id.to_owned(), scheduler.clone());
+                self.schedulers
+                    .insert(scheduler_id.to_owned(), scheduler.clone());
             }
 
             Ok(scheduler)
@@ -263,8 +261,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorServer<T,
             }
         };
 
-        let schedulers = self.schedulers.read().await.clone();
-        for (scheduler_id, mut scheduler) in schedulers {
+        for mut item in self.schedulers.iter_mut() {
+            let scheduler_id = item.key().clone();
+            let scheduler = item.value_mut();
+
             match scheduler
                 .heart_beat_from_executor(heartbeat_params.clone())
                 .await
diff --git a/ballista/rust/scheduler/Cargo.toml 
b/ballista/rust/scheduler/Cargo.toml
index 954e5a48..90e3f020 100644
--- a/ballista/rust/scheduler/Cargo.toml
+++ b/ballista/rust/scheduler/Cargo.toml
@@ -45,6 +45,7 @@ ballista-core = { path = "../core", version = "0.8.0" }
 base64 = { version = "0.13", default-features = false }
 clap = { version = "3", features = ["derive", "cargo"] }
 configure_me = "0.4.0"
+dashmap = "5.4.0"
 datafusion = { git = "https://github.com/apache/arrow-datafusion";, rev = 
"06a4f79f02fcb6ea85303925b7c5a9b0231e3fee" }
 datafusion-proto = { git = "https://github.com/apache/arrow-datafusion";, rev = 
"06a4f79f02fcb6ea85303925b7c5a9b0231e3fee" }
 etcd-client = { version = "0.9", optional = true }
diff --git a/ballista/rust/scheduler/src/state/backend/etcd.rs 
b/ballista/rust/scheduler/src/state/backend/etcd.rs
index 4b24b7aa..e753df66 100644
--- a/ballista/rust/scheduler/src/state/backend/etcd.rs
+++ b/ballista/rust/scheduler/src/state/backend/etcd.rs
@@ -30,7 +30,9 @@ use etcd_client::{
 use futures::{Stream, StreamExt};
 use log::{debug, error, warn};
 
-use crate::state::backend::{Keyspace, Lock, StateBackendClient, Watch, 
WatchEvent};
+use crate::state::backend::{
+    Keyspace, Lock, Operation, StateBackendClient, Watch, WatchEvent,
+};
 
 /// A [`StateBackendClient`] implementation that uses etcd to save cluster 
configuration.
 #[derive(Clone)]
@@ -137,29 +139,31 @@ impl StateBackendClient for EtcdClient {
             .await
             .map_err(|e| {
                 warn!("etcd put failed: {}", e);
-                ballista_error("etcd put failed")
+                ballista_error(&*format!("etcd put failed: {}", e))
             })
             .map(|_| ())
     }
 
-    async fn put_txn(&self, ops: Vec<(Keyspace, String, Vec<u8>)>) -> 
Result<()> {
+    /// Apply multiple operations in a single transaction.
+    async fn apply_txn(&self, ops: Vec<(Operation, Keyspace, String)>) -> 
Result<()> {
         let mut etcd = self.etcd.clone();
 
         let txn_ops: Vec<TxnOp> = ops
             .into_iter()
-            .map(|(ks, key, value)| {
+            .map(|(operation, ks, key)| {
                 let key = format!("/{}/{:?}/{}", self.namespace, ks, key);
-                TxnOp::put(key, value, None)
+                match operation {
+                    Operation::Put(value) => TxnOp::put(key, value, None),
+                    Operation::Delete => TxnOp::delete(key, None),
+                }
             })
             .collect();
 
-        let txn = Txn::new().and_then(txn_ops);
-
-        etcd.txn(txn)
+        etcd.txn(Txn::new().and_then(txn_ops))
             .await
             .map_err(|e| {
-                error!("etcd put failed: {}", e);
-                ballista_error("etcd transaction put failed")
+                error!("etcd operation failed: {}", e);
+                ballista_error(&*format!("etcd operation failed: {}", e))
             })
             .map(|_| ())
     }
diff --git a/ballista/rust/scheduler/src/state/backend/mod.rs 
b/ballista/rust/scheduler/src/state/backend/mod.rs
index b69403b2..85c0f34d 100644
--- a/ballista/rust/scheduler/src/state/backend/mod.rs
+++ b/ballista/rust/scheduler/src/state/backend/mod.rs
@@ -17,7 +17,7 @@
 
 use ballista_core::error::Result;
 use clap::ArgEnum;
-use futures::Stream;
+use futures::{future, Stream};
 use std::collections::HashSet;
 use std::fmt;
 use tokio::sync::OwnedMutexGuard;
@@ -60,6 +60,12 @@ pub enum Keyspace {
     Heartbeats,
 }
 
+#[derive(Debug, Eq, PartialEq, Hash)]
+pub enum Operation {
+    Put(Vec<u8>),
+    Delete,
+}
+
 /// A trait that contains the necessary methods to save and retrieve the state 
and configuration of a cluster.
 #[tonic::async_trait]
 pub trait StateBackendClient: Send + Sync {
@@ -90,8 +96,19 @@ pub trait StateBackendClient: Send + Sync {
     /// Saves the value into the provided key, overriding any previous data 
that might have been associated to that key.
     async fn put(&self, keyspace: Keyspace, key: String, value: Vec<u8>) -> 
Result<()>;
 
-    /// Save multiple values in a single transaction. Either all values should 
be saved, or all should fail
-    async fn put_txn(&self, ops: Vec<(Keyspace, String, Vec<u8>)>) -> 
Result<()>;
+    /// Bundle multiple operation in a single transaction. Either all values 
should be saved, or all should fail.
+    /// It can support multiple types of operations and keyspaces. If the 
count of the unique keyspace is more than one,
+    /// more than one locks has to be acquired.
+    async fn apply_txn(&self, ops: Vec<(Operation, Keyspace, String)>) -> 
Result<()>;
+    /// Acquire mutex with specified IDs.
+    async fn acquire_locks(
+        &self,
+        mut ids: Vec<(Keyspace, &str)>,
+    ) -> Result<Vec<Box<dyn Lock>>> {
+        // We always acquire locks in a specific order to avoid deadlocks.
+        ids.sort_by_key(|n| format!("/{:?}/{}", n.0, n.1));
+        future::try_join_all(ids.into_iter().map(|(ks, key)| self.lock(ks, 
key))).await
+    }
 
     /// Atomically move the given key from one keyspace to another
     async fn mv(
diff --git a/ballista/rust/scheduler/src/state/backend/standalone.rs 
b/ballista/rust/scheduler/src/state/backend/standalone.rs
index 4e5dc063..57bf7470 100644
--- a/ballista/rust/scheduler/src/state/backend/standalone.rs
+++ b/ballista/rust/scheduler/src/state/backend/standalone.rs
@@ -25,7 +25,9 @@ use log::warn;
 use sled_package as sled;
 use tokio::sync::Mutex;
 
-use crate::state::backend::{Keyspace, Lock, StateBackendClient, Watch, 
WatchEvent};
+use crate::state::backend::{
+    Keyspace, Lock, Operation, StateBackendClient, Watch, WatchEvent,
+};
 
 /// A [`StateBackendClient`] implementation that uses file-based storage to 
save cluster configuration.
 #[derive(Clone)]
@@ -162,17 +164,20 @@ impl StateBackendClient for StandaloneClient {
             .map(|_| ())
     }
 
-    async fn put_txn(&self, ops: Vec<(Keyspace, String, Vec<u8>)>) -> 
Result<()> {
+    async fn apply_txn(&self, ops: Vec<(Operation, Keyspace, String)>) -> 
Result<()> {
         let mut batch = sled::Batch::default();
 
-        for (ks, key, value) in ops {
-            let key = format!("/{:?}/{}", ks, key);
-            batch.insert(key.as_str(), value);
+        for (op, keyspace, key_str) in ops {
+            let key = format!("/{:?}/{}", &keyspace, key_str);
+            match op {
+                Operation::Put(value) => batch.insert(key.as_str(), value),
+                Operation::Delete => batch.remove(key.as_str()),
+            }
         }
 
         self.db.apply_batch(batch).map_err(|e| {
             warn!("sled transaction insert failed: {}", e);
-            ballista_error("sled insert failed")
+            ballista_error("sled operations failed")
         })
     }
 
@@ -279,7 +284,8 @@ impl Stream for SledWatch {
 mod tests {
     use super::{StandaloneClient, StateBackendClient, Watch, WatchEvent};
 
-    use crate::state::backend::Keyspace;
+    use crate::state::backend::{Keyspace, Operation};
+    use crate::state::with_locks;
     use futures::StreamExt;
     use std::result::Result;
 
@@ -299,6 +305,34 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn multiple_operation() -> Result<(), Box<dyn std::error::Error>> {
+        let client = create_instance()?;
+        let key = "key".to_string();
+        let value = "value".as_bytes().to_vec();
+        let locks = client
+            .acquire_locks(vec![(Keyspace::ActiveJobs, ""), (Keyspace::Slots, 
"")])
+            .await?;
+
+        let _r: ballista_core::error::Result<()> = with_locks(locks, async {
+            let txn_ops = vec![
+                (Operation::Put(value.clone()), Keyspace::Slots, key.clone()),
+                (
+                    Operation::Put(value.clone()),
+                    Keyspace::ActiveJobs,
+                    key.clone(),
+                ),
+            ];
+            client.apply_txn(txn_ops).await?;
+            Ok(())
+        })
+        .await;
+
+        assert_eq!(client.get(Keyspace::Slots, key.as_str()).await?, value);
+        assert_eq!(client.get(Keyspace::ActiveJobs, key.as_str()).await?, 
value);
+        Ok(())
+    }
+
     #[tokio::test]
     async fn read_empty() -> Result<(), Box<dyn std::error::Error>> {
         let client = create_instance()?;
diff --git a/ballista/rust/scheduler/src/state/executor_manager.rs 
b/ballista/rust/scheduler/src/state/executor_manager.rs
index 1d135ef8..322e5a7b 100644
--- a/ballista/rust/scheduler/src/state/executor_manager.rs
+++ b/ballista/rust/scheduler/src/state/executor_manager.rs
@@ -17,7 +17,7 @@
 
 use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
 
-use crate::state::backend::{Keyspace, StateBackendClient, WatchEvent};
+use crate::state::backend::{Keyspace, Operation, StateBackendClient, 
WatchEvent};
 
 use crate::state::{decode_into, decode_protobuf, encode_protobuf, with_lock};
 use ballista_core::error::{BallistaError, Result};
@@ -30,14 +30,14 @@ use ballista_core::serde::protobuf::{
 };
 use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata};
 use ballista_core::utils::create_grpc_client_connection;
+use dashmap::{DashMap, DashSet};
 use futures::StreamExt;
 use log::{debug, error, info};
-use parking_lot::RwLock;
 use std::collections::{HashMap, HashSet};
 use std::sync::Arc;
 use tonic::transport::Channel;
 
-type ExecutorClients = Arc<RwLock<HashMap<String, 
ExecutorGrpcClient<Channel>>>>;
+type ExecutorClients = Arc<DashMap<String, ExecutorGrpcClient<Channel>>>;
 
 /// Represents a task slot that is reserved (i.e. available for scheduling but 
not visible to the
 /// rest of the system).
@@ -85,11 +85,11 @@ pub const DEFAULT_EXECUTOR_TIMEOUT_SECONDS: u64 = 180;
 pub(crate) struct ExecutorManager {
     state: Arc<dyn StateBackendClient>,
     // executor_id -> ExecutorMetadata map
-    executor_metadata: Arc<RwLock<HashMap<String, ExecutorMetadata>>>,
+    executor_metadata: Arc<DashMap<String, ExecutorMetadata>>,
     // executor_id -> ExecutorHeartbeat map
-    executors_heartbeat: Arc<RwLock<HashMap<String, 
protobuf::ExecutorHeartbeat>>>,
+    executors_heartbeat: Arc<DashMap<String, protobuf::ExecutorHeartbeat>>,
     // dead executor sets:
-    dead_executors: Arc<RwLock<HashSet<String>>>,
+    dead_executors: Arc<DashSet<String>>,
     clients: ExecutorClients,
 }
 
@@ -97,9 +97,9 @@ impl ExecutorManager {
     pub(crate) fn new(state: Arc<dyn StateBackendClient>) -> Self {
         Self {
             state,
-            executor_metadata: Arc::new(RwLock::new(HashMap::new())),
-            executors_heartbeat: Arc::new(RwLock::new(HashMap::new())),
-            dead_executors: Arc::new(RwLock::new(HashSet::new())),
+            executor_metadata: Arc::new(DashMap::new()),
+            executors_heartbeat: Arc::new(DashMap::new()),
+            dead_executors: Arc::new(DashSet::new()),
             clients: Default::default(),
         }
     }
@@ -130,7 +130,7 @@ impl ExecutorManager {
 
             let alive_executors = self.get_alive_executors_within_one_minute();
 
-            let mut txn_ops: Vec<(Keyspace, String, Vec<u8>)> = vec![];
+            let mut txn_ops: Vec<(Operation, Keyspace, String)> = vec![];
 
             for executor_id in alive_executors {
                 let value = self.state.get(Keyspace::Slots, 
&executor_id).await?;
@@ -146,14 +146,14 @@ impl ExecutorManager {
 
                 let proto: protobuf::ExecutorData = data.into();
                 let new_data = encode_protobuf(&proto)?;
-                txn_ops.push((Keyspace::Slots, executor_id, new_data));
+                txn_ops.push((Operation::Put(new_data), Keyspace::Slots, 
executor_id));
 
                 if desired == 0 {
                     break;
                 }
             }
 
-            self.state.put_txn(txn_ops).await?;
+            self.state.apply_txn(txn_ops).await?;
 
             let elapsed = start.elapsed();
             info!(
@@ -195,16 +195,16 @@ impl ExecutorManager {
                 }
             }
 
-            let txn_ops: Vec<(Keyspace, String, Vec<u8>)> = executor_slots
+            let txn_ops: Vec<(Operation, Keyspace, String)> = executor_slots
                 .into_iter()
                 .map(|(executor_id, data)| {
                     let proto: protobuf::ExecutorData = data.into();
                     let new_data = encode_protobuf(&proto)?;
-                    Ok((Keyspace::Slots, executor_id, new_data))
+                    Ok((Operation::Put(new_data), Keyspace::Slots, 
executor_id))
                 })
                 .collect::<Result<Vec<_>>>()?;
 
-            self.state.put_txn(txn_ops).await?;
+            self.state.apply_txn(txn_ops).await?;
 
             let elapsed = start.elapsed();
             info!(
@@ -262,10 +262,7 @@ impl ExecutorManager {
         &self,
         executor_id: &str,
     ) -> Result<ExecutorGrpcClient<Channel>> {
-        let client = {
-            let clients = self.clients.read();
-            clients.get(executor_id).cloned()
-        };
+        let client = self.clients.get(executor_id).map(|value| value.clone());
 
         if let Some(client) = client {
             Ok(client)
@@ -279,8 +276,7 @@ impl ExecutorManager {
             let client = ExecutorGrpcClient::new(connection);
 
             {
-                let mut clients = self.clients.write();
-                clients.insert(executor_id.to_owned(), client.clone());
+                self.clients.insert(executor_id.to_owned(), client.clone());
             }
             Ok(client)
         }
@@ -289,11 +285,10 @@ impl ExecutorManager {
     /// Get a list of all executors along with the timestamp of their last 
recorded heartbeat
     pub async fn get_executor_state(&self) -> Result<Vec<(ExecutorMetadata, 
Duration)>> {
         let heartbeat_timestamps: Vec<(String, u64)> = {
-            let heartbeats = self.executors_heartbeat.read();
-
-            heartbeats
+            self.executors_heartbeat
                 .iter()
-                .map(|(executor_id, heartbeat)| {
+                .map(|item| {
+                    let (executor_id, heartbeat) = item.pair();
                     (executor_id.clone(), heartbeat.timestamp)
                 })
                 .collect()
@@ -316,8 +311,7 @@ impl ExecutorManager {
         executor_id: &str,
     ) -> Result<ExecutorMetadata> {
         {
-            let metadata_cache = self.executor_metadata.read();
-            if let Some(cached) = metadata_cache.get(executor_id) {
+            if let Some(cached) = self.executor_metadata.get(executor_id) {
                 return Ok(cached.clone());
             }
         }
@@ -468,8 +462,8 @@ impl ExecutorManager {
             .put(Keyspace::Heartbeats, executor_id, value)
             .await?;
 
-        let mut executors_heartbeat = self.executors_heartbeat.write();
-        executors_heartbeat.insert(heartbeat.executor_id.clone(), heartbeat);
+        self.executors_heartbeat
+            .insert(heartbeat.executor_id.clone(), heartbeat);
 
         Ok(())
     }
@@ -484,22 +478,19 @@ impl ExecutorManager {
             .put(Keyspace::Heartbeats, executor_id.clone(), value)
             .await?;
 
-        let mut executors_heartbeat = self.executors_heartbeat.write();
-        executors_heartbeat.remove(&heartbeat.executor_id.clone());
-
-        let mut dead_executors = self.dead_executors.write();
-        dead_executors.insert(executor_id);
+        self.executors_heartbeat
+            .remove(&heartbeat.executor_id.clone());
+        self.dead_executors.insert(executor_id);
         Ok(())
     }
 
     pub(crate) fn is_dead_executor(&self, executor_id: &str) -> bool {
-        self.dead_executors.read().contains(executor_id)
+        self.dead_executors.contains(executor_id)
     }
 
     /// Initialize the set of active executor heartbeats from storage
     async fn init_active_executor_heartbeats(&self) -> Result<()> {
         let heartbeats = self.state.scan(Keyspace::Heartbeats, None).await?;
-        let mut cache = self.executors_heartbeat.write();
 
         for (_, value) in heartbeats {
             let data: protobuf::ExecutorHeartbeat = decode_protobuf(&value)?;
@@ -508,7 +499,7 @@ impl ExecutorManager {
                 status: Some(executor_status::Status::Active(_)),
             }) = data.status
             {
-                cache.insert(executor_id, data);
+                self.executors_heartbeat.insert(executor_id, data);
             }
         }
         Ok(())
@@ -520,10 +511,10 @@ impl ExecutorManager {
         &self,
         last_seen_ts_threshold: u64,
     ) -> HashSet<String> {
-        let executors_heartbeat = self.executors_heartbeat.read();
-        executors_heartbeat
+        self.executors_heartbeat
             .iter()
-            .filter_map(|(exec, heartbeat)| {
+            .filter_map(|pair| {
+                let (exec, heartbeat) = pair.pair();
                 (heartbeat.timestamp > last_seen_ts_threshold).then(|| 
exec.clone())
             })
             .collect()
@@ -539,10 +530,11 @@ impl ExecutorManager {
             .unwrap_or_else(|| Duration::from_secs(0))
             .as_secs();
 
-        let lock = self.executors_heartbeat.read();
-        let expired_executors = lock
+        let expired_executors = self
+            .executors_heartbeat
             .iter()
-            .filter_map(|(_exec, heartbeat)| {
+            .filter_map(|pair| {
+                let (_exec, heartbeat) = pair.pair();
                 (heartbeat.timestamp <= last_seen_threshold).then(|| 
heartbeat.clone())
             })
             .collect::<Vec<_>>();
@@ -565,15 +557,15 @@ impl ExecutorManager {
 /// and maintain an in-memory copy of the executor heartbeats.
 struct ExecutorHeartbeatListener {
     state: Arc<dyn StateBackendClient>,
-    executors_heartbeat: Arc<RwLock<HashMap<String, 
protobuf::ExecutorHeartbeat>>>,
-    dead_executors: Arc<RwLock<HashSet<String>>>,
+    executors_heartbeat: Arc<DashMap<String, protobuf::ExecutorHeartbeat>>,
+    dead_executors: Arc<DashSet<String>>,
 }
 
 impl ExecutorHeartbeatListener {
     pub fn new(
         state: Arc<dyn StateBackendClient>,
-        executors_heartbeat: Arc<RwLock<HashMap<String, 
protobuf::ExecutorHeartbeat>>>,
-        dead_executors: Arc<RwLock<HashSet<String>>>,
+        executors_heartbeat: Arc<DashMap<String, protobuf::ExecutorHeartbeat>>,
+        dead_executors: Arc<DashSet<String>>,
     ) -> Self {
         Self {
             state,
@@ -598,14 +590,13 @@ impl ExecutorHeartbeatListener {
                         decode_protobuf::<protobuf::ExecutorHeartbeat>(&value)
                     {
                         let executor_id = data.executor_id.clone();
-                        let mut heartbeats = heartbeats.write();
                         // Remove dead executors
                         if let Some(ExecutorStatus {
                             status: Some(executor_status::Status::Dead(_)),
                         }) = data.status
                         {
                             heartbeats.remove(&executor_id);
-                            dead_executors.write().insert(executor_id);
+                            dead_executors.insert(executor_id);
                         } else {
                             heartbeats.insert(executor_id, data);
                         }
diff --git a/ballista/rust/scheduler/src/state/mod.rs 
b/ballista/rust/scheduler/src/state/mod.rs
index 2168ab64..3319cb12 100644
--- a/ballista/rust/scheduler/src/state/mod.rs
+++ b/ballista/rust/scheduler/src/state/mod.rs
@@ -274,11 +274,23 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerState<T,
     }
 }
 
-pub async fn with_lock<Out, F: Future<Output = Out>>(lock: Box<dyn Lock>, op: 
F) -> Out {
-    let mut lock = lock;
+pub async fn with_lock<Out, F: Future<Output = Out>>(
+    mut lock: Box<dyn Lock>,
+    op: F,
+) -> Out {
     let result = op.await;
     lock.unlock().await;
-
+    result
+}
+/// It takes multiple locks and reverse the order for releasing them to 
prevent a race condition.
+pub async fn with_locks<Out, F: Future<Output = Out>>(
+    locks: Vec<Box<dyn Lock>>,
+    op: F,
+) -> Out {
+    let result = op.await;
+    for mut lock in locks.into_iter().rev() {
+        lock.unlock().await;
+    }
     result
 }
 
diff --git a/ballista/rust/scheduler/src/state/session_registry.rs 
b/ballista/rust/scheduler/src/state/session_registry.rs
index 1281449b..b6f214e5 100644
--- a/ballista/rust/scheduler/src/state/session_registry.rs
+++ b/ballista/rust/scheduler/src/state/session_registry.rs
@@ -15,15 +15,14 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use dashmap::DashMap;
 use datafusion::prelude::SessionContext;
-use std::collections::HashMap;
 use std::sync::Arc;
-use tokio::sync::RwLock;
 
 /// A Registry holds all the datafusion session contexts
 pub struct SessionContextRegistry {
     /// A map from session_id to SessionContext
-    pub running_sessions: RwLock<HashMap<String, Arc<SessionContext>>>,
+    pub running_sessions: DashMap<String, Arc<SessionContext>>,
 }
 
 impl Default for SessionContextRegistry {
@@ -37,7 +36,7 @@ impl SessionContextRegistry {
     /// ['LocalFileSystem'] store is registered in by default to support read 
local files natively.
     pub fn new() -> Self {
         Self {
-            running_sessions: RwLock::new(HashMap::new()),
+            running_sessions: DashMap::new(),
         }
     }
 
@@ -47,14 +46,14 @@ impl SessionContextRegistry {
         session_ctx: Arc<SessionContext>,
     ) -> Option<Arc<SessionContext>> {
         let session_id = session_ctx.session_id();
-        let mut sessions = self.running_sessions.write().await;
-        sessions.insert(session_id, session_ctx)
+        self.running_sessions.insert(session_id, session_ctx)
     }
 
     /// Lookup the session context registered
     pub async fn lookup_session(&self, session_id: &str) -> 
Option<Arc<SessionContext>> {
-        let sessions = self.running_sessions.read().await;
-        sessions.get(session_id).cloned()
+        self.running_sessions
+            .get(session_id)
+            .map(|value| value.clone())
     }
 
     /// Remove a session from this registry.
@@ -62,7 +61,9 @@ impl SessionContextRegistry {
         &self,
         session_id: &str,
     ) -> Option<Arc<SessionContext>> {
-        let mut sessions = self.running_sessions.write().await;
-        sessions.remove(session_id)
+        match self.running_sessions.remove(session_id) {
+            None => None,
+            Some(value) => Some(value.1),
+        }
     }
 }
diff --git a/ballista/rust/scheduler/src/state/task_manager.rs 
b/ballista/rust/scheduler/src/state/task_manager.rs
index 810d0f37..acafbcd8 100644
--- a/ballista/rust/scheduler/src/state/task_manager.rs
+++ b/ballista/rust/scheduler/src/state/task_manager.rs
@@ -17,12 +17,12 @@
 
 use crate::scheduler_server::event::QueryStageSchedulerEvent;
 use crate::scheduler_server::SessionBuilder;
-use crate::state::backend::{Keyspace, Lock, StateBackendClient};
+use crate::state::backend::{Keyspace, Operation, StateBackendClient};
 use crate::state::execution_graph::{
     ExecutionGraph, ExecutionStage, RunningTaskInfo, TaskDescription,
 };
 use crate::state::executor_manager::{ExecutorManager, ExecutorReservation};
-use crate::state::{decode_protobuf, encode_protobuf, with_lock};
+use crate::state::{decode_protobuf, encode_protobuf, with_lock, with_locks};
 use ballista_core::config::BallistaConfig;
 #[cfg(not(test))]
 use ballista_core::error::BallistaError;
@@ -35,6 +35,7 @@ use ballista_core::serde::protobuf::{
 use ballista_core::serde::scheduler::to_proto::hash_partitioning_to_proto;
 use ballista_core::serde::scheduler::ExecutorMetadata;
 use ballista_core::serde::{AsExecutionPlan, BallistaCodec};
+use dashmap::DashMap;
 use datafusion::physical_plan::ExecutionPlan;
 use datafusion::prelude::SessionContext;
 use datafusion_proto::logical_plan::AsLogicalPlan;
@@ -45,8 +46,7 @@ use std::collections::{HashMap, HashSet};
 use std::sync::Arc;
 use std::time::{SystemTime, UNIX_EPOCH};
 use tokio::sync::RwLock;
-
-type ExecutionGraphCache = Arc<RwLock<HashMap<String, 
Arc<RwLock<ExecutionGraph>>>>>;
+type ExecutionGraphCache = Arc<DashMap<String, Arc<RwLock<ExecutionGraph>>>>;
 
 // TODO move to configuration file
 /// Default max failure attempts for task level retry
@@ -85,7 +85,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
             session_builder,
             codec,
             scheduler_id,
-            active_job_cache: Arc::new(RwLock::new(HashMap::new())),
+            active_job_cache: Arc::new(DashMap::new()),
         }
     }
 
@@ -111,9 +111,8 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
             .await?;
 
         graph.revive();
-
-        let mut active_graph_cache = self.active_job_cache.write().await;
-        active_graph_cache.insert(job_id.to_owned(), 
Arc::new(RwLock::new(graph)));
+        self.active_job_cache
+            .insert(job_id.to_owned(), Arc::new(RwLock::new(graph)));
 
         Ok(())
     }
@@ -262,8 +261,8 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
         let mut assignments: Vec<(String, TaskDescription)> = vec![];
         let mut pending_tasks = 0usize;
         let mut assign_tasks = 0usize;
-        let job_cache = self.active_job_cache.read().await;
-        for (_job_id, graph) in job_cache.iter() {
+        for pairs in self.active_job_cache.iter() {
+            let (_job_id, graph) = pairs.pair();
             let mut graph = graph.write().await;
             for reservation in free_reservations.iter().skip(assign_tasks) {
                 if let Some(task) = 
graph.pop_next_task(&reservation.executor_id)? {
@@ -321,7 +320,13 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
         job_id: &str,
         failure_reason: String,
     ) -> Result<Vec<RunningTaskInfo>> {
-        let lock = self.state.lock(Keyspace::ActiveJobs, "").await?;
+        let locks = self
+            .state
+            .acquire_locks(vec![
+                (Keyspace::ActiveJobs, job_id),
+                (Keyspace::FailedJobs, job_id),
+            ])
+            .await?;
         if let Some(graph) = self.get_active_execution_graph(job_id).await {
             let running_tasks = graph.read().await.running_tasks();
             info!(
@@ -329,12 +334,12 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
                 running_tasks.len(),
                 job_id
             );
-            self.fail_job_state(lock, job_id, failure_reason).await?;
+            with_locks(locks, self.fail_job_state(job_id, 
failure_reason)).await?;
             Ok(running_tasks)
         } else {
             // TODO listen the job state update event and fix task cancelling
             warn!("Fail to find job {} in the cache, unable to cancel tasks 
for job, fail the job state only.", job_id);
-            self.fail_job_state(lock, job_id, failure_reason).await?;
+            with_locks(locks, self.fail_job_state(job_id, 
failure_reason)).await?;
             Ok(vec![])
         }
     }
@@ -348,44 +353,55 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
         failure_reason: String,
     ) -> Result<()> {
         debug!("Moving job {} from Active or Queue to Failed", job_id);
-        let lock = self.state.lock(Keyspace::ActiveJobs, "").await?;
-        self.fail_job_state(lock, job_id, failure_reason).await
+        let locks = self
+            .state
+            .acquire_locks(vec![
+                (Keyspace::ActiveJobs, job_id),
+                (Keyspace::FailedJobs, job_id),
+            ])
+            .await?;
+        with_locks(locks, self.fail_job_state(job_id, failure_reason)).await
     }
 
-    async fn fail_job_state(
-        &self,
-        lock: Box<dyn Lock>,
-        job_id: &str,
-        failure_reason: String,
-    ) -> Result<()> {
-        with_lock(lock, self.state.delete(Keyspace::ActiveJobs, 
job_id)).await?;
+    async fn fail_job_state(&self, job_id: &str, failure_reason: String) -> 
Result<()> {
+        let txn_operations = |value: Vec<u8>| -> Vec<(Operation, Keyspace, 
String)> {
+            vec![
+                (Operation::Delete, Keyspace::ActiveJobs, job_id.to_string()),
+                (
+                    Operation::Put(value),
+                    Keyspace::FailedJobs,
+                    job_id.to_string(),
+                ),
+            ]
+        };
 
-        let value = if let Some(graph) = 
self.get_active_execution_graph(job_id).await {
+        let _res = if let Some(graph) = 
self.get_active_execution_graph(job_id).await {
             let mut graph = graph.write().await;
-            for stage_id in graph.running_stages() {
-                graph.fail_stage(stage_id, failure_reason.clone());
-            }
+            let previous_status = graph.status();
             graph.fail_job(failure_reason);
-            let graph = graph.clone();
-            self.encode_execution_graph(graph)?
+            let value = self.encode_execution_graph(graph.clone())?;
+            let txn_ops = txn_operations(value);
+            let result = self.state.apply_txn(txn_ops).await;
+            if result.is_err() {
+                // Rollback
+                graph.update_status(previous_status);
+                warn!("Rollback Execution Graph state change since it did not 
persisted due to a possible connection error.")
+            };
+            result
         } else {
             warn!("Fail to find job {} in the cache", job_id);
-
             let status = JobStatus {
                 status: Some(job_status::Status::Failed(FailedJob {
                     error: failure_reason.clone(),
                 })),
             };
-            encode_protobuf(&status)?
+            let value = encode_protobuf(&status)?;
+            let txn_ops = txn_operations(value);
+            self.state.apply_txn(txn_ops).await
         };
 
-        self.state
-            .put(Keyspace::FailedJobs, job_id.to_owned(), value)
-            .await?;
-
         Ok(())
     }
-
     pub async fn update_job(&self, job_id: &str) -> Result<()> {
         debug!("Update job {} in Active", job_id);
         if let Some(graph) = self.get_active_execution_graph(job_id).await {
@@ -408,10 +424,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
         // Collect all the running task need to cancel when there are running 
stages rolled back.
         let mut running_tasks_to_cancel: Vec<RunningTaskInfo> = vec![];
         // Collect graphs we update so we can update them in storage
-        let mut updated_graphs: HashMap<String, ExecutionGraph> = 
HashMap::new();
+        let updated_graphs: DashMap<String, ExecutionGraph> = DashMap::new();
         {
-            let job_cache = self.active_job_cache.read().await;
-            for (job_id, graph) in job_cache.iter() {
+            for pairs in self.active_job_cache.iter() {
+                let (job_id, graph) = pairs.pair();
                 let mut graph = graph.write().await;
                 let reset = graph.reset_stages_on_lost_executor(executor_id)?;
                 if !reset.0.is_empty() {
@@ -424,14 +440,14 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
         let lock = self.state.lock(Keyspace::ActiveJobs, "").await?;
         with_lock(lock, async {
             // Transactional update graphs
-            let txn_ops: Vec<(Keyspace, String, Vec<u8>)> = updated_graphs
+            let txn_ops: Vec<(Operation, Keyspace, String)> = updated_graphs
                 .into_iter()
                 .map(|(job_id, graph)| {
                     let value = self.encode_execution_graph(graph)?;
-                    Ok((Keyspace::ActiveJobs, job_id, value))
+                    Ok((Operation::Put(value), Keyspace::ActiveJobs, job_id))
                 })
                 .collect::<Result<Vec<_>>>()?;
-            self.state.put_txn(txn_ops).await?;
+            self.state.apply_txn(txn_ops).await?;
             Ok(running_tasks_to_cancel)
         })
         .await
@@ -524,8 +540,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
         &self,
         job_id: &str,
     ) -> Option<Arc<RwLock<ExecutionGraph>>> {
-        let active_graph_cache = self.active_job_cache.read().await;
-        active_graph_cache.get(job_id).cloned()
+        self.active_job_cache.get(job_id).map(|value| value.clone())
     }
 
     /// Get the `ExecutionGraph` for the given job ID. This will search fist 
in the `ActiveJobs`


Reply via email to