You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by db...@apache.org on 2013/06/07 00:16:15 UTC
[2/5] rename
http://git-wip-us.apache.org/repos/asf/cassandra/blob/a0047797/src/java/org/apache/cassandra/hadoop/cql3/CqlPagingRecordReader.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/hadoop/cql3/CqlPagingRecordReader.java b/src/java/org/apache/cassandra/hadoop/cql3/CqlPagingRecordReader.java
new file mode 100644
index 0000000..3a0f628
--- /dev/null
+++ b/src/java/org/apache/cassandra/hadoop/cql3/CqlPagingRecordReader.java
@@ -0,0 +1,763 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.cassandra.hadoop.cql3;
+
+import java.io.IOException;
+import java.net.InetAddress;
+import java.net.UnknownHostException;
+import java.nio.ByteBuffer;
+import java.nio.charset.CharacterCodingException;
+import java.util.*;
+
+import com.google.common.collect.AbstractIterator;
+import com.google.common.collect.Iterables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.cassandra.db.marshal.AbstractType;
+import org.apache.cassandra.db.marshal.CompositeType;
+import org.apache.cassandra.db.marshal.LongType;
+import org.apache.cassandra.db.marshal.TypeParser;
+import org.apache.cassandra.dht.IPartitioner;
+import org.apache.cassandra.exceptions.ConfigurationException;
+import org.apache.cassandra.exceptions.SyntaxException;
+import org.apache.cassandra.hadoop.ColumnFamilySplit;
+import org.apache.cassandra.hadoop.ConfigHelper;
+import org.apache.cassandra.thrift.*;
+import org.apache.cassandra.utils.ByteBufferUtil;
+import org.apache.cassandra.utils.FBUtilities;
+import org.apache.cassandra.utils.Pair;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.mapreduce.InputSplit;
+import org.apache.hadoop.mapreduce.RecordReader;
+import org.apache.hadoop.mapreduce.TaskAttemptContext;
+import org.apache.thrift.TException;
+import org.apache.thrift.transport.TTransport;
+
+/**
+ * Hadoop RecordReader read the values return from the CQL query
+ * It use CQL key range query to page through the wide rows.
+ * <p/>
+ * Return List<IColumn> as keys columns
+ * <p/>
+ * Map<ByteBuffer, IColumn> as column name to columns mappings
+ */
+public class CqlPagingRecordReader extends RecordReader<Map<String, ByteBuffer>, Map<String, ByteBuffer>>
+ implements org.apache.hadoop.mapred.RecordReader<Map<String, ByteBuffer>, Map<String, ByteBuffer>>
+{
+ private static final Logger logger = LoggerFactory.getLogger(CqlPagingRecordReader.class);
+
+ public static final int DEFAULT_CQL_PAGE_LIMIT = 1000; // TODO: find the number large enough but not OOM
+
+ private ColumnFamilySplit split;
+ private RowIterator rowIterator;
+
+ private Pair<Map<String, ByteBuffer>, Map<String, ByteBuffer>> currentRow;
+ private int totalRowCount; // total number of rows to fetch
+ private String keyspace;
+ private String cfName;
+ private Cassandra.Client client;
+ private ConsistencyLevel consistencyLevel;
+
+ // partition keys -- key aliases
+ private List<BoundColumn> partitionBoundColumns = new ArrayList<BoundColumn>();
+
+ // cluster keys -- column aliases
+ private List<BoundColumn> clusterColumns = new ArrayList<BoundColumn>();
+
+ // map prepared query type to item id
+ private Map<Integer, Integer> preparedQueryIds = new HashMap<Integer, Integer>();
+
+ // cql query select columns
+ private String columns;
+
+ // the number of cql rows per page
+ private int pageRowSize;
+
+ // user defined where clauses
+ private String userDefinedWhereClauses;
+
+ private IPartitioner partitioner;
+
+ private AbstractType<?> keyValidator;
+
+ public CqlPagingRecordReader()
+ {
+ super();
+ }
+
+ public void initialize(InputSplit split, TaskAttemptContext context) throws IOException
+ {
+ this.split = (ColumnFamilySplit) split;
+ Configuration conf = context.getConfiguration();
+ totalRowCount = (this.split.getLength() < Long.MAX_VALUE)
+ ? (int) this.split.getLength()
+ : ConfigHelper.getInputSplitSize(conf);
+ cfName = ConfigHelper.getInputColumnFamily(conf);
+ consistencyLevel = ConsistencyLevel.valueOf(ConfigHelper.getReadConsistencyLevel(conf));
+ keyspace = ConfigHelper.getInputKeyspace(conf);
+ columns = CqlConfigHelper.getInputcolumns(conf);
+ userDefinedWhereClauses = CqlConfigHelper.getInputWhereClauses(conf);
+
+ try
+ {
+ pageRowSize = Integer.parseInt(CqlConfigHelper.getInputPageRowSize(conf));
+ }
+ catch (NumberFormatException e)
+ {
+ pageRowSize = DEFAULT_CQL_PAGE_LIMIT;
+ }
+
+ partitioner = ConfigHelper.getInputPartitioner(context.getConfiguration());
+
+ try
+ {
+ if (client != null)
+ return;
+
+ // create connection using thrift
+ String location = getLocation();
+
+ int port = ConfigHelper.getInputRpcPort(conf);
+ client = CqlPagingInputFormat.createAuthenticatedClient(location, port, conf);
+
+ // retrieve partition keys and cluster keys from system.schema_columnfamilies table
+ retrieveKeys();
+
+ client.set_keyspace(keyspace);
+ }
+ catch (Exception e)
+ {
+ throw new RuntimeException(e);
+ }
+
+ rowIterator = new RowIterator();
+
+ logger.debug("created {}", rowIterator);
+ }
+
+ public void close()
+ {
+ if (client != null)
+ {
+ TTransport transport = client.getOutputProtocol().getTransport();
+ if (transport.isOpen())
+ transport.close();
+ client = null;
+ }
+ }
+
+ public Map<String, ByteBuffer> getCurrentKey()
+ {
+ return currentRow.left;
+ }
+
+ public Map<String, ByteBuffer> getCurrentValue()
+ {
+ return currentRow.right;
+ }
+
+ public float getProgress()
+ {
+ if (!rowIterator.hasNext())
+ return 1.0F;
+
+ // the progress is likely to be reported slightly off the actual but close enough
+ float progress = ((float) rowIterator.totalRead / totalRowCount);
+ return progress > 1.0F ? 1.0F : progress;
+ }
+
+ public boolean nextKeyValue() throws IOException
+ {
+ if (!rowIterator.hasNext())
+ {
+ logger.debug("Finished scanning " + rowIterator.totalRead + " rows (estimate was: " + totalRowCount + ")");
+ return false;
+ }
+
+ try
+ {
+ currentRow = rowIterator.next();
+ }
+ catch (Exception e)
+ {
+ // throw it as IOException, so client can catch it and handle it at client side
+ IOException ioe = new IOException(e.getMessage());
+ ioe.initCause(ioe.getCause());
+ throw ioe;
+ }
+ return true;
+ }
+
+ // we don't use endpointsnitch since we are trying to support hadoop nodes that are
+ // not necessarily on Cassandra machines, too. This should be adequate for single-DC clusters, at least.
+ private String getLocation()
+ {
+ Collection<InetAddress> localAddresses = FBUtilities.getAllLocalAddresses();
+
+ for (InetAddress address : localAddresses)
+ {
+ for (String location : split.getLocations())
+ {
+ InetAddress locationAddress;
+ try
+ {
+ locationAddress = InetAddress.getByName(location);
+ }
+ catch (UnknownHostException e)
+ {
+ throw new AssertionError(e);
+ }
+ if (address.equals(locationAddress))
+ {
+ return location;
+ }
+ }
+ }
+ return split.getLocations()[0];
+ }
+
+ // Because the old Hadoop API wants us to write to the key and value
+ // and the new asks for them, we need to copy the output of the new API
+ // to the old. Thus, expect a small performance hit.
+ // And obviously this wouldn't work for wide rows. But since ColumnFamilyInputFormat
+ // and ColumnFamilyRecordReader don't support them, it should be fine for now.
+ public boolean next(Map<String, ByteBuffer> keys, Map<String, ByteBuffer> value) throws IOException
+ {
+ if (nextKeyValue())
+ {
+ value.clear();
+ value.putAll(getCurrentValue());
+
+ keys.clear();
+ keys.putAll(getCurrentKey());
+
+ return true;
+ }
+ return false;
+ }
+
+ public long getPos() throws IOException
+ {
+ return (long) rowIterator.totalRead;
+ }
+
+ public Map<String, ByteBuffer> createKey()
+ {
+ return new LinkedHashMap<String, ByteBuffer>();
+ }
+
+ public Map<String, ByteBuffer> createValue()
+ {
+ return new LinkedHashMap<String, ByteBuffer>();
+ }
+
+ /** CQL row iterator */
+ private class RowIterator extends AbstractIterator<Pair<Map<String, ByteBuffer>, Map<String, ByteBuffer>>>
+ {
+ protected int totalRead = 0; // total number of cf rows read
+ protected Iterator<CqlRow> rows;
+ private int pageRows = 0; // the number of cql rows read of this page
+ private String previousRowKey = null; // previous CF row key
+ private String partitionKeyString; // keys in <key1>, <key2>, <key3> string format
+ private String partitionKeyMarkers; // question marks in ? , ? , ? format which matches the number of keys
+
+ public RowIterator()
+ {
+ // initial page
+ executeQuery();
+ }
+
+ protected Pair<Map<String, ByteBuffer>, Map<String, ByteBuffer>> computeNext()
+ {
+ if (rows == null)
+ return endOfData();
+
+ int index = -2;
+ //check there are more page to read
+ while (!rows.hasNext())
+ {
+ // no more data
+ if (index == -1 || emptyPartitionKeyValues())
+ {
+ logger.debug("no more data.");
+ return endOfData();
+ }
+
+ index = setTailNull(clusterColumns);
+ logger.debug("set tail to null, index: " + index);
+ executeQuery();
+ pageRows = 0;
+
+ if (rows == null || !rows.hasNext() && index < 0)
+ {
+ logger.debug("no more data.");
+ return endOfData();
+ }
+ }
+
+ Map<String, ByteBuffer> valueColumns = createValue();
+ Map<String, ByteBuffer> keyColumns = createKey();
+ int i = 0;
+ CqlRow row = rows.next();
+ for (Column column : row.columns)
+ {
+ String columnName = stringValue(ByteBuffer.wrap(column.getName()));
+ logger.debug("column: " + columnName);
+
+ if (i < partitionBoundColumns.size() + clusterColumns.size())
+ keyColumns.put(stringValue(column.name), column.value);
+ else
+ valueColumns.put(stringValue(column.name), column.value);
+
+ i++;
+ }
+
+ // increase total CQL row read for this page
+ pageRows++;
+
+ // increase total CF row read
+ if (newRow(keyColumns, previousRowKey))
+ totalRead++;
+
+ // read full page
+ if (pageRows >= pageRowSize || !rows.hasNext())
+ {
+ Iterator<String> newKeys = keyColumns.keySet().iterator();
+ for (BoundColumn column : partitionBoundColumns)
+ column.value = keyColumns.get(newKeys.next());
+
+ for (BoundColumn column : clusterColumns)
+ column.value = keyColumns.get(newKeys.next());
+
+ executeQuery();
+ pageRows = 0;
+ }
+
+ return Pair.create(keyColumns, valueColumns);
+ }
+
+ /** check whether start to read a new CF row by comparing the partition keys */
+ private boolean newRow(Map<String, ByteBuffer> keyColumns, String previousRowKey)
+ {
+ if (keyColumns.isEmpty())
+ return false;
+
+ String rowKey = "";
+ if (keyColumns.size() == 1)
+ {
+ rowKey = partitionBoundColumns.get(0).validator.getString(keyColumns.get(partitionBoundColumns.get(0).name));
+ }
+ else
+ {
+ Iterator<ByteBuffer> iter = keyColumns.values().iterator();
+ for (BoundColumn column : partitionBoundColumns)
+ rowKey = rowKey + column.validator.getString(ByteBufferUtil.clone(iter.next())) + ":";
+ }
+
+ logger.debug("previous RowKey: " + previousRowKey + ", new row key: " + rowKey);
+ if (previousRowKey == null)
+ {
+ this.previousRowKey = rowKey;
+ return true;
+ }
+
+ if (rowKey.equals(previousRowKey))
+ return false;
+
+ this.previousRowKey = rowKey;
+ return true;
+ }
+
+ /** set the last non-null key value to null, and return the previous index */
+ private int setTailNull(List<BoundColumn> values)
+ {
+ if (values.isEmpty())
+ return -1;
+
+ Iterator<BoundColumn> iterator = values.iterator();
+ int previousIndex = -1;
+ BoundColumn current;
+ while (iterator.hasNext())
+ {
+ current = iterator.next();
+ if (current.value == null)
+ {
+ int index = previousIndex > 0 ? previousIndex : 0;
+ BoundColumn column = values.get(index);
+ logger.debug("set key " + column.name + " value to null");
+ column.value = null;
+ return previousIndex - 1;
+ }
+
+ previousIndex++;
+ }
+
+ BoundColumn column = values.get(previousIndex);
+ logger.debug("set key " + column.name + " value to null");
+ column.value = null;
+ return previousIndex - 1;
+ }
+
+ /** compose the prepared query, pair.left is query id, pair.right is query */
+ private Pair<Integer, String> composeQuery(String columns)
+ {
+ Pair<Integer, String> clause = whereClause();
+ if (columns == null)
+ {
+ columns = "*";
+ }
+ else
+ {
+ // add keys in the front in order
+ String partitionKey = keyString(partitionBoundColumns);
+ String clusterKey = keyString(clusterColumns);
+
+ columns = withoutKeyColumns(columns);
+ columns = (clusterKey == null || "".equals(clusterKey))
+ ? partitionKey + "," + columns
+ : partitionKey + "," + clusterKey + "," + columns;
+ }
+
+ return Pair.create(clause.left,
+ "SELECT " + columns
+ + " FROM " + cfName
+ + clause.right
+ + (userDefinedWhereClauses == null ? "" : " AND " + userDefinedWhereClauses)
+ + " LIMIT " + pageRowSize
+ + " ALLOW FILTERING");
+ }
+
+
+ /** remove key columns from the column string */
+ private String withoutKeyColumns(String columnString)
+ {
+ Set<String> keyNames = new HashSet<String>();
+ for (BoundColumn column : Iterables.concat(partitionBoundColumns, clusterColumns))
+ keyNames.add(column.name);
+
+ String[] columns = columnString.split(",");
+ String result = null;
+ for (String column : columns)
+ {
+ String trimmed = column.trim();
+ if (keyNames.contains(trimmed))
+ continue;
+
+ result = result == null ? trimmed : result + "," + trimmed;
+ }
+ return result;
+ }
+
+ /** compose the where clause */
+ private Pair<Integer, String> whereClause()
+ {
+ if (partitionKeyString == null)
+ partitionKeyString = keyString(partitionBoundColumns);
+
+ if (partitionKeyMarkers == null)
+ partitionKeyMarkers = partitionKeyMarkers();
+ // initial query token(k) >= start_token and token(k) <= end_token
+ if (emptyPartitionKeyValues())
+ return Pair.create(0, " WHERE token(" + partitionKeyString + ") > ? AND token(" + partitionKeyString + ") <= ?");
+
+ // query token(k) > token(pre_partition_key) and token(k) <= end_token
+ if (clusterColumns.size() == 0 || clusterColumns.get(0).value == null)
+ return Pair.create(1,
+ " WHERE token(" + partitionKeyString + ") > token(" + partitionKeyMarkers + ") "
+ + " AND token(" + partitionKeyString + ") <= ?");
+
+ // query token(k) = token(pre_partition_key) and m = pre_cluster_key_m and n > pre_cluster_key_n
+ Pair<Integer, String> clause = whereClause(clusterColumns, 0);
+ return Pair.create(clause.left,
+ " WHERE token(" + partitionKeyString + ") = token(" + partitionKeyMarkers + ") " + clause.right);
+ }
+
+ /** recursively compose the where clause */
+ private Pair<Integer, String> whereClause(List<BoundColumn> column, int position)
+ {
+ if (position == column.size() - 1 || column.get(position + 1).value == null)
+ return Pair.create(position + 2, " AND " + column.get(position).name + " > ? ");
+
+ Pair<Integer, String> clause = whereClause(column, position + 1);
+ return Pair.create(clause.left, " AND " + column.get(position).name + " = ? " + clause.right);
+ }
+
+ /** check whether all key values are null */
+ private boolean emptyPartitionKeyValues()
+ {
+ for (BoundColumn column : partitionBoundColumns)
+ {
+ if (column.value != null)
+ return false;
+ }
+ return true;
+ }
+
+ /** compose the partition key string in format of <key1>, <key2>, <key3> */
+ private String keyString(List<BoundColumn> columns)
+ {
+ String result = null;
+ for (BoundColumn column : columns)
+ result = result == null ? column.name : result + "," + column.name;
+
+ return result == null ? "" : result;
+ }
+
+ /** compose the question marks for partition key string in format of ?, ? , ? */
+ private String partitionKeyMarkers()
+ {
+ String result = null;
+ for (BoundColumn column : partitionBoundColumns)
+ result = result == null ? "?" : result + ",?";
+
+ return result;
+ }
+
+ /** compose the query binding variables, pair.left is query id, pair.right is the binding variables */
+ private Pair<Integer, List<ByteBuffer>> preparedQueryBindValues()
+ {
+ List<ByteBuffer> values = new LinkedList<ByteBuffer>();
+
+ // initial query token(k) >= start_token and token(k) <= end_token
+ if (emptyPartitionKeyValues())
+ {
+ values.add(partitioner.getTokenValidator().fromString(split.getStartToken()));
+ values.add(partitioner.getTokenValidator().fromString(split.getEndToken()));
+ return Pair.create(0, values);
+ }
+ else
+ {
+ for (BoundColumn partitionBoundColumn1 : partitionBoundColumns)
+ values.add(partitionBoundColumn1.value);
+
+ if (clusterColumns.size() == 0 || clusterColumns.get(0).value == null)
+ {
+ // query token(k) > token(pre_partition_key) and token(k) <= end_token
+ values.add(partitioner.getTokenValidator().fromString(split.getEndToken()));
+ return Pair.create(1, values);
+ }
+ else
+ {
+ // query token(k) = token(pre_partition_key) and m = pre_cluster_key_m and n > pre_cluster_key_n
+ int type = preparedQueryBindValues(clusterColumns, 0, values);
+ return Pair.create(type, values);
+ }
+ }
+ }
+
+ /** recursively compose the query binding variables */
+ private int preparedQueryBindValues(List<BoundColumn> column, int position, List<ByteBuffer> bindValues)
+ {
+ if (position == column.size() - 1 || column.get(position + 1).value == null)
+ {
+ bindValues.add(column.get(position).value);
+ return position + 2;
+ }
+ else
+ {
+ bindValues.add(column.get(position).value);
+ return preparedQueryBindValues(column, position + 1, bindValues);
+ }
+ }
+
+ /** get the prepared query item Id */
+ private int prepareQuery(int type) throws InvalidRequestException, TException
+ {
+ Integer itemId = preparedQueryIds.get(type);
+ if (itemId != null)
+ return itemId;
+
+ Pair<Integer, String> query = null;
+ query = composeQuery(columns);
+ logger.debug("type:" + query.left + ", query: " + query.right);
+ CqlPreparedResult cqlPreparedResult = client.prepare_cql3_query(ByteBufferUtil.bytes(query.right), Compression.NONE);
+ preparedQueryIds.put(query.left, cqlPreparedResult.itemId);
+ return cqlPreparedResult.itemId;
+ }
+
+ /** execute the prepared query */
+ private void executeQuery()
+ {
+ Pair<Integer, List<ByteBuffer>> bindValues = preparedQueryBindValues();
+ logger.debug("query type: " + bindValues.left);
+
+ // check whether it reach end of range for type 1 query CASSANDRA-5573
+ if (bindValues.left == 1 && reachEndRange())
+ {
+ rows = null;
+ return;
+ }
+
+ int retries = 0;
+ // only try three times for TimedOutException and UnavailableException
+ while (retries < 3)
+ {
+ try
+ {
+ CqlResult cqlResult = client.execute_prepared_cql3_query(prepareQuery(bindValues.left), bindValues.right, consistencyLevel);
+ if (cqlResult != null && cqlResult.rows != null)
+ rows = cqlResult.rows.iterator();
+ return;
+ }
+ catch (TimedOutException e)
+ {
+ retries++;
+ if (retries >= 3)
+ {
+ rows = null;
+ RuntimeException rte = new RuntimeException(e.getMessage());
+ rte.initCause(e);
+ throw rte;
+ }
+ }
+ catch (UnavailableException e)
+ {
+ retries++;
+ if (retries >= 3)
+ {
+ rows = null;
+ RuntimeException rte = new RuntimeException(e.getMessage());
+ rte.initCause(e);
+ throw rte;
+ }
+ }
+ catch (Exception e)
+ {
+ rows = null;
+ RuntimeException rte = new RuntimeException(e.getMessage());
+ rte.initCause(e);
+ throw rte;
+ }
+ }
+ }
+ }
+
+ /** retrieve the partition keys and cluster keys from system.schema_columnfamilies table */
+ private void retrieveKeys() throws Exception
+ {
+ String query = "select key_aliases," +
+ "column_aliases, " +
+ "key_validator, " +
+ "comparator " +
+ "from system.schema_columnfamilies " +
+ "where keyspace_name='%s' and columnfamily_name='%s'";
+ String formatted = String.format(query, keyspace, cfName);
+ CqlResult result = client.execute_cql3_query(ByteBufferUtil.bytes(formatted), Compression.NONE, ConsistencyLevel.ONE);
+
+ CqlRow cqlRow = result.rows.get(0);
+ String keyString = ByteBufferUtil.string(ByteBuffer.wrap(cqlRow.columns.get(0).getValue()));
+ logger.debug("partition keys: " + keyString);
+ List<String> keys = FBUtilities.fromJsonList(keyString);
+
+ for (String key : keys)
+ partitionBoundColumns.add(new BoundColumn(key));
+
+ keyString = ByteBufferUtil.string(ByteBuffer.wrap(cqlRow.columns.get(1).getValue()));
+ logger.debug("cluster columns: " + keyString);
+ keys = FBUtilities.fromJsonList(keyString);
+
+ for (String key : keys)
+ clusterColumns.add(new BoundColumn(key));
+
+ Column rawKeyValidator = cqlRow.columns.get(2);
+ String validator = ByteBufferUtil.string(ByteBuffer.wrap(rawKeyValidator.getValue()));
+ logger.debug("row key validator: " + validator);
+ keyValidator = parseType(validator);
+
+ if (keyValidator instanceof CompositeType)
+ {
+ List<AbstractType<?>> types = ((CompositeType) keyValidator).types;
+ for (int i = 0; i < partitionBoundColumns.size(); i++)
+ partitionBoundColumns.get(i).validator = types.get(i);
+ }
+ else
+ {
+ partitionBoundColumns.get(0).validator = keyValidator;
+ }
+ }
+
+ /** check whether current row is at the end of range */
+ private boolean reachEndRange()
+ {
+ // current row key
+ ByteBuffer rowKey;
+ if (keyValidator instanceof CompositeType)
+ {
+ ByteBuffer[] keys = new ByteBuffer[partitionBoundColumns.size()];
+ for (int i = 0; i < partitionBoundColumns.size(); i++)
+ keys[i] = partitionBoundColumns.get(i).value.duplicate();
+
+ rowKey = ((CompositeType) keyValidator).build(keys);
+ }
+ else
+ {
+ rowKey = partitionBoundColumns.get(0).value;
+ }
+
+ String endToken = split.getEndToken();
+ String currentToken = partitioner.getToken(rowKey).toString();
+ logger.debug("End token: " + endToken + ", current token: " + currentToken);
+
+ return endToken.equals(currentToken);
+ }
+
+ private static AbstractType<?> parseType(String type) throws IOException
+ {
+ try
+ {
+ // always treat counters like longs, specifically CCT.compose is not what we need
+ if (type != null && type.equals("org.apache.cassandra.db.marshal.CounterColumnType"))
+ return LongType.instance;
+ return TypeParser.parse(type);
+ }
+ catch (ConfigurationException e)
+ {
+ throw new IOException(e);
+ }
+ catch (SyntaxException e)
+ {
+ throw new IOException(e);
+ }
+ }
+
+ private class BoundColumn
+ {
+ final String name;
+ ByteBuffer value;
+ AbstractType<?> validator;
+
+ public BoundColumn(String name)
+ {
+ this.name = name;
+ }
+ }
+
+ /** get string from a ByteBuffer, catch the exception and throw it as runtime exception*/
+ private static String stringValue(ByteBuffer value)
+ {
+ try
+ {
+ return ByteBufferUtil.string(value);
+ }
+ catch (CharacterCodingException e)
+ {
+ throw new RuntimeException(e);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/cassandra/blob/a0047797/src/java/org/apache/cassandra/hadoop/cql3/CqlRecordWriter.java
----------------------------------------------------------------------
diff --git a/src/java/org/apache/cassandra/hadoop/cql3/CqlRecordWriter.java b/src/java/org/apache/cassandra/hadoop/cql3/CqlRecordWriter.java
new file mode 100644
index 0000000..dde6b1f
--- /dev/null
+++ b/src/java/org/apache/cassandra/hadoop/cql3/CqlRecordWriter.java
@@ -0,0 +1,383 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.cassandra.hadoop.cql3;
+
+import java.io.IOException;
+import java.net.InetAddress;
+import java.nio.ByteBuffer;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+import org.apache.cassandra.thrift.*;
+import org.apache.cassandra.db.marshal.AbstractType;
+import org.apache.cassandra.db.marshal.CompositeType;
+import org.apache.cassandra.db.marshal.LongType;
+import org.apache.cassandra.db.marshal.TypeParser;
+import org.apache.cassandra.dht.Range;
+import org.apache.cassandra.dht.Token;
+import org.apache.cassandra.exceptions.ConfigurationException;
+import org.apache.cassandra.exceptions.SyntaxException;
+import org.apache.cassandra.hadoop.AbstractColumnFamilyRecordWriter;
+import org.apache.cassandra.hadoop.ConfigHelper;
+import org.apache.cassandra.hadoop.Progressable;
+import org.apache.cassandra.utils.ByteBufferUtil;
+import org.apache.cassandra.utils.FBUtilities;
+import org.apache.cassandra.utils.Pair;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.mapreduce.TaskAttemptContext;
+import org.apache.thrift.TException;
+import org.apache.thrift.transport.TTransport;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The <code>ColumnFamilyRecordWriter</code> maps the output <key, value>
+ * pairs to a Cassandra column family. In particular, it applies the binded variables
+ * in the value to the prepared statement, which it associates with the key, and in
+ * turn the responsible endpoint.
+ *
+ * <p>
+ * Furthermore, this writer groups the cql queries by the endpoint responsible for
+ * the rows being affected. This allows the cql queries to be executed in parallel,
+ * directly to a responsible endpoint.
+ * </p>
+ *
+ * @see CqlOutputFormat
+ */
+final class CqlRecordWriter extends AbstractColumnFamilyRecordWriter<Map<String, ByteBuffer>, List<ByteBuffer>>
+{
+ private static final Logger logger = LoggerFactory.getLogger(CqlRecordWriter.class);
+
+ // handles for clients for each range running in the threadpool
+ private final Map<Range, RangeClient> clients;
+
+ // host to prepared statement id mappings
+ private ConcurrentHashMap<Cassandra.Client, Integer> preparedStatements = new ConcurrentHashMap<Cassandra.Client, Integer>();
+
+ private final String cql;
+
+ private AbstractType<?> keyValidator;
+ private String [] partitionkeys;
+
+ /**
+ * Upon construction, obtain the map that this writer will use to collect
+ * mutations, and the ring cache for the given keyspace.
+ *
+ * @param context the task attempt context
+ * @throws IOException
+ */
+ CqlRecordWriter(TaskAttemptContext context) throws IOException
+ {
+ this(context.getConfiguration());
+ this.progressable = new Progressable(context);
+ }
+
+ CqlRecordWriter(Configuration conf, Progressable progressable) throws IOException
+ {
+ this(conf);
+ this.progressable = progressable;
+ }
+
+ CqlRecordWriter(Configuration conf) throws IOException
+ {
+ super(conf);
+ this.clients = new HashMap<Range, RangeClient>();
+ cql = CqlConfigHelper.getOutputCql(conf);
+
+ try
+ {
+ String host = getAnyHost();
+ int port = ConfigHelper.getOutputRpcPort(conf);
+ Cassandra.Client client = CqlOutputFormat.createAuthenticatedClient(host, port, conf);
+ retrievePartitionKeyValidator(client);
+
+ if (client != null)
+ {
+ TTransport transport = client.getOutputProtocol().getTransport();
+ if (transport.isOpen())
+ transport.close();
+ client = null;
+ }
+ }
+ catch (Exception e)
+ {
+ throw new IOException(e);
+ }
+ }
+
+ @Override
+ public void close() throws IOException
+ {
+ // close all the clients before throwing anything
+ IOException clientException = null;
+ for (RangeClient client : clients.values())
+ {
+ try
+ {
+ client.close();
+ }
+ catch (IOException e)
+ {
+ clientException = e;
+ }
+ }
+
+ if (clientException != null)
+ throw clientException;
+ }
+
+ /**
+ * If the key is to be associated with a valid value, a mutation is created
+ * for it with the given column family and columns. In the event the value
+ * in the column is missing (i.e., null), then it is marked for
+ * {@link Deletion}. Similarly, if the entire value for a key is missing
+ * (i.e., null), then the entire key is marked for {@link Deletion}.
+ * </p>
+ *
+ * @param keyColumns
+ * the key to write.
+ * @param values
+ * the values to write.
+ * @throws IOException
+ */
+ @Override
+ public void write(Map<String, ByteBuffer> keyColumns, List<ByteBuffer> values) throws IOException
+ {
+ ByteBuffer rowKey = getRowKey(keyColumns);
+ Range<Token> range = ringCache.getRange(rowKey);
+
+ // get the client for the given range, or create a new one
+ RangeClient client = clients.get(range);
+ if (client == null)
+ {
+ // haven't seen keys for this range: create new client
+ client = new RangeClient(ringCache.getEndpoint(range));
+ client.start();
+ clients.put(range, client);
+ }
+
+ client.put(Pair.create(rowKey, values));
+ progressable.progress();
+ }
+
+ /**
+ * A client that runs in a threadpool and connects to the list of endpoints for a particular
+ * range. Bound variables for keys in that range are sent to this client via a queue.
+ */
+ public class RangeClient extends AbstractRangeClient<List<ByteBuffer>>
+ {
+ /**
+ * Constructs an {@link RangeClient} for the given endpoints.
+ * @param endpoints the possible endpoints to execute the mutations on
+ */
+ public RangeClient(List<InetAddress> endpoints)
+ {
+ super(endpoints);
+ }
+
+ /**
+ * Loops collecting cql binded variable values from the queue and sending to Cassandra
+ */
+ public void run()
+ {
+ outer:
+ while (run || !queue.isEmpty())
+ {
+ Pair<ByteBuffer, List<ByteBuffer>> item;
+ try
+ {
+ item = queue.take();
+ }
+ catch (InterruptedException e)
+ {
+ // re-check loop condition after interrupt
+ continue;
+ }
+
+ Iterator<InetAddress> iter = endpoints.iterator();
+ while (true)
+ {
+ // send the mutation to the last-used endpoint. first time through, this will NPE harmlessly.
+ try
+ {
+ int i = 0;
+ int itemId = preparedStatement(client);
+ while (item != null)
+ {
+ List<ByteBuffer> bindVariables = item.right;
+ client.execute_prepared_cql3_query(itemId, bindVariables, ConsistencyLevel.ONE);
+ i++;
+
+ if (i >= batchThreshold)
+ break;
+
+ item = queue.poll();
+ }
+
+ break;
+ }
+ catch (Exception e)
+ {
+ closeInternal();
+ if (!iter.hasNext())
+ {
+ lastException = new IOException(e);
+ break outer;
+ }
+ }
+
+ // attempt to connect to a different endpoint
+ try
+ {
+ InetAddress address = iter.next();
+ String host = address.getHostName();
+ int port = ConfigHelper.getOutputRpcPort(conf);
+ client = CqlOutputFormat.createAuthenticatedClient(host, port, conf);
+ }
+ catch (Exception e)
+ {
+ closeInternal();
+ // TException means something unexpected went wrong to that endpoint, so
+ // we should try again to another. Other exceptions (auth or invalid request) are fatal.
+ if ((!(e instanceof TException)) || !iter.hasNext())
+ {
+ lastException = new IOException(e);
+ break outer;
+ }
+ }
+ }
+ }
+ }
+
+ /** get prepared statement id from cache, otherwise prepare it from Cassandra server*/
+ private int preparedStatement(Cassandra.Client client)
+ {
+ Integer itemId = preparedStatements.get(client);
+ if (itemId == null)
+ {
+ CqlPreparedResult result;
+ try
+ {
+ result = client.prepare_cql3_query(ByteBufferUtil.bytes(cql), Compression.NONE);
+ }
+ catch (InvalidRequestException e)
+ {
+ throw new RuntimeException("failed to prepare cql query " + cql, e);
+ }
+ catch (TException e)
+ {
+ throw new RuntimeException("failed to prepare cql query " + cql, e);
+ }
+
+ Integer previousId = preparedStatements.putIfAbsent(client, Integer.valueOf(result.itemId));
+ itemId = previousId == null ? result.itemId : previousId;
+ }
+ return itemId;
+ }
+ }
+
+ private ByteBuffer getRowKey(Map<String, ByteBuffer> keyColumns)
+ {
+ //current row key
+ ByteBuffer rowKey;
+ if (keyValidator instanceof CompositeType)
+ {
+ ByteBuffer[] keys = new ByteBuffer[partitionkeys.length];
+ for (int i = 0; i< keys.length; i++)
+ keys[i] = keyColumns.get(partitionkeys[i]);
+
+ rowKey = ((CompositeType) keyValidator).build(keys);
+ }
+ else
+ {
+ rowKey = keyColumns.get(partitionkeys[0]);
+ }
+ return rowKey;
+ }
+
+ /** retrieve the key validator from system.schema_columnfamilies table */
+ private void retrievePartitionKeyValidator(Cassandra.Client client) throws Exception
+ {
+ String keyspace = ConfigHelper.getOutputKeyspace(conf);
+ String cfName = ConfigHelper.getOutputColumnFamily(conf);
+ String query = "SELECT key_validator," +
+ " key_aliases " +
+ "FROM system.schema_columnfamilies " +
+ "WHERE keyspace_name='%s' and columnfamily_name='%s'";
+ String formatted = String.format(query, keyspace, cfName);
+ CqlResult result = client.execute_cql3_query(ByteBufferUtil.bytes(formatted), Compression.NONE, ConsistencyLevel.ONE);
+
+ Column rawKeyValidator = result.rows.get(0).columns.get(0);
+ String validator = ByteBufferUtil.string(ByteBuffer.wrap(rawKeyValidator.getValue()));
+ keyValidator = parseType(validator);
+
+ Column rawPartitionKeys = result.rows.get(0).columns.get(1);
+ String keyString = ByteBufferUtil.string(ByteBuffer.wrap(rawPartitionKeys.getValue()));
+ logger.debug("partition keys: " + keyString);
+
+ List<String> keys = FBUtilities.fromJsonList(keyString);
+ partitionkeys = new String[keys.size()];
+ int i = 0;
+ for (String key : keys)
+ {
+ partitionkeys[i] = key;
+ i++;
+ }
+ }
+
+ private AbstractType<?> parseType(String type) throws IOException
+ {
+ try
+ {
+ // always treat counters like longs, specifically CCT.compose is not what we need
+ if (type != null && type.equals("org.apache.cassandra.db.marshal.CounterColumnType"))
+ return LongType.instance;
+ return TypeParser.parse(type);
+ }
+ catch (ConfigurationException e)
+ {
+ throw new IOException(e);
+ }
+ catch (SyntaxException e)
+ {
+ throw new IOException(e);
+ }
+ }
+
+ private String getAnyHost() throws IOException, InvalidRequestException, TException
+ {
+ Cassandra.Client client = ConfigHelper.getClientFromOutputAddressList(conf);
+ List<TokenRange> ring = client.describe_ring(ConfigHelper.getOutputKeyspace(conf));
+ try
+ {
+ for (TokenRange range : ring)
+ return range.endpoints.get(0);
+ }
+ finally
+ {
+ TTransport transport = client.getOutputProtocol().getTransport();
+ if (transport.isOpen())
+ transport.close();
+ }
+ throw new IOException("There are no endpoints");
+ }
+
+}