Repository: cassandra
Updated Branches:
  refs/heads/trunk a5e501f09 -> a0e8de99d


cqlsh: Fix potential COPY deadlock

This deadlock could occur when the parent process is terminating child
processes (partial backport of CASSANDRA-11320).

Patch by Stefania Alborghetti; reviewed by Tyler Hobbs for
CASSANDRA-11505


Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/4389c9cf
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/4389c9cf
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/4389c9cf

Branch: refs/heads/trunk
Commit: 4389c9cfd86fb3f31a9419c44f0521604be3637b
Parents: 209ebd3
Author: Stefania Alborghetti <stefania.alborghe...@datastax.com>
Authored: Mon Apr 11 10:31:34 2016 +0800
Committer: Tyler Hobbs <tylerlho...@gmail.com>
Committed: Wed Apr 20 13:52:08 2016 -0500

----------------------------------------------------------------------
 CHANGES.txt                |   2 +
 pylib/cqlshlib/copyutil.py | 171 +++++++++++++++++++++-------------------
 2 files changed, 91 insertions(+), 82 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/cassandra/blob/4389c9cf/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index 76d3673..4a91a58 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,6 @@
 2.1.14
+ * (cqlsh) Fix potential COPY deadlock when parent process is terminating child
+   processes (CASSANDRA-11505)
  * Replace sstables on DataTracker before marking them as non-compacting 
during anti-compaction (CASSANDRA-11548)
  * Checking if an unlogged batch is local is inefficient (CASSANDRA-11529)
  * Fix paging for COMPACT tables without clustering columns (CASSANDRA-11467)

http://git-wip-us.apache.org/repos/asf/cassandra/blob/4389c9cf/pylib/cqlshlib/copyutil.py
----------------------------------------------------------------------
diff --git a/pylib/cqlshlib/copyutil.py b/pylib/cqlshlib/copyutil.py
index 28e08b1..12239d8 100644
--- a/pylib/cqlshlib/copyutil.py
+++ b/pylib/cqlshlib/copyutil.py
@@ -28,8 +28,8 @@ import random
 import re
 import struct
 import sys
-import time
 import threading
+import time
 import traceback
 
 from bisect import bisect_right
@@ -57,6 +57,7 @@ from sslhandling import ssl_settings
 
 PROFILE_ON = False
 STRACE_ON = False
+DEBUG = False  # This may be set to True when initializing the task
 IS_LINUX = platform.system() == 'Linux'
 
 CopyOptions = namedtuple('CopyOptions', 'copy dialect unrecognized')
@@ -70,6 +71,16 @@ def safe_normpath(fname):
     return os.path.normpath(os.path.expanduser(fname)) if fname else fname
 
 
+def printdebugmsg(msg):
+    if DEBUG:
+        printmsg(msg)
+
+
+def printmsg(msg, eol='\n'):
+    sys.stdout.write(msg + eol)
+    sys.stdout.flush()
+
+
 class OneWayChannel(object):
     """
     A one way pipe protected by two process level locks, one for reading and 
one for writing.
@@ -78,11 +89,49 @@ class OneWayChannel(object):
         self.reader, self.writer = mp.Pipe(duplex=False)
         self.rlock = mp.Lock()
         self.wlock = mp.Lock()
+        self.feeding_thread = None
+        self.pending_messages = None
+
+    def init_feeding_thread(self):
+        """
+        Initialize a thread that fetches messages from a queue and sends them 
to the channel.
+        We initialize the feeding thread lazily to avoid the fork(), since the 
channels are passed to child processes.
+        """
+        if self.feeding_thread is not None or self.pending_messages is not 
None:
+            raise RuntimeError("Feeding thread already initialized")
+
+        self.pending_messages = Queue()
+
+        def feed():
+            send = self._send
+            pending_messages = self.pending_messages
+
+            while True:
+                try:
+                    msg = pending_messages.get()
+                    send(msg)
+                except Exception, e:
+                    printmsg('%s: %s' % (e.__class__.__name__, e.message))
+
+        feeding_thread = threading.Thread(target=feed)
+        feeding_thread.setDaemon(True)
+        feeding_thread.start()
+
+        self.feeding_thread = feeding_thread
 
     def send(self, obj):
+        if self.feeding_thread is None:
+            self.init_feeding_thread()
+
+        self.pending_messages.put(obj)
+
+    def _send(self, obj):
         with self.wlock:
             self.writer.send(obj)
 
+    def num_pending(self):
+        return self.pending_messages.qsize() if self.pending_messages else 0
+
     def recv(self):
         with self.rlock:
             return self.reader.recv()
@@ -157,8 +206,15 @@ class CopyTask(object):
         self.fname = safe_normpath(fname)
         self.protocol_version = protocol_version
         self.config_file = config_file
-        # do not display messages when exporting to STDOUT
-        self.printmsg = self._printmsg if self.fname is not None or direction 
== 'from' else lambda _, eol='\n': None
+
+        # if cqlsh is invoked with --debug then set the global debug flag to 
True
+        if shell.debug:
+            global DEBUG
+            DEBUG = True
+
+        # do not display messages when exporting to STDOUT unless --debug is 
set
+        self.printmsg = printmsg if self.fname is not None or direction == 
'from' or DEBUG \
+            else lambda _, eol='\n': None
         self.options = self.parse_options(opts, direction)
 
         self.num_processes = self.options.copy['numprocesses']
@@ -174,11 +230,6 @@ class CopyTask(object):
         self.columns = CopyTask.get_columns(shell, ks, table, columns)
         self.time_start = time.time()
 
-    @staticmethod
-    def _printmsg(msg, eol='\n'):
-        sys.stdout.write(msg + eol)
-        sys.stdout.flush()
-
     def maybe_read_config_file(self, opts, direction):
         """
         Read optional sections from a configuration file that  was specified 
in the command options or from the default
@@ -758,7 +809,7 @@ class FilesReader(object):
             try:
                 return open(fname, 'rb')
             except IOError, e:
-                self.printmsg("Can't open %r for reading: %s" % (fname, e))
+                printdebugmsg("Can't open %r for reading: %s" % (fname, e))
                 return None
 
         for path in paths.split(','):
@@ -769,11 +820,6 @@ class FilesReader(object):
                 for f in glob.glob(path):
                     yield (make_source(f))
 
-    @staticmethod
-    def printmsg(msg, eol='\n'):
-        sys.stdout.write(msg + eol)
-        sys.stdout.flush()
-
     def start(self):
         self.sources = self.get_source(self.fname)
         self.next_source()
@@ -921,7 +967,6 @@ class ImportErrorHandler(object):
     def __init__(self, task):
         self.shell = task.shell
         self.options = task.options
-        self.printmsg = task.printmsg
         self.max_attempts = self.options.copy['maxattempts']
         self.max_parse_errors = self.options.copy['maxparseerrors']
         self.max_insert_errors = self.options.copy['maxinserterrors']
@@ -933,7 +978,7 @@ class ImportErrorHandler(object):
         if os.path.isfile(self.err_file):
             now = datetime.datetime.now()
             old_err_file = self.err_file + now.strftime('.%Y%m%d_%H%M%S')
-            self.printmsg("Renaming existing %s to %s\n" % (self.err_file, 
old_err_file))
+            printdebugmsg("Renaming existing %s to %s\n" % (self.err_file, 
old_err_file))
             os.rename(self.err_file, old_err_file)
 
     def max_exceeded(self):
@@ -1088,17 +1133,18 @@ class ImportTask(CopyTask):
             self.shell.printerr("{} child process(es) died unexpectedly, 
aborting"
                                 .format(self.num_processes - 
self.num_live_processes()))
         else:
-            # it is only safe to write to processes if they are all running 
because the feeder process
-            # at the moment hangs whilst sending messages to a crashed worker 
process; in future
-            # we could do something about this by using a BoundedSemaphore to 
keep track of how many messages are
-            # queued on a pipe
+            if self.error_handler.max_exceeded():
+                self.processes[-1].terminate()  # kill the feeder
+
             for i, _ in enumerate(self.processes):
-                self.outmsg.channels[i].send(None)
+                if self.processes[i].is_alive():
+                    self.outmsg.channels[i].send(None)
 
-            if PROFILE_ON:
-                # allow time for worker processes to write profile results 
(only works if processes received
-                # the poison pill above)
-                time.sleep(5)
+        # allow time for worker processes to exit cleanly
+        attempts = 50  # 100 milliseconds per attempt, so 5 seconds total
+        while attempts > 0 and self.num_live_processes() > 0:
+            time.sleep(0.1)
+            attempts -= 1
 
         self.printmsg("\n%d rows imported from %d files in %s (%d skipped)." %
                       (self.receive_meter.get_total_records(),
@@ -1239,12 +1285,8 @@ class ChildProcess(mp.Process):
         else:
             self.test_failures = None
 
-    def printdebugmsg(self, text):
-        if self.debug:
-            sys.stdout.write(text + '\n')
-
     def close(self):
-        self.printdebugmsg("Closing queues...")
+        printdebugmsg("Closing queues...")
         self.inmsg.close()
         self.outmsg.close()
 
@@ -1256,7 +1298,6 @@ class ExpBackoffRetryPolicy(RetryPolicy):
     def __init__(self, parent_process):
         RetryPolicy.__init__(self)
         self.max_attempts = parent_process.max_attempts
-        self.printdebugmsg = parent_process.printdebugmsg
 
     def on_read_timeout(self, query, consistency, required_responses,
                         received_responses, data_retrieved, retry_num):
@@ -1269,14 +1310,14 @@ class ExpBackoffRetryPolicy(RetryPolicy):
     def _handle_timeout(self, consistency, retry_num):
         delay = self.backoff(retry_num)
         if delay > 0:
-            self.printdebugmsg("Timeout received, retrying after %d seconds" % 
(delay,))
+            printdebugmsg("Timeout received, retrying after %d seconds" % 
(delay,))
             time.sleep(delay)
             return self.RETRY, consistency
         elif delay == 0:
-            self.printdebugmsg("Timeout received, retrying immediately")
+            printdebugmsg("Timeout received, retrying immediately")
             return self.RETRY, consistency
         else:
-            self.printdebugmsg("Timeout received, giving up after %d attempts" 
% (retry_num + 1))
+            printdebugmsg("Timeout received, giving up after %d attempts" % 
(retry_num + 1))
             return self.RETHROW, None
 
     def backoff(self, retry_num):
@@ -1309,8 +1350,8 @@ class ExportSession(object):
         session.default_fetch_size = export_process.options.copy['pagesize']
         session.default_timeout = export_process.options.copy['pagetimeout']
 
-        export_process.printdebugmsg("Created connection to %s with page size 
%d and timeout %d seconds per page"
-                                     % (cluster.contact_points, 
session.default_fetch_size, session.default_timeout))
+        printdebugmsg("Created connection to %s with page size %d and timeout 
%d seconds per page"
+                      % (cluster.contact_points, session.default_fetch_size, 
session.default_timeout))
 
         self.cluster = cluster
         self.session = session
@@ -1353,7 +1394,6 @@ class ExportProcess(ChildProcess):
         self.hosts_to_sessions = dict()
         self.formatters = dict()
         self.options = options
-        self.responses = None
 
     def run(self):
         try:
@@ -1371,8 +1411,6 @@ class ExportProcess(ChildProcess):
         we can signal a global error by sending (None, error).
         We terminate when the inbound queue is closed.
         """
-        self.init_feeder_thread()
-
         while True:
             if self.num_requests() > self.max_requests:
                 time.sleep(0.001)  # 1 millisecond
@@ -1381,56 +1419,25 @@ class ExportProcess(ChildProcess):
             token_range, info = self.inmsg.recv()
             self.start_request(token_range, info)
 
-    def init_feeder_thread(self):
-        """
-        Start a thread to feed response messages to the parent process.
-
-        It is not safe to write on the pipe from the main thread if the parent 
process is still sending work and
-        not receiving yet. This will in fact block the main thread on the 
send, which in turn won't be able to call
-        recv(), and will therefore block the parent process on its send().
-
-        It is also not safe to write on the pipe from the driver receiving 
thread whilst the parent process is
-        sending work, because if the receiving thread stops making progress, 
then the main thread may no longer
-        call recv() due to the check on the maximum number of requests in 
inner_run().
-
-        These deadlocks are easiest to reproduce with a single worker process, 
but may well affect multiple worker
-        processes too.
-
-        It is important that the order of the responses in the queue is 
respected, or else the parent process may
-        kill off worker processes before it has received all the pages of the 
last token range.
-        """
-        def feed_errors():
-            while True:
-                try:
-                    self.outmsg.send(self.responses.get())
-                except Exception, e:
-                    self.printdebugmsg(e.message)
-
-        self.responses = Queue()
-
-        thread = threading.Thread(target=feed_errors)
-        thread.setDaemon(True)
-        thread.start()
-
     @staticmethod
     def get_error_message(err, print_traceback=False):
         if isinstance(err, str):
             msg = err
         elif isinstance(err, BaseException):
             msg = "%s - %s" % (err.__class__.__name__, err)
-            if print_traceback:
-                traceback.print_exc(err)
+            if print_traceback and sys.exc_info()[1] == err:
+                traceback.print_exc()
         else:
             msg = str(err)
         return msg
 
     def report_error(self, err, token_range):
         msg = self.get_error_message(err, print_traceback=self.debug)
-        self.printdebugmsg(msg)
+        printdebugmsg(msg)
         self.send((token_range, Exception(msg)))
 
     def send(self, response):
-        self.responses.put(response)
+        self.outmsg.send(response)
 
     def start_request(self, token_range, info):
         """
@@ -1470,7 +1477,7 @@ class ExportProcess(ChildProcess):
 
             if ret:
                 if errors:
-                    self.printdebugmsg("Warning: failed to connect to some 
replicas: %s" % (errors,))
+                    printdebugmsg("Warning: failed to connect to some 
replicas: %s" % (errors,))
                 return ret
 
         self.report_error("Failed to connect to all replicas %s for %s, 
errors: %s" % (hosts, token_range, errors),
@@ -1623,7 +1630,6 @@ class ImportConversion(object):
         self.table = parent.table
         self.columns = parent.valid_columns
         self.nullval = parent.nullval
-        self.printdebugmsg = parent.printdebugmsg
         self.decimal_sep = parent.decimal_sep
         self.thousands_sep = parent.thousands_sep
         self.boolean_styles = parent.boolean_styles
@@ -1822,7 +1828,7 @@ class ImportConversion(object):
             elif issubclass(ct, ReversedType):
                 return convert_single_subtype(val, ct=ct)
 
-            self.printdebugmsg("Unknown type %s (%s) for val %s" % (ct, 
ct.typename, val))
+            printdebugmsg("Unknown type %s (%s) for val %s" % (ct, 
ct.typename, val))
             return val
 
         converters = {
@@ -2104,9 +2110,10 @@ class ImportProcess(ChildProcess):
                 chunk['rows'] = convert_rows(conv, chunk)
                 for replicas, batch in split_into_batches(chunk, conv, tm):
                     statement = make_statement(query, conv, chunk, batch, 
replicas)
-                    future = session.execute_async(statement)
-                    future.add_callbacks(callback=result_callback, 
callback_args=(batch, chunk),
-                                         errback=err_callback, 
errback_args=(batch, chunk, replicas))
+                    if statement:
+                        future = session.execute_async(statement)
+                        future.add_callbacks(callback=result_callback, 
callback_args=(batch, chunk),
+                                             errback=err_callback, 
errback_args=(batch, chunk, replicas))
 
             except Exception, exc:
                 self.report_error(exc, chunk, chunk['rows'])
@@ -2288,8 +2295,8 @@ class ImportProcess(ChildProcess):
                                  errback=self.err_callback, 
errback_args=(batch, chunk, replicas))
 
     def report_error(self, err, chunk, rows=None, attempts=1, final=True):
-        if self.debug:
-            traceback.print_exc(err)
+        if self.debug and sys.exc_info()[1] == err:
+            traceback.print_exc()
         self.outmsg.send(ImportTaskError(err.__class__.__name__, str(err), 
rows, attempts, final))
         if final:
             self.update_chunk(rows, chunk)

Reply via email to