You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2017/03/16 18:00:02 UTC
[3/4] arrow git commit: ARROW-542: Adding dictionary encoding to
FileWriter
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java
index 074b0aa..a12440e 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java
@@ -24,6 +24,10 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ObjectArrays;
+
+import io.netty.buffer.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.OutOfMemoryException;
import org.apache.arrow.vector.AddOrGetResult;
@@ -42,16 +46,12 @@ import org.apache.arrow.vector.complex.writer.FieldWriter;
import org.apache.arrow.vector.schema.ArrowFieldNode;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.Types.MinorType;
+import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.util.CallBack;
import org.apache.arrow.vector.util.JsonStringArrayList;
import org.apache.arrow.vector.util.TransferPair;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ObjectArrays;
-
-import io.netty.buffer.ArrowBuf;
-
public class ListVector extends BaseRepeatedValueVector implements FieldVector {
final UInt4Vector offsets;
@@ -62,14 +62,16 @@ public class ListVector extends BaseRepeatedValueVector implements FieldVector {
private UnionListWriter writer;
private UnionListReader reader;
private CallBack callBack;
+ private final DictionaryEncoding dictionary;
- public ListVector(String name, BufferAllocator allocator, CallBack callBack) {
+ public ListVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack) {
super(name, allocator);
this.bits = new BitVector("$bits$", allocator);
this.offsets = getOffsetVector();
this.innerVectors = Collections.unmodifiableList(Arrays.<BufferBacked>asList(bits, offsets));
this.writer = new UnionListWriter(this);
this.reader = new UnionListReader(this);
+ this.dictionary = dictionary;
this.callBack = callBack;
}
@@ -80,7 +82,7 @@ public class ListVector extends BaseRepeatedValueVector implements FieldVector {
}
Field field = children.get(0);
MinorType minorType = Types.getMinorTypeForArrowType(field.getType());
- AddOrGetResult<FieldVector> addOrGetVector = addOrGetVector(minorType);
+ AddOrGetResult<FieldVector> addOrGetVector = addOrGetVector(minorType, field.getDictionary());
if (!addOrGetVector.isCreated()) {
throw new IllegalArgumentException("Child vector already existed: " + addOrGetVector.getVector());
}
@@ -151,16 +153,16 @@ public class ListVector extends BaseRepeatedValueVector implements FieldVector {
TransferPair pairs[] = new TransferPair[3];
public TransferImpl(String name, BufferAllocator allocator) {
- this(new ListVector(name, allocator, null));
+ this(new ListVector(name, allocator, dictionary, null));
}
public TransferImpl(ListVector to) {
this.to = to;
- to.addOrGetVector(vector.getMinorType());
+ to.addOrGetVector(vector.getMinorType(), vector.getField().getDictionary());
pairs[0] = offsets.makeTransferPair(to.offsets);
pairs[1] = bits.makeTransferPair(to.bits);
if (to.getDataVector() instanceof ZeroVector) {
- to.addOrGetVector(vector.getMinorType());
+ to.addOrGetVector(vector.getMinorType(), vector.getField().getDictionary());
}
pairs[2] = getDataVector().makeTransferPair(to.getDataVector());
}
@@ -232,8 +234,8 @@ public class ListVector extends BaseRepeatedValueVector implements FieldVector {
return success;
}
- public <T extends ValueVector> AddOrGetResult<T> addOrGetVector(MinorType minorType) {
- AddOrGetResult<T> result = super.addOrGetVector(minorType);
+ public <T extends ValueVector> AddOrGetResult<T> addOrGetVector(MinorType minorType, DictionaryEncoding dictionary) {
+ AddOrGetResult<T> result = super.addOrGetVector(minorType, dictionary);
reader = new UnionListReader(this);
return result;
}
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java
index 31a1bb7..4d750ca 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java
@@ -160,7 +160,7 @@ public class MapVector extends AbstractMapVector {
// (This is similar to what happens in ScanBatch where the children cannot be added till they are
// read). To take care of this, we ensure that the hashCode of the MaterializedField does not
// include the hashCode of the children but is based only on MaterializedField$key.
- final FieldVector newVector = to.addOrGet(child, vector.getMinorType(), vector.getClass());
+ final FieldVector newVector = to.addOrGet(child, vector.getMinorType(), vector.getClass(), vector.getField().getDictionary());
if (allocate && to.size() != preSize) {
newVector.allocateNew();
}
@@ -314,12 +314,11 @@ public class MapVector extends AbstractMapVector {
public void initializeChildrenFromFields(List<Field> children) {
for (Field field : children) {
MinorType minorType = Types.getMinorTypeForArrowType(field.getType());
- FieldVector vector = (FieldVector)this.add(field.getName(), minorType);
+ FieldVector vector = (FieldVector)this.add(field.getName(), minorType, field.getDictionary());
vector.initializeChildrenFromFields(field.getChildren());
}
}
-
public List<FieldVector> getChildrenFromFields() {
return getChildren();
}
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/complex/NullableMapVector.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/NullableMapVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/NullableMapVector.java
index 5fa3530..bb1fdf8 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/complex/NullableMapVector.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/NullableMapVector.java
@@ -34,6 +34,7 @@ import org.apache.arrow.vector.complex.impl.NullableMapReaderImpl;
import org.apache.arrow.vector.complex.reader.FieldReader;
import org.apache.arrow.vector.holders.ComplexHolder;
import org.apache.arrow.vector.schema.ArrowFieldNode;
+import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
import org.apache.arrow.vector.util.CallBack;
import org.apache.arrow.vector.util.TransferPair;
@@ -48,14 +49,16 @@ public class NullableMapVector extends MapVector implements FieldVector {
protected final BitVector bits;
private final List<BufferBacked> innerVectors;
+ private final DictionaryEncoding dictionary;
private final Accessor accessor;
private final Mutator mutator;
- public NullableMapVector(String name, BufferAllocator allocator, CallBack callBack) {
+ public NullableMapVector(String name, BufferAllocator allocator, DictionaryEncoding dictionary, CallBack callBack) {
super(name, checkNotNull(allocator), callBack);
this.bits = new BitVector("$bits$", allocator);
this.innerVectors = Collections.unmodifiableList(Arrays.<BufferBacked>asList(bits));
+ this.dictionary = dictionary;
this.accessor = new Accessor();
this.mutator = new Mutator();
}
@@ -83,7 +86,7 @@ public class NullableMapVector extends MapVector implements FieldVector {
@Override
public TransferPair getTransferPair(BufferAllocator allocator) {
- return new NullableMapTransferPair(this, new NullableMapVector(name, allocator, callBack), false);
+ return new NullableMapTransferPair(this, new NullableMapVector(name, allocator, dictionary, callBack), false);
}
@Override
@@ -93,7 +96,7 @@ public class NullableMapVector extends MapVector implements FieldVector {
@Override
public TransferPair getTransferPair(String ref, BufferAllocator allocator) {
- return new NullableMapTransferPair(this, new NullableMapVector(ref, allocator, callBack), false);
+ return new NullableMapTransferPair(this, new NullableMapVector(ref, allocator, dictionary, callBack), false);
}
protected class NullableMapTransferPair extends MapTransferPair {
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/ComplexWriterImpl.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/ComplexWriterImpl.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/ComplexWriterImpl.java
index dbdd205..6d05316 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/ComplexWriterImpl.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/ComplexWriterImpl.java
@@ -149,7 +149,8 @@ public class ComplexWriterImpl extends AbstractFieldWriter implements ComplexWri
switch(mode){
case INIT:
- NullableMapVector map = container.addOrGet(name, MinorType.MAP, NullableMapVector.class);
+ // TODO allow dictionaries in complex types
+ NullableMapVector map = container.addOrGet(name, MinorType.MAP, NullableMapVector.class, null);
mapRoot = nullableMapWriterFactory.build(map);
mapRoot.setPosition(idx());
mode = Mode.MAP;
@@ -180,7 +181,8 @@ public class ComplexWriterImpl extends AbstractFieldWriter implements ComplexWri
case INIT:
int vectorCount = container.size();
- ListVector listVector = container.addOrGet(name, MinorType.LIST, ListVector.class);
+ // TODO allow dictionaries in complex types
+ ListVector listVector = container.addOrGet(name, MinorType.LIST, ListVector.class, null);
if (container.size() > vectorCount) {
listVector.allocateNew();
}
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java
index 1f7253b..e33319a 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java
@@ -125,7 +125,7 @@ public class PromotableWriter extends AbstractPromotableFieldWriter {
// ???
return null;
}
- ValueVector v = listVector.addOrGetVector(type).getVector();
+ ValueVector v = listVector.addOrGetVector(type, null).getVector();
v.allocateNew();
setWriter(v);
writer.setPosition(position);
@@ -150,7 +150,8 @@ public class PromotableWriter extends AbstractPromotableFieldWriter {
TransferPair tp = vector.getTransferPair(vector.getMinorType().name().toLowerCase(), vector.getAllocator());
tp.transfer();
if (parentContainer != null) {
- unionVector = parentContainer.addOrGet(name, MinorType.UNION, UnionVector.class);
+ // TODO allow dictionaries in complex types
+ unionVector = parentContainer.addOrGet(name, MinorType.UNION, UnionVector.class, null);
unionVector.allocateNew();
} else if (listVector != null) {
unionVector = listVector.promoteToUnion();
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java
new file mode 100644
index 0000000..0c1cadf
--- /dev/null
+++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java
@@ -0,0 +1,66 @@
+/*******************************************************************************
+
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ ******************************************************************************/
+package org.apache.arrow.vector.dictionary;
+
+import java.util.Objects;
+
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
+
+public class Dictionary {
+
+ private final DictionaryEncoding encoding;
+ private final FieldVector dictionary;
+
+ public Dictionary(FieldVector dictionary, DictionaryEncoding encoding) {
+ this.dictionary = dictionary;
+ this.encoding = encoding;
+ }
+
+ public FieldVector getVector() {
+ return dictionary;
+ }
+
+ public DictionaryEncoding getEncoding() {
+ return encoding;
+ }
+
+ public ArrowType getVectorType() {
+ return dictionary.getField().getType();
+ }
+
+ @Override
+ public String toString() {
+ return "Dictionary " + encoding + " " + dictionary;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ Dictionary that = (Dictionary) o;
+ return Objects.equals(encoding, that.encoding) && Objects.equals(dictionary, that.dictionary);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(encoding, dictionary);
+ }
+}
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java
new file mode 100644
index 0000000..0666bc4
--- /dev/null
+++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryEncoder.java
@@ -0,0 +1,144 @@
+/*******************************************************************************
+
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ ******************************************************************************/
+package org.apache.arrow.vector.dictionary;
+
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.HashMap;
+import java.util.Map;
+
+import com.google.common.collect.ImmutableList;
+
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.ValueVector;
+import org.apache.arrow.vector.types.Types.MinorType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.util.TransferPair;
+
+public class DictionaryEncoder {
+
+ // TODO recursively examine fields?
+
+ /**
+ * Dictionary encodes a vector with a provided dictionary. The dictionary must contain all values in the vector.
+ *
+ * @param vector vector to encode
+ * @param dictionary dictionary used for encoding
+ * @return dictionary encoded vector
+ */
+ public static ValueVector encode(ValueVector vector, Dictionary dictionary) {
+ validateType(vector.getMinorType());
+ // load dictionary values into a hashmap for lookup
+ ValueVector.Accessor dictionaryAccessor = dictionary.getVector().getAccessor();
+ Map<Object, Integer> lookUps = new HashMap<>(dictionaryAccessor.getValueCount());
+ for (int i = 0; i < dictionaryAccessor.getValueCount(); i++) {
+ // for primitive array types we need a wrapper that implements equals and hashcode appropriately
+ lookUps.put(dictionaryAccessor.getObject(i), i);
+ }
+
+ Field valueField = vector.getField();
+ Field indexField = new Field(valueField.getName(), valueField.isNullable(),
+ dictionary.getEncoding().getIndexType(), dictionary.getEncoding(), null);
+
+ // vector to hold our indices (dictionary encoded values)
+ FieldVector indices = indexField.createVector(vector.getAllocator());
+ ValueVector.Mutator mutator = indices.getMutator();
+
+ // use reflection to pull out the set method
+ // TODO implement a common interface for int vectors
+ Method setter = null;
+ for (Class<?> c: ImmutableList.of(int.class, long.class)) {
+ try {
+ setter = mutator.getClass().getMethod("set", int.class, c);
+ break;
+ } catch(NoSuchMethodException e) {
+ // ignore
+ }
+ }
+ if (setter == null) {
+ throw new IllegalArgumentException("Dictionary encoding does not have a valid int type:" + indices.getClass());
+ }
+
+ ValueVector.Accessor accessor = vector.getAccessor();
+ int count = accessor.getValueCount();
+
+ indices.allocateNew();
+
+ try {
+ for (int i = 0; i < count; i++) {
+ Object value = accessor.getObject(i);
+ if (value != null) { // if it's null leave it null
+ // note: this may fail if value was not included in the dictionary
+ Object encoded = lookUps.get(value);
+ if (encoded == null) {
+ throw new IllegalArgumentException("Dictionary encoding not defined for value:" + value);
+ }
+ setter.invoke(mutator, i, encoded);
+ }
+ }
+ } catch (IllegalAccessException e) {
+ throw new RuntimeException("IllegalAccessException invoking vector mutator set():", e);
+ } catch (InvocationTargetException e) {
+ throw new RuntimeException("InvocationTargetException invoking vector mutator set():", e.getCause());
+ }
+
+ mutator.setValueCount(count);
+
+ return indices;
+ }
+
+ /**
+ * Decodes a dictionary encoded array using the provided dictionary.
+ *
+ * @param indices dictionary encoded values, must be int type
+ * @param dictionary dictionary used to decode the values
+ * @return vector with values restored from dictionary
+ */
+ public static ValueVector decode(ValueVector indices, Dictionary dictionary) {
+ ValueVector.Accessor accessor = indices.getAccessor();
+ int count = accessor.getValueCount();
+ ValueVector dictionaryVector = dictionary.getVector();
+ int dictionaryCount = dictionaryVector.getAccessor().getValueCount();
+ // copy the dictionary values into the decoded vector
+ TransferPair transfer = dictionaryVector.getTransferPair(indices.getAllocator());
+ transfer.getTo().allocateNewSafe();
+ for (int i = 0; i < count; i++) {
+ Object index = accessor.getObject(i);
+ if (index != null) {
+ int indexAsInt = ((Number) index).intValue();
+ if (indexAsInt > dictionaryCount) {
+ throw new IllegalArgumentException("Provided dictionary does not contain value for index " + indexAsInt);
+ }
+ transfer.copyValueSafe(indexAsInt, i);
+ }
+ }
+ // TODO do we need to worry about the field?
+ ValueVector decoded = transfer.getTo();
+ decoded.getMutator().setValueCount(count);
+ return decoded;
+ }
+
+ private static void validateType(MinorType type) {
+ // byte arrays don't work as keys in our dictionary map - we could wrap them with something to
+ // implement equals and hashcode if we want that functionality
+ if (type == MinorType.VARBINARY || type == MinorType.LIST || type == MinorType.MAP || type == MinorType.UNION) {
+ throw new IllegalArgumentException("Dictionary encoding for complex types not implemented: type " + type);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java
new file mode 100644
index 0000000..63fde25
--- /dev/null
+++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryProvider.java
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.vector.dictionary;
+
+import java.util.HashMap;
+import java.util.Map;
+
+public interface DictionaryProvider {
+
+ public Dictionary lookup(long id);
+
+ public static class MapDictionaryProvider implements DictionaryProvider {
+
+ private final Map<Long, Dictionary> map;
+
+ public MapDictionaryProvider(Dictionary... dictionaries) {
+ this.map = new HashMap<>();
+ for (Dictionary dictionary: dictionaries) {
+ put(dictionary);
+ }
+ }
+
+ public void put(Dictionary dictionary) {
+ map.put(dictionary.getEncoding().getId(), dictionary);
+ }
+
+ @Override
+ public Dictionary lookup(long id) {
+ return map.get(id);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileReader.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileReader.java
new file mode 100644
index 0000000..28440a1
--- /dev/null
+++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileReader.java
@@ -0,0 +1,142 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.vector.file;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.SeekableByteChannel;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.arrow.flatbuf.Footer;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.schema.ArrowDictionaryBatch;
+import org.apache.arrow.vector.schema.ArrowMessage;
+import org.apache.arrow.vector.schema.ArrowRecordBatch;
+import org.apache.arrow.vector.stream.MessageSerializer;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class ArrowFileReader extends ArrowReader<SeekableReadChannel> {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFileReader.class);
+
+ private ArrowFooter footer;
+ private int currentDictionaryBatch = 0;
+ private int currentRecordBatch = 0;
+
+ public ArrowFileReader(SeekableByteChannel in, BufferAllocator allocator) {
+ super(new SeekableReadChannel(in), allocator);
+ }
+
+ public ArrowFileReader(SeekableReadChannel in, BufferAllocator allocator) {
+ super(in, allocator);
+ }
+
+ @Override
+ protected Schema readSchema(SeekableReadChannel in) throws IOException {
+ if (footer == null) {
+ if (in.size() <= (ArrowMagic.MAGIC_LENGTH * 2 + 4)) {
+ throw new InvalidArrowFileException("file too small: " + in.size());
+ }
+ ByteBuffer buffer = ByteBuffer.allocate(4 + ArrowMagic.MAGIC_LENGTH);
+ long footerLengthOffset = in.size() - buffer.remaining();
+ in.setPosition(footerLengthOffset);
+ in.readFully(buffer);
+ buffer.flip();
+ byte[] array = buffer.array();
+ if (!ArrowMagic.validateMagic(Arrays.copyOfRange(array, 4, array.length))) {
+ throw new InvalidArrowFileException("missing Magic number " + Arrays.toString(buffer.array()));
+ }
+ int footerLength = MessageSerializer.bytesToInt(array);
+ if (footerLength <= 0 || footerLength + ArrowMagic.MAGIC_LENGTH * 2 + 4 > in.size()) {
+ throw new InvalidArrowFileException("invalid footer length: " + footerLength);
+ }
+ long footerOffset = footerLengthOffset - footerLength;
+ LOGGER.debug(String.format("Footer starts at %d, length: %d", footerOffset, footerLength));
+ ByteBuffer footerBuffer = ByteBuffer.allocate(footerLength);
+ in.setPosition(footerOffset);
+ in.readFully(footerBuffer);
+ footerBuffer.flip();
+ Footer footerFB = Footer.getRootAsFooter(footerBuffer);
+ this.footer = new ArrowFooter(footerFB);
+ }
+ return footer.getSchema();
+ }
+
+ @Override
+ protected ArrowMessage readMessage(SeekableReadChannel in, BufferAllocator allocator) throws IOException {
+ if (currentDictionaryBatch < footer.getDictionaries().size()) {
+ ArrowBlock block = footer.getDictionaries().get(currentDictionaryBatch++);
+ return readDictionaryBatch(in, block, allocator);
+ } else if (currentRecordBatch < footer.getRecordBatches().size()) {
+ ArrowBlock block = footer.getRecordBatches().get(currentRecordBatch++);
+ return readRecordBatch(in, block, allocator);
+ } else {
+ return null;
+ }
+ }
+
+ public List<ArrowBlock> getDictionaryBlocks() throws IOException {
+ ensureInitialized();
+ return footer.getDictionaries();
+ }
+
+ public List<ArrowBlock> getRecordBlocks() throws IOException {
+ ensureInitialized();
+ return footer.getRecordBatches();
+ }
+
+ public void loadRecordBatch(ArrowBlock block) throws IOException {
+ ensureInitialized();
+ int blockIndex = footer.getRecordBatches().indexOf(block);
+ if (blockIndex == -1) {
+ throw new IllegalArgumentException("Arrow bock does not exist in record batches: " + block);
+ }
+ currentRecordBatch = blockIndex;
+ loadNextBatch();
+ }
+
+ private ArrowDictionaryBatch readDictionaryBatch(SeekableReadChannel in,
+ ArrowBlock block,
+ BufferAllocator allocator) throws IOException {
+ LOGGER.debug(String.format("DictionaryRecordBatch at %d, metadata: %d, body: %d",
+ block.getOffset(), block.getMetadataLength(), block.getBodyLength()));
+ in.setPosition(block.getOffset());
+ ArrowDictionaryBatch batch = MessageSerializer.deserializeDictionaryBatch(in, block, allocator);
+ if (batch == null) {
+ throw new IOException("Invalid file. No batch at offset: " + block.getOffset());
+ }
+ return batch;
+ }
+
+ private ArrowRecordBatch readRecordBatch(SeekableReadChannel in,
+ ArrowBlock block,
+ BufferAllocator allocator) throws IOException {
+ LOGGER.debug(String.format("RecordBatch at %d, metadata: %d, body: %d",
+ block.getOffset(), block.getMetadataLength(),
+ block.getBodyLength()));
+ in.setPosition(block.getOffset());
+ ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch(in, block, allocator);
+ if (batch == null) {
+ throw new IOException("Invalid file. No batch at offset: " + block.getOffset());
+ }
+ return batch;
+ }
+}
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileWriter.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileWriter.java
new file mode 100644
index 0000000..23d210a
--- /dev/null
+++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFileWriter.java
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.vector.file;
+
+import java.io.IOException;
+import java.nio.channels.WritableByteChannel;
+import java.util.List;
+
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class ArrowFileWriter extends ArrowWriter {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFileWriter.class);
+
+ public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) {
+ super(root, provider, out);
+ }
+
+ @Override
+ protected void startInternal(WriteChannel out) throws IOException {
+ ArrowMagic.writeMagic(out);
+ }
+
+ @Override
+ protected void endInternal(WriteChannel out,
+ Schema schema,
+ List<ArrowBlock> dictionaries,
+ List<ArrowBlock> records) throws IOException {
+ long footerStart = out.getCurrentPosition();
+ out.write(new ArrowFooter(schema, dictionaries, records), false);
+ int footerLength = (int)(out.getCurrentPosition() - footerStart);
+ if (footerLength <= 0) {
+ throw new InvalidArrowFileException("invalid footer");
+ }
+ out.writeIntLittleEndian(footerLength);
+ LOGGER.debug(String.format("Footer starts at %d, length: %d", footerStart, footerLength));
+ ArrowMagic.writeMagic(out);
+ LOGGER.debug(String.format("magic written, now at %d", out.getCurrentPosition()));
+ }
+}
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java
index 3890306..1c0008a 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowFooter.java
@@ -38,7 +38,6 @@ public class ArrowFooter implements FBSerializable {
private final List<ArrowBlock> recordBatches;
public ArrowFooter(Schema schema, List<ArrowBlock> dictionaries, List<ArrowBlock> recordBatches) {
- super();
this.schema = schema;
this.dictionaries = dictionaries;
this.recordBatches = recordBatches;
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowMagic.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowMagic.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowMagic.java
new file mode 100644
index 0000000..99ea96b
--- /dev/null
+++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowMagic.java
@@ -0,0 +1,37 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.vector.file;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+
+public class ArrowMagic {
+
+ private static final byte[] MAGIC = "ARROW1".getBytes(StandardCharsets.UTF_8);
+
+ public static final int MAGIC_LENGTH = MAGIC.length;
+
+ public static void writeMagic(WriteChannel out) throws IOException {
+ out.write(MAGIC);
+ }
+
+ public static boolean validateMagic(byte[] array) {
+ return Arrays.equals(MAGIC, array);
+ }
+}
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java
index 8f4f497..1646fbe 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowReader.java
@@ -18,90 +18,188 @@
package org.apache.arrow.vector.file;
import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.nio.channels.SeekableByteChannel;
-import java.util.Arrays;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.collect.ImmutableList;
-import org.apache.arrow.flatbuf.Footer;
import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.VectorLoader;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.schema.ArrowDictionaryBatch;
+import org.apache.arrow.vector.schema.ArrowMessage;
+import org.apache.arrow.vector.schema.ArrowMessage.ArrowMessageVisitor;
import org.apache.arrow.vector.schema.ArrowRecordBatch;
-import org.apache.arrow.vector.stream.MessageSerializer;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-public class ArrowReader implements AutoCloseable {
- private static final Logger LOGGER = LoggerFactory.getLogger(ArrowReader.class);
-
- public static final byte[] MAGIC = "ARROW1".getBytes();
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.ArrowType.Int;
+import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
- private final SeekableByteChannel in;
+public abstract class ArrowReader<T extends ReadChannel> implements DictionaryProvider, AutoCloseable {
+ private final T in;
private final BufferAllocator allocator;
- private ArrowFooter footer;
+ private VectorLoader loader;
+ private VectorSchemaRoot root;
+ private Map<Long, Dictionary> dictionaries;
- public ArrowReader(SeekableByteChannel in, BufferAllocator allocator) {
- super();
+ private boolean initialized = false;
+
+ protected ArrowReader(T in, BufferAllocator allocator) {
this.in = in;
this.allocator = allocator;
}
- private int readFully(ByteBuffer buffer) throws IOException {
- int total = 0;
- int n;
- do {
- n = in.read(buffer);
- total += n;
- } while (n >= 0 && buffer.remaining() > 0);
- buffer.flip();
- return total;
+ /**
+ * Returns the vector schema root. This will be loaded with new values on every call to loadNextBatch
+ *
+ * @return the vector schema root
+ * @throws IOException if reading of schema fails
+ */
+ public VectorSchemaRoot getVectorSchemaRoot() throws IOException {
+ ensureInitialized();
+ return root;
}
- public ArrowFooter readFooter() throws IOException {
- if (footer == null) {
- if (in.size() <= (MAGIC.length * 2 + 4)) {
- throw new InvalidArrowFileException("file too small: " + in.size());
- }
- ByteBuffer buffer = ByteBuffer.allocate(4 + MAGIC.length);
- long footerLengthOffset = in.size() - buffer.remaining();
- in.position(footerLengthOffset);
- readFully(buffer);
- byte[] array = buffer.array();
- if (!Arrays.equals(MAGIC, Arrays.copyOfRange(array, 4, array.length))) {
- throw new InvalidArrowFileException("missing Magic number " + Arrays.toString(buffer.array()));
- }
- int footerLength = MessageSerializer.bytesToInt(array);
- if (footerLength <= 0 || footerLength + MAGIC.length * 2 + 4 > in.size()) {
- throw new InvalidArrowFileException("invalid footer length: " + footerLength);
- }
- long footerOffset = footerLengthOffset - footerLength;
- LOGGER.debug(String.format("Footer starts at %d, length: %d", footerOffset, footerLength));
- ByteBuffer footerBuffer = ByteBuffer.allocate(footerLength);
- in.position(footerOffset);
- readFully(footerBuffer);
- Footer footerFB = Footer.getRootAsFooter(footerBuffer);
- this.footer = new ArrowFooter(footerFB);
+ /**
+ * Returns any dictionaries
+ *
+ * @return dictionaries, if any
+ * @throws IOException if reading of schema fails
+ */
+ public Map<Long, Dictionary> getDictionaryVectors() throws IOException {
+ ensureInitialized();
+ return dictionaries;
+ }
+
+ @Override
+ public Dictionary lookup(long id) {
+ if (initialized) {
+ return dictionaries.get(id);
+ } else {
+ return null;
}
- return footer;
}
- // TODO: read dictionaries
-
- public ArrowRecordBatch readRecordBatch(ArrowBlock block) throws IOException {
- LOGGER.debug(String.format("RecordBatch at %d, metadata: %d, body: %d",
- block.getOffset(), block.getMetadataLength(),
- block.getBodyLength()));
- in.position(block.getOffset());
- ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch(
- new ReadChannel(in, block.getOffset()), block, allocator);
- if (batch == null) {
- throw new IOException("Invalid file. No batch at offset: " + block.getOffset());
+ public void loadNextBatch() throws IOException {
+ ensureInitialized();
+ // read in all dictionary batches, then stop after our first record batch
+ ArrowMessageVisitor<Boolean> visitor = new ArrowMessageVisitor<Boolean>() {
+ @Override
+ public Boolean visit(ArrowDictionaryBatch message) {
+ try { load(message); } finally { message.close(); }
+ return true;
+ }
+ @Override
+ public Boolean visit(ArrowRecordBatch message) {
+ try { loader.load(message); } finally { message.close(); }
+ return false;
+ }
+ };
+ root.setRowCount(0);
+ ArrowMessage message = readMessage(in, allocator);
+ while (message != null && message.accepts(visitor)) {
+ message = readMessage(in, allocator);
}
- return batch;
}
+ public long bytesRead() { return in.bytesRead(); }
+
@Override
public void close() throws IOException {
+ if (initialized) {
+ root.close();
+ for (Dictionary dictionary: dictionaries.values()) {
+ dictionary.getVector().close();
+ }
+ }
in.close();
}
+
+ protected abstract Schema readSchema(T in) throws IOException;
+
+ protected abstract ArrowMessage readMessage(T in, BufferAllocator allocator) throws IOException;
+
+ protected void ensureInitialized() throws IOException {
+ if (!initialized) {
+ initialize();
+ initialized = true;
+ }
+ }
+
+ /**
+ * Reads the schema and initializes the vectors
+ */
+ private void initialize() throws IOException {
+ Schema schema = readSchema(in);
+ List<Field> fields = new ArrayList<>();
+ List<FieldVector> vectors = new ArrayList<>();
+ Map<Long, Dictionary> dictionaries = new HashMap<>();
+
+ for (Field field: schema.getFields()) {
+ Field updated = toMemoryFormat(field, dictionaries);
+ fields.add(updated);
+ vectors.add(updated.createVector(allocator));
+ }
+
+ this.root = new VectorSchemaRoot(fields, vectors, 0);
+ this.loader = new VectorLoader(root);
+ this.dictionaries = Collections.unmodifiableMap(dictionaries);
+ }
+
+ // in the message format, fields have the dictionary type
+ // in the memory format, they have the index type
+ private Field toMemoryFormat(Field field, Map<Long, Dictionary> dictionaries) {
+ DictionaryEncoding encoding = field.getDictionary();
+ List<Field> children = field.getChildren();
+
+ if (encoding == null && children.isEmpty()) {
+ return field;
+ }
+
+ List<Field> updatedChildren = new ArrayList<>(children.size());
+ for (Field child: children) {
+ updatedChildren.add(toMemoryFormat(child, dictionaries));
+ }
+
+ ArrowType type;
+ if (encoding == null) {
+ type = field.getType();
+ } else {
+ // re-type the field for in-memory format
+ type = encoding.getIndexType();
+ if (type == null) {
+ type = new Int(32, true);
+ }
+ // get existing or create dictionary vector
+ if (!dictionaries.containsKey(encoding.getId())) {
+ // create a new dictionary vector for the values
+ Field dictionaryField = new Field(field.getName(), field.isNullable(), field.getType(), null, children);
+ FieldVector dictionaryVector = dictionaryField.createVector(allocator);
+ dictionaries.put(encoding.getId(), new Dictionary(dictionaryVector, encoding));
+ }
+ }
+
+ return new Field(field.getName(), field.isNullable(), type, encoding, updatedChildren);
+ }
+
+ private void load(ArrowDictionaryBatch dictionaryBatch) {
+ long id = dictionaryBatch.getDictionaryId();
+ Dictionary dictionary = dictionaries.get(id);
+ if (dictionary == null) {
+ throw new IllegalArgumentException("Dictionary ID " + id + " not defined in schema");
+ }
+ FieldVector vector = dictionary.getVector();
+ VectorSchemaRoot root = new VectorSchemaRoot(ImmutableList.of(vector.getField()), ImmutableList.of(vector), 0);
+ VectorLoader loader = new VectorLoader(root);
+ loader.load(dictionaryBatch.getDictionary());
+ }
}
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java
index 24c667e..60a6afb 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ArrowWriter.java
@@ -1,4 +1,4 @@
-/**
+/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
@@ -21,77 +21,172 @@ import java.io.IOException;
import java.nio.channels.WritableByteChannel;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
+import com.google.common.collect.ImmutableList;
+
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.VectorUnloader;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.schema.ArrowDictionaryBatch;
import org.apache.arrow.vector.schema.ArrowRecordBatch;
import org.apache.arrow.vector.stream.MessageSerializer;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
+import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-public class ArrowWriter implements AutoCloseable {
+public abstract class ArrowWriter implements AutoCloseable {
+
private static final Logger LOGGER = LoggerFactory.getLogger(ArrowWriter.class);
+ // schema with fields in message format, not memory format
+ private final Schema schema;
private final WriteChannel out;
- private final Schema schema;
+ private final VectorUnloader unloader;
+ private final List<ArrowDictionaryBatch> dictionaries;
+
+ private final List<ArrowBlock> dictionaryBlocks = new ArrayList<>();
+ private final List<ArrowBlock> recordBlocks = new ArrayList<>();
- private final List<ArrowBlock> recordBatches = new ArrayList<>();
private boolean started = false;
+ private boolean ended = false;
- public ArrowWriter(WritableByteChannel out, Schema schema) {
+ /**
+ * Note: fields are not closed when the writer is closed
+ *
+ * @param root
+ * @param provider
+ * @param out
+ */
+ protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) {
+ this.unloader = new VectorUnloader(root);
this.out = new WriteChannel(out);
- this.schema = schema;
+
+ List<Field> fields = new ArrayList<>(root.getSchema().getFields().size());
+ Map<Long, ArrowDictionaryBatch> dictionaryBatches = new HashMap<>();
+
+ for (Field field: root.getSchema().getFields()) {
+ fields.add(toMessageFormat(field, provider, dictionaryBatches));
+ }
+
+ this.schema = new Schema(fields);
+ this.dictionaries = Collections.unmodifiableList(new ArrayList<>(dictionaryBatches.values()));
+ }
+
+ // in the message format, fields have the dictionary type
+ // in the memory format, they have the index type
+ private Field toMessageFormat(Field field, DictionaryProvider provider, Map<Long, ArrowDictionaryBatch> batches) {
+ DictionaryEncoding encoding = field.getDictionary();
+ List<Field> children = field.getChildren();
+
+ if (encoding == null && children.isEmpty()) {
+ return field;
+ }
+
+ List<Field> updatedChildren = new ArrayList<>(children.size());
+ for (Field child: children) {
+ updatedChildren.add(toMessageFormat(child, provider, batches));
+ }
+
+ ArrowType type;
+ if (encoding == null) {
+ type = field.getType();
+ } else {
+ long id = encoding.getId();
+ Dictionary dictionary = provider.lookup(id);
+ if (dictionary == null) {
+ throw new IllegalArgumentException("Could not find dictionary with ID " + id);
+ }
+ type = dictionary.getVectorType();
+
+ if (!batches.containsKey(id)) {
+ FieldVector vector = dictionary.getVector();
+ int count = vector.getAccessor().getValueCount();
+ VectorSchemaRoot root = new VectorSchemaRoot(ImmutableList.of(field), ImmutableList.of(vector), count);
+ VectorUnloader unloader = new VectorUnloader(root);
+ ArrowRecordBatch batch = unloader.getRecordBatch();
+ batches.put(id, new ArrowDictionaryBatch(id, batch));
+ }
+ }
+
+ return new Field(field.getName(), field.isNullable(), type, encoding, updatedChildren);
}
- private void start() throws IOException {
- writeMagic();
- MessageSerializer.serialize(out, schema);
+ public void start() throws IOException {
+ ensureStarted();
}
- // TODO: write dictionaries
+ public void writeBatch() throws IOException {
+ ensureStarted();
+ try (ArrowRecordBatch batch = unloader.getRecordBatch()) {
+ writeRecordBatch(batch);
+ }
+ }
- public void writeRecordBatch(ArrowRecordBatch recordBatch) throws IOException {
- checkStarted();
- ArrowBlock batchDesc = MessageSerializer.serialize(out, recordBatch);
+ protected void writeRecordBatch(ArrowRecordBatch batch) throws IOException {
+ ArrowBlock block = MessageSerializer.serialize(out, batch);
LOGGER.debug(String.format("RecordBatch at %d, metadata: %d, body: %d",
- batchDesc.getOffset(), batchDesc.getMetadataLength(), batchDesc.getBodyLength()));
+ block.getOffset(), block.getMetadataLength(), block.getBodyLength()));
+ recordBlocks.add(block);
+ }
- // add metadata to footer
- recordBatches.add(batchDesc);
+ public void end() throws IOException {
+ ensureStarted();
+ ensureEnded();
}
- private void checkStarted() throws IOException {
+ public long bytesWritten() { return out.getCurrentPosition(); }
+
+ private void ensureStarted() throws IOException {
if (!started) {
started = true;
- start();
+ startInternal(out);
+ // write the schema - for file formats this is duplicated in the footer, but matches
+ // the streaming format
+ MessageSerializer.serialize(out, schema);
+ // write out any dictionaries
+ for (ArrowDictionaryBatch batch : dictionaries) {
+ try {
+ ArrowBlock block = MessageSerializer.serialize(out, batch);
+ LOGGER.debug(String.format("DictionaryRecordBatch at %d, metadata: %d, body: %d",
+ block.getOffset(), block.getMetadataLength(), block.getBodyLength()));
+ dictionaryBlocks.add(block);
+ } finally {
+ batch.close();
+ }
+ }
}
}
- @Override
- public void close() throws IOException {
- try {
- long footerStart = out.getCurrentPosition();
- writeFooter();
- int footerLength = (int)(out.getCurrentPosition() - footerStart);
- if (footerLength <= 0 ) {
- throw new InvalidArrowFileException("invalid footer");
- }
- out.writeIntLittleEndian(footerLength);
- LOGGER.debug(String.format("Footer starts at %d, length: %d", footerStart, footerLength));
- writeMagic();
- } finally {
- out.close();
+ private void ensureEnded() throws IOException {
+ if (!ended) {
+ ended = true;
+ endInternal(out, schema, dictionaryBlocks, recordBlocks);
}
}
- private void writeMagic() throws IOException {
- out.write(ArrowReader.MAGIC);
- LOGGER.debug(String.format("magic written, now at %d", out.getCurrentPosition()));
- }
+ protected abstract void startInternal(WriteChannel out) throws IOException;
+
+ protected abstract void endInternal(WriteChannel out,
+ Schema schema,
+ List<ArrowBlock> dictionaries,
+ List<ArrowBlock> records) throws IOException;
- private void writeFooter() throws IOException {
- // TODO: dictionaries
- out.write(new ArrowFooter(schema, Collections.<ArrowBlock>emptyList(), recordBatches), false);
+ @Override
+ public void close() {
+ try {
+ end();
+ out.close();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
}
}
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java b/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java
index a9dc129..b062f38 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/file/ReadChannel.java
@@ -32,16 +32,9 @@ public class ReadChannel implements AutoCloseable {
private ReadableByteChannel in;
private long bytesRead = 0;
- // The starting byte offset into 'in'.
- private final long startByteOffset;
-
- public ReadChannel(ReadableByteChannel in, long startByteOffset) {
- this.in = in;
- this.startByteOffset = startByteOffset;
- }
public ReadChannel(ReadableByteChannel in) {
- this(in, 0);
+ this.in = in;
}
public long bytesRead() { return bytesRead; }
@@ -72,8 +65,6 @@ public class ReadChannel implements AutoCloseable {
return n;
}
- public long getCurrentPositiion() { return startByteOffset + bytesRead; }
-
@Override
public void close() throws IOException {
if (this.in != null) {
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/SeekableReadChannel.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/SeekableReadChannel.java b/java/vector/src/main/java/org/apache/arrow/vector/file/SeekableReadChannel.java
new file mode 100644
index 0000000..914c3cb
--- /dev/null
+++ b/java/vector/src/main/java/org/apache/arrow/vector/file/SeekableReadChannel.java
@@ -0,0 +1,39 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.vector.file;
+
+import java.io.IOException;
+import java.nio.channels.SeekableByteChannel;
+
+public class SeekableReadChannel extends ReadChannel {
+
+ private final SeekableByteChannel in;
+
+ public SeekableReadChannel(SeekableByteChannel in) {
+ super(in);
+ this.in = in;
+ }
+
+ public void setPosition(long position) throws IOException {
+ in.position(position);
+ }
+
+ public long size() throws IOException {
+ return in.size();
+ }
+}
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/WriteChannel.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/WriteChannel.java b/java/vector/src/main/java/org/apache/arrow/vector/file/WriteChannel.java
index d99c9a6..42104d1 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/file/WriteChannel.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/file/WriteChannel.java
@@ -21,13 +21,12 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.WritableByteChannel;
-import org.apache.arrow.vector.schema.FBSerializable;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
import com.google.flatbuffers.FlatBufferBuilder;
import io.netty.buffer.ArrowBuf;
+import org.apache.arrow.vector.schema.FBSerializable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/**
* Wrapper around a WritableByteChannel that maintains the position as well adding
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java
index 24fdc18..bdb63b9 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java
@@ -88,10 +88,34 @@ public class JsonFileReader implements AutoCloseable {
}
}
+ public void read(VectorSchemaRoot root) throws IOException {
+ JsonToken t = parser.nextToken();
+ if (t == START_OBJECT) {
+ {
+ int count = readNextField("count", Integer.class);
+ root.setRowCount(count);
+ nextFieldIs("columns");
+ readToken(START_ARRAY);
+ {
+ for (Field field : schema.getFields()) {
+ FieldVector vector = root.getVector(field.getName());
+ readVector(field, vector);
+ }
+ }
+ readToken(END_ARRAY);
+ }
+ readToken(END_OBJECT);
+ } else if (t == END_ARRAY) {
+ root.setRowCount(0);
+ } else {
+ throw new IllegalArgumentException("Invalid token: " + t);
+ }
+ }
+
public VectorSchemaRoot read() throws IOException {
JsonToken t = parser.nextToken();
if (t == START_OBJECT) {
- VectorSchemaRoot recordBatch = new VectorSchemaRoot(schema, allocator);
+ VectorSchemaRoot recordBatch = VectorSchemaRoot.create(schema, allocator);
{
int count = readNextField("count", Integer.class);
recordBatch.setRowCount(count);
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowDictionaryBatch.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowDictionaryBatch.java b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowDictionaryBatch.java
new file mode 100644
index 0000000..901877b
--- /dev/null
+++ b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowDictionaryBatch.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.vector.schema;
+
+import com.google.flatbuffers.FlatBufferBuilder;
+import org.apache.arrow.flatbuf.DictionaryBatch;
+
+public class ArrowDictionaryBatch implements ArrowMessage {
+
+ private final long dictionaryId;
+ private final ArrowRecordBatch dictionary;
+
+ public ArrowDictionaryBatch(long dictionaryId, ArrowRecordBatch dictionary) {
+ this.dictionaryId = dictionaryId;
+ this.dictionary = dictionary;
+ }
+
+ public long getDictionaryId() { return dictionaryId; }
+ public ArrowRecordBatch getDictionary() { return dictionary; }
+
+ @Override
+ public int writeTo(FlatBufferBuilder builder) {
+ int dataOffset = dictionary.writeTo(builder);
+ DictionaryBatch.startDictionaryBatch(builder);
+ DictionaryBatch.addId(builder, dictionaryId);
+ DictionaryBatch.addData(builder, dataOffset);
+ return DictionaryBatch.endDictionaryBatch(builder);
+ }
+
+ @Override
+ public int computeBodyLength() { return dictionary.computeBodyLength(); }
+
+ @Override
+ public <T> T accepts(ArrowMessageVisitor<T> visitor) { return visitor.visit(this); }
+
+ @Override
+ public String toString() {
+ return "ArrowDictionaryBatch [dictionaryId=" + dictionaryId + ", dictionary=" + dictionary + "]";
+ }
+
+ @Override
+ public void close() {
+ dictionary.close();
+ }
+}
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowMessage.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowMessage.java b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowMessage.java
new file mode 100644
index 0000000..d307428
--- /dev/null
+++ b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowMessage.java
@@ -0,0 +1,30 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.arrow.vector.schema;
+
+public interface ArrowMessage extends FBSerializable, AutoCloseable {
+
+ public int computeBodyLength();
+
+ public <T> T accepts(ArrowMessageVisitor<T> visitor);
+
+ public static interface ArrowMessageVisitor<T> {
+ public T visit(ArrowDictionaryBatch message);
+ public T visit(ArrowRecordBatch message);
+ }
+}
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java
index 40c2fbf..6ef514e 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/schema/ArrowRecordBatch.java
@@ -32,7 +32,8 @@ import com.google.flatbuffers.FlatBufferBuilder;
import io.netty.buffer.ArrowBuf;
-public class ArrowRecordBatch implements FBSerializable, AutoCloseable {
+public class ArrowRecordBatch implements ArrowMessage {
+
private static final Logger LOGGER = LoggerFactory.getLogger(ArrowRecordBatch.class);
/** number of records */
@@ -113,9 +114,13 @@ public class ArrowRecordBatch implements FBSerializable, AutoCloseable {
return RecordBatch.endRecordBatch(builder);
}
+ @Override
+ public <T> T accepts(ArrowMessageVisitor<T> visitor) { return visitor.visit(this); }
+
/**
* releases the buffers
*/
+ @Override
public void close() {
if (!closed) {
closed = true;
@@ -134,6 +139,7 @@ public class ArrowRecordBatch implements FBSerializable, AutoCloseable {
/**
* Computes the size of the serialized body for this recordBatch.
*/
+ @Override
public int computeBodyLength() {
int size = 0;
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamReader.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamReader.java b/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamReader.java
index f32966c..2deef37 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamReader.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamReader.java
@@ -17,79 +17,43 @@
*/
package org.apache.arrow.vector.stream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.nio.channels.Channels;
-import java.nio.channels.ReadableByteChannel;
-
import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.file.ArrowReader;
import org.apache.arrow.vector.file.ReadChannel;
-import org.apache.arrow.vector.schema.ArrowRecordBatch;
+import org.apache.arrow.vector.schema.ArrowMessage;
import org.apache.arrow.vector.types.pojo.Schema;
-import com.google.common.base.Preconditions;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.channels.Channels;
+import java.nio.channels.ReadableByteChannel;
/**
* This classes reads from an input stream and produces ArrowRecordBatches.
*/
-public class ArrowStreamReader implements AutoCloseable {
- private ReadChannel in;
- private final BufferAllocator allocator;
- private Schema schema;
-
- /**
- * Constructs a streaming read, reading bytes from 'in'. Non-blocking.
- */
- public ArrowStreamReader(ReadableByteChannel in, BufferAllocator allocator) {
- super();
- this.in = new ReadChannel(in);
- this.allocator = allocator;
- }
-
- public ArrowStreamReader(InputStream in, BufferAllocator allocator) {
- this(Channels.newChannel(in), allocator);
- }
-
- /**
- * Initializes the reader. Must be called before the other APIs. This is blocking.
- */
- public void init() throws IOException {
- Preconditions.checkState(this.schema == null, "Cannot call init() more than once.");
- this.schema = readSchema();
- }
+public class ArrowStreamReader extends ArrowReader<ReadChannel> {
- /**
- * Returns the schema for all records in this stream.
- */
- public Schema getSchema () {
- Preconditions.checkState(this.schema != null, "Must call init() first.");
- return schema;
- }
-
- public long bytesRead() { return in.bytesRead(); }
+ /**
+ * Constructs a streaming read, reading bytes from 'in'. Non-blocking.
+ */
+ public ArrowStreamReader(ReadableByteChannel in, BufferAllocator allocator) {
+ super(new ReadChannel(in), allocator);
+ }
- /**
- * Reads and returns the next ArrowRecordBatch. Returns null if this is the end
- * of stream.
- */
- public ArrowRecordBatch nextRecordBatch() throws IOException {
- Preconditions.checkState(this.in != null, "Cannot call after close()");
- Preconditions.checkState(this.schema != null, "Must call init() first.");
- return MessageSerializer.deserializeRecordBatch(in, allocator);
- }
+ public ArrowStreamReader(InputStream in, BufferAllocator allocator) {
+ this(Channels.newChannel(in), allocator);
+ }
- @Override
- public void close() throws IOException {
- if (this.in != null) {
- in.close();
- in = null;
+ /**
+ * Reads the schema message from the beginning of the stream.
+ */
+ @Override
+ protected Schema readSchema(ReadChannel in) throws IOException {
+ return MessageSerializer.deserializeSchema(in);
}
- }
- /**
- * Reads the schema message from the beginning of the stream.
- */
- private Schema readSchema() throws IOException {
- return MessageSerializer.deserializeSchema(in);
- }
+ @Override
+ protected ArrowMessage readMessage(ReadChannel in, BufferAllocator allocator) throws IOException {
+ return MessageSerializer.deserializeMessageBatch(in, allocator);
+ }
}
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamWriter.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamWriter.java
index 60dc586..ea29cd9 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamWriter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/stream/ArrowStreamWriter.java
@@ -17,63 +17,40 @@
*/
package org.apache.arrow.vector.stream;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.file.ArrowBlock;
+import org.apache.arrow.vector.file.ArrowWriter;
+import org.apache.arrow.vector.file.WriteChannel;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+
import java.io.IOException;
import java.io.OutputStream;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
+import java.util.List;
-import org.apache.arrow.vector.file.WriteChannel;
-import org.apache.arrow.vector.schema.ArrowRecordBatch;
-import org.apache.arrow.vector.types.pojo.Schema;
-
-public class ArrowStreamWriter implements AutoCloseable {
- private final WriteChannel out;
- private final Schema schema;
- private boolean headerSent = false;
+public class ArrowStreamWriter extends ArrowWriter {
- /**
- * Creates the stream writer. non-blocking.
- * totalBatches can be set if the writer knows beforehand. Can be -1 if unknown.
- */
- public ArrowStreamWriter(WritableByteChannel out, Schema schema) {
- this.out = new WriteChannel(out);
- this.schema = schema;
- }
-
- public ArrowStreamWriter(OutputStream out, Schema schema)
- throws IOException {
- this(Channels.newChannel(out), schema);
- }
-
- public long bytesWritten() { return out.getCurrentPosition(); }
-
- public void writeRecordBatch(ArrowRecordBatch batch) throws IOException {
- // Send the header if we have not yet.
- checkAndSendHeader();
- MessageSerializer.serialize(out, batch);
- }
+ public ArrowStreamWriter(VectorSchemaRoot root, DictionaryProvider provider, OutputStream out) {
+ this(root, provider, Channels.newChannel(out));
+ }
- /**
- * End the stream. This is not required and this object can simply be closed.
- */
- public void end() throws IOException {
- checkAndSendHeader();
- out.writeIntLittleEndian(0);
- }
+ public ArrowStreamWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) {
+ super(root, provider, out);
+ }
- @Override
- public void close() throws IOException {
- // The header might not have been sent if this is an empty stream. Send it even in
- // this case so readers see a valid empty stream.
- checkAndSendHeader();
- out.close();
- }
+ @Override
+ protected void startInternal(WriteChannel out) throws IOException {}
- private void checkAndSendHeader() throws IOException {
- if (!headerSent) {
- MessageSerializer.serialize(out, schema);
- headerSent = true;
+ @Override
+ protected void endInternal(WriteChannel out,
+ Schema schema,
+ List<ArrowBlock> dictionaries,
+ List<ArrowBlock> records) throws IOException {
+ out.writeIntLittleEndian(0);
}
- }
}
-
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java b/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java
index 92df250..92a6c0c 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/stream/MessageSerializer.java
@@ -22,7 +22,11 @@ import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
+import com.google.flatbuffers.FlatBufferBuilder;
+
+import io.netty.buffer.ArrowBuf;
import org.apache.arrow.flatbuf.Buffer;
+import org.apache.arrow.flatbuf.DictionaryBatch;
import org.apache.arrow.flatbuf.FieldNode;
import org.apache.arrow.flatbuf.Message;
import org.apache.arrow.flatbuf.MessageHeader;
@@ -33,14 +37,12 @@ import org.apache.arrow.vector.file.ArrowBlock;
import org.apache.arrow.vector.file.ReadChannel;
import org.apache.arrow.vector.file.WriteChannel;
import org.apache.arrow.vector.schema.ArrowBuffer;
+import org.apache.arrow.vector.schema.ArrowDictionaryBatch;
import org.apache.arrow.vector.schema.ArrowFieldNode;
+import org.apache.arrow.vector.schema.ArrowMessage;
import org.apache.arrow.vector.schema.ArrowRecordBatch;
import org.apache.arrow.vector.types.pojo.Schema;
-import com.google.flatbuffers.FlatBufferBuilder;
-
-import io.netty.buffer.ArrowBuf;
-
/**
* Utility class for serializing Messages. Messages are all serialized a similar way.
* 1. 4 byte little endian message header prefix
@@ -81,35 +83,39 @@ public class MessageSerializer {
* Deserializes a schema object. Format is from serialize().
*/
public static Schema deserializeSchema(ReadChannel in) throws IOException {
- Message message = deserializeMessage(in, MessageHeader.Schema);
+ Message message = deserializeMessage(in);
if (message == null) {
throw new IOException("Unexpected end of input. Missing schema.");
}
+ if (message.headerType() != MessageHeader.Schema) {
+ throw new IOException("Expected schema but header was " + message.headerType());
+ }
return Schema.convertSchema((org.apache.arrow.flatbuf.Schema)
message.header(new org.apache.arrow.flatbuf.Schema()));
}
+
/**
* Serializes an ArrowRecordBatch. Returns the offset and length of the written batch.
*/
public static ArrowBlock serialize(WriteChannel out, ArrowRecordBatch batch)
- throws IOException {
+ throws IOException {
+
long start = out.getCurrentPosition();
int bodyLength = batch.computeBodyLength();
FlatBufferBuilder builder = new FlatBufferBuilder();
int batchOffset = batch.writeTo(builder);
- ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.RecordBatch,
- batchOffset, bodyLength);
+ ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.RecordBatch, batchOffset, bodyLength);
int metadataLength = serializedMessage.remaining();
- // Add extra padding bytes so that length prefix + metadata is a multiple
- // of 8 after alignment
- if ((start + metadataLength + 4) % 8 != 0) {
- metadataLength += 8 - (start + metadataLength + 4) % 8;
+ // calculate alignment bytes so that metadata length points to the correct location after alignment
+ int padding = (int)((start + metadataLength + 4) % 8);
+ if (padding != 0) {
+ metadataLength += (8 - padding);
}
out.writeIntLittleEndian(metadataLength);
@@ -118,6 +124,13 @@ public class MessageSerializer {
// Align the output to 8 byte boundary.
out.align();
+ long bufferLength = writeBatchBuffers(out, batch);
+
+ // Metadata size in the Block account for the size prefix
+ return new ArrowBlock(start, metadataLength + 4, bufferLength);
+ }
+
+ private static long writeBatchBuffers(WriteChannel out, ArrowRecordBatch batch) throws IOException {
long bufferStart = out.getCurrentPosition();
List<ArrowBuf> buffers = batch.getBuffers();
List<ArrowBuffer> buffersLayout = batch.getBuffersLayout();
@@ -135,22 +148,14 @@ public class MessageSerializer {
" != " + startPosition + layout.getSize());
}
}
- // Metadata size in the Block account for the size prefix
- return new ArrowBlock(start, metadataLength + 4, out.getCurrentPosition() - bufferStart);
+ return out.getCurrentPosition() - bufferStart;
}
/**
* Deserializes a RecordBatch
*/
- public static ArrowRecordBatch deserializeRecordBatch(ReadChannel in,
- BufferAllocator alloc) throws IOException {
- Message message = deserializeMessage(in, MessageHeader.RecordBatch);
- if (message == null) return null;
-
- if (message.bodyLength() > Integer.MAX_VALUE) {
- throw new IOException("Cannot currently deserialize record batches over 2GB");
- }
-
+ private static ArrowRecordBatch deserializeRecordBatch(ReadChannel in, Message message, BufferAllocator alloc)
+ throws IOException {
RecordBatch recordBatchFB = (RecordBatch) message.header(new RecordBatch());
int bodyLength = (int) message.bodyLength();
@@ -191,9 +196,7 @@ public class MessageSerializer {
// Now read the body
final ArrowBuf body = buffer.slice(block.getMetadataLength(),
(int) totalLen - block.getMetadataLength());
- ArrowRecordBatch result = deserializeRecordBatch(recordBatchFB, body);
-
- return result;
+ return deserializeRecordBatch(recordBatchFB, body);
}
// Deserializes a record batch given the Flatbuffer metadata and in-memory body
@@ -219,6 +222,106 @@ public class MessageSerializer {
}
/**
+ * Serializes a dictionary ArrowRecordBatch. Returns the offset and length of the written batch.
+ */
+ public static ArrowBlock serialize(WriteChannel out, ArrowDictionaryBatch batch) throws IOException {
+ long start = out.getCurrentPosition();
+ int bodyLength = batch.computeBodyLength();
+
+ FlatBufferBuilder builder = new FlatBufferBuilder();
+ int batchOffset = batch.writeTo(builder);
+
+ ByteBuffer serializedMessage = serializeMessage(builder, MessageHeader.DictionaryBatch, batchOffset, bodyLength);
+
+ int metadataLength = serializedMessage.remaining();
+
+ // Add extra padding bytes so that length prefix + metadata is a multiple
+ // of 8 after alignment
+ if ((start + metadataLength + 4) % 8 != 0) {
+ metadataLength += 8 - (start + metadataLength + 4) % 8;
+ }
+
+ out.writeIntLittleEndian(metadataLength);
+ out.write(serializedMessage);
+
+ // Align the output to 8 byte boundary.
+ out.align();
+
+ // write the embedded record batch
+ long bufferLength = writeBatchBuffers(out, batch.getDictionary());
+
+ // Metadata size in the Block account for the size prefix
+ return new ArrowBlock(start, metadataLength + 4, bufferLength + 8);
+ }
+
+ /**
+ * Deserializes a DictionaryBatch
+ */
+ private static ArrowDictionaryBatch deserializeDictionaryBatch(ReadChannel in,
+ Message message,
+ BufferAllocator alloc) throws IOException {
+ DictionaryBatch dictionaryBatchFB = (DictionaryBatch) message.header(new DictionaryBatch());
+
+ int bodyLength = (int) message.bodyLength();
+
+ // Now read the record batch body
+ ArrowBuf body = alloc.buffer(bodyLength);
+ if (in.readFully(body, bodyLength) != bodyLength) {
+ throw new IOException("Unexpected end of input trying to read batch.");
+ }
+ ArrowRecordBatch recordBatch = deserializeRecordBatch(dictionaryBatchFB.data(), body);
+ return new ArrowDictionaryBatch(dictionaryBatchFB.id(), recordBatch);
+ }
+
+ /**
+ * Deserializes a DictionaryBatch knowing the size of the entire message up front. This
+ * minimizes the number of reads to the underlying stream.
+ */
+ public static ArrowDictionaryBatch deserializeDictionaryBatch(ReadChannel in,
+ ArrowBlock block,
+ BufferAllocator alloc) throws IOException {
+ // Metadata length contains integer prefix plus byte padding
+ long totalLen = block.getMetadataLength() + block.getBodyLength();
+
+ if (totalLen > Integer.MAX_VALUE) {
+ throw new IOException("Cannot currently deserialize record batches over 2GB");
+ }
+
+ ArrowBuf buffer = alloc.buffer((int) totalLen);
+ if (in.readFully(buffer, (int) totalLen) != totalLen) {
+ throw new IOException("Unexpected end of input trying to read batch.");
+ }
+
+ ArrowBuf metadataBuffer = buffer.slice(4, block.getMetadataLength() - 4);
+
+ Message messageFB =
+ Message.getRootAsMessage(metadataBuffer.nioBuffer().asReadOnlyBuffer());
+
+ DictionaryBatch dictionaryBatchFB = (DictionaryBatch) messageFB.header(new DictionaryBatch());
+
+ // Now read the body
+ final ArrowBuf body = buffer.slice(block.getMetadataLength(),
+ (int) totalLen - block.getMetadataLength());
+ ArrowRecordBatch recordBatch = deserializeRecordBatch(dictionaryBatchFB.data(), body);
+ return new ArrowDictionaryBatch(dictionaryBatchFB.id(), recordBatch);
+ }
+
+ public static ArrowMessage deserializeMessageBatch(ReadChannel in, BufferAllocator alloc) throws IOException {
+ Message message = deserializeMessage(in);
+ if (message == null) {
+ return null;
+ } else if (message.bodyLength() > Integer.MAX_VALUE) {
+ throw new IOException("Cannot currently deserialize record batches over 2GB");
+ }
+
+ switch (message.headerType()) {
+ case MessageHeader.RecordBatch: return deserializeRecordBatch(in, message, alloc);
+ case MessageHeader.DictionaryBatch: return deserializeDictionaryBatch(in, message, alloc);
+ default: throw new IOException("Unexpected message header type " + message.headerType());
+ }
+ }
+
+ /**
* Serializes a message header.
*/
private static ByteBuffer serializeMessage(FlatBufferBuilder builder, byte headerType,
@@ -232,7 +335,7 @@ public class MessageSerializer {
return builder.dataBuffer();
}
- private static Message deserializeMessage(ReadChannel in, byte headerType) throws IOException {
+ private static Message deserializeMessage(ReadChannel in) throws IOException {
// Read the message size. There is an i32 little endian prefix.
ByteBuffer buffer = ByteBuffer.allocate(4);
if (in.readFully(buffer) != 4) return null;
@@ -246,11 +349,6 @@ public class MessageSerializer {
}
buffer.rewind();
- Message message = Message.getRootAsMessage(buffer);
- if (message.headerType() != headerType) {
- throw new IOException("Invalid message: expecting " + headerType +
- ". Message contained: " + message.headerType());
- }
- return message;
+ return Message.getRootAsMessage(buffer);
}
}
http://git-wip-us.apache.org/repos/asf/arrow/blob/49f666e7/java/vector/src/main/java/org/apache/arrow/vector/types/Dictionary.java
----------------------------------------------------------------------
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/Dictionary.java b/java/vector/src/main/java/org/apache/arrow/vector/types/Dictionary.java
deleted file mode 100644
index fbe1345..0000000
--- a/java/vector/src/main/java/org/apache/arrow/vector/types/Dictionary.java
+++ /dev/null
@@ -1,40 +0,0 @@
-/*******************************************************************************
-
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- ******************************************************************************/
-package org.apache.arrow.vector.types;
-
-import org.apache.arrow.vector.ValueVector;
-
-public class Dictionary {
-
- private ValueVector dictionary;
- private boolean ordered;
-
- public Dictionary(ValueVector dictionary, boolean ordered) {
- this.dictionary = dictionary;
- this.ordered = ordered;
- }
-
- public ValueVector getDictionary() {
- return dictionary;
- }
-
- public boolean isOrdered() {
- return ordered;
- }
-}