You are viewing a plain text version of this content. The canonical link for it is here.
Posted to by on 2016/03/08 10:26:41 UTC

[02/23] cassandra git commit: COPY FROM on large datasets: fix progress report and debug performance

COPY FROM on large datasets: fix progress report and debug performance

patch by Stefania Alborghetti; reviewed by Adam Holmberg for CASSANDRA-11053


Branch: refs/heads/cassandra-2.2
Commit: c3d2f26f46c2d37b6cf918cbb5565fe57a5904cc
Parents: 0129f70
Author: Stefania Alborghetti <>
Authored: Thu Jan 28 14:31:55 2016 +0800
Committer: Sylvain Lebresne <>
Committed: Tue Mar 8 10:19:13 2016 +0100

 CHANGES.txt                |    1 +
 bin/cqlsh                  |   28 +-
 pylib/cqlshlib/ | 1160 +++++++++++++++++++++++++--------------
 pylib/cqlshlib/     |   35 +-
 pylib/             |    2 +
 5 files changed, 809 insertions(+), 417 deletions(-)
diff --git a/CHANGES.txt b/CHANGES.txt
index eed9035..d6b085c 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -1,4 +1,5 @@
+ * COPY FROM on large datasets: fix progress report and debug performance (CASSANDRA-11053)
  * InvalidateKeys should have a weak ref to key cache (CASSANDRA-11176)
  * Don't remove FailureDetector history on removeEndpoint (CASSANDRA-10371)
  * Only notify if repair status changed (CASSANDRA-11172)
diff --git a/bin/cqlsh b/bin/cqlsh
index 7a39636..374e588 100755
--- a/bin/cqlsh
+++ b/bin/cqlsh
                        'NUMPROCESSES', 'CONFIGFILE', 'RATEFILE']
@@ -533,8 +533,23 @@ def insert_driver_hooks():
 def extend_cql_deserialization():
-    The python driver returns BLOBs as string, but we expect them as bytearrays
+    The python driver returns BLOBs as string, but we expect them as bytearrays; therefore we change
+    the implementation of cassandra.cqltypes.BytesType.deserialize.
+    The deserializers package exists only when the driver has been compiled with cython extensions and
+    cassandra.deserializers.DesBytesType replaces cassandra.cqltypes.BytesType.deserialize.
+    DesBytesTypeByteArray is a fast deserializer that converts blobs into bytearrays but it was
+    only introduced recently (3.1.0). If it is available we use it, otherwise we remove
+    cassandra.deserializers.DesBytesType so that we fall back onto cassandra.cqltypes.BytesType.deserialize
+    just like in the case where no cython extensions are present.
+    if hasattr(cassandra, 'deserializers'):
+        if hasattr(cassandra.deserializers, 'DesBytesTypeByteArray'):
+            cassandra.deserializers.DesBytesType = cassandra.deserializers.DesBytesTypeByteArray
+        else:
+            del cassandra.deserializers.DesBytesType
     cassandra.cqltypes.BytesType.deserialize = staticmethod(lambda byts, protocol_version: bytearray(byts))
     cassandra.cqltypes.CassandraType.support_empty_values = True
@@ -1534,9 +1549,9 @@ class Shell(cmd.Cmd):
         Available COPY FROM options and defaults:
-          CHUNKSIZE=1000          - the size of chunks passed to worker processes
+          CHUNKSIZE=5000          - the size of chunks passed to worker processes
           INGESTRATE=100000       - an approximate ingest rate in rows per second
-          MINBATCHSIZE=2          - the minimum size of an import batch
+          MINBATCHSIZE=10         - the minimum size of an import batch
           MAXBATCHSIZE=20         - the maximum size of an import batch
           MAXROWS=-1              - the maximum number of rows, -1 means no maximum
           SKIPROWS=0              - the number of rows to skip
@@ -1545,6 +1560,11 @@ class Shell(cmd.Cmd):
           MAXINSERTERRORS=-1      - the maximum global number of insert errors, -1 means no maximum
           ERRFILE=''              - a file where to store all rows that could not be imported, by default this is
                                     import_ks_table.err where <ks> is your keyspace and <table> is your table name.
+          PREPAREDSTATEMENTS=True - whether to use prepared statements when importing, by default True. Set this to
+                                    False if you don't mind shifting data parsing to the cluster. The cluster will also
+                                    have to compile every batch statement. For large and oversized clusters
+                                    this will result in a faster import but for smaller clusters it may generate
+                                    timeouts.
         Available COPY TO options and defaults:
diff --git a/pylib/cqlshlib/ b/pylib/cqlshlib/
index f9e4a85..cd03765 100644
--- a/pylib/cqlshlib/
+++ b/pylib/cqlshlib/
@@ -1,3 +1,5 @@
+# cython: profile=True
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -21,25 +23,29 @@ import json
 import glob
 import multiprocessing as mp
 import os
-import Queue
+import platform
+import random
 import re
 import struct
 import sys
 import time
 import traceback
+from bisect import bisect_right
 from calendar import timegm
-from collections import defaultdict, deque, namedtuple
+from collections import defaultdict, namedtuple
 from decimal import Decimal
 from random import randrange
 from StringIO import StringIO
+from select import select
 from threading import Lock
 from uuid import UUID
+from util import profile_on, profile_off
 from cassandra.cluster import Cluster
 from cassandra.cqltypes import ReversedType, UserType
-from cassandra.metadata import protect_name, protect_names
-from cassandra.policies import RetryPolicy, WhiteListRoundRobinPolicy, TokenAwarePolicy, DCAwareRoundRobinPolicy
+from cassandra.metadata import protect_name, protect_names, protect_value
+from cassandra.policies import RetryPolicy, WhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy
 from cassandra.query import BatchStatement, BatchType, SimpleStatement, tuple_factory
 from cassandra.util import Date, Time
@@ -48,6 +54,10 @@ from displaying import NO_COLOR_MAP
 from formatting import format_value_default, EMPTY, get_formatter
 from sslhandling import ssl_settings
+STRACE_ON = False
+IS_LINUX = platform.system() == 'Linux'
 CopyOptions = namedtuple('CopyOptions', 'copy dialect unrecognized')
@@ -59,6 +69,81 @@ def safe_normpath(fname):
     return os.path.normpath(os.path.expanduser(fname)) if fname else fname
+class OneWayChannel(object):
+    """
+    A one way pipe protected by two process level locks, one for reading and one for writing.
+    """
+    def __init__(self):
+        self.reader, self.writer = mp.Pipe(duplex=False)
+        self.rlock = mp.Lock()
+        self.wlock = mp.Lock()
+    def send(self, obj):
+        with self.wlock:
+            self.writer.send(obj)
+    def recv(self):
+        with self.rlock:
+            return self.reader.recv()
+    def close(self):
+        self.reader.close()
+        self.writer.close()
+class OneWayChannels(object):
+    """
+    A group of one way channels.
+    """
+    def __init__(self, num_channels):
+        self.channels = [OneWayChannel() for _ in xrange(num_channels)]
+        self._readers = [ch.reader for ch in self.channels]
+        self._rlocks = [ch.rlock for ch in self.channels]
+        self._rlocks_by_readers = dict([(ch.reader, ch.rlock) for ch in self.channels])
+        self.num_channels = num_channels
+        self.recv = self.recv_select if IS_LINUX else self.recv_polling
+    def recv_select(self, timeout):
+        """
+        Implementation of the recv method for Linux, where select is available. Receive an object from
+        all pipes that are ready for reading without blocking.
+        """
+        readable, _, _ = select(self._readers, [], [], timeout)
+        for r in readable:
+            with self._rlocks_by_readers[r]:
+                try:
+                    yield r.recv()
+                except EOFError:
+                    continue
+    def recv_polling(self, timeout):
+        """
+        Implementation of the recv method for platforms where select() is not available for pipes.
+        We poll on all of the readers with a very small timeout. We stop when the timeout specified
+        has been received but we may exceed it since we check all processes during each sweep.
+        """
+        start = time.time()
+        while True:
+            for i, r in enumerate(self._readers):
+                with self._rlocks[i]:
+                    if r.poll(0.000000001):
+                        try:
+                            yield r.recv()
+                        except EOFError:
+                            continue
+            if time.time() - start > timeout:
+                break
+    def close(self):
+        for ch in self.channels:
+            try:
+                ch.close()
+            except:
+                pass
 class CopyTask(object):
     A base class for ImportTask and ExportTask
@@ -72,15 +157,18 @@ class CopyTask(object):
         self.protocol_version = protocol_version
         self.config_file = config_file
         # do not display messages when exporting to STDOUT
-        self.printmsg = self._printmsg if self.fname is not None or direction == 'in' else lambda _, eol='\n': None
+        self.printmsg = self._printmsg if self.fname is not None or direction == 'from' else lambda _, eol='\n': None
         self.options = self.parse_options(opts, direction)
         self.num_processes = self.options.copy['numprocesses']
+        if direction == 'in':
+            self.num_processes += 1  # add the feeder process
         self.printmsg('Using %d child processes' % (self.num_processes,))
         self.processes = []
-        self.inmsg = mp.Queue()
-        self.outmsg = mp.Queue()
+        self.inmsg = OneWayChannels(self.num_processes)
+        self.outmsg = OneWayChannels(self.num_processes)
         self.columns = CopyTask.get_columns(shell, ks, table, columns)
         self.time_start = time.time()
@@ -166,10 +254,10 @@ class CopyTask(object):
         copy_options['maxattempts'] = int(opts.pop('maxattempts', 5))
         copy_options['dtformats'] = opts.pop('datetimeformat', shell.display_time_format)
         copy_options['float_precision'] = shell.display_float_precision
-        copy_options['chunksize'] = int(opts.pop('chunksize', 1000))
+        copy_options['chunksize'] = int(opts.pop('chunksize', 5000))
         copy_options['ingestrate'] = int(opts.pop('ingestrate', 100000))
         copy_options['maxbatchsize'] = int(opts.pop('maxbatchsize', 20))
-        copy_options['minbatchsize'] = int(opts.pop('minbatchsize', 2))
+        copy_options['minbatchsize'] = int(opts.pop('minbatchsize', 10))
         copy_options['reportfrequency'] = float(opts.pop('reportfrequency', 0.25))
         copy_options['consistencylevel'] = shell.consistency_level
         copy_options['decimalsep'] = opts.pop('decimalsep', '.')
@@ -186,6 +274,7 @@ class CopyTask(object):
         copy_options['errfile'] = safe_normpath(opts.pop('errfile', 'import_%s_%s.err' % (self.ks, self.table,)))
         copy_options['ratefile'] = safe_normpath(opts.pop('ratefile', ''))
         copy_options['maxoutputsize'] = int(opts.pop('maxoutputsize', '-1'))
+        copy_options['preparedstatements'] = bool(opts.pop('preparedstatements', 'true').lower() == 'true')
         return CopyOptions(copy=copy_options, dialect=dialect_options, unrecognized=opts)
@@ -206,14 +295,17 @@ class CopyTask(object):
     def get_num_processes(cap):
         Pick a reasonable number of child processes. We need to leave at
-        least one core for the parent process.  This doesn't necessarily
-        need to be capped, but 4 is currently enough to keep
-        a single local Cassandra node busy so we use this for import, whilst
-        for export we use 16 since we can connect to multiple Cassandra nodes.
-        Eventually this parameter will become an option.
+        least one core for the parent process.
+        """
+        return max(1, min(cap, CopyTask.get_num_cores() - 1))
+    @staticmethod
+    def get_num_cores():
+        """
+        Return the number of cores if available.
-            return max(1, min(cap, mp.cpu_count() - 1))
+            return mp.cpu_count()
         except NotImplementedError:
             return 1
@@ -244,28 +336,40 @@ class CopyTask(object):
         return shell.get_column_names(ks, table) if not columns else columns
     def close(self):
-        for process in self.processes:
-            process.terminate()
+        self.stop_processes()
     def num_live_processes(self):
         return sum(1 for p in self.processes if p.is_alive())
+    @staticmethod
+    def get_pid():
+        return os.getpid() if hasattr(os, 'getpid') else None
+    @staticmethod
+    def trace_process(pid):
+        if pid and STRACE_ON:
+            os.system("strace -vvvv -c -o strace.{pid}.out -e trace=all -p {pid}&".format(pid=pid))
+    def start_processes(self):
+        for i, process in enumerate(self.processes):
+            process.start()
+            self.trace_process(
+        self.trace_process(self.get_pid())
+    def stop_processes(self):
+        for process in self.processes:
+            process.terminate()
     def make_params(self):
         Return a dictionary of parameters to be used by the worker processes.
         On Windows this dictionary must be pickle-able.
-        inmsg is the message queue flowing from parent to child process, so outmsg from the parent point
-        of view and, vice-versa,  outmsg is the message queue flowing from child to parent, so inmsg
-        from the parent point of view, hence the two are swapped below.
         shell =
-        return dict(inmsg=self.outmsg,  # see comment above
-                    outmsg=self.inmsg,  # see comment above
-                    ks=self.ks,
+        return dict(ks=self.ks,
@@ -281,6 +385,17 @@ class CopyTask(object):
+    def update_params(self, params, i):
+        """
+        Add the communication channels to the parameters to be passed to the worker process:
+            inmsg is the message queue flowing from parent to child process, so outmsg from the parent point
+            of view and, vice-versa,  outmsg is the message queue flowing from child to parent, so inmsg
+            from the parent point of view, hence the two are swapped below.
+        """
+        params['inmsg'] = self.outmsg.channels[i]
+        params['outmsg'] = self.inmsg.channels[i]
+        return params
 class ExportWriter(object):
@@ -414,10 +529,9 @@ class ExportTask(CopyTask):
         params = self.make_params()
         for i in xrange(self.num_processes):
-            self.processes.append(ExportProcess(params))
+            self.processes.append(ExportProcess(self.update_params(params, i)))
-        for process in self.processes:
-            process.start()
+        self.start_processes()
@@ -468,11 +582,12 @@ class ExportTask(CopyTask):
             return ret
-        def make_range_data(replicas=[]):
+        def make_range_data(replicas=None):
             hosts = []
-            for r in replicas:
-                if r.is_up and r.datacenter == local_dc:
-                    hosts.append(r.address)
+            if replicas:
+                for r in replicas:
+                    if r.is_up and r.datacenter == local_dc:
+                        hosts.append(r.address)
             if not hosts:
                 hosts.append(hostname)  # fallback to default host if no replicas in current dc
             return {'hosts': tuple(hosts), 'attempts': 0, 'rows': 0}
@@ -542,10 +657,13 @@ class ExportTask(CopyTask):
             return None
     def send_work(self, ranges, tokens_to_send):
+        i = 0
         for token_range in tokens_to_send:
-            self.outmsg.put((token_range, ranges[token_range]))
+            self.outmsg.channels[i].send((token_range, ranges[token_range]))
             ranges[token_range]['attempts'] += 1
+            i = i + 1 if i < self.num_processes - 1 else 0
     def export_records(self, ranges):
         Send records to child processes and monitor them by collecting their results
@@ -568,8 +686,7 @@ class ExportTask(CopyTask):
         succeeded = 0
         failed = 0
         while (failed + succeeded) < total_requests and self.num_live_processes() == num_processes:
-            try:
-                token_range, result = self.inmsg.get(timeout=1.0)
+            for token_range, result in self.inmsg.recv(timeout=0.1):
                 if token_range is None and result is None:  # a request has finished
                     succeeded += 1
                 elif isinstance(result, Exception):  # an error occurred
@@ -594,8 +711,6 @@ class ExportTask(CopyTask):
                     self.writer.write(data, num)
                     ranges[token_range]['rows'] += num
-            except Queue.Empty:
-                pass
         if self.num_live_processes() < len(processes):
             for process in processes:
@@ -612,7 +727,7 @@ class ExportTask(CopyTask):
                        self.describe_interval(time.time() - self.time_start)))
-class ImportReader(object):
+class FilesReader(object):
     A wrapper around a csv reader to keep track of when we have
     exhausted reading input files. We are passed a comma separated
@@ -620,18 +735,15 @@ class ImportReader(object):
     We generate a source generator and we read each source one
     by one.
-    def __init__(self, task):
- =
-        self.options = task.options
-        self.printmsg = task.printmsg
-        self.chunk_size = self.options.copy['chunksize']
-        self.header = self.options.copy['header']
-        self.max_rows = self.options.copy['maxrows']
-        self.skip_rows = self.options.copy['skiprows']
-        self.sources = self.get_source(task.fname)
+    def __init__(self, fname, options):
+        self.chunk_size = options.copy['chunksize']
+        self.header = options.copy['header']
+        self.max_rows = options.copy['maxrows']
+        self.skip_rows = options.copy['skiprows']
+        self.fname = fname
+        self.sources = None  # must be created later due to pickle problems on Windows
         self.num_sources = 0
         self.current_source = None
-        self.current_reader = None
         self.num_read = 0
     def get_source(self, paths):
@@ -640,35 +752,33 @@ class ImportReader(object):
          wrapping the source input, file name and a boolean indicating
          if it requires closing.
-        shell =
-        LineSource = namedtuple('LineSource', 'input close fname')
         def make_source(fname):
-                ret = LineSource(input=open(fname, 'rb'), close=True, fname=fname)
-                return ret
+                return open(fname, 'rb')
             except IOError, e:
-                shell.printerr("Can't open %r for reading: %s" % (fname, e))
+                self.printmsg("Can't open %r for reading: %s" % (fname, e))
                 return None
-        if paths is None:
-            self.printmsg("[Use \. on a line by itself to end input]")
-            yield LineSource(input=shell.use_stdin_reader(prompt='[copy] ', until=r'\.'), close=False, fname='')
-        else:
-            for path in paths.split(','):
-                path = path.strip()
-                if os.path.isfile(path):
-                    yield make_source(path)
-                else:
-                    for f in glob.glob(path):
-                        yield (make_source(f))
+        for path in paths.split(','):
+            path = path.strip()
+            if os.path.isfile(path):
+                yield make_source(path)
+            else:
+                for f in glob.glob(path):
+                    yield (make_source(f))
+    @staticmethod
+    def printmsg(msg, eol='\n'):
+        sys.stdout.write(msg + eol)
+        sys.stdout.flush()
     def start(self):
+        self.sources = self.get_source(self.fname)
     def exhausted(self):
-        return not self.current_reader
+        return not self.current_source
     def next_source(self):
@@ -679,40 +789,34 @@ class ImportReader(object):
         while self.current_source is None:
                 self.current_source =
-                if self.current_source and self.current_source.fname:
+                if self.current_source:
                     self.num_sources += 1
             except StopIteration:
                 return False
         if self.header:
-        self.current_reader = csv.reader(self.current_source.input, **self.options.dialect)
         return True
     def close_current_source(self):
         if not self.current_source:
-        if self.current_source.close:
-            self.current_source.input.close()
-        elif
-            print
+        self.current_source.close()
         self.current_source = None
-        self.current_reader = None
     def close(self):
     def read_rows(self, max_rows):
-        if not self.current_reader:
+        if not self.current_source:
             return []
         rows = []
         for i in xrange(min(max_rows, self.chunk_size)):
-                row =
+                row =
                 self.num_read += 1
                 if 0 <= self.max_rows < self.num_read:
@@ -729,13 +833,91 @@ class ImportReader(object):
         return filter(None, rows)
-class ImportErrors(object):
+class PipeReader(object):
-    A small class for managing import errors
+    A class for reading rows received on a pipe, this is used for reading input from STDIN
+    """
+    def __init__(self, inmsg, options):
+        self.inmsg = inmsg
+        self.chunk_size = options.copy['chunksize']
+        self.header = options.copy['header']
+        self.max_rows = options.copy['maxrows']
+        self.skip_rows = options.copy['skiprows']
+        self.num_read = 0
+        self.exhausted = False
+        self.num_sources = 1
+    def start(self):
+        pass
+    def read_rows(self, max_rows):
+        rows = []
+        for i in xrange(min(max_rows, self.chunk_size)):
+            row = self.inmsg.recv()
+            if row is None:
+                self.exhausted = True
+                break
+            self.num_read += 1
+            if 0 <= self.max_rows < self.num_read:
+                self.exhausted = True
+                break  # max rows exceeded
+            if self.header or self.num_read < self.skip_rows:
+                self.header = False  # skip header or initial skip_rows rows
+                continue
+            rows.append(row)
+        return rows
+class ImportProcessResult(object):
+    """
+    An object sent from ImportProcess instances to the parent import task in order to indicate progress.
+    """
+    def __init__(self, imported=0):
+        self.imported = imported
+class FeedingProcessResult(object):
+    """
+    An object sent from FeedingProcess instances to the parent import task in order to indicate progress.
+    """
+    def __init__(self, sent, reader):
+        self.sent = sent
+        self.num_sources = reader.num_sources
+        self.skip_rows = reader.skip_rows
+class ImportTaskError(object):
+    """
+    An object sent from child processes (feeder or workers) to the parent import task to indicate an error.
+    """
+    def __init__(self, name, msg, rows=None, attempts=1, final=True):
+ = name
+        self.msg = msg
+        self.rows = rows if rows else []
+        self.attempts = attempts
+ = final
+    def is_parse_error(self):
+        """
+        We treat read and parse errors as unrecoverable and we have different global counters for giving up when
+        a maximum has been reached. We consider value and type errors as parse errors as well since they
+        are typically non recoverable.
+        """
+        name =
+        return name.startswith('ValueError') or name.startswith('TypeError') or \
+            name.startswith('ParseError') or name.startswith('IndexError') or name.startswith('ReadError')
+class ImportErrorHandler(object):
+    """
+    A class for managing import errors
     def __init__(self, task): =
-        self.reader = task.reader
         self.options = task.options
         self.printmsg = task.printmsg
         self.max_attempts = self.options.copy['maxattempts']
@@ -771,42 +953,26 @@ class ImportErrors(object):
             for row in rows:
-    def handle_error(self, err, batch):
+    def handle_error(self, err):
         Handle an error by printing the appropriate error message and incrementing the correct counter.
-        Return true if we should retry this batch, false if the error is non-recoverable
         shell =
-        err = str(err)
-        if self.is_parse_error(err):
-            self.parse_errors += len(batch['rows'])
-            self.add_failed_rows(batch['rows'])
-            shell.printerr("Failed to import %d rows: %s -  given up without retries"
-                           % (len(batch['rows']), err))
-            return False
+        if err.is_parse_error():
+            self.parse_errors += len(err.rows)
+            self.add_failed_rows(err.rows)
+            shell.printerr("Failed to import %d rows: %s - %s,  given up without retries"
+                           % (len(err.rows),, err.msg))
-            self.insert_errors += len(batch['rows'])
-            if batch['attempts'] < self.max_attempts:
-                shell.printerr("Failed to import %d rows: %s -  will retry later, attempt %d of %d"
-                               % (len(batch['rows']), err, batch['attempts'],
-                                  self.max_attempts))
-                return True
+            self.insert_errors += len(err.rows)
+            if not
+                shell.printerr("Failed to import %d rows: %s - %s,  will retry later, attempt %d of %d"
+                               % (len(err.rows),, err.msg, err.attempts, self.max_attempts))
-                self.add_failed_rows(batch['rows'])
-                shell.printerr("Failed to import %d rows: %s -  given up after %d attempts"
-                               % (len(batch['rows']), err, batch['attempts']))
-                return False
-    @staticmethod
-    def is_parse_error(err):
-        """
-        We treat parse errors as unrecoverable and we have different global counters for giving up when
-        a maximum has been reached. We consider value and type errors as parse errors as well since they
-        are typically non recoverable.
-        """
-        return err.startswith('ValueError') or err.startswith('TypeError') or \
-            err.startswith('ParseError') or err.startswith('IndexError')
+                self.add_failed_rows(err.rows)
+                shell.printerr("Failed to import %d rows: %s - %s,  given up after %d attempts"
+                               % (len(err.rows),, err.msg, err.attempts))
 class ImportTask(CopyTask):
@@ -818,22 +984,14 @@ class ImportTask(CopyTask):
         CopyTask.__init__(self, shell, ks, table, columns, fname, opts, protocol_version, config_file, 'from')
         options = self.options
-        self.ingest_rate = options.copy['ingestrate']
-        self.max_attempts = options.copy['maxattempts']
-        self.header = options.copy['header']
         self.skip_columns = [c.strip() for c in self.options.copy['skipcols'].split(',')]
         self.valid_columns = [c for c in self.columns if c not in self.skip_columns]
         self.table_meta =, self.table)
-        self.batch_id = 0
         self.receive_meter = RateMeter(log_fcn=self.printmsg,
-        self.send_meter = RateMeter(log_fcn=None, update_interval=1)
-        self.reader = ImportReader(self)
-        self.import_errors = ImportErrors(self)
-        self.retries = deque([])
-        self.failed = 0
-        self.succeeded = 0
+        self.error_handler = ImportErrorHandler(self)
+        self.feeding_result = None
         self.sent = 0
     def make_params(self):
@@ -861,17 +1019,24 @@ class ImportTask(CopyTask):
         self.printmsg("\nStarting copy of %s.%s with columns %s." % (self.ks, self.table, self.valid_columns))
-            self.reader.start()
             params = self.make_params()
-            for i in range(self.num_processes):
-                self.processes.append(ImportProcess(params))
+            for i in range(self.num_processes - 1):
+                self.processes.append(ImportProcess(self.update_params(params, i)))
+            feeder = FeedingProcess(self.outmsg.channels[-1], self.inmsg.channels[-1],
+                                    self.outmsg.channels[:-1], self.fname, self.options)
+            self.processes.append(feeder)
-            for process in self.processes:
-                process.start()
+            self.start_processes()
+            pr = profile_on() if PROFILE_ON else None
+            if pr:
+                profile_off(pr, file_name='parent_profile_%d.txt' % (os.getpid(),))
         except Exception, exc:
             if shell.debug:
@@ -880,9 +1045,22 @@ class ImportTask(CopyTask):
-    def close(self):
-        CopyTask.close(self)
-        self.reader.close()
+    def send_stdin_rows(self):
+        """
+        We need to pass stdin rows to the feeder process as it is not safe to pickle or share stdin
+        directly (in case of file the child process would close it). This is a very primitive support
+        for STDIN import in that we we won't start reporting progress until STDIN is fully consumed. I
+        think this is reasonable.
+        """
+        shell =
+        self.printmsg("[Use \. on a line by itself to end input]")
+        for row in shell.use_stdin_reader(prompt='[copy] ', until=r'.'):
+            self.outmsg.channels[-1].send(row)
+        self.outmsg.channels[-1].send(None)
+        if shell.tty:
+            print
     def import_records(self):
@@ -890,114 +1068,137 @@ class ImportTask(CopyTask):
         Send data (batches or retries) up to the max ingest rate. If we are waiting for stuff to
         receive check the incoming queue.
-        reader = self.reader
-        while self.has_more_to_send(reader) or self.has_more_to_receive():
-            if self.has_more_to_send(reader):
-                self.send_batches(reader)
+        if not self.fname:
+            self.send_stdin_rows()
-            if self.has_more_to_receive():
-                self.receive()
+        while self.feeding_result is None or self.receive_meter.total_records < self.feeding_result.sent:
+            self.receive_results()
-            if self.import_errors.max_exceeded() or not self.all_processes_running():
+            if self.error_handler.max_exceeded() or not self.all_processes_running():
-        if self.import_errors.num_rows_failed:
+        if self.error_handler.num_rows_failed:
   "Failed to process %d rows; failed rows written to %s" %
-                                (self.import_errors.num_rows_failed,
-                                 self.import_errors.err_file))
+                                (self.error_handler.num_rows_failed,
+                                 self.error_handler.err_file))
         if not self.all_processes_running():
   "{} child process(es) died unexpectedly, aborting"
                                 .format(self.num_processes - self.num_live_processes()))
+        else:
+            # it is only safe to write to processes if they are all running because the feeder process
+            # at the moment hangs whilst sending messages to a crashed worker process; in future
+            # we could do something about this by using a BoundedSemaphore to keep track of how many messages are
+            # queued on a pipe
+            for i, _ in enumerate(self.processes):
+                self.outmsg.channels[i].send(None)
+            if PROFILE_ON:
+                # allow time for worker processes to write profile results (only works if processes received
+                # the poison pill above)
+                time.sleep(5)
         self.printmsg("\n%d rows imported from %d files in %s (%d skipped)." %
-                       self.reader.num_sources,
+                       self.feeding_result.num_sources if self.feeding_result else 0,
                        self.describe_interval(time.time() - self.time_start),
-                       self.reader.skip_rows))
-    def has_more_to_receive(self):
-        return (self.succeeded + self.failed) < self.sent
-    def has_more_to_send(self, reader):
-        return (not reader.exhausted) or self.retries
+                       self.feeding_result.skip_rows if self.feeding_result else 0))
     def all_processes_running(self):
-        return self.num_live_processes() == self.num_processes
+        return self.num_live_processes() == len(self.processes)
-    def receive(self):
-        start_time = time.time()
+    def receive_results(self):
+        """
+        Receive results from the worker processes, which will send the number of rows imported
+        or from the feeder process, which will send the number of rows sent when it has finished sending rows.
+        """
+        aggregate_result = ImportProcessResult()
+        try:
+            for result in self.inmsg.recv(timeout=0.1):
+                if isinstance(result, ImportProcessResult):
+                    aggregate_result.imported += result.imported
+                elif isinstance(result, ImportTaskError):
+                    self.error_handler.handle_error(result)
+                elif isinstance(result, FeedingProcessResult):
+                    self.feeding_result = result
+                else:
+                    raise ValueError("Unexpected result: %s" % (result,))
+        finally:
+            self.receive_meter.increment(aggregate_result.imported)
-        while time.time() - start_time < 0.001:
-            try:
-                batch, err = self.inmsg.get(timeout=0.00001)
-                if err is None:
-                    self.succeeded += batch['imported']
-                    self.receive_meter.increment(batch['imported'])
-                else:
-                    err = str(err)
+class FeedingProcess(mp.Process):
+    """
+    A process that reads from import sources and sends chunks to worker processes.
+    """
+    def __init__(self, inmsg, outmsg, worker_channels, fname, options):
+        mp.Process.__init__(self,
+        self.inmsg = inmsg
+        self.outmsg = outmsg
+        self.worker_channels = worker_channels
+        self.reader = FilesReader(fname, options) if fname else PipeReader(inmsg, options)
+        self.send_meter = RateMeter(log_fcn=None, update_interval=1)
+        self.ingest_rate = options.copy['ingestrate']
+        self.num_worker_processes = options.copy['numprocesses']
+        self.chunk_id = 0
+    def run(self):
+        pr = profile_on() if PROFILE_ON else None
-                    if self.import_errors.handle_error(err, batch):
-                        self.retries.append(self.reset_batch(batch))
-                    else:
-                        self.failed += len(batch['rows'])
+        self.inner_run()
-            except Queue.Empty:
-                pass
+        if pr:
+            profile_off(pr, file_name='feeder_profile_%d.txt' % (os.getpid(),))
-    def send_batches(self, reader):
+    def inner_run(self):
         Send one batch per worker process to the queue unless we have exceeded the ingest rate.
         In the export case we queue everything and let the worker processes throttle using max_requests,
-        here we throttle using the ingest rate in the parent process because of memory usage concerns.
-        When we have finished reading the csv file, then send any retries.
+        here we throttle using the ingest rate in the feeding process because of memory usage concerns.
+        When finished we send back to the parent process the total number of rows sent.
-        for _ in xrange(self.num_processes):
-            max_rows = self.ingest_rate - self.send_meter.current_record
-            if max_rows <= 0:
-                self.send_meter.maybe_update()
-                break
+        reader = self.reader
+        reader.start()
+        channels = self.worker_channels
+        sent = 0
+        while not reader.exhausted:
+            for ch in channels:
+                try:
+                    max_rows = self.ingest_rate - self.send_meter.current_record
+                    if max_rows <= 0:
+                        self.send_meter.maybe_update(sleep=False)
+                        continue
+                    rows = reader.read_rows(max_rows)
+                    if rows:
+                        sent += self.send_chunk(ch, rows)
+                except Exception, exc:
+                    self.outmsg.send(ImportTaskError(exc.__class__.__name__, exc.message))
+                if reader.exhausted:
+                    break
-            if not reader.exhausted:
-                rows = reader.read_rows(max_rows)
-                if rows:
-                    self.sent += self.send_batch(self.new_batch(rows))
-            elif self.retries:
-                batch = self.retries.popleft()
-                if len(batch['rows']) <= max_rows:
-                    self.send_batch(batch)
-                else:
-                    self.send_batch(self.split_batch(batch, batch['rows'][:max_rows]))
-                    self.retries.append(self.split_batch(batch, batch['rows'][max_rows:]))
-            else:
-                break
+        # send back to the parent process the number of rows sent to the worker processes
+        self.outmsg.send(FeedingProcessResult(sent, reader))
+        # wait for poison pill (None)
+        self.inmsg.recv()
-    def send_batch(self, batch):
-        batch['attempts'] += 1
-        num_rows = len(batch['rows'])
+    def send_chunk(self, ch, rows):
+        self.chunk_id += 1
+        num_rows = len(rows)
-        self.outmsg.put(batch)
+        ch.send({'id': self.chunk_id, 'rows': rows, 'imported': 0, 'num_rows_sent': num_rows})
         return num_rows
-    def new_batch(self, rows):
-        self.batch_id += 1
-        return self.make_batch(self.batch_id, rows, 0)
-    @staticmethod
-    def reset_batch(batch):
-        batch['imported'] = 0
-        return batch
-    @staticmethod
-    def split_batch(batch, rows):
-        return ImportTask.make_batch(batch['id'], rows, batch['attempts'])
+    def close(self):
+        self.reader.close()
+        self.inmsg.close()
+        self.outmsg.close()
-    @staticmethod
-    def make_batch(batch_id, rows, attempts):
-        return {'id': batch_id, 'rows': rows, 'attempts': attempts, 'imported': 0}
+        for ch in self.worker_channels:
+            ch.close()
 class ChildProcess(mp.Process):
@@ -1029,6 +1230,7 @@ class ChildProcess(mp.Process):
         self.decimal_sep = options.copy['decimalsep']
         self.thousands_sep = options.copy['thousandssep']
         self.boolean_styles = options.copy['boolstyle']
+        self.max_attempts = options.copy['maxattempts']
         # Here we inject some failures for testing purposes, only if this environment variable is set
         if os.environ.get('CQLSH_COPY_TEST_FAILURES', ''):
             self.test_failures = json.loads(os.environ.get('CQLSH_COPY_TEST_FAILURES', ''))
@@ -1144,7 +1346,6 @@ class ExportProcess(ChildProcess):
         self.encoding = options.copy['encoding']
         self.float_precision = options.copy['float_precision']
         self.nullval = options.copy['nullval']
-        self.max_attempts = options.copy['maxattempts']
         self.max_requests = options.copy['maxrequests']
         self.hosts_to_sessions = dict()
@@ -1172,7 +1373,7 @@ class ExportProcess(ChildProcess):
                 time.sleep(0.001)  # 1 millisecond
-            token_range, info = self.inmsg.get()
+            token_range, info = self.inmsg.recv()
             self.start_request(token_range, info)
@@ -1190,7 +1391,7 @@ class ExportProcess(ChildProcess):
     def report_error(self, err, token_range=None):
         msg = self.get_error_message(err, print_traceback=self.debug)
-        self.outmsg.put((token_range, Exception(msg)))
+        self.outmsg.send((token_range, Exception(msg)))
     def start_request(self, token_range, info):
@@ -1253,7 +1454,8 @@ class ExportProcess(ChildProcess):
-            connect_timeout=self.connect_timeout)
+            connect_timeout=self.connect_timeout,
+            idle_heartbeat_interval=0)
         session = ExportSession(new_cluster, self)
         self.hosts_to_sessions[host] = session
         return session
@@ -1265,7 +1467,7 @@ class ExportProcess(ChildProcess):
                 self.write_rows_to_csv(token_range, rows)
                 self.write_rows_to_csv(token_range, rows)
-                self.outmsg.put((None, None))
+                self.outmsg.send((None, None))
         def err_callback(err):
@@ -1286,7 +1488,7 @@ class ExportProcess(ChildProcess):
                 writer.writerow(map(self.format_value, row))
             data = (output.getvalue(), len(rows))
-            self.outmsg.put((token_range, data))
+            self.outmsg.send((token_range, data))
         except Exception, e:
@@ -1376,7 +1578,7 @@ class ImportConversion(object):
     A class for converting strings to values when importing from csv, used by ImportProcess,
     the parent.
-    def __init__(self, parent, table_meta, statement):
+    def __init__(self, parent, table_meta, statement=None):
         self.ks = parent.ks
         self.table = parent.table
         self.columns = parent.valid_columns
@@ -1391,9 +1593,37 @@ class ImportConversion(object):
         self.primary_key_indexes = [self.columns.index( for col in self.table_meta.primary_key]
         self.partition_key_indexes = [self.columns.index( for col in self.table_meta.partition_key]
+        if statement is None:
+            self.use_prepared_statements = False
+            statement = self._get_primary_key_statement(parent, table_meta)
+        else:
+            self.use_prepared_statements = True
         self.proto_version = statement.protocol_version
-        self.cqltypes = dict([(, c.type) for c in statement.column_metadata])
-        self.converters = dict([(, self._get_converter(c.type)) for c in statement.column_metadata])
+        # the cql types and converters for the prepared statement, either the full statement or only the primary keys
+        self.cqltypes = [c.type for c in statement.column_metadata]
+        self.converters = [self._get_converter(c.type) for c in statement.column_metadata]
+        # the cql types for the entire statement, these are the same as the types above but
+        # only when using prepared statements
+        self.coltypes = [table_meta.columns[name].typestring for name in parent.valid_columns]
+        # these functions are used for non-prepared statements to protect values with quotes if required
+        self.protectors = [protect_value if t in ('ascii', 'text', 'timestamp', 'date', 'time', 'inet') else lambda v: v
+                           for t in self.coltypes]
+    @staticmethod
+    def _get_primary_key_statement(parent, table_meta):
+        """
+        We prepare a query statement to find out the types of the partition key columns so we can
+        route the update query to the correct replicas. As far as I understood this is the easiest
+        way to find out the types of the partition columns, we will never use this prepared statement
+        """
+        where_clause = ' AND '.join(['%s = ?' % (protect_name( for c in table_meta.partition_key])
+        select_query = 'SELECT * FROM %s.%s WHERE %s' % (protect_name(parent.ks),
+                                                         protect_name(parent.table),
+                                                         where_clause)
+        return parent.session.prepare(select_query)
     def _get_converter(self, cql_type):
@@ -1581,27 +1811,25 @@ class ImportConversion(object):
         return converters.get(cql_type.typename, convert_unknown)
-    def get_row_values(self, row):
+    def convert_row(self, row):
-        Parse the row into a list of row values to be returned
+        Convert the row into a list of parsed values if using prepared statements, else simply apply the
+        protection functions to escape values with quotes when required. Also check on the row length and
+        make sure primary partition key values aren't missing.
-        def convert(n, val):
-            try:
-                return self.converters[self.columns[n]](val)
-            except Exception, e:
-                raise ParseError(e.message)
+        converters = self.converters if self.use_prepared_statements else self.protectors
-        ret = [None] * len(row)
-        for i, val in enumerate(row):
-            if val != self.nullval:
-                ret[i] = convert(i, val)
-            else:
-                if i in self.primary_key_indexes:
-                    raise ParseError(self.get_null_primary_key_message(i))
+        if len(row) != len(converters):
+            raise ParseError('Invalid row length %d should be %d' % (len(row), len(converters)))
-                ret[i] = None
+        for i in self.primary_key_indexes:
+            if row[i] == self.nullval:
+                raise ParseError(self.get_null_primary_key_message(i))
-        return ret
+        try:
+            return [conv(val) for conv, val in zip(converters, row)]
+        except Exception, e:
+            raise ParseError(e.message)
     def get_null_primary_key_message(self, idx):
         message = "Cannot insert null value for primary key column '%s'." % (self.columns[idx],)
@@ -1610,31 +1838,111 @@ class ImportConversion(object):
                        " the WITH NULL=<marker> option for COPY."
         return message
-    def get_row_partition_key_values(self, row):
+    def get_row_partition_key_values_fcn(self):
-        Return a string composed of the partition key values, serialized and binary packed -
-        as expected by metadata.get_replicas(), see also BoundStatement.routing_key.
+        Return a function to convert a row into a string composed of the partition key values serialized
+        and binary packed (the tokens on the ring). Depending on whether we are using prepared statements, we
+        may have to convert the primary key values first, so we have two different serialize_value implementations.
+        We also return different functions depending on how many partition key indexes we have (single or multiple).
+        See also BoundStatement.routing_key.
-        def serialize(n):
-            try:
-                c, v = self.columns[n], row[n]
-                if v == self.nullval:
-                    raise ParseError(self.get_null_primary_key_message(n))
-                return self.cqltypes[c].serialize(self.converters[c](v), self.proto_version)
-            except Exception, e:
-                raise ParseError(e.message)
+        def serialize_value_prepared(n, v):
+            return self.cqltypes[n].serialize(v, self.proto_version)
+        def serialize_value_not_prepared(n, v):
+            return self.cqltypes[n].serialize(self.converters[n](v), self.proto_version)
         partition_key_indexes = self.partition_key_indexes
-        if len(partition_key_indexes) == 1:
-            return serialize(partition_key_indexes[0])
-        else:
+        serialize = serialize_value_prepared if self.use_prepared_statements else serialize_value_not_prepared
+        def serialize_row_single(row):
+            return serialize(partition_key_indexes[0], row[partition_key_indexes[0]])
+        def serialize_row_multiple(row):
             pk_values = []
             for i in partition_key_indexes:
-                val = serialize(i)
+                val = serialize(i, row[i])
                 l = len(val)
                 pk_values.append(struct.pack(">H%dsB" % l, l, val, 0))
             return b"".join(pk_values)
+        if len(partition_key_indexes) == 1:
+            return serialize_row_single
+        return serialize_row_multiple
+class TokenMap(object):
+    """
+    A wrapper around the metadata token map to speed things up by caching ring token *values* and
+    replicas. It is very important that we use the token values, which are primitive types, rather
+    than the tokens classes when calling bisect_right() in split_batches(). If we use primitive values,
+    the bisect is done in compiled code whilst with token classes each comparison requires a call
+    into the interpreter to perform the cmp operation defined in Python. A simple test with 1 million bisect
+    operations on an array of 2048 tokens was done in 0.37 seconds with primitives and 2.25 seconds with
+    token classes. This is significant for large datasets because we need to do a bisect for each single row,
+    and if VNODES are used, the size of the token map can get quite large too.
+    """
+    def __init__(self, ks, hostname, local_dc, session):
+        self.ks = ks
+        self.hostname = hostname
+        self.local_dc = local_dc
+        self.metadata = session.cluster.metadata
+        self._initialize_ring()
+        # Note that refresh metadata is disabled by default and we currenlty do not intercept it
+        # If hosts are added, removed or moved during a COPY operation our token map is no longer optimal
+        # However we can cope with hosts going down and up since we filter for replicas that are up when
+        # making each batch
+    def _initialize_ring(self):
+        token_map = self.metadata.token_map
+        if token_map is None:
+            self.ring = [0]
+            self.replicas = [(self.metadata.get_host(self.hostname),)]
+            self.pk_to_token_value = lambda pk: 0
+            return
+        token_map.rebuild_keyspace(self.ks, build_if_absent=True)
+        tokens_to_hosts = token_map.tokens_to_hosts_by_ks.get(self.ks, None)
+        from_key = token_map.token_class.from_key
+        self.ring = [token.value for token in token_map.ring]
+        self.replicas = [tuple(tokens_to_hosts[token]) for token in token_map.ring]
+        self.pk_to_token_value = lambda pk: from_key(pk).value
+    @staticmethod
+    def get_ring_pos(ring, val):
+        idx = bisect_right(ring, val)
+        return idx if idx < len(ring) else 0
+    def filter_replicas(self, hosts):
+        shuffled = tuple(sorted(hosts, key=lambda k: random.random()))
+        return filter(lambda r: r.is_up and r.datacenter == self.local_dc, shuffled) if hosts else ()
+class FastTokenAwarePolicy(DCAwareRoundRobinPolicy):
+    """
+    Send to any replicas attached to the query, or else fall back to DCAwareRoundRobinPolicy
+    """
+    def __init__(self, local_dc='', used_hosts_per_remote_dc=0):
+        DCAwareRoundRobinPolicy.__init__(self, local_dc, used_hosts_per_remote_dc)
+    def make_query_plan(self, working_keyspace=None, query=None):
+        """
+        Extend TokenAwarePolicy.make_query_plan() so that we choose the same replicas in preference
+        and most importantly we avoid repeating the (slow) bisect
+        """
+        replicas = query.replicas if hasattr(query, 'replicas') else []
+        for r in replicas:
+            yield r
+        for r in DCAwareRoundRobinPolicy.make_query_plan(self, working_keyspace, query):
+            if r not in replicas:
+                yield r
 class ImportProcess(ChildProcess):
@@ -1650,7 +1958,12 @@ class ImportProcess(ChildProcess):
         self.max_attempts = options.copy['maxattempts']
         self.min_batch_size = options.copy['minbatchsize']
         self.max_batch_size = options.copy['maxbatchsize']
+        self.use_prepared_statements = options.copy['preparedstatements']
+        self.dialect_options = options.dialect
         self._session = None
+        self.query = None
+        self.conv = None
+        self.make_statement = None
     def session(self):
@@ -1661,12 +1974,13 @@ class ImportProcess(ChildProcess):
-                load_balancing_policy=TokenAwarePolicy(DCAwareRoundRobinPolicy(local_dc=self.local_dc)),
+                load_balancing_policy=FastTokenAwarePolicy(local_dc=self.local_dc),
                 ssl_options=ssl_settings(self.hostname, self.config_file) if self.ssl else None,
-                connect_timeout=self.connect_timeout)
+                connect_timeout=self.connect_timeout,
+                idle_heartbeat_interval=0)
             self._session = cluster.connect(self.ks)
             self._session.default_timeout = None
@@ -1674,13 +1988,12 @@ class ImportProcess(ChildProcess):
     def run(self):
-            table_meta = self.session.cluster.metadata.keyspaces[self.ks].tables[self.table]
-            is_counter = ("counter" in [table_meta.columns[name].typestring for name in self.valid_columns])
+            pr = profile_on() if PROFILE_ON else None
-            if is_counter:
-                self.run_counter(table_meta)
-            else:
-                self.run_normal(table_meta)
+            self.inner_run(*self.make_params())
+            if pr:
+                profile_off(pr, file_name='worker_profile_%d.txt' % (os.getpid(),))
         except Exception, exc:
             if self.debug:
@@ -1694,67 +2007,88 @@ class ImportProcess(ChildProcess):
-    def run_counter(self, table_meta):
-        """
-        Main run method for tables that contain counter columns.
-        """
-        query = 'UPDATE %s.%s SET %%s WHERE %%s' % (protect_name(self.ks), protect_name(self.table))
-        # We prepare a query statement to find out the types of the partition key columns so we can
-        # route the update query to the correct replicas. As far as I understood this is the easiest
-        # way to find out the types of the partition columns, we will never use this prepared statement
-        where_clause = ' AND '.join(['%s = ?' % (protect_name( for c in table_meta.partition_key])
-        select_query = 'SELECT * FROM %s.%s WHERE %s' % (protect_name(self.ks), protect_name(self.table), where_clause)
-        conv = ImportConversion(self, table_meta, self.session.prepare(select_query))
-        while True:
-            batch = self.inmsg.get()
-            try:
-                for b in self.split_batches(batch, conv):
-                    self.send_counter_batch(query, conv, b)
+    def make_params(self):
+        metadata = self.session.cluster.metadata
+        table_meta = metadata.keyspaces[self.ks].tables[self.table]
+        prepared_statement = None
+        is_counter = ("counter" in [table_meta.columns[name].typestring for name in self.valid_columns])
+        if is_counter:
+            query = 'UPDATE %s.%s SET %%s WHERE %%s' % (protect_name(self.ks), protect_name(self.table))
+            make_statement = self.wrap_make_statement(self.make_counter_batch_statement)
+        elif self.use_prepared_statements:
+            query = 'INSERT INTO %s.%s (%s) VALUES (%s)' % (protect_name(self.ks),
+                                                            protect_name(self.table),
+                                                            ', '.join(protect_names(self.valid_columns),),
+                                                            ', '.join(['?' for _ in self.valid_columns]))
+            query = self.session.prepare(query)
+            query.consistency_level = self.consistency_level
+            prepared_statement = query
+            make_statement = self.wrap_make_statement(self.make_prepared_batch_statement)
+        else:
+            query = 'INSERT INTO %s.%s (%s) VALUES (%%s)' % (protect_name(self.ks),
+                                                             protect_name(self.table),
+                                                             ', '.join(protect_names(self.valid_columns),))
+            make_statement = self.wrap_make_statement(self.make_non_prepared_batch_statement)
-            except Exception, exc:
-                self.outmsg.put((batch, '%s - %s' % (exc.__class__.__name__, exc.message)))
-                if self.debug:
-                    traceback.print_exc(exc)
+        conv = ImportConversion(self, table_meta, prepared_statement)
+        tm = TokenMap(self.ks, self.hostname, self.local_dc, self.session)
+        return query, conv, tm, make_statement
-    def run_normal(self, table_meta):
+    def inner_run(self, query, conv, tm, make_statement):
-        Main run method for normal tables, i.e. tables that do not contain counter columns.
+        Main run method. Note that we bind self methods that are called inside loops
+        for performance reasons.
-        query = 'INSERT INTO %s.%s (%s) VALUES (%s)' % (protect_name(self.ks),
-                                                        protect_name(self.table),
-                                                        ', '.join(protect_names(self.valid_columns),),
-                                                        ', '.join(['?' for _ in self.valid_columns]))
+        self.query = query
+        self.conv = conv
+        self.make_statement = make_statement
-        query_statement = self.session.prepare(query)
-        query_statement.consistency_level = self.consistency_level
-        conv = ImportConversion(self, table_meta, query_statement)
+        convert_rows = self.convert_rows
+        split_into_batches = self.split_into_batches
+        result_callback = self.result_callback
+        err_callback = self.err_callback
+        session = self.session
         while True:
-            batch = self.inmsg.get()
+            chunk = self.inmsg.recv()
+            if chunk is None:
+                break
-                for b in self.split_batches(batch, conv):
-                    self.send_normal_batch(conv, query_statement, b)
+                chunk['rows'] = convert_rows(conv, chunk)
+                for replicas, batch in split_into_batches(chunk, conv, tm):
+                    statement = make_statement(query, conv, chunk, batch, replicas)
+                    future = session.execute_async(statement)
+                    future.add_callbacks(callback=result_callback, callback_args=(batch, chunk),
+                                         errback=err_callback, errback_args=(batch, chunk, replicas))
             except Exception, exc:
-                self.outmsg.put((batch, '%s - %s' % (exc.__class__.__name__, exc.message)))
-                if self.debug:
-                    traceback.print_exc(exc)
+                self.report_error(exc, chunk, chunk['rows'])
-    def send_counter_batch(self, query_text, conv, batch):
-        if self.test_failures and self.maybe_inject_failures(batch):
-            return
+    def wrap_make_statement(self, inner_make_statement):
+        def make_statement(query, conv, chunk, batch, replicas):
+            try:
+                return inner_make_statement(query, conv, batch, replicas)
+            except Exception, exc:
+                print "Failed to make batch statement: {}".format(exc)
+                self.report_error(exc, chunk, batch['rows'])
+                return None
-        error_rows = []
-        batch_statement = BatchStatement(batch_type=BatchType.COUNTER, consistency_level=self.consistency_level)
+        def make_statement_with_failures(query, conv, chunk, batch, replicas):
+            failed_batch = self.maybe_inject_failures(batch)
+            if failed_batch:
+                return failed_batch
+            return make_statement(query, conv, chunk, batch, replicas)
-        for r in batch['rows']:
-            row = self.filter_row_values(r)
-            if len(row) != len(self.valid_columns):
-                error_rows.append(row)
-                continue
+        return make_statement_with_failures if self.test_failures else make_statement
+    def make_counter_batch_statement(self, query, conv, batch, replicas):
+        statement = BatchStatement(batch_type=BatchType.COUNTER, consistency_level=self.consistency_level)
+        statement.replicas = replicas
+        statement.keyspace = self.ks
+        for row in batch['rows']:
             where_clause = []
             set_clause = []
             for i, value in enumerate(row):
@@ -1763,65 +2097,61 @@ class ImportProcess(ChildProcess):
                     set_clause.append("%s=%s+%s" % (self.valid_columns[i], self.valid_columns[i], value))
-            full_query_text = query_text % (','.join(set_clause), ' AND '.join(where_clause))
-            batch_statement.add(full_query_text)
-        self.execute_statement(batch_statement, batch)
+            full_query_text = query % (','.join(set_clause), ' AND '.join(where_clause))
+            statement.add(full_query_text)
+        return statement
-        if error_rows:
-            self.outmsg.put((ImportTask.split_batch(batch, error_rows),
-                            '%s - %s' % (ParseError.__name__, "Failed to parse one or more rows")))
+    def make_prepared_batch_statement(self, query, _, batch, replicas):
+        """
+        Return a batch statement. This is an optimized version of:
-    def send_normal_batch(self, conv, query_statement, batch):
-        if self.test_failures and self.maybe_inject_failures(batch):
-            return
+            statement = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=self.consistency_level)
+            for row in batch['rows']:
+                statement.add(query, row)
-        good_rows, converted_rows, errors = self.convert_rows(conv, batch['rows'])
+        We could optimize further by removing bound_statements altogether but we'd have to duplicate much
+        more driver's code (BoundStatement.bind()).
+        """
+        statement = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=self.consistency_level)
+        statement.replicas = replicas
+        statement.keyspace = self.ks
+        statement._statements_and_parameters = [(True, query.query_id, query.bind(r).values) for r in batch['rows']]
+        return statement
-        if converted_rows:
-            try:
-                statement = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=self.consistency_level)
-                for row in converted_rows:
-                    statement.add(query_statement, row)
-                self.execute_statement(statement, ImportTask.split_batch(batch, good_rows))
-            except Exception, exc:
-                self.err_callback(exc, ImportTask.split_batch(batch, good_rows))
+    def make_non_prepared_batch_statement(self, query, _, batch, replicas):
+        statement = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=self.consistency_level)
+        statement.replicas = replicas
+        statement.keyspace = self.ks
+        statement._statements_and_parameters = [(False, query % (','.join(r),), ()) for r in batch['rows']]
+        return statement
-        if errors:
-            for msg, rows in errors.iteritems():
-                self.outmsg.put((ImportTask.split_batch(batch, rows),
-                                '%s - %s' % (ParseError.__name__, msg)))
-    def convert_rows(self, conv, rows):
+    def convert_rows(self, conv, chunk):
-        Try to convert each row. If conversion is OK then add the converted result to converted_rows
-        and the original string to good_rows. Else add the original string to error_rows. Return the three
-        arrays.
+        Return converted rows and report any errors during conversion.
-        good_rows = []
-        errors = defaultdict(list)
-        converted_rows = []
+        def filter_row_values(row):
+            return [v for i, v in enumerate(row) if i not in self.skip_column_indexes]
-        for r in rows:
-            row = self.filter_row_values(r)
-            if len(row) != len(self.valid_columns):
-                msg = 'Invalid row length %d should be %d' % (len(row), len(self.valid_columns))
-                errors[msg].append(row)
-                continue
+        if self.skip_column_indexes:
+            rows = [filter_row_values(r) for r in list(csv.reader(chunk['rows'], **self.dialect_options))]
+        else:
+            rows = list(csv.reader(chunk['rows'], **self.dialect_options))
-            try:
-                converted_rows.append(conv.get_row_values(row))
-                good_rows.append(row)
-            except ParseError, err:
-                errors[err.message].append(row)
+        errors = defaultdict(list)
-        return good_rows, converted_rows, errors
+        def convert_row(r):
+            try:
+                return conv.convert_row(r)
+            except Exception, err:
+                errors[err.message].append(r)
+                return None
-    def filter_row_values(self, row):
-        if not self.skip_column_indexes:
-            return row
+        converted_rows = filter(None, [convert_row(r) for r in rows])
-        return [v for i, v in enumerate(row) if i not in self.skip_column_indexes]
+        if errors:
+            for msg, rows in errors.iteritems():
+                self.report_error(ParseError(msg), chunk, rows)
+        return converted_rows
     def maybe_inject_failures(self, batch):
@@ -1836,86 +2166,94 @@ class ImportProcess(ChildProcess):
                 if batch['attempts'] < failing_batch['failures']:
                     statement = SimpleStatement("INSERT INTO badtable (a, b) VALUES (1, 2)",
-                    self.execute_statement(statement, batch)
-                    return True
+                    return statement
         if 'exit_batch' in self.test_failures:
             exit_batch = self.test_failures['exit_batch']
             if exit_batch['id'] == batch['id']:
-        return False  # carry on as normal
+        return None  # carry on as normal
-    def execute_statement(self, statement, batch):
-        future = self.session.execute_async(statement)
-        future.add_callbacks(callback=self.result_callback, callback_args=(batch, ),
-                             errback=self.err_callback, errback_args=(batch, ))
+    @staticmethod
+    def make_batch(batch_id, rows, attempts=1):
+        return {'id': batch_id, 'rows': rows, 'attempts': attempts}
-    def split_batches(self, batch, conv):
+    def split_into_batches(self, chunk, conv, tm):
-        Batch rows by partition key, if there are at least min_batch_size (2)
-        rows with the same partition key. These batches can be as big as they want
-        since this translates to a single insert operation server side.
-        If there are less than min_batch_size rows for a partition, work out the
-        first replica for this partition and add the rows to replica left-over rows.
-        Then batch the left-overs of each replica up to max_batch_size.
+        Batch rows by ring position or replica.
+        If there are at least min_batch_size rows for a ring position then split these rows into
+        groups of max_batch_size and send a batch for each group, using all replicas for this ring position.
+        Otherwise, we are forced to batch by replica, and here unfortunately we can only choose one replica to
+        guarantee common replicas across partition keys. We are typically able
+        to batch by ring position for small clusters or when VNODES are not used. For large clusters with VNODES
+        it may not be possible, in this case it helps to increase the CHUNK SIZE but up to a limit, otherwise
+        we may choke the cluster.
-        rows_by_pk = defaultdict(list)
+        rows_by_ring_pos = defaultdict(list)
         errors = defaultdict(list)
-        for row in batch['rows']:
+        min_batch_size = self.min_batch_size
+        max_batch_size = self.max_batch_size
+        ring = tm.ring
+        get_row_partition_key_values = conv.get_row_partition_key_values_fcn()
+        pk_to_token_value = tm.pk_to_token_value
+        get_ring_pos = tm.get_ring_pos
+        make_batch = self.make_batch
+        for row in chunk['rows']:
-                pk = conv.get_row_partition_key_values(row)
-                rows_by_pk[pk].append(row)
-            except ParseError, e:
+                pk = get_row_partition_key_values(row)
+                rows_by_ring_pos[get_ring_pos(ring, pk_to_token_value(pk))].append(row)
+            except Exception, e:
         if errors:
             for msg, rows in errors.iteritems():
-                self.outmsg.put((ImportTask.split_batch(batch, rows),
-                                 '%s - %s' % (ParseError.__name__, msg)))
+                self.report_error(ParseError(msg), chunk, rows)
+        replicas = tm.replicas
+        filter_replicas = tm.filter_replicas
         rows_by_replica = defaultdict(list)
-        for pk, rows in rows_by_pk.iteritems():
-            if len(rows) >= self.min_batch_size:
-                yield ImportTask.make_batch(batch['id'], rows, batch['attempts'])
+        for ring_pos, rows in rows_by_ring_pos.iteritems():
+            if len(rows) > min_batch_size:
+                for i in xrange(0, len(rows), max_batch_size):
+                    yield filter_replicas(replicas[ring_pos]), make_batch(chunk['id'], rows[i:i + max_batch_size])
-                replica = self.get_replica(pk)
-                rows_by_replica[replica].extend(rows)
-        for replica, rows in rows_by_replica.iteritems():
-            for b in self.batches(rows, batch):
-                yield b
-    def get_replica(self, pk):
-        """
-        Return the first replica or the host we are already connected to if there are no local
-        replicas that are up. We always use the first replica to match the replica chosen by the driver
-        TAR, see TokenAwarePolicy.make_query_plan().
-        """
-        metadata = self.session.cluster.metadata
-        replicas = filter(lambda r: r.is_up and r.datacenter == self.local_dc, metadata.get_replicas(self.ks, pk))
-        ret = replicas[0].address if len(replicas) > 0 else self.hostname
-        return ret
-    def batches(self, rows, batch):
-        """
-        Split rows into batches of max_batch_size
-        """
-        for i in xrange(0, len(rows), self.max_batch_size):
-            yield ImportTask.make_batch(batch['id'], rows[i:i + self.max_batch_size], batch['attempts'])
-    def result_callback(self, _, batch):
-        batch['imported'] = len(batch['rows'])
-        batch['rows'] = []  # no need to resend these, just send the count in 'imported'
-        self.outmsg.put((batch, None))
-    def err_callback(self, response, batch):
-        self.outmsg.put((batch, '%s - %s' % (response.__class__.__name__, response.message)))
+                # select only the first valid replica to guarantee more overlap or none at all
+                rows_by_replica[filter_replicas(replicas[ring_pos])[:1]].extend(rows)
+        # Now send the batches by replica
+        for replicas, rows in rows_by_replica.iteritems():
+            for i in xrange(0, len(rows), max_batch_size):
+                yield replicas, make_batch(chunk['id'], rows[i:i + max_batch_size])
+    def result_callback(self, _, batch, chunk):
+        self.update_chunk(batch['rows'], chunk)
+    def err_callback(self, response, batch, chunk, replicas):
+        err_is_final = batch['attempts'] >= self.max_attempts
+        self.report_error(response, chunk, batch['rows'], batch['attempts'], err_is_final)
+        if not err_is_final:
+            batch['attempts'] += 1
+            statement = self.make_statement(self.query, self.conv, chunk, batch, replicas)
+            future = self.session.execute_async(statement)
+            future.add_callbacks(callback=self.result_callback, callback_args=(batch, chunk),
+                                 errback=self.err_callback, errback_args=(batch, chunk, replicas))
+    def report_error(self, err, chunk, rows=None, attempts=1, final=True):
         if self.debug:
-            traceback.print_exc(response)
+            traceback.print_exc(err)
+        self.outmsg.send(ImportTaskError(err.__class__.__name__, err.message, rows, attempts, final))
+        if final:
+            self.update_chunk(rows, chunk)
+    def update_chunk(self, rows, chunk):
+        chunk['imported'] += len(rows)
+        if chunk['imported'] == chunk['num_rows_sent']:
+            self.outmsg.send(ImportProcessResult(chunk['num_rows_sent']))
 class RateMeter(object):
@@ -1937,11 +2275,19 @@ class RateMeter(object):
         self.current_record += n
-    def maybe_update(self):
+    def maybe_update(self, sleep=False):
+        if self.current_record == 0:
+            return
         new_checkpoint_time = time.time()
-        if new_checkpoint_time - self.last_checkpoint_time >= self.update_interval:
+        time_difference = new_checkpoint_time - self.last_checkpoint_time
+        if time_difference >= self.update_interval:
+        elif sleep:
+            remaining_time = time_difference - self.update_interval
+            if remaining_time > 0.000001:
+                time.sleep(remaining_time)
     def update(self, new_checkpoint_time):
         time_difference = new_checkpoint_time - self.last_checkpoint_time
diff --git a/pylib/cqlshlib/ b/pylib/cqlshlib/
index 281aad6..3ee128d 100644
--- a/pylib/cqlshlib/
+++ b/pylib/cqlshlib/
@@ -23,6 +23,12 @@ from itertools import izip
 from datetime import timedelta, tzinfo
 from StringIO import StringIO
+    from line_profiler import LineProfiler
+except ImportError:
 ZERO = timedelta(0)
@@ -126,18 +132,35 @@ def get_file_encoding_bomsize(filename):
         file_encoding, size = "utf-8", 0
-    return (file_encoding, size)
+    return file_encoding, size
+def profile_on(fcn_names=None):
+    if fcn_names and HAS_LINE_PROFILER:
+        pr = LineProfiler()
+        for fcn_name in fcn_names:
+            pr.add_function(fcn_name)
+        pr.enable()
+        return pr
-def profile_on():
     pr = cProfile.Profile()
     return pr
-def profile_off(pr):
+def profile_off(pr, file_name):
     s = StringIO()
-    ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
-    ps.print_stats()
-    print s.getvalue()
+    if HAS_LINE_PROFILER and isinstance(pr, LineProfiler):
+        pr.print_stats(s)
+    else:
+        ps = pstats.Stats(pr, stream=s).sort_stats('cumulative')
+        ps.print_stats()
+    ret = s.getvalue()
+    if file_name:
+        with open(file_name, 'w') as f:
+            print "Writing to %s\n" % (, )
+            f.write(ret)
+    return ret
diff --git a/pylib/ b/pylib/
index 704d077..3654502 100755
--- a/pylib/
+++ b/pylib/
@@ -16,9 +16,11 @@
 # limitations under the License.
 from distutils.core import setup
+from Cython.Build import cythonize
     description="Cassandra Python Libraries",
+    ext_modules=cythonize("cqlshlib/"),