cqlsh: handle deserialization errors. Patch by paul cannon, reviewed by brandonwilliams for CASSANDRA-3874
Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/cd36f975 Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/cd36f975 Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/cd36f975 Branch: refs/heads/trunk Commit: cd36f9757150db6a369d70f633971dadd571509a Parents: 636e41d Author: Brandon Williams <[email protected]> Authored: Thu Feb 16 09:21:55 2012 -0600 Committer: Brandon Williams <[email protected]> Committed: Thu Feb 16 09:21:55 2012 -0600 ---------------------------------------------------------------------- bin/cqlsh | 162 +++++++++++++++++++++++++++++++++++++++++--------------- 1 files changed, 120 insertions(+), 42 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/cassandra/blob/cd36f975/bin/cqlsh ---------------------------------------------------------------------- diff --git a/bin/cqlsh b/bin/cqlsh index c89aa16..47763ec 100755 --- a/bin/cqlsh +++ b/bin/cqlsh @@ -214,6 +214,25 @@ class NoKeyspaceError(Exception): class KeyspaceNotFound(Exception): pass +class DecodeError(Exception): + def __init__(self, thebytes, err, expectedtype, colname=None): + self.thebytes = thebytes + self.err = err + self.expectedtype = expectedtype + self.colname = colname + + def __str__(self): + return str(self.thebytes) + + def message(self): + what = 'column name %r' % (self.thebytes,) + if self.colname is not None: + what = 'value %r (for column %r)' % (self.thebytes, self.colname) + return 'Failed to decode %s as %s: %s' % (what, self.expectedtype, self.err) + + def __repr__(self): + return '<%s %s>' % (self.__class__.__name__, self.message()) + def trim_if_present(s, prefix): if s.startswith(prefix): return s[len(prefix):] @@ -225,12 +244,23 @@ class FormattedValue: self.coloredval = coloredval self.displaywidth = displaywidth + def __len__(self): + return len(self.strval) + def _pad(self, width, fill=' '): if width > self.displaywidth: return fill * (width - self.displaywidth) else: return '' + def ljust(self, width, fill=' '): + """ + Similar to self.strval.ljust(width), but takes expected terminal + display width into account for special characters, and does not + take color escape codes into account. + """ + return self.strval + self._pad(width, fill) + def rjust(self, width, fill=' '): """ Similar to self.strval.rjust(width), but takes expected terminal @@ -247,7 +277,16 @@ class FormattedValue: """ return self._pad(width, fill) + self.coloredval -controlchars_re = re.compile(r'[\x00-\x31\x7f-\xa0]') + def color_ljust(self, width, fill=' '): + """ + Similar to self.ljust(width), but uses this value's colored + representation, and does not take color escape codes into account + in determining width. + """ + return self.coloredval + self._pad(width, fill) + +unicode_controlchars_re = re.compile(r'[\x00-\x31\x7f-\xa0]') +controlchars_re = re.compile(r'[\x00-\x31\x7f-\xff]') def _show_control_chars(match): txt = repr(match.group(0)) @@ -273,9 +312,13 @@ def format_value(val, casstype, output_encoding, addcolor=False, time_format='', if val is None: bval = 'null' color = RED + elif isinstance(val, DecodeError): + casstype = 'BytesType' + bval = repr(val.thebytes) + color = RED elif casstype == 'UTF8Type': escapedval = val.replace(u'\\', u'\\\\') - escapedval = controlchars_re.sub(_show_control_chars, escapedval) + escapedval = unicode_controlchars_re.sub(_show_control_chars, escapedval) bval = escapedval.encode(output_encoding, 'backslashreplace') displaywidth = wcwidth.wcswidth(bval.decode(output_encoding)) if addcolor: @@ -352,10 +395,22 @@ class Shell(cmd.Cmd): self.prompt = "" def myformat_value(self, val, casstype): + if isinstance(val, DecodeError): + self.decoding_errors.append(val) return format_value(val, casstype, self.output_codec.name, addcolor=self.color, time_format=self.display_time_format, float_precision=self.display_float_precision) + def myformat_colname(self, name): + if isinstance(name, DecodeError): + self.decoding_errors.append(name) + name = str(name) + color = RED + else: + color = MAGENTA + return FormattedValue(name, self.applycolor(name, color), + wcwidth.wcswidth(name.decode(self.output_codec.name))) + def report_connection(self): self.show_host() self.show_version() @@ -632,58 +687,68 @@ class Shell(cmd.Cmd): return False if self.cursor.description is _COUNT_DESCRIPTION: - self.print_count_result() + self.print_count_result(self.cursor) elif self.cursor.description is not _VOID_DESCRIPTION: - self.print_result() + self.print_result(self.cursor) return True def determine_decoder_for(self, cfname, ksname=None): + decoder = ErrorHandlingSchemaDecoder if ksname is None: ksname = self.current_keyspace - schema = self.schema_overrides.get((ksname, cfname), None) - if schema: - def use_my_schema_decoder(real_schema): - return cql.decoders.SchemaDecoder(schema.join(real_schema)) - return use_my_schema_decoder - - def print_count_result(self): - if not self.cursor.result: + overrides = self.schema_overrides.get((ksname, cfname), None) + if overrides: + decoder = partial(decoder, overrides=overrides) + return decoder + + def print_count_result(self, cursor): + if not cursor.result: return self.printout('count') self.printout('-----') - self.printout(self.cursor.result[0]) + self.printout(cursor.result[0]) self.printout("") - def print_result(self): + def print_result(self, cursor): + self.decoding_errors = [] + # first pass: see if we have a static column set last_description = None - for row in self.cursor: - if last_description is not None and self.cursor.description != last_description: + for row in cursor: + if last_description is not None and cursor.description != last_description: static = False break - last_description = self.cursor.description + last_description = cursor.description else: static = True - self.cursor._reset() + cursor._reset() if static: - self.print_static_result() + self.print_static_result(cursor) else: - self.print_dynamic_result() + self.print_dynamic_result(cursor) self.printout("") - def print_static_result(self): - colnames, coltypes = zip(*self.cursor.description)[:2] - formatted_data = [map(self.myformat_value, row, coltypes) for row in self.cursor] + if self.decoding_errors: + for err in self.decoding_errors[:2]: + self.printout(err.message(), color=RED) + if len(self.decoding_errors) > 2: + self.printout('%d more decoding errors suppressed.' + % (len(self.decoding_errors) - 2), color=RED) + + def print_static_result(self, cursor): + colnames, coltypes = zip(*cursor.description)[:2] + formatted_names = map(self.myformat_colname, colnames) + formatted_data = [map(self.myformat_value, row, coltypes) for row in cursor] # determine column widths - widths = map(len, colnames) + widths = map(len, formatted_names) for fmtrow in formatted_data: for num, col in enumerate(fmtrow): - widths[num] = max(widths[num], len(col.strval)) + widths[num] = max(widths[num], len(col)) # print header - header = ' | '.join(self.applycolor(name.ljust(w), MAGENTA) for (name, w) in zip(colnames, widths)) + header = ' | '.join(hdr.color_ljust(w) for (hdr, w) in zip(formatted_names, widths)) print ' ' + header.rstrip() print '-%s-' % '-+-'.join('-' * w for w in widths) @@ -692,12 +757,12 @@ class Shell(cmd.Cmd): line = ' | '.join(col.color_rjust(w) for (col, w) in zip(row, widths)) print ' ' + line - def print_dynamic_result(self): - for row in self.cursor: - colnames, coltypes = zip(*self.cursor.description)[:2] - colnames = [self.applycolor(name, MAGENTA) for name in colnames] + def print_dynamic_result(self, cursor): + for row in cursor: + colnames, coltypes = zip(*cursor.description)[:2] + colnames = [self.myformat_colname(name) for name in colnames] colvals = [self.myformat_value(val, casstype) for (val, casstype) in zip(row, coltypes)] - line = ' | '.join(name + ',' + col.coloredval for (col, name) in zip(colvals, colnames)) + line = ' | '.join('%s,%s' % (n.coloredval, v.coloredval) for (n, v) in zip(colnames, colvals)) print ' ' + line def emptyline(self): @@ -995,8 +1060,9 @@ class Shell(cmd.Cmd): validator_class = cqlhandling.find_validator_class(cqltype) except KeyError: self.printerr('Error: validator type %s not found.' % cqltype) - self.add_assumption(params['ks'], params['cf'], params['colname'], - overridetype, validator_class) + else: + self.add_assumption(params['ks'], params['cf'], params['colname'], + overridetype, validator_class) def do_EOF(self, parsed): """ @@ -1696,15 +1762,27 @@ class FakeCqlMetadata: self.default_name_type = None self.default_value_type = None - def join(self, realschema): - f = self.__class__() - f.default_name_type = self.default_name_type or realschema.default_name_type - f.default_value_types = self.default_value_type or realschema.default_value_type - f.name_types = realschema.name_types.copy() - f.name_types.update(self.name_types) - f.value_types = realschema.value_types.copy() - f.value_types.update(self.value_types) - return f +class OverrideableSchemaDecoder(cql.decoders.SchemaDecoder): + def __init__(self, schema, overrides=None): + cql.decoders.SchemaDecoder.__init__(self, schema) + self.apply_schema_overrides(overrides) + + def apply_schema_overrides(self, overrides): + if overrides is None: + return + if overrides.default_name_type is not None: + self.schema.default_name_type = overrides.default_name_type + if overrides.default_value_type is not None: + self.schema.default_value_type = overrides.default_value_type + self.schema.name_types.update(overrides.name_types) + self.schema.value_types.update(overrides.value_types) + +class ErrorHandlingSchemaDecoder(OverrideableSchemaDecoder): + def name_decode_error(self, err, namebytes, expectedtype): + return DecodeError(namebytes, err, expectedtype) + + def value_decode_error(self, err, namebytes, valuebytes, expectedtype): + return DecodeError(valuebytes, err, expectedtype, colname=namebytes) def option_with_default(cparser_getter, section, option, default=None):
