Repository: cassandra Updated Branches: refs/heads/trunk a5e501f09 -> a0e8de99d
cqlsh: Fix potential COPY deadlock This deadlock could occur when the parent process is terminating child processes (partial backport of CASSANDRA-11320). Patch by Stefania Alborghetti; reviewed by Tyler Hobbs for CASSANDRA-11505 Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/4389c9cf Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/4389c9cf Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/4389c9cf Branch: refs/heads/trunk Commit: 4389c9cfd86fb3f31a9419c44f0521604be3637b Parents: 209ebd3 Author: Stefania Alborghetti <stefania.alborghe...@datastax.com> Authored: Mon Apr 11 10:31:34 2016 +0800 Committer: Tyler Hobbs <tylerlho...@gmail.com> Committed: Wed Apr 20 13:52:08 2016 -0500 ---------------------------------------------------------------------- CHANGES.txt | 2 + pylib/cqlshlib/copyutil.py | 171 +++++++++++++++++++++------------------- 2 files changed, 91 insertions(+), 82 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/cassandra/blob/4389c9cf/CHANGES.txt ---------------------------------------------------------------------- diff --git a/CHANGES.txt b/CHANGES.txt index 76d3673..4a91a58 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,4 +1,6 @@ 2.1.14 + * (cqlsh) Fix potential COPY deadlock when parent process is terminating child + processes (CASSANDRA-11505) * Replace sstables on DataTracker before marking them as non-compacting during anti-compaction (CASSANDRA-11548) * Checking if an unlogged batch is local is inefficient (CASSANDRA-11529) * Fix paging for COMPACT tables without clustering columns (CASSANDRA-11467) http://git-wip-us.apache.org/repos/asf/cassandra/blob/4389c9cf/pylib/cqlshlib/copyutil.py ---------------------------------------------------------------------- diff --git a/pylib/cqlshlib/copyutil.py b/pylib/cqlshlib/copyutil.py index 28e08b1..12239d8 100644 --- a/pylib/cqlshlib/copyutil.py +++ b/pylib/cqlshlib/copyutil.py @@ -28,8 +28,8 @@ import random import re import struct import sys -import time import threading +import time import traceback from bisect import bisect_right @@ -57,6 +57,7 @@ from sslhandling import ssl_settings PROFILE_ON = False STRACE_ON = False +DEBUG = False # This may be set to True when initializing the task IS_LINUX = platform.system() == 'Linux' CopyOptions = namedtuple('CopyOptions', 'copy dialect unrecognized') @@ -70,6 +71,16 @@ def safe_normpath(fname): return os.path.normpath(os.path.expanduser(fname)) if fname else fname +def printdebugmsg(msg): + if DEBUG: + printmsg(msg) + + +def printmsg(msg, eol='\n'): + sys.stdout.write(msg + eol) + sys.stdout.flush() + + class OneWayChannel(object): """ A one way pipe protected by two process level locks, one for reading and one for writing. @@ -78,11 +89,49 @@ class OneWayChannel(object): self.reader, self.writer = mp.Pipe(duplex=False) self.rlock = mp.Lock() self.wlock = mp.Lock() + self.feeding_thread = None + self.pending_messages = None + + def init_feeding_thread(self): + """ + Initialize a thread that fetches messages from a queue and sends them to the channel. + We initialize the feeding thread lazily to avoid the fork(), since the channels are passed to child processes. + """ + if self.feeding_thread is not None or self.pending_messages is not None: + raise RuntimeError("Feeding thread already initialized") + + self.pending_messages = Queue() + + def feed(): + send = self._send + pending_messages = self.pending_messages + + while True: + try: + msg = pending_messages.get() + send(msg) + except Exception, e: + printmsg('%s: %s' % (e.__class__.__name__, e.message)) + + feeding_thread = threading.Thread(target=feed) + feeding_thread.setDaemon(True) + feeding_thread.start() + + self.feeding_thread = feeding_thread def send(self, obj): + if self.feeding_thread is None: + self.init_feeding_thread() + + self.pending_messages.put(obj) + + def _send(self, obj): with self.wlock: self.writer.send(obj) + def num_pending(self): + return self.pending_messages.qsize() if self.pending_messages else 0 + def recv(self): with self.rlock: return self.reader.recv() @@ -157,8 +206,15 @@ class CopyTask(object): self.fname = safe_normpath(fname) self.protocol_version = protocol_version self.config_file = config_file - # do not display messages when exporting to STDOUT - self.printmsg = self._printmsg if self.fname is not None or direction == 'from' else lambda _, eol='\n': None + + # if cqlsh is invoked with --debug then set the global debug flag to True + if shell.debug: + global DEBUG + DEBUG = True + + # do not display messages when exporting to STDOUT unless --debug is set + self.printmsg = printmsg if self.fname is not None or direction == 'from' or DEBUG \ + else lambda _, eol='\n': None self.options = self.parse_options(opts, direction) self.num_processes = self.options.copy['numprocesses'] @@ -174,11 +230,6 @@ class CopyTask(object): self.columns = CopyTask.get_columns(shell, ks, table, columns) self.time_start = time.time() - @staticmethod - def _printmsg(msg, eol='\n'): - sys.stdout.write(msg + eol) - sys.stdout.flush() - def maybe_read_config_file(self, opts, direction): """ Read optional sections from a configuration file that was specified in the command options or from the default @@ -758,7 +809,7 @@ class FilesReader(object): try: return open(fname, 'rb') except IOError, e: - self.printmsg("Can't open %r for reading: %s" % (fname, e)) + printdebugmsg("Can't open %r for reading: %s" % (fname, e)) return None for path in paths.split(','): @@ -769,11 +820,6 @@ class FilesReader(object): for f in glob.glob(path): yield (make_source(f)) - @staticmethod - def printmsg(msg, eol='\n'): - sys.stdout.write(msg + eol) - sys.stdout.flush() - def start(self): self.sources = self.get_source(self.fname) self.next_source() @@ -921,7 +967,6 @@ class ImportErrorHandler(object): def __init__(self, task): self.shell = task.shell self.options = task.options - self.printmsg = task.printmsg self.max_attempts = self.options.copy['maxattempts'] self.max_parse_errors = self.options.copy['maxparseerrors'] self.max_insert_errors = self.options.copy['maxinserterrors'] @@ -933,7 +978,7 @@ class ImportErrorHandler(object): if os.path.isfile(self.err_file): now = datetime.datetime.now() old_err_file = self.err_file + now.strftime('.%Y%m%d_%H%M%S') - self.printmsg("Renaming existing %s to %s\n" % (self.err_file, old_err_file)) + printdebugmsg("Renaming existing %s to %s\n" % (self.err_file, old_err_file)) os.rename(self.err_file, old_err_file) def max_exceeded(self): @@ -1088,17 +1133,18 @@ class ImportTask(CopyTask): self.shell.printerr("{} child process(es) died unexpectedly, aborting" .format(self.num_processes - self.num_live_processes())) else: - # it is only safe to write to processes if they are all running because the feeder process - # at the moment hangs whilst sending messages to a crashed worker process; in future - # we could do something about this by using a BoundedSemaphore to keep track of how many messages are - # queued on a pipe + if self.error_handler.max_exceeded(): + self.processes[-1].terminate() # kill the feeder + for i, _ in enumerate(self.processes): - self.outmsg.channels[i].send(None) + if self.processes[i].is_alive(): + self.outmsg.channels[i].send(None) - if PROFILE_ON: - # allow time for worker processes to write profile results (only works if processes received - # the poison pill above) - time.sleep(5) + # allow time for worker processes to exit cleanly + attempts = 50 # 100 milliseconds per attempt, so 5 seconds total + while attempts > 0 and self.num_live_processes() > 0: + time.sleep(0.1) + attempts -= 1 self.printmsg("\n%d rows imported from %d files in %s (%d skipped)." % (self.receive_meter.get_total_records(), @@ -1239,12 +1285,8 @@ class ChildProcess(mp.Process): else: self.test_failures = None - def printdebugmsg(self, text): - if self.debug: - sys.stdout.write(text + '\n') - def close(self): - self.printdebugmsg("Closing queues...") + printdebugmsg("Closing queues...") self.inmsg.close() self.outmsg.close() @@ -1256,7 +1298,6 @@ class ExpBackoffRetryPolicy(RetryPolicy): def __init__(self, parent_process): RetryPolicy.__init__(self) self.max_attempts = parent_process.max_attempts - self.printdebugmsg = parent_process.printdebugmsg def on_read_timeout(self, query, consistency, required_responses, received_responses, data_retrieved, retry_num): @@ -1269,14 +1310,14 @@ class ExpBackoffRetryPolicy(RetryPolicy): def _handle_timeout(self, consistency, retry_num): delay = self.backoff(retry_num) if delay > 0: - self.printdebugmsg("Timeout received, retrying after %d seconds" % (delay,)) + printdebugmsg("Timeout received, retrying after %d seconds" % (delay,)) time.sleep(delay) return self.RETRY, consistency elif delay == 0: - self.printdebugmsg("Timeout received, retrying immediately") + printdebugmsg("Timeout received, retrying immediately") return self.RETRY, consistency else: - self.printdebugmsg("Timeout received, giving up after %d attempts" % (retry_num + 1)) + printdebugmsg("Timeout received, giving up after %d attempts" % (retry_num + 1)) return self.RETHROW, None def backoff(self, retry_num): @@ -1309,8 +1350,8 @@ class ExportSession(object): session.default_fetch_size = export_process.options.copy['pagesize'] session.default_timeout = export_process.options.copy['pagetimeout'] - export_process.printdebugmsg("Created connection to %s with page size %d and timeout %d seconds per page" - % (cluster.contact_points, session.default_fetch_size, session.default_timeout)) + printdebugmsg("Created connection to %s with page size %d and timeout %d seconds per page" + % (cluster.contact_points, session.default_fetch_size, session.default_timeout)) self.cluster = cluster self.session = session @@ -1353,7 +1394,6 @@ class ExportProcess(ChildProcess): self.hosts_to_sessions = dict() self.formatters = dict() self.options = options - self.responses = None def run(self): try: @@ -1371,8 +1411,6 @@ class ExportProcess(ChildProcess): we can signal a global error by sending (None, error). We terminate when the inbound queue is closed. """ - self.init_feeder_thread() - while True: if self.num_requests() > self.max_requests: time.sleep(0.001) # 1 millisecond @@ -1381,56 +1419,25 @@ class ExportProcess(ChildProcess): token_range, info = self.inmsg.recv() self.start_request(token_range, info) - def init_feeder_thread(self): - """ - Start a thread to feed response messages to the parent process. - - It is not safe to write on the pipe from the main thread if the parent process is still sending work and - not receiving yet. This will in fact block the main thread on the send, which in turn won't be able to call - recv(), and will therefore block the parent process on its send(). - - It is also not safe to write on the pipe from the driver receiving thread whilst the parent process is - sending work, because if the receiving thread stops making progress, then the main thread may no longer - call recv() due to the check on the maximum number of requests in inner_run(). - - These deadlocks are easiest to reproduce with a single worker process, but may well affect multiple worker - processes too. - - It is important that the order of the responses in the queue is respected, or else the parent process may - kill off worker processes before it has received all the pages of the last token range. - """ - def feed_errors(): - while True: - try: - self.outmsg.send(self.responses.get()) - except Exception, e: - self.printdebugmsg(e.message) - - self.responses = Queue() - - thread = threading.Thread(target=feed_errors) - thread.setDaemon(True) - thread.start() - @staticmethod def get_error_message(err, print_traceback=False): if isinstance(err, str): msg = err elif isinstance(err, BaseException): msg = "%s - %s" % (err.__class__.__name__, err) - if print_traceback: - traceback.print_exc(err) + if print_traceback and sys.exc_info()[1] == err: + traceback.print_exc() else: msg = str(err) return msg def report_error(self, err, token_range): msg = self.get_error_message(err, print_traceback=self.debug) - self.printdebugmsg(msg) + printdebugmsg(msg) self.send((token_range, Exception(msg))) def send(self, response): - self.responses.put(response) + self.outmsg.send(response) def start_request(self, token_range, info): """ @@ -1470,7 +1477,7 @@ class ExportProcess(ChildProcess): if ret: if errors: - self.printdebugmsg("Warning: failed to connect to some replicas: %s" % (errors,)) + printdebugmsg("Warning: failed to connect to some replicas: %s" % (errors,)) return ret self.report_error("Failed to connect to all replicas %s for %s, errors: %s" % (hosts, token_range, errors), @@ -1623,7 +1630,6 @@ class ImportConversion(object): self.table = parent.table self.columns = parent.valid_columns self.nullval = parent.nullval - self.printdebugmsg = parent.printdebugmsg self.decimal_sep = parent.decimal_sep self.thousands_sep = parent.thousands_sep self.boolean_styles = parent.boolean_styles @@ -1822,7 +1828,7 @@ class ImportConversion(object): elif issubclass(ct, ReversedType): return convert_single_subtype(val, ct=ct) - self.printdebugmsg("Unknown type %s (%s) for val %s" % (ct, ct.typename, val)) + printdebugmsg("Unknown type %s (%s) for val %s" % (ct, ct.typename, val)) return val converters = { @@ -2104,9 +2110,10 @@ class ImportProcess(ChildProcess): chunk['rows'] = convert_rows(conv, chunk) for replicas, batch in split_into_batches(chunk, conv, tm): statement = make_statement(query, conv, chunk, batch, replicas) - future = session.execute_async(statement) - future.add_callbacks(callback=result_callback, callback_args=(batch, chunk), - errback=err_callback, errback_args=(batch, chunk, replicas)) + if statement: + future = session.execute_async(statement) + future.add_callbacks(callback=result_callback, callback_args=(batch, chunk), + errback=err_callback, errback_args=(batch, chunk, replicas)) except Exception, exc: self.report_error(exc, chunk, chunk['rows']) @@ -2288,8 +2295,8 @@ class ImportProcess(ChildProcess): errback=self.err_callback, errback_args=(batch, chunk, replicas)) def report_error(self, err, chunk, rows=None, attempts=1, final=True): - if self.debug: - traceback.print_exc(err) + if self.debug and sys.exc_info()[1] == err: + traceback.print_exc() self.outmsg.send(ImportTaskError(err.__class__.__name__, str(err), rows, attempts, final)) if final: self.update_chunk(rows, chunk)