Repository: cassandra Updated Branches: refs/heads/trunk b42afc424 -> 8c511cc45
cqlsh: Improve COPY TO perf and error handling Patch by Stefania Alborghetti; reviewed by Tyler Hobbs for CASSANDRA-9304 Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/1b629c10 Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/1b629c10 Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/1b629c10 Branch: refs/heads/trunk Commit: 1b629c101bbf793f8e248bbf4396bb41adc0af97 Parents: 246cb88 Author: Stefania Alborghetti <[email protected]> Authored: Wed Nov 18 18:22:28 2015 -0600 Committer: Tyler Hobbs <[email protected]> Committed: Wed Nov 18 18:22:28 2015 -0600 ---------------------------------------------------------------------- CHANGES.txt | 1 + bin/cqlsh | 180 +++-------- pylib/cqlshlib/copy.py | 644 ++++++++++++++++++++++++++++++++++++++ pylib/cqlshlib/displaying.py | 10 + pylib/cqlshlib/formatting.py | 34 +- 5 files changed, 729 insertions(+), 140 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/cassandra/blob/1b629c10/CHANGES.txt ---------------------------------------------------------------------- diff --git a/CHANGES.txt b/CHANGES.txt index 6ccde28..42dcf3e 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,4 +1,5 @@ 2.1.12 + * (cqlsh) Improve COPY TO performance and error handling (CASSANDRA-9304) * Don't remove level info when running upgradesstables (CASSANDRA-10692) * Create compression chunk for sending file only (CASSANDRA-10680) * Make buffered read size configurable (CASSANDRA-10249) http://git-wip-us.apache.org/repos/asf/cassandra/blob/1b629c10/bin/cqlsh ---------------------------------------------------------------------- diff --git a/bin/cqlsh b/bin/cqlsh index 7291803..5459d67 100755 --- a/bin/cqlsh +++ b/bin/cqlsh @@ -37,6 +37,7 @@ import ConfigParser import csv import getpass import locale +import multiprocessing as mp import optparse import os import platform @@ -44,10 +45,11 @@ import sys import time import traceback import warnings + +from StringIO import StringIO from contextlib import contextmanager from functools import partial from glob import glob -from StringIO import StringIO from uuid import UUID description = "CQL Shell for Apache Cassandra" @@ -119,7 +121,7 @@ cqlshlibdir = os.path.join(CASSANDRA_PATH, 'pylib') if os.path.isdir(cqlshlibdir): sys.path.insert(0, cqlshlibdir) -from cqlshlib import cql3handling, cqlhandling, pylexotron, sslhandling +from cqlshlib import cql3handling, cqlhandling, pylexotron, sslhandling, copy from cqlshlib.displaying import (ANSI_RESET, BLUE, COLUMN_NAME_COLORS, CYAN, RED, FormattedValue, colorme) from cqlshlib.formatting import (format_by_type, format_value_utype, @@ -410,7 +412,8 @@ def complete_copy_column_names(ctxt, cqlsh): return set(colnames[1:]) - set(existcols) -COPY_OPTIONS = ('DELIMITER', 'QUOTE', 'ESCAPE', 'HEADER', 'ENCODING', 'TIMEFORMAT', 'NULL') +COPY_OPTIONS = ['DELIMITER', 'QUOTE', 'ESCAPE', 'HEADER', 'NULL', 'ENCODING', + 'TIMEFORMAT', 'JOBS', 'PAGESIZE', 'PAGETIMEOUT', 'MAXATTEMPTS'] @cqlsh_syntax_completer('copyOption', 'optnames') @@ -419,8 +422,7 @@ def complete_copy_options(ctxt, cqlsh): direction = ctxt.get_binding('dir').upper() opts = set(COPY_OPTIONS) - set(optnames) if direction == 'FROM': - opts -= ('ENCODING',) - opts -= ('TIMEFORMAT',) + opts -= set(['ENCODING', 'TIMEFORMAT', 'JOBS', 'PAGESIZE', 'PAGETIMEOUT', 'MAXATTEMPTS']) return opts @@ -535,6 +537,19 @@ def describe_interval(seconds): return words +def insert_driver_hooks(): + extend_cql_deserialization() + auto_format_udts() + + +def extend_cql_deserialization(): + """ + The python driver returns BLOBs as string, but we expect them as bytearrays + """ + cassandra.cqltypes.BytesType.deserialize = staticmethod(lambda byts, protocol_version: bytearray(byts)) + cassandra.cqltypes.CassandraType.support_empty_values = True + + def auto_format_udts(): # when we see a new user defined type, set up the shell formatting for it udt_apply_params = cassandra.cqltypes.UserType.apply_parameters @@ -673,11 +688,6 @@ class Shell(cmd.Cmd): self.query_out = sys.stdout self.consistency_level = cassandra.ConsistencyLevel.ONE self.serial_consistency_level = cassandra.ConsistencyLevel.SERIAL - # the python driver returns BLOBs as string, but we expect them as bytearrays - cassandra.cqltypes.BytesType.deserialize = staticmethod(lambda byts, protocol_version: bytearray(byts)) - cassandra.cqltypes.CassandraType.support_empty_values = True - - auto_format_udts() self.empty_lines = 0 self.statement_error = False @@ -807,11 +817,9 @@ class Shell(cmd.Cmd): def get_keyspaces(self): return self.conn.metadata.keyspaces.values() - def get_ring(self): - if self.current_keyspace is None or self.current_keyspace == 'system': - raise NoKeyspaceError("Ring view requires a current non-system keyspace") - self.conn.metadata.token_map.rebuild_keyspace(self.current_keyspace, build_if_absent=True) - return self.conn.metadata.token_map.tokens_to_hosts_by_ks[self.current_keyspace] + def get_ring(self, ks): + self.conn.metadata.token_map.rebuild_keyspace(ks, build_if_absent=True) + return self.conn.metadata.token_map.tokens_to_hosts_by_ks[ks] def get_table_meta(self, ksname, tablename): if ksname is None: @@ -1369,7 +1377,7 @@ class Shell(cmd.Cmd): # print 'Snitch: %s\n' % snitch if self.current_keyspace is not None and self.current_keyspace != 'system': print "Range ownership:" - ring = self.get_ring() + ring = self.get_ring(self.current_keyspace) for entry in ring.items(): print ' %39s [%s]' % (str(entry[0].value), ', '.join([host.address for host in entry[1]])) print @@ -1506,10 +1514,14 @@ class Shell(cmd.Cmd): ENCODING='utf8' - encoding for CSV output (COPY TO only) TIMEFORMAT= - timestamp strftime format (COPY TO only) '%Y-%m-%d %H:%M:%S%z' defaults to time_format value in cqlshrc + PAGESIZE='1000' - the page size for fetching results (COPY TO only) + PAGETIMEOUT=10 - the page timeout for fetching results (COPY TO only) + MAXATTEMPTS='5' - the maximum number of attempts for errors (COPY TO only) When entering CSV data on STDIN, you can use the sequence "\." on a line by itself to end the data input. """ + ks = self.cql_unprotect_name(parsed.get_binding('ksname', None)) if ks is None: ks = self.current_keyspace @@ -1546,22 +1558,12 @@ class Shell(cmd.Cmd): print "\n%d rows %s in %s." % (rows, verb, describe_interval(timeend - timestart)) def perform_csv_import(self, ks, cf, columns, fname, opts): - dialect_options = self.csv_dialect_defaults.copy() - if 'quote' in opts: - dialect_options['quotechar'] = opts.pop('quote') - if 'escape' in opts: - dialect_options['escapechar'] = opts.pop('escape') - if 'delimiter' in opts: - dialect_options['delimiter'] = opts.pop('delimiter') - nullval = opts.pop('null', '') - header = bool(opts.pop('header', '').lower() == 'true') - if dialect_options['quotechar'] == dialect_options['escapechar']: - dialect_options['doublequote'] = True - del dialect_options['escapechar'] - if opts: + csv_options, dialect_options, unrecognized_options = copy.parse_options(self, opts) + if unrecognized_options: self.printerr('Unrecognized COPY FROM options: %s' - % ', '.join(opts.keys())) + % ', '.join(unrecognized_options.keys())) return 0 + nullval, header = csv_options['nullval'], csv_options['header'] if fname is None: do_close = False @@ -1576,33 +1578,24 @@ class Shell(cmd.Cmd): return 0 current_record = None + processes, pipes = [], [], try: if header: linesource.next() reader = csv.reader(linesource, **dialect_options) - from multiprocessing import Process, Pipe, cpu_count + num_processes = copy.get_num_processes(cap=4) - # Pick a resonable number of child processes. We need to leave at - # least one core for the parent process. This doesn't necessarily - # need to be capped at 4, but it's currently enough to keep - # a single local Cassandra node busy, and I see lower throughput - # with more processes. - try: - num_processes = max(1, min(4, cpu_count() - 1)) - except NotImplementedError: - num_processes = 1 - - processes, pipes = [], [], for i in range(num_processes): - parent_conn, child_conn = Pipe() + parent_conn, child_conn = mp.Pipe() pipes.append(parent_conn) - processes.append(Process(target=self.multiproc_import, args=(child_conn, ks, cf, columns, nullval))) + proc_args = (child_conn, ks, cf, columns, nullval) + processes.append(mp.Process(target=self.multiproc_import, args=proc_args)) for process in processes: process.start() - meter = RateMeter(10000) + meter = copy.RateMeter(10000) for current_record, row in enumerate(reader, start=1): # write to the child process pipes[current_record % num_processes].send((current_record, row)) @@ -1612,7 +1605,7 @@ class Shell(cmd.Cmd): # check for any errors reported by the children if (current_record % 100) == 0: - if self._check_child_pipes(current_record, pipes): + if self._check_import_processes(current_record, pipes): # no errors seen, continue with outer loop continue else: @@ -1641,7 +1634,7 @@ class Shell(cmd.Cmd): for process in processes: process.join() - self._check_child_pipes(current_record, pipes) + self._check_import_processes(current_record, pipes) for pipe in pipes: pipe.close() @@ -1653,8 +1646,7 @@ class Shell(cmd.Cmd): return current_record - def _check_child_pipes(self, current_record, pipes): - # check the pipes for errors from child processes + def _check_import_processes(self, current_record, pipes): for pipe in pipes: if pipe.poll(): try: @@ -1802,62 +1794,13 @@ class Shell(cmd.Cmd): new_cluster.shutdown() def perform_csv_export(self, ks, cf, columns, fname, opts): - dialect_options = self.csv_dialect_defaults.copy() - - if 'quote' in opts: - dialect_options['quotechar'] = opts.pop('quote') - if 'escape' in opts: - dialect_options['escapechar'] = opts.pop('escape') - if 'delimiter' in opts: - dialect_options['delimiter'] = opts.pop('delimiter') - encoding = opts.pop('encoding', 'utf8') - nullval = opts.pop('null', '') - header = bool(opts.pop('header', '').lower() == 'true') - time_format = opts.pop('timeformat', self.display_time_format) - if dialect_options['quotechar'] == dialect_options['escapechar']: - dialect_options['doublequote'] = True - del dialect_options['escapechar'] - - if opts: - self.printerr('Unrecognized COPY TO options: %s' - % ', '.join(opts.keys())) + csv_options, dialect_options, unrecognized_options = copy.parse_options(self, opts) + if unrecognized_options: + self.printerr('Unrecognized COPY TO options: %s' % ', '.join(unrecognized_options.keys())) return 0 - if fname is None: - do_close = False - csvdest = sys.stdout - else: - do_close = True - try: - csvdest = open(fname, 'wb') - except IOError, e: - self.printerr("Can't open %r for writing: %s" % (fname, e)) - return 0 - - meter = RateMeter(10000) - try: - dump = self.prep_export_dump(ks, cf, columns) - writer = csv.writer(csvdest, **dialect_options) - if header: - writer.writerow(columns) - for row in dump: - fmt = lambda v: \ - format_value(v, output_encoding=encoding, nullval=nullval, - time_format=time_format, - float_precision=self.display_float_precision).strval - writer.writerow(map(fmt, row.values())) - meter.increment() - finally: - if do_close: - csvdest.close() - return meter.current_record - - def prep_export_dump(self, ks, cf, columns): - if columns is None: - columns = self.get_column_names(ks, cf) - columnlist = ', '.join(protect_names(columns)) - query = 'SELECT %s FROM %s.%s' % (columnlist, protect_name(ks), protect_name(cf)) - return self.session.execute(query) + return copy.ExportTask(self, ks, cf, columns, fname, csv_options, dialect_options, + DEFAULT_PROTOCOL_VERSION, CONFIG_FILE).run() def do_show(self, parsed): """ @@ -2215,34 +2158,6 @@ class Shell(cmd.Cmd): self.writeresult(text, color, newline=newline, out=sys.stderr) -class RateMeter(object): - - def __init__(self, log_rate): - self.log_rate = log_rate - self.last_checkpoint_time = time.time() - self.current_rate = 0.0 - self.current_record = 0 - - def increment(self): - self.current_record += 1 - - if (self.current_record % self.log_rate) == 0: - new_checkpoint_time = time.time() - new_rate = self.log_rate / (new_checkpoint_time - self.last_checkpoint_time) - self.last_checkpoint_time = new_checkpoint_time - - # smooth the rate a bit - if self.current_rate == 0.0: - self.current_rate = new_rate - else: - self.current_rate = (self.current_rate + new_rate) / 2.0 - - output = 'Processed %s rows; Write: %.2f rows/s\r' % \ - (self.current_record, self.current_rate) - sys.stdout.write(output) - sys.stdout.flush() - - class SwitchCommand(object): command = None description = None @@ -2487,6 +2402,9 @@ def main(options, hostname, port): if batch_mode and shell.statement_error: sys.exit(2) +# always call this regardless of module name: when a sub-process is spawned +# on Windows then the module name is not __main__, see CASSANDRA-9304 +insert_driver_hooks() if __name__ == '__main__': main(*read_options(sys.argv[1:], os.environ)) http://git-wip-us.apache.org/repos/asf/cassandra/blob/1b629c10/pylib/cqlshlib/copy.py ---------------------------------------------------------------------- diff --git a/pylib/cqlshlib/copy.py b/pylib/cqlshlib/copy.py new file mode 100644 index 0000000..8534b98 --- /dev/null +++ b/pylib/cqlshlib/copy.py @@ -0,0 +1,644 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import json +import multiprocessing as mp +import os +import Queue +import sys +import time +import traceback + +from StringIO import StringIO +from random import randrange +from threading import Lock + +from cassandra.cluster import Cluster +from cassandra.metadata import protect_name, protect_names +from cassandra.policies import RetryPolicy, WhiteListRoundRobinPolicy, TokenAwarePolicy +from cassandra.query import tuple_factory + + +import sslhandling +from displaying import NO_COLOR_MAP +from formatting import format_value_default, EMPTY, get_formatter + + +def parse_options(shell, opts): + """ + Parse options for import (COPY FROM) and export (COPY TO) operations. + Extract from opts csv and dialect options. + + :return: 3 dictionaries: the csv options, the dialect options, any unrecognized options. + """ + dialect_options = shell.csv_dialect_defaults.copy() + if 'quote' in opts: + dialect_options['quotechar'] = opts.pop('quote') + if 'escape' in opts: + dialect_options['escapechar'] = opts.pop('escape') + if 'delimiter' in opts: + dialect_options['delimiter'] = opts.pop('delimiter') + if dialect_options['quotechar'] == dialect_options['escapechar']: + dialect_options['doublequote'] = True + del dialect_options['escapechar'] + + csv_options = dict() + csv_options['nullval'] = opts.pop('null', '') + csv_options['header'] = bool(opts.pop('header', '').lower() == 'true') + csv_options['encoding'] = opts.pop('encoding', 'utf8') + csv_options['jobs'] = int(opts.pop('jobs', 12)) + csv_options['pagesize'] = int(opts.pop('pagesize', 1000)) + # by default the page timeout is 10 seconds per 1000 entries in the page size or 10 seconds if pagesize is smaller + csv_options['pagetimeout'] = int(opts.pop('pagetimeout', max(10, 10 * (csv_options['pagesize'] / 1000)))) + csv_options['maxattempts'] = int(opts.pop('maxattempts', 5)) + csv_options['dtformats'] = opts.pop('timeformat', shell.display_time_format) + csv_options['float_precision'] = shell.display_float_precision + + return csv_options, dialect_options, opts + + +def get_num_processes(cap): + """ + Pick a reasonable number of child processes. We need to leave at + least one core for the parent process. This doesn't necessarily + need to be capped, but 4 is currently enough to keep + a single local Cassandra node busy so we use this for import, whilst + for export we use 16 since we can connect to multiple Cassandra nodes. + Eventually this parameter will become an option. + """ + try: + return max(1, min(cap, mp.cpu_count() - 1)) + except NotImplementedError: + return 1 + + +class ExportTask(object): + """ + A class that exports data to .csv by instantiating one or more processes that work in parallel (ExportProcess). + """ + def __init__(self, shell, ks, cf, columns, fname, csv_options, dialect_options, protocol_version, config_file): + self.shell = shell + self.csv_options = csv_options + self.dialect_options = dialect_options + self.ks = ks + self.cf = cf + self.columns = shell.get_column_names(ks, cf) if columns is None else columns + self.fname = fname + self.protocol_version = protocol_version + self.config_file = config_file + + def run(self): + """ + Initiates the export by creating the processes. + """ + shell = self.shell + fname = self.fname + + if fname is None: + do_close = False + csvdest = sys.stdout + else: + do_close = True + try: + csvdest = open(fname, 'wb') + except IOError, e: + shell.printerr("Can't open %r for writing: %s" % (fname, e)) + return 0 + + if self.csv_options['header']: + writer = csv.writer(csvdest, **self.dialect_options) + writer.writerow(self.columns) + + ranges = self.get_ranges() + num_processes = get_num_processes(cap=min(16, len(ranges))) + + inmsg = mp.Queue() + outmsg = mp.Queue() + processes = [] + for i in xrange(num_processes): + process = ExportProcess(outmsg, inmsg, self.ks, self.cf, self.columns, self.dialect_options, + self.csv_options, shell.debug, shell.port, shell.conn.cql_version, + shell.auth_provider, shell.ssl, self.protocol_version, self.config_file) + process.start() + processes.append(process) + + try: + return self.check_processes(csvdest, ranges, inmsg, outmsg, processes) + finally: + for process in processes: + process.terminate() + + inmsg.close() + outmsg.close() + if do_close: + csvdest.close() + + def get_ranges(self): + """ + return a queue of tuples, where the first tuple entry is a token range (from, to] + and the second entry is a list of hosts that own that range. Each host is responsible + for all the tokens in the rage (from, to]. + + The ring information comes from the driver metadata token map, which is built by + querying System.PEERS. + + We only consider replicas that are in the local datacenter. If there are no local replicas + we use the cqlsh session host. + """ + shell = self.shell + hostname = shell.hostname + ranges = dict() + + def make_range(hosts): + return {'hosts': tuple(hosts), 'attempts': 0, 'rows': 0} + + min_token = self.get_min_token() + if shell.conn.metadata.token_map is None or min_token is None: + ranges[(None, None)] = make_range([hostname]) + return ranges + + local_dc = shell.conn.metadata.get_host(hostname).datacenter + ring = shell.get_ring(self.ks).items() + ring.sort() + + previous_previous = None + previous = None + for token, replicas in ring: + if previous is None and token.value == min_token: + continue # avoids looping entire ring + + hosts = [] + for host in replicas: + if host.datacenter == local_dc: + hosts.append(host.address) + if len(hosts) == 0: + hosts.append(hostname) # fallback to default host if no replicas in current dc + ranges[(previous, token.value)] = make_range(hosts) + previous_previous = previous + previous = token.value + + # If the ring is empty we get the entire ring from the + # host we are currently connected to, otherwise for the last ring interval + # we query the same replicas that hold the last token in the ring + if len(ranges) == 0: + ranges[(None, None)] = make_range([hostname]) + else: + ranges[(previous, None)] = ranges[(previous_previous, previous)].copy() + + return ranges + + def get_min_token(self): + """ + :return the minimum token, which depends on the partitioner. + For partitioners that do not support tokens we return None, in + this cases we will not work in parallel, we'll just send all requests + to the cqlsh session host. + """ + partitioner = self.shell.conn.metadata.partitioner + + if partitioner.endswith('RandomPartitioner'): + return -1 + elif partitioner.endswith('Murmur3Partitioner'): + return -(2 ** 63) # Long.MIN_VALUE in Java + else: + return None + + @staticmethod + def send_work(ranges, tokens_to_send, queue): + for token_range in tokens_to_send: + queue.put((token_range, ranges[token_range])) + ranges[token_range]['attempts'] += 1 + + def check_processes(self, csvdest, ranges, inmsg, outmsg, processes): + """ + Here we monitor all child processes by collecting their results + or any errors. We terminate when we have processed all the ranges or when there + are no more processes. + """ + shell = self.shell + meter = RateMeter(10000) + total_jobs = len(ranges) + max_attempts = self.csv_options['maxattempts'] + + self.send_work(ranges, ranges.keys(), outmsg) + + num_processes = len(processes) + succeeded = 0 + failed = 0 + while (failed + succeeded) < total_jobs and self.num_live_processes(processes) == num_processes: + try: + token_range, result = inmsg.get(timeout=1.0) + if token_range is None and result is None: # a job has finished + succeeded += 1 + elif isinstance(result, Exception): # an error occurred + if token_range is None: # the entire process failed + shell.printerr('Error from worker process: %s' % (result)) + else: # only this token_range failed, retry up to max_attempts if no rows received yet, + # if rows are receive we risk duplicating data, there is a back-off policy in place + # in the worker process as well, see ExpBackoffRetryPolicy + if ranges[token_range]['attempts'] < max_attempts and ranges[token_range]['rows'] == 0: + shell.printerr('Error for %s: %s (will try again later attempt %d of %d)' + % (token_range, result, ranges[token_range]['attempts'], max_attempts)) + self.send_work(ranges, [token_range], outmsg) + else: + shell.printerr('Error for %s: %s (permanently given up after %d rows and %d attempts)' + % (token_range, result, ranges[token_range]['rows'], + ranges[token_range]['attempts'])) + failed += 1 + else: # partial result received + data, num = result + csvdest.write(data) + meter.increment(n=num) + ranges[token_range]['rows'] += num + except Queue.Empty: + pass + + if self.num_live_processes(processes) < len(processes): + for process in processes: + if not process.is_alive(): + shell.printerr('Child process %d died with exit code %d' % (process.pid, process.exitcode)) + + if succeeded < total_jobs: + shell.printerr('Exported %d ranges out of %d total ranges, some records might be missing' + % (succeeded, total_jobs)) + + return meter.get_total_records() + + @staticmethod + def num_live_processes(processes): + return sum(1 for p in processes if p.is_alive()) + + +class ExpBackoffRetryPolicy(RetryPolicy): + """ + A retry policy with exponential back-off for read timeouts, + see ExportProcess. + """ + def __init__(self, export_process): + RetryPolicy.__init__(self) + self.max_attempts = export_process.csv_options['maxattempts'] + self.printmsg = lambda txt: export_process.printmsg(txt) + + def on_read_timeout(self, query, consistency, required_responses, + received_responses, data_retrieved, retry_num): + delay = self.backoff(retry_num) + if delay > 0: + self.printmsg("Timeout received, retrying after %d seconds" % (delay)) + time.sleep(delay) + return self.RETRY, consistency + elif delay == 0: + self.printmsg("Timeout received, retrying immediately") + return self.RETRY, consistency + else: + self.printmsg("Timeout received, giving up after %d attempts" % (retry_num + 1)) + return self.RETHROW, None + + def backoff(self, retry_num): + """ + Perform exponential back-off up to a maximum number of times, where + this maximum is per query. + To back-off we should wait a random number of seconds + between 0 and 2^c - 1, where c is the number of total failures. + randrange() excludes the last value, so we drop the -1. + + :return : the number of seconds to wait for, -1 if we should not retry + """ + if retry_num >= self.max_attempts: + return -1 + + delay = randrange(0, pow(2, retry_num + 1)) + return delay + + +class ExportSession(object): + """ + A class for connecting to a cluster and storing the number + of jobs that this connection is processing. It wraps the methods + for executing a query asynchronously and for shutting down the + connection to the cluster. + """ + def __init__(self, cluster, export_process): + session = cluster.connect(export_process.ks) + session.row_factory = tuple_factory + session.default_fetch_size = export_process.csv_options['pagesize'] + session.default_timeout = export_process.csv_options['pagetimeout'] + + export_process.printmsg("Created connection to %s with page size %d and timeout %d seconds per page" + % (session.hosts, session.default_fetch_size, session.default_timeout)) + + self.cluster = cluster + self.session = session + self.jobs = 1 + self.lock = Lock() + + def add_job(self): + with self.lock: + self.jobs += 1 + + def complete_job(self): + with self.lock: + self.jobs -= 1 + + def num_jobs(self): + with self.lock: + return self.jobs + + def execute_async(self, query): + return self.session.execute_async(query) + + def shutdown(self): + self.cluster.shutdown() + + +class ExportProcess(mp.Process): + """ + An child worker process for the export task, ExportTask. + """ + + def __init__(self, inmsg, outmsg, ks, cf, columns, dialect_options, csv_options, + debug, port, cql_version, auth_provider, ssl, protocol_version, config_file): + mp.Process.__init__(self, target=self.run) + self.inmsg = inmsg + self.outmsg = outmsg + self.ks = ks + self.cf = cf + self.columns = columns + self.dialect_options = dialect_options + self.hosts_to_sessions = dict() + + self.debug = debug + self.port = port + self.cql_version = cql_version + self.auth_provider = auth_provider + self.ssl = ssl + self.protocol_version = protocol_version + self.config_file = config_file + + self.encoding = csv_options['encoding'] + self.time_format = csv_options['dtformats'] + self.float_precision = csv_options['float_precision'] + self.nullval = csv_options['nullval'] + self.maxjobs = csv_options['jobs'] + self.csv_options = csv_options + self.formatters = dict() + + # Here we inject some failures for testing purposes, only if this environment variable is set + if os.environ.get('CQLSH_COPY_TEST_FAILURES', ''): + self.test_failures = json.loads(os.environ.get('CQLSH_COPY_TEST_FAILURES', '')) + else: + self.test_failures = None + + def printmsg(self, text): + if self.debug: + sys.stderr.write(text + os.linesep) + + def run(self): + try: + self.inner_run() + finally: + self.close() + + def inner_run(self): + """ + The parent sends us (range, info) on the inbound queue (inmsg) + in order to request us to process a range, for which we can + select any of the hosts in info, which also contains other information for this + range such as the number of attempts already performed. We can signal errors + on the outbound queue (outmsg) by sending (range, error) or + we can signal a global error by sending (None, error). + We terminate when the inbound queue is closed. + """ + while True: + if self.num_jobs() > self.maxjobs: + time.sleep(0.001) # 1 millisecond + continue + + token_range, info = self.inmsg.get() + self.start_job(token_range, info) + + def report_error(self, err, token_range=None): + if isinstance(err, str): + msg = err + elif isinstance(err, BaseException): + msg = "%s - %s" % (err.__class__.__name__, err) + if self.debug: + traceback.print_exc(err) + else: + msg = str(err) + + self.printmsg(msg) + self.outmsg.put((token_range, Exception(msg))) + + def start_job(self, token_range, info): + """ + Begin querying a range by executing an async query that + will later on invoke the callbacks attached in attach_callbacks. + """ + session = self.get_session(info['hosts']) + metadata = session.cluster.metadata.keyspaces[self.ks].tables[self.cf] + query = self.prepare_query(metadata.partition_key, token_range, info['attempts']) + future = session.execute_async(query) + self.attach_callbacks(token_range, future, session) + + def num_jobs(self): + return sum(session.num_jobs() for session in self.hosts_to_sessions.values()) + + def get_session(self, hosts): + """ + We select a host to connect to. If we have no connections to one of the hosts + yet then we select this host, else we pick the one with the smallest number + of jobs. + + :return: An ExportSession connected to the chosen host. + """ + new_hosts = [h for h in hosts if h not in self.hosts_to_sessions] + if new_hosts: + host = new_hosts[0] + new_cluster = Cluster( + contact_points=(host,), + port=self.port, + cql_version=self.cql_version, + protocol_version=self.protocol_version, + auth_provider=self.auth_provider, + ssl_options=sslhandling.ssl_settings(host, self.config_file) if self.ssl else None, + load_balancing_policy=TokenAwarePolicy(WhiteListRoundRobinPolicy(hosts)), + default_retry_policy=ExpBackoffRetryPolicy(self), + compression=None, + executor_threads=max(2, self.csv_options['jobs'] / 2)) + + session = ExportSession(new_cluster, self) + self.hosts_to_sessions[host] = session + return session + else: + host = min(hosts, key=lambda h: self.hosts_to_sessions[h].jobs) + session = self.hosts_to_sessions[host] + session.add_job() + return session + + def attach_callbacks(self, token_range, future, session): + def result_callback(rows): + if future.has_more_pages: + future.start_fetching_next_page() + self.write_rows_to_csv(token_range, rows) + else: + self.write_rows_to_csv(token_range, rows) + self.outmsg.put((None, None)) + session.complete_job() + + def err_callback(err): + self.report_error(err, token_range) + session.complete_job() + + future.add_callbacks(callback=result_callback, errback=err_callback) + + def write_rows_to_csv(self, token_range, rows): + if len(rows) == 0: + return # no rows in this range + + try: + output = StringIO() + writer = csv.writer(output, **self.dialect_options) + + for row in rows: + writer.writerow(map(self.format_value, row)) + + data = (output.getvalue(), len(rows)) + self.outmsg.put((token_range, data)) + output.close() + + except Exception, e: + self.report_error(e, token_range) + + def format_value(self, val): + if val is None or val == EMPTY: + return format_value_default(self.nullval, colormap=NO_COLOR_MAP) + + ctype = type(val) + formatter = self.formatters.get(ctype, None) + if not formatter: + formatter = get_formatter(ctype) + self.formatters[ctype] = formatter + + return formatter(val, encoding=self.encoding, colormap=NO_COLOR_MAP, time_format=self.time_format, + float_precision=self.float_precision, nullval=self.nullval, quote=False) + + def close(self): + self.printmsg("Export process terminating...") + self.inmsg.close() + self.outmsg.close() + for session in self.hosts_to_sessions.values(): + session.shutdown() + self.printmsg("Export process terminated") + + def prepare_query(self, partition_key, token_range, attempts): + """ + Return the export query or a fake query with some failure injected. + """ + if self.test_failures: + return self.maybe_inject_failures(partition_key, token_range, attempts) + else: + return self.prepare_export_query(partition_key, token_range) + + def maybe_inject_failures(self, partition_key, token_range, attempts): + """ + Examine self.test_failures and see if token_range is either a token range + supposed to cause a failure (failing_range) or to terminate the worker process + (exit_range). If not then call prepare_export_query(), which implements the + normal behavior. + """ + start_token, end_token = token_range + + if not start_token or not end_token: + # exclude first and last ranges to make things simpler + return self.prepare_export_query(partition_key, token_range) + + if 'failing_range' in self.test_failures: + failing_range = self.test_failures['failing_range'] + if start_token >= failing_range['start'] and end_token <= failing_range['end']: + if attempts < failing_range['num_failures']: + return 'SELECT * from bad_table' + + if 'exit_range' in self.test_failures: + exit_range = self.test_failures['exit_range'] + if start_token >= exit_range['start'] and end_token <= exit_range['end']: + sys.exit(1) + + return self.prepare_export_query(partition_key, token_range) + + def prepare_export_query(self, partition_key, token_range): + """ + Return a query where we select all the data for this token range + """ + pk_cols = ", ".join(protect_names(col.name for col in partition_key)) + columnlist = ', '.join(protect_names(self.columns)) + start_token, end_token = token_range + query = 'SELECT %s FROM %s.%s' % (columnlist, protect_name(self.ks), protect_name(self.cf)) + if start_token is not None or end_token is not None: + query += ' WHERE' + if start_token is not None: + query += ' token(%s) > %s' % (pk_cols, start_token) + if start_token is not None and end_token is not None: + query += ' AND' + if end_token is not None: + query += ' token(%s) <= %s' % (pk_cols, end_token) + return query + + +class RateMeter(object): + + def __init__(self, log_threshold): + self.log_threshold = log_threshold # number of records after which we log + self.last_checkpoint_time = time.time() # last time we logged + self.current_rate = 0.0 # rows per second + self.current_record = 0 # number of records since we last logged + self.total_records = 0 # total number of records + + def increment(self, n=1): + self.current_record += n + + if self.current_record >= self.log_threshold: + self.update() + self.log() + + def update(self): + new_checkpoint_time = time.time() + time_difference = new_checkpoint_time - self.last_checkpoint_time + if time_difference != 0.0: + self.current_rate = self.get_new_rate(self.current_record / time_difference) + + self.last_checkpoint_time = new_checkpoint_time + self.total_records += self.current_record + self.current_record = 0 + + def get_new_rate(self, new_rate): + """ + return the previous rate averaged with the new rate to smooth a bit + """ + if self.current_rate == 0.0: + return new_rate + else: + return (self.current_rate + new_rate) / 2.0 + + def log(self): + output = 'Processed %d rows; Written: %f rows/s\r' % (self.total_records, self.current_rate,) + sys.stdout.write(output) + sys.stdout.flush() + + def get_total_records(self): + self.update() + self.log() + return self.total_records http://git-wip-us.apache.org/repos/asf/cassandra/blob/1b629c10/pylib/cqlshlib/displaying.py ---------------------------------------------------------------------- diff --git a/pylib/cqlshlib/displaying.py b/pylib/cqlshlib/displaying.py index f3a016e..7b260c2 100644 --- a/pylib/cqlshlib/displaying.py +++ b/pylib/cqlshlib/displaying.py @@ -28,11 +28,19 @@ ANSI_RESET = '\033[0m' def colorme(bval, colormap, colorkey): + if colormap is NO_COLOR_MAP: + return bval if colormap is None: colormap = DEFAULT_VALUE_COLORS return FormattedValue(bval, colormap[colorkey] + bval + colormap['reset']) +def get_str(val): + if isinstance(val, FormattedValue): + return val.strval + return val + + class FormattedValue: def __init__(self, strval, coloredval=None, displaywidth=None): @@ -112,3 +120,5 @@ COLUMN_NAME_COLORS = defaultdict(lambda: MAGENTA, blob=DARK_MAGENTA, reset=ANSI_RESET, ) + +NO_COLOR_MAP = dict() http://git-wip-us.apache.org/repos/asf/cassandra/blob/1b629c10/pylib/cqlshlib/formatting.py ---------------------------------------------------------------------- diff --git a/pylib/cqlshlib/formatting.py b/pylib/cqlshlib/formatting.py index 79e661b..54dde0f 100644 --- a/pylib/cqlshlib/formatting.py +++ b/pylib/cqlshlib/formatting.py @@ -14,13 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import binascii import sys import re import calendar import math from collections import defaultdict from . import wcwidth -from .displaying import colorme, FormattedValue, DEFAULT_VALUE_COLORS +from .displaying import colorme, get_str, FormattedValue, DEFAULT_VALUE_COLORS, NO_COLOR_MAP from cassandra.cqltypes import EMPTY from cassandra.util import datetime_from_timestamp from util import UTC @@ -83,7 +84,6 @@ def color_text(bval, colormap, displaywidth=None): # adding the smarts to handle that in to FormattedValue, we just # make an explicit check to see if a null colormap is being used or # not. - if displaywidth is None: displaywidth = len(bval) tbr = _make_turn_bits_red_f(colormap['blob'], colormap['text']) @@ -97,7 +97,7 @@ def format_value_default(val, colormap, **_): val = str(val) escapedval = val.replace('\\', '\\\\') bval = controlchars_re.sub(_show_control_chars, escapedval) - return color_text(bval, colormap) + return bval if colormap is NO_COLOR_MAP else color_text(bval, colormap) # Mapping cql type base names ("int", "map", etc) to formatter functions, # making format_value a generic function @@ -111,6 +111,10 @@ def format_value(type, val, **kwargs): return formatter(val, **kwargs) +def get_formatter(type): + return _formatters.get(type.__name__, format_value_default) + + def formatter_for(typname): def registrator(f): _formatters[typname] = f @@ -120,7 +124,7 @@ def formatter_for(typname): @formatter_for('bytearray') def format_value_blob(val, colormap, **_): - bval = '0x' + ''.join('%02x' % c for c in val) + bval = '0x' + binascii.hexlify(val) return colorme(bval, colormap, 'blob') formatter_for('buffer')(format_value_blob) @@ -204,8 +208,8 @@ def format_value_text(val, encoding, colormap, quote=False, **_): bval = escapedval.encode(encoding, 'backslashreplace') if quote: bval = "'%s'" % bval - displaywidth = wcwidth.wcswidth(bval.decode(encoding)) - return color_text(bval, colormap, displaywidth) + + return bval if colormap is NO_COLOR_MAP else color_text(bval, colormap, wcwidth.wcswidth(bval.decode(encoding))) # name alias formatter_for('unicode')(format_value_text) @@ -217,7 +221,10 @@ def format_simple_collection(val, lbracket, rbracket, encoding, time_format=time_format, float_precision=float_precision, nullval=nullval, quote=True) for sval in val] - bval = lbracket + ', '.join(sval.strval for sval in subs) + rbracket + bval = lbracket + ', '.join(get_str(sval) for sval in subs) + rbracket + if colormap is NO_COLOR_MAP: + return bval + lb, sep, rb = [colormap['collection'] + s + colormap['reset'] for s in (lbracket, ', ', rbracket)] coloredval = lb + sep.join(sval.coloredval for sval in subs) + rb @@ -242,6 +249,9 @@ def format_value_set(val, encoding, colormap, time_format, float_precision, null return format_simple_collection(sorted(val), '{', '}', encoding, colormap, time_format, float_precision, nullval) formatter_for('frozenset')(format_value_set) +# This code is used by cqlsh (bundled driver version 2.7.2 using sortedset), +# and the dtests, which use whichever driver on the machine, i.e. 3.0.0 (SortedSet) +formatter_for('SortedSet')(format_value_set) formatter_for('sortedset')(format_value_set) @@ -253,7 +263,10 @@ def format_value_map(val, encoding, colormap, time_format, float_precision, null nullval=nullval, quote=True) subs = [(subformat(k), subformat(v)) for (k, v) in sorted(val.items())] - bval = '{' + ', '.join(k.strval + ': ' + v.strval for (k, v) in subs) + '}' + bval = '{' + ', '.join(get_str(k) + ': ' + get_str(v) for (k, v) in subs) + '}' + if colormap is NO_COLOR_MAP: + return bval + lb, comma, colon, rb = [colormap['collection'] + s + colormap['reset'] for s in ('{', ', ', ': ', '}')] coloredval = lb \ @@ -278,7 +291,10 @@ def format_value_utype(val, encoding, colormap, time_format, float_precision, nu return format_value_text(name, encoding=encoding, colormap=colormap, quote=False) subs = [(format_field_name(k), format_field_value(v)) for (k, v) in val._asdict().items()] - bval = '{' + ', '.join(k.strval + ': ' + v.strval for (k, v) in subs) + '}' + bval = '{' + ', '.join(get_str(k) + ': ' + get_str(v) for (k, v) in subs) + '}' + if colormap is NO_COLOR_MAP: + return bval + lb, comma, colon, rb = [colormap['collection'] + s + colormap['reset'] for s in ('{', ', ', ': ', '}')] coloredval = lb \
