You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by br...@apache.org on 2012/02/21 15:14:15 UTC

[5/5] git commit: cqlsh: handle deserialization errors. Patch by paul cannon, reviewed by brandonwilliams for CASSANDRA-3874

cqlsh: handle deserialization errors.
Patch by paul cannon, reviewed by brandonwilliams for CASSANDRA-3874


Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/cd36f975
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/cd36f975
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/cd36f975

Branch: refs/heads/cassandra-1.1
Commit: cd36f9757150db6a369d70f633971dadd571509a
Parents: 636e41d
Author: Brandon Williams <br...@apache.org>
Authored: Thu Feb 16 09:21:55 2012 -0600
Committer: Brandon Williams <br...@apache.org>
Committed: Thu Feb 16 09:21:55 2012 -0600

----------------------------------------------------------------------
 bin/cqlsh |  162 +++++++++++++++++++++++++++++++++++++++++---------------
 1 files changed, 120 insertions(+), 42 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/cassandra/blob/cd36f975/bin/cqlsh
----------------------------------------------------------------------
diff --git a/bin/cqlsh b/bin/cqlsh
index c89aa16..47763ec 100755
--- a/bin/cqlsh
+++ b/bin/cqlsh
@@ -214,6 +214,25 @@ class NoKeyspaceError(Exception):
 class KeyspaceNotFound(Exception):
     pass
 
+class DecodeError(Exception):
+    def __init__(self, thebytes, err, expectedtype, colname=None):
+        self.thebytes = thebytes
+        self.err = err
+        self.expectedtype = expectedtype
+        self.colname = colname
+
+    def __str__(self):
+        return str(self.thebytes)
+
+    def message(self):
+        what = 'column name %r' % (self.thebytes,)
+        if self.colname is not None:
+            what = 'value %r (for column %r)' % (self.thebytes, self.colname)
+        return 'Failed to decode %s as %s: %s' % (what, self.expectedtype, self.err)
+
+    def __repr__(self):
+        return '<%s %s>' % (self.__class__.__name__, self.message())
+
 def trim_if_present(s, prefix):
     if s.startswith(prefix):
         return s[len(prefix):]
@@ -225,12 +244,23 @@ class FormattedValue:
         self.coloredval = coloredval
         self.displaywidth = displaywidth
 
+    def __len__(self):
+        return len(self.strval)
+
     def _pad(self, width, fill=' '):
         if width > self.displaywidth:
             return fill * (width - self.displaywidth)
         else:
             return ''
 
+    def ljust(self, width, fill=' '):
+        """
+        Similar to self.strval.ljust(width), but takes expected terminal
+        display width into account for special characters, and does not
+        take color escape codes into account.
+        """
+        return self.strval + self._pad(width, fill)
+
     def rjust(self, width, fill=' '):
         """
         Similar to self.strval.rjust(width), but takes expected terminal
@@ -247,7 +277,16 @@ class FormattedValue:
         """
         return self._pad(width, fill) + self.coloredval
 
-controlchars_re = re.compile(r'[\x00-\x31\x7f-\xa0]')
+    def color_ljust(self, width, fill=' '):
+        """
+        Similar to self.ljust(width), but uses this value's colored
+        representation, and does not take color escape codes into account
+        in determining width.
+        """
+        return self.coloredval + self._pad(width, fill)
+
+unicode_controlchars_re = re.compile(r'[\x00-\x31\x7f-\xa0]')
+controlchars_re = re.compile(r'[\x00-\x31\x7f-\xff]')
 
 def _show_control_chars(match):
     txt = repr(match.group(0))
@@ -273,9 +312,13 @@ def format_value(val, casstype, output_encoding, addcolor=False, time_format='',
     if val is None:
         bval = 'null'
         color = RED
+    elif isinstance(val, DecodeError):
+        casstype = 'BytesType'
+        bval = repr(val.thebytes)
+        color = RED
     elif casstype == 'UTF8Type':
         escapedval = val.replace(u'\\', u'\\\\')
-        escapedval = controlchars_re.sub(_show_control_chars, escapedval)
+        escapedval = unicode_controlchars_re.sub(_show_control_chars, escapedval)
         bval = escapedval.encode(output_encoding, 'backslashreplace')
         displaywidth = wcwidth.wcswidth(bval.decode(output_encoding))
         if addcolor:
@@ -352,10 +395,22 @@ class Shell(cmd.Cmd):
             self.prompt = ""
 
     def myformat_value(self, val, casstype):
+        if isinstance(val, DecodeError):
+            self.decoding_errors.append(val)
         return format_value(val, casstype, self.output_codec.name,
                             addcolor=self.color, time_format=self.display_time_format,
                             float_precision=self.display_float_precision)
 
+    def myformat_colname(self, name):
+        if isinstance(name, DecodeError):
+            self.decoding_errors.append(name)
+            name = str(name)
+            color = RED
+        else:
+            color = MAGENTA
+        return FormattedValue(name, self.applycolor(name, color),
+                              wcwidth.wcswidth(name.decode(self.output_codec.name)))
+
     def report_connection(self):
         self.show_host()
         self.show_version()
@@ -632,58 +687,68 @@ class Shell(cmd.Cmd):
                 return False
 
         if self.cursor.description is _COUNT_DESCRIPTION:
-            self.print_count_result()
+            self.print_count_result(self.cursor)
         elif self.cursor.description is not _VOID_DESCRIPTION:
-            self.print_result()
+            self.print_result(self.cursor)
         return True
 
     def determine_decoder_for(self, cfname, ksname=None):
+        decoder = ErrorHandlingSchemaDecoder
         if ksname is None:
             ksname = self.current_keyspace
-        schema = self.schema_overrides.get((ksname, cfname), None)
-        if schema:
-            def use_my_schema_decoder(real_schema):
-                return cql.decoders.SchemaDecoder(schema.join(real_schema))
-            return use_my_schema_decoder
-
-    def print_count_result(self):
-        if not self.cursor.result:
+        overrides = self.schema_overrides.get((ksname, cfname), None)
+        if overrides:
+            decoder = partial(decoder, overrides=overrides)
+        return decoder
+
+    def print_count_result(self, cursor):
+        if not cursor.result:
             return
         self.printout('count')
         self.printout('-----')
-        self.printout(self.cursor.result[0])
+        self.printout(cursor.result[0])
         self.printout("")
 
-    def print_result(self):
+    def print_result(self, cursor):
+        self.decoding_errors = []
+
         # first pass: see if we have a static column set
         last_description = None
-        for row in self.cursor:
-            if last_description is not None and self.cursor.description != last_description:
+        for row in cursor:
+            if last_description is not None and cursor.description != last_description:
                 static = False
                 break
-            last_description = self.cursor.description
+            last_description = cursor.description
         else:
             static = True
-        self.cursor._reset()
+        cursor._reset()
 
         if static:
-            self.print_static_result()
+            self.print_static_result(cursor)
         else:
-            self.print_dynamic_result()
+            self.print_dynamic_result(cursor)
         self.printout("")
 
-    def print_static_result(self):
-        colnames, coltypes = zip(*self.cursor.description)[:2]
-        formatted_data = [map(self.myformat_value, row, coltypes) for row in self.cursor]
+        if self.decoding_errors:
+            for err in self.decoding_errors[:2]:
+                self.printout(err.message(), color=RED)
+            if len(self.decoding_errors) > 2:
+                self.printout('%d more decoding errors suppressed.'
+                              % (len(self.decoding_errors) - 2), color=RED)
+
+    def print_static_result(self, cursor):
+        colnames, coltypes = zip(*cursor.description)[:2]
+        formatted_names = map(self.myformat_colname, colnames)
+        formatted_data = [map(self.myformat_value, row, coltypes) for row in cursor]
 
         # determine column widths
-        widths = map(len, colnames)
+        widths = map(len, formatted_names)
         for fmtrow in formatted_data:
             for num, col in enumerate(fmtrow):
-                widths[num] = max(widths[num], len(col.strval))
+                widths[num] = max(widths[num], len(col))
 
         # print header
-        header = ' | '.join(self.applycolor(name.ljust(w), MAGENTA) for (name, w) in zip(colnames, widths))
+        header = ' | '.join(hdr.color_ljust(w) for (hdr, w) in zip(formatted_names, widths))
         print ' ' + header.rstrip()
         print '-%s-' % '-+-'.join('-' * w for w in widths)
 
@@ -692,12 +757,12 @@ class Shell(cmd.Cmd):
             line = ' | '.join(col.color_rjust(w) for (col, w) in zip(row, widths))
             print ' ' + line
 
-    def print_dynamic_result(self):
-        for row in self.cursor:
-            colnames, coltypes = zip(*self.cursor.description)[:2]
-            colnames = [self.applycolor(name, MAGENTA) for name in colnames]
+    def print_dynamic_result(self, cursor):
+        for row in cursor:
+            colnames, coltypes = zip(*cursor.description)[:2]
+            colnames = [self.myformat_colname(name) for name in colnames]
             colvals = [self.myformat_value(val, casstype) for (val, casstype) in zip(row, coltypes)]
-            line = ' | '.join(name + ',' + col.coloredval for (col, name) in zip(colvals, colnames))
+            line = ' | '.join('%s,%s' % (n.coloredval, v.coloredval) for (n, v) in zip(colnames, colvals))
             print ' ' + line
 
     def emptyline(self):
@@ -995,8 +1060,9 @@ class Shell(cmd.Cmd):
                 validator_class = cqlhandling.find_validator_class(cqltype)
             except KeyError:
                 self.printerr('Error: validator type %s not found.' % cqltype)
-            self.add_assumption(params['ks'], params['cf'], params['colname'],
-                                overridetype, validator_class)
+            else:
+                self.add_assumption(params['ks'], params['cf'], params['colname'],
+                                    overridetype, validator_class)
 
     def do_EOF(self, parsed):
         """
@@ -1696,15 +1762,27 @@ class FakeCqlMetadata:
         self.default_name_type = None
         self.default_value_type = None
 
-    def join(self, realschema):
-        f = self.__class__()
-        f.default_name_type = self.default_name_type or realschema.default_name_type
-        f.default_value_types = self.default_value_type or realschema.default_value_type
-        f.name_types = realschema.name_types.copy()
-        f.name_types.update(self.name_types)
-        f.value_types = realschema.value_types.copy()
-        f.value_types.update(self.value_types)
-        return f
+class OverrideableSchemaDecoder(cql.decoders.SchemaDecoder):
+    def __init__(self, schema, overrides=None):
+        cql.decoders.SchemaDecoder.__init__(self, schema)
+        self.apply_schema_overrides(overrides)
+
+    def apply_schema_overrides(self, overrides):
+        if overrides is None:
+            return
+        if overrides.default_name_type is not None:
+            self.schema.default_name_type = overrides.default_name_type
+        if overrides.default_value_type is not None:
+            self.schema.default_value_type = overrides.default_value_type
+        self.schema.name_types.update(overrides.name_types)
+        self.schema.value_types.update(overrides.value_types)
+
+class ErrorHandlingSchemaDecoder(OverrideableSchemaDecoder):
+    def name_decode_error(self, err, namebytes, expectedtype):
+        return DecodeError(namebytes, err, expectedtype)
+
+    def value_decode_error(self, err, namebytes, valuebytes, expectedtype):
+        return DecodeError(valuebytes, err, expectedtype, colname=namebytes)
 
 
 def option_with_default(cparser_getter, section, option, default=None):