Repository: reef
Updated Branches:
  refs/heads/master 298eaa934 -> a8fd76c19


[REEF-1060] Make TaskStatus thread safe and improve Task state management in C# 
Evaluator

This addressed the issue by
  * Tighter locking in TaskStatus.
  * State transition check fix in TaskStatus.cs.

JIRA:
  [REEF-1060](https://issues.apache.org/jira/browse/REEF-1060)

Pull Request:
  This closes #718


Project: http://git-wip-us.apache.org/repos/asf/reef/repo
Commit: http://git-wip-us.apache.org/repos/asf/reef/commit/a8fd76c1
Tree: http://git-wip-us.apache.org/repos/asf/reef/tree/a8fd76c1
Diff: http://git-wip-us.apache.org/repos/asf/reef/diff/a8fd76c1

Branch: refs/heads/master
Commit: a8fd76c193af5025bbb0b48ae7d4b8e6c862cbf7
Parents: 298eaa9
Author: Andrew Chung <[email protected]>
Authored: Wed Dec 9 17:25:51 2015 -0800
Committer: Markus Weimer <[email protected]>
Committed: Tue Dec 15 08:59:50 2015 -0800

----------------------------------------------------------------------
 .../Runtime/Evaluator/Task/TaskRuntime.cs       |  10 +-
 .../Runtime/Evaluator/Task/TaskStatus.cs        | 196 +++++++++++--------
 2 files changed, 125 insertions(+), 81 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/reef/blob/a8fd76c1/lang/cs/Org.Apache.REEF.Common/Runtime/Evaluator/Task/TaskRuntime.cs
----------------------------------------------------------------------
diff --git 
a/lang/cs/Org.Apache.REEF.Common/Runtime/Evaluator/Task/TaskRuntime.cs 
b/lang/cs/Org.Apache.REEF.Common/Runtime/Evaluator/Task/TaskRuntime.cs
index 81b1b06..3a08dab 100644
--- a/lang/cs/Org.Apache.REEF.Common/Runtime/Evaluator/Task/TaskRuntime.cs
+++ b/lang/cs/Org.Apache.REEF.Common/Runtime/Evaluator/Task/TaskRuntime.cs
@@ -121,6 +121,7 @@ namespace Org.Apache.REEF.Common.Runtime.Evaluator.Task
                 throw new InvalidOperationException("TaskRun has already been 
called on TaskRuntime.");
             }
 
+            // Send heartbeat such that user receives a TaskRunning message.
             _currentStatus.SetRunning();
             ITask userTask;
             try
@@ -133,8 +134,11 @@ namespace Org.Apache.REEF.Common.Runtime.Evaluator.Task
                 return;
             }
 
-            System.Threading.Tasks.Task.Run(() => 
userTask.Call(null)).ContinueWith(
-                runTask =>
+            System.Threading.Tasks.Task.Run(() =>
+            {
+                Logger.Log(Level.Info, "Calling into user's task.");
+                return userTask.Call(null);
+            }).ContinueWith(runTask =>
                 {
                     try
                     {
@@ -278,7 +282,7 @@ namespace Org.Apache.REEF.Common.Runtime.Evaluator.Task
             catch (Exception e)
             {
                 Utilities.Diagnostics.Exceptions.Caught(e, Level.Warning, 
"Exception throw when handling driver message: " + e, Logger);
-                _currentStatus.RecordExecptionWithoutHeartbeat(e);
+                _currentStatus.SetException(e);
             }
         }
 

http://git-wip-us.apache.org/repos/asf/reef/blob/a8fd76c1/lang/cs/Org.Apache.REEF.Common/Runtime/Evaluator/Task/TaskStatus.cs
----------------------------------------------------------------------
diff --git 
a/lang/cs/Org.Apache.REEF.Common/Runtime/Evaluator/Task/TaskStatus.cs 
b/lang/cs/Org.Apache.REEF.Common/Runtime/Evaluator/Task/TaskStatus.cs
index c0c19da..8037639 100644
--- a/lang/cs/Org.Apache.REEF.Common/Runtime/Evaluator/Task/TaskStatus.cs
+++ b/lang/cs/Org.Apache.REEF.Common/Runtime/Evaluator/Task/TaskStatus.cs
@@ -30,24 +30,26 @@ namespace Org.Apache.REEF.Common.Runtime.Evaluator.Task
     internal sealed class TaskStatus
     {
         private static readonly Logger LOGGER = 
Logger.GetLogger(typeof(TaskStatus));
+
+        private readonly object _stateLock = new object();
         private readonly TaskLifeCycle _taskLifeCycle;
         private readonly HeartBeatManager _heartBeatManager;
         private readonly Optional<ISet<ITaskMessageSource>> 
_evaluatorMessageSources;
-
         private readonly string _taskId;
         private readonly string _contextId;
+
         private Optional<Exception> _lastException = 
Optional<Exception>.Empty();
         private Optional<byte[]> _result = Optional<byte[]>.Empty();
         private TaskState _state;
 
         public TaskStatus(HeartBeatManager heartBeatManager, string contextId, 
string taskId, Optional<ISet<ITaskMessageSource>> evaluatorMessageSources)
         {
-            _contextId = contextId;
-            _taskId = taskId;
             _heartBeatManager = heartBeatManager;
             _taskLifeCycle = new TaskLifeCycle();
             _evaluatorMessageSources = evaluatorMessageSources;
             State = TaskState.Init;
+            _taskId = taskId;
+            _contextId = contextId;
         }
 
         public TaskState State
@@ -57,7 +59,7 @@ namespace Org.Apache.REEF.Common.Runtime.Evaluator.Task
                 return _state;
             }
 
-            set
+            private set
             {
                 if (IsLegalStateTransition(_state, value))
                 {
@@ -67,7 +69,7 @@ namespace Org.Apache.REEF.Common.Runtime.Evaluator.Task
                 {
                     string message = 
string.Format(CultureInfo.InvariantCulture, "Illegal state transition from 
[{0}] to [{1}]", _state, value);
                     LOGGER.Log(Level.Error, message);
-                    Org.Apache.REEF.Utilities.Diagnostics.Exceptions.Throw(new 
InvalidOperationException(message), LOGGER);
+                    Utilities.Diagnostics.Exceptions.Throw(new 
InvalidOperationException(message), LOGGER);
                 }
             }
         }
@@ -84,66 +86,102 @@ namespace Org.Apache.REEF.Common.Runtime.Evaluator.Task
 
         public void SetException(Exception e)
         {
-            RecordExecptionWithoutHeartbeat(e);
-            Heartbeat();
-            _lastException = Optional<Exception>.Empty();
+            lock (_stateLock)
+            {
+                if (!_lastException.IsPresent())
+                {
+                    _lastException = Optional<Exception>.Of(e);
+                }
+                State = TaskState.Failed;
+                _taskLifeCycle.Stop();
+                Heartbeat();
+            }
         }
 
         public void SetResult(byte[] result)
         {
-            _result = Optional<byte[]>.OfNullable(result);
-            if (State == TaskState.Running)
-            {
-                State = TaskState.Done;
-            }
-            else if (State == TaskState.SuspendRequested)
-            {
-                State = TaskState.Suspended;
-            }
-            else if (State == TaskState.CloseRequested)
+            lock (_stateLock)
             {
-                State = TaskState.Done;
+                _result = Optional<byte[]>.OfNullable(result);
+                switch (State)
+                {
+                    case TaskState.SuspendRequested:
+                        State = TaskState.Suspended;
+                        break;
+                    case TaskState.Running:
+                    case TaskState.CloseRequested:
+                        State = TaskState.Done;
+                        break;
+                }
+                _taskLifeCycle.Stop();
+                Heartbeat();
             }
-            _taskLifeCycle.Stop();
-            Heartbeat();
         }
 
         public void SetRunning()
         {
-            LOGGER.Log(Level.Verbose, "TaskStatus::SetRunning");
-            if (_state == TaskState.Init)
+            lock (_stateLock)
             {
-                try
-                {
-                    _taskLifeCycle.Start();
-
-                    // Need to send an INIT heartbeat to the driver prompting 
it to create an RunningTask event. 
-                    LOGGER.Log(Level.Info, 
string.Format(CultureInfo.InvariantCulture, "Sending task INIT heartbeat"));
-                    Heartbeat();
-                    State = TaskState.Running;
-                }
-                catch (Exception e)
+                LOGGER.Log(Level.Verbose, "TaskStatus::SetRunning");
+                if (_state == TaskState.Init)
                 {
-                    Org.Apache.REEF.Utilities.Diagnostics.Exceptions.Caught(e, 
Level.Error, "Cannot set task status to running.", LOGGER);
-                    SetException(e);
+                    try
+                    {
+                        _taskLifeCycle.Start();
+                        
+                        // Need to send an INIT heartbeat to the driver 
prompting it to create an RunningTask event. 
+                        LOGGER.Log(Level.Info, "Sending task INIT heartbeat");
+                        Heartbeat();
+                        State = TaskState.Running;
+                    }
+                    catch (Exception e)
+                    {
+                        Utilities.Diagnostics.Exceptions.Caught(e, 
Level.Error, "Cannot set task status to running.", LOGGER);
+                        SetException(e);
+                    }
                 }
             }
         }
 
         public void SetCloseRequested()
         {
-            State = TaskState.CloseRequested;
+            lock (_stateLock)
+            {
+                if (HasEnded())
+                {
+                    return;
+                }
+
+                State = TaskState.CloseRequested;
+            }
         }
 
         public void SetSuspendRequested()
         {
-            State = TaskState.SuspendRequested;
+            lock (_stateLock)
+            {
+                if (HasEnded())
+                {
+                    return;
+                }
+
+                State = TaskState.SuspendRequested;
+            }
         }
 
         public void SetKilled()
         {
-            State = TaskState.Killed;
-            Heartbeat();
+            lock (_stateLock)
+            {
+                if (HasEnded())
+                {
+                    LOGGER.Log(Level.Warning, "Trying to kill a task that is 
in {0} state. Ignored.", State);
+                    return;
+                }
+
+                State = TaskState.Killed;
+                Heartbeat();
+            }
         }
 
         public bool IsNotRunning()
@@ -167,47 +205,40 @@ namespace Org.Apache.REEF.Common.Runtime.Evaluator.Task
 
         public TaskStatusProto ToProto()
         {
-            Check();
-            TaskStatusProto taskStatusProto = new TaskStatusProto()
+            // This is locked because the Task continuation thread which sets 
the
+            // result is potentially different from the HeartBeat thread.
+            lock (_stateLock)
             {
-                context_id = _contextId,
-                task_id = _taskId,
-                state = GetProtoState(),
-            };
-            if (_result.IsPresent())
-            {
-                taskStatusProto.result = 
ByteUtilities.CopyBytesFrom(_result.Value);
-            }
-            else if (_lastException.IsPresent())
-            {
-                // final Encoder<Throwable> codec = new 
ObjectSerializableCodec<>();
-                // final byte[] error = codec.encode(_lastException.get());
-                byte[] error = 
ByteUtilities.StringToByteArrays(_lastException.Value.ToString());
-                taskStatusProto.result = ByteUtilities.CopyBytesFrom(error);
-            }
-            else if (_state == TaskState.Running)
-            {
-                foreach (TaskMessage message in GetMessages())
+                Check();
+                TaskStatusProto taskStatusProto = new TaskStatusProto()
+                {
+                    context_id = ContextId,
+                    task_id = TaskId,
+                    state = GetProtoState()
+                };
+                if (_result.IsPresent())
+                {
+                    taskStatusProto.result = 
ByteUtilities.CopyBytesFrom(_result.Value);
+                }
+                else if (_lastException.IsPresent())
                 {
-                    TaskStatusProto.TaskMessageProto taskMessageProto = new 
TaskStatusProto.TaskMessageProto()
+                    byte[] error = 
ByteUtilities.StringToByteArrays(_lastException.Value.ToString());
+                    taskStatusProto.result = 
ByteUtilities.CopyBytesFrom(error);
+                }
+                else if (_state == TaskState.Running)
+                {
+                    foreach (TaskMessage message in GetMessages())
                     {
-                        source_id = message.MessageSourceId,
-                        message = ByteUtilities.CopyBytesFrom(message.Message),
-                    };
-                    taskStatusProto.task_message.Add(taskMessageProto);
+                        TaskStatusProto.TaskMessageProto taskMessageProto = 
new TaskStatusProto.TaskMessageProto()
+                        {
+                            source_id = message.MessageSourceId,
+                            message = 
ByteUtilities.CopyBytesFrom(message.Message),
+                        };
+                        taskStatusProto.task_message.Add(taskMessageProto);
+                    }
                 }
+                return taskStatusProto;
             }
-            return taskStatusProto;
-        }
-
-        internal void RecordExecptionWithoutHeartbeat(Exception e)
-        {
-            if (!_lastException.IsPresent())
-            {
-                _lastException = Optional<Exception>.Of(e);
-            }
-            State = TaskState.Failed;
-            _taskLifeCycle.Stop();
         }
 
         private static bool IsLegalStateTransition(TaskState? from, TaskState 
to)
@@ -216,6 +247,13 @@ namespace Org.Apache.REEF.Common.Runtime.Evaluator.Task
             {
                 return to == TaskState.Init;
             }
+
+            if (from == to)
+            {
+                LOGGER.Log(Level.Warning, "Transitioning to the same state 
from {0} to {1}.", from, to);
+                return true;
+            }
+
             switch (from)
             {
                 case TaskState.Init:
@@ -265,9 +303,11 @@ namespace Org.Apache.REEF.Common.Runtime.Evaluator.Task
 
                 case TaskState.Failed:
                 case TaskState.Done:
-                case TaskState.Killed:           
+                case TaskState.Killed:
+                    return false;
                 default:
-                    return true;
+                    LOGGER.Log(Level.Error, "Unknown \"from\" state: {0}", 
from);
+                    return false;
             }
         }
 
@@ -305,7 +345,7 @@ namespace Org.Apache.REEF.Common.Runtime.Evaluator.Task
                 case TaskState.Killed:
                     return Protobuf.ReefProtocol.State.KILLED;
                 default:
-                    Org.Apache.REEF.Utilities.Diagnostics.Exceptions.Throw(new 
InvalidOperationException("Unknown state: " + _state), LOGGER);
+                    Utilities.Diagnostics.Exceptions.Throw(new 
InvalidOperationException("Unknown state: " + _state), LOGGER);
                     break;
             }
             return Protobuf.ReefProtocol.State.FAILED; // this line should not 
be reached as default case will throw exception

Reply via email to