This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new acd878413d [MINOR] Changes to internal metadata and Sparse safe Ceil
acd878413d is described below
commit acd878413dd6bf333ec628c732d38a709a460642
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Wed Sep 6 16:09:46 2023 +0200
[MINOR] Changes to internal metadata and Sparse safe Ceil
This commit change some of the internal metadata handling to make
the MatrixObject allocation more consistent, i had encountered before
a bug where federated workers sometimes would not work because metadata
was allocated incorrectly.
This commit make the MatrixObject do an acquire read and write on the
instantiation if it is given a MatrixBlock.
The bug surfaced after making ceil sparse safe.
Closes #1900
---
.../controlprogram/caching/MatrixObject.java | 16 +++--
.../controlprogram/context/ExecutionContext.java | 24 +++----
.../federated/FederatedResponse.java | 84 +++++++++++++---------
.../paramserv/FederatedPSControlThread.java | 9 +--
.../runtime/matrix/operators/UnaryOperator.java | 3 +-
.../sysds/runtime/meta/MatrixCharacteristics.java | 30 ++++++--
.../org/apache/sysds/runtime/meta/MetaDataAll.java | 2 +-
.../org/apache/sysds/test/AutomatedTestBase.java | 21 ++++--
src/test/java/org/apache/sysds/test/TestUtils.java | 20 ++++--
9 files changed, 127 insertions(+), 82 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index 7d7a7743c3..697dfb6719 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -96,7 +96,7 @@ public class MatrixObject extends CacheableData<MatrixBlock> {
* @param file file name
*/
public MatrixObject(ValueType vt, String file) {
- this(vt, file, null); // HDFS file path
+ this(vt, file, null, null); // HDFS file path
}
/**
@@ -107,11 +107,7 @@ public class MatrixObject extends
CacheableData<MatrixBlock> {
* @param mtd metadata
*/
public MatrixObject(ValueType vt, String file, MetaData mtd) {
- super(DataType.MATRIX, vt);
- _metaData = mtd;
- _hdfsFileName = file;
- _cache = null;
- _data = null;
+ this(vt, file, mtd, null);
}
/**
@@ -128,7 +124,13 @@ public class MatrixObject extends
CacheableData<MatrixBlock> {
_metaData = mtd;
_hdfsFileName = file;
_cache = null;
- _data = data;
+ if(data != null) {
+ acquireModify(data);
+ release();
+ }
+ else {
+ data = null;
+ }
}
/**
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
index 1a3a0ecaed..d98827a24e 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java
@@ -668,25 +668,17 @@ public class ExecutionContext {
}
public static MatrixObject createMatrixObject(MatrixBlock mb) {
- MatrixObject ret = new MatrixObject(Types.ValueType.FP64,
- OptimizerUtils.getUniqueTempFileName());
- ret.acquireModify(mb);
- ret.setMetaData(new MetaDataFormat(new MatrixCharacteristics(
- mb.getNumRows(), mb.getNumColumns()),
FileFormat.BINARY));
- ret.getMetaData().getDataCharacteristics()
- .setBlocksize(ConfigurationManager.getBlocksize());
- ret.release();
- return ret;
+ final long nRow = mb.getNumRows(), nCol = mb.getNumColumns();
+ final int bz = ConfigurationManager.getBlocksize();
+ MetaData md = new MetaDataFormat(new
MatrixCharacteristics(nRow, nCol, bz), FileFormat.BINARY);
+ return new MatrixObject(Types.ValueType.FP64,
OptimizerUtils.getUniqueTempFileName(), md, mb);
}
public static MatrixObject createMatrixObject(DataCharacteristics dc) {
- MatrixObject ret = new MatrixObject(Types.ValueType.FP64,
- OptimizerUtils.getUniqueTempFileName());
- ret.setMetaData(new MetaDataFormat(new MatrixCharacteristics(
- dc.getRows(), dc.getCols()), FileFormat.BINARY));
- ret.getMetaData().getDataCharacteristics()
- .setBlocksize(ConfigurationManager.getBlocksize());
- return ret;
+ final long nRow = dc.getRows(), nCol = dc.getCols();
+ final int bz = dc.getBlocksize() == -1 ?
ConfigurationManager.getBlocksize() : dc.getBlocksize();
+ MetaData md = new MetaDataFormat(new
MatrixCharacteristics(nRow, nCol, bz), FileFormat.BINARY);
+ return new MatrixObject(Types.ValueType.FP64,
OptimizerUtils.getUniqueTempFileName(), md);
}
public static FrameObject createFrameObject(DataCharacteristics dc) {
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
index 89c98377bb..d87fdc9d2a 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
@@ -20,36 +20,35 @@
package org.apache.sysds.runtime.controlprogram.federated;
import java.io.Serializable;
+import java.util.Arrays;
import java.util.EnumMap;
import java.util.Map;
import java.util.concurrent.atomic.LongAdder;
import org.apache.commons.lang3.exception.ExceptionUtils;
-import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.privacy.CheckedConstraintsLog;
import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
public class FederatedResponse implements Serializable {
private static final long serialVersionUID = 3142180026498695091L;
-
+
public enum ResponseType {
- SUCCESS,
- SUCCESS_EMPTY,
- ERROR,
+ SUCCESS, SUCCESS_EMPTY, ERROR,
}
-
+
private ResponseType _status;
private Object[] _data;
- private Map<PrivacyLevel,LongAdder> checkedConstraints;
+ private Map<PrivacyLevel, LongAdder> checkedConstraints;
private transient LineageItem _linItem = null; // not included in
serialized object
-
+
public FederatedResponse(ResponseType status) {
this(status, null, null);
}
-
+
public FederatedResponse(ResponseType status, Object[] data) {
this(status, data, null);
}
@@ -57,7 +56,7 @@ public class FederatedResponse implements Serializable {
public FederatedResponse(ResponseType status, Object[] data,
LineageItem linItem) {
_status = status;
_data = data;
- if( _status == ResponseType.SUCCESS && data == null )
+ if(_status == ResponseType.SUCCESS && data == null)
_status = ResponseType.SUCCESS_EMPTY;
_linItem = linItem;
}
@@ -73,23 +72,24 @@ public class FederatedResponse implements Serializable {
_status = ResponseType.SUCCESS_EMPTY;
_linItem = linItem;
}
-
+
public boolean isSuccessful() {
return _status != ResponseType.ERROR;
}
-
+
public String getErrorMessage() {
- if (_data[0] instanceof Throwable )
- return ExceptionUtils.getStackTrace( (Throwable)
_data[0] );
- else if (_data[0] instanceof String)
+ if(_data[0] instanceof Throwable)
+ return ExceptionUtils.getStackTrace((Throwable)
_data[0]);
+ else if(_data[0] instanceof String)
return (String) _data[0];
- else return "No readable error message";
+ else
+ return "No readable error message";
}
-
+
public Object[] getData() throws Exception {
updateCheckedConstraintsLog();
- if ( !isSuccessful() )
- throwExceptionFromResponse();
+ if(!isSuccessful())
+ throwExceptionFromResponse();
return _data;
}
@@ -98,49 +98,65 @@ public class FederatedResponse implements Serializable {
if(_data != null) {
for(Object obj : _data) {
if(obj instanceof CacheBlock)
- minBufferSize +=
((CacheBlock<?>)obj).getExactSerializedSize();
+ minBufferSize += ((CacheBlock<?>)
obj).getExactSerializedSize();
}
}
return minBufferSize;
}
/**
- * Checks the data object array for exceptions that occurred in the
federated worker
- * during handling of request.
- * @throws Exception the exception retrieved from the data object array
- * or DMLRuntimeException if no exception is provided by the federated
worker.
+ * Checks the data object array for exceptions that occurred in the
federated worker during handling of request.
+ *
+ * @throws Exception the exception retrieved from the data object array
or DMLRuntimeException if no exception is
+ * provided by the federated worker.
*/
public void throwExceptionFromResponse() throws Exception {
- for ( Object potentialException : _data){
- if (potentialException != null && (potentialException
instanceof Exception) ){
+ for(Object potentialException : _data) {
+ if(potentialException != null && (potentialException
instanceof Exception)) {
throw (Exception) potentialException;
}
}
String errorMessage = getErrorMessage();
- if (getErrorMessage() != "No readable error message")
+ if(getErrorMessage() != "No readable error message")
throw new DMLRuntimeException(errorMessage);
else
- throw new DMLRuntimeException("Unknown runtime
exception in handling of federated request by federated worker.");
+ throw new DMLRuntimeException(
+ "Unknown runtime exception in handling of
federated request by federated worker.");
}
/**
- * Set checked privacy constraints in response if the provided map is
not empty.
- * If the map is empty, it means that no privacy constraints were found.
+ * Set checked privacy constraints in response if the provided map is
not empty. If the map is empty, it means that
+ * no privacy constraints were found.
+ *
* @param checkedConstraints map of checked constraints from the
PrivacyMonitor
*/
- public void setCheckedConstraints(Map<PrivacyLevel,LongAdder>
checkedConstraints){
- if ( checkedConstraints != null &&
!checkedConstraints.isEmpty() ){
+ public void setCheckedConstraints(Map<PrivacyLevel, LongAdder>
checkedConstraints) {
+ if(checkedConstraints != null && !checkedConstraints.isEmpty())
{
this.checkedConstraints = new
EnumMap<>(PrivacyLevel.class);
this.checkedConstraints.putAll(checkedConstraints);
}
}
- public void updateCheckedConstraintsLog(){
- if ( checkedConstraints != null &&
!checkedConstraints.isEmpty() )
+ public void updateCheckedConstraintsLog() {
+ if(checkedConstraints != null && !checkedConstraints.isEmpty())
CheckedConstraintsLog.addCheckedConstraints(checkedConstraints);
}
public LineageItem getLineageItem() {
return _linItem;
}
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(this.getClass().getSimpleName().toString());
+ sb.append(" response:").append(_status);
+ sb.append("\ndata:\n").append(Arrays.toString(_data));
+ if(checkedConstraints != null) {
+ sb.append("\ncheckedConstraints:\n");
+ sb.append(checkedConstraints);
+ }
+
+ return sb.toString();
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
index c7cb14f392..2a9bfb190d 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/FederatedPSControlThread.java
@@ -559,7 +559,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
try {
Object[] responseData = udfResponse.get().getData();
- if(DMLScript.STATISTICS) {
+ if(tFedCommunication != null) {
long total = (long) tFedCommunication.stop();
long workerComputing = ((DoubleObject)
responseData[1]).getLongValue();
ParamServStatistics.accFedWorkerComputing(workerComputing);
@@ -569,7 +569,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
return (ListObject) responseData[0];
}
catch(Exception e) {
- if(DMLScript.STATISTICS)
+ if(tFedCommunication != null)
tFedCommunication.stop();
throw new DMLRuntimeException("FederatedLocalPSThread:
failed to execute UDF" + e.getMessage(), e);
}
@@ -624,7 +624,8 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
// recreate aggregation instruction and output if needed
Instruction aggregationInstruction = null;
DataIdentifier aggregationOutput = null;
- if(_localUpdate && _numBatchesToCompute > 1 | modelAvg)
{
+ boolean loc= _localUpdate && _numBatchesToCompute > 1 |
modelAvg;
+ if(loc) {
func =
ec.getProgram().getFunctionProgramBlock(namespace, aggFunc, opt);
inputs = func.getInputParams();
outputs = func.getOutputParams();
@@ -666,7 +667,7 @@ public class FederatedPSControlThread extends PSWorker
implements Callable<Void>
// update the local model with gradients if
needed
// FIXME ensure that with modelAvg we always
update the model
// (current fails due to missing aggregation
instruction)
- if(_localUpdate && (batchCounter <
_numBatchesToCompute - 1 | modelAvg) ) {
+ if(loc && aggregationInstruction != null &&
aggregationOutput != null) {
// Invoke the aggregate function
aggregationInstruction.processInstruction(ec);
// Get the new model
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java
b/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java
index 8f5e2ff6d6..776d87c45c 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java
@@ -46,7 +46,8 @@ public class UnaryOperator extends MultiThreadedOperator
|| ((Builtin)p).bFunc==Builtin.BuiltinCode.SINH ||
((Builtin)p).bFunc==Builtin.BuiltinCode.TANH
|| ((Builtin)p).bFunc==Builtin.BuiltinCode.ROUND ||
((Builtin)p).bFunc==Builtin.BuiltinCode.ABS
|| ((Builtin)p).bFunc==Builtin.BuiltinCode.SQRT ||
((Builtin)p).bFunc==Builtin.BuiltinCode.SPROP
- || ((Builtin)p).bFunc==Builtin.BuiltinCode.LOG_NZ ||
((Builtin)p).bFunc==Builtin.BuiltinCode.SIGN) );
+ || ((Builtin)p).bFunc==Builtin.BuiltinCode.LOG_NZ ||
((Builtin)p).bFunc==Builtin.BuiltinCode.SIGN
+ || ((Builtin)p).bFunc==Builtin.BuiltinCode.CEIL ||
((Builtin)p).bFunc==Builtin.BuiltinCode.FLOOR ));
fn = p;
_numThreads = numThreads;
inplace = inPlace;
diff --git
a/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java
b/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java
index bdc4b2111e..aeacc1c064 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.meta;
+import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -30,9 +31,13 @@ public class MatrixCharacteristics extends
DataCharacteristics
{
private static final long serialVersionUID = 8300479822915546000L;
+ /** Number of columns */
private long numRows = -1;
+ /** Number of rows */
private long numColumns = -1;
+ /** Number of non zero values if -1 then unknown */
private long nonZero = -1;
+ /** Upper bound non zero value, indicate if the non zero value is an
upper bound */
private boolean ubNnz = false;
public MatrixCharacteristics() {}
@@ -124,12 +129,6 @@ public class MatrixCharacteristics extends
DataCharacteristics
return Math.max((long) Math.ceil((double)getCols() /
getBlocksize()), 1);
}
- @Override
- public String toString() {
- return "["+numRows+" x "+numColumns+", nnz="+nonZero+"
("+ubNnz+")"
- +", blocks ("+_blocksize+" x "+_blocksize+")]";
- }
-
@Override
public DataCharacteristics setDimension(long nr, long nc) {
numRows = nr;
@@ -256,4 +255,23 @@ public class MatrixCharacteristics extends
DataCharacteristics
return Arrays.hashCode(new long[]{
numRows, numColumns, _blocksize, nonZero});
}
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder(30);
+ sb.append("[");
+ sb.append(numRows);
+ sb.append(" x ");
+ sb.append(numColumns);
+ sb.append(", nnz=");
+ sb.append(nonZero);
+ sb.append(" (");
+ sb.append(ubNnz);
+ sb.append("), blocks (");
+ sb.append(_blocksize);
+ sb.append(" x ");
+ sb.append(_blocksize);
+ sb.append(")]");
+ return sb.toString();
+ }
}
diff --git a/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java
b/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java
index 8a139edf5e..43d8ac3840 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java
@@ -153,7 +153,7 @@ public class MetaDataAll extends DataIdentifier {
boolean isValidName =
DataExpression.READ_VALID_MTD_PARAM_NAMES.contains(key);
if (!isValidName){ //wrong parameters always rejected
- raiseValidateError("MTD file " + " contains
invalid parameter name: " + key, false);
+ raiseValidateError("MTD file contains invalid
parameter name: " + key, false);
}
parseMetaDataParam(key, val);
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 5e64ddea07..9283fa15db 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -38,6 +38,7 @@ import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
@@ -1451,6 +1452,7 @@ public abstract class AutomatedTestBase {
throw new RuntimeException("Our tests should run faster
than 1000 sec each",e);
}
catch(Exception e){
+ fail(e.getMessage());
throw new RuntimeException(e);
}
finally{
@@ -1541,18 +1543,23 @@ public abstract class AutomatedTestBase {
}
if(!exceptionExpected || (expectedException != null &&
!(e.getClass().equals(expectedException)))) {
StringBuilder errorMessage = new
StringBuilder();
- errorMessage.append("\nfailed to run script: "
+ executionFile);
+ String base = "\nfailed to run script: " +
executionFile;
+ errorMessage.append(base);
errorMessage.append("\nStandard Out:");
if(outputBuffering)
errorMessage.append("\n" + buff);
- errorMessage.append("\nStackTrace:");
- errorMessage.append(getStackTraceString(e, 0));
- fail(errorMessage.toString());
+ // errorMessage.append("\nStackTrace:");
+ // errorMessage.append(getStackTraceString(e,
0));
+ LOG.error(errorMessage);
+ e.printStackTrace();
+ fail(base);
}
}
- if(outputBuffering) {
- System.out.flush();
- System.setOut(old);
+ finally{
+ if(outputBuffering) {
+ System.out.flush();
+ System.setOut(old);
+ }
}
return buff;
}
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java
b/src/test/java/org/apache/sysds/test/TestUtils.java
index 1b6703b29b..c90318b448 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -800,16 +800,16 @@ public class TestUtils {
fail("Invalid number of rows and cols in actual");
int countErrors = 0;
- for (int i = 0; i < rows && countErrors < 50; i++) {
- for (int j = 0; j < cols && countErrors < 50; j++) {
+ for (int i = 0; i < rows && countErrors < 10; i++) {
+ for (int j = 0; j < cols && countErrors < 10; j++) {
if (!compareCellValue(expectedMatrix[i][j],
actualMatrix[i][j], epsilon, true)) {
message += ("\n Expected: "
+expectedMatrix[i][j] +" vs actual: "+actualMatrix[i][j]+" at "+i+" "+j);
countErrors++;
}
}
}
- if(countErrors == 50){
- assertTrue(message+" \n More than 50 values are not
equal using epsilon " + epsilon, countErrors == 0);
+ if(countErrors == 10){
+ assertTrue(message+" \n More than 10 values are not
equal using epsilon " + epsilon, countErrors == 0);
}else{
assertTrue(message+" \n" + countErrors + " values are
not in equal using epsilon " + epsilon, countErrors == 0);
}
@@ -1505,6 +1505,14 @@ public class TestUtils {
compareMatrices(m1, m2, tolerance, null);
}
+ /**
+ * compare and error out on differences above tolerance
+ *
+ * @param m1 expected matrix
+ * @param m2 actual matrix
+ * @param tolerance tolerance
+ * @param message error message
+ */
public static void compareMatrices(MatrixBlock m1, MatrixBlock m2,
double tolerance, String message) {
if(m1.getNumRows() != m2.getNumRows() || m1.getNumColumns() !=
m2.getNumColumns())
fail("Matrices are different sizes " + m1.getNumRows()
+ "," + m1.getNumColumns() + " vs " + m2.getNumRows()
@@ -3253,11 +3261,11 @@ public class TestUtils {
}
public static MatrixBlock round(MatrixBlock data) {
- return data.unaryOperations(new
UnaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.ROUND)), null);
+ return data.unaryOperations(new
UnaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.ROUND),2, true), null);
}
public static MatrixBlock ceil(MatrixBlock data){
- return data.unaryOperations(new
UnaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.CEIL)), null);
+ return data.unaryOperations(new
UnaryOperator(Builtin.getBuiltinFnObject(BuiltinCode.CEIL),2, true), null);
}
public static double[][] floor(double[][] data) {