You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by ty...@apache.org on 2015/11/19 01:23:12 UTC
cassandra git commit: cqlsh: Improve COPY TO perf and error handling
Repository: cassandra
Updated Branches:
refs/heads/cassandra-2.1 246cb883a -> 1b629c101
cqlsh: Improve COPY TO perf and error handling
Patch by Stefania Alborghetti; reviewed by Tyler Hobbs for
CASSANDRA-9304
Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/1b629c10
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/1b629c10
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/1b629c10
Branch: refs/heads/cassandra-2.1
Commit: 1b629c101bbf793f8e248bbf4396bb41adc0af97
Parents: 246cb88
Author: Stefania Alborghetti <st...@datastax.com>
Authored: Wed Nov 18 18:22:28 2015 -0600
Committer: Tyler Hobbs <ty...@gmail.com>
Committed: Wed Nov 18 18:22:28 2015 -0600
----------------------------------------------------------------------
CHANGES.txt | 1 +
bin/cqlsh | 180 +++--------
pylib/cqlshlib/copy.py | 644 ++++++++++++++++++++++++++++++++++++++
pylib/cqlshlib/displaying.py | 10 +
pylib/cqlshlib/formatting.py | 34 +-
5 files changed, 729 insertions(+), 140 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/cassandra/blob/1b629c10/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index 6ccde28..42dcf3e 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
2.1.12
+ * (cqlsh) Improve COPY TO performance and error handling (CASSANDRA-9304)
* Don't remove level info when running upgradesstables (CASSANDRA-10692)
* Create compression chunk for sending file only (CASSANDRA-10680)
* Make buffered read size configurable (CASSANDRA-10249)
http://git-wip-us.apache.org/repos/asf/cassandra/blob/1b629c10/bin/cqlsh
----------------------------------------------------------------------
diff --git a/bin/cqlsh b/bin/cqlsh
index 7291803..5459d67 100755
--- a/bin/cqlsh
+++ b/bin/cqlsh
@@ -37,6 +37,7 @@ import ConfigParser
import csv
import getpass
import locale
+import multiprocessing as mp
import optparse
import os
import platform
@@ -44,10 +45,11 @@ import sys
import time
import traceback
import warnings
+
+from StringIO import StringIO
from contextlib import contextmanager
from functools import partial
from glob import glob
-from StringIO import StringIO
from uuid import UUID
description = "CQL Shell for Apache Cassandra"
@@ -119,7 +121,7 @@ cqlshlibdir = os.path.join(CASSANDRA_PATH, 'pylib')
if os.path.isdir(cqlshlibdir):
sys.path.insert(0, cqlshlibdir)
-from cqlshlib import cql3handling, cqlhandling, pylexotron, sslhandling
+from cqlshlib import cql3handling, cqlhandling, pylexotron, sslhandling, copy
from cqlshlib.displaying import (ANSI_RESET, BLUE, COLUMN_NAME_COLORS, CYAN,
RED, FormattedValue, colorme)
from cqlshlib.formatting import (format_by_type, format_value_utype,
@@ -410,7 +412,8 @@ def complete_copy_column_names(ctxt, cqlsh):
return set(colnames[1:]) - set(existcols)
-COPY_OPTIONS = ('DELIMITER', 'QUOTE', 'ESCAPE', 'HEADER', 'ENCODING', 'TIMEFORMAT', 'NULL')
+COPY_OPTIONS = ['DELIMITER', 'QUOTE', 'ESCAPE', 'HEADER', 'NULL', 'ENCODING',
+ 'TIMEFORMAT', 'JOBS', 'PAGESIZE', 'PAGETIMEOUT', 'MAXATTEMPTS']
@cqlsh_syntax_completer('copyOption', 'optnames')
@@ -419,8 +422,7 @@ def complete_copy_options(ctxt, cqlsh):
direction = ctxt.get_binding('dir').upper()
opts = set(COPY_OPTIONS) - set(optnames)
if direction == 'FROM':
- opts -= ('ENCODING',)
- opts -= ('TIMEFORMAT',)
+ opts -= set(['ENCODING', 'TIMEFORMAT', 'JOBS', 'PAGESIZE', 'PAGETIMEOUT', 'MAXATTEMPTS'])
return opts
@@ -535,6 +537,19 @@ def describe_interval(seconds):
return words
+def insert_driver_hooks():
+ extend_cql_deserialization()
+ auto_format_udts()
+
+
+def extend_cql_deserialization():
+ """
+ The python driver returns BLOBs as string, but we expect them as bytearrays
+ """
+ cassandra.cqltypes.BytesType.deserialize = staticmethod(lambda byts, protocol_version: bytearray(byts))
+ cassandra.cqltypes.CassandraType.support_empty_values = True
+
+
def auto_format_udts():
# when we see a new user defined type, set up the shell formatting for it
udt_apply_params = cassandra.cqltypes.UserType.apply_parameters
@@ -673,11 +688,6 @@ class Shell(cmd.Cmd):
self.query_out = sys.stdout
self.consistency_level = cassandra.ConsistencyLevel.ONE
self.serial_consistency_level = cassandra.ConsistencyLevel.SERIAL
- # the python driver returns BLOBs as string, but we expect them as bytearrays
- cassandra.cqltypes.BytesType.deserialize = staticmethod(lambda byts, protocol_version: bytearray(byts))
- cassandra.cqltypes.CassandraType.support_empty_values = True
-
- auto_format_udts()
self.empty_lines = 0
self.statement_error = False
@@ -807,11 +817,9 @@ class Shell(cmd.Cmd):
def get_keyspaces(self):
return self.conn.metadata.keyspaces.values()
- def get_ring(self):
- if self.current_keyspace is None or self.current_keyspace == 'system':
- raise NoKeyspaceError("Ring view requires a current non-system keyspace")
- self.conn.metadata.token_map.rebuild_keyspace(self.current_keyspace, build_if_absent=True)
- return self.conn.metadata.token_map.tokens_to_hosts_by_ks[self.current_keyspace]
+ def get_ring(self, ks):
+ self.conn.metadata.token_map.rebuild_keyspace(ks, build_if_absent=True)
+ return self.conn.metadata.token_map.tokens_to_hosts_by_ks[ks]
def get_table_meta(self, ksname, tablename):
if ksname is None:
@@ -1369,7 +1377,7 @@ class Shell(cmd.Cmd):
# print 'Snitch: %s\n' % snitch
if self.current_keyspace is not None and self.current_keyspace != 'system':
print "Range ownership:"
- ring = self.get_ring()
+ ring = self.get_ring(self.current_keyspace)
for entry in ring.items():
print ' %39s [%s]' % (str(entry[0].value), ', '.join([host.address for host in entry[1]]))
print
@@ -1506,10 +1514,14 @@ class Shell(cmd.Cmd):
ENCODING='utf8' - encoding for CSV output (COPY TO only)
TIMEFORMAT= - timestamp strftime format (COPY TO only)
'%Y-%m-%d %H:%M:%S%z' defaults to time_format value in cqlshrc
+ PAGESIZE='1000' - the page size for fetching results (COPY TO only)
+ PAGETIMEOUT=10 - the page timeout for fetching results (COPY TO only)
+ MAXATTEMPTS='5' - the maximum number of attempts for errors (COPY TO only)
When entering CSV data on STDIN, you can use the sequence "\."
on a line by itself to end the data input.
"""
+
ks = self.cql_unprotect_name(parsed.get_binding('ksname', None))
if ks is None:
ks = self.current_keyspace
@@ -1546,22 +1558,12 @@ class Shell(cmd.Cmd):
print "\n%d rows %s in %s." % (rows, verb, describe_interval(timeend - timestart))
def perform_csv_import(self, ks, cf, columns, fname, opts):
- dialect_options = self.csv_dialect_defaults.copy()
- if 'quote' in opts:
- dialect_options['quotechar'] = opts.pop('quote')
- if 'escape' in opts:
- dialect_options['escapechar'] = opts.pop('escape')
- if 'delimiter' in opts:
- dialect_options['delimiter'] = opts.pop('delimiter')
- nullval = opts.pop('null', '')
- header = bool(opts.pop('header', '').lower() == 'true')
- if dialect_options['quotechar'] == dialect_options['escapechar']:
- dialect_options['doublequote'] = True
- del dialect_options['escapechar']
- if opts:
+ csv_options, dialect_options, unrecognized_options = copy.parse_options(self, opts)
+ if unrecognized_options:
self.printerr('Unrecognized COPY FROM options: %s'
- % ', '.join(opts.keys()))
+ % ', '.join(unrecognized_options.keys()))
return 0
+ nullval, header = csv_options['nullval'], csv_options['header']
if fname is None:
do_close = False
@@ -1576,33 +1578,24 @@ class Shell(cmd.Cmd):
return 0
current_record = None
+ processes, pipes = [], [],
try:
if header:
linesource.next()
reader = csv.reader(linesource, **dialect_options)
- from multiprocessing import Process, Pipe, cpu_count
+ num_processes = copy.get_num_processes(cap=4)
- # Pick a resonable number of child processes. We need to leave at
- # least one core for the parent process. This doesn't necessarily
- # need to be capped at 4, but it's currently enough to keep
- # a single local Cassandra node busy, and I see lower throughput
- # with more processes.
- try:
- num_processes = max(1, min(4, cpu_count() - 1))
- except NotImplementedError:
- num_processes = 1
-
- processes, pipes = [], [],
for i in range(num_processes):
- parent_conn, child_conn = Pipe()
+ parent_conn, child_conn = mp.Pipe()
pipes.append(parent_conn)
- processes.append(Process(target=self.multiproc_import, args=(child_conn, ks, cf, columns, nullval)))
+ proc_args = (child_conn, ks, cf, columns, nullval)
+ processes.append(mp.Process(target=self.multiproc_import, args=proc_args))
for process in processes:
process.start()
- meter = RateMeter(10000)
+ meter = copy.RateMeter(10000)
for current_record, row in enumerate(reader, start=1):
# write to the child process
pipes[current_record % num_processes].send((current_record, row))
@@ -1612,7 +1605,7 @@ class Shell(cmd.Cmd):
# check for any errors reported by the children
if (current_record % 100) == 0:
- if self._check_child_pipes(current_record, pipes):
+ if self._check_import_processes(current_record, pipes):
# no errors seen, continue with outer loop
continue
else:
@@ -1641,7 +1634,7 @@ class Shell(cmd.Cmd):
for process in processes:
process.join()
- self._check_child_pipes(current_record, pipes)
+ self._check_import_processes(current_record, pipes)
for pipe in pipes:
pipe.close()
@@ -1653,8 +1646,7 @@ class Shell(cmd.Cmd):
return current_record
- def _check_child_pipes(self, current_record, pipes):
- # check the pipes for errors from child processes
+ def _check_import_processes(self, current_record, pipes):
for pipe in pipes:
if pipe.poll():
try:
@@ -1802,62 +1794,13 @@ class Shell(cmd.Cmd):
new_cluster.shutdown()
def perform_csv_export(self, ks, cf, columns, fname, opts):
- dialect_options = self.csv_dialect_defaults.copy()
-
- if 'quote' in opts:
- dialect_options['quotechar'] = opts.pop('quote')
- if 'escape' in opts:
- dialect_options['escapechar'] = opts.pop('escape')
- if 'delimiter' in opts:
- dialect_options['delimiter'] = opts.pop('delimiter')
- encoding = opts.pop('encoding', 'utf8')
- nullval = opts.pop('null', '')
- header = bool(opts.pop('header', '').lower() == 'true')
- time_format = opts.pop('timeformat', self.display_time_format)
- if dialect_options['quotechar'] == dialect_options['escapechar']:
- dialect_options['doublequote'] = True
- del dialect_options['escapechar']
-
- if opts:
- self.printerr('Unrecognized COPY TO options: %s'
- % ', '.join(opts.keys()))
+ csv_options, dialect_options, unrecognized_options = copy.parse_options(self, opts)
+ if unrecognized_options:
+ self.printerr('Unrecognized COPY TO options: %s' % ', '.join(unrecognized_options.keys()))
return 0
- if fname is None:
- do_close = False
- csvdest = sys.stdout
- else:
- do_close = True
- try:
- csvdest = open(fname, 'wb')
- except IOError, e:
- self.printerr("Can't open %r for writing: %s" % (fname, e))
- return 0
-
- meter = RateMeter(10000)
- try:
- dump = self.prep_export_dump(ks, cf, columns)
- writer = csv.writer(csvdest, **dialect_options)
- if header:
- writer.writerow(columns)
- for row in dump:
- fmt = lambda v: \
- format_value(v, output_encoding=encoding, nullval=nullval,
- time_format=time_format,
- float_precision=self.display_float_precision).strval
- writer.writerow(map(fmt, row.values()))
- meter.increment()
- finally:
- if do_close:
- csvdest.close()
- return meter.current_record
-
- def prep_export_dump(self, ks, cf, columns):
- if columns is None:
- columns = self.get_column_names(ks, cf)
- columnlist = ', '.join(protect_names(columns))
- query = 'SELECT %s FROM %s.%s' % (columnlist, protect_name(ks), protect_name(cf))
- return self.session.execute(query)
+ return copy.ExportTask(self, ks, cf, columns, fname, csv_options, dialect_options,
+ DEFAULT_PROTOCOL_VERSION, CONFIG_FILE).run()
def do_show(self, parsed):
"""
@@ -2215,34 +2158,6 @@ class Shell(cmd.Cmd):
self.writeresult(text, color, newline=newline, out=sys.stderr)
-class RateMeter(object):
-
- def __init__(self, log_rate):
- self.log_rate = log_rate
- self.last_checkpoint_time = time.time()
- self.current_rate = 0.0
- self.current_record = 0
-
- def increment(self):
- self.current_record += 1
-
- if (self.current_record % self.log_rate) == 0:
- new_checkpoint_time = time.time()
- new_rate = self.log_rate / (new_checkpoint_time - self.last_checkpoint_time)
- self.last_checkpoint_time = new_checkpoint_time
-
- # smooth the rate a bit
- if self.current_rate == 0.0:
- self.current_rate = new_rate
- else:
- self.current_rate = (self.current_rate + new_rate) / 2.0
-
- output = 'Processed %s rows; Write: %.2f rows/s\r' % \
- (self.current_record, self.current_rate)
- sys.stdout.write(output)
- sys.stdout.flush()
-
-
class SwitchCommand(object):
command = None
description = None
@@ -2487,6 +2402,9 @@ def main(options, hostname, port):
if batch_mode and shell.statement_error:
sys.exit(2)
+# always call this regardless of module name: when a sub-process is spawned
+# on Windows then the module name is not __main__, see CASSANDRA-9304
+insert_driver_hooks()
if __name__ == '__main__':
main(*read_options(sys.argv[1:], os.environ))
http://git-wip-us.apache.org/repos/asf/cassandra/blob/1b629c10/pylib/cqlshlib/copy.py
----------------------------------------------------------------------
diff --git a/pylib/cqlshlib/copy.py b/pylib/cqlshlib/copy.py
new file mode 100644
index 0000000..8534b98
--- /dev/null
+++ b/pylib/cqlshlib/copy.py
@@ -0,0 +1,644 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import csv
+import json
+import multiprocessing as mp
+import os
+import Queue
+import sys
+import time
+import traceback
+
+from StringIO import StringIO
+from random import randrange
+from threading import Lock
+
+from cassandra.cluster import Cluster
+from cassandra.metadata import protect_name, protect_names
+from cassandra.policies import RetryPolicy, WhiteListRoundRobinPolicy, TokenAwarePolicy
+from cassandra.query import tuple_factory
+
+
+import sslhandling
+from displaying import NO_COLOR_MAP
+from formatting import format_value_default, EMPTY, get_formatter
+
+
+def parse_options(shell, opts):
+ """
+ Parse options for import (COPY FROM) and export (COPY TO) operations.
+ Extract from opts csv and dialect options.
+
+ :return: 3 dictionaries: the csv options, the dialect options, any unrecognized options.
+ """
+ dialect_options = shell.csv_dialect_defaults.copy()
+ if 'quote' in opts:
+ dialect_options['quotechar'] = opts.pop('quote')
+ if 'escape' in opts:
+ dialect_options['escapechar'] = opts.pop('escape')
+ if 'delimiter' in opts:
+ dialect_options['delimiter'] = opts.pop('delimiter')
+ if dialect_options['quotechar'] == dialect_options['escapechar']:
+ dialect_options['doublequote'] = True
+ del dialect_options['escapechar']
+
+ csv_options = dict()
+ csv_options['nullval'] = opts.pop('null', '')
+ csv_options['header'] = bool(opts.pop('header', '').lower() == 'true')
+ csv_options['encoding'] = opts.pop('encoding', 'utf8')
+ csv_options['jobs'] = int(opts.pop('jobs', 12))
+ csv_options['pagesize'] = int(opts.pop('pagesize', 1000))
+ # by default the page timeout is 10 seconds per 1000 entries in the page size or 10 seconds if pagesize is smaller
+ csv_options['pagetimeout'] = int(opts.pop('pagetimeout', max(10, 10 * (csv_options['pagesize'] / 1000))))
+ csv_options['maxattempts'] = int(opts.pop('maxattempts', 5))
+ csv_options['dtformats'] = opts.pop('timeformat', shell.display_time_format)
+ csv_options['float_precision'] = shell.display_float_precision
+
+ return csv_options, dialect_options, opts
+
+
+def get_num_processes(cap):
+ """
+ Pick a reasonable number of child processes. We need to leave at
+ least one core for the parent process. This doesn't necessarily
+ need to be capped, but 4 is currently enough to keep
+ a single local Cassandra node busy so we use this for import, whilst
+ for export we use 16 since we can connect to multiple Cassandra nodes.
+ Eventually this parameter will become an option.
+ """
+ try:
+ return max(1, min(cap, mp.cpu_count() - 1))
+ except NotImplementedError:
+ return 1
+
+
+class ExportTask(object):
+ """
+ A class that exports data to .csv by instantiating one or more processes that work in parallel (ExportProcess).
+ """
+ def __init__(self, shell, ks, cf, columns, fname, csv_options, dialect_options, protocol_version, config_file):
+ self.shell = shell
+ self.csv_options = csv_options
+ self.dialect_options = dialect_options
+ self.ks = ks
+ self.cf = cf
+ self.columns = shell.get_column_names(ks, cf) if columns is None else columns
+ self.fname = fname
+ self.protocol_version = protocol_version
+ self.config_file = config_file
+
+ def run(self):
+ """
+ Initiates the export by creating the processes.
+ """
+ shell = self.shell
+ fname = self.fname
+
+ if fname is None:
+ do_close = False
+ csvdest = sys.stdout
+ else:
+ do_close = True
+ try:
+ csvdest = open(fname, 'wb')
+ except IOError, e:
+ shell.printerr("Can't open %r for writing: %s" % (fname, e))
+ return 0
+
+ if self.csv_options['header']:
+ writer = csv.writer(csvdest, **self.dialect_options)
+ writer.writerow(self.columns)
+
+ ranges = self.get_ranges()
+ num_processes = get_num_processes(cap=min(16, len(ranges)))
+
+ inmsg = mp.Queue()
+ outmsg = mp.Queue()
+ processes = []
+ for i in xrange(num_processes):
+ process = ExportProcess(outmsg, inmsg, self.ks, self.cf, self.columns, self.dialect_options,
+ self.csv_options, shell.debug, shell.port, shell.conn.cql_version,
+ shell.auth_provider, shell.ssl, self.protocol_version, self.config_file)
+ process.start()
+ processes.append(process)
+
+ try:
+ return self.check_processes(csvdest, ranges, inmsg, outmsg, processes)
+ finally:
+ for process in processes:
+ process.terminate()
+
+ inmsg.close()
+ outmsg.close()
+ if do_close:
+ csvdest.close()
+
+ def get_ranges(self):
+ """
+ return a queue of tuples, where the first tuple entry is a token range (from, to]
+ and the second entry is a list of hosts that own that range. Each host is responsible
+ for all the tokens in the rage (from, to].
+
+ The ring information comes from the driver metadata token map, which is built by
+ querying System.PEERS.
+
+ We only consider replicas that are in the local datacenter. If there are no local replicas
+ we use the cqlsh session host.
+ """
+ shell = self.shell
+ hostname = shell.hostname
+ ranges = dict()
+
+ def make_range(hosts):
+ return {'hosts': tuple(hosts), 'attempts': 0, 'rows': 0}
+
+ min_token = self.get_min_token()
+ if shell.conn.metadata.token_map is None or min_token is None:
+ ranges[(None, None)] = make_range([hostname])
+ return ranges
+
+ local_dc = shell.conn.metadata.get_host(hostname).datacenter
+ ring = shell.get_ring(self.ks).items()
+ ring.sort()
+
+ previous_previous = None
+ previous = None
+ for token, replicas in ring:
+ if previous is None and token.value == min_token:
+ continue # avoids looping entire ring
+
+ hosts = []
+ for host in replicas:
+ if host.datacenter == local_dc:
+ hosts.append(host.address)
+ if len(hosts) == 0:
+ hosts.append(hostname) # fallback to default host if no replicas in current dc
+ ranges[(previous, token.value)] = make_range(hosts)
+ previous_previous = previous
+ previous = token.value
+
+ # If the ring is empty we get the entire ring from the
+ # host we are currently connected to, otherwise for the last ring interval
+ # we query the same replicas that hold the last token in the ring
+ if len(ranges) == 0:
+ ranges[(None, None)] = make_range([hostname])
+ else:
+ ranges[(previous, None)] = ranges[(previous_previous, previous)].copy()
+
+ return ranges
+
+ def get_min_token(self):
+ """
+ :return the minimum token, which depends on the partitioner.
+ For partitioners that do not support tokens we return None, in
+ this cases we will not work in parallel, we'll just send all requests
+ to the cqlsh session host.
+ """
+ partitioner = self.shell.conn.metadata.partitioner
+
+ if partitioner.endswith('RandomPartitioner'):
+ return -1
+ elif partitioner.endswith('Murmur3Partitioner'):
+ return -(2 ** 63) # Long.MIN_VALUE in Java
+ else:
+ return None
+
+ @staticmethod
+ def send_work(ranges, tokens_to_send, queue):
+ for token_range in tokens_to_send:
+ queue.put((token_range, ranges[token_range]))
+ ranges[token_range]['attempts'] += 1
+
+ def check_processes(self, csvdest, ranges, inmsg, outmsg, processes):
+ """
+ Here we monitor all child processes by collecting their results
+ or any errors. We terminate when we have processed all the ranges or when there
+ are no more processes.
+ """
+ shell = self.shell
+ meter = RateMeter(10000)
+ total_jobs = len(ranges)
+ max_attempts = self.csv_options['maxattempts']
+
+ self.send_work(ranges, ranges.keys(), outmsg)
+
+ num_processes = len(processes)
+ succeeded = 0
+ failed = 0
+ while (failed + succeeded) < total_jobs and self.num_live_processes(processes) == num_processes:
+ try:
+ token_range, result = inmsg.get(timeout=1.0)
+ if token_range is None and result is None: # a job has finished
+ succeeded += 1
+ elif isinstance(result, Exception): # an error occurred
+ if token_range is None: # the entire process failed
+ shell.printerr('Error from worker process: %s' % (result))
+ else: # only this token_range failed, retry up to max_attempts if no rows received yet,
+ # if rows are receive we risk duplicating data, there is a back-off policy in place
+ # in the worker process as well, see ExpBackoffRetryPolicy
+ if ranges[token_range]['attempts'] < max_attempts and ranges[token_range]['rows'] == 0:
+ shell.printerr('Error for %s: %s (will try again later attempt %d of %d)'
+ % (token_range, result, ranges[token_range]['attempts'], max_attempts))
+ self.send_work(ranges, [token_range], outmsg)
+ else:
+ shell.printerr('Error for %s: %s (permanently given up after %d rows and %d attempts)'
+ % (token_range, result, ranges[token_range]['rows'],
+ ranges[token_range]['attempts']))
+ failed += 1
+ else: # partial result received
+ data, num = result
+ csvdest.write(data)
+ meter.increment(n=num)
+ ranges[token_range]['rows'] += num
+ except Queue.Empty:
+ pass
+
+ if self.num_live_processes(processes) < len(processes):
+ for process in processes:
+ if not process.is_alive():
+ shell.printerr('Child process %d died with exit code %d' % (process.pid, process.exitcode))
+
+ if succeeded < total_jobs:
+ shell.printerr('Exported %d ranges out of %d total ranges, some records might be missing'
+ % (succeeded, total_jobs))
+
+ return meter.get_total_records()
+
+ @staticmethod
+ def num_live_processes(processes):
+ return sum(1 for p in processes if p.is_alive())
+
+
+class ExpBackoffRetryPolicy(RetryPolicy):
+ """
+ A retry policy with exponential back-off for read timeouts,
+ see ExportProcess.
+ """
+ def __init__(self, export_process):
+ RetryPolicy.__init__(self)
+ self.max_attempts = export_process.csv_options['maxattempts']
+ self.printmsg = lambda txt: export_process.printmsg(txt)
+
+ def on_read_timeout(self, query, consistency, required_responses,
+ received_responses, data_retrieved, retry_num):
+ delay = self.backoff(retry_num)
+ if delay > 0:
+ self.printmsg("Timeout received, retrying after %d seconds" % (delay))
+ time.sleep(delay)
+ return self.RETRY, consistency
+ elif delay == 0:
+ self.printmsg("Timeout received, retrying immediately")
+ return self.RETRY, consistency
+ else:
+ self.printmsg("Timeout received, giving up after %d attempts" % (retry_num + 1))
+ return self.RETHROW, None
+
+ def backoff(self, retry_num):
+ """
+ Perform exponential back-off up to a maximum number of times, where
+ this maximum is per query.
+ To back-off we should wait a random number of seconds
+ between 0 and 2^c - 1, where c is the number of total failures.
+ randrange() excludes the last value, so we drop the -1.
+
+ :return : the number of seconds to wait for, -1 if we should not retry
+ """
+ if retry_num >= self.max_attempts:
+ return -1
+
+ delay = randrange(0, pow(2, retry_num + 1))
+ return delay
+
+
+class ExportSession(object):
+ """
+ A class for connecting to a cluster and storing the number
+ of jobs that this connection is processing. It wraps the methods
+ for executing a query asynchronously and for shutting down the
+ connection to the cluster.
+ """
+ def __init__(self, cluster, export_process):
+ session = cluster.connect(export_process.ks)
+ session.row_factory = tuple_factory
+ session.default_fetch_size = export_process.csv_options['pagesize']
+ session.default_timeout = export_process.csv_options['pagetimeout']
+
+ export_process.printmsg("Created connection to %s with page size %d and timeout %d seconds per page"
+ % (session.hosts, session.default_fetch_size, session.default_timeout))
+
+ self.cluster = cluster
+ self.session = session
+ self.jobs = 1
+ self.lock = Lock()
+
+ def add_job(self):
+ with self.lock:
+ self.jobs += 1
+
+ def complete_job(self):
+ with self.lock:
+ self.jobs -= 1
+
+ def num_jobs(self):
+ with self.lock:
+ return self.jobs
+
+ def execute_async(self, query):
+ return self.session.execute_async(query)
+
+ def shutdown(self):
+ self.cluster.shutdown()
+
+
+class ExportProcess(mp.Process):
+ """
+ An child worker process for the export task, ExportTask.
+ """
+
+ def __init__(self, inmsg, outmsg, ks, cf, columns, dialect_options, csv_options,
+ debug, port, cql_version, auth_provider, ssl, protocol_version, config_file):
+ mp.Process.__init__(self, target=self.run)
+ self.inmsg = inmsg
+ self.outmsg = outmsg
+ self.ks = ks
+ self.cf = cf
+ self.columns = columns
+ self.dialect_options = dialect_options
+ self.hosts_to_sessions = dict()
+
+ self.debug = debug
+ self.port = port
+ self.cql_version = cql_version
+ self.auth_provider = auth_provider
+ self.ssl = ssl
+ self.protocol_version = protocol_version
+ self.config_file = config_file
+
+ self.encoding = csv_options['encoding']
+ self.time_format = csv_options['dtformats']
+ self.float_precision = csv_options['float_precision']
+ self.nullval = csv_options['nullval']
+ self.maxjobs = csv_options['jobs']
+ self.csv_options = csv_options
+ self.formatters = dict()
+
+ # Here we inject some failures for testing purposes, only if this environment variable is set
+ if os.environ.get('CQLSH_COPY_TEST_FAILURES', ''):
+ self.test_failures = json.loads(os.environ.get('CQLSH_COPY_TEST_FAILURES', ''))
+ else:
+ self.test_failures = None
+
+ def printmsg(self, text):
+ if self.debug:
+ sys.stderr.write(text + os.linesep)
+
+ def run(self):
+ try:
+ self.inner_run()
+ finally:
+ self.close()
+
+ def inner_run(self):
+ """
+ The parent sends us (range, info) on the inbound queue (inmsg)
+ in order to request us to process a range, for which we can
+ select any of the hosts in info, which also contains other information for this
+ range such as the number of attempts already performed. We can signal errors
+ on the outbound queue (outmsg) by sending (range, error) or
+ we can signal a global error by sending (None, error).
+ We terminate when the inbound queue is closed.
+ """
+ while True:
+ if self.num_jobs() > self.maxjobs:
+ time.sleep(0.001) # 1 millisecond
+ continue
+
+ token_range, info = self.inmsg.get()
+ self.start_job(token_range, info)
+
+ def report_error(self, err, token_range=None):
+ if isinstance(err, str):
+ msg = err
+ elif isinstance(err, BaseException):
+ msg = "%s - %s" % (err.__class__.__name__, err)
+ if self.debug:
+ traceback.print_exc(err)
+ else:
+ msg = str(err)
+
+ self.printmsg(msg)
+ self.outmsg.put((token_range, Exception(msg)))
+
+ def start_job(self, token_range, info):
+ """
+ Begin querying a range by executing an async query that
+ will later on invoke the callbacks attached in attach_callbacks.
+ """
+ session = self.get_session(info['hosts'])
+ metadata = session.cluster.metadata.keyspaces[self.ks].tables[self.cf]
+ query = self.prepare_query(metadata.partition_key, token_range, info['attempts'])
+ future = session.execute_async(query)
+ self.attach_callbacks(token_range, future, session)
+
+ def num_jobs(self):
+ return sum(session.num_jobs() for session in self.hosts_to_sessions.values())
+
+ def get_session(self, hosts):
+ """
+ We select a host to connect to. If we have no connections to one of the hosts
+ yet then we select this host, else we pick the one with the smallest number
+ of jobs.
+
+ :return: An ExportSession connected to the chosen host.
+ """
+ new_hosts = [h for h in hosts if h not in self.hosts_to_sessions]
+ if new_hosts:
+ host = new_hosts[0]
+ new_cluster = Cluster(
+ contact_points=(host,),
+ port=self.port,
+ cql_version=self.cql_version,
+ protocol_version=self.protocol_version,
+ auth_provider=self.auth_provider,
+ ssl_options=sslhandling.ssl_settings(host, self.config_file) if self.ssl else None,
+ load_balancing_policy=TokenAwarePolicy(WhiteListRoundRobinPolicy(hosts)),
+ default_retry_policy=ExpBackoffRetryPolicy(self),
+ compression=None,
+ executor_threads=max(2, self.csv_options['jobs'] / 2))
+
+ session = ExportSession(new_cluster, self)
+ self.hosts_to_sessions[host] = session
+ return session
+ else:
+ host = min(hosts, key=lambda h: self.hosts_to_sessions[h].jobs)
+ session = self.hosts_to_sessions[host]
+ session.add_job()
+ return session
+
+ def attach_callbacks(self, token_range, future, session):
+ def result_callback(rows):
+ if future.has_more_pages:
+ future.start_fetching_next_page()
+ self.write_rows_to_csv(token_range, rows)
+ else:
+ self.write_rows_to_csv(token_range, rows)
+ self.outmsg.put((None, None))
+ session.complete_job()
+
+ def err_callback(err):
+ self.report_error(err, token_range)
+ session.complete_job()
+
+ future.add_callbacks(callback=result_callback, errback=err_callback)
+
+ def write_rows_to_csv(self, token_range, rows):
+ if len(rows) == 0:
+ return # no rows in this range
+
+ try:
+ output = StringIO()
+ writer = csv.writer(output, **self.dialect_options)
+
+ for row in rows:
+ writer.writerow(map(self.format_value, row))
+
+ data = (output.getvalue(), len(rows))
+ self.outmsg.put((token_range, data))
+ output.close()
+
+ except Exception, e:
+ self.report_error(e, token_range)
+
+ def format_value(self, val):
+ if val is None or val == EMPTY:
+ return format_value_default(self.nullval, colormap=NO_COLOR_MAP)
+
+ ctype = type(val)
+ formatter = self.formatters.get(ctype, None)
+ if not formatter:
+ formatter = get_formatter(ctype)
+ self.formatters[ctype] = formatter
+
+ return formatter(val, encoding=self.encoding, colormap=NO_COLOR_MAP, time_format=self.time_format,
+ float_precision=self.float_precision, nullval=self.nullval, quote=False)
+
+ def close(self):
+ self.printmsg("Export process terminating...")
+ self.inmsg.close()
+ self.outmsg.close()
+ for session in self.hosts_to_sessions.values():
+ session.shutdown()
+ self.printmsg("Export process terminated")
+
+ def prepare_query(self, partition_key, token_range, attempts):
+ """
+ Return the export query or a fake query with some failure injected.
+ """
+ if self.test_failures:
+ return self.maybe_inject_failures(partition_key, token_range, attempts)
+ else:
+ return self.prepare_export_query(partition_key, token_range)
+
+ def maybe_inject_failures(self, partition_key, token_range, attempts):
+ """
+ Examine self.test_failures and see if token_range is either a token range
+ supposed to cause a failure (failing_range) or to terminate the worker process
+ (exit_range). If not then call prepare_export_query(), which implements the
+ normal behavior.
+ """
+ start_token, end_token = token_range
+
+ if not start_token or not end_token:
+ # exclude first and last ranges to make things simpler
+ return self.prepare_export_query(partition_key, token_range)
+
+ if 'failing_range' in self.test_failures:
+ failing_range = self.test_failures['failing_range']
+ if start_token >= failing_range['start'] and end_token <= failing_range['end']:
+ if attempts < failing_range['num_failures']:
+ return 'SELECT * from bad_table'
+
+ if 'exit_range' in self.test_failures:
+ exit_range = self.test_failures['exit_range']
+ if start_token >= exit_range['start'] and end_token <= exit_range['end']:
+ sys.exit(1)
+
+ return self.prepare_export_query(partition_key, token_range)
+
+ def prepare_export_query(self, partition_key, token_range):
+ """
+ Return a query where we select all the data for this token range
+ """
+ pk_cols = ", ".join(protect_names(col.name for col in partition_key))
+ columnlist = ', '.join(protect_names(self.columns))
+ start_token, end_token = token_range
+ query = 'SELECT %s FROM %s.%s' % (columnlist, protect_name(self.ks), protect_name(self.cf))
+ if start_token is not None or end_token is not None:
+ query += ' WHERE'
+ if start_token is not None:
+ query += ' token(%s) > %s' % (pk_cols, start_token)
+ if start_token is not None and end_token is not None:
+ query += ' AND'
+ if end_token is not None:
+ query += ' token(%s) <= %s' % (pk_cols, end_token)
+ return query
+
+
+class RateMeter(object):
+
+ def __init__(self, log_threshold):
+ self.log_threshold = log_threshold # number of records after which we log
+ self.last_checkpoint_time = time.time() # last time we logged
+ self.current_rate = 0.0 # rows per second
+ self.current_record = 0 # number of records since we last logged
+ self.total_records = 0 # total number of records
+
+ def increment(self, n=1):
+ self.current_record += n
+
+ if self.current_record >= self.log_threshold:
+ self.update()
+ self.log()
+
+ def update(self):
+ new_checkpoint_time = time.time()
+ time_difference = new_checkpoint_time - self.last_checkpoint_time
+ if time_difference != 0.0:
+ self.current_rate = self.get_new_rate(self.current_record / time_difference)
+
+ self.last_checkpoint_time = new_checkpoint_time
+ self.total_records += self.current_record
+ self.current_record = 0
+
+ def get_new_rate(self, new_rate):
+ """
+ return the previous rate averaged with the new rate to smooth a bit
+ """
+ if self.current_rate == 0.0:
+ return new_rate
+ else:
+ return (self.current_rate + new_rate) / 2.0
+
+ def log(self):
+ output = 'Processed %d rows; Written: %f rows/s\r' % (self.total_records, self.current_rate,)
+ sys.stdout.write(output)
+ sys.stdout.flush()
+
+ def get_total_records(self):
+ self.update()
+ self.log()
+ return self.total_records
http://git-wip-us.apache.org/repos/asf/cassandra/blob/1b629c10/pylib/cqlshlib/displaying.py
----------------------------------------------------------------------
diff --git a/pylib/cqlshlib/displaying.py b/pylib/cqlshlib/displaying.py
index f3a016e..7b260c2 100644
--- a/pylib/cqlshlib/displaying.py
+++ b/pylib/cqlshlib/displaying.py
@@ -28,11 +28,19 @@ ANSI_RESET = '\033[0m'
def colorme(bval, colormap, colorkey):
+ if colormap is NO_COLOR_MAP:
+ return bval
if colormap is None:
colormap = DEFAULT_VALUE_COLORS
return FormattedValue(bval, colormap[colorkey] + bval + colormap['reset'])
+def get_str(val):
+ if isinstance(val, FormattedValue):
+ return val.strval
+ return val
+
+
class FormattedValue:
def __init__(self, strval, coloredval=None, displaywidth=None):
@@ -112,3 +120,5 @@ COLUMN_NAME_COLORS = defaultdict(lambda: MAGENTA,
blob=DARK_MAGENTA,
reset=ANSI_RESET,
)
+
+NO_COLOR_MAP = dict()
http://git-wip-us.apache.org/repos/asf/cassandra/blob/1b629c10/pylib/cqlshlib/formatting.py
----------------------------------------------------------------------
diff --git a/pylib/cqlshlib/formatting.py b/pylib/cqlshlib/formatting.py
index 79e661b..54dde0f 100644
--- a/pylib/cqlshlib/formatting.py
+++ b/pylib/cqlshlib/formatting.py
@@ -14,13 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import binascii
import sys
import re
import calendar
import math
from collections import defaultdict
from . import wcwidth
-from .displaying import colorme, FormattedValue, DEFAULT_VALUE_COLORS
+from .displaying import colorme, get_str, FormattedValue, DEFAULT_VALUE_COLORS, NO_COLOR_MAP
from cassandra.cqltypes import EMPTY
from cassandra.util import datetime_from_timestamp
from util import UTC
@@ -83,7 +84,6 @@ def color_text(bval, colormap, displaywidth=None):
# adding the smarts to handle that in to FormattedValue, we just
# make an explicit check to see if a null colormap is being used or
# not.
-
if displaywidth is None:
displaywidth = len(bval)
tbr = _make_turn_bits_red_f(colormap['blob'], colormap['text'])
@@ -97,7 +97,7 @@ def format_value_default(val, colormap, **_):
val = str(val)
escapedval = val.replace('\\', '\\\\')
bval = controlchars_re.sub(_show_control_chars, escapedval)
- return color_text(bval, colormap)
+ return bval if colormap is NO_COLOR_MAP else color_text(bval, colormap)
# Mapping cql type base names ("int", "map", etc) to formatter functions,
# making format_value a generic function
@@ -111,6 +111,10 @@ def format_value(type, val, **kwargs):
return formatter(val, **kwargs)
+def get_formatter(type):
+ return _formatters.get(type.__name__, format_value_default)
+
+
def formatter_for(typname):
def registrator(f):
_formatters[typname] = f
@@ -120,7 +124,7 @@ def formatter_for(typname):
@formatter_for('bytearray')
def format_value_blob(val, colormap, **_):
- bval = '0x' + ''.join('%02x' % c for c in val)
+ bval = '0x' + binascii.hexlify(val)
return colorme(bval, colormap, 'blob')
formatter_for('buffer')(format_value_blob)
@@ -204,8 +208,8 @@ def format_value_text(val, encoding, colormap, quote=False, **_):
bval = escapedval.encode(encoding, 'backslashreplace')
if quote:
bval = "'%s'" % bval
- displaywidth = wcwidth.wcswidth(bval.decode(encoding))
- return color_text(bval, colormap, displaywidth)
+
+ return bval if colormap is NO_COLOR_MAP else color_text(bval, colormap, wcwidth.wcswidth(bval.decode(encoding)))
# name alias
formatter_for('unicode')(format_value_text)
@@ -217,7 +221,10 @@ def format_simple_collection(val, lbracket, rbracket, encoding,
time_format=time_format, float_precision=float_precision,
nullval=nullval, quote=True)
for sval in val]
- bval = lbracket + ', '.join(sval.strval for sval in subs) + rbracket
+ bval = lbracket + ', '.join(get_str(sval) for sval in subs) + rbracket
+ if colormap is NO_COLOR_MAP:
+ return bval
+
lb, sep, rb = [colormap['collection'] + s + colormap['reset']
for s in (lbracket, ', ', rbracket)]
coloredval = lb + sep.join(sval.coloredval for sval in subs) + rb
@@ -242,6 +249,9 @@ def format_value_set(val, encoding, colormap, time_format, float_precision, null
return format_simple_collection(sorted(val), '{', '}', encoding, colormap,
time_format, float_precision, nullval)
formatter_for('frozenset')(format_value_set)
+# This code is used by cqlsh (bundled driver version 2.7.2 using sortedset),
+# and the dtests, which use whichever driver on the machine, i.e. 3.0.0 (SortedSet)
+formatter_for('SortedSet')(format_value_set)
formatter_for('sortedset')(format_value_set)
@@ -253,7 +263,10 @@ def format_value_map(val, encoding, colormap, time_format, float_precision, null
nullval=nullval, quote=True)
subs = [(subformat(k), subformat(v)) for (k, v) in sorted(val.items())]
- bval = '{' + ', '.join(k.strval + ': ' + v.strval for (k, v) in subs) + '}'
+ bval = '{' + ', '.join(get_str(k) + ': ' + get_str(v) for (k, v) in subs) + '}'
+ if colormap is NO_COLOR_MAP:
+ return bval
+
lb, comma, colon, rb = [colormap['collection'] + s + colormap['reset']
for s in ('{', ', ', ': ', '}')]
coloredval = lb \
@@ -278,7 +291,10 @@ def format_value_utype(val, encoding, colormap, time_format, float_precision, nu
return format_value_text(name, encoding=encoding, colormap=colormap, quote=False)
subs = [(format_field_name(k), format_field_value(v)) for (k, v) in val._asdict().items()]
- bval = '{' + ', '.join(k.strval + ': ' + v.strval for (k, v) in subs) + '}'
+ bval = '{' + ', '.join(get_str(k) + ': ' + get_str(v) for (k, v) in subs) + '}'
+ if colormap is NO_COLOR_MAP:
+ return bval
+
lb, comma, colon, rb = [colormap['collection'] + s + colormap['reset']
for s in ('{', ', ', ': ', '}')]
coloredval = lb \