You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by br...@apache.org on 2012/02/21 15:14:15 UTC
[5/5] git commit: cqlsh: handle deserialization errors. Patch by paul
cannon, reviewed by brandonwilliams for CASSANDRA-3874
cqlsh: handle deserialization errors.
Patch by paul cannon, reviewed by brandonwilliams for CASSANDRA-3874
Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/cd36f975
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/cd36f975
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/cd36f975
Branch: refs/heads/cassandra-1.1
Commit: cd36f9757150db6a369d70f633971dadd571509a
Parents: 636e41d
Author: Brandon Williams <br...@apache.org>
Authored: Thu Feb 16 09:21:55 2012 -0600
Committer: Brandon Williams <br...@apache.org>
Committed: Thu Feb 16 09:21:55 2012 -0600
----------------------------------------------------------------------
bin/cqlsh | 162 +++++++++++++++++++++++++++++++++++++++++---------------
1 files changed, 120 insertions(+), 42 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/cassandra/blob/cd36f975/bin/cqlsh
----------------------------------------------------------------------
diff --git a/bin/cqlsh b/bin/cqlsh
index c89aa16..47763ec 100755
--- a/bin/cqlsh
+++ b/bin/cqlsh
@@ -214,6 +214,25 @@ class NoKeyspaceError(Exception):
class KeyspaceNotFound(Exception):
pass
+class DecodeError(Exception):
+ def __init__(self, thebytes, err, expectedtype, colname=None):
+ self.thebytes = thebytes
+ self.err = err
+ self.expectedtype = expectedtype
+ self.colname = colname
+
+ def __str__(self):
+ return str(self.thebytes)
+
+ def message(self):
+ what = 'column name %r' % (self.thebytes,)
+ if self.colname is not None:
+ what = 'value %r (for column %r)' % (self.thebytes, self.colname)
+ return 'Failed to decode %s as %s: %s' % (what, self.expectedtype, self.err)
+
+ def __repr__(self):
+ return '<%s %s>' % (self.__class__.__name__, self.message())
+
def trim_if_present(s, prefix):
if s.startswith(prefix):
return s[len(prefix):]
@@ -225,12 +244,23 @@ class FormattedValue:
self.coloredval = coloredval
self.displaywidth = displaywidth
+ def __len__(self):
+ return len(self.strval)
+
def _pad(self, width, fill=' '):
if width > self.displaywidth:
return fill * (width - self.displaywidth)
else:
return ''
+ def ljust(self, width, fill=' '):
+ """
+ Similar to self.strval.ljust(width), but takes expected terminal
+ display width into account for special characters, and does not
+ take color escape codes into account.
+ """
+ return self.strval + self._pad(width, fill)
+
def rjust(self, width, fill=' '):
"""
Similar to self.strval.rjust(width), but takes expected terminal
@@ -247,7 +277,16 @@ class FormattedValue:
"""
return self._pad(width, fill) + self.coloredval
-controlchars_re = re.compile(r'[\x00-\x31\x7f-\xa0]')
+ def color_ljust(self, width, fill=' '):
+ """
+ Similar to self.ljust(width), but uses this value's colored
+ representation, and does not take color escape codes into account
+ in determining width.
+ """
+ return self.coloredval + self._pad(width, fill)
+
+unicode_controlchars_re = re.compile(r'[\x00-\x31\x7f-\xa0]')
+controlchars_re = re.compile(r'[\x00-\x31\x7f-\xff]')
def _show_control_chars(match):
txt = repr(match.group(0))
@@ -273,9 +312,13 @@ def format_value(val, casstype, output_encoding, addcolor=False, time_format='',
if val is None:
bval = 'null'
color = RED
+ elif isinstance(val, DecodeError):
+ casstype = 'BytesType'
+ bval = repr(val.thebytes)
+ color = RED
elif casstype == 'UTF8Type':
escapedval = val.replace(u'\\', u'\\\\')
- escapedval = controlchars_re.sub(_show_control_chars, escapedval)
+ escapedval = unicode_controlchars_re.sub(_show_control_chars, escapedval)
bval = escapedval.encode(output_encoding, 'backslashreplace')
displaywidth = wcwidth.wcswidth(bval.decode(output_encoding))
if addcolor:
@@ -352,10 +395,22 @@ class Shell(cmd.Cmd):
self.prompt = ""
def myformat_value(self, val, casstype):
+ if isinstance(val, DecodeError):
+ self.decoding_errors.append(val)
return format_value(val, casstype, self.output_codec.name,
addcolor=self.color, time_format=self.display_time_format,
float_precision=self.display_float_precision)
+ def myformat_colname(self, name):
+ if isinstance(name, DecodeError):
+ self.decoding_errors.append(name)
+ name = str(name)
+ color = RED
+ else:
+ color = MAGENTA
+ return FormattedValue(name, self.applycolor(name, color),
+ wcwidth.wcswidth(name.decode(self.output_codec.name)))
+
def report_connection(self):
self.show_host()
self.show_version()
@@ -632,58 +687,68 @@ class Shell(cmd.Cmd):
return False
if self.cursor.description is _COUNT_DESCRIPTION:
- self.print_count_result()
+ self.print_count_result(self.cursor)
elif self.cursor.description is not _VOID_DESCRIPTION:
- self.print_result()
+ self.print_result(self.cursor)
return True
def determine_decoder_for(self, cfname, ksname=None):
+ decoder = ErrorHandlingSchemaDecoder
if ksname is None:
ksname = self.current_keyspace
- schema = self.schema_overrides.get((ksname, cfname), None)
- if schema:
- def use_my_schema_decoder(real_schema):
- return cql.decoders.SchemaDecoder(schema.join(real_schema))
- return use_my_schema_decoder
-
- def print_count_result(self):
- if not self.cursor.result:
+ overrides = self.schema_overrides.get((ksname, cfname), None)
+ if overrides:
+ decoder = partial(decoder, overrides=overrides)
+ return decoder
+
+ def print_count_result(self, cursor):
+ if not cursor.result:
return
self.printout('count')
self.printout('-----')
- self.printout(self.cursor.result[0])
+ self.printout(cursor.result[0])
self.printout("")
- def print_result(self):
+ def print_result(self, cursor):
+ self.decoding_errors = []
+
# first pass: see if we have a static column set
last_description = None
- for row in self.cursor:
- if last_description is not None and self.cursor.description != last_description:
+ for row in cursor:
+ if last_description is not None and cursor.description != last_description:
static = False
break
- last_description = self.cursor.description
+ last_description = cursor.description
else:
static = True
- self.cursor._reset()
+ cursor._reset()
if static:
- self.print_static_result()
+ self.print_static_result(cursor)
else:
- self.print_dynamic_result()
+ self.print_dynamic_result(cursor)
self.printout("")
- def print_static_result(self):
- colnames, coltypes = zip(*self.cursor.description)[:2]
- formatted_data = [map(self.myformat_value, row, coltypes) for row in self.cursor]
+ if self.decoding_errors:
+ for err in self.decoding_errors[:2]:
+ self.printout(err.message(), color=RED)
+ if len(self.decoding_errors) > 2:
+ self.printout('%d more decoding errors suppressed.'
+ % (len(self.decoding_errors) - 2), color=RED)
+
+ def print_static_result(self, cursor):
+ colnames, coltypes = zip(*cursor.description)[:2]
+ formatted_names = map(self.myformat_colname, colnames)
+ formatted_data = [map(self.myformat_value, row, coltypes) for row in cursor]
# determine column widths
- widths = map(len, colnames)
+ widths = map(len, formatted_names)
for fmtrow in formatted_data:
for num, col in enumerate(fmtrow):
- widths[num] = max(widths[num], len(col.strval))
+ widths[num] = max(widths[num], len(col))
# print header
- header = ' | '.join(self.applycolor(name.ljust(w), MAGENTA) for (name, w) in zip(colnames, widths))
+ header = ' | '.join(hdr.color_ljust(w) for (hdr, w) in zip(formatted_names, widths))
print ' ' + header.rstrip()
print '-%s-' % '-+-'.join('-' * w for w in widths)
@@ -692,12 +757,12 @@ class Shell(cmd.Cmd):
line = ' | '.join(col.color_rjust(w) for (col, w) in zip(row, widths))
print ' ' + line
- def print_dynamic_result(self):
- for row in self.cursor:
- colnames, coltypes = zip(*self.cursor.description)[:2]
- colnames = [self.applycolor(name, MAGENTA) for name in colnames]
+ def print_dynamic_result(self, cursor):
+ for row in cursor:
+ colnames, coltypes = zip(*cursor.description)[:2]
+ colnames = [self.myformat_colname(name) for name in colnames]
colvals = [self.myformat_value(val, casstype) for (val, casstype) in zip(row, coltypes)]
- line = ' | '.join(name + ',' + col.coloredval for (col, name) in zip(colvals, colnames))
+ line = ' | '.join('%s,%s' % (n.coloredval, v.coloredval) for (n, v) in zip(colnames, colvals))
print ' ' + line
def emptyline(self):
@@ -995,8 +1060,9 @@ class Shell(cmd.Cmd):
validator_class = cqlhandling.find_validator_class(cqltype)
except KeyError:
self.printerr('Error: validator type %s not found.' % cqltype)
- self.add_assumption(params['ks'], params['cf'], params['colname'],
- overridetype, validator_class)
+ else:
+ self.add_assumption(params['ks'], params['cf'], params['colname'],
+ overridetype, validator_class)
def do_EOF(self, parsed):
"""
@@ -1696,15 +1762,27 @@ class FakeCqlMetadata:
self.default_name_type = None
self.default_value_type = None
- def join(self, realschema):
- f = self.__class__()
- f.default_name_type = self.default_name_type or realschema.default_name_type
- f.default_value_types = self.default_value_type or realschema.default_value_type
- f.name_types = realschema.name_types.copy()
- f.name_types.update(self.name_types)
- f.value_types = realschema.value_types.copy()
- f.value_types.update(self.value_types)
- return f
+class OverrideableSchemaDecoder(cql.decoders.SchemaDecoder):
+ def __init__(self, schema, overrides=None):
+ cql.decoders.SchemaDecoder.__init__(self, schema)
+ self.apply_schema_overrides(overrides)
+
+ def apply_schema_overrides(self, overrides):
+ if overrides is None:
+ return
+ if overrides.default_name_type is not None:
+ self.schema.default_name_type = overrides.default_name_type
+ if overrides.default_value_type is not None:
+ self.schema.default_value_type = overrides.default_value_type
+ self.schema.name_types.update(overrides.name_types)
+ self.schema.value_types.update(overrides.value_types)
+
+class ErrorHandlingSchemaDecoder(OverrideableSchemaDecoder):
+ def name_decode_error(self, err, namebytes, expectedtype):
+ return DecodeError(namebytes, err, expectedtype)
+
+ def value_decode_error(self, err, namebytes, valuebytes, expectedtype):
+ return DecodeError(valuebytes, err, expectedtype, colname=namebytes)
def option_with_default(cparser_getter, section, option, default=None):