Repository: flink Updated Branches: refs/heads/master 0b3ca57b4 -> 501a9b085
[FLINK-2690] [api-breaking] [scala api] [java api] Adds functionality to the CsvInputFormat to find fields defined in a super class of a Pojo. Refactors CsvInputFormat to share code between this format and ScalaCsvInputFormat. This closes #1141. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/501a9b08 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/501a9b08 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/501a9b08 Branch: refs/heads/master Commit: 501a9b0854995b3de1c293b3391eeb3048daeef0 Parents: 0b3ca57 Author: Till Rohrmann <[email protected]> Authored: Thu Sep 17 15:28:25 2015 +0200 Committer: Till Rohrmann <[email protected]> Committed: Mon Sep 28 12:22:27 2015 +0200 ---------------------------------------------------------------------- .../wordcount/BoltTokenizerWordCountPojo.java | 4 +- .../BoltTokenizerWordCountWithNames.java | 4 +- .../api/common/io/GenericCsvInputFormat.java | 1 - .../flink/api/java/io/CommonCsvInputFormat.java | 258 +++++++++++++++++++ .../flink/api/java/io/CsvInputFormat.java | 211 +-------------- .../flink/api/java/io/CsvInputFormatTest.java | 114 +++++++- .../scala/operators/ScalaCsvInputFormat.java | 204 ++------------- .../scala/operators/ScalaCsvOutputFormat.java | 6 +- .../flink/api/scala/ExecutionEnvironment.scala | 21 +- .../api/java/common/PlanBinder.java | 11 +- .../flink/api/scala/io/CsvInputFormatTest.scala | 90 +++++-- 11 files changed, 497 insertions(+), 427 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/501a9b08/flink-contrib/flink-storm-compatibility/flink-storm-compatibility-examples/src/main/java/org/apache/flink/stormcompatibility/wordcount/BoltTokenizerWordCountPojo.java ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-storm-compatibility/flink-storm-compatibility-examples/src/main/java/org/apache/flink/stormcompatibility/wordcount/BoltTokenizerWordCountPojo.java b/flink-contrib/flink-storm-compatibility/flink-storm-compatibility-examples/src/main/java/org/apache/flink/stormcompatibility/wordcount/BoltTokenizerWordCountPojo.java index befb18f..20e69db 100644 --- a/flink-contrib/flink-storm-compatibility/flink-storm-compatibility-examples/src/main/java/org/apache/flink/stormcompatibility/wordcount/BoltTokenizerWordCountPojo.java +++ b/flink-contrib/flink-storm-compatibility/flink-storm-compatibility-examples/src/main/java/org/apache/flink/stormcompatibility/wordcount/BoltTokenizerWordCountPojo.java @@ -19,9 +19,9 @@ package org.apache.flink.stormcompatibility.wordcount; import backtype.storm.topology.IRichBolt; -import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.io.CsvInputFormat; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.PojoTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.core.fs.Path; import org.apache.flink.examples.java.wordcount.util.WordCountData; @@ -121,7 +121,7 @@ public class BoltTokenizerWordCountPojo { private static DataStream<Sentence> getTextDataStream(final StreamExecutionEnvironment env) { if (fileOutput) { // read the text file from given input path - TypeInformation<Sentence> sourceType = TypeExtractor + PojoTypeInfo<Sentence> sourceType = (PojoTypeInfo)TypeExtractor .getForObject(new Sentence("")); return env.createInput(new CsvInputFormat<Sentence>(new Path( textPath), CsvInputFormat.DEFAULT_LINE_DELIMITER, http://git-wip-us.apache.org/repos/asf/flink/blob/501a9b08/flink-contrib/flink-storm-compatibility/flink-storm-compatibility-examples/src/main/java/org/apache/flink/stormcompatibility/wordcount/BoltTokenizerWordCountWithNames.java ---------------------------------------------------------------------- diff --git a/flink-contrib/flink-storm-compatibility/flink-storm-compatibility-examples/src/main/java/org/apache/flink/stormcompatibility/wordcount/BoltTokenizerWordCountWithNames.java b/flink-contrib/flink-storm-compatibility/flink-storm-compatibility-examples/src/main/java/org/apache/flink/stormcompatibility/wordcount/BoltTokenizerWordCountWithNames.java index 8483f48..e233da1 100644 --- a/flink-contrib/flink-storm-compatibility/flink-storm-compatibility-examples/src/main/java/org/apache/flink/stormcompatibility/wordcount/BoltTokenizerWordCountWithNames.java +++ b/flink-contrib/flink-storm-compatibility/flink-storm-compatibility-examples/src/main/java/org/apache/flink/stormcompatibility/wordcount/BoltTokenizerWordCountWithNames.java @@ -20,11 +20,11 @@ package org.apache.flink.stormcompatibility.wordcount; import backtype.storm.topology.IRichBolt; import backtype.storm.tuple.Fields; -import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.io.CsvInputFormat; import org.apache.flink.api.java.tuple.Tuple; import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.core.fs.Path; import org.apache.flink.examples.java.wordcount.util.WordCountData; @@ -124,7 +124,7 @@ public class BoltTokenizerWordCountWithNames { private static DataStream<Tuple1<String>> getTextDataStream(final StreamExecutionEnvironment env) { if (fileOutput) { // read the text file from given input path - TypeInformation<Tuple1<String>> sourceType = TypeExtractor + TupleTypeInfo<Tuple1<String>> sourceType = (TupleTypeInfo<Tuple1<String>>)TypeExtractor .getForObject(new Tuple1<String>("")); return env.createInput(new CsvInputFormat<Tuple1<String>>(new Path( textPath), CsvInputFormat.DEFAULT_LINE_DELIMITER, http://git-wip-us.apache.org/repos/asf/flink/blob/501a9b08/flink-core/src/main/java/org/apache/flink/api/common/io/GenericCsvInputFormat.java ---------------------------------------------------------------------- diff --git a/flink-core/src/main/java/org/apache/flink/api/common/io/GenericCsvInputFormat.java b/flink-core/src/main/java/org/apache/flink/api/common/io/GenericCsvInputFormat.java index 8d979bb..e68d271 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/io/GenericCsvInputFormat.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/io/GenericCsvInputFormat.java @@ -40,7 +40,6 @@ import java.util.ArrayList; import java.util.Map; import java.util.TreeMap; - public abstract class GenericCsvInputFormat<OT> extends DelimitedInputFormat<OT> { private static final Logger LOG = LoggerFactory.getLogger(GenericCsvInputFormat.class); http://git-wip-us.apache.org/repos/asf/flink/blob/501a9b08/flink-java/src/main/java/org/apache/flink/api/java/io/CommonCsvInputFormat.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/io/CommonCsvInputFormat.java b/flink-java/src/main/java/org/apache/flink/api/java/io/CommonCsvInputFormat.java new file mode 100644 index 0000000..444d151 --- /dev/null +++ b/flink-java/src/main/java/org/apache/flink/api/java/io/CommonCsvInputFormat.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.java.io; + +import com.google.common.base.Preconditions; +import org.apache.flink.api.common.io.GenericCsvInputFormat; +import org.apache.flink.api.common.typeutils.CompositeType; +import org.apache.flink.api.java.typeutils.PojoTypeInfo; +import org.apache.flink.core.fs.FileInputSplit; +import org.apache.flink.core.fs.Path; +import org.apache.flink.types.parser.FieldParser; + +import java.io.IOException; +import java.lang.reflect.Field; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +public abstract class CommonCsvInputFormat<OUT> extends GenericCsvInputFormat<OUT> { + + private static final long serialVersionUID = 1L; + + public static final String DEFAULT_LINE_DELIMITER = "\n"; + + public static final String DEFAULT_FIELD_DELIMITER = ","; + + protected transient Object[] parsedValues; + + private final Class<OUT> pojoTypeClass; + + private String[] pojoFieldNames; + + private transient PojoTypeInfo<OUT> pojoTypeInfo; + private transient Field[] pojoFields; + + public CommonCsvInputFormat(Path filePath, CompositeType<OUT> typeInformation) { + this(filePath, DEFAULT_LINE_DELIMITER, DEFAULT_FIELD_DELIMITER, typeInformation); + } + + public CommonCsvInputFormat( + Path filePath, + String lineDelimiter, + String fieldDelimiter, + CompositeType<OUT> compositeTypeInfo) { + super(filePath); + + setDelimiter(lineDelimiter); + setFieldDelimiter(fieldDelimiter); + + Class<?>[] classes = new Class<?>[compositeTypeInfo.getArity()]; + + for (int i = 0; i < compositeTypeInfo.getArity(); i++) { + classes[i] = compositeTypeInfo.getTypeAt(i).getTypeClass(); + } + + setFieldTypes(classes); + + if (compositeTypeInfo instanceof PojoTypeInfo) { + pojoTypeInfo = (PojoTypeInfo<OUT>) compositeTypeInfo; + + pojoTypeClass = compositeTypeInfo.getTypeClass(); + setOrderOfPOJOFields(compositeTypeInfo.getFieldNames()); + } else { + pojoTypeClass = null; + pojoFieldNames = null; + } + } + + public void setOrderOfPOJOFields(String[] fieldNames) { + Preconditions.checkNotNull(pojoTypeClass, "Field order can only be specified if output type is a POJO."); + Preconditions.checkNotNull(fieldNames); + + int includedCount = 0; + for (boolean isIncluded : fieldIncluded) { + if (isIncluded) { + includedCount++; + } + } + + Preconditions.checkArgument(includedCount == fieldNames.length, includedCount + + " CSV fields and " + fieldNames.length + " POJO fields selected. The number of selected CSV and POJO fields must be equal."); + + for (String field : fieldNames) { + Preconditions.checkNotNull(field, "The field name cannot be null."); + Preconditions.checkArgument(pojoTypeInfo.getFieldIndex(field) != -1, + "Field \""+ field + "\" is not a member of POJO class " + pojoTypeClass.getName()); + } + + pojoFieldNames = Arrays.copyOfRange(fieldNames, 0, fieldNames.length); + } + + public void setFieldTypes(Class<?>... fieldTypes) { + if (fieldTypes == null || fieldTypes.length == 0) { + throw new IllegalArgumentException("Field types must not be null or empty."); + } + + setFieldTypesGeneric(fieldTypes); + } + + public void setFields(int[] sourceFieldIndices, Class<?>[] fieldTypes) { + Preconditions.checkNotNull(sourceFieldIndices); + Preconditions.checkNotNull(fieldTypes); + + checkForMonotonousOrder(sourceFieldIndices, fieldTypes); + + setFieldsGeneric(sourceFieldIndices, fieldTypes); + } + + public void setFields(boolean[] sourceFieldMask, Class<?>[] fieldTypes) { + Preconditions.checkNotNull(sourceFieldMask); + Preconditions.checkNotNull(fieldTypes); + + setFieldsGeneric(sourceFieldMask, fieldTypes); + } + + public Class<?>[] getFieldTypes() { + return super.getGenericFieldTypes(); + } + + @Override + public void open(FileInputSplit split) throws IOException { + super.open(split); + + @SuppressWarnings("unchecked") + FieldParser<Object>[] fieldParsers = (FieldParser<Object>[]) getFieldParsers(); + + //throw exception if no field parsers are available + if (fieldParsers.length == 0) { + throw new IOException("CsvInputFormat.open(FileInputSplit split) - no field parsers to parse input"); + } + + // create the value holders + this.parsedValues = new Object[fieldParsers.length]; + for (int i = 0; i < fieldParsers.length; i++) { + this.parsedValues[i] = fieldParsers[i].createValue(); + } + + // left to right evaluation makes access [0] okay + // this marker is used to fasten up readRecord, so that it doesn't have to check each call if the line ending is set to default + if (this.getDelimiter().length == 1 && this.getDelimiter()[0] == '\n' ) { + this.lineDelimiterIsLinebreak = true; + } + + // for POJO type + if (pojoTypeClass != null) { + pojoFields = new Field[pojoFieldNames.length]; + + Map<String, Field> allFields = new HashMap<String, Field>(); + + findAllFields(pojoTypeClass, allFields); + + for (int i = 0; i < pojoFieldNames.length; i++) { + pojoFields[i] = allFields.get(pojoFieldNames[i]); + + if (pojoFields[i] != null) { + pojoFields[i].setAccessible(true); + } else { + throw new RuntimeException("There is no field called \"" + pojoFieldNames[i] + "\" in " + pojoTypeClass.getName()); + } + } + } + + this.commentCount = 0; + this.invalidLineCount = 0; + } + + /** + * Finds all declared fields in a class and all its super classes. + * + * @param clazz Class for which all declared fields are found + * @param allFields Map containing all found fields so far + */ + private void findAllFields(Class<?> clazz, Map<String, Field> allFields) { + for (Field field: clazz.getDeclaredFields()) { + allFields.put(field.getName(), field); + } + + if (clazz.getSuperclass() != null) { + findAllFields(clazz.getSuperclass(), allFields); + } + } + + @Override + public OUT nextRecord(OUT record) throws IOException { + OUT returnRecord = null; + do { + returnRecord = super.nextRecord(record); + } while (returnRecord == null && !reachedEnd()); + + return returnRecord; + } + + @Override + public OUT readRecord(OUT reuse, byte[] bytes, int offset, int numBytes) throws IOException { + /* + * Fix to support windows line endings in CSVInputFiles with standard delimiter setup = \n + */ + //Find windows end line, so find carriage return before the newline + if (this.lineDelimiterIsLinebreak == true && numBytes > 0 && bytes[offset + numBytes -1] == '\r' ) { + //reduce the number of bytes so that the Carriage return is not taken as data + numBytes--; + } + + if (commentPrefix != null && commentPrefix.length <= numBytes) { + //check record for comments + boolean isComment = true; + for (int i = 0; i < commentPrefix.length; i++) { + if (commentPrefix[i] != bytes[offset + i]) { + isComment = false; + break; + } + } + if (isComment) { + this.commentCount++; + return null; + } + } + + if (parseRecord(parsedValues, bytes, offset, numBytes)) { + if (pojoTypeClass == null) { + // result type is tuple + return createTuple(reuse); + } else { + // result type is POJO + for (int i = 0; i < parsedValues.length; i++) { + try { + pojoFields[i].set(reuse, parsedValues[i]); + } catch (IllegalAccessException e) { + throw new RuntimeException("Parsed value could not be set in POJO field \"" + pojoFieldNames[i] + "\"", e); + } + } + return reuse; + } + + } else { + this.invalidLineCount++; + return null; + } + } + + protected abstract OUT createTuple(OUT reuse); +} http://git-wip-us.apache.org/repos/asf/flink/blob/501a9b08/flink-java/src/main/java/org/apache/flink/api/java/io/CsvInputFormat.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/io/CsvInputFormat.java b/flink-java/src/main/java/org/apache/flink/api/java/io/CsvInputFormat.java index ee33484..7d86f39 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/io/CsvInputFormat.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/io/CsvInputFormat.java @@ -19,224 +19,33 @@ package org.apache.flink.api.java.io; -import com.google.common.base.Preconditions; -import org.apache.flink.api.common.io.GenericCsvInputFormat; -import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.java.tuple.Tuple; -import org.apache.flink.api.java.typeutils.PojoTypeInfo; -import org.apache.flink.core.fs.FileInputSplit; import org.apache.flink.core.fs.Path; -import org.apache.flink.types.parser.FieldParser; import org.apache.flink.util.StringUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import java.io.IOException; -import java.lang.reflect.Field; -import java.util.Arrays; - - -public class CsvInputFormat<OUT> extends GenericCsvInputFormat<OUT> { +public class CsvInputFormat<OUT> extends CommonCsvInputFormat<OUT> { private static final long serialVersionUID = 1L; - /** - * The log. - */ - private static final Logger LOG = LoggerFactory.getLogger(CsvInputFormat.class); - - public static final String DEFAULT_LINE_DELIMITER = "\n"; - - public static final String DEFAULT_FIELD_DELIMITER = ","; - - private transient Object[] parsedValues; - - private Class<OUT> pojoTypeClass = null; - private String[] pojoFieldsName = null; - private transient Field[] pojoFields = null; - private transient PojoTypeInfo<OUT> pojoTypeInfo = null; - - public CsvInputFormat(Path filePath, TypeInformation<OUT> typeInformation) { - this(filePath, DEFAULT_LINE_DELIMITER, DEFAULT_FIELD_DELIMITER, typeInformation); + public CsvInputFormat(Path filePath, CompositeType<OUT> typeInformation) { + super(filePath, typeInformation); } - public CsvInputFormat(Path filePath, String lineDelimiter, String fieldDelimiter, TypeInformation<OUT> typeInformation) { - super(filePath); - - Preconditions.checkArgument(typeInformation instanceof CompositeType); - CompositeType<OUT> compositeType = (CompositeType<OUT>) typeInformation; - - setDelimiter(lineDelimiter); - setFieldDelimiter(fieldDelimiter); - - Class<?>[] classes = new Class<?>[typeInformation.getArity()]; - for (int i = 0, arity = typeInformation.getArity(); i < arity; i++) { - classes[i] = compositeType.getTypeAt(i).getTypeClass(); - } - setFieldTypes(classes); - - if (typeInformation instanceof PojoTypeInfo) { - pojoTypeInfo = (PojoTypeInfo<OUT>) typeInformation; - pojoTypeClass = typeInformation.getTypeClass(); - pojoFieldsName = compositeType.getFieldNames(); - setOrderOfPOJOFields(pojoFieldsName); - } - } - - public void setOrderOfPOJOFields(String[] fieldsOrder) { - Preconditions.checkNotNull(pojoTypeClass, "Field order can only be specified if output type is a POJO."); - Preconditions.checkNotNull(fieldsOrder); - - int includedCount = 0; - for (boolean isIncluded : fieldIncluded) { - if (isIncluded) { - includedCount++; - } - } - - Preconditions.checkArgument(includedCount == fieldsOrder.length, includedCount + - " CSV fields and " + fieldsOrder.length + " POJO fields selected. The number of selected CSV and POJO fields must be equal."); - - for (String field : fieldsOrder) { - Preconditions.checkNotNull(field, "The field name cannot be null."); - Preconditions.checkArgument(pojoTypeInfo.getFieldIndex(field) != -1, - "Field \""+ field + "\" is not a member of POJO class " + pojoTypeClass.getName()); - } - - pojoFieldsName = Arrays.copyOfRange(fieldsOrder, 0, fieldsOrder.length); - } - - public void setFieldTypes(Class<?>... fieldTypes) { - if (fieldTypes == null || fieldTypes.length == 0) { - throw new IllegalArgumentException("Field types must not be null or empty."); - } - - setFieldTypesGeneric(fieldTypes); - } - - public void setFields(int[] sourceFieldIndices, Class<?>[] fieldTypes) { - Preconditions.checkNotNull(sourceFieldIndices); - Preconditions.checkNotNull(fieldTypes); - - checkForMonotonousOrder(sourceFieldIndices, fieldTypes); - - setFieldsGeneric(sourceFieldIndices, fieldTypes); - } - - public void setFields(boolean[] sourceFieldMask, Class<?>[] fieldTypes) { - Preconditions.checkNotNull(sourceFieldMask); - Preconditions.checkNotNull(fieldTypes); - - setFieldsGeneric(sourceFieldMask, fieldTypes); + public CsvInputFormat(Path filePath, String lineDelimiter, String fieldDelimiter, CompositeType<OUT> typeInformation) { + super(filePath, lineDelimiter, fieldDelimiter, typeInformation); } - public Class<?>[] getFieldTypes() { - return super.getGenericFieldTypes(); - } - @Override - public void open(FileInputSplit split) throws IOException { - super.open(split); - - @SuppressWarnings("unchecked") - FieldParser<Object>[] fieldParsers = (FieldParser<Object>[]) getFieldParsers(); - - //throw exception if no field parsers are available - if (fieldParsers.length == 0) { - throw new IOException("CsvInputFormat.open(FileInputSplit split) - no field parsers to parse input"); - } - - // create the value holders - this.parsedValues = new Object[fieldParsers.length]; - for (int i = 0; i < fieldParsers.length; i++) { - this.parsedValues[i] = fieldParsers[i].createValue(); - } - - // left to right evaluation makes access [0] okay - // this marker is used to fasten up readRecord, so that it doesn't have to check each call if the line ending is set to default - if (this.getDelimiter().length == 1 && this.getDelimiter()[0] == '\n' ) { - this.lineDelimiterIsLinebreak = true; + protected OUT createTuple(OUT reuse) { + Tuple result = (Tuple) reuse; + for (int i = 0; i < parsedValues.length; i++) { + result.setField(parsedValues[i], i); } - // for POJO type - if (pojoTypeClass != null) { - pojoFields = new Field[pojoFieldsName.length]; - for (int i = 0; i < pojoFieldsName.length; i++) { - try { - pojoFields[i] = pojoTypeClass.getDeclaredField(pojoFieldsName[i]); - pojoFields[i].setAccessible(true); - } catch (NoSuchFieldException e) { - throw new RuntimeException("There is no field called \"" + pojoFieldsName[i] + "\" in " + pojoTypeClass.getName(), e); - } - } - } - - this.commentCount = 0; - this.invalidLineCount = 0; + return reuse; } - - @Override - public OUT nextRecord(OUT record) throws IOException { - OUT returnRecord = null; - do { - returnRecord = super.nextRecord(record); - } while (returnRecord == null && !reachedEnd()); - return returnRecord; - } - - @Override - public OUT readRecord(OUT reuse, byte[] bytes, int offset, int numBytes) throws IOException { - /* - * Fix to support windows line endings in CSVInputFiles with standard delimiter setup = \n - */ - //Find windows end line, so find carriage return before the newline - if (this.lineDelimiterIsLinebreak == true && numBytes > 0 && bytes[offset + numBytes -1] == '\r' ) { - //reduce the number of bytes so that the Carriage return is not taken as data - numBytes--; - } - - if (commentPrefix != null && commentPrefix.length <= numBytes) { - //check record for comments - boolean isComment = true; - for (int i = 0; i < commentPrefix.length; i++) { - if (commentPrefix[i] != bytes[offset + i]) { - isComment = false; - break; - } - } - if (isComment) { - this.commentCount++; - return null; - } - } - - if (parseRecord(parsedValues, bytes, offset, numBytes)) { - if (pojoTypeClass == null) { - // result type is tuple - Tuple result = (Tuple) reuse; - for (int i = 0; i < parsedValues.length; i++) { - result.setField(parsedValues[i], i); - } - } else { - // result type is POJO - for (int i = 0; i < parsedValues.length; i++) { - try { - pojoFields[i].set(reuse, parsedValues[i]); - } catch (IllegalAccessException e) { - throw new RuntimeException("Parsed value could not be set in POJO field \"" + pojoFieldsName[i] + "\"", e); - } - } - } - return reuse; - } else { - this.invalidLineCount++; - return null; - } - } - - @Override public String toString() { return "CSV Input (" + StringUtils.showControlCharacters(String.valueOf(getFieldDelimiter())) + ") " + getFilePath(); http://git-wip-us.apache.org/repos/asf/flink/blob/501a9b08/flink-java/src/test/java/org/apache/flink/api/java/io/CsvInputFormatTest.java ---------------------------------------------------------------------- diff --git a/flink-java/src/test/java/org/apache/flink/api/java/io/CsvInputFormatTest.java b/flink-java/src/test/java/org/apache/flink/api/java/io/CsvInputFormatTest.java index 4f1c339..2c2ffa5 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/io/CsvInputFormatTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/io/CsvInputFormatTest.java @@ -22,8 +22,9 @@ package org.apache.flink.api.java.io; import com.google.common.base.Charsets; import org.apache.flink.api.common.io.ParseException; -import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.java.tuple.*; +import org.apache.flink.api.java.typeutils.PojoTypeInfo; import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.typeutils.TypeExtractor; import org.apache.flink.configuration.Configuration; @@ -37,6 +38,8 @@ import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStreamWriter; +import java.util.ArrayList; +import java.util.List; import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertEquals; @@ -548,7 +551,7 @@ public class CsvInputFormatTest { format.setFieldDelimiter("|"); - format.setFields(new int[] {0, 3, 7}, new Class<?>[] {Integer.class, Integer.class, Integer.class}); + format.setFields(new int[]{0, 3, 7}, new Class<?>[]{Integer.class, Integer.class, Integer.class}); format.configure(new Configuration()); @@ -823,7 +826,7 @@ public class CsvInputFormatTest { wrt.close(); @SuppressWarnings("unchecked") - TypeInformation<PojoItem> typeInfo = (TypeInformation<PojoItem>) TypeExtractor.createTypeInfo(PojoItem.class); + PojoTypeInfo<PojoItem> typeInfo = (PojoTypeInfo<PojoItem>) TypeExtractor.createTypeInfo(PojoItem.class); CsvInputFormat<PojoItem> inputFormat = new CsvInputFormat<PojoItem>(new Path(tempFile.toURI().toString()), typeInfo); inputFormat.configure(new Configuration()); @@ -846,7 +849,7 @@ public class CsvInputFormatTest { wrt.close(); @SuppressWarnings("unchecked") - TypeInformation<PrivatePojoItem> typeInfo = (TypeInformation<PrivatePojoItem>) TypeExtractor.createTypeInfo(PrivatePojoItem.class); + PojoTypeInfo<PrivatePojoItem> typeInfo = (PojoTypeInfo<PrivatePojoItem>) TypeExtractor.createTypeInfo(PrivatePojoItem.class); CsvInputFormat<PrivatePojoItem> inputFormat = new CsvInputFormat<PrivatePojoItem>(new Path(tempFile.toURI().toString()), typeInfo); inputFormat.configure(new Configuration()); @@ -882,7 +885,7 @@ public class CsvInputFormatTest { wrt.close(); @SuppressWarnings("unchecked") - TypeInformation<PojoItem> typeInfo = (TypeInformation<PojoItem>) TypeExtractor.createTypeInfo(PojoItem.class); + PojoTypeInfo<PojoItem> typeInfo = (PojoTypeInfo<PojoItem>) TypeExtractor.createTypeInfo(PojoItem.class); CsvInputFormat<PojoItem> inputFormat = new CsvInputFormat<PojoItem>(new Path(tempFile.toURI().toString()), typeInfo); inputFormat.setFields(new boolean[]{true, true, true, true}, new Class<?>[]{Integer.class, Double.class, String.class, String.class}); inputFormat.setOrderOfPOJOFields(new String[]{"field1", "field3", "field2", "field4"}); @@ -907,7 +910,7 @@ public class CsvInputFormatTest { wrt.close(); @SuppressWarnings("unchecked") - TypeInformation<PojoItem> typeInfo = (TypeInformation<PojoItem>) TypeExtractor.createTypeInfo(PojoItem.class); + PojoTypeInfo<PojoItem> typeInfo = (PojoTypeInfo<PojoItem>) TypeExtractor.createTypeInfo(PojoItem.class); CsvInputFormat<PojoItem> inputFormat = new CsvInputFormat<PojoItem>(new Path(tempFile.toURI().toString()), typeInfo); inputFormat.setFields(new boolean[]{true, false, true, false, true, true}, new Class[]{Integer.class, String .class, Double.class, String.class}); @@ -932,7 +935,7 @@ public class CsvInputFormatTest { wrt.close(); @SuppressWarnings("unchecked") - TypeInformation<PojoItem> typeInfo = (TypeInformation<PojoItem>) TypeExtractor.createTypeInfo(PojoItem.class); + PojoTypeInfo<PojoItem> typeInfo = (PojoTypeInfo<PojoItem>) TypeExtractor.createTypeInfo(PojoItem.class); CsvInputFormat<PojoItem> inputFormat = new CsvInputFormat<PojoItem>(new Path(tempFile.toURI().toString()), typeInfo); inputFormat.setFields(new boolean[]{true, false, false, true}, new Class[]{Integer.class, String.class}); inputFormat.setOrderOfPOJOFields(new String[]{"field1", "field4"}); @@ -956,7 +959,7 @@ public class CsvInputFormatTest { tempFile.setWritable(true); @SuppressWarnings("unchecked") - TypeInformation<PojoItem> typeInfo = (TypeInformation<PojoItem>) TypeExtractor.createTypeInfo(PojoItem.class); + PojoTypeInfo<PojoItem> typeInfo = (PojoTypeInfo<PojoItem>) TypeExtractor.createTypeInfo(PojoItem.class); CsvInputFormat<PojoItem> inputFormat = new CsvInputFormat<PojoItem>(new Path(tempFile.toURI().toString()), typeInfo); try { @@ -994,7 +997,7 @@ public class CsvInputFormatTest { writer.write(fileContent); writer.close(); - TypeInformation<Tuple2<String, String>> typeInfo = TupleTypeInfo.getBasicTupleTypeInfo(String.class, String.class); + CompositeType<Tuple2<String, String>> typeInfo = TupleTypeInfo.getBasicTupleTypeInfo(String.class, String.class); CsvInputFormat<Tuple2<String, String>> inputFormat = new CsvInputFormat<Tuple2<String, String>>(new Path(tempFile.toURI().toString()), typeInfo); inputFormat.enableQuotedStringParsing('"'); @@ -1025,7 +1028,7 @@ public class CsvInputFormatTest { writer.write(fileContent); writer.close(); - TypeInformation<Tuple2<String, String>> typeInfo = TupleTypeInfo.getBasicTupleTypeInfo(String.class, String.class); + TupleTypeInfo<Tuple2<String, String>> typeInfo = TupleTypeInfo.getBasicTupleTypeInfo(String.class, String.class); CsvInputFormat<Tuple2<String, String>> inputFormat = new CsvInputFormat<>(new Path(tempFile.toURI().toString()), typeInfo); inputFormat.enableQuotedStringParsing('"'); @@ -1043,6 +1046,49 @@ public class CsvInputFormatTest { assertEquals("We are\\\" young", record.f1); } + /** + * Tests that the CSV input format can deal with POJOs which are subclasses. + * + * @throws Exception + */ + @Test + public void testPojoSubclassType() throws Exception { + final String fileContent = "t1,foobar,tweet2\nt2,barfoo,tweet2"; + + final File tempFile = File.createTempFile("CsvReaderPOJOSubclass", "tmp"); + tempFile.deleteOnExit(); + + OutputStreamWriter writer = new OutputStreamWriter(new FileOutputStream(tempFile)); + writer.write(fileContent); + writer.close(); + + @SuppressWarnings("unchecked") + PojoTypeInfo<TwitterPOJO> typeInfo = (PojoTypeInfo<TwitterPOJO>)TypeExtractor.createTypeInfo(TwitterPOJO.class); + CsvInputFormat<TwitterPOJO> inputFormat = new CsvInputFormat<>(new Path(tempFile.toURI().toString()), typeInfo); + + inputFormat.configure(new Configuration()); + FileInputSplit[] splits = inputFormat.createInputSplits(1); + + inputFormat.open(splits[0]); + + List<TwitterPOJO> expected = new ArrayList<>(); + + for (String line: fileContent.split("\n")) { + String[] elements = line.split(","); + expected.add(new TwitterPOJO(elements[0], elements[1], elements[2])); + } + + List<TwitterPOJO> actual = new ArrayList<>(); + + TwitterPOJO pojo; + + while((pojo = inputFormat.nextRecord(new TwitterPOJO())) != null) { + actual.add(pojo); + } + + assertEquals(expected, actual); + } + // -------------------------------------------------------------------------------------------- // Custom types for testing // -------------------------------------------------------------------------------------------- @@ -1093,4 +1139,52 @@ public class CsvInputFormatTest { } } + public static class POJO { + public String table; + public String time; + + public POJO() { + this("", ""); + } + + public POJO(String table, String time) { + this.table = table; + this.time = time; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof POJO) { + POJO other = (POJO) obj; + return table.equals(other.table) && time.equals(other.time); + } else { + return false; + } + } + } + + public static class TwitterPOJO extends POJO { + public String tweet; + + public TwitterPOJO() { + this("", "", ""); + } + + public TwitterPOJO(String table, String time, String tweet) { + super(table, time); + this.tweet = tweet; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof TwitterPOJO) { + TwitterPOJO other = (TwitterPOJO) obj; + + return super.equals(other) && tweet.equals(other.tweet); + } else { + return false; + } + } + } + } http://git-wip-us.apache.org/repos/asf/flink/blob/501a9b08/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvInputFormat.java ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvInputFormat.java b/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvInputFormat.java index 9adbed8..522cc0c 100644 --- a/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvInputFormat.java +++ b/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvInputFormat.java @@ -21,213 +21,45 @@ package org.apache.flink.api.scala.operators; import com.google.common.base.Preconditions; import org.apache.flink.api.common.ExecutionConfig; -import org.apache.flink.api.common.io.GenericCsvInputFormat; -import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; +import org.apache.flink.api.java.io.CommonCsvInputFormat; import org.apache.flink.api.java.typeutils.PojoTypeInfo; import org.apache.flink.api.java.typeutils.TupleTypeInfoBase; import org.apache.flink.api.java.typeutils.runtime.TupleSerializerBase; -import org.apache.flink.core.fs.FileInputSplit; import org.apache.flink.core.fs.Path; +import org.apache.flink.util.StringUtils; -import org.apache.flink.types.parser.FieldParser; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.IOException; -import java.lang.reflect.Field; -import java.util.Arrays; - -public class ScalaCsvInputFormat<OUT> extends GenericCsvInputFormat<OUT> { - - private static final long serialVersionUID = 1L; - - private static final Logger LOG = LoggerFactory.getLogger(ScalaCsvInputFormat.class); - - private transient Object[] parsedValues; +public class ScalaCsvInputFormat<OUT> extends CommonCsvInputFormat<OUT> { + private static final long serialVersionUID = -7347888812778968640L; private final TupleSerializerBase<OUT> tupleSerializer; - private Class<OUT> pojoTypeClass = null; - private String[] pojoFieldsName = null; - private transient Field[] pojoFields = null; - private transient PojoTypeInfo<OUT> pojoTypeInfo = null; - - public ScalaCsvInputFormat(Path filePath, TypeInformation<OUT> typeInfo) { - super(filePath); + public ScalaCsvInputFormat(Path filePath, CompositeType<OUT> typeInfo) { + super(filePath, typeInfo); - Class<?>[] classes = new Class[typeInfo.getArity()]; + Preconditions.checkArgument(typeInfo instanceof PojoTypeInfo || typeInfo instanceof TupleTypeInfoBase, + "Only pojo types or tuple types are supported."); if (typeInfo instanceof TupleTypeInfoBase) { - TupleTypeInfoBase<OUT> tupleType = (TupleTypeInfoBase<OUT>) typeInfo; - // We can use an empty config here, since we only use the serializer to create - // the top-level case class - tupleSerializer = (TupleSerializerBase<OUT>) tupleType.createSerializer(new ExecutionConfig()); - - for (int i = 0; i < tupleType.getArity(); i++) { - classes[i] = tupleType.getTypeAt(i).getTypeClass(); - } + TupleTypeInfoBase<OUT> tupleTypeInfo = (TupleTypeInfoBase<OUT>) typeInfo; - setFieldTypes(classes); + tupleSerializer = (TupleSerializerBase<OUT>)tupleTypeInfo.createSerializer(new ExecutionConfig()); } else { tupleSerializer = null; - pojoTypeInfo = (PojoTypeInfo<OUT>) typeInfo; - pojoTypeClass = typeInfo.getTypeClass(); - pojoFieldsName = pojoTypeInfo.getFieldNames(); - - for (int i = 0, arity = pojoTypeInfo.getArity(); i < arity; i++) { - classes[i] = pojoTypeInfo.getTypeAt(i).getTypeClass(); - } - - setFieldTypes(classes); - setOrderOfPOJOFields(pojoFieldsName); - } - } - - public void setOrderOfPOJOFields(String[] fieldsOrder) { - Preconditions.checkNotNull(pojoTypeClass, "Field order can only be specified if output type is a POJO."); - Preconditions.checkNotNull(fieldsOrder); - - int includedCount = 0; - for (boolean isIncluded : fieldIncluded) { - if (isIncluded) { - includedCount++; - } - } - - Preconditions.checkArgument(includedCount == fieldsOrder.length, - "The number of selected POJO fields should be the same as that of CSV fields."); - - for (String field : fieldsOrder) { - Preconditions.checkNotNull(field, "The field name cannot be null."); - Preconditions.checkArgument(pojoTypeInfo.getFieldIndex(field) != -1, - "The given field name isn't matched to POJO fields."); - } - - pojoFieldsName = Arrays.copyOfRange(fieldsOrder, 0, fieldsOrder.length); - } - - public void setFieldTypes(Class<?>[] fieldTypes) { - if (fieldTypes == null || fieldTypes.length == 0) { - throw new IllegalArgumentException("Field types must not be null or empty."); - } - - setFieldTypesGeneric(fieldTypes); - } - - public void setFields(int[] sourceFieldIndices, Class<?>[] fieldTypes) { - Preconditions.checkNotNull(sourceFieldIndices); - Preconditions.checkNotNull(fieldTypes); - - checkForMonotonousOrder(sourceFieldIndices, fieldTypes); - - setFieldsGeneric(sourceFieldIndices, fieldTypes); - } - - public void setFields(boolean[] sourceFieldMask, Class<?>[] fieldTypes) { - Preconditions.checkNotNull(sourceFieldMask); - Preconditions.checkNotNull(fieldTypes); - - setFieldsGeneric(sourceFieldMask, fieldTypes); - } - - public Class<?>[] getFieldTypes() { - return super.getGenericFieldTypes(); - } - - @Override - public void open(FileInputSplit split) throws IOException { - super.open(split); - - @SuppressWarnings("unchecked") - FieldParser<Object>[] fieldParsers = (FieldParser<Object>[]) getFieldParsers(); - - //throw exception if no field parsers are available - if (fieldParsers.length == 0) { - throw new IOException("CsvInputFormat.open(FileInputSplit split) - no field parsers to parse input"); - } - - // create the value holders - this.parsedValues = new Object[fieldParsers.length]; - for (int i = 0; i < fieldParsers.length; i++) { - this.parsedValues[i] = fieldParsers[i].createValue(); } - - // left to right evaluation makes access [0] okay - // this marker is used to fasten up readRecord, so that it doesn't have to check each call if the line ending is set to default - if (this.getDelimiter().length == 1 && this.getDelimiter()[0] == '\n' ) { - this.lineDelimiterIsLinebreak = true; - } - - // for POJO type - if (pojoTypeClass != null) { - pojoFields = new Field[pojoFieldsName.length]; - for (int i = 0; i < pojoFieldsName.length; i++) { - try { - pojoFields[i] = pojoTypeClass.getDeclaredField(pojoFieldsName[i]); - pojoFields[i].setAccessible(true); - } catch (NoSuchFieldException e) { - throw new RuntimeException("There is no field called \"" + pojoFieldsName[i] + "\" in " + pojoTypeClass.getName(), e); - } - } - } - - this.commentCount = 0; - this.invalidLineCount = 0; } @Override - public OUT nextRecord(OUT record) throws IOException { - OUT returnRecord = null; - do { - returnRecord = super.nextRecord(record); - } while (returnRecord == null && !reachedEnd()); + protected OUT createTuple(OUT reuse) { + Preconditions.checkNotNull(tupleSerializer, "The tuple serializer must be initialised." + + " It is not initialized if the given type was not a " + + TupleTypeInfoBase.class.getName() + "."); - return returnRecord; + return tupleSerializer.createInstance(parsedValues); } @Override - public OUT readRecord(OUT reuse, byte[] bytes, int offset, int numBytes) { - /* - * Fix to support windows line endings in CSVInputFiles with standard delimiter setup = \n - */ - //Find windows end line, so find carriage return before the newline - if (this.lineDelimiterIsLinebreak == true && numBytes > 0 && bytes[offset + numBytes -1] == '\r' ) { - //reduce the number of bytes so that the Carriage return is not taken as data - numBytes--; - } - - if (commentPrefix != null && commentPrefix.length <= numBytes) { - //check record for comments - boolean isComment = true; - for (int i = 0; i < commentPrefix.length; i++) { - if (commentPrefix[i] != bytes[offset + i]) { - isComment = false; - break; - } - } - if (isComment) { - this.commentCount++; - return null; - } - } - - if (parseRecord(parsedValues, bytes, offset, numBytes)) { - if (tupleSerializer != null) { - return tupleSerializer.createInstance(parsedValues); - } else { - for (int i = 0; i < pojoFields.length; i++) { - try { - pojoFields[i].set(reuse, parsedValues[i]); - } catch (IllegalAccessException e) { - throw new RuntimeException("Parsed value could not be set in POJO field \"" + pojoFieldsName[i] + "\"", e); - } - } - - return reuse; - } - } else { - this.invalidLineCount++; - return null; - } + public String toString() { + return "Scala CSV Input (" + StringUtils.showControlCharacters(String.valueOf(getFieldDelimiter())) + ") " + getFilePath(); } } http://git-wip-us.apache.org/repos/asf/flink/blob/501a9b08/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvOutputFormat.java ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvOutputFormat.java b/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvOutputFormat.java index afcdc4a..e29a6e6 100644 --- a/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvOutputFormat.java +++ b/flink-scala/src/main/java/org/apache/flink/api/scala/operators/ScalaCsvOutputFormat.java @@ -23,7 +23,7 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.InvalidProgramException; import org.apache.flink.api.common.io.FileOutputFormat; import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.java.io.CsvInputFormat; +import org.apache.flink.api.java.io.CommonCsvInputFormat; import org.apache.flink.api.java.typeutils.InputTypeConfigurable; import org.apache.flink.core.fs.Path; import org.apache.flink.types.StringValue; @@ -51,9 +51,9 @@ public class ScalaCsvOutputFormat<T extends Product> extends FileOutputFormat<T> // -------------------------------------------------------------------------------------------- - public static final String DEFAULT_LINE_DELIMITER = CsvInputFormat.DEFAULT_LINE_DELIMITER; + public static final String DEFAULT_LINE_DELIMITER = CommonCsvInputFormat.DEFAULT_LINE_DELIMITER; - public static final String DEFAULT_FIELD_DELIMITER = String.valueOf(CsvInputFormat.DEFAULT_FIELD_DELIMITER); + public static final String DEFAULT_FIELD_DELIMITER = String.valueOf(CommonCsvInputFormat.DEFAULT_FIELD_DELIMITER); // -------------------------------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/501a9b08/flink-scala/src/main/scala/org/apache/flink/api/scala/ExecutionEnvironment.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/ExecutionEnvironment.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/ExecutionEnvironment.scala index 745b0d3..85c5410 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/ExecutionEnvironment.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/ExecutionEnvironment.scala @@ -20,9 +20,11 @@ package org.apache.flink.api.scala import java.util.UUID import com.esotericsoftware.kryo.Serializer +import com.google.common.base.Preconditions import org.apache.flink.api.common.io.{FileInputFormat, InputFormat} import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.common.{JobID, ExecutionConfig, JobExecutionResult} +import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.java.io._ import org.apache.flink.api.java.operators.DataSource import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer @@ -53,11 +55,11 @@ import scala.reflect.ClassTag * * To get an execution environment use the methods on the companion object: * - * - [[ExecutionEnvironment.getExecutionEnvironment]] - * - [[ExecutionEnvironment.createLocalEnvironment]] - * - [[ExecutionEnvironment.createRemoteEnvironment]] + * - [[ExecutionEnvironment#getExecutionEnvironment]] + * - [[ExecutionEnvironment#createLocalEnvironment]] + * - [[ExecutionEnvironment#createRemoteEnvironment]] * - * Use [[ExecutionEnvironment.getExecutionEnvironment]] to get the correct environment depending + * Use [[ExecutionEnvironment#getExecutionEnvironment]] to get the correct environment depending * on where the program is executed. If it is run inside an IDE a loca environment will be * created. If the program is submitted to a cluster a remote execution environment will * be created. @@ -294,7 +296,14 @@ class ExecutionEnvironment(javaEnv: JavaEnv) { val typeInfo = implicitly[TypeInformation[T]] - val inputFormat = new ScalaCsvInputFormat[T](new Path(filePath), typeInfo) + Preconditions.checkArgument( + typeInfo.isInstanceOf[CompositeType[T]], + s"The type $typeInfo has to be a tuple or pojo type.", + null) + + val inputFormat = new ScalaCsvInputFormat[T]( + new Path(filePath), + typeInfo.asInstanceOf[CompositeType[T]]) inputFormat.setDelimiter(lineDelimiter) inputFormat.setFieldDelimiter(fieldDelimiter) inputFormat.setSkipFirstLineAsHeader(ignoreFirstLine) @@ -334,7 +343,7 @@ class ExecutionEnvironment(javaEnv: JavaEnv) { " included fields must match.") inputFormat.setFields(includedFields, classesBuf.toArray) } else { - inputFormat.setFieldTypes(classesBuf.toArray) + inputFormat.setFieldTypes(classesBuf: _*) } if (pojoFields != null) { http://git-wip-us.apache.org/repos/asf/flink/blob/501a9b08/flink-staging/flink-language-binding/flink-language-binding-generic/src/main/java/org/apache/flink/languagebinding/api/java/common/PlanBinder.java ---------------------------------------------------------------------- diff --git a/flink-staging/flink-language-binding/flink-language-binding-generic/src/main/java/org/apache/flink/languagebinding/api/java/common/PlanBinder.java b/flink-staging/flink-language-binding/flink-language-binding-generic/src/main/java/org/apache/flink/languagebinding/api/java/common/PlanBinder.java index 69cd501..8ca0405 100644 --- a/flink-staging/flink-language-binding/flink-language-binding-generic/src/main/java/org/apache/flink/languagebinding/api/java/common/PlanBinder.java +++ b/flink-staging/flink-language-binding/flink-language-binding-generic/src/main/java/org/apache/flink/languagebinding/api/java/common/PlanBinder.java @@ -14,6 +14,8 @@ package org.apache.flink.languagebinding.api.java.common; import java.io.IOException; import java.util.HashMap; + +import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.io.CsvInputFormat; @@ -256,7 +258,14 @@ public abstract class PlanBinder<INFO extends OperationInfo> { protected abstract INFO createOperationInfo(AbstractOperation operationIdentifier) throws IOException; private void createCsvSource(OperationInfo info) throws IOException { - sets.put(info.setID, env.createInput(new CsvInputFormat(new Path(info.path), info.lineDelimiter, info.fieldDelimiter, info.types), info.types).name("CsvSource")); + if (!(info.types instanceof CompositeType)) { + throw new RuntimeException("The output type of a csv source has to be a tuple or a " + + "pojo type. The derived type is " + info); + } + + sets.put(info.setID, env.createInput(new CsvInputFormat(new Path(info.path), + info.lineDelimiter, info.fieldDelimiter, (CompositeType)info.types), info.types) + .name("CsvSource")); } private void createTextSource(OperationInfo info) throws IOException { http://git-wip-us.apache.org/repos/asf/flink/blob/501a9b08/flink-tests/src/test/scala/org/apache/flink/api/scala/io/CsvInputFormatTest.scala ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/io/CsvInputFormatTest.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/io/CsvInputFormatTest.scala index 0d74515..e49f737 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/io/CsvInputFormatTest.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/io/CsvInputFormatTest.scala @@ -20,13 +20,19 @@ package org.apache.flink.api.scala.io import java.io.{File, FileOutputStream, FileWriter, OutputStreamWriter} import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.api.java.io.CsvInputFormatTest.TwitterPOJO +import org.apache.flink.api.java.typeutils.{TupleTypeInfo, PojoTypeInfo} import org.apache.flink.api.scala._ import org.apache.flink.api.scala.operators.ScalaCsvInputFormat +import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo import org.apache.flink.configuration.Configuration import org.apache.flink.core.fs.{FileInputSplit, Path} import org.junit.Assert.{assertEquals, assertNotNull, assertNull, assertTrue, fail} import org.junit.Test +import scala.collection.mutable.ArrayBuffer + class CsvInputFormatTest { private final val PATH: Path = new Path("an/ignored/file/") @@ -45,7 +51,9 @@ class CsvInputFormatTest { "#next|5|6.0|\n" val split = createTempFile(fileContent) val format = new ScalaCsvInputFormat[(String, Integer, Double)]( - PATH, createTypeInformation[(String, Integer, Double)]) + PATH, + createTypeInformation[(String, Integer, Double)] + .asInstanceOf[CaseClassTypeInfo[(String, Integer, Double)]]) format.setDelimiter("\n") format.setFieldDelimiter('|') format.setCommentPrefix("#") @@ -85,7 +93,9 @@ class CsvInputFormatTest { "//next|5|6.0|\n" val split = createTempFile(fileContent) val format = new ScalaCsvInputFormat[(String, Integer, Double)]( - PATH, createTypeInformation[(String, Integer, Double)]) + PATH, + createTypeInformation[(String, Integer, Double)] + .asInstanceOf[CaseClassTypeInfo[(String, Integer, Double)]]) format.setDelimiter("\n") format.setFieldDelimiter('|') format.setCommentPrefix("//") @@ -121,7 +131,9 @@ class CsvInputFormatTest { val fileContent = "abc|def|ghijk\nabc||hhg\n|||" val split = createTempFile(fileContent) val format = new ScalaCsvInputFormat[(String, String, String)]( - PATH, createTypeInformation[(String, String, String)]) + PATH, + createTypeInformation[(String, String, String)] + .asInstanceOf[CaseClassTypeInfo[(String, String, String)]]) format.setDelimiter("\n") format.setFieldDelimiter("|") val parameters = new Configuration @@ -161,7 +173,9 @@ class CsvInputFormatTest { val fileContent = "abc|\"de|f\"|ghijk\n\"a|bc\"||hhg\n|||" val split = createTempFile(fileContent) val format = new ScalaCsvInputFormat[(String, String, String)]( - PATH, createTypeInformation[(String, String, String)]) + PATH, + createTypeInformation[(String, String, String)] + .asInstanceOf[CaseClassTypeInfo[(String, String, String)]]) format.setDelimiter("\n") format.enableQuotedStringParsing('"') format.setFieldDelimiter("|") @@ -202,7 +216,9 @@ class CsvInputFormatTest { val fileContent = "abc|-def|-ghijk\nabc|-|-hhg\n|-|-|-\n" val split = createTempFile(fileContent) val format = new ScalaCsvInputFormat[(String, String, String)]( - PATH, createTypeInformation[(String, String, String)]) + PATH, + createTypeInformation[(String, String, String)] + .asInstanceOf[CaseClassTypeInfo[(String, String, String)]]) format.setDelimiter("\n") format.setFieldDelimiter("|-") val parameters = new Configuration @@ -241,7 +257,8 @@ class CsvInputFormatTest { val fileContent = "111|222|333|444|555\n666|777|888|999|000|\n" val split = createTempFile(fileContent) val format = new ScalaCsvInputFormat[(Int, Int, Int, Int, Int)]( - PATH, createTypeInformation[(Int, Int, Int, Int, Int)]) + PATH, createTypeInformation[(Int, Int, Int, Int, Int)]. + asInstanceOf[CaseClassTypeInfo[(Int, Int, Int, Int, Int)]]) format.setFieldDelimiter("|") format.configure(new Configuration) format.open(split) @@ -276,7 +293,9 @@ class CsvInputFormatTest { val fileContent = "111|x|222|x|333|x|444|x|555|x|\n" + "666|x|777|x|888|x|999|x|000|x|\n" val split = createTempFile(fileContent) - val format = new ScalaCsvInputFormat[(Int, Int)](PATH, createTypeInformation[(Int, Int)]) + val format = new ScalaCsvInputFormat[(Int, Int)]( + PATH, + createTypeInformation[(Int, Int)].asInstanceOf[CaseClassTypeInfo[(Int, Int)]]) format.setFieldDelimiter("|x|") format.configure(new Configuration) format.open(split) @@ -307,7 +326,7 @@ class CsvInputFormatTest { val split = createTempFile(fileContent) val format = new ScalaCsvInputFormat[(Int, Int, Int)]( PATH, - createTypeInformation[(Int, Int, Int)]) + createTypeInformation[(Int, Int, Int)].asInstanceOf[CaseClassTypeInfo[(Int, Int, Int)]]) format.setFieldDelimiter("|") format.setFields(Array(0, 3, 7), Array(classOf[Integer], classOf[Integer], classOf[Integer]): Array[Class[_]]) @@ -339,7 +358,7 @@ class CsvInputFormatTest { try { val format = new ScalaCsvInputFormat[(Int, Int, Int)]( PATH, - createTypeInformation[(Int, Int, Int)]) + createTypeInformation[(Int, Int, Int)].asInstanceOf[CaseClassTypeInfo[(Int, Int, Int)]]) format.setFieldDelimiter("|") try { format.setFields(Array(8, 1, 3), @@ -384,7 +403,7 @@ class CsvInputFormatTest { wrt.write(fileContent) wrt.close() val inputFormat = new ScalaCsvInputFormat[Tuple1[String]](new Path(tempFile.toURI.toString), - createTypeInformation[Tuple1[String]]) + createTypeInformation[Tuple1[String]].asInstanceOf[CaseClassTypeInfo[Tuple1[String]]]) val parameters = new Configuration inputFormat.configure(parameters) inputFormat.setDelimiter(lineBreakerSetup) @@ -442,7 +461,8 @@ class CsvInputFormatTest { def testPOJOType(): Unit = { val fileContent = "123,HELLO,3.123\n" + "456,ABC,1.234" val tempFile = createTempFile(fileContent) - val typeInfo: TypeInformation[POJOItem] = createTypeInformation[POJOItem] + val typeInfo: PojoTypeInfo[POJOItem] = createTypeInformation[POJOItem] + .asInstanceOf[PojoTypeInfo[POJOItem]] val format = new ScalaCsvInputFormat[POJOItem](PATH, typeInfo) format.setDelimiter('\n') @@ -457,7 +477,8 @@ class CsvInputFormatTest { def testCaseClass(): Unit = { val fileContent = "123,HELLO,3.123\n" + "456,ABC,1.234" val tempFile = createTempFile(fileContent) - val typeInfo: TypeInformation[CaseClassItem] = createTypeInformation[CaseClassItem] + val typeInfo: CompositeType[CaseClassItem] = createTypeInformation[CaseClassItem] + .asInstanceOf[CompositeType[CaseClassItem]] val format = new ScalaCsvInputFormat[CaseClassItem](PATH, typeInfo) format.setDelimiter('\n') @@ -472,12 +493,13 @@ class CsvInputFormatTest { def testPOJOTypeWithFieldMapping(): Unit = { val fileContent = "HELLO,123,3.123\n" + "ABC,456,1.234" val tempFile = createTempFile(fileContent) - val typeInfo: TypeInformation[POJOItem] = createTypeInformation[POJOItem] + val typeInfo: PojoTypeInfo[POJOItem] = createTypeInformation[POJOItem] + .asInstanceOf[PojoTypeInfo[POJOItem]] val format = new ScalaCsvInputFormat[POJOItem](PATH, typeInfo) format.setDelimiter('\n') format.setFieldDelimiter(',') - format.setFieldTypes(Array(classOf[String], classOf[Integer], classOf[java.lang.Double])) + format.setFieldTypes(classOf[String], classOf[Integer], classOf[java.lang.Double]) format.setOrderOfPOJOFields(Array("field2", "field1", "field3")) format.configure(new Configuration) format.open(tempFile) @@ -489,7 +511,8 @@ class CsvInputFormatTest { def testPOJOTypeWithFieldSubsetAndDataSubset(): Unit = { val fileContent = "HELLO,123,NODATA,3.123,NODATA\n" + "ABC,456,NODATA,1.234,NODATA" val tempFile = createTempFile(fileContent) - val typeInfo: TypeInformation[POJOItem] = createTypeInformation[POJOItem] + val typeInfo: PojoTypeInfo[POJOItem] = createTypeInformation[POJOItem] + .asInstanceOf[PojoTypeInfo[POJOItem]] val format = new ScalaCsvInputFormat[POJOItem](PATH, typeInfo) format.setDelimiter('\n') @@ -502,4 +525,41 @@ class CsvInputFormatTest { validatePOJOItem(format) } + + @Test + def testPOJOSubclassType(): Unit = { + val fileContent = "t1,foobar,tweet2\nt2,barfoo,tweet2" + val tempFile = createTempFile(fileContent) + val typeInfo: PojoTypeInfo[TwitterPOJO] = createTypeInformation[TwitterPOJO] + .asInstanceOf[PojoTypeInfo[TwitterPOJO]] + val format = new ScalaCsvInputFormat[TwitterPOJO](PATH, typeInfo) + + format.setDelimiter('\n') + format.setFieldDelimiter(',') + format.configure(new Configuration) + format.open(tempFile) + + val expected = for (line <- fileContent.split("\n")) yield { + val elements = line.split(",") + new TwitterPOJO(elements(0), elements(1), elements(2)) + } + + val actual = ArrayBuffer[TwitterPOJO]() + var readNextElement = true + + while (readNextElement) { + val element = format.nextRecord(new TwitterPOJO()) + + if (element != null) { + actual += element + } else { + readNextElement = false + } + } + + assert(expected.sameElements(actual)) + } + + + }
