You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by fh...@apache.org on 2015/03/25 20:45:15 UTC
[2/5] flink git commit: [FLINK-1512] [scala api] Add CsvReader for
reading into POJOs
[FLINK-1512] [scala api] Add CsvReader for reading into POJOs
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/7a6f2960
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/7a6f2960
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/7a6f2960
Branch: refs/heads/master
Commit: 7a6f296094b26b940f9f9f66f64e5e2a0f700cb1
Parents: 7b1c19c
Author: Chiwan Park <ch...@icloud.com>
Authored: Fri Feb 20 02:23:56 2015 +0900
Committer: Fabian Hueske <fh...@apache.org>
Committed: Wed Mar 25 20:38:59 2015 +0100
----------------------------------------------------------------------
.../scala/operators/ScalaCsvInputFormat.java | 270 ++++++++-----------
.../flink/api/scala/ExecutionEnvironment.scala | 47 +++-
.../flink/api/scala/io/CsvInputFormatTest.scala | 125 ++++++++-
.../scala/io/ScalaCsvReaderWithPOJOITCase.scala | 124 +++++++++
4 files changed, 378 insertions(+), 188 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/7a6f2960/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvInputFormat.java
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvInputFormat.java b/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvInputFormat.java
index 79c6659..9adbed8 100644
--- a/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvInputFormat.java
+++ b/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvInputFormat.java
@@ -19,66 +19,91 @@
package org.apache.flink.api.scala.operators;
-import com.google.common.base.Charsets;
import com.google.common.base.Preconditions;
-
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.io.GenericCsvInputFormat;
import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.PojoTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
import org.apache.flink.api.java.typeutils.runtime.TupleSerializerBase;
import org.apache.flink.core.fs.FileInputSplit;
import org.apache.flink.core.fs.Path;
-import org.apache.flink.types.parser.FieldParser;
-import org.apache.flink.util.StringUtils;
+import org.apache.flink.types.parser.FieldParser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
-import java.nio.charset.Charset;
-import java.nio.charset.IllegalCharsetNameException;
-import java.nio.charset.UnsupportedCharsetException;
-import java.util.Map;
-import java.util.TreeMap;
+import java.lang.reflect.Field;
+import java.util.Arrays;
-import scala.Product;
-
-public class ScalaCsvInputFormat<OUT extends Product> extends GenericCsvInputFormat<OUT> {
+public class ScalaCsvInputFormat<OUT> extends GenericCsvInputFormat<OUT> {
private static final long serialVersionUID = 1L;
private static final Logger LOG = LoggerFactory.getLogger(ScalaCsvInputFormat.class);
-
- private transient Object[] parsedValues;
-
- // To speed up readRecord processing. Used to find windows line endings.
- // It is set when open so that readRecord does not have to evaluate it
- private boolean lineDelimiterIsLinebreak = false;
- private final TupleSerializerBase<OUT> serializer;
+ private transient Object[] parsedValues;
- private byte[] commentPrefix = null;
+ private final TupleSerializerBase<OUT> tupleSerializer;
- private transient int commentCount;
- private transient int invalidLineCount;
+ private Class<OUT> pojoTypeClass = null;
+ private String[] pojoFieldsName = null;
+ private transient Field[] pojoFields = null;
+ private transient PojoTypeInfo<OUT> pojoTypeInfo = null;
public ScalaCsvInputFormat(Path filePath, TypeInformation<OUT> typeInfo) {
super(filePath);
- if (!(typeInfo.isTupleType())) {
- throw new UnsupportedOperationException("This only works on tuple types.");
+ Class<?>[] classes = new Class[typeInfo.getArity()];
+
+ if (typeInfo instanceof TupleTypeInfoBase) {
+ TupleTypeInfoBase<OUT> tupleType = (TupleTypeInfoBase<OUT>) typeInfo;
+ // We can use an empty config here, since we only use the serializer to create
+ // the top-level case class
+ tupleSerializer = (TupleSerializerBase<OUT>) tupleType.createSerializer(new ExecutionConfig());
+
+ for (int i = 0; i < tupleType.getArity(); i++) {
+ classes[i] = tupleType.getTypeAt(i).getTypeClass();
+ }
+
+ setFieldTypes(classes);
+ } else {
+ tupleSerializer = null;
+ pojoTypeInfo = (PojoTypeInfo<OUT>) typeInfo;
+ pojoTypeClass = typeInfo.getTypeClass();
+ pojoFieldsName = pojoTypeInfo.getFieldNames();
+
+ for (int i = 0, arity = pojoTypeInfo.getArity(); i < arity; i++) {
+ classes[i] = pojoTypeInfo.getTypeAt(i).getTypeClass();
+ }
+
+ setFieldTypes(classes);
+ setOrderOfPOJOFields(pojoFieldsName);
+ }
+ }
+
+ public void setOrderOfPOJOFields(String[] fieldsOrder) {
+ Preconditions.checkNotNull(pojoTypeClass, "Field order can only be specified if output type is a POJO.");
+ Preconditions.checkNotNull(fieldsOrder);
+
+ int includedCount = 0;
+ for (boolean isIncluded : fieldIncluded) {
+ if (isIncluded) {
+ includedCount++;
+ }
}
- TupleTypeInfoBase<OUT> tupleType = (TupleTypeInfoBase<OUT>) typeInfo;
- // We can use an empty config here, since we only use the serializer to create
- // the top-level case class
- serializer = (TupleSerializerBase<OUT>) tupleType.createSerializer(new ExecutionConfig());
-
- Class<?>[] classes = new Class[tupleType.getArity()];
- for (int i = 0; i < tupleType.getArity(); i++) {
- classes[i] = tupleType.getTypeAt(i).getTypeClass();
+
+ Preconditions.checkArgument(includedCount == fieldsOrder.length,
+ "The number of selected POJO fields should be the same as that of CSV fields.");
+
+ for (String field : fieldsOrder) {
+ Preconditions.checkNotNull(field, "The field name cannot be null.");
+ Preconditions.checkArgument(pojoTypeInfo.getFieldIndex(field) != -1,
+ "The given field name isn't matched to POJO fields.");
}
- setFieldTypes(classes);
+
+ pojoFieldsName = Arrays.copyOfRange(fieldsOrder, 0, fieldsOrder.length);
}
public void setFieldTypes(Class<?>[] fieldTypes) {
@@ -98,98 +123,66 @@ public class ScalaCsvInputFormat<OUT extends Product> extends GenericCsvInputFor
setFieldsGeneric(sourceFieldIndices, fieldTypes);
}
- public byte[] getCommentPrefix() {
- return commentPrefix;
- }
-
- public void setCommentPrefix(byte[] commentPrefix) {
- this.commentPrefix = commentPrefix;
- }
-
- public void setCommentPrefix(char commentPrefix) {
- setCommentPrefix(String.valueOf(commentPrefix));
- }
+ public void setFields(boolean[] sourceFieldMask, Class<?>[] fieldTypes) {
+ Preconditions.checkNotNull(sourceFieldMask);
+ Preconditions.checkNotNull(fieldTypes);
- public void setCommentPrefix(String commentPrefix) {
- setCommentPrefix(commentPrefix, Charsets.UTF_8);
+ setFieldsGeneric(sourceFieldMask, fieldTypes);
}
- public void setCommentPrefix(String commentPrefix, String charsetName) throws IllegalCharsetNameException, UnsupportedCharsetException {
- if (charsetName == null) {
- throw new IllegalArgumentException("Charset name must not be null");
- }
-
- if (commentPrefix != null) {
- Charset charset = Charset.forName(charsetName);
- setCommentPrefix(commentPrefix, charset);
- } else {
- this.commentPrefix = null;
- }
+ public Class<?>[] getFieldTypes() {
+ return super.getGenericFieldTypes();
}
- public void setCommentPrefix(String commentPrefix, Charset charset) {
- if (charset == null) {
- throw new IllegalArgumentException("Charset must not be null");
- }
- if (commentPrefix != null) {
- this.commentPrefix = commentPrefix.getBytes(charset);
- } else {
- this.commentPrefix = null;
- }
- }
-
- @Override
- public void close() throws IOException {
- if (this.invalidLineCount > 0) {
- if (LOG.isWarnEnabled()) {
- LOG.warn("In file \""+ this.filePath + "\" (split start: " + this.splitStart + ") " + this.invalidLineCount +" invalid line(s) were skipped.");
- }
- }
-
- if (this.commentCount > 0) {
- if (LOG.isInfoEnabled()) {
- LOG.info("In file \""+ this.filePath + "\" (split start: " + this.splitStart + ") " + this.commentCount +" comment line(s) were skipped.");
- }
- }
- super.close();
- }
-
- @Override
- public OUT nextRecord(OUT record) throws IOException {
- OUT returnRecord = null;
- do {
- returnRecord = super.nextRecord(record);
- } while (returnRecord == null && !reachedEnd());
-
- return returnRecord;
- }
-
@Override
public void open(FileInputSplit split) throws IOException {
super.open(split);
-
+
@SuppressWarnings("unchecked")
FieldParser<Object>[] fieldParsers = (FieldParser<Object>[]) getFieldParsers();
-
+
//throw exception if no field parsers are available
if (fieldParsers.length == 0) {
throw new IOException("CsvInputFormat.open(FileInputSplit split) - no field parsers to parse input");
}
-
+
// create the value holders
this.parsedValues = new Object[fieldParsers.length];
for (int i = 0; i < fieldParsers.length; i++) {
this.parsedValues[i] = fieldParsers[i].createValue();
}
- this.commentCount = 0;
- this.invalidLineCount = 0;
-
// left to right evaluation makes access [0] okay
// this marker is used to fasten up readRecord, so that it doesn't have to check each call if the line ending is set to default
if (this.getDelimiter().length == 1 && this.getDelimiter()[0] == '\n' ) {
this.lineDelimiterIsLinebreak = true;
}
+
+ // for POJO type
+ if (pojoTypeClass != null) {
+ pojoFields = new Field[pojoFieldsName.length];
+ for (int i = 0; i < pojoFieldsName.length; i++) {
+ try {
+ pojoFields[i] = pojoTypeClass.getDeclaredField(pojoFieldsName[i]);
+ pojoFields[i].setAccessible(true);
+ } catch (NoSuchFieldException e) {
+ throw new RuntimeException("There is no field called \"" + pojoFieldsName[i] + "\" in " + pojoTypeClass.getName(), e);
+ }
+ }
+ }
+
+ this.commentCount = 0;
+ this.invalidLineCount = 0;
+ }
+
+ @Override
+ public OUT nextRecord(OUT record) throws IOException {
+ OUT returnRecord = null;
+ do {
+ returnRecord = super.nextRecord(record);
+ } while (returnRecord == null && !reachedEnd());
+
+ return returnRecord;
}
@Override
@@ -219,73 +212,22 @@ public class ScalaCsvInputFormat<OUT extends Product> extends GenericCsvInputFor
}
if (parseRecord(parsedValues, bytes, offset, numBytes)) {
- OUT result = serializer.createInstance(parsedValues);
- return result;
+ if (tupleSerializer != null) {
+ return tupleSerializer.createInstance(parsedValues);
+ } else {
+ for (int i = 0; i < pojoFields.length; i++) {
+ try {
+ pojoFields[i].set(reuse, parsedValues[i]);
+ } catch (IllegalAccessException e) {
+ throw new RuntimeException("Parsed value could not be set in POJO field \"" + pojoFieldsName[i] + "\"", e);
+ }
+ }
+
+ return reuse;
+ }
} else {
this.invalidLineCount++;
return null;
}
}
-
-
- @Override
- public String toString() {
- return "CSV Input (" + StringUtils.showControlCharacters(String.valueOf(getFieldDelimiter())) + ") " + getFilePath();
- }
-
- // --------------------------------------------------------------------------------------------
-
- @SuppressWarnings("unused")
- private static void checkAndCoSort(int[] positions, Class<?>[] types) {
- if (positions.length != types.length) {
- throw new IllegalArgumentException("The positions and types must be of the same length");
- }
-
- TreeMap<Integer, Class<?>> map = new TreeMap<Integer, Class<?>>();
-
- for (int i = 0; i < positions.length; i++) {
- if (positions[i] < 0) {
- throw new IllegalArgumentException("The field " + " (" + positions[i] + ") is invalid.");
- }
- if (types[i] == null) {
- throw new IllegalArgumentException("The type " + i + " is invalid (null)");
- }
-
- if (map.containsKey(positions[i])) {
- throw new IllegalArgumentException("The position " + positions[i] + " occurs multiple times.");
- }
-
- map.put(positions[i], types[i]);
- }
-
- int i = 0;
- for (Map.Entry<Integer, Class<?>> entry : map.entrySet()) {
- positions[i] = entry.getKey();
- types[i] = entry.getValue();
- i++;
- }
- }
-
- private static void checkForMonotonousOrder(int[] positions, Class<?>[] types) {
- if (positions.length != types.length) {
- throw new IllegalArgumentException("The positions and types must be of the same length");
- }
-
- int lastPos = -1;
-
- for (int i = 0; i < positions.length; i++) {
- if (positions[i] < 0) {
- throw new IllegalArgumentException("The field " + " (" + positions[i] + ") is invalid.");
- }
- if (types[i] == null) {
- throw new IllegalArgumentException("The type " + i + " is invalid (null)");
- }
-
- if (positions[i] <= lastPos) {
- throw new IllegalArgumentException("The positions must be strictly increasing (no permutations are supported).");
- }
-
- lastPos = positions[i];
- }
- }
}
http://git-wip-us.apache.org/repos/asf/flink/blob/7a6f2960/flink-scala/src/main/scala/org/apache/flink/api/scala/ExecutionEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/ExecutionEnvironment.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/ExecutionEnvironment.scala
index 4c1e627..7073f07 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/ExecutionEnvironment.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/ExecutionEnvironment.scala
@@ -26,7 +26,7 @@ import org.apache.flink.api.java.io._
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer
-import org.apache.flink.api.java.typeutils.{ValueTypeInfo, TupleTypeInfoBase}
+import org.apache.flink.api.java.typeutils.{PojoTypeInfo, ValueTypeInfo, TupleTypeInfoBase}
import org.apache.flink.api.scala.hadoop.mapred
import org.apache.flink.api.scala.hadoop.mapreduce
import org.apache.flink.api.scala.operators.ScalaCsvInputFormat
@@ -46,6 +46,7 @@ import org.apache.hadoop.fs.{Path => HadoopPath}
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
@@ -243,8 +244,9 @@ class ExecutionEnvironment(javaEnv: JavaEnv) {
* @param lenient Whether the parser should silently ignore malformed lines.
* @param includedFields The fields in the file that should be read. Per default all fields
* are read.
+ * @param pojoFields The fields of the POJO which are mapped to CSV fields.
*/
- def readCsvFile[T <: Product : ClassTag : TypeInformation](
+ def readCsvFile[T : ClassTag : TypeInformation](
filePath: String,
lineDelimiter: String = "\n",
fieldDelimiter: String = ",",
@@ -252,9 +254,10 @@ class ExecutionEnvironment(javaEnv: JavaEnv) {
ignoreFirstLine: Boolean = false,
ignoreComments: String = null,
lenient: Boolean = false,
- includedFields: Array[Int] = null): DataSet[T] = {
+ includedFields: Array[Int] = null,
+ pojoFields: Array[String] = null): DataSet[T] = {
- val typeInfo = implicitly[TypeInformation[T]].asInstanceOf[TupleTypeInfoBase[T]]
+ val typeInfo = implicitly[TypeInformation[T]]
val inputFormat = new ScalaCsvInputFormat[T](new Path(filePath), typeInfo)
inputFormat.setDelimiter(lineDelimiter)
@@ -267,16 +270,40 @@ class ExecutionEnvironment(javaEnv: JavaEnv) {
inputFormat.enableQuotedStringParsing(quoteCharacter);
}
- val classes: Array[Class[_]] = new Array[Class[_]](typeInfo.getArity)
- for (i <- 0 until typeInfo.getArity) {
- classes(i) = typeInfo.getTypeAt(i).getTypeClass
+ val classesBuf: ArrayBuffer[Class[_]] = new ArrayBuffer[Class[_]]
+ typeInfo match {
+ case info: TupleTypeInfoBase[T] =>
+ for (i <- 0 until info.getArity) {
+ classesBuf += info.getTypeAt(i).getTypeClass()
+ }
+ case info: PojoTypeInfo[T] =>
+ if (pojoFields == null) {
+ throw new IllegalArgumentException(
+ "POJO fields must be specified (not null) if output type is a POJO.")
+ } else {
+ for (i <- 0 until pojoFields.length) {
+ val pos = info.getFieldIndex(pojoFields(i))
+ if (pos < 0) {
+ throw new IllegalArgumentException(
+ "Field \"" + pojoFields(i) + "\" not part of POJO type " +
+ info.getTypeClass.getCanonicalName);
+ }
+ classesBuf += info.getPojoFieldAt(pos).`type`.getTypeClass
+ }
+ }
+ case _ => throw new IllegalArgumentException("Type information is not valid.")
}
+
if (includedFields != null) {
- Validate.isTrue(typeInfo.getArity == includedFields.length, "Number of tuple fields and" +
+ Validate.isTrue(classesBuf.size == includedFields.length, "Number of tuple fields and" +
" included fields must match.")
- inputFormat.setFields(includedFields, classes)
+ inputFormat.setFields(includedFields, classesBuf.toArray)
} else {
- inputFormat.setFieldTypes(classes)
+ inputFormat.setFieldTypes(classesBuf.toArray)
+ }
+
+ if (pojoFields != null) {
+ inputFormat.setOrderOfPOJOFields(pojoFields)
}
wrap(new DataSource[T](javaEnv, inputFormat, typeInfo, getCallLocationName()))
http://git-wip-us.apache.org/repos/asf/flink/blob/7a6f2960/flink-tests/src/test/scala/org/apache/flink/api/scala/io/CsvInputFormatTest.scala
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/io/CsvInputFormatTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/io/CsvInputFormatTest.scala
index 9964a9d..4bcd35a 100644
--- a/flink-tests/src/test/scala/org/apache/flink/api/scala/io/CsvInputFormatTest.scala
+++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/io/CsvInputFormatTest.scala
@@ -17,21 +17,15 @@
*/
package org.apache.flink.api.scala.io
+import java.io.{File, FileOutputStream, FileWriter, OutputStreamWriter}
+
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.scala._
import org.apache.flink.api.scala.operators.ScalaCsvInputFormat
-import org.junit.Assert._
-import org.junit.Assert.assertEquals
-import org.junit.Assert.assertNotNull
-import org.junit.Assert.assertNull
-import org.junit.Assert.assertTrue
-import org.junit.Assert.fail
-import java.io.File
-import java.io.FileOutputStream
-import java.io.FileWriter
-import java.io.OutputStreamWriter
import org.apache.flink.configuration.Configuration
import org.apache.flink.core.fs.{FileInputSplit, Path}
+import org.junit.Assert.{assertEquals, assertNotNull, assertNull, assertTrue, fail}
import org.junit.Test
-import org.apache.flink.api.scala._
class CsvInputFormatTest {
@@ -315,7 +309,8 @@ class CsvInputFormatTest {
PATH,
createTypeInformation[(Int, Int, Int)])
format.setFieldDelimiter("|")
- format.setFields(Array(0, 3, 7), Array(classOf[Integer], classOf[Integer], classOf[Integer]))
+ format.setFields(Array(0, 3, 7),
+ Array(classOf[Integer], classOf[Integer], classOf[Integer]): Array[Class[_]])
format.configure(new Configuration)
format.open(split)
var result: (Int, Int, Int) = null
@@ -347,7 +342,8 @@ class CsvInputFormatTest {
createTypeInformation[(Int, Int, Int)])
format.setFieldDelimiter("|")
try {
- format.setFields(Array(8, 1, 3), Array(classOf[Integer],classOf[Integer],classOf[Integer]))
+ format.setFields(Array(8, 1, 3),
+ Array(classOf[Integer], classOf[Integer], classOf[Integer]): Array[Class[_]])
fail("Input sequence should have been rejected.")
}
catch {
@@ -408,5 +404,106 @@ class CsvInputFormatTest {
fail("Test erroneous")
}
}
-}
+ class POJOItem(var field1: Int, var field2: String, var field3: Double) {
+ def this() {
+ this(-1, "", -1)
+ }
+ }
+
+ case class CaseClassItem(field1: Int, field2: String, field3: Double)
+
+ @Test
+ def testPOJOType(): Unit = {
+ val fileContent = "123,HELLO,3.123\n" + "456,ABC,1.234"
+ val tempFile = createTempFile(fileContent)
+ val typeInfo: TypeInformation[POJOItem] = createTypeInformation[POJOItem]
+ val format = new ScalaCsvInputFormat[POJOItem](PATH, typeInfo)
+
+ format.setDelimiter('\n')
+ format.setFieldDelimiter(',')
+ format.configure(new Configuration)
+ format.open(tempFile)
+
+ var result = new POJOItem()
+ result = format.nextRecord(result)
+ assertEquals(123, result.field1)
+ assertEquals("HELLO", result.field2)
+ assertEquals(3.123, result.field3, 0.001)
+
+ result = format.nextRecord(result)
+ assertEquals(456, result.field1)
+ assertEquals("ABC", result.field2)
+ assertEquals(1.234, result.field3, 0.001)
+ }
+
+ @Test
+ def testCaseClass(): Unit = {
+ val fileContent = "123,HELLO,3.123\n" + "456,ABC,1.234"
+ val tempFile = createTempFile(fileContent)
+ val typeInfo: TypeInformation[CaseClassItem] = createTypeInformation[CaseClassItem]
+ val format = new ScalaCsvInputFormat[CaseClassItem](PATH, typeInfo)
+
+ format.setDelimiter('\n')
+ format.setFieldDelimiter(',')
+ format.configure(new Configuration)
+ format.open(tempFile)
+
+ var result = format.nextRecord(null)
+ assertEquals(123, result.field1)
+ assertEquals("HELLO", result.field2)
+ assertEquals(3.123, result.field3, 0.001)
+
+ result = format.nextRecord(null)
+ assertEquals(456, result.field1)
+ assertEquals("ABC", result.field2)
+ assertEquals(1.234, result.field3, 0.001)
+ }
+
+ @Test
+ def testPOJOTypeWithFieldMapping(): Unit = {
+ val fileContent = "HELLO,123,3.123\n" + "ABC,456,1.234"
+ val tempFile = createTempFile(fileContent)
+ val typeInfo: TypeInformation[POJOItem] = createTypeInformation[POJOItem]
+ val format = new ScalaCsvInputFormat[POJOItem](PATH, typeInfo)
+
+ format.setDelimiter('\n')
+ format.setFieldDelimiter(',')
+ format.setFieldTypes(Array(classOf[String], classOf[Integer], classOf[java.lang.Double]))
+ format.setOrderOfPOJOFields(Array("field2", "field1", "field3"))
+ format.configure(new Configuration)
+ format.open(tempFile)
+
+ var result = new POJOItem()
+ result = format.nextRecord(result)
+ assertEquals(123, result.field1)
+ assertEquals("HELLO", result.field2)
+ assertEquals(3.123, result.field3, 0.001)
+
+ result = format.nextRecord(result)
+ assertEquals(456, result.field1)
+ assertEquals("ABC", result.field2)
+ assertEquals(1.234, result.field3, 0.001)
+ }
+
+ @Test
+ def testPOJOTypeWithFieldSubsetAndDataSubset(): Unit = {
+ val fileContent = "123,HELLO,3.123\n" + "456,ABC,1.234"
+ val tempFile = createTempFile(fileContent)
+ val typeInfo: TypeInformation[POJOItem] = createTypeInformation[POJOItem]
+ val format = new ScalaCsvInputFormat[POJOItem](PATH, typeInfo)
+
+ format.setDelimiter('\n')
+ format.setFieldDelimiter(',')
+ format.setFields(Array(false, true), Array(classOf[String]): Array[Class[_]])
+ format.setOrderOfPOJOFields(Array("field2", "field1", "field3"))
+ format.configure(new Configuration)
+ format.open(tempFile)
+
+ var result = format.nextRecord(new POJOItem())
+ assertEquals("HELLO", result.field2)
+
+ result = format.nextRecord(result)
+ assertEquals("ABC", result.field2)
+ }
+}
http://git-wip-us.apache.org/repos/asf/flink/blob/7a6f2960/flink-tests/src/test/scala/org/apache/flink/api/scala/io/ScalaCsvReaderWithPOJOITCase.scala
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/io/ScalaCsvReaderWithPOJOITCase.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/io/ScalaCsvReaderWithPOJOITCase.scala
new file mode 100644
index 0000000..21aa93d
--- /dev/null
+++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/io/ScalaCsvReaderWithPOJOITCase.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.scala.io
+
+import com.google.common.base.Charsets
+import com.google.common.io.Files
+import org.apache.flink.api.scala._
+import org.apache.flink.core.fs.FileSystem.WriteMode
+import org.apache.flink.test.util.MultipleProgramsTestBase
+import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
+import org.junit.Assert._
+import org.junit.rules.TemporaryFolder
+import org.junit.runner.RunWith
+import org.junit.runners.Parameterized
+import org.junit.{After, Before, Rule, Test}
+
+@RunWith(classOf[Parameterized])
+class ScalaCsvReaderWithPOJOITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) {
+ private val _tempFolder = new TemporaryFolder()
+ private var resultPath: String = null
+ private var expected: String = null
+
+ @Rule
+ def tempFolder = _tempFolder
+
+ @Before
+ def before(): Unit = {
+ resultPath = tempFolder.newFile("result").toURI.toString
+ }
+
+ @After
+ def after(): Unit = {
+ compareResultsByLinesInMemory(expected, resultPath)
+ }
+
+ def createInputData(data: String): String = {
+ val dataFile = tempFolder.newFile("data")
+ Files.write(data, dataFile, Charsets.UTF_8)
+ dataFile.toURI.toString
+ }
+
+ @Test
+ def testPOJOType(): Unit = {
+ val dataPath = createInputData("ABC,2.20,3\nDEF,5.1,5\nDEF,3.30,1\nGHI,3.30,10")
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val data = env.readCsvFile[POJOItem](dataPath, pojoFields = Array("f1", "f2", "f3"))
+
+ implicit val typeInfo = createTypeInformation[(String, Int)]
+ data.writeAsText(resultPath, WriteMode.OVERWRITE)
+
+ env.execute()
+
+ expected = "ABC,2.20,3\nDEF,5.10,5\nDEF,3.30,1\nGHI,3.30,10"
+ }
+
+ @Test
+ def testPOJOTypeWithFieldsOrder(): Unit = {
+ val dataPath = createInputData("2.20,ABC,3\n5.1,DEF,5\n3.30,DEF,1\n3.30,GHI,10")
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val data = env.readCsvFile[POJOItem](dataPath, pojoFields = Array("f2", "f1", "f3"))
+
+ implicit val typeInfo = createTypeInformation[(String, Int)]
+ data.writeAsText(resultPath, WriteMode.OVERWRITE)
+
+ env.execute()
+
+ expected = "ABC,2.20,3\nDEF,5.10,5\nDEF,3.30,1\nGHI,3.30,10"
+ }
+
+ @Test
+ def testPOJOTypeWithoutFieldsOrder(): Unit = {
+ val dataPath = createInputData("")
+ val env = ExecutionEnvironment.getExecutionEnvironment
+
+ try {
+ val data = env.readCsvFile[POJOItem](dataPath)
+ fail("POJO type without fields order must raise IllegalArgumentException!")
+ } catch {
+ case _: IllegalArgumentException => // success
+ }
+
+ expected = ""
+ resultPath = dataPath
+ }
+
+ @Test
+ def testPOJOTypeWithFieldsOrderAndFieldsSelection(): Unit = {
+ val dataPath = createInputData("2.20,3,ABC\n5.1,5,DEF\n3.30,1,DEF\n3.30,10,GHI")
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ val data = env.readCsvFile[POJOItem](dataPath, includedFields = Array(1, 2),
+ pojoFields = Array("f3", "f1"))
+
+ implicit val typeInfo = createTypeInformation[(String, Int)]
+ data.writeAsText(resultPath, WriteMode.OVERWRITE)
+
+ env.execute()
+
+ expected = "ABC,0.00,3\nDEF,0.00,5\nDEF,0.00,1\nGHI,0.00,10"
+ }
+}
+
+class POJOItem(var f1: String, var f2: Double, var f3: Int) {
+ def this() {
+ this("", 0.0, 0)
+ }
+
+ override def toString: String = "%s,%.02f,%d".format(f1, f2, f3)
+}