You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by al...@apache.org on 2015/12/15 22:40:43 UTC

[2/6] cassandra git commit: Merge branch 'cassandra-2.1' into cassandra-2.2

http://git-wip-us.apache.org/repos/asf/cassandra/blob/57d558fc/pylib/cqlshlib/copyutil.py
----------------------------------------------------------------------
diff --cc pylib/cqlshlib/copyutil.py
index a2fab00,f699e64..a117ec3
--- a/pylib/cqlshlib/copyutil.py
+++ b/pylib/cqlshlib/copyutil.py
@@@ -23,19 -26,25 +26,25 @@@ import sy
  import time
  import traceback
  
- from StringIO import StringIO
+ from calendar import timegm
+ from collections import defaultdict, deque, namedtuple
+ from decimal import Decimal
  from random import randrange
+ from StringIO import StringIO
  from threading import Lock
+ from uuid import UUID
  
  from cassandra.cluster import Cluster
+ from cassandra.cqltypes import ReversedType, UserType
  from cassandra.metadata import protect_name, protect_names
- from cassandra.policies import RetryPolicy, WhiteListRoundRobinPolicy, TokenAwarePolicy
- from cassandra.query import tuple_factory
+ from cassandra.policies import RetryPolicy, WhiteListRoundRobinPolicy, TokenAwarePolicy, DCAwareRoundRobinPolicy
+ from cassandra.query import BatchStatement, BatchType, SimpleStatement, tuple_factory
+ from cassandra.util import Date, Time
  
- 
- import sslhandling
+ from cql3handling import CqlRuleSet
  from displaying import NO_COLOR_MAP
 -from formatting import format_value_default, EMPTY, get_formatter
 +from formatting import format_value_default, DateTimeFormat, EMPTY, get_formatter
+ from sslhandling import ssl_settings
  
  
  def parse_options(shell, opts):
@@@ -65,10 -74,13 +74,15 @@@
      # by default the page timeout is 10 seconds per 1000 entries in the page size or 10 seconds if pagesize is smaller
      csv_options['pagetimeout'] = int(opts.pop('pagetimeout', max(10, 10 * (csv_options['pagesize'] / 1000))))
      csv_options['maxattempts'] = int(opts.pop('maxattempts', 5))
-     csv_options['float_precision'] = shell.display_float_precision
 -    csv_options['dtformats'] = opts.pop('timeformat', shell.display_time_format)
 +    csv_options['dtformats'] = DateTimeFormat(opts.pop('timeformat', shell.display_timestamp_format),
 +                                              shell.display_date_format,
 +                                              shell.display_nanotime_format)
+     csv_options['float_precision'] = shell.display_float_precision
+     csv_options['chunksize'] = int(opts.pop('chunksize', 1000))
+     csv_options['ingestrate'] = int(opts.pop('ingestrate', 100000))
+     csv_options['maxbatchsize'] = int(opts.pop('maxbatchsize', 20))
+     csv_options['minbatchsize'] = int(opts.pop('minbatchsize', 2))
+     csv_options['reportfrequency'] = float(opts.pop('reportfrequency', 0.25))
  
      return csv_options, dialect_options, opts
  
@@@ -371,30 -648,18 +650,18 @@@ class ExportProcess(ChildProcess)
      An child worker process for the export task, ExportTask.
      """
  
-     def __init__(self, inmsg, outmsg, ks, cf, columns, dialect_options, csv_options,
-                  debug, port, cql_version, auth_provider, ssl, protocol_version, config_file):
-         mp.Process.__init__(self, target=self.run)
-         self.inmsg = inmsg
-         self.outmsg = outmsg
-         self.ks = ks
-         self.cf = cf
-         self.columns = columns
-         self.dialect_options = dialect_options
+     def __init__(self, params):
+         ChildProcess.__init__(self, params=params, target=self.run)
+         self.dialect_options = params['dialect_options']
          self.hosts_to_sessions = dict()
  
-         self.debug = debug
-         self.port = port
-         self.cql_version = cql_version
-         self.auth_provider = auth_provider
-         self.ssl = ssl
-         self.protocol_version = protocol_version
-         self.config_file = config_file
- 
+         csv_options = params['csv_options']
          self.encoding = csv_options['encoding']
 -        self.time_format = csv_options['dtformats']
 +        self.date_time_format = csv_options['dtformats']
          self.float_precision = csv_options['float_precision']
          self.nullval = csv_options['nullval']
-         self.maxjobs = csv_options['jobs']
+         self.max_attempts = csv_options['maxattempts']
+         self.max_requests = csv_options['maxrequests']
          self.csv_options = csv_options
          self.formatters = dict()
  
@@@ -600,13 -851,424 +853,424 @@@
          return query
  
  
+ class ImportConversion(object):
+     """
+     A class for converting strings to values when importing from csv, used by ImportProcess,
+     the parent.
+     """
+     def __init__(self, parent, table_meta, statement):
+         self.ks = parent.ks
+         self.cf = parent.cf
+         self.columns = parent.columns
+         self.nullval = parent.nullval
+         self.printmsg = parent.printmsg
+         self.table_meta = table_meta
+         self.primary_key_indexes = [self.columns.index(col.name) for col in self.table_meta.primary_key]
+         self.partition_key_indexes = [self.columns.index(col.name) for col in self.table_meta.partition_key]
+ 
+         self.proto_version = statement.protocol_version
+         self.cqltypes = dict([(c.name, c.type) for c in statement.column_metadata])
+         self.converters = dict([(c.name, self._get_converter(c.type)) for c in statement.column_metadata])
+ 
+     def _get_converter(self, cql_type):
+         """
+         Return a function that converts a string into a value the can be passed
+         into BoundStatement.bind() for the given cql type. See cassandra.cqltypes
+         for more details.
+         """
+         def unprotect(v):
+             if v is not None:
+                 return CqlRuleSet.dequote_value(v)
+ 
+         def convert(t, v):
+             return converters.get(t.typename, convert_unknown)(unprotect(v), ct=t)
+ 
+         def split(val, sep=','):
+             """
+             Split into a list of values whenever we encounter a separator but
+             ignore separators inside parentheses or single quotes, except for the two
+             outermost parentheses, which will be ignored. We expect val to be at least
+             2 characters long (the two outer parentheses).
+             """
+             ret = []
+             last = 1
+             level = 0
+             quote = False
+             for i, c in enumerate(val):
+                 if c == '{' or c == '[' or c == '(':
+                     level += 1
+                 elif c == '}' or c == ']' or c == ')':
+                     level -= 1
+                 elif c == '\'':
+                     quote = not quote
+                 elif c == sep and level == 1 and not quote:
+                     ret.append(val[last:i])
+                     last = i + 1
+             else:
+                 if last < len(val) - 1:
+                     ret.append(val[last:-1])
+ 
+             return ret
+ 
+         # this should match all possible CQL datetime formats
+         p = re.compile("(\d{4})\-(\d{2})\-(\d{2})\s?(?:'T')?" +  # YYYY-MM-DD[( |'T')]
+                        "(?:(\d{2}):(\d{2})(?::(\d{2}))?)?" +  # [HH:MM[:SS]]
+                        "(?:([+\-])(\d{2}):?(\d{2}))?")  # [(+|-)HH[:]MM]]
+ 
+         def convert_date(val, **_):
+             m = p.match(val)
+             if not m:
+                 raise ValueError("can't interpret %r as a date" % (val,))
+ 
+             # https://docs.python.org/2/library/time.html#time.struct_time
+             tval = time.struct_time((int(m.group(1)), int(m.group(2)), int(m.group(3)),  # year, month, day
+                                      int(m.group(4)) if m.group(4) else 0,  # hour
+                                      int(m.group(5)) if m.group(5) else 0,  # minute
+                                      int(m.group(6)) if m.group(6) else 0,  # second
+                                      0, 1, -1))  # day of week, day of year, dst-flag
+ 
+             if m.group(7):
+                 offset = (int(m.group(8)) * 3600 + int(m.group(9)) * 60) * int(m.group(7) + '1')
+             else:
+                 offset = -time.timezone
+ 
+             # scale seconds to millis for the raw value
+             return (timegm(tval) + offset) * 1e3
+ 
+         def convert_tuple(val, ct=cql_type):
+             return tuple(convert(t, v) for t, v in zip(ct.subtypes, split(val)))
+ 
+         def convert_list(val, ct=cql_type):
+             return list(convert(ct.subtypes[0], v) for v in split(val))
+ 
+         def convert_set(val, ct=cql_type):
+             return frozenset(convert(ct.subtypes[0], v) for v in split(val))
+ 
+         def convert_map(val, ct=cql_type):
+             """
+             We need to pass to BoundStatement.bind() a dict() because it calls iteritems(),
+             except we can't create a dict with another dict as the key, hence we use a class
+             that adds iteritems to a frozen set of tuples (which is how dict are normally made
+             immutable in python).
+             """
+             class ImmutableDict(frozenset):
+                 iteritems = frozenset.__iter__
+ 
+             return ImmutableDict(frozenset((convert(ct.subtypes[0], v[0]), convert(ct.subtypes[1], v[1]))
+                                  for v in [split('{%s}' % vv, sep=':') for vv in split(val)]))
+ 
+         def convert_user_type(val, ct=cql_type):
+             """
+             A user type is a dictionary except that we must convert each key into
+             an attribute, so we are using named tuples. It must also be hashable,
+             so we cannot use dictionaries. Maybe there is a way to instantiate ct
+             directly but I could not work it out.
+             """
+             vals = [v for v in [split('{%s}' % vv, sep=':') for vv in split(val)]]
+             ret_type = namedtuple(ct.typename, [unprotect(v[0]) for v in vals])
+             return ret_type(*tuple(convert(t, v[1]) for t, v in zip(ct.subtypes, vals)))
+ 
+         def convert_single_subtype(val, ct=cql_type):
+             return converters.get(ct.subtypes[0].typename, convert_unknown)(val, ct=ct.subtypes[0])
+ 
+         def convert_unknown(val, ct=cql_type):
+             if issubclass(ct, UserType):
+                 return convert_user_type(val, ct=ct)
+             elif issubclass(ct, ReversedType):
+                 return convert_single_subtype(val, ct=ct)
+ 
+             self.printmsg("Unknown type %s (%s) for val %s" % (ct, ct.typename, val))
+             return val
+ 
+         converters = {
+             'blob': (lambda v, ct=cql_type: bytearray.fromhex(v[2:])),
+             'decimal': (lambda v, ct=cql_type: Decimal(v)),
+             'uuid': (lambda v, ct=cql_type: UUID(v)),
+             'boolean': (lambda v, ct=cql_type: bool(v)),
+             'tinyint': (lambda v, ct=cql_type: int(v)),
+             'ascii': (lambda v, ct=cql_type: v),
+             'float': (lambda v, ct=cql_type: float(v)),
+             'double': (lambda v, ct=cql_type: float(v)),
+             'bigint': (lambda v, ct=cql_type: long(v)),
+             'int': (lambda v, ct=cql_type: int(v)),
+             'varint': (lambda v, ct=cql_type: int(v)),
+             'inet': (lambda v, ct=cql_type: v),
+             'counter': (lambda v, ct=cql_type: long(v)),
+             'timestamp': convert_date,
+             'timeuuid': (lambda v, ct=cql_type: UUID(v)),
+             'date': (lambda v, ct=cql_type: Date(v)),
+             'smallint': (lambda v, ct=cql_type: int(v)),
+             'time': (lambda v, ct=cql_type: Time(v)),
+             'text': (lambda v, ct=cql_type: v),
+             'varchar': (lambda v, ct=cql_type: v),
+             'list': convert_list,
+             'set': convert_set,
+             'map': convert_map,
+             'tuple': convert_tuple,
+             'frozen': convert_single_subtype,
+         }
+ 
+         return converters.get(cql_type.typename, convert_unknown)
+ 
+     def get_row_values(self, row):
+         """
+         Parse the row into a list of row values to be returned
+         """
+         ret = [None] * len(row)
+         for i, val in enumerate(row):
+             if val != self.nullval:
+                 ret[i] = self.converters[self.columns[i]](val)
+             else:
+                 if i in self.primary_key_indexes:
+                     message = "Cannot insert null value for primary key column '%s'." % (self.columns[i],)
+                     if self.nullval == '':
+                         message += " If you want to insert empty strings, consider using" \
+                                    " the WITH NULL=<marker> option for COPY."
+                     raise Exception(message=message)
+ 
+                 ret[i] = None
+ 
+         return ret
+ 
+     def get_row_partition_key_values(self, row):
+         """
+         Return a string composed of the partition key values, serialized and binary packed -
+         as expected by metadata.get_replicas(), see also BoundStatement.routing_key.
+         """
+         def serialize(n):
+             c, v = self.columns[n], row[n]
+             return self.cqltypes[c].serialize(self.converters[c](v), self.proto_version)
+ 
+         partition_key_indexes = self.partition_key_indexes
+         if len(partition_key_indexes) == 1:
+             return serialize(partition_key_indexes[0])
+         else:
+             pk_values = []
+             for i in partition_key_indexes:
+                 val = serialize(i)
+                 l = len(val)
+                 pk_values.append(struct.pack(">H%dsB" % l, l, val, 0))
+             return b"".join(pk_values)
+ 
+ 
+ class ImportProcess(ChildProcess):
+ 
+     def __init__(self, params):
+         ChildProcess.__init__(self, params=params, target=self.run)
+ 
+         csv_options = params['csv_options']
+         self.nullval = csv_options['nullval']
+         self.max_attempts = csv_options['maxattempts']
+         self.min_batch_size = csv_options['minbatchsize']
+         self.max_batch_size = csv_options['maxbatchsize']
+         self._session = None
+ 
+     @property
+     def session(self):
+         if not self._session:
+             cluster = Cluster(
+                 contact_points=(self.hostname,),
+                 port=self.port,
+                 cql_version=self.cql_version,
+                 protocol_version=self.protocol_version,
+                 auth_provider=self.auth_provider,
+                 load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy()),
+                 ssl_options=ssl_settings(self.hostname, self.config_file) if self.ssl else None,
+                 default_retry_policy=ExpBackoffRetryPolicy(self),
+                 compression=None,
+                 connect_timeout=self.connect_timeout)
+ 
+             self._session = cluster.connect(self.ks)
+             self._session.default_timeout = None
+         return self._session
+ 
+     def run(self):
+         try:
+             table_meta = self.session.cluster.metadata.keyspaces[self.ks].tables[self.cf]
 -            is_counter = ("counter" in [table_meta.columns[name].typestring for name in self.columns])
++            is_counter = ("counter" in [table_meta.columns[name].cql_type for name in self.columns])
+ 
+             if is_counter:
+                 self.run_counter(table_meta)
+             else:
+                 self.run_normal(table_meta)
+ 
+         except Exception, exc:
+             if self.debug:
+                 traceback.print_exc(exc)
+ 
+         finally:
+             self.close()
+ 
+     def close(self):
+         if self._session:
+             self._session.cluster.shutdown()
+         ChildProcess.close(self)
+ 
+     def run_counter(self, table_meta):
+         """
+         Main run method for tables that contain counter columns.
+         """
+         query = 'UPDATE %s.%s SET %%s WHERE %%s' % (protect_name(self.ks), protect_name(self.cf))
+ 
+         # We prepare a query statement to find out the types of the partition key columns so we can
+         # route the update query to the correct replicas. As far as I understood this is the easiest
+         # way to find out the types of the partition columns, we will never use this prepared statement
+         where_clause = ' AND '.join(['%s = ?' % (protect_name(c.name)) for c in table_meta.partition_key])
+         select_query = 'SELECT * FROM %s.%s WHERE %s' % (protect_name(self.ks), protect_name(self.cf), where_clause)
+         conv = ImportConversion(self, table_meta, self.session.prepare(select_query))
+ 
+         while True:
+             try:
+                 batch = self.inmsg.get()
+ 
+                 for batches in self.split_batches(batch, conv):
+                     for b in batches:
+                         self.send_counter_batch(query, conv, b)
+ 
+             except Exception, exc:
+                 self.outmsg.put((batch, '%s - %s' % (exc.__class__.__name__, exc.message)))
+                 if self.debug:
+                     traceback.print_exc(exc)
+ 
+     def run_normal(self, table_meta):
+         """
+         Main run method for normal tables, i.e. tables that do not contain counter columns.
+         """
+         query = 'INSERT INTO %s.%s (%s) VALUES (%s)' % (protect_name(self.ks),
+                                                         protect_name(self.cf),
+                                                         ', '.join(protect_names(self.columns),),
+                                                         ', '.join(['?' for _ in self.columns]))
+         query_statement = self.session.prepare(query)
+         conv = ImportConversion(self, table_meta, query_statement)
+ 
+         while True:
+             try:
+                 batch = self.inmsg.get()
+ 
+                 for batches in self.split_batches(batch, conv):
+                     for b in batches:
+                         self.send_normal_batch(conv, query_statement, b)
+ 
+             except Exception, exc:
+                 self.outmsg.put((batch, '%s - %s' % (exc.__class__.__name__, exc.message)))
+                 if self.debug:
+                     traceback.print_exc(exc)
+ 
+     def send_counter_batch(self, query_text, conv, batch):
+         if self.test_failures and self.maybe_inject_failures(batch):
+             return
+ 
+         columns = self.columns
+         batch_statement = BatchStatement(batch_type=BatchType.COUNTER, consistency_level=self.consistency_level)
+         for row in batch['rows']:
+             where_clause = []
+             set_clause = []
+             for i, value in enumerate(row):
+                 if i in conv.primary_key_indexes:
+                     where_clause.append("%s=%s" % (columns[i], value))
+                 else:
+                     set_clause.append("%s=%s+%s" % (columns[i], columns[i], value))
+ 
+             full_query_text = query_text % (','.join(set_clause), ' AND '.join(where_clause))
+             batch_statement.add(full_query_text)
+ 
+         self.execute_statement(batch_statement, batch)
+ 
+     def send_normal_batch(self, conv, query_statement, batch):
+         try:
+             if self.test_failures and self.maybe_inject_failures(batch):
+                 return
+ 
+             batch_statement = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=self.consistency_level)
+             for row in batch['rows']:
+                 batch_statement.add(query_statement, conv.get_row_values(row))
+ 
+             self.execute_statement(batch_statement, batch)
+ 
+         except Exception, exc:
+             self.err_callback(exc, batch)
+ 
+     def maybe_inject_failures(self, batch):
+         """
+         Examine self.test_failures and see if token_range is either a token range
+         supposed to cause a failure (failing_range) or to terminate the worker process
+         (exit_range). If not then call prepare_export_query(), which implements the
+         normal behavior.
+         """
+         if 'failing_batch' in self.test_failures:
+             failing_batch = self.test_failures['failing_batch']
+             if failing_batch['id'] == batch['id']:
+                 if batch['attempts'] < failing_batch['failures']:
+                     statement = SimpleStatement("INSERT INTO badtable (a, b) VALUES (1, 2)",
+                                                 consistency_level=self.consistency_level)
+                     self.execute_statement(statement, batch)
+                     return True
+ 
+         if 'exit_batch' in self.test_failures:
+             exit_batch = self.test_failures['exit_batch']
+             if exit_batch['id'] == batch['id']:
+                 sys.exit(1)
+ 
+         return False  # carry on as normal
+ 
+     def execute_statement(self, statement, batch):
+         future = self.session.execute_async(statement)
+         future.add_callbacks(callback=self.result_callback, callback_args=(batch, ),
+                              errback=self.err_callback, errback_args=(batch, ))
+ 
+     def split_batches(self, batch, conv):
+         """
+         Split a batch into sub-batches with the same
+         partition key, if possible. If there are at least
+         batch_size rows with the same partition key value then
+         create a sub-batch with that partition key value, else
+         aggregate all remaining rows in a single 'left-overs' batch
+         """
+         rows_by_pk = defaultdict(list)
+ 
+         for row in batch['rows']:
+             pk = conv.get_row_partition_key_values(row)
+             rows_by_pk[pk].append(row)
+ 
+         ret = dict()
+         remaining_rows = []
+ 
+         for pk, rows in rows_by_pk.items():
+             if len(rows) >= self.min_batch_size:
+                 ret[pk] = self.batches(rows, batch)
+             else:
+                 remaining_rows.extend(rows)
+ 
+         if remaining_rows:
+             ret[self.hostname] = self.batches(remaining_rows, batch)
+ 
+         return ret.itervalues()
+ 
+     def batches(self, rows, batch):
+         for i in xrange(0, len(rows), self.max_batch_size):
+             yield ImportTask.make_batch(batch['id'], rows[i:i + self.max_batch_size], batch['attempts'])
+ 
+     def result_callback(self, result, batch):
+         batch['imported'] = len(batch['rows'])
+         batch['rows'] = []  # no need to resend these
+         self.outmsg.put((batch, None))
+ 
+     def err_callback(self, response, batch):
+         batch['imported'] = len(batch['rows'])
+         self.outmsg.put((batch, '%s - %s' % (response.__class__.__name__, response.message)))
+         if self.debug:
+             traceback.print_exc(response)
+ 
+ 
  class RateMeter(object):
  
-     def __init__(self, log_threshold):
-         self.log_threshold = log_threshold  # number of records after which we log
-         self.last_checkpoint_time = time.time()  # last time we logged
+     def __init__(self, update_interval=0.25, log=True):
+         self.log = log  # true if we should log
+         self.update_interval = update_interval  # how often we update in seconds
+         self.start_time = time.time()  # the start time
+         self.last_checkpoint_time = self.start_time  # last time we logged
          self.current_rate = 0.0  # rows per second
-         self.current_record = 0  # number of records since we last logged
+         self.current_record = 0  # number of records since we last updated
          self.total_records = 0   # total number of records
  
      def increment(self, n=1):