cqlsh copy: fixed possible race in initializing feeding thread

patch by Stefania Alborghetti; reviewed by Paulo Motta for CASSANDRA-11701


Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/59da2756
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/59da2756
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/59da2756

Branch: refs/heads/trunk
Commit: 59da27560c1532bb4956c661da25992840996706
Parents: f03b10c
Author: Stefania Alborghetti <stefania.alborghe...@datastax.com>
Authored: Tue May 3 09:46:33 2016 +0800
Committer: Stefania Alborghetti <stefania.alborghe...@datastax.com>
Committed: Mon Aug 22 09:00:49 2016 +0800

----------------------------------------------------------------------
 CHANGES.txt                |   1 +
 pylib/cqlshlib/copyutil.py | 151 +++++++++++++++++++++++++---------------
 2 files changed, 95 insertions(+), 57 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/cassandra/blob/59da2756/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index c421398..d28e419 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
 2.2.8
+ * cqlsh copy: fixed possible race in initializing feeding thread 
(CASSANDRA-11701)
  * Only set broadcast_rpc_address on Ec2MultiRegionSnitch if it's not set 
(CASSANDRA-11357)
  * Update StorageProxy range metrics for timeouts, failures and unavailables 
(CASSANDRA-9507)
  * Add Sigar to classes included in clientutil.jar (CASSANDRA-11635)

http://git-wip-us.apache.org/repos/asf/cassandra/blob/59da2756/pylib/cqlshlib/copyutil.py
----------------------------------------------------------------------
diff --git a/pylib/cqlshlib/copyutil.py b/pylib/cqlshlib/copyutil.py
index 14172ef..460ae6a 100644
--- a/pylib/cqlshlib/copyutil.py
+++ b/pylib/cqlshlib/copyutil.py
@@ -83,7 +83,7 @@ def printmsg(msg, eol='\n', encoding='utf8'):
     sys.stdout.flush()
 
 
-class OneWayChannel(object):
+class OneWayPipe(object):
     """
     A one way pipe protected by two process level locks, one for reading and 
one for writing.
     """
@@ -91,27 +91,47 @@ class OneWayChannel(object):
         self.reader, self.writer = mp.Pipe(duplex=False)
         self.rlock = mp.Lock()
         self.wlock = mp.Lock()
-        self.feeding_thread = None
-        self.pending_messages = None
 
-    def init_feeding_thread(self):
-        """
-        Initialize a thread that fetches messages from a queue and sends them 
to the channel.
-        We initialize the feeding thread lazily to avoid the fork(), since the 
channels are passed to child processes.
-        """
-        if self.feeding_thread is not None or self.pending_messages is not 
None:
-            raise RuntimeError("Feeding thread already initialized")
+    def send(self, obj):
+        with self.wlock:
+            self.writer.send(obj)
+
+    def recv(self):
+        with self.rlock:
+            return self.reader.recv()
+
+    def close(self):
+        self.reader.close()
+        self.writer.close()
 
+
+class ReceivingChannel(object):
+    """
+    A one way channel that wraps a pipe to receive messages.
+    """
+    def __init__(self, pipe):
+        self.pipe = pipe
+
+    def recv(self):
+        return self.pipe.recv()
+
+    def close(self):
+        self.pipe.close()
+
+
+class SendingChannel(object):
+    """
+    A one way channel that wraps a pipe and provides a feeding thread to send 
messages asynchronously.
+    """
+    def __init__(self, pipe):
+        self.pipe = pipe
         self.pending_messages = Queue()
 
         def feed():
-            send = self._send
-            pending_messages = self.pending_messages
-
             while True:
                 try:
-                    msg = pending_messages.get()
-                    send(msg)
+                    msg = self.pending_messages.get()
+                    self.pipe.send(msg)
                 except Exception, e:
                     printmsg('%s: %s' % (e.__class__.__name__, e.message))
 
@@ -119,39 +139,43 @@ class OneWayChannel(object):
         feeding_thread.setDaemon(True)
         feeding_thread.start()
 
-        self.feeding_thread = feeding_thread
-
     def send(self, obj):
-        if self.feeding_thread is None:
-            self.init_feeding_thread()
-
         self.pending_messages.put(obj)
 
-    def _send(self, obj):
-        with self.wlock:
-            self.writer.send(obj)
-
     def num_pending(self):
         return self.pending_messages.qsize() if self.pending_messages else 0
 
-    def recv(self):
-        with self.rlock:
-            return self.reader.recv()
+    def close(self):
+        self.pipe.close()
+
+
+class SendingChannels(object):
+    """
+    A group of one way channels for sending messages.
+    """
+    def __init__(self, num_channels):
+        self.pipes = [OneWayPipe() for _ in xrange(num_channels)]
+        self.channels = [SendingChannel(p) for p in self.pipes]
+        self.num_channels = num_channels
 
     def close(self):
-        self.reader.close()
-        self.writer.close()
+        for ch in self.channels:
+            try:
+                ch.close()
+            except:
+                pass
 
 
-class OneWayChannels(object):
+class ReceivingChannels(object):
     """
-    A group of one way channels.
+    A group of one way channels for receiving messages.
     """
     def __init__(self, num_channels):
-        self.channels = [OneWayChannel() for _ in xrange(num_channels)]
-        self._readers = [ch.reader for ch in self.channels]
-        self._rlocks = [ch.rlock for ch in self.channels]
-        self._rlocks_by_readers = dict([(ch.reader, ch.rlock) for ch in 
self.channels])
+        self.pipes = [OneWayPipe() for _ in xrange(num_channels)]
+        self.channels = [ReceivingChannel(p) for p in self.pipes]
+        self._readers = [p.reader for p in self.pipes]
+        self._rlocks = [p.rlock for p in self.pipes]
+        self._rlocks_by_readers = dict([(p.reader, p.rlock) for p in 
self.pipes])
         self.num_channels = num_channels
 
         self.recv = self.recv_select if IS_LINUX else self.recv_polling
@@ -228,8 +252,8 @@ class CopyTask(object):
             self.num_processes += 1  # add the feeder process
 
         self.processes = []
-        self.inmsg = OneWayChannels(self.num_processes)
-        self.outmsg = OneWayChannels(self.num_processes)
+        self.inmsg = ReceivingChannels(self.num_processes)
+        self.outmsg = SendingChannels(self.num_processes)
 
         self.columns = CopyTask.get_columns(shell, ks, table, columns)
         self.time_start = time.time()
@@ -466,13 +490,13 @@ class CopyTask(object):
 
     def update_params(self, params, i):
         """
-        Add the communication channels to the parameters to be passed to the 
worker process:
-            inmsg is the message queue flowing from parent to child process, 
so outmsg from the parent point
-            of view and, vice-versa,  outmsg is the message queue flowing from 
child to parent, so inmsg
+        Add the communication pipes to the parameters to be passed to the 
worker process:
+            inpipe is the message pipe flowing from parent to child process, 
so outpipe from the parent point
+            of view and, vice-versa,  outpipe is the message pipe flowing from 
child to parent, so inpipe
             from the parent point of view, hence the two are swapped below.
         """
-        params['inmsg'] = self.outmsg.channels[i]
-        params['outmsg'] = self.inmsg.channels[i]
+        params['inpipe'] = self.outmsg.pipes[i]
+        params['outpipe'] = self.inmsg.pipes[i]
         return params
 
 
@@ -912,8 +936,8 @@ class PipeReader(object):
     """
     A class for reading rows received on a pipe, this is used for reading 
input from STDIN
     """
-    def __init__(self, inmsg, options):
-        self.inmsg = inmsg
+    def __init__(self, inpipe, options):
+        self.inpipe = inpipe
         self.chunk_size = options.copy['chunksize']
         self.header = options.copy['header']
         self.max_rows = options.copy['maxrows']
@@ -928,7 +952,7 @@ class PipeReader(object):
     def read_rows(self, max_rows):
         rows = []
         for i in xrange(min(max_rows, self.chunk_size)):
-            row = self.inmsg.recv()
+            row = self.inpipe.recv()
             if row is None:
                 self.exhausted = True
                 break
@@ -1108,8 +1132,8 @@ class ImportTask(CopyTask):
             for i in range(self.num_processes - 1):
                 self.processes.append(ImportProcess(self.update_params(params, 
i)))
 
-            feeder = FeedingProcess(self.outmsg.channels[-1], 
self.inmsg.channels[-1],
-                                    self.outmsg.channels[:-1], self.fname, 
self.options,
+            feeder = FeedingProcess(self.outmsg.pipes[-1], 
self.inmsg.pipes[-1],
+                                    self.outmsg.pipes[:-1], self.fname, 
self.options,
                                     self.shell.conn if not IS_WINDOWS else 
None)
             self.processes.append(feeder)
 
@@ -1139,7 +1163,7 @@ class ImportTask(CopyTask):
         """
         shell = self.shell
 
-        self.printmsg("[Use \. on a line by itself to end input]")
+        self.printmsg("[Use . on a line by itself to end input]")
         for row in shell.use_stdin_reader(prompt='[copy] ', until=r'.'):
             self.outmsg.channels[-1].send(row)
 
@@ -1217,12 +1241,15 @@ class FeedingProcess(mp.Process):
     """
     A process that reads from import sources and sends chunks to worker 
processes.
     """
-    def __init__(self, inmsg, outmsg, worker_channels, fname, options, 
parent_cluster):
+    def __init__(self, inpipe, outpipe, worker_pipes, fname, options, 
parent_cluster):
         mp.Process.__init__(self, target=self.run)
-        self.inmsg = inmsg
-        self.outmsg = outmsg
-        self.worker_channels = worker_channels
-        self.reader = FilesReader(fname, options) if fname else 
PipeReader(inmsg, options)
+        self.inpipe = inpipe
+        self.outpipe = outpipe
+        self.worker_pipes = worker_pipes
+        self.inmsg = None  # must be created after forking on Windows
+        self.outmsg = None  # must be created after forking on Windows
+        self.worker_channels = None  # must be created after forking on Windows
+        self.reader = FilesReader(fname, options) if fname else 
PipeReader(inpipe, options)
         self.send_meter = RateMeter(log_fcn=None, update_interval=1)
         self.ingest_rate = options.copy['ingestrate']
         self.num_worker_processes = options.copy['numprocesses']
@@ -1231,8 +1258,13 @@ class FeedingProcess(mp.Process):
 
     def on_fork(self):
         """
-        Release any parent connections after forking, see CASSANDRA-11749 for 
details.
+        Create the channels and release any parent connections after forking,
+        see CASSANDRA-11749 for details.
         """
+        self.inmsg = ReceivingChannel(self.inpipe)
+        self.outmsg = SendingChannel(self.outpipe)
+        self.worker_channels = [SendingChannel(p) for p in self.worker_pipes]
+
         if self.parent_cluster:
             printdebugmsg("Closing parent cluster sockets")
             self.parent_cluster.shutdown()
@@ -1306,8 +1338,10 @@ class ChildProcess(mp.Process):
 
     def __init__(self, params, target):
         mp.Process.__init__(self, target=target)
-        self.inmsg = params['inmsg']
-        self.outmsg = params['outmsg']
+        self.inpipe = params['inpipe']
+        self.outpipe = params['outpipe']
+        self.inmsg = None  # must be initialized after fork on Windows
+        self.outmsg = None  # must be initialized after fork on Windows
         self.ks = params['ks']
         self.table = params['table']
         self.local_dc = params['local_dc']
@@ -1339,8 +1373,11 @@ class ChildProcess(mp.Process):
 
     def on_fork(self):
         """
-        Release any parent connections after forking, see CASSANDRA-11749 for 
details.
+        Create the channels and release any parent connections after forking, 
see CASSANDRA-11749 for details.
         """
+        self.inmsg = ReceivingChannel(self.inpipe)
+        self.outmsg = SendingChannel(self.outpipe)
+
         if self.parent_cluster:
             printdebugmsg("Closing parent cluster sockets")
             self.parent_cluster.shutdown()

Reply via email to