This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit dea38dc42c37a44034e351123110c5267ae57fd0 Author: baunsgaard <[email protected]> AuthorDate: Tue Nov 1 17:54:22 2022 +0100 [SYSTEMDS-3462] FrameBlock Iterators Factory Pattern This commit moves the iterators out of the frame, and to a factory pattern. This is to in the future allow the customized column allocations to iterate nicely, and therefore we need specialized code to return different instances of iterators returned, and therefore to no bloat the code internally in the FrameBlock we move this logic out. --- .../sysds/runtime/frame/data/FrameBlock.java | 267 +-------------------- .../apache/sysds/runtime/frame/data/FrameUtil.java | 27 ++- .../frame/data/iterators/IteratorFactory.java | 166 +++++++++++++ .../frame/data/iterators/ObjectRowIterator.java | 65 +++++ .../runtime/frame/data/iterators/RowIterator.java | 57 +++++ .../StringRowIterator.java} | 37 ++- .../functions/ConvertFrameBlockToIJVLines.java | 3 +- .../spark/utils/FrameRDDConverterUtils.java | 3 +- .../apache/sysds/runtime/io/FrameWriterJSONL.java | 19 +- .../apache/sysds/runtime/io/FrameWriterProto.java | 3 +- .../sysds/runtime/io/FrameWriterTextCSV.java | 3 +- .../sysds/runtime/io/FrameWriterTextCell.java | 3 +- .../apache/sysds/runtime/util/DataConverter.java | 5 +- .../EntityResolutionClusteringTest.java | 19 +- .../test/component/frame/FrameIterators.java} | 24 +- .../TransformFederatedEncodeDecodeTest.java | 3 +- .../test/functions/jmlc/FrameReadMetaTest.java | 9 +- .../transform/TransformEncodeDecodeTest.java | 9 +- 18 files changed, 407 insertions(+), 315 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index 57e38f4939..6078ad5420 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -61,6 +61,7 @@ import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.frame.data.columns.StringArray; +import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; import org.apache.sysds.runtime.functionobjects.ValueComparisonFunction; import org.apache.sysds.runtime.instructions.cp.BooleanObject; import org.apache.sysds.runtime.instructions.cp.DoubleObject; @@ -681,137 +682,6 @@ public class FrameBlock implements CacheBlock, Externalizable { _msize = -1; } - /** - * Get a row iterator over the frame where all fields are encoded - * as strings independent of their value types. - * - * @return string array iterator - */ - public Iterator<String[]> getStringRowIterator() { - return new StringRowIterator(0, _numRows); - } - - /** - * Get a row iterator over the frame where all selected fields are - * encoded as strings independent of their value types. - * - * @param cols column selection, 1-based - * @return string array iterator - */ - public Iterator<String[]> getStringRowIterator(int[] cols) { - return new StringRowIterator(0, _numRows, cols); - } - - /** - * Get a row iterator over the frame where all selected fields are encoded as strings independent of their value - * types. - * - * @param colID column selection, 1-based - * @return string array iterator - */ - public Iterator<String[]> getStringRowIterator(int colID) { - return new StringRowIterator(0, _numRows, new int[] {colID}); - } - - /** - * Get a row iterator over the frame where all fields are encoded - * as strings independent of their value types. - * - * @param rl lower row index - * @param ru upper row index - * @return string array iterator - */ - public Iterator<String[]> getStringRowIterator(int rl, int ru) { - return new StringRowIterator(rl, ru); - } - - /** - * Get a row iterator over the frame where all selected fields are - * encoded as strings independent of their value types. - * - * @param rl lower row index - * @param ru upper row index - * @param cols column selection, 1-based - * @return string array iterator - */ - public Iterator<String[]> getStringRowIterator(int rl, int ru, int[] cols) { - return new StringRowIterator(rl, ru, cols); - } - - /** - * Get a row iterator over the frame where all selected fields are - * encoded as strings independent of their value types. - * - * @param rl lower row index - * @param ru upper row index - * @param colID columnID, 1-based - * @return string array iterator - */ - public Iterator<String[]> getStringRowIterator(int rl, int ru, int colID) { - return new StringRowIterator(rl, ru, new int[] {colID}); - } - - - /** - * Get a row iterator over the frame where all fields are encoded - * as boxed objects according to their value types. - * - * @return object array iterator - */ - public Iterator<Object[]> getObjectRowIterator() { - return new ObjectRowIterator(0, _numRows); - } - - /** - * Get a row iterator over the frame where all fields are encoded - * as boxed objects according to the value types of the provided - * target schema. - * - * @param schema target schema of objects - * @return object array iterator - */ - public Iterator<Object[]> getObjectRowIterator(ValueType[] schema) { - ObjectRowIterator iter = new ObjectRowIterator(0, _numRows); - iter.setSchema(schema); - return iter; - } - - /** - * Get a row iterator over the frame where all selected fields are - * encoded as boxed objects according to their value types. - * - * @param cols column selection, 1-based - * @return object array iterator - */ - public Iterator<Object[]> getObjectRowIterator(int[] cols) { - return new ObjectRowIterator(0, _numRows, cols); - } - - /** - * Get a row iterator over the frame where all fields are encoded - * as boxed objects according to their value types. - * - * @param rl lower row index - * @param ru upper row index - * @return object array iterator - */ - public Iterator<Object[]> getObjectRowIterator(int rl, int ru) { - return new ObjectRowIterator(rl, ru); - } - - /** - * Get a row iterator over the frame where all selected fields are - * encoded as boxed objects according to their value types. - * - * @param rl lower row index - * @param ru upper row index - * @param cols column selection, 1-based - * @return object array iterator - */ - public Iterator<Object[]> getObjectRowIterator(int rl, int ru, int[] cols) { - return new ObjectRowIterator(rl, ru, cols); - } - /////// // serialization / deserialization (implementation of writable and externalizable) // FIXME for FrameBlock fix write and readFields, it does not work if the Arrays are not yet @@ -1310,7 +1180,7 @@ public class FrameBlock implements CacheBlock, Externalizable { ret._coldata = new Array[getNumColumns()]; for( int j=0; j<getNumColumns(); j++ ) ret._coldata[j] = _coldata[j].clone(); - Iterator<Object[]> iter = that.getObjectRowIterator(_schema); + Iterator<Object[]> iter = IteratorFactory.getObjectRowIterator(that, _schema); while( iter.hasNext() ) ret.appendRow(iter.next()); } @@ -1480,123 +1350,6 @@ public class FrameBlock implements CacheBlock, Externalizable { return fb; } - /////// - // row iterators (over strings and boxed objects) - - private abstract class RowIterator<T> implements Iterator<T[]> { - protected final int[] _cols; - protected final T[] _curRow; - protected final int _maxPos; - protected int _curPos = -1; - - protected RowIterator(int rl, int ru) { - this(rl, ru, UtilFunctions.getSeqArray(1, getNumColumns(), 1)); - } - - protected RowIterator(int rl, int ru, int[] cols) { - _curRow = createRow(cols.length); - _cols = cols; - _maxPos = ru; - _curPos = rl; - } - - @Override - public boolean hasNext() { - return (_curPos < _maxPos); - } - - @Override - public void remove() { - throw new RuntimeException("RowIterator.remove is unsupported!"); - } - - protected abstract T[] createRow(int size); - } - - private class StringRowIterator extends RowIterator<String> { - public StringRowIterator(int rl, int ru) { - super(rl, ru); - } - - public StringRowIterator(int rl, int ru, int[] cols) { - super(rl, ru, cols); - } - - @Override - protected String[] createRow(int size) { - return new String[size]; - } - - @Override - public String[] next( ) { - for( int j=0; j<_cols.length; j++ ) { - Object tmp = get(_curPos, _cols[j]-1); - _curRow[j] = (tmp!=null) ? tmp.toString() : null; - } - _curPos++; - return _curRow; - } - } - - private class ObjectRowIterator extends RowIterator<Object> { - private ValueType[] _tgtSchema = null; - - public ObjectRowIterator(int rl, int ru) { - super(rl, ru); - } - - public ObjectRowIterator(int rl, int ru, int[] cols) { - super(rl, ru, cols); - } - - public void setSchema(ValueType[] schema) { - _tgtSchema = schema; - } - - @Override - protected Object[] createRow(int size) { - return new Object[size]; - } - - @Override - public Object[] next( ) { - for( int j=0; j<_cols.length; j++ ) - _curRow[j] = getValue(_curPos, _cols[j]-1); - _curPos++; - return _curRow; - } - - private Object getValue(int i, int j) { - Object val = get(i, j); - if( _tgtSchema != null ) - val = UtilFunctions.objectToObject(_tgtSchema[j], val); - return val; - } - } - - private static ValueType isType(String val) { - val = val.trim().toLowerCase().replaceAll("\"", ""); - if (val.matches("(true|false|t|f|0|1)")) - return ValueType.BOOLEAN; - else if (val.matches("[-+]?\\d+")){ - long maxValue = Long.parseLong(val); - if ((maxValue >= Integer.MIN_VALUE) && (maxValue <= Integer.MAX_VALUE)) - return ValueType.INT32; - else - return ValueType.INT64; - } - else if (val.matches("[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?")){ - double maxValue = Double.parseDouble(val); - if ((maxValue >= (-Float.MAX_VALUE)) && (maxValue <= Float.MAX_VALUE)) - return ValueType.FP32; - else - return ValueType.FP64; - } - else if (val.equals("infinity") || val.equals("-infinity") || val.equals("nan")) - return ValueType.FP64; - else return ValueType.STRING; - } - public FrameBlock detectSchemaFromRow(double sampleFraction) { int rows = this.getNumRows(); int cols = this.getNumColumns(); @@ -1648,7 +1401,7 @@ public class FrameBlock implements CacheBlock, Externalizable { int randomIndex = ThreadLocalRandom.current().nextInt(0, _rows - 1); String dataValue = ((_obj.get(randomIndex) != null)?_obj.get(randomIndex).toString().trim().replace("\"", "").toLowerCase():null); if(dataValue != null){ - ValueType current = isType(dataValue); + ValueType current = FrameUtil.isType(dataValue); if (current == ValueType.STRING) { state = ValueType.STRING; break; @@ -1683,7 +1436,7 @@ public class FrameBlock implements CacheBlock, Externalizable { if(this.getNumColumns() != schema.getNumColumns()) throw new DMLException("mismatch in number of columns in frame and its schema "+this.getNumColumns()+" != "+schema.getNumColumns()); - String[] schemaString = schema.getStringRowIterator().next(); // extract the schema in String array + String[] schemaString = IteratorFactory.getStringRowIterator(this).next(); // extract the schema in String array for (int i = 0; i < this.getNumColumns(); i++) { Array obj = this.getColumn(i); String schemaCol = schemaString[i]; @@ -1705,7 +1458,7 @@ public class FrameBlock implements CacheBlock, Externalizable { continue; String dataValue = obj.get(j).toString().trim().replace("\"", "").toLowerCase() ; - ValueType dataType = isType(dataValue); + ValueType dataType = FrameUtil.isType(dataValue); if(!dataType.toString().contains(type) && !(dataType == ValueType.BOOLEAN && type.equals("INT")) && !(dataType == ValueType.BOOLEAN && type.equals("FP"))){ @@ -1752,8 +1505,8 @@ public class FrameBlock implements CacheBlock, Externalizable { } public static FrameBlock mergeSchema(FrameBlock temp1, FrameBlock temp2) { - String[] rowTemp1 = temp1.getStringRowIterator().next(); - String[] rowTemp2 = temp2.getStringRowIterator().next(); + String[] rowTemp1 = IteratorFactory.getStringRowIterator(temp1).next(); + String[] rowTemp2 = IteratorFactory.getStringRowIterator(temp2).next(); if(rowTemp1.length != rowTemp2.length) throw new DMLRuntimeException("Schema dimension " @@ -1821,7 +1574,7 @@ public class FrameBlock implements CacheBlock, Externalizable { } public FrameBlock valueSwap(FrameBlock schema) { - String[] schemaString = schema.getStringRowIterator().next(); + String[] schemaString = IteratorFactory.getStringRowIterator(schema).next(); String dataValue2 = null; double minSimScore = 0; int bestIdx = 0; @@ -1846,7 +1599,7 @@ public class FrameBlock implements CacheBlock, Externalizable { if(this.get(j, i) == null) continue; String dataValue = this.get(j, i).toString().trim().replace("\"", "").toLowerCase(); - ValueType dataType = isType(dataValue); + ValueType dataType = FrameUtil.isType(dataValue); String type = dataType.toString().replaceAll("\\d", ""); // get the avergae column length @@ -1861,7 +1614,7 @@ public class FrameBlock implements CacheBlock, Externalizable { Object item = this.get(j, w); String dataValueProb = (item != null) ? item.toString().trim().replace("\"", "") .toLowerCase() : "0"; - ValueType dataTypeProb = isType(dataValueProb); + ValueType dataTypeProb = FrameUtil.isType(dataValueProb); if(!dataTypeProb.toString().equals(schemaString[w])) { bestIdx = w; break; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameUtil.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameUtil.java index f181e86b86..06fa42360c 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameUtil.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameUtil.java @@ -19,10 +19,11 @@ package org.apache.sysds.runtime.frame.data; +import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.frame.data.columns.Array; -@SuppressWarnings({"rawtypes"}) public interface FrameUtil { + @SuppressWarnings({"rawtypes"}) public static Array[] add(Array[] ar, Array e) { if(ar == null) return new Array[] {e}; @@ -31,4 +32,28 @@ public interface FrameUtil { ret[ar.length] = e; return ret; } + + public static ValueType isType(String val) { + val = val.trim().toLowerCase().replaceAll("\"", ""); + if(val.matches("(true|false|t|f|0|1)")) + return ValueType.BOOLEAN; + else if(val.matches("[-+]?\\d+")) { + long maxValue = Long.parseLong(val); + if((maxValue >= Integer.MIN_VALUE) && (maxValue <= Integer.MAX_VALUE)) + return ValueType.INT32; + else + return ValueType.INT64; + } + else if(val.matches("[-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?")) { + double maxValue = Double.parseDouble(val); + if((maxValue >= (-Float.MAX_VALUE)) && (maxValue <= Float.MAX_VALUE)) + return ValueType.FP32; + else + return ValueType.FP64; + } + else if(val.equals("infinity") || val.equals("-infinity") || val.equals("nan")) + return ValueType.FP64; + else + return ValueType.STRING; + } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/iterators/IteratorFactory.java b/src/main/java/org/apache/sysds/runtime/frame/data/iterators/IteratorFactory.java new file mode 100644 index 0000000000..4560b1736e --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/frame/data/iterators/IteratorFactory.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.frame.data.iterators; + +import java.util.Iterator; + +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.frame.data.FrameBlock; + +/** + * Factory pattern for construction of rowIterators of the FrameBlock. + */ +public interface IteratorFactory { + + /** + * Get a row iterator over the frame where all fields are encoded as strings independent of their value types. + * + * @param fb The frame to iterate through + * @return string array iterator + */ + public static Iterator<String[]> getStringRowIterator(FrameBlock fb) { + return new StringRowIterator(fb, 0, fb.getNumRows()); + } + + /** + * Get a row iterator over the frame where all selected fields are encoded as strings independent of their value + * types. + * + * @param fb The frame to iterate through + * @param cols column selection, 1-based + * @return string array iterator + */ + public static Iterator<String[]> getStringRowIterator(FrameBlock fb, int[] cols) { + return new StringRowIterator(fb, 0, fb.getNumRows(), cols); + } + + /** + * Get a row iterator over the frame where all selected fields are encoded as strings independent of their value + * types. + * + * @param fb The frame to iterate through + * @param colID column selection, 1-based + * @return string array iterator + */ + public static Iterator<String[]> getStringRowIterator(FrameBlock fb, int colID) { + return new StringRowIterator(fb, 0, fb.getNumRows(), new int[] {colID}); + } + + /** + * Get a row iterator over the frame where all fields are encoded as strings independent of their value types. + * + * @param fb The frame to iterate through + * @param rl lower row index + * @param ru upper row index + * @return string array iterator + */ + public static Iterator<String[]> getStringRowIterator(FrameBlock fb, int rl, int ru) { + return new StringRowIterator(fb, rl, ru); + } + + /** + * Get a row iterator over the frame where all selected fields are encoded as strings independent of their value + * types. + * + * @param fb The frame to iterate through + * @param rl lower row index + * @param ru upper row index + * @param cols column selection, 1-based + * @return string array iterator + */ + public static Iterator<String[]> getStringRowIterator(FrameBlock fb, int rl, int ru, int[] cols) { + return new StringRowIterator(fb, rl, ru, cols); + } + + /** + * Get a row iterator over the frame where all selected fields are encoded as strings independent of their value + * types. + * + * @param fb The frame to iterate through + * @param rl lower row index + * @param ru upper row index + * @param colID columnID, 1-based + * @return string array iterator + */ + public static Iterator<String[]> getStringRowIterator(FrameBlock fb, int rl, int ru, int colID) { + return new StringRowIterator(fb, rl, ru, new int[] {colID}); + } + + /** + * Get a row iterator over the frame where all fields are encoded as boxed objects according to their value types. + * + * @param fb The frame to iterate through + * @return object array iterator + */ + public static Iterator<Object[]> getObjectRowIterator(FrameBlock fb) { + return new ObjectRowIterator(fb, 0, fb.getNumRows()); + } + + /** + * Get a row iterator over the frame where all fields are encoded as boxed objects according to the value types of + * the provided target schema. + * + * @param fb The frame to iterate through + * @param schema target schema of objects + * @return object array iterator + */ + public static Iterator<Object[]> getObjectRowIterator(FrameBlock fb, ValueType[] schema) { + return new ObjectRowIterator(fb, 0, fb.getNumRows(), schema); + } + + /** + * Get a row iterator over the frame where all selected fields are encoded as boxed objects according to their value + * types. + * + * @param fb The frame to iterate through + * @param cols column selection, 1-based + * @return object array iterator + */ + public static Iterator<Object[]> getObjectRowIterator(FrameBlock fb, int[] cols) { + return new ObjectRowIterator(fb, 0, fb.getNumRows(), cols); + } + + /** + * Get a row iterator over the frame where all fields are encoded as boxed objects according to their value types. + * + * @param fb The frame to iterate through + * @param rl lower row index + * @param ru upper row index + * @return object array iterator + */ + public static Iterator<Object[]> getObjectRowIterator(FrameBlock fb, int rl, int ru) { + return new ObjectRowIterator(fb, rl, ru); + } + + /** + * Get a row iterator over the frame where all selected fields are encoded as boxed objects according to their value + * types. + * + * @param fb The frame to iterate through + * @param rl lower row index + * @param ru upper row index + * @param cols column selection, 1-based + * @return object array iterator + */ + public static Iterator<Object[]> getObjectRowIterator(FrameBlock fb, int rl, int ru, int[] cols) { + return new ObjectRowIterator(fb, rl, ru, cols); + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/iterators/ObjectRowIterator.java b/src/main/java/org/apache/sysds/runtime/frame/data/iterators/ObjectRowIterator.java new file mode 100644 index 0000000000..584a6a3173 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/frame/data/iterators/ObjectRowIterator.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.frame.data.iterators; + +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.util.UtilFunctions; + +public class ObjectRowIterator extends RowIterator<Object> { + private final ValueType[] _tgtSchema; + + public ObjectRowIterator(FrameBlock fb, int rl, int ru) { + this(fb, rl, ru, UtilFunctions.getSeqArray(1, fb.getNumColumns(), 1), null); + } + + public ObjectRowIterator(FrameBlock fb, int rl, int ru, ValueType[] schema) { + this(fb, rl, ru, UtilFunctions.getSeqArray(1, fb.getNumColumns(), 1), schema); + } + + public ObjectRowIterator(FrameBlock fb, int rl, int ru, int[] cols) { + this(fb, rl, ru, cols, null); + } + + public ObjectRowIterator(FrameBlock fb, int rl, int ru, int[] cols, ValueType[] schema){ + super(fb, rl, ru, cols); + _tgtSchema = schema; + } + + @Override + protected Object[] createRow(int size) { + return new Object[size]; + } + + @Override + public Object[] next() { + for(int j = 0; j < _cols.length; j++) + _curRow[j] = getValue(_curPos, _cols[j] - 1); + _curPos++; + return _curRow; + } + + private Object getValue(int i, int j) { + Object val = _fb.get(i, j); + if(_tgtSchema != null) + val = UtilFunctions.objectToObject(_tgtSchema[j], val); + return val; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/iterators/RowIterator.java b/src/main/java/org/apache/sysds/runtime/frame/data/iterators/RowIterator.java new file mode 100644 index 0000000000..d62a5d6cd9 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/frame/data/iterators/RowIterator.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.frame.data.iterators; + +import java.util.Iterator; + +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.util.UtilFunctions; + +public abstract class RowIterator<T> implements Iterator<T[]> { + protected final FrameBlock _fb; + protected final int[] _cols; + protected final T[] _curRow; + protected final int _maxPos; + protected int _curPos = -1; + + protected RowIterator(FrameBlock fb, int rl, int ru) { + this(fb, rl, ru, UtilFunctions.getSeqArray(1, fb.getNumColumns(), 1)); + } + + protected RowIterator(FrameBlock fb, int rl, int ru, int[] cols) { + _fb = fb; + _curRow = createRow(cols.length); + _cols = cols; + _maxPos = ru; + _curPos = rl; + } + + @Override + public boolean hasNext() { + return(_curPos < _maxPos); + } + + @Override + public void remove() { + throw new RuntimeException("RowIterator.remove is unsupported!"); + } + + protected abstract T[] createRow(int size); +} diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameUtil.java b/src/main/java/org/apache/sysds/runtime/frame/data/iterators/StringRowIterator.java similarity index 55% copy from src/main/java/org/apache/sysds/runtime/frame/data/FrameUtil.java copy to src/main/java/org/apache/sysds/runtime/frame/data/iterators/StringRowIterator.java index f181e86b86..3647ce5106 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameUtil.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/iterators/StringRowIterator.java @@ -17,18 +17,31 @@ * under the License. */ -package org.apache.sysds.runtime.frame.data; +package org.apache.sysds.runtime.frame.data.iterators; -import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.FrameBlock; -@SuppressWarnings({"rawtypes"}) -public interface FrameUtil { - public static Array[] add(Array[] ar, Array e) { - if(ar == null) - return new Array[] {e}; - Array[] ret = new Array[ar.length + 1]; - System.arraycopy(ar, 0, ret, 0, ar.length); - ret[ar.length] = e; - return ret; +public class StringRowIterator extends RowIterator<String> { + public StringRowIterator(FrameBlock fb, int rl, int ru) { + super(fb, rl, ru); } -} + + public StringRowIterator(FrameBlock fb, int rl, int ru, int[] cols) { + super(fb, rl, ru, cols); + } + + @Override + protected String[] createRow(int size) { + return new String[size]; + } + + @Override + public String[] next( ) { + for( int j=0; j<_cols.length; j++ ) { + Object tmp = _fb.get(_curPos, _cols[j]-1); + _curRow[j] = (tmp!=null) ? tmp.toString() : null; + } + _curPos++; + return _curRow; + } +} \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/functions/ConvertFrameBlockToIJVLines.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/functions/ConvertFrameBlockToIJVLines.java index ec27faf32e..bb55c6451f 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/functions/ConvertFrameBlockToIJVLines.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/functions/ConvertFrameBlockToIJVLines.java @@ -23,6 +23,7 @@ import java.util.Iterator; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; import scala.Tuple2; @@ -49,7 +50,7 @@ public class ConvertFrameBlockToIJVLines implements FlatMapFunction<Tuple2<Long, //convert frame block to list of ijv cell triples StringBuilder sb = new StringBuilder(); - Iterator<String[]> iter = block.getStringRowIterator(); + Iterator<String[]> iter = IteratorFactory.getStringRowIterator(block); for( int i=0; iter.hasNext(); i++ ) { //for all rows String rowIndex = Long.toString(rowoffset + i); String[] row = iter.next(); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDConverterUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDConverterUtils.java index b93e1cfc0c..d26b2fcd9e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDConverterUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDConverterUtils.java @@ -52,6 +52,7 @@ import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; import org.apache.sysds.runtime.instructions.spark.data.FrameReblockBuffer; import org.apache.sysds.runtime.instructions.spark.data.SerLongWritable; import org.apache.sysds.runtime.instructions.spark.data.SerText; @@ -685,7 +686,7 @@ public class FrameRDDConverterUtils } //handle Frame block data - Iterator<String[]> iter = blk.getStringRowIterator(); + Iterator<String[]> iter = IteratorFactory.getStringRowIterator(blk); while( iter.hasNext() ) { String[] row = iter.next(); for(int j=0; j<row.length; j++) { diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameWriterJSONL.java b/src/main/java/org/apache/sysds/runtime/io/FrameWriterJSONL.java index 919613c7de..be20d7045c 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameWriterJSONL.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameWriterJSONL.java @@ -19,21 +19,22 @@ package org.apache.sysds.runtime.io; +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.util.Iterator; +import java.util.Map; + import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapred.JobConf; -import org.apache.wink.json4j.JSONException; -import org.apache.wink.json4j.JSONObject; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; import org.apache.sysds.runtime.util.HDFSTool; - -import java.io.BufferedWriter; -import java.io.IOException; -import java.io.OutputStreamWriter; -import java.util.Iterator; -import java.util.Map; +import org.apache.wink.json4j.JSONException; +import org.apache.wink.json4j.JSONObject; public class FrameWriterJSONL { @@ -73,7 +74,7 @@ public class FrameWriterJSONL BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(fileSystem.create(path, true))); try { - Iterator<String[]> stringRowIterator = src.getStringRowIterator(lowerRowBound, upperRowBound); + Iterator<String[]> stringRowIterator = IteratorFactory.getStringRowIterator(src, lowerRowBound, upperRowBound); while (stringRowIterator.hasNext()) { String[] row = stringRowIterator.next(); bufferedWriter.write(formatToJSONString(row, schemaMap) + "\n"); diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameWriterProto.java b/src/main/java/org/apache/sysds/runtime/io/FrameWriterProto.java index 40af18b20e..c5efef6ba3 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameWriterProto.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameWriterProto.java @@ -30,6 +30,7 @@ import org.apache.hadoop.mapred.JobConf; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.protobuf.SysdsProtos; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; import org.apache.sysds.runtime.util.HDFSTool; public class FrameWriterProto extends FrameWriter { @@ -66,7 +67,7 @@ public class FrameWriterProto extends FrameWriter { OutputStream outputStream = fileSystem.create(path, true); SysdsProtos.Frame.Builder frameBuilder = SysdsProtos.Frame.newBuilder(); try { - Iterator<String[]> stringRowIterator = src.getStringRowIterator(lowerRowBound, upperRowBound); + Iterator<String[]> stringRowIterator = IteratorFactory.getStringRowIterator(src, lowerRowBound, upperRowBound); while(stringRowIterator.hasNext()) { String[] row = stringRowIterator.next(); frameBuilder.addRowsBuilder().addAllColumnData(Arrays.asList(row)); diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCSV.java b/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCSV.java index 308fd0038c..92abe0932d 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCSV.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCSV.java @@ -30,6 +30,7 @@ import org.apache.hadoop.mapred.JobConf; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; import org.apache.sysds.runtime.transform.TfUtils; import org.apache.sysds.runtime.util.HDFSTool; @@ -122,7 +123,7 @@ public class FrameWriterTextCSV extends FrameWriter } // Write data lines - Iterator<String[]> iter = src.getStringRowIterator(rl, ru); + Iterator<String[]> iter = IteratorFactory.getStringRowIterator(src, rl, ru); while( iter.hasNext() ) { //write row chunk-wise to prevent OOM on large number of columns String[] row = iter.next(); diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCell.java b/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCell.java index 546bbe25dd..5e5e131987 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCell.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCell.java @@ -30,6 +30,7 @@ import org.apache.hadoop.mapred.JobConf; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; import org.apache.sysds.runtime.util.HDFSTool; /** @@ -108,7 +109,7 @@ public class FrameWriterTextCell extends FrameWriter } //write frame row range to output - Iterator<String[]> iter = src.getStringRowIterator(rl, ru); + Iterator<String[]> iter = IteratorFactory.getStringRowIterator(src, rl, ru); for( int i=rl; iter.hasNext(); i++ ) { //for all rows String rowIndex = Integer.toString(i+1); String[] row = iter.next(); diff --git a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java index e11dfb2c0f..545ba78e38 100644 --- a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java +++ b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java @@ -50,6 +50,7 @@ import org.apache.sysds.runtime.data.DenseBlockFactory; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.TensorBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; import org.apache.sysds.runtime.instructions.cp.BooleanObject; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.cp.Data; @@ -701,7 +702,7 @@ public class DataConverter { public static String[][] convertToStringFrame(FrameBlock frame) { String[][] ret = new String[frame.getNumRows()][]; - Iterator<String[]> iter = frame.getStringRowIterator(); + Iterator<String[]> iter = IteratorFactory.getStringRowIterator(frame); for( int i=0; iter.hasNext(); i++ ) { //deep copy output rows due to internal reuse ret[i] = iter.next().clone(); @@ -1276,7 +1277,7 @@ public class DataConverter { if (decimal >= 0) df.setMinimumFractionDigits(decimal); - Iterator<Object[]> iter = fb.getObjectRowIterator(0, rowLength); + Iterator<Object[]> iter = IteratorFactory.getObjectRowIterator(fb, 0, rowLength); while( iter.hasNext() ) { Object[] row = iter.next(); for( int j=0; j<colLength; j++ ) { diff --git a/src/test/java/org/apache/sysds/test/applications/EntityResolutionClusteringTest.java b/src/test/java/org/apache/sysds/test/applications/EntityResolutionClusteringTest.java index 84b7c69c26..fe2041e15c 100644 --- a/src/test/java/org/apache/sysds/test/applications/EntityResolutionClusteringTest.java +++ b/src/test/java/org/apache/sysds/test/applications/EntityResolutionClusteringTest.java @@ -19,13 +19,6 @@ package org.apache.sysds.test.applications; -import org.apache.sysds.common.Types; -import org.apache.sysds.runtime.frame.data.FrameBlock; -import org.apache.sysds.test.AutomatedTestBase; -import org.apache.sysds.test.TestConfiguration; -import org.junit.Assert; -import org.junit.Test; - import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; @@ -33,6 +26,14 @@ import java.nio.file.StandardCopyOption; import java.util.Arrays; import java.util.Iterator; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.junit.Assert; +import org.junit.Test; + public class EntityResolutionClusteringTest extends AutomatedTestBase { private final static String TEST_NAME = "EntityResolutionClustering"; private final static String TEST_DIR = "applications/entity_resolution/clustering/"; @@ -99,8 +100,8 @@ public class EntityResolutionClusteringTest extends AutomatedTestBase { FrameBlock predictedPairs = readDMLFrameFromHDFS("B", Types.FileFormat.CSV); - Iterator<Object[]> expectedIter = expectedPairs.getObjectRowIterator(); - Iterator<Object[]> predictedIter = predictedPairs.getObjectRowIterator(); + Iterator<Object[]> expectedIter = IteratorFactory.getObjectRowIterator(expectedPairs); + Iterator<Object[]> predictedIter = IteratorFactory.getObjectRowIterator(predictedPairs); int row = 0; while (expectedIter.hasNext()) { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameUtil.java b/src/test/java/org/apache/sysds/test/component/frame/FrameIterators.java similarity index 61% copy from src/main/java/org/apache/sysds/runtime/frame/data/FrameUtil.java copy to src/test/java/org/apache/sysds/test/component/frame/FrameIterators.java index f181e86b86..48396f283e 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameUtil.java +++ b/src/test/java/org/apache/sysds/test/component/frame/FrameIterators.java @@ -17,18 +17,20 @@ * under the License. */ -package org.apache.sysds.runtime.frame.data; +package org.apache.sysds.test.component.frame; -import org.apache.sysds.runtime.frame.data.columns.Array; +import java.util.Iterator; -@SuppressWarnings({"rawtypes"}) -public interface FrameUtil { - public static Array[] add(Array[] ar, Array e) { - if(ar == null) - return new Array[] {e}; - Array[] ret = new Array[ar.length + 1]; - System.arraycopy(ar, 0, ret, 0, ar.length); - ret[ar.length] = e; - return ret; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class FrameIterators { + @Test(expected = RuntimeException.class) + public void testRemove() { + FrameBlock fb = TestUtils.generateRandomFrameBlock(10, 10, 32); + Iterator<Object[]> it = IteratorFactory.getObjectRowIterator(fb); + it.remove(); } } diff --git a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java index df46b3f5ee..f144e03984 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeDecodeTest.java @@ -28,6 +28,7 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types; import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; import org.apache.sysds.runtime.io.FrameReader; import org.apache.sysds.runtime.io.FrameReaderFactory; import org.apache.sysds.test.AutomatedTestBase; @@ -190,7 +191,7 @@ public class TransformFederatedEncodeDecodeTest extends AutomatedTestBase { // compare matrices (values recoded to identical codes) FrameBlock FO = reader.readFrameFromHDFS(output("FO1"), 15, 2); HashMap<String, Long> cFA = getCounts(A, B); - Iterator<String[]> iterFO = FO.getStringRowIterator(); + Iterator<String[]> iterFO = IteratorFactory.getStringRowIterator(FO); while(iterFO.hasNext()) { String[] row = iterFO.next(); Double expected = (double) cFA.get(row[1]); diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/FrameReadMetaTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/FrameReadMetaTest.java index c12fb420bd..77cd2aa202 100644 --- a/src/test/java/org/apache/sysds/test/functions/jmlc/FrameReadMetaTest.java +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/FrameReadMetaTest.java @@ -25,13 +25,13 @@ import java.util.HashMap; import java.util.Iterator; import java.util.List; -import org.junit.Assert; -import org.junit.Test; +import org.apache.commons.lang.ArrayUtils; import org.apache.sysds.api.jmlc.Connection; import org.apache.sysds.api.jmlc.PreparedScript; import org.apache.sysds.api.jmlc.ResultVariables; import org.apache.sysds.lops.Lop; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.runtime.transform.TfUtils.TfMethod; import org.apache.sysds.runtime.transform.meta.TfMetaUtils; @@ -39,7 +39,8 @@ import org.apache.sysds.runtime.util.DataConverter; import org.apache.sysds.runtime.util.HDFSTool; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; -import org.apache.commons.lang.ArrayUtils; +import org.junit.Assert; +import org.junit.Test; public class FrameReadMetaTest extends AutomatedTestBase { @@ -160,7 +161,7 @@ public class FrameReadMetaTest extends AutomatedTestBase List<Integer> collist = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(spec, M.getColumnNames(), TfMethod.RECODE.toString()))); HashMap<String,Long>[] ret = new HashMap[M.getNumColumns()]; - Iterator<Object[]> iter = M.getObjectRowIterator(); + Iterator<Object[]> iter = IteratorFactory.getObjectRowIterator(M); while( iter.hasNext() ) { Object[] tmp = iter.next(); for( int j=0; j<tmp.length; j++ ) diff --git a/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeDecodeTest.java b/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeDecodeTest.java index 6144aaf42c..b0b9a3a7b3 100644 --- a/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeDecodeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/transform/TransformEncodeDecodeTest.java @@ -22,12 +22,11 @@ package org.apache.sysds.test.functions.transform; import java.util.HashMap; import java.util.Iterator; -import org.junit.Assert; -import org.junit.Test; import org.apache.sysds.common.Types.ExecMode; -import org.apache.sysds.common.Types.FileFormat; import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.common.Types.FileFormat; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; import org.apache.sysds.runtime.io.FrameReader; import org.apache.sysds.runtime.io.FrameReaderFactory; import org.apache.sysds.runtime.io.FrameWriter; @@ -36,6 +35,8 @@ import org.apache.sysds.runtime.util.DataConverter; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; /** * @@ -116,7 +117,7 @@ public class TransformEncodeDecodeTest extends AutomatedTestBase FrameReader reader = FrameReaderFactory.createFrameReader(FileFormat.safeValueOf(fmt)); FrameBlock FO = reader.readFrameFromHDFS(output("FO"), 16, 2); HashMap<String,Long> cFA = getCounts(FA, 1); - Iterator<String[]> iterFO = FO.getStringRowIterator(); + Iterator<String[]> iterFO = IteratorFactory.getStringRowIterator(FO); while( iterFO.hasNext() ) { String[] row = iterFO.next(); Double expected = (double)cFA.get(row[1]);
