You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by ty...@apache.org on 2015/11/19 01:31:45 UTC
[1/5] cassandra git commit: cqlsh: Improve COPY TO perf and error
handling
Repository: cassandra
Updated Branches:
refs/heads/cassandra-3.0 12fd5d270 -> 74070ee4a
cqlsh: Improve COPY TO perf and error handling
Patch by Stefania Alborghetti; reviewed by Tyler Hobbs for
CASSANDRA-9304
Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/1b629c10
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/1b629c10
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/1b629c10
Branch: refs/heads/cassandra-3.0
Commit: 1b629c101bbf793f8e248bbf4396bb41adc0af97
Parents: 246cb88
Author: Stefania Alborghetti <st...@datastax.com>
Authored: Wed Nov 18 18:22:28 2015 -0600
Committer: Tyler Hobbs <ty...@gmail.com>
Committed: Wed Nov 18 18:22:28 2015 -0600
----------------------------------------------------------------------
CHANGES.txt | 1 +
bin/cqlsh | 180 +++--------
pylib/cqlshlib/copy.py | 644 ++++++++++++++++++++++++++++++++++++++
pylib/cqlshlib/displaying.py | 10 +
pylib/cqlshlib/formatting.py | 34 +-
5 files changed, 729 insertions(+), 140 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/cassandra/blob/1b629c10/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index 6ccde28..42dcf3e 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
2.1.12
+ * (cqlsh) Improve COPY TO performance and error handling (CASSANDRA-9304)
* Don't remove level info when running upgradesstables (CASSANDRA-10692)
* Create compression chunk for sending file only (CASSANDRA-10680)
* Make buffered read size configurable (CASSANDRA-10249)
http://git-wip-us.apache.org/repos/asf/cassandra/blob/1b629c10/bin/cqlsh
----------------------------------------------------------------------
diff --git a/bin/cqlsh b/bin/cqlsh
index 7291803..5459d67 100755
--- a/bin/cqlsh
+++ b/bin/cqlsh
@@ -37,6 +37,7 @@ import ConfigParser
import csv
import getpass
import locale
+import multiprocessing as mp
import optparse
import os
import platform
@@ -44,10 +45,11 @@ import sys
import time
import traceback
import warnings
+
+from StringIO import StringIO
from contextlib import contextmanager
from functools import partial
from glob import glob
-from StringIO import StringIO
from uuid import UUID
description = "CQL Shell for Apache Cassandra"
@@ -119,7 +121,7 @@ cqlshlibdir = os.path.join(CASSANDRA_PATH, 'pylib')
if os.path.isdir(cqlshlibdir):
sys.path.insert(0, cqlshlibdir)
-from cqlshlib import cql3handling, cqlhandling, pylexotron, sslhandling
+from cqlshlib import cql3handling, cqlhandling, pylexotron, sslhandling, copy
from cqlshlib.displaying import (ANSI_RESET, BLUE, COLUMN_NAME_COLORS, CYAN,
RED, FormattedValue, colorme)
from cqlshlib.formatting import (format_by_type, format_value_utype,
@@ -410,7 +412,8 @@ def complete_copy_column_names(ctxt, cqlsh):
return set(colnames[1:]) - set(existcols)
-COPY_OPTIONS = ('DELIMITER', 'QUOTE', 'ESCAPE', 'HEADER', 'ENCODING', 'TIMEFORMAT', 'NULL')
+COPY_OPTIONS = ['DELIMITER', 'QUOTE', 'ESCAPE', 'HEADER', 'NULL', 'ENCODING',
+ 'TIMEFORMAT', 'JOBS', 'PAGESIZE', 'PAGETIMEOUT', 'MAXATTEMPTS']
@cqlsh_syntax_completer('copyOption', 'optnames')
@@ -419,8 +422,7 @@ def complete_copy_options(ctxt, cqlsh):
direction = ctxt.get_binding('dir').upper()
opts = set(COPY_OPTIONS) - set(optnames)
if direction == 'FROM':
- opts -= ('ENCODING',)
- opts -= ('TIMEFORMAT',)
+ opts -= set(['ENCODING', 'TIMEFORMAT', 'JOBS', 'PAGESIZE', 'PAGETIMEOUT', 'MAXATTEMPTS'])
return opts
@@ -535,6 +537,19 @@ def describe_interval(seconds):
return words
+def insert_driver_hooks():
+ extend_cql_deserialization()
+ auto_format_udts()
+
+
+def extend_cql_deserialization():
+ """
+ The python driver returns BLOBs as string, but we expect them as bytearrays
+ """
+ cassandra.cqltypes.BytesType.deserialize = staticmethod(lambda byts, protocol_version: bytearray(byts))
+ cassandra.cqltypes.CassandraType.support_empty_values = True
+
+
def auto_format_udts():
# when we see a new user defined type, set up the shell formatting for it
udt_apply_params = cassandra.cqltypes.UserType.apply_parameters
@@ -673,11 +688,6 @@ class Shell(cmd.Cmd):
self.query_out = sys.stdout
self.consistency_level = cassandra.ConsistencyLevel.ONE
self.serial_consistency_level = cassandra.ConsistencyLevel.SERIAL
- # the python driver returns BLOBs as string, but we expect them as bytearrays
- cassandra.cqltypes.BytesType.deserialize = staticmethod(lambda byts, protocol_version: bytearray(byts))
- cassandra.cqltypes.CassandraType.support_empty_values = True
-
- auto_format_udts()
self.empty_lines = 0
self.statement_error = False
@@ -807,11 +817,9 @@ class Shell(cmd.Cmd):
def get_keyspaces(self):
return self.conn.metadata.keyspaces.values()
- def get_ring(self):
- if self.current_keyspace is None or self.current_keyspace == 'system':
- raise NoKeyspaceError("Ring view requires a current non-system keyspace")
- self.conn.metadata.token_map.rebuild_keyspace(self.current_keyspace, build_if_absent=True)
- return self.conn.metadata.token_map.tokens_to_hosts_by_ks[self.current_keyspace]
+ def get_ring(self, ks):
+ self.conn.metadata.token_map.rebuild_keyspace(ks, build_if_absent=True)
+ return self.conn.metadata.token_map.tokens_to_hosts_by_ks[ks]
def get_table_meta(self, ksname, tablename):
if ksname is None:
@@ -1369,7 +1377,7 @@ class Shell(cmd.Cmd):
# print 'Snitch: %s\n' % snitch
if self.current_keyspace is not None and self.current_keyspace != 'system':
print "Range ownership:"
- ring = self.get_ring()
+ ring = self.get_ring(self.current_keyspace)
for entry in ring.items():
print ' %39s [%s]' % (str(entry[0].value), ', '.join([host.address for host in entry[1]]))
print
@@ -1506,10 +1514,14 @@ class Shell(cmd.Cmd):
ENCODING='utf8' - encoding for CSV output (COPY TO only)
TIMEFORMAT= - timestamp strftime format (COPY TO only)
'%Y-%m-%d %H:%M:%S%z' defaults to time_format value in cqlshrc
+ PAGESIZE='1000' - the page size for fetching results (COPY TO only)
+ PAGETIMEOUT=10 - the page timeout for fetching results (COPY TO only)
+ MAXATTEMPTS='5' - the maximum number of attempts for errors (COPY TO only)
When entering CSV data on STDIN, you can use the sequence "\."
on a line by itself to end the data input.
"""
+
ks = self.cql_unprotect_name(parsed.get_binding('ksname', None))
if ks is None:
ks = self.current_keyspace
@@ -1546,22 +1558,12 @@ class Shell(cmd.Cmd):
print "\n%d rows %s in %s." % (rows, verb, describe_interval(timeend - timestart))
def perform_csv_import(self, ks, cf, columns, fname, opts):
- dialect_options = self.csv_dialect_defaults.copy()
- if 'quote' in opts:
- dialect_options['quotechar'] = opts.pop('quote')
- if 'escape' in opts:
- dialect_options['escapechar'] = opts.pop('escape')
- if 'delimiter' in opts:
- dialect_options['delimiter'] = opts.pop('delimiter')
- nullval = opts.pop('null', '')
- header = bool(opts.pop('header', '').lower() == 'true')
- if dialect_options['quotechar'] == dialect_options['escapechar']:
- dialect_options['doublequote'] = True
- del dialect_options['escapechar']
- if opts:
+ csv_options, dialect_options, unrecognized_options = copy.parse_options(self, opts)
+ if unrecognized_options:
self.printerr('Unrecognized COPY FROM options: %s'
- % ', '.join(opts.keys()))
+ % ', '.join(unrecognized_options.keys()))
return 0
+ nullval, header = csv_options['nullval'], csv_options['header']
if fname is None:
do_close = False
@@ -1576,33 +1578,24 @@ class Shell(cmd.Cmd):
return 0
current_record = None
+ processes, pipes = [], [],
try:
if header:
linesource.next()
reader = csv.reader(linesource, **dialect_options)
- from multiprocessing import Process, Pipe, cpu_count
+ num_processes = copy.get_num_processes(cap=4)
- # Pick a resonable number of child processes. We need to leave at
- # least one core for the parent process. This doesn't necessarily
- # need to be capped at 4, but it's currently enough to keep
- # a single local Cassandra node busy, and I see lower throughput
- # with more processes.
- try:
- num_processes = max(1, min(4, cpu_count() - 1))
- except NotImplementedError:
- num_processes = 1
-
- processes, pipes = [], [],
for i in range(num_processes):
- parent_conn, child_conn = Pipe()
+ parent_conn, child_conn = mp.Pipe()
pipes.append(parent_conn)
- processes.append(Process(target=self.multiproc_import, args=(child_conn, ks, cf, columns, nullval)))
+ proc_args = (child_conn, ks, cf, columns, nullval)
+ processes.append(mp.Process(target=self.multiproc_import, args=proc_args))
for process in processes:
process.start()
- meter = RateMeter(10000)
+ meter = copy.RateMeter(10000)
for current_record, row in enumerate(reader, start=1):
# write to the child process
pipes[current_record % num_processes].send((current_record, row))
@@ -1612,7 +1605,7 @@ class Shell(cmd.Cmd):
# check for any errors reported by the children
if (current_record % 100) == 0:
- if self._check_child_pipes(current_record, pipes):
+ if self._check_import_processes(current_record, pipes):
# no errors seen, continue with outer loop
continue
else:
@@ -1641,7 +1634,7 @@ class Shell(cmd.Cmd):
for process in processes:
process.join()
- self._check_child_pipes(current_record, pipes)
+ self._check_import_processes(current_record, pipes)
for pipe in pipes:
pipe.close()
@@ -1653,8 +1646,7 @@ class Shell(cmd.Cmd):
return current_record
- def _check_child_pipes(self, current_record, pipes):
- # check the pipes for errors from child processes
+ def _check_import_processes(self, current_record, pipes):
for pipe in pipes:
if pipe.poll():
try:
@@ -1802,62 +1794,13 @@ class Shell(cmd.Cmd):
new_cluster.shutdown()
def perform_csv_export(self, ks, cf, columns, fname, opts):
- dialect_options = self.csv_dialect_defaults.copy()
-
- if 'quote' in opts:
- dialect_options['quotechar'] = opts.pop('quote')
- if 'escape' in opts:
- dialect_options['escapechar'] = opts.pop('escape')
- if 'delimiter' in opts:
- dialect_options['delimiter'] = opts.pop('delimiter')
- encoding = opts.pop('encoding', 'utf8')
- nullval = opts.pop('null', '')
- header = bool(opts.pop('header', '').lower() == 'true')
- time_format = opts.pop('timeformat', self.display_time_format)
- if dialect_options['quotechar'] == dialect_options['escapechar']:
- dialect_options['doublequote'] = True
- del dialect_options['escapechar']
-
- if opts:
- self.printerr('Unrecognized COPY TO options: %s'
- % ', '.join(opts.keys()))
+ csv_options, dialect_options, unrecognized_options = copy.parse_options(self, opts)
+ if unrecognized_options:
+ self.printerr('Unrecognized COPY TO options: %s' % ', '.join(unrecognized_options.keys()))
return 0
- if fname is None:
- do_close = False
- csvdest = sys.stdout
- else:
- do_close = True
- try:
- csvdest = open(fname, 'wb')
- except IOError, e:
- self.printerr("Can't open %r for writing: %s" % (fname, e))
- return 0
-
- meter = RateMeter(10000)
- try:
- dump = self.prep_export_dump(ks, cf, columns)
- writer = csv.writer(csvdest, **dialect_options)
- if header:
- writer.writerow(columns)
- for row in dump:
- fmt = lambda v: \
- format_value(v, output_encoding=encoding, nullval=nullval,
- time_format=time_format,
- float_precision=self.display_float_precision).strval
- writer.writerow(map(fmt, row.values()))
- meter.increment()
- finally:
- if do_close:
- csvdest.close()
- return meter.current_record
-
- def prep_export_dump(self, ks, cf, columns):
- if columns is None:
- columns = self.get_column_names(ks, cf)
- columnlist = ', '.join(protect_names(columns))
- query = 'SELECT %s FROM %s.%s' % (columnlist, protect_name(ks), protect_name(cf))
- return self.session.execute(query)
+ return copy.ExportTask(self, ks, cf, columns, fname, csv_options, dialect_options,
+ DEFAULT_PROTOCOL_VERSION, CONFIG_FILE).run()
def do_show(self, parsed):
"""
@@ -2215,34 +2158,6 @@ class Shell(cmd.Cmd):
self.writeresult(text, color, newline=newline, out=sys.stderr)
-class RateMeter(object):
-
- def __init__(self, log_rate):
- self.log_rate = log_rate
- self.last_checkpoint_time = time.time()
- self.current_rate = 0.0
- self.current_record = 0
-
- def increment(self):
- self.current_record += 1
-
- if (self.current_record % self.log_rate) == 0:
- new_checkpoint_time = time.time()
- new_rate = self.log_rate / (new_checkpoint_time - self.last_checkpoint_time)
- self.last_checkpoint_time = new_checkpoint_time
-
- # smooth the rate a bit
- if self.current_rate == 0.0:
- self.current_rate = new_rate
- else:
- self.current_rate = (self.current_rate + new_rate) / 2.0
-
- output = 'Processed %s rows; Write: %.2f rows/s\r' % \
- (self.current_record, self.current_rate)
- sys.stdout.write(output)
- sys.stdout.flush()
-
-
class SwitchCommand(object):
command = None
description = None
@@ -2487,6 +2402,9 @@ def main(options, hostname, port):
if batch_mode and shell.statement_error:
sys.exit(2)
+# always call this regardless of module name: when a sub-process is spawned
+# on Windows then the module name is not __main__, see CASSANDRA-9304
+insert_driver_hooks()
if __name__ == '__main__':
main(*read_options(sys.argv[1:], os.environ))
http://git-wip-us.apache.org/repos/asf/cassandra/blob/1b629c10/pylib/cqlshlib/copy.py
----------------------------------------------------------------------
diff --git a/pylib/cqlshlib/copy.py b/pylib/cqlshlib/copy.py
new file mode 100644
index 0000000..8534b98
--- /dev/null
+++ b/pylib/cqlshlib/copy.py
@@ -0,0 +1,644 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import csv
+import json
+import multiprocessing as mp
+import os
+import Queue
+import sys
+import time
+import traceback
+
+from StringIO import StringIO
+from random import randrange
+from threading import Lock
+
+from cassandra.cluster import Cluster
+from cassandra.metadata import protect_name, protect_names
+from cassandra.policies import RetryPolicy, WhiteListRoundRobinPolicy, TokenAwarePolicy
+from cassandra.query import tuple_factory
+
+
+import sslhandling
+from displaying import NO_COLOR_MAP
+from formatting import format_value_default, EMPTY, get_formatter
+
+
+def parse_options(shell, opts):
+ """
+ Parse options for import (COPY FROM) and export (COPY TO) operations.
+ Extract from opts csv and dialect options.
+
+ :return: 3 dictionaries: the csv options, the dialect options, any unrecognized options.
+ """
+ dialect_options = shell.csv_dialect_defaults.copy()
+ if 'quote' in opts:
+ dialect_options['quotechar'] = opts.pop('quote')
+ if 'escape' in opts:
+ dialect_options['escapechar'] = opts.pop('escape')
+ if 'delimiter' in opts:
+ dialect_options['delimiter'] = opts.pop('delimiter')
+ if dialect_options['quotechar'] == dialect_options['escapechar']:
+ dialect_options['doublequote'] = True
+ del dialect_options['escapechar']
+
+ csv_options = dict()
+ csv_options['nullval'] = opts.pop('null', '')
+ csv_options['header'] = bool(opts.pop('header', '').lower() == 'true')
+ csv_options['encoding'] = opts.pop('encoding', 'utf8')
+ csv_options['jobs'] = int(opts.pop('jobs', 12))
+ csv_options['pagesize'] = int(opts.pop('pagesize', 1000))
+ # by default the page timeout is 10 seconds per 1000 entries in the page size or 10 seconds if pagesize is smaller
+ csv_options['pagetimeout'] = int(opts.pop('pagetimeout', max(10, 10 * (csv_options['pagesize'] / 1000))))
+ csv_options['maxattempts'] = int(opts.pop('maxattempts', 5))
+ csv_options['dtformats'] = opts.pop('timeformat', shell.display_time_format)
+ csv_options['float_precision'] = shell.display_float_precision
+
+ return csv_options, dialect_options, opts
+
+
+def get_num_processes(cap):
+ """
+ Pick a reasonable number of child processes. We need to leave at
+ least one core for the parent process. This doesn't necessarily
+ need to be capped, but 4 is currently enough to keep
+ a single local Cassandra node busy so we use this for import, whilst
+ for export we use 16 since we can connect to multiple Cassandra nodes.
+ Eventually this parameter will become an option.
+ """
+ try:
+ return max(1, min(cap, mp.cpu_count() - 1))
+ except NotImplementedError:
+ return 1
+
+
+class ExportTask(object):
+ """
+ A class that exports data to .csv by instantiating one or more processes that work in parallel (ExportProcess).
+ """
+ def __init__(self, shell, ks, cf, columns, fname, csv_options, dialect_options, protocol_version, config_file):
+ self.shell = shell
+ self.csv_options = csv_options
+ self.dialect_options = dialect_options
+ self.ks = ks
+ self.cf = cf
+ self.columns = shell.get_column_names(ks, cf) if columns is None else columns
+ self.fname = fname
+ self.protocol_version = protocol_version
+ self.config_file = config_file
+
+ def run(self):
+ """
+ Initiates the export by creating the processes.
+ """
+ shell = self.shell
+ fname = self.fname
+
+ if fname is None:
+ do_close = False
+ csvdest = sys.stdout
+ else:
+ do_close = True
+ try:
+ csvdest = open(fname, 'wb')
+ except IOError, e:
+ shell.printerr("Can't open %r for writing: %s" % (fname, e))
+ return 0
+
+ if self.csv_options['header']:
+ writer = csv.writer(csvdest, **self.dialect_options)
+ writer.writerow(self.columns)
+
+ ranges = self.get_ranges()
+ num_processes = get_num_processes(cap=min(16, len(ranges)))
+
+ inmsg = mp.Queue()
+ outmsg = mp.Queue()
+ processes = []
+ for i in xrange(num_processes):
+ process = ExportProcess(outmsg, inmsg, self.ks, self.cf, self.columns, self.dialect_options,
+ self.csv_options, shell.debug, shell.port, shell.conn.cql_version,
+ shell.auth_provider, shell.ssl, self.protocol_version, self.config_file)
+ process.start()
+ processes.append(process)
+
+ try:
+ return self.check_processes(csvdest, ranges, inmsg, outmsg, processes)
+ finally:
+ for process in processes:
+ process.terminate()
+
+ inmsg.close()
+ outmsg.close()
+ if do_close:
+ csvdest.close()
+
+ def get_ranges(self):
+ """
+ return a queue of tuples, where the first tuple entry is a token range (from, to]
+ and the second entry is a list of hosts that own that range. Each host is responsible
+ for all the tokens in the rage (from, to].
+
+ The ring information comes from the driver metadata token map, which is built by
+ querying System.PEERS.
+
+ We only consider replicas that are in the local datacenter. If there are no local replicas
+ we use the cqlsh session host.
+ """
+ shell = self.shell
+ hostname = shell.hostname
+ ranges = dict()
+
+ def make_range(hosts):
+ return {'hosts': tuple(hosts), 'attempts': 0, 'rows': 0}
+
+ min_token = self.get_min_token()
+ if shell.conn.metadata.token_map is None or min_token is None:
+ ranges[(None, None)] = make_range([hostname])
+ return ranges
+
+ local_dc = shell.conn.metadata.get_host(hostname).datacenter
+ ring = shell.get_ring(self.ks).items()
+ ring.sort()
+
+ previous_previous = None
+ previous = None
+ for token, replicas in ring:
+ if previous is None and token.value == min_token:
+ continue # avoids looping entire ring
+
+ hosts = []
+ for host in replicas:
+ if host.datacenter == local_dc:
+ hosts.append(host.address)
+ if len(hosts) == 0:
+ hosts.append(hostname) # fallback to default host if no replicas in current dc
+ ranges[(previous, token.value)] = make_range(hosts)
+ previous_previous = previous
+ previous = token.value
+
+ # If the ring is empty we get the entire ring from the
+ # host we are currently connected to, otherwise for the last ring interval
+ # we query the same replicas that hold the last token in the ring
+ if len(ranges) == 0:
+ ranges[(None, None)] = make_range([hostname])
+ else:
+ ranges[(previous, None)] = ranges[(previous_previous, previous)].copy()
+
+ return ranges
+
+ def get_min_token(self):
+ """
+ :return the minimum token, which depends on the partitioner.
+ For partitioners that do not support tokens we return None, in
+ this cases we will not work in parallel, we'll just send all requests
+ to the cqlsh session host.
+ """
+ partitioner = self.shell.conn.metadata.partitioner
+
+ if partitioner.endswith('RandomPartitioner'):
+ return -1
+ elif partitioner.endswith('Murmur3Partitioner'):
+ return -(2 ** 63) # Long.MIN_VALUE in Java
+ else:
+ return None
+
+ @staticmethod
+ def send_work(ranges, tokens_to_send, queue):
+ for token_range in tokens_to_send:
+ queue.put((token_range, ranges[token_range]))
+ ranges[token_range]['attempts'] += 1
+
+ def check_processes(self, csvdest, ranges, inmsg, outmsg, processes):
+ """
+ Here we monitor all child processes by collecting their results
+ or any errors. We terminate when we have processed all the ranges or when there
+ are no more processes.
+ """
+ shell = self.shell
+ meter = RateMeter(10000)
+ total_jobs = len(ranges)
+ max_attempts = self.csv_options['maxattempts']
+
+ self.send_work(ranges, ranges.keys(), outmsg)
+
+ num_processes = len(processes)
+ succeeded = 0
+ failed = 0
+ while (failed + succeeded) < total_jobs and self.num_live_processes(processes) == num_processes:
+ try:
+ token_range, result = inmsg.get(timeout=1.0)
+ if token_range is None and result is None: # a job has finished
+ succeeded += 1
+ elif isinstance(result, Exception): # an error occurred
+ if token_range is None: # the entire process failed
+ shell.printerr('Error from worker process: %s' % (result))
+ else: # only this token_range failed, retry up to max_attempts if no rows received yet,
+ # if rows are receive we risk duplicating data, there is a back-off policy in place
+ # in the worker process as well, see ExpBackoffRetryPolicy
+ if ranges[token_range]['attempts'] < max_attempts and ranges[token_range]['rows'] == 0:
+ shell.printerr('Error for %s: %s (will try again later attempt %d of %d)'
+ % (token_range, result, ranges[token_range]['attempts'], max_attempts))
+ self.send_work(ranges, [token_range], outmsg)
+ else:
+ shell.printerr('Error for %s: %s (permanently given up after %d rows and %d attempts)'
+ % (token_range, result, ranges[token_range]['rows'],
+ ranges[token_range]['attempts']))
+ failed += 1
+ else: # partial result received
+ data, num = result
+ csvdest.write(data)
+ meter.increment(n=num)
+ ranges[token_range]['rows'] += num
+ except Queue.Empty:
+ pass
+
+ if self.num_live_processes(processes) < len(processes):
+ for process in processes:
+ if not process.is_alive():
+ shell.printerr('Child process %d died with exit code %d' % (process.pid, process.exitcode))
+
+ if succeeded < total_jobs:
+ shell.printerr('Exported %d ranges out of %d total ranges, some records might be missing'
+ % (succeeded, total_jobs))
+
+ return meter.get_total_records()
+
+ @staticmethod
+ def num_live_processes(processes):
+ return sum(1 for p in processes if p.is_alive())
+
+
+class ExpBackoffRetryPolicy(RetryPolicy):
+ """
+ A retry policy with exponential back-off for read timeouts,
+ see ExportProcess.
+ """
+ def __init__(self, export_process):
+ RetryPolicy.__init__(self)
+ self.max_attempts = export_process.csv_options['maxattempts']
+ self.printmsg = lambda txt: export_process.printmsg(txt)
+
+ def on_read_timeout(self, query, consistency, required_responses,
+ received_responses, data_retrieved, retry_num):
+ delay = self.backoff(retry_num)
+ if delay > 0:
+ self.printmsg("Timeout received, retrying after %d seconds" % (delay))
+ time.sleep(delay)
+ return self.RETRY, consistency
+ elif delay == 0:
+ self.printmsg("Timeout received, retrying immediately")
+ return self.RETRY, consistency
+ else:
+ self.printmsg("Timeout received, giving up after %d attempts" % (retry_num + 1))
+ return self.RETHROW, None
+
+ def backoff(self, retry_num):
+ """
+ Perform exponential back-off up to a maximum number of times, where
+ this maximum is per query.
+ To back-off we should wait a random number of seconds
+ between 0 and 2^c - 1, where c is the number of total failures.
+ randrange() excludes the last value, so we drop the -1.
+
+ :return : the number of seconds to wait for, -1 if we should not retry
+ """
+ if retry_num >= self.max_attempts:
+ return -1
+
+ delay = randrange(0, pow(2, retry_num + 1))
+ return delay
+
+
+class ExportSession(object):
+ """
+ A class for connecting to a cluster and storing the number
+ of jobs that this connection is processing. It wraps the methods
+ for executing a query asynchronously and for shutting down the
+ connection to the cluster.
+ """
+ def __init__(self, cluster, export_process):
+ session = cluster.connect(export_process.ks)
+ session.row_factory = tuple_factory
+ session.default_fetch_size = export_process.csv_options['pagesize']
+ session.default_timeout = export_process.csv_options['pagetimeout']
+
+ export_process.printmsg("Created connection to %s with page size %d and timeout %d seconds per page"
+ % (session.hosts, session.default_fetch_size, session.default_timeout))
+
+ self.cluster = cluster
+ self.session = session
+ self.jobs = 1
+ self.lock = Lock()
+
+ def add_job(self):
+ with self.lock:
+ self.jobs += 1
+
+ def complete_job(self):
+ with self.lock:
+ self.jobs -= 1
+
+ def num_jobs(self):
+ with self.lock:
+ return self.jobs
+
+ def execute_async(self, query):
+ return self.session.execute_async(query)
+
+ def shutdown(self):
+ self.cluster.shutdown()
+
+
+class ExportProcess(mp.Process):
+ """
+ An child worker process for the export task, ExportTask.
+ """
+
+ def __init__(self, inmsg, outmsg, ks, cf, columns, dialect_options, csv_options,
+ debug, port, cql_version, auth_provider, ssl, protocol_version, config_file):
+ mp.Process.__init__(self, target=self.run)
+ self.inmsg = inmsg
+ self.outmsg = outmsg
+ self.ks = ks
+ self.cf = cf
+ self.columns = columns
+ self.dialect_options = dialect_options
+ self.hosts_to_sessions = dict()
+
+ self.debug = debug
+ self.port = port
+ self.cql_version = cql_version
+ self.auth_provider = auth_provider
+ self.ssl = ssl
+ self.protocol_version = protocol_version
+ self.config_file = config_file
+
+ self.encoding = csv_options['encoding']
+ self.time_format = csv_options['dtformats']
+ self.float_precision = csv_options['float_precision']
+ self.nullval = csv_options['nullval']
+ self.maxjobs = csv_options['jobs']
+ self.csv_options = csv_options
+ self.formatters = dict()
+
+ # Here we inject some failures for testing purposes, only if this environment variable is set
+ if os.environ.get('CQLSH_COPY_TEST_FAILURES', ''):
+ self.test_failures = json.loads(os.environ.get('CQLSH_COPY_TEST_FAILURES', ''))
+ else:
+ self.test_failures = None
+
+ def printmsg(self, text):
+ if self.debug:
+ sys.stderr.write(text + os.linesep)
+
+ def run(self):
+ try:
+ self.inner_run()
+ finally:
+ self.close()
+
+ def inner_run(self):
+ """
+ The parent sends us (range, info) on the inbound queue (inmsg)
+ in order to request us to process a range, for which we can
+ select any of the hosts in info, which also contains other information for this
+ range such as the number of attempts already performed. We can signal errors
+ on the outbound queue (outmsg) by sending (range, error) or
+ we can signal a global error by sending (None, error).
+ We terminate when the inbound queue is closed.
+ """
+ while True:
+ if self.num_jobs() > self.maxjobs:
+ time.sleep(0.001) # 1 millisecond
+ continue
+
+ token_range, info = self.inmsg.get()
+ self.start_job(token_range, info)
+
+ def report_error(self, err, token_range=None):
+ if isinstance(err, str):
+ msg = err
+ elif isinstance(err, BaseException):
+ msg = "%s - %s" % (err.__class__.__name__, err)
+ if self.debug:
+ traceback.print_exc(err)
+ else:
+ msg = str(err)
+
+ self.printmsg(msg)
+ self.outmsg.put((token_range, Exception(msg)))
+
+ def start_job(self, token_range, info):
+ """
+ Begin querying a range by executing an async query that
+ will later on invoke the callbacks attached in attach_callbacks.
+ """
+ session = self.get_session(info['hosts'])
+ metadata = session.cluster.metadata.keyspaces[self.ks].tables[self.cf]
+ query = self.prepare_query(metadata.partition_key, token_range, info['attempts'])
+ future = session.execute_async(query)
+ self.attach_callbacks(token_range, future, session)
+
+ def num_jobs(self):
+ return sum(session.num_jobs() for session in self.hosts_to_sessions.values())
+
+ def get_session(self, hosts):
+ """
+ We select a host to connect to. If we have no connections to one of the hosts
+ yet then we select this host, else we pick the one with the smallest number
+ of jobs.
+
+ :return: An ExportSession connected to the chosen host.
+ """
+ new_hosts = [h for h in hosts if h not in self.hosts_to_sessions]
+ if new_hosts:
+ host = new_hosts[0]
+ new_cluster = Cluster(
+ contact_points=(host,),
+ port=self.port,
+ cql_version=self.cql_version,
+ protocol_version=self.protocol_version,
+ auth_provider=self.auth_provider,
+ ssl_options=sslhandling.ssl_settings(host, self.config_file) if self.ssl else None,
+ load_balancing_policy=TokenAwarePolicy(WhiteListRoundRobinPolicy(hosts)),
+ default_retry_policy=ExpBackoffRetryPolicy(self),
+ compression=None,
+ executor_threads=max(2, self.csv_options['jobs'] / 2))
+
+ session = ExportSession(new_cluster, self)
+ self.hosts_to_sessions[host] = session
+ return session
+ else:
+ host = min(hosts, key=lambda h: self.hosts_to_sessions[h].jobs)
+ session = self.hosts_to_sessions[host]
+ session.add_job()
+ return session
+
+ def attach_callbacks(self, token_range, future, session):
+ def result_callback(rows):
+ if future.has_more_pages:
+ future.start_fetching_next_page()
+ self.write_rows_to_csv(token_range, rows)
+ else:
+ self.write_rows_to_csv(token_range, rows)
+ self.outmsg.put((None, None))
+ session.complete_job()
+
+ def err_callback(err):
+ self.report_error(err, token_range)
+ session.complete_job()
+
+ future.add_callbacks(callback=result_callback, errback=err_callback)
+
+ def write_rows_to_csv(self, token_range, rows):
+ if len(rows) == 0:
+ return # no rows in this range
+
+ try:
+ output = StringIO()
+ writer = csv.writer(output, **self.dialect_options)
+
+ for row in rows:
+ writer.writerow(map(self.format_value, row))
+
+ data = (output.getvalue(), len(rows))
+ self.outmsg.put((token_range, data))
+ output.close()
+
+ except Exception, e:
+ self.report_error(e, token_range)
+
+ def format_value(self, val):
+ if val is None or val == EMPTY:
+ return format_value_default(self.nullval, colormap=NO_COLOR_MAP)
+
+ ctype = type(val)
+ formatter = self.formatters.get(ctype, None)
+ if not formatter:
+ formatter = get_formatter(ctype)
+ self.formatters[ctype] = formatter
+
+ return formatter(val, encoding=self.encoding, colormap=NO_COLOR_MAP, time_format=self.time_format,
+ float_precision=self.float_precision, nullval=self.nullval, quote=False)
+
+ def close(self):
+ self.printmsg("Export process terminating...")
+ self.inmsg.close()
+ self.outmsg.close()
+ for session in self.hosts_to_sessions.values():
+ session.shutdown()
+ self.printmsg("Export process terminated")
+
+ def prepare_query(self, partition_key, token_range, attempts):
+ """
+ Return the export query or a fake query with some failure injected.
+ """
+ if self.test_failures:
+ return self.maybe_inject_failures(partition_key, token_range, attempts)
+ else:
+ return self.prepare_export_query(partition_key, token_range)
+
+ def maybe_inject_failures(self, partition_key, token_range, attempts):
+ """
+ Examine self.test_failures and see if token_range is either a token range
+ supposed to cause a failure (failing_range) or to terminate the worker process
+ (exit_range). If not then call prepare_export_query(), which implements the
+ normal behavior.
+ """
+ start_token, end_token = token_range
+
+ if not start_token or not end_token:
+ # exclude first and last ranges to make things simpler
+ return self.prepare_export_query(partition_key, token_range)
+
+ if 'failing_range' in self.test_failures:
+ failing_range = self.test_failures['failing_range']
+ if start_token >= failing_range['start'] and end_token <= failing_range['end']:
+ if attempts < failing_range['num_failures']:
+ return 'SELECT * from bad_table'
+
+ if 'exit_range' in self.test_failures:
+ exit_range = self.test_failures['exit_range']
+ if start_token >= exit_range['start'] and end_token <= exit_range['end']:
+ sys.exit(1)
+
+ return self.prepare_export_query(partition_key, token_range)
+
+ def prepare_export_query(self, partition_key, token_range):
+ """
+ Return a query where we select all the data for this token range
+ """
+ pk_cols = ", ".join(protect_names(col.name for col in partition_key))
+ columnlist = ', '.join(protect_names(self.columns))
+ start_token, end_token = token_range
+ query = 'SELECT %s FROM %s.%s' % (columnlist, protect_name(self.ks), protect_name(self.cf))
+ if start_token is not None or end_token is not None:
+ query += ' WHERE'
+ if start_token is not None:
+ query += ' token(%s) > %s' % (pk_cols, start_token)
+ if start_token is not None and end_token is not None:
+ query += ' AND'
+ if end_token is not None:
+ query += ' token(%s) <= %s' % (pk_cols, end_token)
+ return query
+
+
+class RateMeter(object):
+
+ def __init__(self, log_threshold):
+ self.log_threshold = log_threshold # number of records after which we log
+ self.last_checkpoint_time = time.time() # last time we logged
+ self.current_rate = 0.0 # rows per second
+ self.current_record = 0 # number of records since we last logged
+ self.total_records = 0 # total number of records
+
+ def increment(self, n=1):
+ self.current_record += n
+
+ if self.current_record >= self.log_threshold:
+ self.update()
+ self.log()
+
+ def update(self):
+ new_checkpoint_time = time.time()
+ time_difference = new_checkpoint_time - self.last_checkpoint_time
+ if time_difference != 0.0:
+ self.current_rate = self.get_new_rate(self.current_record / time_difference)
+
+ self.last_checkpoint_time = new_checkpoint_time
+ self.total_records += self.current_record
+ self.current_record = 0
+
+ def get_new_rate(self, new_rate):
+ """
+ return the previous rate averaged with the new rate to smooth a bit
+ """
+ if self.current_rate == 0.0:
+ return new_rate
+ else:
+ return (self.current_rate + new_rate) / 2.0
+
+ def log(self):
+ output = 'Processed %d rows; Written: %f rows/s\r' % (self.total_records, self.current_rate,)
+ sys.stdout.write(output)
+ sys.stdout.flush()
+
+ def get_total_records(self):
+ self.update()
+ self.log()
+ return self.total_records
http://git-wip-us.apache.org/repos/asf/cassandra/blob/1b629c10/pylib/cqlshlib/displaying.py
----------------------------------------------------------------------
diff --git a/pylib/cqlshlib/displaying.py b/pylib/cqlshlib/displaying.py
index f3a016e..7b260c2 100644
--- a/pylib/cqlshlib/displaying.py
+++ b/pylib/cqlshlib/displaying.py
@@ -28,11 +28,19 @@ ANSI_RESET = '\033[0m'
def colorme(bval, colormap, colorkey):
+ if colormap is NO_COLOR_MAP:
+ return bval
if colormap is None:
colormap = DEFAULT_VALUE_COLORS
return FormattedValue(bval, colormap[colorkey] + bval + colormap['reset'])
+def get_str(val):
+ if isinstance(val, FormattedValue):
+ return val.strval
+ return val
+
+
class FormattedValue:
def __init__(self, strval, coloredval=None, displaywidth=None):
@@ -112,3 +120,5 @@ COLUMN_NAME_COLORS = defaultdict(lambda: MAGENTA,
blob=DARK_MAGENTA,
reset=ANSI_RESET,
)
+
+NO_COLOR_MAP = dict()
http://git-wip-us.apache.org/repos/asf/cassandra/blob/1b629c10/pylib/cqlshlib/formatting.py
----------------------------------------------------------------------
diff --git a/pylib/cqlshlib/formatting.py b/pylib/cqlshlib/formatting.py
index 79e661b..54dde0f 100644
--- a/pylib/cqlshlib/formatting.py
+++ b/pylib/cqlshlib/formatting.py
@@ -14,13 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import binascii
import sys
import re
import calendar
import math
from collections import defaultdict
from . import wcwidth
-from .displaying import colorme, FormattedValue, DEFAULT_VALUE_COLORS
+from .displaying import colorme, get_str, FormattedValue, DEFAULT_VALUE_COLORS, NO_COLOR_MAP
from cassandra.cqltypes import EMPTY
from cassandra.util import datetime_from_timestamp
from util import UTC
@@ -83,7 +84,6 @@ def color_text(bval, colormap, displaywidth=None):
# adding the smarts to handle that in to FormattedValue, we just
# make an explicit check to see if a null colormap is being used or
# not.
-
if displaywidth is None:
displaywidth = len(bval)
tbr = _make_turn_bits_red_f(colormap['blob'], colormap['text'])
@@ -97,7 +97,7 @@ def format_value_default(val, colormap, **_):
val = str(val)
escapedval = val.replace('\\', '\\\\')
bval = controlchars_re.sub(_show_control_chars, escapedval)
- return color_text(bval, colormap)
+ return bval if colormap is NO_COLOR_MAP else color_text(bval, colormap)
# Mapping cql type base names ("int", "map", etc) to formatter functions,
# making format_value a generic function
@@ -111,6 +111,10 @@ def format_value(type, val, **kwargs):
return formatter(val, **kwargs)
+def get_formatter(type):
+ return _formatters.get(type.__name__, format_value_default)
+
+
def formatter_for(typname):
def registrator(f):
_formatters[typname] = f
@@ -120,7 +124,7 @@ def formatter_for(typname):
@formatter_for('bytearray')
def format_value_blob(val, colormap, **_):
- bval = '0x' + ''.join('%02x' % c for c in val)
+ bval = '0x' + binascii.hexlify(val)
return colorme(bval, colormap, 'blob')
formatter_for('buffer')(format_value_blob)
@@ -204,8 +208,8 @@ def format_value_text(val, encoding, colormap, quote=False, **_):
bval = escapedval.encode(encoding, 'backslashreplace')
if quote:
bval = "'%s'" % bval
- displaywidth = wcwidth.wcswidth(bval.decode(encoding))
- return color_text(bval, colormap, displaywidth)
+
+ return bval if colormap is NO_COLOR_MAP else color_text(bval, colormap, wcwidth.wcswidth(bval.decode(encoding)))
# name alias
formatter_for('unicode')(format_value_text)
@@ -217,7 +221,10 @@ def format_simple_collection(val, lbracket, rbracket, encoding,
time_format=time_format, float_precision=float_precision,
nullval=nullval, quote=True)
for sval in val]
- bval = lbracket + ', '.join(sval.strval for sval in subs) + rbracket
+ bval = lbracket + ', '.join(get_str(sval) for sval in subs) + rbracket
+ if colormap is NO_COLOR_MAP:
+ return bval
+
lb, sep, rb = [colormap['collection'] + s + colormap['reset']
for s in (lbracket, ', ', rbracket)]
coloredval = lb + sep.join(sval.coloredval for sval in subs) + rb
@@ -242,6 +249,9 @@ def format_value_set(val, encoding, colormap, time_format, float_precision, null
return format_simple_collection(sorted(val), '{', '}', encoding, colormap,
time_format, float_precision, nullval)
formatter_for('frozenset')(format_value_set)
+# This code is used by cqlsh (bundled driver version 2.7.2 using sortedset),
+# and the dtests, which use whichever driver on the machine, i.e. 3.0.0 (SortedSet)
+formatter_for('SortedSet')(format_value_set)
formatter_for('sortedset')(format_value_set)
@@ -253,7 +263,10 @@ def format_value_map(val, encoding, colormap, time_format, float_precision, null
nullval=nullval, quote=True)
subs = [(subformat(k), subformat(v)) for (k, v) in sorted(val.items())]
- bval = '{' + ', '.join(k.strval + ': ' + v.strval for (k, v) in subs) + '}'
+ bval = '{' + ', '.join(get_str(k) + ': ' + get_str(v) for (k, v) in subs) + '}'
+ if colormap is NO_COLOR_MAP:
+ return bval
+
lb, comma, colon, rb = [colormap['collection'] + s + colormap['reset']
for s in ('{', ', ', ': ', '}')]
coloredval = lb \
@@ -278,7 +291,10 @@ def format_value_utype(val, encoding, colormap, time_format, float_precision, nu
return format_value_text(name, encoding=encoding, colormap=colormap, quote=False)
subs = [(format_field_name(k), format_field_value(v)) for (k, v) in val._asdict().items()]
- bval = '{' + ', '.join(k.strval + ': ' + v.strval for (k, v) in subs) + '}'
+ bval = '{' + ', '.join(get_str(k) + ': ' + get_str(v) for (k, v) in subs) + '}'
+ if colormap is NO_COLOR_MAP:
+ return bval
+
lb, comma, colon, rb = [colormap['collection'] + s + colormap['reset']
for s in ('{', ', ', ': ', '}')]
coloredval = lb \
[2/5] cassandra git commit: Merge branch 'cassandra-2.1' into
cassandra-2.2
Posted by ty...@apache.org.
http://git-wip-us.apache.org/repos/asf/cassandra/blob/d2f243ee/pylib/cqlshlib/copy.py
----------------------------------------------------------------------
diff --cc pylib/cqlshlib/copy.py
index 0000000,8534b98..8ff474f
mode 000000,100644..100644
--- a/pylib/cqlshlib/copy.py
+++ b/pylib/cqlshlib/copy.py
@@@ -1,0 -1,644 +1,647 @@@
+ # Licensed to the Apache Software Foundation (ASF) under one
+ # or more contributor license agreements. See the NOTICE file
+ # distributed with this work for additional information
+ # regarding copyright ownership. The ASF licenses this file
+ # to you under the Apache License, Version 2.0 (the
+ # "License"); you may not use this file except in compliance
+ # with the License. You may obtain a copy of the License at
+ #
+ # http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing, software
+ # distributed under the License is distributed on an "AS IS" BASIS,
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ # See the License for the specific language governing permissions and
+ # limitations under the License.
+
+ import csv
+ import json
+ import multiprocessing as mp
+ import os
+ import Queue
++import random
+ import sys
+ import time
+ import traceback
+
+ from StringIO import StringIO
+ from random import randrange
+ from threading import Lock
+
+ from cassandra.cluster import Cluster
+ from cassandra.metadata import protect_name, protect_names
+ from cassandra.policies import RetryPolicy, WhiteListRoundRobinPolicy, TokenAwarePolicy
+ from cassandra.query import tuple_factory
+
+
+ import sslhandling
+ from displaying import NO_COLOR_MAP
-from formatting import format_value_default, EMPTY, get_formatter
++from formatting import format_value_default, DateTimeFormat, EMPTY, get_formatter
+
+
+ def parse_options(shell, opts):
+ """
+ Parse options for import (COPY FROM) and export (COPY TO) operations.
+ Extract from opts csv and dialect options.
+
+ :return: 3 dictionaries: the csv options, the dialect options, any unrecognized options.
+ """
+ dialect_options = shell.csv_dialect_defaults.copy()
+ if 'quote' in opts:
+ dialect_options['quotechar'] = opts.pop('quote')
+ if 'escape' in opts:
+ dialect_options['escapechar'] = opts.pop('escape')
+ if 'delimiter' in opts:
+ dialect_options['delimiter'] = opts.pop('delimiter')
+ if dialect_options['quotechar'] == dialect_options['escapechar']:
+ dialect_options['doublequote'] = True
+ del dialect_options['escapechar']
+
+ csv_options = dict()
+ csv_options['nullval'] = opts.pop('null', '')
+ csv_options['header'] = bool(opts.pop('header', '').lower() == 'true')
+ csv_options['encoding'] = opts.pop('encoding', 'utf8')
+ csv_options['jobs'] = int(opts.pop('jobs', 12))
+ csv_options['pagesize'] = int(opts.pop('pagesize', 1000))
+ # by default the page timeout is 10 seconds per 1000 entries in the page size or 10 seconds if pagesize is smaller
+ csv_options['pagetimeout'] = int(opts.pop('pagetimeout', max(10, 10 * (csv_options['pagesize'] / 1000))))
+ csv_options['maxattempts'] = int(opts.pop('maxattempts', 5))
- csv_options['dtformats'] = opts.pop('timeformat', shell.display_time_format)
+ csv_options['float_precision'] = shell.display_float_precision
++ csv_options['dtformats'] = DateTimeFormat(opts.pop('timeformat', shell.display_timestamp_format),
++ shell.display_date_format,
++ shell.display_nanotime_format)
+
+ return csv_options, dialect_options, opts
+
+
+ def get_num_processes(cap):
+ """
+ Pick a reasonable number of child processes. We need to leave at
+ least one core for the parent process. This doesn't necessarily
+ need to be capped, but 4 is currently enough to keep
+ a single local Cassandra node busy so we use this for import, whilst
+ for export we use 16 since we can connect to multiple Cassandra nodes.
+ Eventually this parameter will become an option.
+ """
+ try:
+ return max(1, min(cap, mp.cpu_count() - 1))
+ except NotImplementedError:
+ return 1
+
+
+ class ExportTask(object):
+ """
+ A class that exports data to .csv by instantiating one or more processes that work in parallel (ExportProcess).
+ """
+ def __init__(self, shell, ks, cf, columns, fname, csv_options, dialect_options, protocol_version, config_file):
+ self.shell = shell
+ self.csv_options = csv_options
+ self.dialect_options = dialect_options
+ self.ks = ks
+ self.cf = cf
+ self.columns = shell.get_column_names(ks, cf) if columns is None else columns
+ self.fname = fname
+ self.protocol_version = protocol_version
+ self.config_file = config_file
+
+ def run(self):
+ """
+ Initiates the export by creating the processes.
+ """
+ shell = self.shell
+ fname = self.fname
+
+ if fname is None:
+ do_close = False
+ csvdest = sys.stdout
+ else:
+ do_close = True
+ try:
+ csvdest = open(fname, 'wb')
+ except IOError, e:
+ shell.printerr("Can't open %r for writing: %s" % (fname, e))
+ return 0
+
+ if self.csv_options['header']:
+ writer = csv.writer(csvdest, **self.dialect_options)
+ writer.writerow(self.columns)
+
+ ranges = self.get_ranges()
+ num_processes = get_num_processes(cap=min(16, len(ranges)))
+
+ inmsg = mp.Queue()
+ outmsg = mp.Queue()
+ processes = []
+ for i in xrange(num_processes):
+ process = ExportProcess(outmsg, inmsg, self.ks, self.cf, self.columns, self.dialect_options,
+ self.csv_options, shell.debug, shell.port, shell.conn.cql_version,
+ shell.auth_provider, shell.ssl, self.protocol_version, self.config_file)
+ process.start()
+ processes.append(process)
+
+ try:
+ return self.check_processes(csvdest, ranges, inmsg, outmsg, processes)
+ finally:
+ for process in processes:
+ process.terminate()
+
+ inmsg.close()
+ outmsg.close()
+ if do_close:
+ csvdest.close()
+
+ def get_ranges(self):
+ """
+ return a queue of tuples, where the first tuple entry is a token range (from, to]
+ and the second entry is a list of hosts that own that range. Each host is responsible
+ for all the tokens in the rage (from, to].
+
+ The ring information comes from the driver metadata token map, which is built by
+ querying System.PEERS.
+
+ We only consider replicas that are in the local datacenter. If there are no local replicas
+ we use the cqlsh session host.
+ """
+ shell = self.shell
+ hostname = shell.hostname
+ ranges = dict()
+
+ def make_range(hosts):
+ return {'hosts': tuple(hosts), 'attempts': 0, 'rows': 0}
+
+ min_token = self.get_min_token()
+ if shell.conn.metadata.token_map is None or min_token is None:
+ ranges[(None, None)] = make_range([hostname])
+ return ranges
+
+ local_dc = shell.conn.metadata.get_host(hostname).datacenter
+ ring = shell.get_ring(self.ks).items()
+ ring.sort()
+
+ previous_previous = None
+ previous = None
+ for token, replicas in ring:
+ if previous is None and token.value == min_token:
+ continue # avoids looping entire ring
+
+ hosts = []
+ for host in replicas:
+ if host.datacenter == local_dc:
+ hosts.append(host.address)
+ if len(hosts) == 0:
+ hosts.append(hostname) # fallback to default host if no replicas in current dc
+ ranges[(previous, token.value)] = make_range(hosts)
+ previous_previous = previous
+ previous = token.value
+
+ # If the ring is empty we get the entire ring from the
+ # host we are currently connected to, otherwise for the last ring interval
+ # we query the same replicas that hold the last token in the ring
+ if len(ranges) == 0:
+ ranges[(None, None)] = make_range([hostname])
+ else:
+ ranges[(previous, None)] = ranges[(previous_previous, previous)].copy()
+
+ return ranges
+
+ def get_min_token(self):
+ """
+ :return the minimum token, which depends on the partitioner.
+ For partitioners that do not support tokens we return None, in
+ this cases we will not work in parallel, we'll just send all requests
+ to the cqlsh session host.
+ """
+ partitioner = self.shell.conn.metadata.partitioner
+
+ if partitioner.endswith('RandomPartitioner'):
+ return -1
+ elif partitioner.endswith('Murmur3Partitioner'):
+ return -(2 ** 63) # Long.MIN_VALUE in Java
+ else:
+ return None
+
+ @staticmethod
+ def send_work(ranges, tokens_to_send, queue):
+ for token_range in tokens_to_send:
+ queue.put((token_range, ranges[token_range]))
+ ranges[token_range]['attempts'] += 1
+
+ def check_processes(self, csvdest, ranges, inmsg, outmsg, processes):
+ """
+ Here we monitor all child processes by collecting their results
+ or any errors. We terminate when we have processed all the ranges or when there
+ are no more processes.
+ """
+ shell = self.shell
+ meter = RateMeter(10000)
+ total_jobs = len(ranges)
+ max_attempts = self.csv_options['maxattempts']
+
+ self.send_work(ranges, ranges.keys(), outmsg)
+
+ num_processes = len(processes)
+ succeeded = 0
+ failed = 0
+ while (failed + succeeded) < total_jobs and self.num_live_processes(processes) == num_processes:
+ try:
+ token_range, result = inmsg.get(timeout=1.0)
+ if token_range is None and result is None: # a job has finished
+ succeeded += 1
+ elif isinstance(result, Exception): # an error occurred
+ if token_range is None: # the entire process failed
+ shell.printerr('Error from worker process: %s' % (result))
+ else: # only this token_range failed, retry up to max_attempts if no rows received yet,
+ # if rows are receive we risk duplicating data, there is a back-off policy in place
+ # in the worker process as well, see ExpBackoffRetryPolicy
+ if ranges[token_range]['attempts'] < max_attempts and ranges[token_range]['rows'] == 0:
+ shell.printerr('Error for %s: %s (will try again later attempt %d of %d)'
+ % (token_range, result, ranges[token_range]['attempts'], max_attempts))
+ self.send_work(ranges, [token_range], outmsg)
+ else:
+ shell.printerr('Error for %s: %s (permanently given up after %d rows and %d attempts)'
+ % (token_range, result, ranges[token_range]['rows'],
+ ranges[token_range]['attempts']))
+ failed += 1
+ else: # partial result received
+ data, num = result
+ csvdest.write(data)
+ meter.increment(n=num)
+ ranges[token_range]['rows'] += num
+ except Queue.Empty:
+ pass
+
+ if self.num_live_processes(processes) < len(processes):
+ for process in processes:
+ if not process.is_alive():
+ shell.printerr('Child process %d died with exit code %d' % (process.pid, process.exitcode))
+
+ if succeeded < total_jobs:
+ shell.printerr('Exported %d ranges out of %d total ranges, some records might be missing'
+ % (succeeded, total_jobs))
+
+ return meter.get_total_records()
+
+ @staticmethod
+ def num_live_processes(processes):
+ return sum(1 for p in processes if p.is_alive())
+
+
+ class ExpBackoffRetryPolicy(RetryPolicy):
+ """
+ A retry policy with exponential back-off for read timeouts,
+ see ExportProcess.
+ """
+ def __init__(self, export_process):
+ RetryPolicy.__init__(self)
+ self.max_attempts = export_process.csv_options['maxattempts']
+ self.printmsg = lambda txt: export_process.printmsg(txt)
+
+ def on_read_timeout(self, query, consistency, required_responses,
+ received_responses, data_retrieved, retry_num):
+ delay = self.backoff(retry_num)
+ if delay > 0:
+ self.printmsg("Timeout received, retrying after %d seconds" % (delay))
+ time.sleep(delay)
+ return self.RETRY, consistency
+ elif delay == 0:
+ self.printmsg("Timeout received, retrying immediately")
+ return self.RETRY, consistency
+ else:
+ self.printmsg("Timeout received, giving up after %d attempts" % (retry_num + 1))
+ return self.RETHROW, None
+
+ def backoff(self, retry_num):
+ """
+ Perform exponential back-off up to a maximum number of times, where
+ this maximum is per query.
+ To back-off we should wait a random number of seconds
+ between 0 and 2^c - 1, where c is the number of total failures.
+ randrange() excludes the last value, so we drop the -1.
+
+ :return : the number of seconds to wait for, -1 if we should not retry
+ """
+ if retry_num >= self.max_attempts:
+ return -1
+
+ delay = randrange(0, pow(2, retry_num + 1))
+ return delay
+
+
+ class ExportSession(object):
+ """
+ A class for connecting to a cluster and storing the number
+ of jobs that this connection is processing. It wraps the methods
+ for executing a query asynchronously and for shutting down the
+ connection to the cluster.
+ """
+ def __init__(self, cluster, export_process):
+ session = cluster.connect(export_process.ks)
+ session.row_factory = tuple_factory
+ session.default_fetch_size = export_process.csv_options['pagesize']
+ session.default_timeout = export_process.csv_options['pagetimeout']
+
+ export_process.printmsg("Created connection to %s with page size %d and timeout %d seconds per page"
+ % (session.hosts, session.default_fetch_size, session.default_timeout))
+
+ self.cluster = cluster
+ self.session = session
+ self.jobs = 1
+ self.lock = Lock()
+
+ def add_job(self):
+ with self.lock:
+ self.jobs += 1
+
+ def complete_job(self):
+ with self.lock:
+ self.jobs -= 1
+
+ def num_jobs(self):
+ with self.lock:
+ return self.jobs
+
+ def execute_async(self, query):
+ return self.session.execute_async(query)
+
+ def shutdown(self):
+ self.cluster.shutdown()
+
+
+ class ExportProcess(mp.Process):
+ """
+ An child worker process for the export task, ExportTask.
+ """
+
+ def __init__(self, inmsg, outmsg, ks, cf, columns, dialect_options, csv_options,
+ debug, port, cql_version, auth_provider, ssl, protocol_version, config_file):
+ mp.Process.__init__(self, target=self.run)
+ self.inmsg = inmsg
+ self.outmsg = outmsg
+ self.ks = ks
+ self.cf = cf
+ self.columns = columns
+ self.dialect_options = dialect_options
+ self.hosts_to_sessions = dict()
+
+ self.debug = debug
+ self.port = port
+ self.cql_version = cql_version
+ self.auth_provider = auth_provider
+ self.ssl = ssl
+ self.protocol_version = protocol_version
+ self.config_file = config_file
+
+ self.encoding = csv_options['encoding']
- self.time_format = csv_options['dtformats']
++ self.date_time_format = csv_options['dtformats']
+ self.float_precision = csv_options['float_precision']
+ self.nullval = csv_options['nullval']
+ self.maxjobs = csv_options['jobs']
+ self.csv_options = csv_options
+ self.formatters = dict()
+
+ # Here we inject some failures for testing purposes, only if this environment variable is set
+ if os.environ.get('CQLSH_COPY_TEST_FAILURES', ''):
+ self.test_failures = json.loads(os.environ.get('CQLSH_COPY_TEST_FAILURES', ''))
+ else:
+ self.test_failures = None
+
+ def printmsg(self, text):
+ if self.debug:
+ sys.stderr.write(text + os.linesep)
+
+ def run(self):
+ try:
+ self.inner_run()
+ finally:
+ self.close()
+
+ def inner_run(self):
+ """
+ The parent sends us (range, info) on the inbound queue (inmsg)
+ in order to request us to process a range, for which we can
+ select any of the hosts in info, which also contains other information for this
+ range such as the number of attempts already performed. We can signal errors
+ on the outbound queue (outmsg) by sending (range, error) or
+ we can signal a global error by sending (None, error).
+ We terminate when the inbound queue is closed.
+ """
+ while True:
+ if self.num_jobs() > self.maxjobs:
+ time.sleep(0.001) # 1 millisecond
+ continue
+
+ token_range, info = self.inmsg.get()
+ self.start_job(token_range, info)
+
+ def report_error(self, err, token_range=None):
+ if isinstance(err, str):
+ msg = err
+ elif isinstance(err, BaseException):
+ msg = "%s - %s" % (err.__class__.__name__, err)
+ if self.debug:
+ traceback.print_exc(err)
+ else:
+ msg = str(err)
+
+ self.printmsg(msg)
+ self.outmsg.put((token_range, Exception(msg)))
+
+ def start_job(self, token_range, info):
+ """
+ Begin querying a range by executing an async query that
+ will later on invoke the callbacks attached in attach_callbacks.
+ """
+ session = self.get_session(info['hosts'])
+ metadata = session.cluster.metadata.keyspaces[self.ks].tables[self.cf]
+ query = self.prepare_query(metadata.partition_key, token_range, info['attempts'])
+ future = session.execute_async(query)
+ self.attach_callbacks(token_range, future, session)
+
+ def num_jobs(self):
+ return sum(session.num_jobs() for session in self.hosts_to_sessions.values())
+
+ def get_session(self, hosts):
+ """
+ We select a host to connect to. If we have no connections to one of the hosts
+ yet then we select this host, else we pick the one with the smallest number
+ of jobs.
+
+ :return: An ExportSession connected to the chosen host.
+ """
+ new_hosts = [h for h in hosts if h not in self.hosts_to_sessions]
+ if new_hosts:
+ host = new_hosts[0]
+ new_cluster = Cluster(
+ contact_points=(host,),
+ port=self.port,
+ cql_version=self.cql_version,
+ protocol_version=self.protocol_version,
+ auth_provider=self.auth_provider,
+ ssl_options=sslhandling.ssl_settings(host, self.config_file) if self.ssl else None,
+ load_balancing_policy=TokenAwarePolicy(WhiteListRoundRobinPolicy(hosts)),
+ default_retry_policy=ExpBackoffRetryPolicy(self),
+ compression=None,
+ executor_threads=max(2, self.csv_options['jobs'] / 2))
+
+ session = ExportSession(new_cluster, self)
+ self.hosts_to_sessions[host] = session
+ return session
+ else:
+ host = min(hosts, key=lambda h: self.hosts_to_sessions[h].jobs)
+ session = self.hosts_to_sessions[host]
+ session.add_job()
+ return session
+
+ def attach_callbacks(self, token_range, future, session):
+ def result_callback(rows):
+ if future.has_more_pages:
+ future.start_fetching_next_page()
+ self.write_rows_to_csv(token_range, rows)
+ else:
+ self.write_rows_to_csv(token_range, rows)
+ self.outmsg.put((None, None))
+ session.complete_job()
+
+ def err_callback(err):
+ self.report_error(err, token_range)
+ session.complete_job()
+
+ future.add_callbacks(callback=result_callback, errback=err_callback)
+
+ def write_rows_to_csv(self, token_range, rows):
+ if len(rows) == 0:
+ return # no rows in this range
+
+ try:
+ output = StringIO()
+ writer = csv.writer(output, **self.dialect_options)
+
+ for row in rows:
+ writer.writerow(map(self.format_value, row))
+
+ data = (output.getvalue(), len(rows))
+ self.outmsg.put((token_range, data))
+ output.close()
+
+ except Exception, e:
+ self.report_error(e, token_range)
+
+ def format_value(self, val):
+ if val is None or val == EMPTY:
+ return format_value_default(self.nullval, colormap=NO_COLOR_MAP)
+
+ ctype = type(val)
+ formatter = self.formatters.get(ctype, None)
+ if not formatter:
+ formatter = get_formatter(ctype)
+ self.formatters[ctype] = formatter
+
- return formatter(val, encoding=self.encoding, colormap=NO_COLOR_MAP, time_format=self.time_format,
++ return formatter(val, encoding=self.encoding, colormap=NO_COLOR_MAP, date_time_format=self.date_time_format,
+ float_precision=self.float_precision, nullval=self.nullval, quote=False)
+
+ def close(self):
+ self.printmsg("Export process terminating...")
+ self.inmsg.close()
+ self.outmsg.close()
+ for session in self.hosts_to_sessions.values():
+ session.shutdown()
+ self.printmsg("Export process terminated")
+
+ def prepare_query(self, partition_key, token_range, attempts):
+ """
+ Return the export query or a fake query with some failure injected.
+ """
+ if self.test_failures:
+ return self.maybe_inject_failures(partition_key, token_range, attempts)
+ else:
+ return self.prepare_export_query(partition_key, token_range)
+
+ def maybe_inject_failures(self, partition_key, token_range, attempts):
+ """
+ Examine self.test_failures and see if token_range is either a token range
+ supposed to cause a failure (failing_range) or to terminate the worker process
+ (exit_range). If not then call prepare_export_query(), which implements the
+ normal behavior.
+ """
+ start_token, end_token = token_range
+
+ if not start_token or not end_token:
+ # exclude first and last ranges to make things simpler
+ return self.prepare_export_query(partition_key, token_range)
+
+ if 'failing_range' in self.test_failures:
+ failing_range = self.test_failures['failing_range']
+ if start_token >= failing_range['start'] and end_token <= failing_range['end']:
+ if attempts < failing_range['num_failures']:
+ return 'SELECT * from bad_table'
+
+ if 'exit_range' in self.test_failures:
+ exit_range = self.test_failures['exit_range']
+ if start_token >= exit_range['start'] and end_token <= exit_range['end']:
+ sys.exit(1)
+
+ return self.prepare_export_query(partition_key, token_range)
+
+ def prepare_export_query(self, partition_key, token_range):
+ """
+ Return a query where we select all the data for this token range
+ """
+ pk_cols = ", ".join(protect_names(col.name for col in partition_key))
+ columnlist = ', '.join(protect_names(self.columns))
+ start_token, end_token = token_range
+ query = 'SELECT %s FROM %s.%s' % (columnlist, protect_name(self.ks), protect_name(self.cf))
+ if start_token is not None or end_token is not None:
+ query += ' WHERE'
+ if start_token is not None:
+ query += ' token(%s) > %s' % (pk_cols, start_token)
+ if start_token is not None and end_token is not None:
+ query += ' AND'
+ if end_token is not None:
+ query += ' token(%s) <= %s' % (pk_cols, end_token)
+ return query
+
+
+ class RateMeter(object):
+
+ def __init__(self, log_threshold):
+ self.log_threshold = log_threshold # number of records after which we log
+ self.last_checkpoint_time = time.time() # last time we logged
+ self.current_rate = 0.0 # rows per second
+ self.current_record = 0 # number of records since we last logged
+ self.total_records = 0 # total number of records
+
+ def increment(self, n=1):
+ self.current_record += n
+
+ if self.current_record >= self.log_threshold:
+ self.update()
+ self.log()
+
+ def update(self):
+ new_checkpoint_time = time.time()
+ time_difference = new_checkpoint_time - self.last_checkpoint_time
+ if time_difference != 0.0:
+ self.current_rate = self.get_new_rate(self.current_record / time_difference)
+
+ self.last_checkpoint_time = new_checkpoint_time
+ self.total_records += self.current_record
+ self.current_record = 0
+
+ def get_new_rate(self, new_rate):
+ """
+ return the previous rate averaged with the new rate to smooth a bit
+ """
+ if self.current_rate == 0.0:
+ return new_rate
+ else:
+ return (self.current_rate + new_rate) / 2.0
+
+ def log(self):
+ output = 'Processed %d rows; Written: %f rows/s\r' % (self.total_records, self.current_rate,)
+ sys.stdout.write(output)
+ sys.stdout.flush()
+
+ def get_total_records(self):
+ self.update()
+ self.log()
+ return self.total_records
http://git-wip-us.apache.org/repos/asf/cassandra/blob/d2f243ee/pylib/cqlshlib/displaying.py
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/cassandra/blob/d2f243ee/pylib/cqlshlib/formatting.py
----------------------------------------------------------------------
diff --cc pylib/cqlshlib/formatting.py
index fe1786a,54dde0f..8b66bce
--- a/pylib/cqlshlib/formatting.py
+++ b/pylib/cqlshlib/formatting.py
@@@ -14,16 -14,14 +14,16 @@@
# See the License for the specific language governing permissions and
# limitations under the License.
+ import binascii
-import sys
-import re
import calendar
import math
- import platform
+import re
+import sys
+import platform
from collections import defaultdict
+
from . import wcwidth
- from .displaying import colorme, FormattedValue, DEFAULT_VALUE_COLORS
+ from .displaying import colorme, get_str, FormattedValue, DEFAULT_VALUE_COLORS, NO_COLOR_MAP
from cassandra.cqltypes import EMPTY
from cassandra.util import datetime_from_timestamp
from util import UTC
@@@ -239,12 -216,15 +242,15 @@@ formatter_for('unicode')(format_value_t
def format_simple_collection(val, lbracket, rbracket, encoding,
- colormap, time_format, float_precision, nullval):
+ colormap, date_time_format, float_precision, nullval):
subs = [format_value(type(sval), sval, encoding=encoding, colormap=colormap,
- time_format=time_format, float_precision=float_precision,
+ date_time_format=date_time_format, float_precision=float_precision,
nullval=nullval, quote=True)
for sval in val]
- bval = lbracket + ', '.join(sval.strval for sval in subs) + rbracket
+ bval = lbracket + ', '.join(get_str(sval) for sval in subs) + rbracket
+ if colormap is NO_COLOR_MAP:
+ return bval
+
lb, sep, rb = [colormap['collection'] + s + colormap['reset']
for s in (lbracket, ', ', rbracket)]
coloredval = lb + sep.join(sval.coloredval for sval in subs) + rb
[5/5] cassandra git commit: Merge branch 'cassandra-2.2' into
cassandra-3.0
Posted by ty...@apache.org.
Merge branch 'cassandra-2.2' into cassandra-3.0
Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/74070ee4
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/74070ee4
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/74070ee4
Branch: refs/heads/cassandra-3.0
Commit: 74070ee4aba7d8d735cbe2ce5453801a0563b043
Parents: 12fd5d2 d2f243e
Author: Tyler Hobbs <ty...@gmail.com>
Authored: Wed Nov 18 18:31:31 2015 -0600
Committer: Tyler Hobbs <ty...@gmail.com>
Committed: Wed Nov 18 18:31:31 2015 -0600
----------------------------------------------------------------------
CHANGES.txt | 1 +
bin/cqlsh.py | 153 +++------
pylib/cqlshlib/copy.py | 647 ++++++++++++++++++++++++++++++++++++++
pylib/cqlshlib/displaying.py | 10 +
pylib/cqlshlib/formatting.py | 32 +-
5 files changed, 725 insertions(+), 118 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/cassandra/blob/74070ee4/CHANGES.txt
----------------------------------------------------------------------
diff --cc CHANGES.txt
index 4510462,4e19b23..2a9fbe7
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@@ -9,8 -3,19 +9,9 @@@ Merged from 2.2
* Fix SimpleDateType type compatibility (CASSANDRA-10027)
* (Hadoop) fix splits calculation (CASSANDRA-10640)
* (Hadoop) ensure that Cluster instances are always closed (CASSANDRA-10058)
- * (cqlsh) show partial trace if incomplete after max_trace_wait (CASSANDRA-7645)
- * Use most up-to-date version of schema for system tables (CASSANDRA-10652)
- * Deprecate memory_allocator in cassandra.yaml (CASSANDRA-10581,10628)
- * Expose phi values from failure detector via JMX and tweak debug
- and trace logging (CASSANDRA-9526)
- * Fix RangeNamesQueryPager (CASSANDRA-10509)
- * Deprecate Pig support (CASSANDRA-10542)
- * Reduce contention getting instances of CompositeType (CASSANDRA-10433)
Merged from 2.1:
- * * (cqlsh) Improve COPY TO performance and error handling (CASSANDRA-9304)
- * Don't remove level info when running upgradesstables (CASSANDRA-10692)
++ * (cqlsh) Improve COPY TO performance and error handling (CASSANDRA-9304)
* Create compression chunk for sending file only (CASSANDRA-10680)
- * Make buffered read size configurable (CASSANDRA-10249)
* Forbid compact clustering column type changes in ALTER TABLE (CASSANDRA-8879)
* Reject incremental repair with subrange repair (CASSANDRA-10422)
* Add a nodetool command to refresh size_estimates (CASSANDRA-9579)
http://git-wip-us.apache.org/repos/asf/cassandra/blob/74070ee4/bin/cqlsh.py
----------------------------------------------------------------------
diff --cc bin/cqlsh.py
index 33533c5,94c7af3..793afe5
--- a/bin/cqlsh.py
+++ b/bin/cqlsh.py
@@@ -2394,11 -2259,11 +2328,11 @@@ class ImportProcess(mp.Process)
table_meta = new_cluster.metadata.keyspaces[self.ks].tables[self.cf]
pk_cols = [col.name for col in table_meta.primary_key]
- cqltypes = [table_meta.columns[name].typestring for name in self.columns]
+ cqltypes = [table_meta.columns[name].cql_type for name in self.columns]
pk_indexes = [self.columns.index(col.name) for col in table_meta.primary_key]
query = 'INSERT INTO %s.%s (%s) VALUES (%%s)' % (
- protect_name(table_meta.keyspace_name),
- protect_name(table_meta.name),
+ protect_name(self.ks),
+ protect_name(self.cf),
', '.join(protect_names(self.columns)))
# we need to handle some types specially
[4/5] cassandra git commit: Merge branch 'cassandra-2.1' into
cassandra-2.2
Posted by ty...@apache.org.
Merge branch 'cassandra-2.1' into cassandra-2.2
Project: http://git-wip-us.apache.org/repos/asf/cassandra/repo
Commit: http://git-wip-us.apache.org/repos/asf/cassandra/commit/d2f243ee
Tree: http://git-wip-us.apache.org/repos/asf/cassandra/tree/d2f243ee
Diff: http://git-wip-us.apache.org/repos/asf/cassandra/diff/d2f243ee
Branch: refs/heads/cassandra-3.0
Commit: d2f243ee5c656eb51053e553d7371802255dfc54
Parents: d09b6c6 1b629c1
Author: Tyler Hobbs <ty...@gmail.com>
Authored: Wed Nov 18 18:25:18 2015 -0600
Committer: Tyler Hobbs <ty...@gmail.com>
Committed: Wed Nov 18 18:25:18 2015 -0600
----------------------------------------------------------------------
CHANGES.txt | 1 +
bin/cqlsh.py | 153 +++------
pylib/cqlshlib/copy.py | 647 ++++++++++++++++++++++++++++++++++++++
pylib/cqlshlib/displaying.py | 10 +
pylib/cqlshlib/formatting.py | 32 +-
5 files changed, 725 insertions(+), 118 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/cassandra/blob/d2f243ee/CHANGES.txt
----------------------------------------------------------------------
diff --cc CHANGES.txt
index c3dacc2,42dcf3e..4e19b23
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@@ -1,17 -1,5 +1,18 @@@
-2.1.12
- * (cqlsh) Improve COPY TO performance and error handling (CASSANDRA-9304)
+2.2.4
+ * Don't do anticompaction after subrange repair (CASSANDRA-10422)
+ * Fix SimpleDateType type compatibility (CASSANDRA-10027)
+ * (Hadoop) fix splits calculation (CASSANDRA-10640)
+ * (Hadoop) ensure that Cluster instances are always closed (CASSANDRA-10058)
+ * (cqlsh) show partial trace if incomplete after max_trace_wait (CASSANDRA-7645)
+ * Use most up-to-date version of schema for system tables (CASSANDRA-10652)
+ * Deprecate memory_allocator in cassandra.yaml (CASSANDRA-10581,10628)
+ * Expose phi values from failure detector via JMX and tweak debug
+ and trace logging (CASSANDRA-9526)
+ * Fix RangeNamesQueryPager (CASSANDRA-10509)
+ * Deprecate Pig support (CASSANDRA-10542)
+ * Reduce contention getting instances of CompositeType (CASSANDRA-10433)
+Merged from 2.1:
++ * * (cqlsh) Improve COPY TO performance and error handling (CASSANDRA-9304)
* Don't remove level info when running upgradesstables (CASSANDRA-10692)
* Create compression chunk for sending file only (CASSANDRA-10680)
* Make buffered read size configurable (CASSANDRA-10249)
[3/5] cassandra git commit: Merge branch 'cassandra-2.1' into
cassandra-2.2
Posted by ty...@apache.org.
http://git-wip-us.apache.org/repos/asf/cassandra/blob/d2f243ee/bin/cqlsh.py
----------------------------------------------------------------------
diff --cc bin/cqlsh.py
index 64536e1,0000000..94c7af3
mode 100644,000000..100644
--- a/bin/cqlsh.py
+++ b/bin/cqlsh.py
@@@ -1,2737 -1,0 +1,2674 @@@
+#!/bin/sh
+# -*- mode: Python -*-
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""":"
+# bash code here; finds a suitable python interpreter and execs this file.
+# prefer unqualified "python" if suitable:
+python -c 'import sys; sys.exit(not (0x020500b0 < sys.hexversion < 0x03000000))' 2>/dev/null \
+ && exec python "$0" "$@"
+for pyver in 2.6 2.7 2.5; do
+ which python$pyver > /dev/null 2>&1 && exec python$pyver "$0" "$@"
+done
+echo "No appropriate python interpreter found." >&2
+exit 1
+":"""
+
+from __future__ import with_statement
+
+import cmd
+import codecs
+import ConfigParser
+import csv
+import getpass
+import locale
- import multiprocessing
++import multiprocessing as mp
+import optparse
+import os
+import platform
+import sys
+import time
+import traceback
+import warnings
+from contextlib import contextmanager
+from functools import partial
+from glob import glob
+from StringIO import StringIO
+from uuid import UUID
+
+if sys.version_info[0] != 2 or sys.version_info[1] != 7:
+ sys.exit("\nCQL Shell supports only Python 2.7\n")
+
+description = "CQL Shell for Apache Cassandra"
+version = "5.0.1"
+
+readline = None
+try:
+ # check if tty first, cause readline doesn't check, and only cares
+ # about $TERM. we don't want the funky escape code stuff to be
+ # output if not a tty.
+ if sys.stdin.isatty():
+ import readline
+except ImportError:
+ pass
+
+CQL_LIB_PREFIX = 'cassandra-driver-internal-only-'
+
+CASSANDRA_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')
+
+# use bundled libs for python-cql and thrift, if available. if there
+# is a ../lib dir, use bundled libs there preferentially.
+ZIPLIB_DIRS = [os.path.join(CASSANDRA_PATH, 'lib')]
+myplatform = platform.system()
+if myplatform == 'Linux':
+ ZIPLIB_DIRS.append('/usr/share/cassandra/lib')
+
+if os.environ.get('CQLSH_NO_BUNDLED', ''):
+ ZIPLIB_DIRS = ()
+
+
+def find_zip(libprefix):
+ for ziplibdir in ZIPLIB_DIRS:
+ zips = glob(os.path.join(ziplibdir, libprefix + '*.zip'))
+ if zips:
+ return max(zips) # probably the highest version, if multiple
+
+cql_zip = find_zip(CQL_LIB_PREFIX)
+if cql_zip:
+ ver = os.path.splitext(os.path.basename(cql_zip))[0][len(CQL_LIB_PREFIX):]
+ sys.path.insert(0, os.path.join(cql_zip, 'cassandra-driver-' + ver))
+
+third_parties = ('futures-', 'six-')
+
+for lib in third_parties:
+ lib_zip = find_zip(lib)
+ if lib_zip:
+ sys.path.insert(0, lib_zip)
+
+warnings.filterwarnings("ignore", r".*blist.*")
+try:
+ import cassandra
+except ImportError, e:
+ sys.exit("\nPython Cassandra driver not installed, or not on PYTHONPATH.\n"
+ 'You might try "pip install cassandra-driver".\n\n'
+ 'Python: %s\n'
+ 'Module load path: %r\n\n'
+ 'Error: %s\n' % (sys.executable, sys.path, e))
+
+from cassandra.auth import PlainTextAuthProvider
+from cassandra.cluster import Cluster
+from cassandra.metadata import (ColumnMetadata, KeyspaceMetadata,
+ TableMetadata, protect_name, protect_names,
+ protect_value)
+from cassandra.policies import WhiteListRoundRobinPolicy
+from cassandra.protocol import QueryMessage, ResultMessage
+from cassandra.query import SimpleStatement, ordered_dict_factory, TraceUnavailable
+
+# cqlsh should run correctly when run out of a Cassandra source tree,
+# out of an unpacked Cassandra tarball, and after a proper package install.
+cqlshlibdir = os.path.join(CASSANDRA_PATH, 'pylib')
+if os.path.isdir(cqlshlibdir):
+ sys.path.insert(0, cqlshlibdir)
+
- from cqlshlib import cql3handling, cqlhandling, pylexotron, sslhandling
++from cqlshlib import cql3handling, cqlhandling, pylexotron, sslhandling, copy
+from cqlshlib.displaying import (ANSI_RESET, BLUE, COLUMN_NAME_COLORS, CYAN,
+ RED, FormattedValue, colorme)
+from cqlshlib.formatting import (DEFAULT_DATE_FORMAT, DEFAULT_NANOTIME_FORMAT,
+ DEFAULT_TIMESTAMP_FORMAT, DateTimeFormat,
+ format_by_type, format_value_utype,
+ formatter_for)
+from cqlshlib.tracing import print_trace, print_trace_session
+from cqlshlib.util import get_file_encoding_bomsize, trim_if_present
+
+DEFAULT_HOST = '127.0.0.1'
+DEFAULT_PORT = 9042
+DEFAULT_CQLVER = '3.3.1'
+DEFAULT_PROTOCOL_VERSION = 4
+DEFAULT_CONNECT_TIMEOUT_SECONDS = 5
+
+DEFAULT_FLOAT_PRECISION = 5
+DEFAULT_MAX_TRACE_WAIT = 10
+
+if readline is not None and readline.__doc__ is not None and 'libedit' in readline.__doc__:
+ DEFAULT_COMPLETEKEY = '\t'
+else:
+ DEFAULT_COMPLETEKEY = 'tab'
+
+cqldocs = None
+cqlruleset = None
+
+epilog = """Connects to %(DEFAULT_HOST)s:%(DEFAULT_PORT)d by default. These
+defaults can be changed by setting $CQLSH_HOST and/or $CQLSH_PORT. When a
+host (and optional port number) are given on the command line, they take
+precedence over any defaults.""" % globals()
+
+parser = optparse.OptionParser(description=description, epilog=epilog,
+ usage="Usage: %prog [options] [host [port]]",
+ version='cqlsh ' + version)
+parser.add_option("-C", "--color", action='store_true', dest='color',
+ help='Always use color output')
+parser.add_option("--no-color", action='store_false', dest='color',
+ help='Never use color output')
+parser.add_option('--ssl', action='store_true', help='Use SSL', default=False)
+parser.add_option("-u", "--username", help="Authenticate as user.")
+parser.add_option("-p", "--password", help="Authenticate using password.")
+parser.add_option('-k', '--keyspace', help='Authenticate to the given keyspace.')
+parser.add_option("-f", "--file", help="Execute commands from FILE, then exit")
+parser.add_option('--debug', action='store_true',
+ help='Show additional debugging information')
+parser.add_option("--encoding", help="Specify a non-default encoding for output. If you are " +
+ "experiencing problems with unicode characters, using utf8 may fix the problem." +
+ " (Default from system preferences: %s)" % (locale.getpreferredencoding(),))
+parser.add_option("--cqlshrc", help="Specify an alternative cqlshrc file location.")
+parser.add_option('--cqlversion', default=DEFAULT_CQLVER,
+ help='Specify a particular CQL version (default: %default).'
+ ' Examples: "3.0.3", "3.1.0"')
+parser.add_option("-e", "--execute", help='Execute the statement and quit.')
+parser.add_option("--connect-timeout", default=DEFAULT_CONNECT_TIMEOUT_SECONDS, dest='connect_timeout',
+ help='Specify the connection timeout in seconds (default: %default seconds).')
+
+optvalues = optparse.Values()
+(options, arguments) = parser.parse_args(sys.argv[1:], values=optvalues)
+
+# BEGIN history/config definition
+HISTORY_DIR = os.path.expanduser(os.path.join('~', '.cassandra'))
+
+if hasattr(options, 'cqlshrc'):
+ CONFIG_FILE = options.cqlshrc
+ if not os.path.exists(CONFIG_FILE):
+ print '\nWarning: Specified cqlshrc location `%s` does not exist. Using `%s` instead.\n' % (CONFIG_FILE, HISTORY_DIR)
+ CONFIG_FILE = os.path.join(HISTORY_DIR, 'cqlshrc')
+else:
+ CONFIG_FILE = os.path.join(HISTORY_DIR, 'cqlshrc')
+
+HISTORY = os.path.join(HISTORY_DIR, 'cqlsh_history')
+if not os.path.exists(HISTORY_DIR):
+ try:
+ os.mkdir(HISTORY_DIR)
+ except OSError:
+ print '\nWarning: Cannot create directory at `%s`. Command history will not be saved.\n' % HISTORY_DIR
+
+OLD_CONFIG_FILE = os.path.expanduser(os.path.join('~', '.cqlshrc'))
+if os.path.exists(OLD_CONFIG_FILE):
+ if os.path.exists(CONFIG_FILE):
+ print '\nWarning: cqlshrc config files were found at both the old location (%s) and \
+ the new location (%s), the old config file will not be migrated to the new \
+ location, and the new location will be used for now. You should manually \
+ consolidate the config files at the new location and remove the old file.' \
+ % (OLD_CONFIG_FILE, CONFIG_FILE)
+ else:
+ os.rename(OLD_CONFIG_FILE, CONFIG_FILE)
+OLD_HISTORY = os.path.expanduser(os.path.join('~', '.cqlsh_history'))
+if os.path.exists(OLD_HISTORY):
+ os.rename(OLD_HISTORY, HISTORY)
+# END history/config definition
+
+CQL_ERRORS = (
+ cassandra.AlreadyExists, cassandra.AuthenticationFailed, cassandra.InvalidRequest,
+ cassandra.Timeout, cassandra.Unauthorized, cassandra.OperationTimedOut,
+ cassandra.cluster.NoHostAvailable,
+ cassandra.connection.ConnectionBusy, cassandra.connection.ProtocolError, cassandra.connection.ConnectionException,
+ cassandra.protocol.ErrorMessage, cassandra.protocol.InternalError, cassandra.query.TraceUnavailable
+)
+
+debug_completion = bool(os.environ.get('CQLSH_DEBUG_COMPLETION', '') == 'YES')
+
+# we want the cql parser to understand our cqlsh-specific commands too
+my_commands_ending_with_newline = (
+ 'help',
+ '?',
+ 'consistency',
+ 'serial',
+ 'describe',
+ 'desc',
+ 'show',
+ 'source',
+ 'capture',
+ 'login',
+ 'debug',
+ 'tracing',
+ 'expand',
+ 'paging',
+ 'exit',
+ 'quit',
+ 'clear',
+ 'cls'
+)
+
+
+cqlsh_syntax_completers = []
+
+
+def cqlsh_syntax_completer(rulename, termname):
+ def registrator(f):
+ cqlsh_syntax_completers.append((rulename, termname, f))
+ return f
+ return registrator
+
+
+cqlsh_extra_syntax_rules = r'''
+<cqlshCommand> ::= <CQL_Statement>
+ | <specialCommand> ( ";" | "\n" )
+ ;
+
+<specialCommand> ::= <describeCommand>
+ | <consistencyCommand>
+ | <serialConsistencyCommand>
+ | <showCommand>
+ | <sourceCommand>
+ | <captureCommand>
+ | <copyCommand>
+ | <loginCommand>
+ | <debugCommand>
+ | <helpCommand>
+ | <tracingCommand>
+ | <expandCommand>
+ | <exitCommand>
+ | <pagingCommand>
+ | <clearCommand>
+ ;
+
+<describeCommand> ::= ( "DESCRIBE" | "DESC" )
+ ( "FUNCTIONS" ksname=<keyspaceName>?
+ | "FUNCTION" udf=<anyFunctionName>
+ | "AGGREGATES" ksname=<keyspaceName>?
+ | "AGGREGATE" uda=<userAggregateName>
+ | "KEYSPACES"
+ | "KEYSPACE" ksname=<keyspaceName>?
+ | ( "COLUMNFAMILY" | "TABLE" ) cf=<columnFamilyName>
+ | "INDEX" idx=<indexName>
+ | ( "COLUMNFAMILIES" | "TABLES" )
+ | "FULL"? "SCHEMA"
+ | "CLUSTER"
+ | "TYPES"
+ | "TYPE" ut=<userTypeName>
+ | (ksname=<keyspaceName> | cf=<columnFamilyName> | idx=<indexName>))
+ ;
+
+<consistencyCommand> ::= "CONSISTENCY" ( level=<consistencyLevel> )?
+ ;
+
+<consistencyLevel> ::= "ANY"
+ | "ONE"
+ | "TWO"
+ | "THREE"
+ | "QUORUM"
+ | "ALL"
+ | "LOCAL_QUORUM"
+ | "EACH_QUORUM"
+ | "SERIAL"
+ | "LOCAL_SERIAL"
+ | "LOCAL_ONE"
+ ;
+
+<serialConsistencyCommand> ::= "SERIAL" "CONSISTENCY" ( level=<serialConsistencyLevel> )?
+ ;
+
+<serialConsistencyLevel> ::= "SERIAL"
+ | "LOCAL_SERIAL"
+ ;
+
+<showCommand> ::= "SHOW" what=( "VERSION" | "HOST" | "SESSION" sessionid=<uuid> )
+ ;
+
+<sourceCommand> ::= "SOURCE" fname=<stringLiteral>
+ ;
+
+<captureCommand> ::= "CAPTURE" ( fname=( <stringLiteral> | "OFF" ) )?
+ ;
+
+<copyCommand> ::= "COPY" cf=<columnFamilyName>
+ ( "(" [colnames]=<colname> ( "," [colnames]=<colname> )* ")" )?
+ ( dir="FROM" ( fname=<stringLiteral> | "STDIN" )
+ | dir="TO" ( fname=<stringLiteral> | "STDOUT" ) )
+ ( "WITH" <copyOption> ( "AND" <copyOption> )* )?
+ ;
+
+<copyOption> ::= [optnames]=(<identifier>|<reserved_identifier>) "=" [optvals]=<copyOptionVal>
+ ;
+
+<copyOptionVal> ::= <identifier>
+ | <reserved_identifier>
+ | <stringLiteral>
+ ;
+
+# avoiding just "DEBUG" so that this rule doesn't get treated as a terminal
+<debugCommand> ::= "DEBUG" "THINGS"?
+ ;
+
+<helpCommand> ::= ( "HELP" | "?" ) [topic]=( /[a-z_]*/ )*
+ ;
+
+<tracingCommand> ::= "TRACING" ( switch=( "ON" | "OFF" ) )?
+ ;
+
+<expandCommand> ::= "EXPAND" ( switch=( "ON" | "OFF" ) )?
+ ;
+
+<pagingCommand> ::= "PAGING" ( switch=( "ON" | "OFF" | /[0-9]+/) )?
+ ;
+
+<loginCommand> ::= "LOGIN" username=<username> (password=<stringLiteral>)?
+ ;
+
+<exitCommand> ::= "exit" | "quit"
+ ;
+
+<clearCommand> ::= "CLEAR" | "CLS"
+ ;
+
+<qmark> ::= "?" ;
+'''
+
+
+@cqlsh_syntax_completer('helpCommand', 'topic')
+def complete_help(ctxt, cqlsh):
+ return sorted([t.upper() for t in cqldocs.get_help_topics() + cqlsh.get_help_topics()])
+
+
+def complete_source_quoted_filename(ctxt, cqlsh):
+ partial_path = ctxt.get_binding('partial', '')
+ head, tail = os.path.split(partial_path)
+ exhead = os.path.expanduser(head)
+ try:
+ contents = os.listdir(exhead or '.')
+ except OSError:
+ return ()
+ matches = filter(lambda f: f.startswith(tail), contents)
+ annotated = []
+ for f in matches:
+ match = os.path.join(head, f)
+ if os.path.isdir(os.path.join(exhead, f)):
+ match += '/'
+ annotated.append(match)
+ return annotated
+
+
+cqlsh_syntax_completer('sourceCommand', 'fname')(complete_source_quoted_filename)
+cqlsh_syntax_completer('captureCommand', 'fname')(complete_source_quoted_filename)
+
+
+@cqlsh_syntax_completer('copyCommand', 'fname')
+def copy_fname_completer(ctxt, cqlsh):
+ lasttype = ctxt.get_binding('*LASTTYPE*')
+ if lasttype == 'unclosedString':
+ return complete_source_quoted_filename(ctxt, cqlsh)
+ partial_path = ctxt.get_binding('partial')
+ if partial_path == '':
+ return ["'"]
+ return ()
+
+
+@cqlsh_syntax_completer('copyCommand', 'colnames')
+def complete_copy_column_names(ctxt, cqlsh):
+ existcols = map(cqlsh.cql_unprotect_name, ctxt.get_binding('colnames', ()))
+ ks = cqlsh.cql_unprotect_name(ctxt.get_binding('ksname', None))
+ cf = cqlsh.cql_unprotect_name(ctxt.get_binding('cfname'))
+ colnames = cqlsh.get_column_names(ks, cf)
+ if len(existcols) == 0:
+ return [colnames[0]]
+ return set(colnames[1:]) - set(existcols)
+
+
- COPY_OPTIONS = ('DELIMITER', 'QUOTE', 'ESCAPE', 'HEADER', 'ENCODING', 'TIMEFORMAT', 'NULL')
++COPY_OPTIONS = ['DELIMITER', 'QUOTE', 'ESCAPE', 'HEADER', 'NULL', 'ENCODING',
++ 'TIMEFORMAT', 'JOBS', 'PAGESIZE', 'PAGETIMEOUT', 'MAXATTEMPTS']
+
+
+@cqlsh_syntax_completer('copyOption', 'optnames')
+def complete_copy_options(ctxt, cqlsh):
+ optnames = map(str.upper, ctxt.get_binding('optnames', ()))
+ direction = ctxt.get_binding('dir').upper()
+ opts = set(COPY_OPTIONS) - set(optnames)
+ if direction == 'FROM':
- opts -= ('ENCODING',)
- opts -= ('TIMEFORMAT',)
++ opts -= set(['ENCODING', 'TIMEFORMAT', 'JOBS', 'PAGESIZE', 'PAGETIMEOUT', 'MAXATTEMPTS'])
+ return opts
+
+
+@cqlsh_syntax_completer('copyOption', 'optvals')
+def complete_copy_opt_values(ctxt, cqlsh):
+ optnames = ctxt.get_binding('optnames', ())
+ lastopt = optnames[-1].lower()
+ if lastopt == 'header':
+ return ['true', 'false']
+ return [cqlhandling.Hint('<single_character_string>')]
+
+
+class NoKeyspaceError(Exception):
+ pass
+
+
+class KeyspaceNotFound(Exception):
+ pass
+
+
+class ColumnFamilyNotFound(Exception):
+ pass
+
+
+class IndexNotFound(Exception):
+ pass
+
+
+class ObjectNotFound(Exception):
+ pass
+
+
+class VersionNotSupported(Exception):
+ pass
+
+
+class UserTypeNotFound(Exception):
+ pass
+
+
+class FunctionNotFound(Exception):
+ pass
+
+
+class AggregateNotFound(Exception):
+ pass
+
+
+class DecodeError(Exception):
+ verb = 'decode'
+
+ def __init__(self, thebytes, err, colname=None):
+ self.thebytes = thebytes
+ self.err = err
+ self.colname = colname
+
+ def __str__(self):
+ return str(self.thebytes)
+
+ def message(self):
+ what = 'value %r' % (self.thebytes,)
+ if self.colname is not None:
+ what = 'value %r (for column %r)' % (self.thebytes, self.colname)
+ return 'Failed to %s %s : %s' \
+ % (self.verb, what, self.err)
+
+ def __repr__(self):
+ return '<%s %s>' % (self.__class__.__name__, self.message())
+
+
+class FormatError(DecodeError):
+ verb = 'format'
+
+
+def full_cql_version(ver):
+ while ver.count('.') < 2:
+ ver += '.0'
+ ver_parts = ver.split('-', 1) + ['']
+ vertuple = tuple(map(int, ver_parts[0].split('.')) + [ver_parts[1]])
+ return ver, vertuple
+
+
+def format_value(val, output_encoding, addcolor=False, date_time_format=None,
+ float_precision=None, colormap=None, nullval=None):
+ if isinstance(val, DecodeError):
+ if addcolor:
+ return colorme(repr(val.thebytes), colormap, 'error')
+ else:
+ return FormattedValue(repr(val.thebytes))
+ return format_by_type(type(val), val, output_encoding, colormap=colormap,
+ addcolor=addcolor, nullval=nullval, date_time_format=date_time_format,
+ float_precision=float_precision)
+
+
+def show_warning_without_quoting_line(message, category, filename, lineno, file=None, line=None):
+ if file is None:
+ file = sys.stderr
+ try:
+ file.write(warnings.formatwarning(message, category, filename, lineno, line=''))
+ except IOError:
+ pass
+warnings.showwarning = show_warning_without_quoting_line
+warnings.filterwarnings('always', category=cql3handling.UnexpectedTableStructure)
+
+
+def describe_interval(seconds):
+ desc = []
+ for length, unit in ((86400, 'day'), (3600, 'hour'), (60, 'minute')):
+ num = int(seconds) / length
+ if num > 0:
+ desc.append('%d %s' % (num, unit))
+ if num > 1:
+ desc[-1] += 's'
+ seconds %= length
+ words = '%.03f seconds' % seconds
+ if len(desc) > 1:
+ words = ', '.join(desc) + ', and ' + words
+ elif len(desc) == 1:
+ words = desc[0] + ' and ' + words
+ return words
+
+
++def insert_driver_hooks():
++ extend_cql_deserialization()
++ auto_format_udts()
++
++
++def extend_cql_deserialization():
++ """
++ The python driver returns BLOBs as string, but we expect them as bytearrays
++ """
++ cassandra.cqltypes.BytesType.deserialize = staticmethod(lambda byts, protocol_version: bytearray(byts))
++ cassandra.cqltypes.CassandraType.support_empty_values = True
++
++
+def auto_format_udts():
+ # when we see a new user defined type, set up the shell formatting for it
+ udt_apply_params = cassandra.cqltypes.UserType.apply_parameters
+
+ def new_apply_params(cls, *args, **kwargs):
+ udt_class = udt_apply_params(*args, **kwargs)
+ formatter_for(udt_class.typename)(format_value_utype)
+ return udt_class
+
+ cassandra.cqltypes.UserType.udt_apply_parameters = classmethod(new_apply_params)
+
+ make_udt_class = cassandra.cqltypes.UserType.make_udt_class
+
+ def new_make_udt_class(cls, *args, **kwargs):
+ udt_class = make_udt_class(*args, **kwargs)
+ formatter_for(udt_class.tuple_type.__name__)(format_value_utype)
+ return udt_class
+
+ cassandra.cqltypes.UserType.make_udt_class = classmethod(new_make_udt_class)
+
+
+class FrozenType(cassandra.cqltypes._ParameterizedType):
+ """
+ Needed until the bundled python driver adds FrozenType.
+ """
+ typename = "frozen"
+ num_subtypes = 1
+
+ @classmethod
+ def deserialize_safe(cls, byts, protocol_version):
+ subtype, = cls.subtypes
+ return subtype.from_binary(byts)
+
+ @classmethod
+ def serialize_safe(cls, val, protocol_version):
+ subtype, = cls.subtypes
+ return subtype.to_binary(val, protocol_version)
+
+
+class Shell(cmd.Cmd):
+ custom_prompt = os.getenv('CQLSH_PROMPT', '')
+ if custom_prompt is not '':
+ custom_prompt += "\n"
+ default_prompt = custom_prompt + "cqlsh> "
+ continue_prompt = " ... "
+ keyspace_prompt = custom_prompt + "cqlsh:%s> "
+ keyspace_continue_prompt = "%s ... "
+ show_line_nums = False
+ debug = False
+ stop = False
+ last_hist = None
+ shunted_query_out = None
+ use_paging = True
+ csv_dialect_defaults = dict(delimiter=',', doublequote=False,
+ escapechar='\\', quotechar='"')
+ default_page_size = 100
+
+ def __init__(self, hostname, port, color=False,
+ username=None, password=None, encoding=None, stdin=None, tty=True,
+ completekey=DEFAULT_COMPLETEKEY, use_conn=None,
+ cqlver=DEFAULT_CQLVER, keyspace=None,
+ tracing_enabled=False, expand_enabled=False,
+ display_nanotime_format=DEFAULT_NANOTIME_FORMAT,
+ display_timestamp_format=DEFAULT_TIMESTAMP_FORMAT,
+ display_date_format=DEFAULT_DATE_FORMAT,
+ display_float_precision=DEFAULT_FLOAT_PRECISION,
+ max_trace_wait=DEFAULT_MAX_TRACE_WAIT,
+ ssl=False,
+ single_statement=None,
+ client_timeout=10,
+ protocol_version=DEFAULT_PROTOCOL_VERSION,
+ connect_timeout=DEFAULT_CONNECT_TIMEOUT_SECONDS):
+ cmd.Cmd.__init__(self, completekey=completekey)
+ self.hostname = hostname
+ self.port = port
+ self.auth_provider = None
+ if username:
+ if not password:
+ password = getpass.getpass()
+ self.auth_provider = PlainTextAuthProvider(username=username, password=password)
+ self.username = username
+ self.keyspace = keyspace
+ self.ssl = ssl
+ self.tracing_enabled = tracing_enabled
+ self.page_size = self.default_page_size
+ self.expand_enabled = expand_enabled
+ if use_conn:
+ self.conn = use_conn
+ else:
+ self.conn = Cluster(contact_points=(self.hostname,), port=self.port, cql_version=cqlver,
+ protocol_version=protocol_version,
+ auth_provider=self.auth_provider,
+ ssl_options=sslhandling.ssl_settings(hostname, CONFIG_FILE) if ssl else None,
+ load_balancing_policy=WhiteListRoundRobinPolicy([self.hostname]),
+ connect_timeout=connect_timeout)
+ self.owns_connection = not use_conn
+ self.set_expanded_cql_version(cqlver)
+
+ if keyspace:
+ self.session = self.conn.connect(keyspace)
+ else:
+ self.session = self.conn.connect()
+
+ self.color = color
+
+ self.display_nanotime_format = display_nanotime_format
+ self.display_timestamp_format = display_timestamp_format
+ self.display_date_format = display_date_format
+
+ self.display_float_precision = display_float_precision
+
+ # If there is no schema metadata present (due to a schema mismatch), force schema refresh
+ if not self.conn.metadata.keyspaces:
+ self.refresh_schema_metadata_best_effort()
+
+ self.session.default_timeout = client_timeout
+ self.session.row_factory = ordered_dict_factory
+ self.session.default_consistency_level = cassandra.ConsistencyLevel.ONE
+ self.get_connection_versions()
+
+ self.current_keyspace = keyspace
+
+ self.display_timestamp_format = display_timestamp_format
+ self.display_nanotime_format = display_nanotime_format
+ self.display_date_format = display_date_format
+
+ self.max_trace_wait = max_trace_wait
+ self.session.max_trace_wait = max_trace_wait
+ if encoding is None:
+ encoding = locale.getpreferredencoding()
+ self.encoding = encoding
+ self.output_codec = codecs.lookup(encoding)
+
+ self.statement = StringIO()
+ self.lineno = 1
+ self.in_comment = False
+
+ self.prompt = ''
+ if stdin is None:
+ stdin = sys.stdin
+ self.tty = tty
+ if tty:
+ self.reset_prompt()
+ self.report_connection()
+ print 'Use HELP for help.'
+ else:
+ self.show_line_nums = True
+ self.stdin = stdin
+ self.query_out = sys.stdout
+ self.consistency_level = cassandra.ConsistencyLevel.ONE
+ self.serial_consistency_level = cassandra.ConsistencyLevel.SERIAL
- # the python driver returns BLOBs as string, but we expect them as bytearrays
- cassandra.cqltypes.BytesType.deserialize = staticmethod(lambda byts, protocol_version: bytearray(byts))
- cassandra.cqltypes.CassandraType.support_empty_values = True
-
- auto_format_udts()
+
+ self.empty_lines = 0
+ self.statement_error = False
+ self.single_statement = single_statement
+
+ def refresh_schema_metadata_best_effort(self):
+ try:
+ self.conn.refresh_schema_metadata(5) # will throw exception if there is a schema mismatch
+ except Exception:
+ self.printerr("Warning: schema version mismatch detected, which might be caused by DOWN nodes; if "
+ "this is not the case, check the schema versions of your nodes in system.local and "
+ "system.peers.")
+ self.conn.refresh_schema_metadata(0)
+
+ def set_expanded_cql_version(self, ver):
+ ver, vertuple = full_cql_version(ver)
+ self.cql_version = ver
+ self.cql_ver_tuple = vertuple
+
+ def cqlver_atleast(self, major, minor=0, patch=0):
+ return self.cql_ver_tuple[:3] >= (major, minor, patch)
+
+ def myformat_value(self, val, **kwargs):
+ if isinstance(val, DecodeError):
+ self.decoding_errors.append(val)
+ try:
+ dtformats = DateTimeFormat(timestamp_format=self.display_timestamp_format,
+ date_format=self.display_date_format, nanotime_format=self.display_nanotime_format)
+ return format_value(val, self.output_codec.name,
+ addcolor=self.color, date_time_format=dtformats,
+ float_precision=self.display_float_precision, **kwargs)
+ except Exception, e:
+ err = FormatError(val, e)
+ self.decoding_errors.append(err)
+ return format_value(err, self.output_codec.name, addcolor=self.color)
+
+ def myformat_colname(self, name, table_meta=None):
+ column_colors = COLUMN_NAME_COLORS.copy()
+ # check column role and color appropriately
+ if table_meta:
+ if name in [col.name for col in table_meta.partition_key]:
+ column_colors.default_factory = lambda: RED
+ elif name in [col.name for col in table_meta.clustering_key]:
+ column_colors.default_factory = lambda: CYAN
+ return self.myformat_value(name, colormap=column_colors)
+
+ def report_connection(self):
+ self.show_host()
+ self.show_version()
+
+ def show_host(self):
+ print "Connected to %s at %s:%d." % \
+ (self.applycolor(self.get_cluster_name(), BLUE),
+ self.hostname,
+ self.port)
+
+ def show_version(self):
+ vers = self.connection_versions.copy()
+ vers['shver'] = version
+ # system.Versions['cql'] apparently does not reflect changes with
+ # set_cql_version.
+ vers['cql'] = self.cql_version
+ print "[cqlsh %(shver)s | Cassandra %(build)s | CQL spec %(cql)s | Native protocol v%(protocol)s]" % vers
+
+ def show_session(self, sessionid, partial_session=False):
+ print_trace_session(self, self.session, sessionid, partial_session)
+
+ def get_connection_versions(self):
+ result, = self.session.execute("select * from system.local where key = 'local'")
+ vers = {
+ 'build': result['release_version'],
+ 'protocol': result['native_protocol_version'],
+ 'cql': result['cql_version'],
+ }
+ self.connection_versions = vers
+
+ def get_keyspace_names(self):
+ return map(str, self.conn.metadata.keyspaces.keys())
+
+ def get_columnfamily_names(self, ksname=None):
+ if ksname is None:
+ ksname = self.current_keyspace
+
+ return map(str, self.get_keyspace_meta(ksname).tables.keys())
+
+ def get_index_names(self, ksname=None):
+ if ksname is None:
+ ksname = self.current_keyspace
+
+ return map(str, self.get_keyspace_meta(ksname).indexes.keys())
+
+ def get_column_names(self, ksname, cfname):
+ if ksname is None:
+ ksname = self.current_keyspace
+ layout = self.get_table_meta(ksname, cfname)
+ return [str(col) for col in layout.columns]
+
+ def get_usertype_names(self, ksname=None):
+ if ksname is None:
+ ksname = self.current_keyspace
+
+ return self.get_keyspace_meta(ksname).user_types.keys()
+
+ def get_usertype_layout(self, ksname, typename):
+ if ksname is None:
+ ksname = self.current_keyspace
+
+ ks_meta = self.get_keyspace_meta(ksname)
+
+ try:
+ user_type = ks_meta.user_types[typename]
+ except KeyError:
+ raise UserTypeNotFound("User type %r not found" % typename)
+
+ return [(field_name, field_type.cql_parameterized_type())
+ for field_name, field_type in zip(user_type.field_names, user_type.field_types)]
+
+ def get_userfunction_names(self, ksname=None):
+ if ksname is None:
+ ksname = self.current_keyspace
+
+ return map(lambda f: f.name, self.get_keyspace_meta(ksname).functions.values())
+
+ def get_useraggregate_names(self, ksname=None):
+ if ksname is None:
+ ksname = self.current_keyspace
+
+ return map(lambda f: f.name, self.get_keyspace_meta(ksname).aggregates.values())
+
+ def get_cluster_name(self):
+ return self.conn.metadata.cluster_name
+
+ def get_partitioner(self):
+ return self.conn.metadata.partitioner
+
+ def get_keyspace_meta(self, ksname):
+ if ksname not in self.conn.metadata.keyspaces:
+ raise KeyspaceNotFound('Keyspace %r not found.' % ksname)
+ return self.conn.metadata.keyspaces[ksname]
+
+ def get_keyspaces(self):
+ return self.conn.metadata.keyspaces.values()
+
- def get_ring(self):
- if self.current_keyspace is None or self.current_keyspace == 'system':
- raise NoKeyspaceError("Ring view requires a current non-system keyspace")
- self.conn.metadata.token_map.rebuild_keyspace(self.current_keyspace, build_if_absent=True)
- return self.conn.metadata.token_map.tokens_to_hosts_by_ks[self.current_keyspace]
++ def get_ring(self, ks):
++ self.conn.metadata.token_map.rebuild_keyspace(ks, build_if_absent=True)
++ return self.conn.metadata.token_map.tokens_to_hosts_by_ks[ks]
+
+ def get_table_meta(self, ksname, tablename):
+ if ksname is None:
+ ksname = self.current_keyspace
+ ksmeta = self.get_keyspace_meta(ksname)
+
+ if tablename not in ksmeta.tables:
+ if ksname == 'system_auth' and tablename in ['roles', 'role_permissions']:
+ self.get_fake_auth_table_meta(ksname, tablename)
+ else:
+ raise ColumnFamilyNotFound("Column family %r not found" % tablename)
+ else:
+ return ksmeta.tables[tablename]
+
+ def get_fake_auth_table_meta(self, ksname, tablename):
+ # may be using external auth implementation so internal tables
+ # aren't actually defined in schema. In this case, we'll fake
+ # them up
+ if tablename == 'roles':
+ ks_meta = KeyspaceMetadata(ksname, True, None, None)
+ table_meta = TableMetadata(ks_meta, 'roles')
+ table_meta.columns['role'] = ColumnMetadata(table_meta, 'role', cassandra.cqltypes.UTF8Type)
+ table_meta.columns['is_superuser'] = ColumnMetadata(table_meta, 'is_superuser', cassandra.cqltypes.BooleanType)
+ table_meta.columns['can_login'] = ColumnMetadata(table_meta, 'can_login', cassandra.cqltypes.BooleanType)
+ elif tablename == 'role_permissions':
+ ks_meta = KeyspaceMetadata(ksname, True, None, None)
+ table_meta = TableMetadata(ks_meta, 'role_permissions')
+ table_meta.columns['role'] = ColumnMetadata(table_meta, 'role', cassandra.cqltypes.UTF8Type)
+ table_meta.columns['resource'] = ColumnMetadata(table_meta, 'resource', cassandra.cqltypes.UTF8Type)
+ table_meta.columns['permission'] = ColumnMetadata(table_meta, 'permission', cassandra.cqltypes.UTF8Type)
+ else:
+ raise ColumnFamilyNotFound("Column family %r not found" % tablename)
+
+ def get_index_meta(self, ksname, idxname):
+ if ksname is None:
+ ksname = self.current_keyspace
+ ksmeta = self.get_keyspace_meta(ksname)
+
+ if idxname not in ksmeta.indexes:
+ raise IndexNotFound("Index %r not found" % idxname)
+
+ return ksmeta.indexes[idxname]
+
+ def get_object_meta(self, ks, name):
+ if name is None:
+ if ks and ks in self.conn.metadata.keyspaces:
+ return self.conn.metadata.keyspaces[ks]
+ elif self.current_keyspace is None:
+ raise ObjectNotFound("%r not found in keyspaces" % (ks))
+ else:
+ name = ks
+ ks = self.current_keyspace
+
+ if ks is None:
+ ks = self.current_keyspace
+
+ ksmeta = self.get_keyspace_meta(ks)
+
+ if name in ksmeta.tables:
+ return ksmeta.tables[name]
+ elif name in ksmeta.indexes:
+ return ksmeta.indexes[name]
+
+ raise ObjectNotFound("%r not found in keyspace %r" % (name, ks))
+
+ def get_usertypes_meta(self):
+ data = self.session.execute("select * from system.schema_usertypes")
+ if not data:
+ return cql3handling.UserTypesMeta({})
+
+ return cql3handling.UserTypesMeta.from_layout(data)
+
+ def get_trigger_names(self, ksname=None):
+ if ksname is None:
+ ksname = self.current_keyspace
+
+ return [trigger.name
+ for table in self.get_keyspace_meta(ksname).tables.values()
+ for trigger in table.triggers.values()]
+
+ def reset_statement(self):
+ self.reset_prompt()
+ self.statement.truncate(0)
+ self.empty_lines = 0
+
+ def reset_prompt(self):
+ if self.current_keyspace is None:
+ self.set_prompt(self.default_prompt, True)
+ else:
+ self.set_prompt(self.keyspace_prompt % self.current_keyspace, True)
+
+ def set_continue_prompt(self):
+ if self.empty_lines >= 3:
+ self.set_prompt("Statements are terminated with a ';'. You can press CTRL-C to cancel an incomplete statement.")
+ self.empty_lines = 0
+ return
+ if self.current_keyspace is None:
+ self.set_prompt(self.continue_prompt)
+ else:
+ spaces = ' ' * len(str(self.current_keyspace))
+ self.set_prompt(self.keyspace_continue_prompt % spaces)
+ self.empty_lines = self.empty_lines + 1 if not self.lastcmd else 0
+
+ @contextmanager
+ def prepare_loop(self):
+ readline = None
+ if self.tty and self.completekey:
+ try:
+ import readline
+ except ImportError:
+ if myplatform == 'Windows':
+ print "WARNING: pyreadline dependency missing. Install to enable tab completion."
+ pass
+ else:
+ old_completer = readline.get_completer()
+ readline.set_completer(self.complete)
+ if readline.__doc__ is not None and 'libedit' in readline.__doc__:
+ readline.parse_and_bind("bind -e")
+ readline.parse_and_bind("bind '" + self.completekey + "' rl_complete")
+ readline.parse_and_bind("bind ^R em-inc-search-prev")
+ else:
+ readline.parse_and_bind(self.completekey + ": complete")
+ try:
+ yield
+ finally:
+ if readline is not None:
+ readline.set_completer(old_completer)
+
+ def get_input_line(self, prompt=''):
+ if self.tty:
+ self.lastcmd = raw_input(prompt)
+ line = self.lastcmd + '\n'
+ else:
+ self.lastcmd = self.stdin.readline()
+ line = self.lastcmd
+ if not len(line):
+ raise EOFError
+ self.lineno += 1
+ return line
+
+ def use_stdin_reader(self, until='', prompt=''):
+ until += '\n'
+ while True:
+ try:
+ newline = self.get_input_line(prompt=prompt)
+ except EOFError:
+ return
+ if newline == until:
+ return
+ yield newline
+
+ def cmdloop(self):
+ """
+ Adapted from cmd.Cmd's version, because there is literally no way with
+ cmd.Cmd.cmdloop() to tell the difference between "EOF" showing up in
+ input and an actual EOF.
+ """
+ with self.prepare_loop():
+ while not self.stop:
+ try:
+ if self.single_statement:
+ line = self.single_statement
+ self.stop = True
+ else:
+ line = self.get_input_line(self.prompt)
+ self.statement.write(line)
+ if self.onecmd(self.statement.getvalue()):
+ self.reset_statement()
+ except EOFError:
+ self.handle_eof()
+ except CQL_ERRORS, cqlerr:
+ self.printerr(str(cqlerr))
+ except KeyboardInterrupt:
+ self.reset_statement()
+ print
+
+ def onecmd(self, statementtext):
+ """
+ Returns true if the statement is complete and was handled (meaning it
+ can be reset).
+ """
+
+ try:
+ statements, in_batch = cqlruleset.cql_split_statements(statementtext)
+ except pylexotron.LexingError, e:
+ if self.show_line_nums:
+ self.printerr('Invalid syntax at char %d' % (e.charnum,))
+ else:
+ self.printerr('Invalid syntax at line %d, char %d'
+ % (e.linenum, e.charnum))
+ statementline = statementtext.split('\n')[e.linenum - 1]
+ self.printerr(' %s' % statementline)
+ self.printerr(' %s^' % (' ' * e.charnum))
+ return True
+
+ while statements and not statements[-1]:
+ statements = statements[:-1]
+ if not statements:
+ return True
+ if in_batch or statements[-1][-1][0] != 'endtoken':
+ self.set_continue_prompt()
+ return
+ for st in statements:
+ try:
+ self.handle_statement(st, statementtext)
+ except Exception, e:
+ if self.debug:
+ traceback.print_exc()
+ else:
+ self.printerr(e)
+ return True
+
+ def handle_eof(self):
+ if self.tty:
+ print
+ statement = self.statement.getvalue()
+ if statement.strip():
+ if not self.onecmd(statement):
+ self.printerr('Incomplete statement at end of file')
+ self.do_exit()
+
+ def handle_statement(self, tokens, srcstr):
+ # Concat multi-line statements and insert into history
+ if readline is not None:
+ nl_count = srcstr.count("\n")
+
+ new_hist = srcstr.replace("\n", " ").rstrip()
+
+ if nl_count > 1 and self.last_hist != new_hist:
+ readline.add_history(new_hist)
+
+ self.last_hist = new_hist
+ cmdword = tokens[0][1]
+ if cmdword == '?':
+ cmdword = 'help'
+ custom_handler = getattr(self, 'do_' + cmdword.lower(), None)
+ if custom_handler:
+ parsed = cqlruleset.cql_whole_parse_tokens(tokens, srcstr=srcstr,
+ startsymbol='cqlshCommand')
+ if parsed and not parsed.remainder:
+ # successful complete parse
+ return custom_handler(parsed)
+ else:
+ return self.handle_parse_error(cmdword, tokens, parsed, srcstr)
+ return self.perform_statement(cqlruleset.cql_extract_orig(tokens, srcstr))
+
+ def handle_parse_error(self, cmdword, tokens, parsed, srcstr):
+ if cmdword.lower() in ('select', 'insert', 'update', 'delete', 'truncate',
+ 'create', 'drop', 'alter', 'grant', 'revoke',
+ 'batch', 'list'):
+ # hey, maybe they know about some new syntax we don't. type
+ # assumptions won't work, but maybe the query will.
+ return self.perform_statement(cqlruleset.cql_extract_orig(tokens, srcstr))
+ if parsed:
+ self.printerr('Improper %s command (problem at %r).' % (cmdword, parsed.remainder[0]))
+ else:
+ self.printerr('Improper %s command.' % cmdword)
+
+ def do_use(self, parsed):
+ ksname = parsed.get_binding('ksname')
+ success, _ = self.perform_simple_statement(SimpleStatement(parsed.extract_orig()))
+ if success:
+ if ksname[0] == '"' and ksname[-1] == '"':
+ self.current_keyspace = self.cql_unprotect_name(ksname)
+ else:
+ self.current_keyspace = ksname.lower()
+
+ def do_select(self, parsed):
+ tracing_was_enabled = self.tracing_enabled
+ ksname = parsed.get_binding('ksname')
+ stop_tracing = ksname == 'system_traces' or (ksname is None and self.current_keyspace == 'system_traces')
+ self.tracing_enabled = self.tracing_enabled and not stop_tracing
+ statement = parsed.extract_orig()
+ self.perform_statement(statement)
+ self.tracing_enabled = tracing_was_enabled
+
+ def perform_statement(self, statement):
+ stmt = SimpleStatement(statement, consistency_level=self.consistency_level, serial_consistency_level=self.serial_consistency_level, fetch_size=self.page_size if self.use_paging else None)
+ success, future = self.perform_simple_statement(stmt)
+
+ if future:
+ if future.warnings:
+ self.print_warnings(future.warnings)
+
+ if self.tracing_enabled:
+ try:
+ for trace in future.get_all_query_traces(self.max_trace_wait):
+ print_trace(self, trace)
+ except TraceUnavailable:
+ msg = "Statement trace did not complete within %d seconds; trace data may be incomplete." % (self.session.max_trace_wait,)
+ self.writeresult(msg, color=RED)
+ for trace_id in future.get_query_trace_ids():
+ self.show_session(trace_id, partial_session=True)
+ except Exception, err:
+ self.printerr("Unable to fetch query trace: %s" % (str(err),))
+
+ return success
+
+ def parse_for_table_meta(self, query_string):
+ try:
+ parsed = cqlruleset.cql_parse(query_string)[1]
+ except IndexError:
+ return None
+ ks = self.cql_unprotect_name(parsed.get_binding('ksname', None))
+ cf = self.cql_unprotect_name(parsed.get_binding('cfname'))
+ return self.get_table_meta(ks, cf)
+
+ def perform_simple_statement(self, statement):
+ if not statement:
+ return False, None
+ rows = None
+ while True:
+ try:
+ future = self.session.execute_async(statement, trace=self.tracing_enabled)
+ result = future.result()
+ break
+ except cassandra.OperationTimedOut, err:
+ self.refresh_schema_metadata_best_effort()
+ self.printerr(str(err.__class__.__name__) + ": " + str(err))
+ return False, None
+ except CQL_ERRORS, err:
+ self.printerr(str(err.__class__.__name__) + ": " + str(err))
+ return False, None
+ except Exception, err:
+ import traceback
+ self.printerr(traceback.format_exc())
+ return False, None
+
+ if statement.query_string[:6].lower() == 'select':
+ self.print_result(result, self.parse_for_table_meta(statement.query_string))
+ elif statement.query_string.lower().startswith("list users") or statement.query_string.lower().startswith("list roles"):
+ self.print_result(result, self.get_table_meta('system_auth', 'roles'))
+ elif statement.query_string.lower().startswith("list"):
+ self.print_result(result, self.get_table_meta('system_auth', 'role_permissions'))
+ elif result:
+ # CAS INSERT/UPDATE
+ self.writeresult("")
+ self.print_static_result(list(result), self.parse_for_table_meta(statement.query_string))
+ self.flush_output()
+ return True, future
+
+ def print_result(self, result, table_meta):
+ self.decoding_errors = []
+
+ self.writeresult("")
+ if result.has_more_pages and self.tty:
+ num_rows = 0
+ while True:
+ page = result.current_rows
+ if page:
+ num_rows += len(page)
+ self.print_static_result(page, table_meta)
+ if result.has_more_pages:
+ raw_input("---MORE---")
+ result.fetch_next_page()
+ else:
+ break
+ else:
+ rows = list(result)
+ num_rows = len(rows)
+ self.print_static_result(rows, table_meta)
+ self.writeresult("(%d rows)" % num_rows)
+
+ if self.decoding_errors:
+ for err in self.decoding_errors[:2]:
+ self.writeresult(err.message(), color=RED)
+ if len(self.decoding_errors) > 2:
+ self.writeresult('%d more decoding errors suppressed.'
+ % (len(self.decoding_errors) - 2), color=RED)
+
+ def print_static_result(self, rows, table_meta):
+ if not rows:
+ if not table_meta:
+ return
+ # print header only
+ colnames = table_meta.columns.keys() # full header
+ formatted_names = [self.myformat_colname(name, table_meta) for name in colnames]
+ self.print_formatted_result(formatted_names, None)
+ return
+
+ colnames = rows[0].keys()
+ formatted_names = [self.myformat_colname(name, table_meta) for name in colnames]
+ formatted_values = [map(self.myformat_value, row.values()) for row in rows]
+
+ if self.expand_enabled:
+ self.print_formatted_result_vertically(formatted_names, formatted_values)
+ else:
+ self.print_formatted_result(formatted_names, formatted_values)
+
+ def print_formatted_result(self, formatted_names, formatted_values):
+ # determine column widths
+ widths = [n.displaywidth for n in formatted_names]
+ if formatted_values is not None:
+ for fmtrow in formatted_values:
+ for num, col in enumerate(fmtrow):
+ widths[num] = max(widths[num], col.displaywidth)
+
+ # print header
+ header = ' | '.join(hdr.ljust(w, color=self.color) for (hdr, w) in zip(formatted_names, widths))
+ self.writeresult(' ' + header.rstrip())
+ self.writeresult('-%s-' % '-+-'.join('-' * w for w in widths))
+
+ # stop if there are no rows
+ if formatted_values is None:
+ self.writeresult("")
+ return
+
+ # print row data
+ for row in formatted_values:
+ line = ' | '.join(col.rjust(w, color=self.color) for (col, w) in zip(row, widths))
+ self.writeresult(' ' + line)
+
+ self.writeresult("")
+
+ def print_formatted_result_vertically(self, formatted_names, formatted_values):
+ max_col_width = max([n.displaywidth for n in formatted_names])
+ max_val_width = max([n.displaywidth for row in formatted_values for n in row])
+
+ # for each row returned, list all the column-value pairs
+ for row_id, row in enumerate(formatted_values):
+ self.writeresult("@ Row %d" % (row_id + 1))
+ self.writeresult('-%s-' % '-+-'.join(['-' * max_col_width, '-' * max_val_width]))
+ for field_id, field in enumerate(row):
+ column = formatted_names[field_id].ljust(max_col_width, color=self.color)
+ value = field.ljust(field.displaywidth, color=self.color)
+ self.writeresult(' ' + " | ".join([column, value]))
+ self.writeresult('')
+
+ def print_warnings(self, warnings):
+ if warnings is None or len(warnings) == 0:
+ return
+
+ self.writeresult('')
+ self.writeresult('Warnings :')
+ for warning in warnings:
+ self.writeresult(warning)
+ self.writeresult('')
+
+ def emptyline(self):
+ pass
+
+ def parseline(self, line):
+ # this shouldn't be needed
+ raise NotImplementedError
+
+ def complete(self, text, state):
+ if readline is None:
+ return
+ if state == 0:
+ try:
+ self.completion_matches = self.find_completions(text)
+ except Exception:
+ if debug_completion:
+ import traceback
+ traceback.print_exc()
+ else:
+ raise
+ try:
+ return self.completion_matches[state]
+ except IndexError:
+ return None
+
+ def find_completions(self, text):
+ curline = readline.get_line_buffer()
+ prevlines = self.statement.getvalue()
+ wholestmt = prevlines + curline
+ begidx = readline.get_begidx() + len(prevlines)
+ stuff_to_complete = wholestmt[:begidx]
+ return cqlruleset.cql_complete(stuff_to_complete, text, cassandra_conn=self,
+ debug=debug_completion, startsymbol='cqlshCommand')
+
+ def set_prompt(self, prompt, prepend_user=False):
+ if prepend_user and self.username:
+ self.prompt = "%s@%s" % (self.username, prompt)
+ return
+ self.prompt = prompt
+
+ def cql_unprotect_name(self, namestr):
+ if namestr is None:
+ return
+ return cqlruleset.dequote_name(namestr)
+
+ def cql_unprotect_value(self, valstr):
+ if valstr is not None:
+ return cqlruleset.dequote_value(valstr)
+
+ def print_recreate_keyspace(self, ksdef, out):
+ out.write(ksdef.export_as_string())
+ out.write("\n")
+
+ def print_recreate_columnfamily(self, ksname, cfname, out):
+ """
+ Output CQL commands which should be pasteable back into a CQL session
+ to recreate the given table.
+
+ Writes output to the given out stream.
+ """
+ out.write(self.get_table_meta(ksname, cfname).export_as_string())
+ out.write("\n")
+
+ def print_recreate_index(self, ksname, idxname, out):
+ """
+ Output CQL commands which should be pasteable back into a CQL session
+ to recreate the given index.
+
+ Writes output to the given out stream.
+ """
+ out.write(self.get_index_meta(ksname, idxname).export_as_string())
+ out.write("\n")
+
+ def print_recreate_object(self, ks, name, out):
+ """
+ Output CQL commands which should be pasteable back into a CQL session
+ to recreate the given object (ks, table or index).
+
+ Writes output to the given out stream.
+ """
+ out.write(self.get_object_meta(ks, name).export_as_string())
+ out.write("\n")
+
+ def describe_keyspaces(self):
+ print
+ cmd.Cmd.columnize(self, protect_names(self.get_keyspace_names()))
+ print
+
+ def describe_keyspace(self, ksname):
+ print
+ self.print_recreate_keyspace(self.get_keyspace_meta(ksname), sys.stdout)
+ print
+
+ def describe_columnfamily(self, ksname, cfname):
+ if ksname is None:
+ ksname = self.current_keyspace
+ if ksname is None:
+ raise NoKeyspaceError("No keyspace specified and no current keyspace")
+ print
+ self.print_recreate_columnfamily(ksname, cfname, sys.stdout)
+ print
+
+ def describe_index(self, ksname, idxname):
+ print
+ self.print_recreate_index(ksname, idxname, sys.stdout)
+ print
+
+ def describe_object(self, ks, name):
+ print
+ self.print_recreate_object(ks, name, sys.stdout)
+ print
+
+ def describe_columnfamilies(self, ksname):
+ print
+ if ksname is None:
+ for k in self.get_keyspaces():
+ name = protect_name(k.name)
+ print 'Keyspace %s' % (name,)
+ print '---------%s' % ('-' * len(name))
+ cmd.Cmd.columnize(self, protect_names(self.get_columnfamily_names(k.name)))
+ print
+ else:
+ cmd.Cmd.columnize(self, protect_names(self.get_columnfamily_names(ksname)))
+ print
+
+ def describe_functions(self, ksname=None):
+ print
+ if ksname is None:
+ for ksmeta in self.get_keyspaces():
+ name = protect_name(ksmeta.name)
+ print 'Keyspace %s' % (name,)
+ print '---------%s' % ('-' * len(name))
+ cmd.Cmd.columnize(self, protect_names(ksmeta.functions.keys()))
+ print
+ else:
+ ksmeta = self.get_keyspace_meta(ksname)
+ cmd.Cmd.columnize(self, protect_names(ksmeta.functions.keys()))
+ print
+
+ def describe_function(self, ksname, functionname):
+ if ksname is None:
+ ksname = self.current_keyspace
+ if ksname is None:
+ raise NoKeyspaceError("No keyspace specified and no current keyspace")
+ print
+ ksmeta = self.get_keyspace_meta(ksname)
+ functions = filter(lambda f: f.name == functionname, ksmeta.functions.values())
+ if len(functions) == 0:
+ raise FunctionNotFound("User defined function %r not found" % functionname)
+ print "\n\n".join(func.as_cql_query(formatted=True) for func in functions)
+ print
+
+ def describe_aggregates(self, ksname=None):
+ print
+ if ksname is None:
+ for ksmeta in self.get_keyspaces():
+ name = protect_name(ksmeta.name)
+ print 'Keyspace %s' % (name,)
+ print '---------%s' % ('-' * len(name))
+ cmd.Cmd.columnize(self, protect_names(ksmeta.aggregates.keys()))
+ print
+ else:
+ ksmeta = self.get_keyspace_meta(ksname)
+ cmd.Cmd.columnize(self, protect_names(ksmeta.aggregates.keys()))
+ print
+
+ def describe_aggregate(self, ksname, aggregatename):
+ if ksname is None:
+ ksname = self.current_keyspace
+ if ksname is None:
+ raise NoKeyspaceError("No keyspace specified and no current keyspace")
+ print
+ ksmeta = self.get_keyspace_meta(ksname)
+ aggregates = filter(lambda f: f.name == aggregatename, ksmeta.aggregates.values())
+ if len(aggregates) == 0:
+ raise FunctionNotFound("User defined aggregate %r not found" % aggregatename)
+ print "\n\n".join(aggr.as_cql_query(formatted=True) for aggr in aggregates)
+ print
+
+ def describe_usertypes(self, ksname):
+ print
+ if ksname is None:
+ for ksmeta in self.get_keyspaces():
+ name = protect_name(ksmeta.name)
+ print 'Keyspace %s' % (name,)
+ print '---------%s' % ('-' * len(name))
+ cmd.Cmd.columnize(self, protect_names(ksmeta.user_types.keys()))
+ print
+ else:
+ ksmeta = self.get_keyspace_meta(ksname)
+ cmd.Cmd.columnize(self, protect_names(ksmeta.user_types.keys()))
+ print
+
+ def describe_usertype(self, ksname, typename):
+ if ksname is None:
+ ksname = self.current_keyspace
+ if ksname is None:
+ raise NoKeyspaceError("No keyspace specified and no current keyspace")
+ print
+ ksmeta = self.get_keyspace_meta(ksname)
+ try:
+ usertype = ksmeta.user_types[typename]
+ except KeyError:
+ raise UserTypeNotFound("User type %r not found" % typename)
+ print usertype.as_cql_query(formatted=True)
+ print
+
+ def describe_cluster(self):
+ print '\nCluster: %s' % self.get_cluster_name()
+ p = trim_if_present(self.get_partitioner(), 'org.apache.cassandra.dht.')
+ print 'Partitioner: %s\n' % p
+ # TODO: snitch?
+ # snitch = trim_if_present(self.get_snitch(), 'org.apache.cassandra.locator.')
+ # print 'Snitch: %s\n' % snitch
+ if self.current_keyspace is not None and self.current_keyspace != 'system':
+ print "Range ownership:"
- ring = self.get_ring()
++ ring = self.get_ring(self.current_keyspace)
+ for entry in ring.items():
+ print ' %39s [%s]' % (str(entry[0].value), ', '.join([host.address for host in entry[1]]))
+ print
+
+ def describe_schema(self, include_system=False):
+ print
+ for k in self.get_keyspaces():
+ if include_system or k.name not in cql3handling.SYSTEM_KEYSPACES:
+ self.print_recreate_keyspace(k, sys.stdout)
+ print
+
+ def do_describe(self, parsed):
+ """
+ DESCRIBE [cqlsh only]
+
+ (DESC may be used as a shorthand.)
+
+ Outputs information about the connected Cassandra cluster, or about
+ the data stored on it. Use in one of the following ways:
+
+ DESCRIBE KEYSPACES
+
+ Output the names of all keyspaces.
+
+ DESCRIBE KEYSPACE [<keyspacename>]
+
+ Output CQL commands that could be used to recreate the given
+ keyspace, and the tables in it. In some cases, as the CQL interface
+ matures, there will be some metadata about a keyspace that is not
+ representable with CQL. That metadata will not be shown.
+
+ The '<keyspacename>' argument may be omitted when using a non-system
+ keyspace; in that case, the current keyspace will be described.
+
+ DESCRIBE TABLES
+
+ Output the names of all tables in the current keyspace, or in all
+ keyspaces if there is no current keyspace.
+
+ DESCRIBE TABLE <tablename>
+
+ Output CQL commands that could be used to recreate the given table.
+ In some cases, as above, there may be table metadata which is not
+ representable and which will not be shown.
+
+ DESCRIBE INDEX <indexname>
+
+ Output CQL commands that could be used to recreate the given index.
+ In some cases, there may be index metadata which is not representable
+ and which will not be shown.
+
+ DESCRIBE CLUSTER
+
+ Output information about the connected Cassandra cluster, such as the
+ cluster name, and the partitioner and snitch in use. When you are
+ connected to a non-system keyspace, also shows endpoint-range
+ ownership information for the Cassandra ring.
+
+ DESCRIBE [FULL] SCHEMA
+
+ Output CQL commands that could be used to recreate the entire (non-system) schema.
+ Works as though "DESCRIBE KEYSPACE k" was invoked for each non-system keyspace
+ k. Use DESCRIBE FULL SCHEMA to include the system keyspaces.
+
+ DESCRIBE FUNCTIONS <keyspace>
+
+ Output the names of all user defined functions in the given keyspace.
+
+ DESCRIBE FUNCTION [<keyspace>.]<function>
+
+ Describe the given user defined function.
+
+ DESCRIBE AGGREGATES <keyspace>
+
+ Output the names of all user defined aggregates in the given keyspace.
+
+ DESCRIBE AGGREGATE [<keyspace>.]<aggregate>
+
+ Describe the given user defined aggregate.
+
+ DESCRIBE <objname>
+
+ Output CQL commands that could be used to recreate the entire object schema,
+ where object can be either a keyspace or a table or an index (in this order).
+ """
+ what = parsed.matched[1][1].lower()
+ if what == 'functions':
+ ksname = self.cql_unprotect_name(parsed.get_binding('ksname', None))
+ self.describe_functions(ksname)
+ elif what == 'function':
+ ksname = self.cql_unprotect_name(parsed.get_binding('ksname', None))
+ functionname = self.cql_unprotect_name(parsed.get_binding('udfname'))
+ self.describe_function(ksname, functionname)
+ elif what == 'aggregates':
+ ksname = self.cql_unprotect_name(parsed.get_binding('ksname', None))
+ self.describe_aggregates(ksname)
+ elif what == 'aggregate':
+ ksname = self.cql_unprotect_name(parsed.get_binding('ksname', None))
+ aggregatename = self.cql_unprotect_name(parsed.get_binding('udaname'))
+ self.describe_aggregate(ksname, aggregatename)
+ elif what == 'keyspaces':
+ self.describe_keyspaces()
+ elif what == 'keyspace':
+ ksname = self.cql_unprotect_name(parsed.get_binding('ksname', ''))
+ if not ksname:
+ ksname = self.current_keyspace
+ if ksname is None:
+ self.printerr('Not in any keyspace.')
+ return
+ self.describe_keyspace(ksname)
+ elif what in ('columnfamily', 'table'):
+ ks = self.cql_unprotect_name(parsed.get_binding('ksname', None))
+ cf = self.cql_unprotect_name(parsed.get_binding('cfname'))
+ self.describe_columnfamily(ks, cf)
+ elif what == 'index':
+ ks = self.cql_unprotect_name(parsed.get_binding('ksname', None))
+ idx = self.cql_unprotect_name(parsed.get_binding('idxname', None))
+ self.describe_index(ks, idx)
+ elif what in ('columnfamilies', 'tables'):
+ self.describe_columnfamilies(self.current_keyspace)
+ elif what == 'types':
+ self.describe_usertypes(self.current_keyspace)
+ elif what == 'type':
+ ks = self.cql_unprotect_name(parsed.get_binding('ksname', None))
+ ut = self.cql_unprotect_name(parsed.get_binding('utname'))
+ self.describe_usertype(ks, ut)
+ elif what == 'cluster':
+ self.describe_cluster()
+ elif what == 'schema':
+ self.describe_schema(False)
+ elif what == 'full' and parsed.matched[2][1].lower() == 'schema':
+ self.describe_schema(True)
+ elif what:
+ ks = self.cql_unprotect_name(parsed.get_binding('ksname', None))
+ name = self.cql_unprotect_name(parsed.get_binding('cfname'))
+ if not name:
+ name = self.cql_unprotect_name(parsed.get_binding('idxname', None))
+ self.describe_object(ks, name)
+ do_desc = do_describe
+
+ def do_copy(self, parsed):
+ r"""
+ COPY [cqlsh only]
+
+ COPY x FROM: Imports CSV data into a Cassandra table
+ COPY x TO: Exports data from a Cassandra table in CSV format.
+
+ COPY <table_name> [ ( column [, ...] ) ]
+ FROM ( '<filename>' | STDIN )
+ [ WITH <option>='value' [AND ...] ];
+
+ COPY <table_name> [ ( column [, ...] ) ]
+ TO ( '<filename>' | STDOUT )
+ [ WITH <option>='value' [AND ...] ];
+
+ Available options and defaults:
+
+ DELIMITER=',' - character that appears between records
+ QUOTE='"' - quoting character to be used to quote fields
+ ESCAPE='\' - character to appear before the QUOTE char when quoted
+ HEADER=false - whether to ignore the first line
+ NULL='' - string that represents a null value
+ ENCODING='utf8' - encoding for CSV output (COPY TO only)
+ TIMEFORMAT= - timestamp strftime format (COPY TO only)
+ '%Y-%m-%d %H:%M:%S%z' defaults to time_format value in cqlshrc
+
+ When entering CSV data on STDIN, you can use the sequence "\."
+ on a line by itself to end the data input.
+ """
+ ks = self.cql_unprotect_name(parsed.get_binding('ksname', None))
+ if ks is None:
+ ks = self.current_keyspace
+ if ks is None:
+ raise NoKeyspaceError("Not in any keyspace.")
+ cf = self.cql_unprotect_name(parsed.get_binding('cfname'))
+ columns = parsed.get_binding('colnames', None)
+ if columns is not None:
+ columns = map(self.cql_unprotect_name, columns)
+ else:
+ # default to all known columns
+ columns = self.get_column_names(ks, cf)
+ fname = parsed.get_binding('fname', None)
+ if fname is not None:
+ fname = os.path.expanduser(self.cql_unprotect_value(fname))
+ copyoptnames = map(str.lower, parsed.get_binding('optnames', ()))
+ copyoptvals = map(self.cql_unprotect_value, parsed.get_binding('optvals', ()))
+ cleancopyoptvals = [optval.decode('string-escape') for optval in copyoptvals]
+ opts = dict(zip(copyoptnames, cleancopyoptvals))
+
+ timestart = time.time()
+
+ direction = parsed.get_binding('dir').upper()
+ if direction == 'FROM':
+ rows = self.perform_csv_import(ks, cf, columns, fname, opts)
+ verb = 'imported'
+ elif direction == 'TO':
+ rows = self.perform_csv_export(ks, cf, columns, fname, opts)
+ verb = 'exported'
+ else:
+ raise SyntaxError("Unknown direction %s" % direction)
+
+ timeend = time.time()
+ print "\n%d rows %s in %s." % (rows, verb, describe_interval(timeend - timestart))
+
+ def perform_csv_import(self, ks, cf, columns, fname, opts):
- dialect_options = self.csv_dialect_defaults.copy()
- if 'quote' in opts:
- dialect_options['quotechar'] = opts.pop('quote')
- if 'escape' in opts:
- dialect_options['escapechar'] = opts.pop('escape')
- if 'delimiter' in opts:
- dialect_options['delimiter'] = opts.pop('delimiter')
- nullval = opts.pop('null', '')
- header = bool(opts.pop('header', '').lower() == 'true')
- if dialect_options['quotechar'] == dialect_options['escapechar']:
- dialect_options['doublequote'] = True
- del dialect_options['escapechar']
- if opts:
++ csv_options, dialect_options, unrecognized_options = copy.parse_options(self, opts)
++ if unrecognized_options:
+ self.printerr('Unrecognized COPY FROM options: %s'
- % ', '.join(opts.keys()))
++ % ', '.join(unrecognized_options.keys()))
+ return 0
++ nullval, header = csv_options['nullval'], csv_options['header']
+
+ if fname is None:
+ do_close = False
+ print "[Use \. on a line by itself to end input]"
+ linesource = self.use_stdin_reader(prompt='[copy] ', until=r'\.')
+ else:
+ do_close = True
+ try:
+ linesource = open(fname, 'rb')
+ except IOError, e:
+ self.printerr("Can't open %r for reading: %s" % (fname, e))
+ return 0
+
+ current_record = None
-
++ processes, pipes = [], [],
+ try:
+ if header:
+ linesource.next()
+ reader = csv.reader(linesource, **dialect_options)
+
- from multiprocessing import Pipe, cpu_count
-
- # Pick a resonable number of child processes. We need to leave at
- # least one core for the parent process. This doesn't necessarily
- # need to be capped at 4, but it's currently enough to keep
- # a single local Cassandra node busy, and I see lower throughput
- # with more processes.
- try:
- num_processes = max(1, min(4, cpu_count() - 1))
- except NotImplementedError:
- num_processes = 1
++ num_processes = copy.get_num_processes(cap=4)
+
- processes, pipes = [], [],
+ for i in range(num_processes):
- parent_conn, child_conn = Pipe()
++ parent_conn, child_conn = mp.Pipe()
+ pipes.append(parent_conn)
+ processes.append(ImportProcess(self, child_conn, ks, cf, columns, nullval))
+
+ for process in processes:
+ process.start()
+
- meter = RateMeter(10000)
++ meter = copy.RateMeter(10000)
+ for current_record, row in enumerate(reader, start=1):
+ # write to the child process
+ pipes[current_record % num_processes].send((current_record, row))
+
+ # update the progress and current rate periodically
+ meter.increment()
+
+ # check for any errors reported by the children
+ if (current_record % 100) == 0:
- if self._check_child_pipes(current_record, pipes):
++ if self._check_import_processes(current_record, pipes):
+ # no errors seen, continue with outer loop
+ continue
+ else:
+ # errors seen, break out of outer loop
+ break
+ except Exception, exc:
+ if current_record is None:
+ # we failed before we started
+ self.printerr("\nError starting import process:\n")
+ self.printerr(str(exc))
+ if self.debug:
+ traceback.print_exc()
+ else:
+ self.printerr("\n" + str(exc))
+ self.printerr("\nAborting import at record #%d. "
+ "Previously inserted records and some records after "
+ "this number may be present."
+ % (current_record,))
+ if self.debug:
+ traceback.print_exc()
+ finally:
+ # send a message that indicates we're done
+ for pipe in pipes:
+ pipe.send((None, None))
+
+ for process in processes:
+ process.join()
+
- self._check_child_pipes(current_record, pipes)
++ self._check_import_processes(current_record, pipes)
+
+ for pipe in pipes:
+ pipe.close()
+
+ if do_close:
+ linesource.close()
+ elif self.tty:
+ print
+
+ return current_record
+
- def _check_child_pipes(self, current_record, pipes):
- # check the pipes for errors from child processes
++ def _check_import_processes(self, current_record, pipes):
+ for pipe in pipes:
+ if pipe.poll():
+ try:
+ (record_num, error) = pipe.recv()
+ self.printerr("\n" + str(error))
+ self.printerr(
+ "Aborting import at record #%d. "
+ "Previously inserted records are still present, "
+ "and some records after that may be present as well."
+ % (record_num,))
+ return False
+ except EOFError:
+ # pipe is closed, nothing to read
+ self.printerr("\nChild process died without notification, "
+ "aborting import at record #%d. Previously "
+ "inserted records are probably still present, "
+ "and some records after that may be present "
+ "as well." % (current_record,))
+ return False
+ return True
+
+ def perform_csv_export(self, ks, cf, columns, fname, opts):
- dialect_options = self.csv_dialect_defaults.copy()
-
- if 'quote' in opts:
- dialect_options['quotechar'] = opts.pop('quote')
- if 'escape' in opts:
- dialect_options['escapechar'] = opts.pop('escape')
- if 'delimiter' in opts:
- dialect_options['delimiter'] = opts.pop('delimiter')
- encoding = opts.pop('encoding', 'utf8')
- nullval = opts.pop('null', '')
- header = bool(opts.pop('header', '').lower() == 'true')
- timestamp_format = opts.pop('timeformat', self.display_timestamp_format)
- if dialect_options['quotechar'] == dialect_options['escapechar']:
- dialect_options['doublequote'] = True
- del dialect_options['escapechar']
-
- if opts:
- self.printerr('Unrecognized COPY TO options: %s'
- % ', '.join(opts.keys()))
++ csv_options, dialect_options, unrecognized_options = copy.parse_options(self, opts)
++ if unrecognized_options:
++ self.printerr('Unrecognized COPY TO options: %s' % ', '.join(unrecognized_options.keys()))
+ return 0
+
- if fname is None:
- do_close = False
- csvdest = sys.stdout
- else:
- do_close = True
- try:
- csvdest = open(fname, 'wb')
- except IOError, e:
- self.printerr("Can't open %r for writing: %s" % (fname, e))
- return 0
-
- meter = RateMeter(10000)
- try:
- dtformats = DateTimeFormat(timestamp_format, self.display_date_format, self.display_nanotime_format)
- dump = self.prep_export_dump(ks, cf, columns)
- writer = csv.writer(csvdest, **dialect_options)
- if header:
- writer.writerow(columns)
- for row in dump:
- fmt = lambda v: \
- format_value(v, output_encoding=encoding, nullval=nullval,
- date_time_format=dtformats,
- float_precision=self.display_float_precision).strval
- writer.writerow(map(fmt, row.values()))
- meter.increment()
- finally:
- if do_close:
- csvdest.close()
- return meter.current_record
-
- def prep_export_dump(self, ks, cf, columns):
- if columns is None:
- columns = self.get_column_names(ks, cf)
- columnlist = ', '.join(protect_names(columns))
- query = 'SELECT %s FROM %s.%s' % (columnlist, protect_name(ks), protect_name(cf))
- return self.session.execute(query)
++ return copy.ExportTask(self, ks, cf, columns, fname, csv_options, dialect_options,
++ DEFAULT_PROTOCOL_VERSION, CONFIG_FILE).run()
+
+ def do_show(self, parsed):
+ """
+ SHOW [cqlsh only]
+
+ Displays information about the current cqlsh session. Can be called in
+ the following ways:
+
+ SHOW VERSION
+
+ Shows the version and build of the connected Cassandra instance, as
+ well as the versions of the CQL spec and the Thrift protocol that
+ the connected Cassandra instance understands.
+
+ SHOW HOST
+
+ Shows where cqlsh is currently connected.
+
+ SHOW SESSION <sessionid>
+
+ Pretty-prints the requested tracing session.
+ """
+ showwhat = parsed.get_binding('what').lower()
+ if showwhat == 'version':
+ self.get_connection_versions()
+ self.show_version()
+ elif showwhat == 'host':
+ self.show_host()
+ elif showwhat.startswith('session'):
+ session_id = parsed.get_binding('sessionid').lower()
+ self.show_session(UUID(session_id))
+ else:
+ self.printerr('Wait, how do I show %r?' % (showwhat,))
+
+ def do_source(self, parsed):
+ """
+ SOURCE [cqlsh only]
+
+ Executes a file containing CQL statements. Gives the output for each
+ statement in turn, if any, or any errors that occur along the way.
+
+ Errors do NOT abort execution of the CQL source file.
+
+ Usage:
+
+ SOURCE '<file>';
+
+ That is, the path to the file to be executed must be given inside a
+ string literal. The path is interpreted relative to the current working
+ directory. The tilde shorthand notation ('~/mydir') is supported for
+ referring to $HOME.
+
+ See also the --file option to cqlsh.
+ """
+ fname = parsed.get_binding('fname')
+ fname = os.path.expanduser(self.cql_unprotect_value(fname))
+ try:
+ encoding, bom_size = get_file_encoding_bomsize(fname)
+ f = codecs.open(fname, 'r', encoding)
+ f.seek(bom_size)
+ except IOError, e:
+ self.printerr('Could not open %r: %s' % (fname, e))
+ return
+ subshell = Shell(self.hostname, self.port,
+ color=self.color, encoding=self.encoding, stdin=f,
+ tty=False, use_conn=self.conn, cqlver=self.cql_version,
+ display_timestamp_format=self.display_timestamp_format,
+ display_date_format=self.display_date_format,
+ display_nanotime_format=self.display_nanotime_format,
+ display_float_precision=self.display_float_precision,
+ max_trace_wait=self.max_trace_wait)
+ subshell.cmdloop()
+ f.close()
+
+ def do_capture(self, parsed):
+ """
+ CAPTURE [cqlsh only]
+
+ Begins capturing command output and appending it to a specified file.
+ Output will not be shown at the console while it is captured.
+
+ Usage:
+
+ CAPTURE '<file>';
+ CAPTURE OFF;
+ CAPTURE;
+
+ That is, the path to the file to be appended to must be given inside a
+ string literal. The path is interpreted relative to the current working
+ directory. The tilde shorthand notation ('~/mydir') is supported for
+ referring to $HOME.
+
+ Only query result output is captured. Errors and output from cqlsh-only
+ commands will still be shown in the cqlsh session.
+
+ To stop capturing output and show it in the cqlsh session again, use
+ CAPTURE OFF.
+
+ To inspect the current capture configuration, use CAPTURE with no
+ arguments.
+ """
+ fname = parsed.get_binding('fname')
+ if fname is None:
+ if self.shunted_query_out is not None:
+ print "Currently capturing query output to %r." % (self.query_out.name,)
+ else:
+ print "Currently not capturing query output."
+ return
+
+ if fname.upper() == 'OFF':
+ if self.shunted_query_out is None:
+ self.printerr('Not currently capturing output.')
+ return
+ self.query_out.close()
+ self.query_out = self.shunted_query_out
+ self.color = self.shunted_color
+ self.shunted_query_out = None
+ del self.shunted_color
+ return
+
+ if self.shunted_query_out is not None:
+ self.printerr('Already capturing output to %s. Use CAPTURE OFF'
+ ' to disable.' % (self.query_out.name,))
+ return
+
+ fname = os.path.expanduser(self.cql_unprotect_value(fname))
+ try:
+ f = open(fname, 'a')
+ except IOError, e:
+ self.printerr('Could not open %r for append: %s' % (fname, e))
+ return
+ self.shunted_query_out = self.query_out
+ self.shunted_color = self.color
+ self.query_out = f
+ self.color = False
+ print 'Now capturing query output to %r.' % (fname,)
+
+ def do_tracing(self, parsed):
+ """
+ TRACING [cqlsh]
+
+ Enables or disables request tracing.
+
+ TRACING ON
+
+ Enables tracing for all further requests.
+
+ TRACING OFF
+
+ Disables tracing.
+
+ TRACING
+
+ TRACING with no arguments shows the current tracing status.
+ """
+ self.tracing_enabled = SwitchCommand("TRACING", "Tracing").execute(self.tracing_enabled, parsed, self.printerr)
+
+ def do_expand(self, parsed):
+ """
+ EXPAND [cqlsh]
+
+ Enables or disables expanded (vertical) output.
+
+ EXPAND ON
+
+ Enables expanded (vertical) output.
+
+ EXPAND OFF
+
+ Disables expanded (vertical) output.
+
+ EXPAND
+
+ EXPAND with no arguments shows the current value of expand setting.
+ """
+ self.expand_enabled = SwitchCommand("EXPAND", "Expanded output").execute(self.expand_enabled, parsed, self.printerr)
+
+ def do_consistency(self, parsed):
+ """
+ CONSISTENCY [cqlsh only]
+
+ Overrides default consistency level (default level is ONE).
+
+ CONSISTENCY <level>
+
+ Sets consistency level for future requests.
+
+ Valid consistency levels:
+
+ ANY, ONE, TWO, THREE, QUORUM, ALL, LOCAL_ONE, LOCAL_QUORUM, EACH_QUORUM, SERIAL and LOCAL_SERIAL.
+
+ SERIAL and LOCAL_SERIAL may be used only for SELECTs; will be rejected with updates.
+
+ CONSISTENCY
+
+ CONSISTENCY with no arguments shows the current consistency level.
+ """
+ level = parsed.get_binding('level')
+ if level is None:
+ print 'Current consistency level is %s.' % (cassandra.ConsistencyLevel.value_to_name[self.consistency_level])
+ return
+
+ self.consistency_level = cassandra.ConsistencyLevel.name_to_value[level.upper()]
+ print 'Consistency level set to %s.' % (level.upper(),)
+
+ def do_serial(self, parsed):
+ """
+ SERIAL CONSISTENCY [cqlsh only]
+
+ Overrides serial consistency level (default level is SERIAL).
+
+ SERIAL CONSISTENCY <level>
+
+ Sets consistency level for future conditional updates.
+
+ Valid consistency levels:
+
+ SERIAL, LOCAL_SERIAL.
+
+ SERIAL CONSISTENCY
+
+ SERIAL CONSISTENCY with no arguments shows the current consistency level.
+ """
+ level = parsed.get_binding('level')
+ if level is None:
+ print 'Current serial consistency level is %s.' % (cassandra.ConsistencyLevel.value_to_name[self.serial_consistency_level])
+ return
+
+ self.serial_consistency_level = cassandra.ConsistencyLevel.name_to_value[level.upper()]
+ print 'Serial consistency level set to %s.' % (level.upper(),)
+
+ def do_login(self, parsed):
+ """
+ LOGIN [cqlsh only]
+
+ Changes login information without requiring restart.
+
+ LOGIN <username> (<password>)
+
+ Login using the specified username. If password is specified, it will be used
+ otherwise, you will be prompted to enter.
+ """
+ username = parsed.get_binding('username')
+ password = parsed.get_binding('password')
+ if password is None:
+ password = getpass.getpass()
+ else:
+ password = password[1:-1]
+
+ auth_provider = PlainTextAuthProvider(username=username, password=password)
+
+ conn = Cluster(contact_points=(self.hostname,), port=self.port, cql_version=self.conn.cql_version,
+
<TRUNCATED>