This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit e2e560a2d6382d9b290f0e8f25e41772ad6c91ec Author: baunsgaard <[email protected]> AuthorDate: Wed Dec 21 13:18:27 2022 +0100 [SYSTEMDS-3272] applySchema FrameBlock parallel This commit improve performance of applySchema through parallelization, from 0.8- 0.9 sec to 0.169 sec on a 64kx2k Frame block, also included are test with 100% test coverage of the applySchema. --- .../sysds/runtime/frame/data/FrameBlock.java | 30 ++--- .../frame/data/lib/FrameLibApplySchema.java | 106 +++++++++++++++++ .../frame/data/lib/FrameLibDetectSchema.java | 2 +- .../test/component/frame/FrameApplySchema.java | 129 +++++++++++++++++++-- 4 files changed, 239 insertions(+), 28 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index 7d2fbb5797..80c3508fea 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -51,12 +51,14 @@ import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.codegen.CodegenUtils; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; +import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; import org.apache.sysds.runtime.frame.data.lib.FrameFromMatrixBlock; +import org.apache.sysds.runtime.frame.data.lib.FrameLibApplySchema; import org.apache.sysds.runtime.frame.data.lib.FrameLibDetectSchema; import org.apache.sysds.runtime.functionobjects.ValueComparisonFunction; import org.apache.sysds.runtime.instructions.cp.BooleanObject; @@ -151,6 +153,14 @@ public class FrameBlock implements CacheBlock<FrameBlock>, Externalizable { appendRow(data[i]); } + public FrameBlock(ValueType[] schema, String[] colNames, ColumnMetadata[] meta, Array<?>[] data ){ + _numRows = data[0].size(); + _schema = schema; + _colnames = colNames; + _colmeta = meta; + _coldata = data; + } + /** * Get the number of rows of the frame block. * @@ -279,6 +289,10 @@ public class FrameBlock implements CacheBlock<FrameBlock>, Externalizable { return _colmeta[c]; } + public Array<?>[] getColumns(){ + return _coldata; + } + public boolean isColumnMetadataDefault() { boolean ret = true; for( int j=0; j<getNumColumns() && ret; j++ ) @@ -1733,21 +1747,7 @@ public class FrameBlock implements CacheBlock<FrameBlock>, Externalizable { * @return A new FrameBlock with the schema applied. */ public FrameBlock applySchema(ValueType[] schema) { - if(schema.length != _schema.length) - throw new DMLRuntimeException(// - "Invalid apply schema with different number of columns expected: " + _schema.length + " got: " - + schema.length); - FrameBlock ret = new FrameBlock(); - final int nCol = getNumColumns(); - ret._numRows = getNumRows(); - ret._schema = schema; - ret._colnames = _colnames; - ret._colmeta = _colmeta; - ret._coldata = new Array[nCol]; - for(int i = 0; i < nCol; i++) - ret._coldata[i] = _coldata[i].changeType(schema[i]); - ret._msize = -1; - return ret; + return FrameLibApplySchema.applySchema(this, schema, InfrastructureAnalyzer.getLocalParallelism()); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java new file mode 100644 index 0000000000..57c79a49a9 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.frame.data.lib; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.stream.IntStream; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; +import org.apache.sysds.runtime.util.CommonThreadPool; + +public class FrameLibApplySchema { + + protected static final Log LOG = LogFactory.getLog(FrameLibApplySchema.class.getName()); + + private final FrameBlock fb; + private final ValueType[] schema; + private final int nCol; + private final Array<?>[] columnsIn; + private final Array<?>[] columnsOut; + + /** + * Method to create a new FrameBlock where the given schema is applied, k is parallelization degree. + * + * @param fb The input block to apply schema to + * @param schema The schema to apply + * @param k The parallelization degree + * @return A new FrameBlock allocated with new arrays. + */ + public static FrameBlock applySchema(FrameBlock fb, ValueType[] schema, int k) { + return new FrameLibApplySchema(fb, schema).apply(k); + } + + private FrameLibApplySchema(FrameBlock fb, ValueType[] schema) { + this.fb = fb; + this.schema = schema; + verifySize(); + nCol = fb.getNumColumns(); + columnsIn = fb.getColumns(); + columnsOut = new Array<?>[nCol]; + + } + + private FrameBlock apply(int k) { + if(k <= 1 || nCol == 1) + applySingleThread(); + else + applyMultiThread(k); + + final String[] colNames = fb.getColumnNames(false); + final ColumnMetadata[] meta = fb.getColumnMetadata(); + return new FrameBlock(schema, colNames, meta, columnsOut); + } + + private void applySingleThread() { + for(int i = 0; i < nCol; i++) + columnsOut[i] = columnsIn[i].changeType(schema[i]); + } + + private void applyMultiThread(int k) { + final ExecutorService pool = CommonThreadPool.get(k); + try { + + pool.submit(() -> { + IntStream.rangeClosed(0, nCol - 1).parallel() // parallel columns + .forEach(x -> columnsOut[x] = columnsIn[x].changeType(schema[x])); + }).get(); + + pool.shutdown(); + } + catch(InterruptedException | ExecutionException e) { + pool.shutdown(); + throw new DMLRuntimeException("Failed to combine column groups", e); + } + } + + private void verifySize() { + if(schema.length != fb.getSchema().length) + throw new DMLRuntimeException(// + "Invalid apply schema with different number of columns expected: " + fb.getSchema().length + " got: " + + schema.length); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java index 8a37ac8f4d..3219617f27 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java @@ -34,7 +34,7 @@ import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.UtilFunctions; public final class FrameLibDetectSchema { - // private static final Log LOG = LogFactory.getLog(FrameBlock.class.getName()); + // private static final Log LOG = LogFactory.getLog(FrameLibDetectSchema.class.getName()); private FrameLibDetectSchema() { // private constructor diff --git a/src/test/java/org/apache/sysds/test/component/frame/FrameApplySchema.java b/src/test/java/org/apache/sysds/test/component/frame/FrameApplySchema.java index a843d9f916..fb0262daa3 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/FrameApplySchema.java +++ b/src/test/java/org/apache/sysds/test/component/frame/FrameApplySchema.java @@ -25,37 +25,142 @@ import static org.junit.Assert.fail; import java.util.Random; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; -import org.apache.sysds.runtime.frame.data.columns.BooleanArray; +import org.apache.sysds.runtime.frame.data.lib.FrameLibApplySchema; import org.junit.Test; public class FrameApplySchema { + @Test + public void testApplySchemaStringToBoolean() { + try { + + FrameBlock fb = genStringContainingBoolean(10, 2); + ValueType[] schema = new ValueType[] {ValueType.BOOLEAN, ValueType.BOOLEAN}; + FrameBlock ret = fb.applySchema(schema); + assertTrue(ret.getColumn(0).getValueType() == ValueType.BOOLEAN); + assertTrue(ret.getColumn(1).getValueType() == ValueType.BOOLEAN); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } @Test - public void testApplySchema(){ - try{ + public void testApplySchemaStringToInt() { + try { + FrameBlock fb = genStringContainingInteger(10, 2); + ValueType[] schema = new ValueType[] {ValueType.INT32, ValueType.INT32}; + FrameBlock ret = fb.applySchema(schema); + assertTrue(ret.getColumn(0).getValueType() == ValueType.INT32); + assertTrue(ret.getColumn(1).getValueType() == ValueType.INT32); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } - FrameBlock fb = genBoolean(10, 2); - ValueType[] schema = new ValueType[]{ValueType.BOOLEAN,ValueType.BOOLEAN}; + @Test + public void testApplySchemaStringToIntSingleCol() { + try { + FrameBlock fb = genStringContainingInteger(10, 1); + ValueType[] schema = new ValueType[] {ValueType.INT32}; FrameBlock ret = fb.applySchema(schema); - assertTrue(ret.getColumn(0) instanceof BooleanArray); - assertTrue(ret.getColumn(1) instanceof BooleanArray); + assertTrue(ret.getColumn(0).getValueType() == ValueType.INT32); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testApplySchemaStringToIntDirectCallSingleThread() { + try { + FrameBlock fb = genStringContainingInteger(10, 3); + ValueType[] schema = new ValueType[] {ValueType.INT32, ValueType.INT32, ValueType.INT32}; + FrameBlock ret = FrameLibApplySchema.applySchema(fb, schema, 1); + for(int i = 0; i < ret.getNumColumns(); i++) + assertTrue(ret.getColumn(i).getValueType() == ValueType.INT32); + + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testApplySchemaStringToIntDirectCallMultiThread() { + try { + FrameBlock fb = genStringContainingInteger(10, 3); + ValueType[] schema = new ValueType[] {ValueType.INT32, ValueType.INT32, ValueType.INT32}; + FrameBlock ret = FrameLibApplySchema.applySchema(fb, schema, 3); + for(int i = 0; i < ret.getNumColumns(); i++) + assertTrue(ret.getColumn(i).getValueType() == ValueType.INT32); + + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); } - catch(Exception e){ + } + + + @Test + public void testApplySchemaStringToIntDirectCallMultiThreadSingleCol() { + try { + FrameBlock fb = genStringContainingInteger(10, 1); + ValueType[] schema = new ValueType[] {ValueType.INT32}; + FrameBlock ret = FrameLibApplySchema.applySchema(fb, schema, 3); + for(int i = 0; i < ret.getNumColumns(); i++) + assertTrue(ret.getColumn(i).getValueType() == ValueType.INT32); + + } + catch(Exception e) { e.printStackTrace(); fail(e.getMessage()); } } - private FrameBlock genBoolean(int row, int col){ + @Test(expected = DMLRuntimeException.class) + public void testInvalidInput() { + FrameBlock fb = genStringContainingInteger(10, 10); + ValueType[] schema = new ValueType[] {ValueType.INT32, ValueType.INT32, ValueType.INT32}; + FrameLibApplySchema.applySchema(fb, schema, 3); + } + + @Test(expected = DMLRuntimeException.class) + public void testInvalidInput2() { + FrameBlock fb = genStringContainingInteger(10, 3); + ValueType[] schema = new ValueType[] {ValueType.UNKNOWN, ValueType.INT32, ValueType.INT32}; + FrameLibApplySchema.applySchema(fb, schema, 3); + } + + private FrameBlock genStringContainingInteger(int row, int col) { FrameBlock ret = new FrameBlock(); Random r = new Random(31); - for(int c = 0; c < col; c ++){ + for(int c = 0; c < col; c++) { String[] column = new String[row]; - for(int i = 0; i < row; i ++) + for(int i = 0; i < row; i++) + column[i] = "" + r.nextInt(); + + ret.appendColumn(column); + } + return ret; + } + + private FrameBlock genStringContainingBoolean(int row, int col) { + FrameBlock ret = new FrameBlock(); + Random r = new Random(31); + for(int c = 0; c < col; c++) { + String[] column = new String[row]; + for(int i = 0; i < row; i++) column[i] = "" + r.nextBoolean(); - + ret.appendColumn(column); } return ret;
