This is an automated email from the ASF dual-hosted git repository. kaspersor pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/metamodel.git
commit 9cdccc82b9e544c04b4cee796dd59a23d0282b89 Author: Kasper Sørensen <[email protected]> AuthorDate: Tue Mar 12 23:57:54 2019 -0700 METAMODEL-1210: Implemented ARFF data reading. --- arff/pom.xml | 4 + .../org/apache/metamodel/arff/ArffDataContext.java | 67 +++++++++++---- .../org/apache/metamodel/arff/ArffDataSet.java | 98 ++++++++++++++++++++++ .../apache/metamodel/arff/ArffDataContextTest.java | 40 ++++++++- .../java/org/apache/metamodel/schema/Column.java | 5 +- csv/pom.xml | 1 - pom.xml | 5 ++ 7 files changed, 201 insertions(+), 19 deletions(-) diff --git a/arff/pom.xml b/arff/pom.xml index 19f3f5f..102d8ef 100644 --- a/arff/pom.xml +++ b/arff/pom.xml @@ -39,6 +39,10 @@ under the License. <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> </dependency> + <dependency> + <groupId>com.opencsv</groupId> + <artifactId>opencsv</artifactId> + </dependency> <!-- Test dependencies --> <dependency> diff --git a/arff/src/main/java/org/apache/metamodel/arff/ArffDataContext.java b/arff/src/main/java/org/apache/metamodel/arff/ArffDataContext.java index 5926f36..6d4f5ed 100644 --- a/arff/src/main/java/org/apache/metamodel/arff/ArffDataContext.java +++ b/arff/src/main/java/org/apache/metamodel/arff/ArffDataContext.java @@ -29,6 +29,7 @@ import java.util.regex.Pattern; import org.apache.metamodel.MetaModelException; import org.apache.metamodel.QueryPostprocessDataContext; import org.apache.metamodel.data.DataSet; +import org.apache.metamodel.data.MaxRowsDataSet; import org.apache.metamodel.schema.Column; import org.apache.metamodel.schema.ColumnType; import org.apache.metamodel.schema.MutableColumn; @@ -49,9 +50,12 @@ public class ArffDataContext extends QueryPostprocessDataContext { private static final Logger logger = LoggerFactory.getLogger(ArffDataContext.class); + private static final String SECTION_ANNOTATION_RELATION = "@relation"; + private static final String SECTION_ANNOTATION_ATTRIBUTE = "@attribute"; + private static final String SECTION_ANNOTATION_DATA = "@data"; private static final Charset CHARSET = FileHelper.UTF_8_CHARSET; private static final Pattern ATTRIBUTE_DEF_W_DATATYPE_PARAM = - Pattern.compile("\\'?(.+)\\'? ([a-zA-Z]+) \\'?(.+)\\'?"); + Pattern.compile("\\'?(.+)\\'?\\s+([a-zA-Z]+)\\s+\\'?(.+)\\'?"); private final Splitter whitespaceSplitter = Splitter.on(CharMatcher.whitespace()).trimResults().omitEmptyStrings(); @@ -70,36 +74,36 @@ public class ArffDataContext extends QueryPostprocessDataContext { try (BufferedReader reader = createReader()) { boolean inHeader = true; for (String line = reader.readLine(); inHeader && line != null; line = reader.readLine()) { - if (line.startsWith("%")) { - continue; // comment + if (isIgnoreLine(line)) { + continue; } final List<String> split = whitespaceSplitter.limit(2).splitToList(line); - if (split.isEmpty()) { - continue; // empty line - } switch (split.get(0).toLowerCase()) { - case "@relation": + case SECTION_ANNOTATION_RELATION: // table name final String tableName = trimString(split.get(1)); table.setName(tableName); break; - case "@attribute": + case SECTION_ANNOTATION_ATTRIBUTE: // column(s) final String attributeDef = split.get(1).trim(); final String attributeName; final String attributeType; + final String attributeParam; final ColumnType columnType; final int indexOfCurly = attributeDef.indexOf('{'); if (indexOfCurly != -1) { - attributeName = trimString(attributeDef.substring(0, indexOfCurly)); + attributeName = trimString(attributeDef.substring(0, indexOfCurly).trim()); attributeType = attributeDef.substring(indexOfCurly); + attributeParam = null; } else { final Matcher matcher = ATTRIBUTE_DEF_W_DATATYPE_PARAM.matcher(attributeDef); if (matcher.find()) { attributeName = matcher.group(1); - attributeType = matcher.group(2) + ' ' + matcher.group(3); + attributeType = matcher.group(2); + attributeParam = matcher.group(3); } else { // simple attribute definition "[name] [type]" final List<String> attributeDefSplit = whitespaceSplitter.splitToList(attributeDef); @@ -109,6 +113,7 @@ public class ArffDataContext extends QueryPostprocessDataContext { } attributeName = trimString(attributeDefSplit.get(0)); attributeType = attributeDefSplit.get(1); + attributeParam = null; } } switch (attributeType.toLowerCase()) { @@ -128,6 +133,9 @@ public class ArffDataContext extends QueryPostprocessDataContext { case "string": columnType = ColumnType.STRING; break; + case "date": + columnType = ColumnType.DATE; + break; default: if (indexOfCurly == -1) { logger.info( @@ -139,10 +147,11 @@ public class ArffDataContext extends QueryPostprocessDataContext { } final MutableColumn column = new MutableColumn(attributeName, columnType, table); - column.setRemarks(attributeType); + column.setRemarks(attributeParam == null ? attributeType : attributeType + " " + attributeParam); + column.setColumnNumber(table.getColumnCount()); table.addColumn(column); break; - case "@data": + case SECTION_ANNOTATION_DATA: // the header part of the file is done, no more schema to build up inHeader = false; break; @@ -155,7 +164,6 @@ public class ArffDataContext extends QueryPostprocessDataContext { } private String trimString(String string) { - string = string.trim(); if (string.startsWith("'") && string.endsWith("'")) { string = string.substring(1, string.length() - 1); } @@ -173,8 +181,37 @@ public class ArffDataContext extends QueryPostprocessDataContext { @Override protected DataSet materializeMainSchemaTable(Table table, List<Column> columns, int maxRows) { - // TODO Auto-generated method stub - return null; + BufferedReader reader = createReader(); + try { + for (String line = reader.readLine(); line != null; line = reader.readLine()) { + line = line.trim(); + if (isIgnoreLine(line)) { + continue; + } + if (line.equals(SECTION_ANNOTATION_DATA)) { + // start of the data + break; + } + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + final ArffDataSet dataSet = new ArffDataSet(columns, reader); + if (maxRows > -1) { + return new MaxRowsDataSet(dataSet, maxRows); + } else { + return dataSet; + } + } + + protected static boolean isIgnoreLine(String line) { + if (line.trim().isEmpty()) { + return true; + } + if (line.startsWith("%")) { + return true; // comment + } + return false; } } diff --git a/arff/src/main/java/org/apache/metamodel/arff/ArffDataSet.java b/arff/src/main/java/org/apache/metamodel/arff/ArffDataSet.java new file mode 100644 index 0000000..abf1e5d --- /dev/null +++ b/arff/src/main/java/org/apache/metamodel/arff/ArffDataSet.java @@ -0,0 +1,98 @@ +package org.apache.metamodel.arff; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.metamodel.data.AbstractDataSet; +import org.apache.metamodel.data.DefaultRow; +import org.apache.metamodel.data.Row; +import org.apache.metamodel.query.SelectItem; +import org.apache.metamodel.schema.Column; +import org.apache.metamodel.schema.ColumnType; +import org.apache.metamodel.util.NumberComparator; + +import com.opencsv.CSVParser; +import com.opencsv.ICSVParser; + +public class ArffDataSet extends AbstractDataSet { + + private final ICSVParser csvParser = new CSVParser(',', '\''); + private final BufferedReader reader; + private final int[] valueIndices; + private final ColumnType[] valueTypes; + + private String line; + + public ArffDataSet(List<Column> columns, BufferedReader reader) { + super(columns.stream().map(c -> new SelectItem(c)).collect(Collectors.toList())); + this.valueIndices = columns.stream().mapToInt(Column::getColumnNumber).toArray(); + this.valueTypes = columns.stream().map(Column::getType).toArray(ColumnType[]::new); + this.reader = reader; + } + + @Override + public boolean next() { + try { + line = reader.readLine(); + while (line != null && ArffDataContext.isIgnoreLine(line)) { + line = reader.readLine(); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + return line != null; + } + + @Override + public Row getRow() { + if (line == null) { + return null; + } + final String[] stringValues; + try { + stringValues = csvParser.parseLine(line); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + final Object[] values = new Object[valueIndices.length]; + for (int i = 0; i < valueIndices.length; i++) { + final int index = valueIndices[i]; + final String stringValue = stringValues[index]; + final ColumnType type = valueTypes[i]; + if (type.isNumber()) { + if (stringValue.isEmpty() || "?".equals(stringValue)) { + values[i] = null; + } else { + final Number n = NumberComparator.toNumber(stringValue); + if (type == ColumnType.INTEGER) { + values[i] = n.intValue(); + } else { + values[i] = n; + } + } + } else if (type.isTimeBased()) { + // TODO: extract format from column remarks + try { + values[i] = new SimpleDateFormat("yyyy-MM-dd").parse(stringValue); + } catch (ParseException e) { + throw new IllegalStateException(e); + } + } else { + values[i] = stringValue; + } + } + + return new DefaultRow(getHeader(), values); + } + + @Override + public void close() { + super.close(); + } +} diff --git a/arff/src/test/java/org/apache/metamodel/arff/ArffDataContextTest.java b/arff/src/test/java/org/apache/metamodel/arff/ArffDataContextTest.java index d874526..5cd8b92 100644 --- a/arff/src/test/java/org/apache/metamodel/arff/ArffDataContextTest.java +++ b/arff/src/test/java/org/apache/metamodel/arff/ArffDataContextTest.java @@ -28,7 +28,9 @@ import java.io.FilenameFilter; import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.metamodel.data.DataSet; import org.apache.metamodel.schema.Column; import org.apache.metamodel.schema.ColumnType; import org.apache.metamodel.schema.ColumnTypeImpl; @@ -54,6 +56,7 @@ public class ArffDataContextTest { assertTrue(files.length > 1); + final AtomicInteger rowCounter = new AtomicInteger(); final Set<ColumnType> observedColumnTypes = new HashSet<>(); for (File file : files) { @@ -67,16 +70,24 @@ public class ArffDataContextTest { for (Column column : columns) { observedColumnTypes.add(column.getType()); } + + try (DataSet dataSet = dc.query().from(tables.get(0)).selectAll().execute()) { + while (dataSet.next()) { + assertNotNull(dataSet.getRow()); + rowCounter.incrementAndGet(); + } + } } assertTrue(observedColumnTypes.size() > 1); assertTrue(observedColumnTypes.contains(ColumnTypeImpl.STRING)); assertTrue(observedColumnTypes.contains(ColumnTypeImpl.NUMBER)); + assertTrue(rowCounter.get() > 10_000); } // test case that checks our ability to parse column types and names from a specific file. @Test - public void testReadTableOfHypothyroid() { + public void testReadStructureOfHypothyroid() { final File file = new File(wekaDataDir, "hypothyroid.arff"); final ArffDataContext dc = new ArffDataContext(new FileResource(file)); final Schema schema = dc.getDefaultSchema(); @@ -92,5 +103,32 @@ public class ArffDataContextTest { assertNotNull(sexColumn); assertEquals(ColumnType.STRING, sexColumn.getType()); assertEquals("{ F, M}", sexColumn.getRemarks()); + + final Column tshColumn = table.getColumnByName("TSH measured"); + assertNotNull(tshColumn); + assertEquals(ColumnType.STRING, tshColumn.getType()); + assertEquals("{ t, f}", tshColumn.getRemarks()); + } + + @Test + public void testReadDataOfHypothyroid() { + final File file = new File(wekaDataDir, "hypothyroid.arff"); + final ArffDataContext dc = new ArffDataContext(new FileResource(file)); + + final DataSet dataSet = + dc.query().from("hypothyroid").select("Class", "age", "sex", "TSH measured").limit(3).execute(); + try { + assertTrue(dataSet.next()); + assertEquals("Row[values=[negative, 41, F, t]]", dataSet.getRow().toString()); + assertTrue(dataSet.next()); + assertEquals("Row[values=[negative, 23, F, t]]", dataSet.getRow().toString()); + assertTrue(dataSet.next()); + assertEquals("Row[values=[negative, 46, M, t]]", dataSet.getRow().toString()); + final Object ageValue = dataSet.getRow().getValue(1); + assertEquals(Integer.class, ageValue.getClass()); + assertFalse(dataSet.next()); + } finally { + dataSet.close(); + } } } diff --git a/core/src/main/java/org/apache/metamodel/schema/Column.java b/core/src/main/java/org/apache/metamodel/schema/Column.java index 6874990..d27e791 100644 --- a/core/src/main/java/org/apache/metamodel/schema/Column.java +++ b/core/src/main/java/org/apache/metamodel/schema/Column.java @@ -40,8 +40,9 @@ public interface Column extends Comparable<Column>, Serializable, NamedStructure public String getName(); /** - * Returns the column number or index. Note: This column number is 0-based - * whereas the JDBC is 1-based. + * Returns the column number or index. + * + * Note: This column number is 0-based whereas JDBC's column numbers are 1-based. * * @return the number of this column. */ diff --git a/csv/pom.xml b/csv/pom.xml index 543c655..09c494a 100644 --- a/csv/pom.xml +++ b/csv/pom.xml @@ -40,7 +40,6 @@ under the License. <dependency> <groupId>com.opencsv</groupId> <artifactId>opencsv</artifactId> - <version>3.9</version> </dependency> <dependency> <groupId>junit</groupId> diff --git a/pom.xml b/pom.xml index 4215be1..6c810a6 100644 --- a/pom.xml +++ b/pom.xml @@ -480,6 +480,11 @@ under the License. <version>${guava.version}</version> </dependency> <dependency> + <groupId>com.opencsv</groupId> + <artifactId>opencsv</artifactId> + <version>3.9</version> + </dependency> + <dependency> <groupId>org.apache.httpcomponents</groupId> <artifactId>httpclient</artifactId> <version>${httpcomponents.version}</version>
