This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new c745550397 [SYSTEMDS-3778] New determinant function, kernels, rewrites
c745550397 is described below
commit c7455503978eab8808fb19ca7e77bfc9f6cc510d
Author: mike0609king <[email protected]>
AuthorDate: Sat Feb 1 18:38:41 2025 +0100
[SYSTEMDS-3778] New determinant function, kernels, rewrites
DIA WiSe 24/25 project
Closes #2196.
Co-authored-by: Lou Frizzi Maria Wagner
<[email protected]>
Co-authored-by: Laurits Sartorius <[email protected]>
---
.../java/org/apache/sysds/common/Builtins.java | 1 +
src/main/java/org/apache/sysds/common/Opcodes.java | 1 +
src/main/java/org/apache/sysds/common/Types.java | 4 +-
src/main/java/org/apache/sysds/hops/UnaryOp.java | 3 +-
.../RewriteAlgebraicSimplificationStatic.java | 85 ++++++++
.../sysds/parser/BuiltinFunctionExpression.java | 11 +
.../org/apache/sysds/parser/DMLTranslator.java | 1 +
.../apache/sysds/resource/cost/CPCostUtils.java | 1 +
.../instructions/cp/UnaryMatrixCPInstruction.java | 9 +-
.../sysds/runtime/matrix/data/LibCommonsMath.java | 162 ++++++++++++++-
.../test/functions/rewrite/RewriteDetTest.java | 225 +++++++++++++++++++++
.../sysds/test/functions/unary/matrix/DetTest.java | 115 +++++++++++
.../scripts/functions/rewrite/RewriteDetMixed.R | 34 ++++
.../scripts/functions/rewrite/RewriteDetMixed.dml | 28 +++
.../scripts/functions/rewrite/RewriteDetMult.R | 32 +++
.../scripts/functions/rewrite/RewriteDetMult.dml | 27 +++
.../functions/rewrite/RewriteDetScalarMatrixMult.R | 33 +++
.../rewrite/RewriteDetScalarMatrixMult.dml | 27 +++
.../functions/rewrite/RewriteDetTranspose.R | 31 +++
.../functions/rewrite/RewriteDetTranspose.dml | 26 +++
src/test/scripts/functions/unary/matrix/DetTest.R | 31 +++
.../scripts/functions/unary/matrix/DetTest.dml | 26 +++
22 files changed, 909 insertions(+), 4 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index ab7400df44..1a7ba207b8 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -113,6 +113,7 @@ public enum Builtins {
DECISIONTREEPREDICT("decisionTreePredict", true),
DECOMPRESS("decompress", false),
DEEPWALK("deepWalk", true),
+ DET("det", false),
DETECTSCHEMA("detectSchema", false),
DENIALCONSTRAINTS("denialConstraints", true),
DIFFERENCESTATISTICS("differenceStatistics", true),
diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java
b/src/main/java/org/apache/sysds/common/Opcodes.java
index 77813e1bfa..a878d3f0ac 100644
--- a/src/main/java/org/apache/sysds/common/Opcodes.java
+++ b/src/main/java/org/apache/sysds/common/Opcodes.java
@@ -165,6 +165,7 @@ public enum Opcodes {
STOP("stop", CPType.Unary),
INVERSE("inverse", CPType.Unary),
CHOLESKY("cholesky", CPType.Unary),
+ DET("det", CPType.Unary),
SPROP("sprop", CPType.Unary),
SIGMOID("sigmoid", CPType.Unary),
TYPEOF("typeOf", CPType.Unary),
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index 5b56afe02f..6e61f44615 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -539,7 +539,7 @@ public interface Types {
CAST_AS_FRAME, CAST_AS_LIST, CAST_AS_MATRIX, CAST_AS_SCALAR,
CAST_AS_BOOLEAN, CAST_AS_DOUBLE, CAST_AS_INT,
CEIL, CHOLESKY, COS, COSH, CUMMAX, CUMMIN, CUMPROD, CUMSUM,
- CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR,
INVERSE,
+ CUMSUMPROD, DET, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP,
FLOOR, INVERSE,
IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT,
STOP, _EVICT,
SVD, TAN, TANH, TYPEOF, TRIGREMOTE, SQRT_MATRIX_JAVA,
@@ -558,6 +558,7 @@ public interface Types {
public boolean isScalarOutput() {
return this == CAST_AS_SCALAR
+ || this == DET
|| this == NROW || this == NCOL
|| this == LENGTH || this == EXISTS
|| this == IQM || this == LINEAGE
@@ -579,6 +580,7 @@ public interface Types {
case CUMPROD: return
Opcodes.UCUMM.toString();
case CUMSUM: return
Opcodes.UCUMKP.toString();
case CUMSUMPROD: return
Opcodes.UCUMKPM.toString();
+ case DET: return
Opcodes.DET.toString();
case DETECTSCHEMA: return
Opcodes.DETECTSCHEMA.toString();
case MULT2: return
Opcodes.MULT2.toString();
case NOT: return
Opcodes.NOT.toString();
diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java
b/src/main/java/org/apache/sysds/hops/UnaryOp.java
index 1bda77530b..91f3a5ec58 100644
--- a/src/main/java/org/apache/sysds/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java
@@ -512,7 +512,8 @@ public class UnaryOp extends MultiThreadedHop
//ensure cp exec type for single-node operations
if( _op == OpOp1.PRINT || _op == OpOp1.ASSERT || _op ==
OpOp1.STOP || _op == OpOp1.TYPEOF
- || _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op ==
OpOp1.CHOLESKY || _op == OpOp1.SVD || _op == OpOp1.SQRT_MATRIX_JAVA
+ || _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op ==
OpOp1.CHOLESKY || _op == OpOp1.DET
+ ||_op == OpOp1.SVD || _op == OpOp1.SQRT_MATRIX_JAVA
|| getInput().get(0).getDataType() == DataType.LIST ||
isMetadataOperation() )
{
_etype = ExecType.CP;
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index 6988ee5839..de14b5c5ec 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -159,12 +159,15 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
if(OptimizerUtils.ALLOW_OPERATOR_FUSION)
hi = simplifyMultiBinaryToBinaryOperation(hi);
//e.g., 1-X*Y -> X 1-* Y
hi = simplifyDistributiveBinaryOperation(hop, hi,
i);//e.g., (X-Y*X) -> (1-Y)*X
+ hi = simplifyTransposeInDetOperation(hop, hi, i);
//e.g., det(t(X)) -> det(X)
hi = simplifyBushyBinaryOperation(hop, hi, i);
//e.g., (X*(Y*(Z%*%v))) -> (X*Y)*(Z%*%v)
hi = simplifyUnaryAggReorgOperation(hop, hi, i);
//e.g., sum(t(X)) -> sum(X)
hi = removeUnnecessaryAggregates(hi);
//e.g., sum(rowSums(X)) -> sum(X)
hi = simplifyBinaryMatrixScalarOperation(hop, hi,
i);//e.g., as.scalar(X*s) -> as.scalar(X)*s;
hi = pushdownUnaryAggTransposeOperation(hop, hi, i);
//e.g., colSums(t(X)) -> t(rowSums(X))
hi = pushdownCSETransposeScalarOperation(hop, hi,
i);//e.g., a=t(X), b=t(X^2) -> a=t(X), b=t(X)^2 for CSE t(X)
+ hi = pushdownDetMultOperation(hop, hi, i);
//e.g., det(X%*%Y) -> det(X)*det(Y)
+ hi = pushdownDetScalarMatrixMultOperation(hop, hi, i);
//e.g., det(lambda*X) -> lambda^nrow(X)*det(X)
hi = pushdownSumBinaryMult(hop, hi, i);
//e.g., sum(lambda*X) -> lambda*sum(X)
hi = pullupAbs(hop, hi, i);
//e.g., abs(X)*abs(Y) --> abs(X*Y)
hi = simplifyUnaryPPredOperation(hop, hi, i);
//e.g., abs(ppred()) -> ppred(), others: round, ceil, floor
@@ -922,6 +925,29 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
return hi;
}
+ /**
+ * det(t(X)) -> det(X)
+ *
+ * @param parent parent high-level operator
+ * @param hi high-level operator
+ * @param pos position
+ * @return high-level operator
+ */
+ private static Hop simplifyTransposeInDetOperation(Hop parent, Hop hi,
int pos)
+ {
+ if(HopRewriteUtils.isUnary(hi, OpOp1.DET)
+ && HopRewriteUtils.isReorg(hi.getInput(0),
ReOrgOp.TRANS))
+ {
+ Hop operand = hi.getInput(0).getInput(0);
+ Hop uop = HopRewriteUtils.createUnary(operand,
OpOp1.DET);
+ HopRewriteUtils.replaceChildReference(parent, hi, uop,
pos);
+
+ LOG.debug("Applied simplifyTransposeInDetOperation.");
+ return uop;
+ }
+ return hi;
+ }
+
/**
* t(Z)%*%(X*(Y*(Z%*%v))) -> t(Z)%*%(X*Y)*(Z%*%v)
* t(Z)%*%(X+(Y+(Z%*%v))) -> t(Z)%*%((X+Y)+(Z%*%v))
@@ -1163,6 +1189,65 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
return hi;
}
+ /**
+ * det(X%*%Y) -> det(X)*det(Y)
+ *
+ * @param parent parent high-level operator
+ * @param hi high-level operator
+ * @param pos position
+ * @return high-level operator
+ */
+ private static Hop pushdownDetMultOperation(Hop parent, Hop hi, int
pos) {
+ if( HopRewriteUtils.isUnary(hi, OpOp1.DET)
+ &&
HopRewriteUtils.isMatrixMultiply(hi.getInput(0))
+ && hi.getInput(0).getInput(0).isMatrix()
+ && hi.getInput(0).getInput(1).isMatrix())
+ {
+ Hop operand1 = hi.getInput(0).getInput(0);
+ Hop operand2 = hi.getInput(0).getInput(1);
+ Hop uop1 = HopRewriteUtils.createUnary(operand1,
OpOp1.DET);
+ Hop uop2 = HopRewriteUtils.createUnary(operand2,
OpOp1.DET);
+ Hop bop = HopRewriteUtils.createBinary(uop1, uop2,
OpOp2.MULT);
+ HopRewriteUtils.replaceChildReference(parent, hi, bop,
pos);
+
+ LOG.debug("Applied pushdownDetMultOperation.");
+ return bop;
+ }
+ return hi;
+ }
+
+ /**
+ * det(lambda*X) -> lambda^nrow*det(X)
+ *
+ * @param parent parent high-level operator
+ * @param hi high-level operator
+ * @param pos position
+ * @return high-level operator
+ */
+ private static Hop pushdownDetScalarMatrixMultOperation(Hop parent, Hop
hi, int pos) {
+ if( HopRewriteUtils.isUnary(hi, OpOp1.DET)
+ && HopRewriteUtils.isBinary(hi.getInput(0),
OpOp2.MULT)
+ && ((hi.getInput(0).getInput(0).isMatrix() &&
hi.getInput(0).getInput(1).isScalar())
+ ||
(hi.getInput(0).getInput(0).isScalar() &&
hi.getInput(0).getInput(1).isMatrix())))
+ {
+ Hop operand1 = hi.getInput(0).getInput(0);
+ Hop operand2 = hi.getInput(0).getInput(1);
+
+ Hop lambda = (operand1.isScalar()) ? operand1 :
operand2;
+ Hop matrix = (operand1.isMatrix()) ? operand1 :
operand2;
+
+ Hop uopDet = HopRewriteUtils.createUnary(matrix,
OpOp1.DET);
+ Hop uopNrow = HopRewriteUtils.createUnary(matrix,
OpOp1.NROW);
+ Hop bopPow = HopRewriteUtils.createBinary(lambda,
uopNrow, OpOp2.POW);
+ Hop bopMult = HopRewriteUtils.createBinary(bopPow,
uopDet, OpOp2.MULT);
+ HopRewriteUtils.replaceChildReference(parent, hi,
bopMult, pos);
+
+ LOG.debug("Applied
pushdownDetScalarMatrixMultOperation.");
+ return bopMult;
+ }
+ return hi;
+ }
+
private static Hop pushdownSumBinaryMult(Hop parent, Hop hi, int pos ) {
//pattern: sum(lamda*X) -> lamda*sum(X)
if( hi instanceof AggUnaryOp &&
((AggUnaryOp)hi).getDirection()==Direction.RowCol
diff --git
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index 2fd1afd4a3..84e0fe079b 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -1302,6 +1302,17 @@ public class BuiltinFunctionExpression extends
DataIdentifier {
output.setBlocksize(id.getBlocksize());
output.setValueType(id.getValueType());
break;
+ case DET:
+ checkNumParameters(1);
+ checkMatrixParam(getFirstExpr());
+ if ( id.getDim2() == -1 || id.getDim1() != id.getDim2()
) {
+ raiseValidateError("det requires a square
matrix as first argument.", conditional, LanguageErrorCodes.INVALID_PARAMETERS);
+ }
+ output.setDataType(DataType.SCALAR);
+ output.setDimensions(0, 0);
+ output.setBlocksize(0);
+ output.setValueType(ValueType.FP64);
+ break;
case NROW:
case NCOL:
case LENGTH:
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index fb9457252e..60483d12e3 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2755,6 +2755,7 @@ public class DMLTranslator
case SQRT_MATRIX_JAVA:
case CHOLESKY:
case TYPEOF:
+ case DET:
case DETECTSCHEMA:
case COLNAMES:
currBuiltinOp = new UnaryOp(target.getName(),
target.getDataType(),
diff --git a/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java
b/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java
index 0ae3bc933d..23bb2dc640 100644
--- a/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java
+++ b/src/main/java/org/apache/sysds/resource/cost/CPCostUtils.java
@@ -501,6 +501,7 @@ public class CPCostUtils {
case "cholesky":
costs = (1.0 / 3.0) *
output.getCellsWithSparsity() * output.getCellsWithSparsity();
break;
+ case "det":
case "detectschema":
case "colnames":
throw new
RuntimeException("Specific Frame operation with opcode '" + opcode + "' is not
supported yet");
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryMatrixCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryMatrixCPInstruction.java
index ddbdd70b7d..7724e48660 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryMatrixCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryMatrixCPInstruction.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.instructions.cp;
+import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -57,7 +58,13 @@ public class UnaryMatrixCPInstruction extends
UnaryCPInstruction {
LineageItem lin = (!inObj.hasValidLineage() ||
!inObj.getCacheLineage().isLeaf() ||
CacheableData.isBelowCachingThreshold(retBlock)) ? null
:
getCacheLineageItem(inObj.getCacheLineage());
- ec.setMatrixOutputAndLineage(output, retBlock, lin);
+ if (getOpcode().equals("det")){
+ var temp =
ScalarObjectFactory.createScalarObject(ValueType.FP64, retBlock.get(0,0));
+ ec.setVariable(output.getName(), temp);
+ }
+ else {
+ ec.setMatrixOutputAndLineage(output, retBlock, lin);
+ }
}
public LineageItem getCacheLineageItem(LineageItem input) {
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
index e7643648e0..a96ded9fbe 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
@@ -82,7 +82,8 @@ public class LibCommonsMath
public static boolean isSupportedUnaryOperation( String opcode ) {
return opcode.equals(Opcodes.INVERSE.toString())
- || opcode.equals(Opcodes.CHOLESKY.toString())
+ || opcode.equals(Opcodes.CHOLESKY.toString())
+ || opcode.equals(Opcodes.DET.toString())
|| opcode.equals(Opcodes.SQRT_MATRIX_JAVA.toString());
}
@@ -117,6 +118,8 @@ public class LibCommonsMath
return computeMatrixInverse(matrixInput);
else if (opcode.equals(Opcodes.CHOLESKY.toString()))
return computeCholesky(matrixInput);
+ else if (opcode.equals(Opcodes.DET.toString()))
+ return computeDeterminant(matrixInput);
}
return null;
}
@@ -567,6 +570,163 @@ public class LibCommonsMath
return DataConverter.convertToMatrixBlock(rmL.getData());
}
+ /**
+ * Function to compute the determinant of a square matrix.
+ *
+ * @param in Array2DRowRealMatrix object
+ * @return determinant of the matrix as a 1x1 Matrixblock
+ */
+ private static MatrixBlock computeDeterminant(Array2DRowRealMatrix in) {
+ if(!in.isSquare()) {
+ throw new DMLRuntimeException("Determinant can only be
computed for a square matrix. Input matrix is rectangular");
+ }
+
+ final int useBuiltinStrategy = 0;
+ final int useGaussianStrategy = 1;
+ final int useBareissStrategy = 2;
+ final int useLaplaceStrategy = 3;
+ int computationStrategy = useBuiltinStrategy;
+
+ double determinant = 0;
+ switch (computationStrategy) {
+ case useGaussianStrategy:
+ determinant = computeDetGaussian(in);
+ break;
+ case useLaplaceStrategy:
+ determinant = computeDetLaplace(in);
+ break;
+ case useBareissStrategy:
+ determinant = computeDetBareiss(in);
+ break;
+ case useBuiltinStrategy:
+ default:
+ LUDecomposition ludecompose = new
LUDecomposition(in);
+ determinant = ludecompose.getDeterminant();
+ }
+
+ MatrixBlock determinantResult = new MatrixBlock(1, 1, false);
+ determinantResult.set(0, 0, determinant);
+ return determinantResult;
+ }
+
+ private static double computeDetGaussian(Array2DRowRealMatrix in) {
+ double[][] matrix = in.getData();
+ int size = in.getRowDimension();
+ double determinant = 1.0;
+ int swapCount = 0;
+
+ // create upper triangular matrix
+ for (int pivotRow = 0; pivotRow < size; pivotRow++) {
+
+ // Find a non-zero pivot in current column
+ boolean nonZeroPivotFound = false;
+ for (int swapRow = pivotRow; swapRow < size; swapRow++)
{
+ if (Math.abs(matrix[swapRow][pivotRow]) > 1e-9)
{ // small epsilon for fp comparison
+ // Swap rows if necessary to move pivot
to diagonal position
+ if (swapRow != pivotRow) {
+ double[] tempRow =
matrix[swapRow];
+ matrix[swapRow] =
matrix[pivotRow];
+ matrix[pivotRow] = tempRow;
+ swapCount = swapCount + 1;
+ }
+ nonZeroPivotFound = true;
+ break;
+ }
+ }
+
+ if (!nonZeroPivotFound) {
+ // one diagonal element is 0, therefore the
multiplication of the
+ // diagonal elements would be zero aswell
+ determinant = 0;
+ break;
+ }
+
+ // eliminate entries below pivot
+ for (int row = pivotRow + 1; row < size; row++) {
+ double factor = matrix[row][pivotRow] /
matrix[pivotRow][pivotRow];
+
+ // update the row using the elimination factor
+ for (int col = pivotRow; col < size; col++) {
+ matrix[row][col] = matrix[row][col] -
(factor * matrix[pivotRow][col]);
+ }
+ }
+ }
+
+ // Calculate product of diagonal elements
+ for (int i = 0; i < size; i++) {
+ determinant = determinant * matrix[i][i];
+ }
+ if (swapCount % 2 != 0) {
+ determinant = -determinant;
+ }
+ return determinant;
+ }
+
+ private static double computeDetLaplace(Array2DRowRealMatrix in) {
+ int length = in.getRowDimension();
+ double determinant = 0;
+
+ // base case 2x2 matrix
+ if (length == 2) {
+ return in.getEntry(0, 0) * in.getEntry(1, 1) -
in.getEntry(0, 1) * in.getEntry(1, 0);
+ }
+
+ // laplace expansion
+ for (int col = 0; col < length; col++) {
+ if (in.getEntry(0, col) == 0) {
+ // multiplication with zero results in zero
+ continue;
+ }
+ // Build submatrix
+ Array2DRowRealMatrix subMatrix = new
Array2DRowRealMatrix(length - 1, length - 1);
+ for (int i = 1; i < length; i++) { // Skip first row
+ int subCol = 0;
+ for (int j = 0; j < length; j++) {
+ if (j == col) continue; // Skip current
col
+ subMatrix.setEntry(i - 1, subCol,
in.getEntry(i, j));
+ subCol++;
+ }
+ }
+ // recusive determinant calculation
+ int sign = (col % 2 == 0) ? 1 : -1;
+ double subDeterminant =
computeDeterminant(subMatrix).get(0, 0);
+ determinant = determinant + sign * in.getEntry(0, col)
* subDeterminant;
+ }
+ return determinant;
+ }
+
+ private static double computeDetBareiss(Array2DRowRealMatrix in) {
+ int n = in.getRowDimension();
+ int sign = 1;
+ for (int k = 0; k < n - 1; k++) {
+ if (0 == in.getEntry(k, k)) {
+ boolean found = false;
+ for (int m = k + 1; m < n; m++) {
+ if (0 == in.getEntry(m, k)) { continue;
}
+ found = true;
+ sign = -1*sign;
+ double[] tmp = in.getRow(m);
+ in.setRow(m, in.getRow(k));
+ in.setRow(k, tmp);
+ break;
+ }
+ if (!found) {
+ in.getEntry(n - 1, n - 1);
+ break;
+ }
+ }
+
+ for (int i = k + 1; i < n; i++) {
+ for (int j = k + 1; j < n; j++) {
+ double den = (0 == k) ? 1 :
in.getEntry(k-1, k-1);
+ double num = in.getEntry(i,
j)*in.getEntry(k, k) - in.getEntry(i, k)*in.getEntry(k, j);
+ in.setEntry(i, j, num/den);
+ }
+ }
+ }
+ return sign * in.getEntry(n - 1, n - 1);
+ }
+
/**
* Creates a random normalized vector with dim elements.
*
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDetTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDetTest.java
new file mode 100644
index 0000000000..288ff0e44e
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteDetTest.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+
+import java.util.HashMap;
+
+public class RewriteDetTest extends AutomatedTestBase
+{
+ private static final String TEST_NAME_MIXED = "RewriteDetMixed";
+ private static final String TEST_NAME_MULT = "RewriteDetMult";
+ private static final String TEST_NAME_TRANSPOSE = "RewriteDetTranspose";
+ private static final String TEST_NAME_SCALAR_MATRIX_MULT =
"RewriteDetScalarMatrixMult";
+
+ private static final String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewriteDetTest.class.getSimpleName() + "/";
+
+ private final static int rows = 23;
+ private final static double _sparsityDense = 0.7;
+ private final static double _sparsitySparse = 0.2;
+ private final static double eps = 1e-8;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME_MIXED, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_MIXED, new String[] { "d" }));
+ // det(A%*%B) -> det(A)*det(B)
+ addTestConfiguration(TEST_NAME_MULT, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_MULT, new String[] { "d" }));
+ // det(t(A)) -> det(A)
+ addTestConfiguration(TEST_NAME_TRANSPOSE, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_TRANSPOSE, new String[] { "d" }));
+ // det(lambda*A) -> lambda^ncol*det(A)
+ // This is faster, because lambda is a scalar, that can be
multiplied
+ // with in logarithmic time O(log(nrow(A))), whereas lambda
needs to
+ // be multiplied to every element in A, which is O(nrow(A)^2)).
+ addTestConfiguration(TEST_NAME_SCALAR_MATRIX_MULT, new
TestConfiguration(
+ TEST_CLASS_DIR, TEST_NAME_SCALAR_MATRIX_MULT, new
String[] { "d" }));
+ }
+
+ @Test
+ public void testRewriteDetMixedSparseNoRewrite() {
+ runRewriteDetTest(TEST_NAME_MIXED, false, true, ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteDetMixedSparseRewrite() {
+ runRewriteDetTest(TEST_NAME_MIXED, true, true, ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteDetMixedDenseNoRewrite() {
+ runRewriteDetTest(TEST_NAME_MIXED, false, false, ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteDetMixedDenseRewrite() {
+ runRewriteDetTest(TEST_NAME_MIXED, true, false, ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteDetMultDenseRewrite() {
+ runRewriteDetTest(TEST_NAME_MULT, true, false, ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteDetMultSparseRewrite() {
+ runRewriteDetTest(TEST_NAME_MULT, true, true, ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteDetMultDenseNoRewrite() {
+ runRewriteDetTest(TEST_NAME_MULT, false, false, ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteDetMultSparseNoRewrite() {
+ runRewriteDetTest(TEST_NAME_MULT, false, true, ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteDetTransposeSparseRewrite() {
+ runRewriteDetTest(TEST_NAME_TRANSPOSE, true, true, ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteDetTransposeDenseRewrite() {
+ runRewriteDetTest(TEST_NAME_TRANSPOSE, true, false,
ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteDetTransposeSparseNoRewrite() {
+ runRewriteDetTest(TEST_NAME_TRANSPOSE, false, true,
ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteDetTransposeDenseNoRewrite() {
+ runRewriteDetTest(TEST_NAME_TRANSPOSE, false, false,
ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteDetScalarMatrixMultDenseNoRewrite() {
+ runRewriteDetTest(TEST_NAME_SCALAR_MATRIX_MULT, false, false,
ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteDetScalarMatrixMultSparseNoRewrite() {
+ runRewriteDetTest(TEST_NAME_SCALAR_MATRIX_MULT, false, true,
ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteDetScalarMatrixMultDenseRewrite() {
+ runRewriteDetTest(TEST_NAME_SCALAR_MATRIX_MULT, true, false,
ExecType.CP);
+ }
+
+ @Test
+ public void testRewriteDetScalarMatrixMultSparseRewrite() {
+ runRewriteDetTest(TEST_NAME_SCALAR_MATRIX_MULT, true, true,
ExecType.CP);
+ }
+
+ private void runRewriteDetTest(String testScriptName, boolean rewrites,
boolean sparse, ExecType et) {
+ // NOTE The sparsity of the matrix is considered, because
rewrite
+ // simplifications could be made if the matrix contains a lot
of zeros.
+ // Furthermore, some det-algorithms perform optimizations
+ // (early termination, less recursions, ...) when the matrix is
sparse.
+ // Therefore dense and sparse matrices are part of the rewrite
tests.
+
+ ExecMode platformOld = setExecMode(et);
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+
+ try {
+ double sparsity = (sparse) ? _sparsitySparse :
_sparsityDense;
+ getAndLoadTestConfiguration(testScriptName);
+
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewrites;
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + testScriptName + ".dml";
+
+ boolean twoMatrixArg =
testScriptName.equals(TEST_NAME_MULT);
+ boolean oneMatrixOneScalarArg =
testScriptName.equals(TEST_NAME_SCALAR_MATRIX_MULT);
+ boolean twoMatrixOneScalarArg =
testScriptName.equals(TEST_NAME_MIXED);
+ if (twoMatrixArg) {
+ programArgs = new String[]{"-stats", "-args",
input("A"), input("B"), output("d")};
+ }
+ else if (oneMatrixOneScalarArg) {
+ programArgs = new String[]{"-stats", "-args",
input("A"), input("lambda"), output("d")};
+ }
+ else if (twoMatrixOneScalarArg) {
+ programArgs = new String[]{"-stats", "-args",
input("A"), input("B"), input("lambda"), output("d")};
+ }
+ else {
+ programArgs = new String[]{"-stats", "-args",
input("A"), output("d")};
+ }
+
+ fullRScriptName = HOME + testScriptName + ".R";
+ rCmd = "Rscript" + " " + fullRScriptName + " " +
inputDir() + " " + expectedDir();
+
+ double[][] A = getRandomMatrix(rows, rows, -1, 1,
sparsity, 21332);
+ writeInputMatrixWithMTD("A", A, true);
+ if (twoMatrixArg) {
+ double[][] B = getRandomMatrix(rows, rows, -1,
1, sparsity, 42422);
+ writeInputMatrixWithMTD("B", B, true);
+ }
+ else if (twoMatrixOneScalarArg) {
+ double[][] B = getRandomMatrix(rows, rows, -1,
1, sparsity, 4242);
+ writeInputMatrixWithMTD("B", B, true);
+ double[][] lambda = getRandomMatrix(1, 1, -1,
1, sparsity, 121);
+ writeInputMatrixWithMTD("lambda", lambda, true);
+ }
+ else if (oneMatrixOneScalarArg) {
+ double[][] lambda = getRandomMatrix(1, 1, -1,
1, sparsity, 121);
+ writeInputMatrixWithMTD("lambda", lambda, true);
+ }
+
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ HashMap<CellIndex, Double> dmlfile =
readDMLScalarFromOutputDir("d");
+ HashMap<CellIndex, Double> rfile =
readRScalarFromExpectedDir("d");
+ TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
+
+ if (rewrites) {
+ Assert.assertTrue(
+
(!testScriptName.equals(TEST_NAME_TRANSPOSE) ||
!heavyHittersContainsString(Opcodes.TRANSPOSE.toString()))
+ &&
(!testScriptName.equals(TEST_NAME_MULT) ||
Statistics.getCPHeavyHitterCount(Opcodes.DET.toString()) == 2)
+ &&
(!testScriptName.equals(TEST_NAME_SCALAR_MATRIX_MULT) ||
heavyHittersContainsString(Opcodes.POW.toString()))
+ &&
(!testScriptName.equals(TEST_NAME_MIXED) || (
+
Statistics.getCPHeavyHitterCount(Opcodes.DET.toString()) == 2
+ &&
!heavyHittersContainsString(Opcodes.TRANSPOSE.toString())
+ &&
heavyHittersContainsString(Opcodes.POW.toString()))));
+ }
+ }
+ finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ resetExecMode(platformOld);
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/unary/matrix/DetTest.java
b/src/test/java/org/apache/sysds/test/functions/unary/matrix/DetTest.java
new file mode 100644
index 0000000000..2f4c049ed2
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/unary/matrix/DetTest.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.unary.matrix;
+
+import org.junit.Test;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.parser.LanguageException;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+import java.util.HashMap;
+
+
+public class DetTest extends AutomatedTestBase {
+
+ private static final String TEST_DIR = "functions/unary/matrix/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
DetTest.class.getSimpleName() + "/";
+ private static final String DML_SCRIPT_NAME = "DetTest";
+ private static final String R_SCRIPT_NAME = "DetTest";
+
+ private static final String TEST_NAME_WRONG_DIM = "WrongDimensionsTest";
+ private static final String TEST_NAME_DET_TEST = "DetTest";
+
+ // The number of rows and columns should not be chosen to be too large,
+ // because the calculation of the determinant can introduce rather large
+ // floating point errors with large row sizes, because there are many
+ // floating point operations involving both multiplication and addition.
+ private final static int rows = 23;
+ private final static double _sparsityDense = 0.7;
+ private final static double _sparsitySparse = 0.2;
+ private final static double eps = 1e-8;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME_WRONG_DIM, new
TestConfiguration(TEST_CLASS_DIR, DML_SCRIPT_NAME, new String[] { "d" }));
+ addTestConfiguration(TEST_NAME_DET_TEST, new
TestConfiguration(TEST_CLASS_DIR, DML_SCRIPT_NAME, new String[] { "d" }) );
+ }
+
+ @Test
+ public void testWrongDimensions() {
+ int wrong_rows = 10;
+ int wrong_cols = 9;
+
+ TestConfiguration config =
availableTestConfigurations.get(TEST_NAME_WRONG_DIM);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + DML_SCRIPT_NAME + ".dml";
+ programArgs = new String[]{"-args", input("A"), output("d") };
+
+ double[][] A = getRandomMatrix(wrong_rows, wrong_cols, -1, 1,
0.5, 3);
+ writeInputMatrixWithMTD("A", A, true);
+ runTest(true, true, LanguageException.class, -1);
+ }
+
+ @Test
+ public void testDetMatrixDense() {
+ runDetTest(false);
+ }
+
+ @Test
+ public void testDetMatrixSparse() {
+ runDetTest(true);
+ }
+
+ private void runDetTest(boolean sparse) {
+ ExecMode platformOld = rtplatform;
+ rtplatform = ExecMode.HYBRID;
+
+ try {
+ double sparsity = (sparse) ? _sparsitySparse :
_sparsityDense;
+ getAndLoadTestConfiguration(TEST_NAME_DET_TEST);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + DML_SCRIPT_NAME + ".dml";
+ programArgs = new String[]{"-args", input("A"),
output("d")};
+
+ fullRScriptName = HOME + R_SCRIPT_NAME + ".R";
+ rCmd = "Rscript" + " " + fullRScriptName + " " +
inputDir() + " " + expectedDir();
+
+ double[][] A = getRandomMatrix(rows, rows, -1, 1,
sparsity, 7);
+ writeInputMatrixWithMTD("A", A, true);
+
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ HashMap<CellIndex, Double> dmlfile =
readDMLScalarFromOutputDir("d");
+ HashMap<CellIndex, Double> rfile =
readRScalarFromExpectedDir("d");
+ TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+}
+
diff --git a/src/test/scripts/functions/rewrite/RewriteDetMixed.R
b/src/test/scripts/functions/rewrite/RewriteDetMixed.R
new file mode 100644
index 0000000000..2fc8771fb1
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteDetMixed.R
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# junit test class: org.apache.sysds.test.functions.rewrite.RewriteDetTest.java
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A <- readMM(paste(args[1], "A.mtx", sep=""));
+B <- readMM(paste(args[1], "B.mtx", sep=""));
+tmp_lambda <- readMM(paste(args[1], "lambda.mtx", sep=""));
+lambda <- as.double(tmp_lambda[1, 1]);
+d = det(lambda*t(t(A) %*% t(B)));
+write(d, paste(args[2], "d", sep=""));
diff --git a/src/test/scripts/functions/rewrite/RewriteDetMixed.dml
b/src/test/scripts/functions/rewrite/RewriteDetMixed.dml
new file mode 100644
index 0000000000..1363b8c774
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteDetMixed.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# junit test class: org.apache.sysds.test.functions.rewrite.RewriteDetTest.java
+
+A = read($1);
+B = read($2);
+lambda = as.scalar(read($3));
+d = det(lambda*t(t(A)%*%t(B)));
+write(d, $4);
diff --git a/src/test/scripts/functions/rewrite/RewriteDetMult.R
b/src/test/scripts/functions/rewrite/RewriteDetMult.R
new file mode 100644
index 0000000000..9e4e0efc70
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteDetMult.R
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# junit test class: org.apache.sysds.test.functions.rewrite.RewriteDetTest.java
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A <- readMM(paste(args[1], "A.mtx", sep=""))
+B <- readMM(paste(args[1], "B.mtx", sep=""))
+d = det(A) * det(B);
+write(d, paste(args[2], "d", sep=""));
diff --git a/src/test/scripts/functions/rewrite/RewriteDetMult.dml
b/src/test/scripts/functions/rewrite/RewriteDetMult.dml
new file mode 100644
index 0000000000..bd0d48f7c5
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteDetMult.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# junit test class: org.apache.sysds.test.functions.rewrite.RewriteDetTest.java
+
+A = read($1);
+B = read($2);
+d = det(A%*%B);
+write(d, $3);
diff --git a/src/test/scripts/functions/rewrite/RewriteDetScalarMatrixMult.R
b/src/test/scripts/functions/rewrite/RewriteDetScalarMatrixMult.R
new file mode 100644
index 0000000000..c4ccf03136
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteDetScalarMatrixMult.R
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# junit test class: org.apache.sysds.test.functions.rewrite.RewriteDetTest.java
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A <- readMM(paste(args[1], "A.mtx", sep=""));
+tmp_lambda <- readMM(paste(args[1], "lambda.mtx", sep=""));
+lambda <- as.double(tmp_lambda[1, 1]);
+d = det(lambda*A);
+write(d, paste(args[2], "d", sep=""));
diff --git a/src/test/scripts/functions/rewrite/RewriteDetScalarMatrixMult.dml
b/src/test/scripts/functions/rewrite/RewriteDetScalarMatrixMult.dml
new file mode 100644
index 0000000000..7c234173aa
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteDetScalarMatrixMult.dml
@@ -0,0 +1,27 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# junit test class: org.apache.sysds.test.functions.rewrite.RewriteDetTest.java
+
+A = read($1);
+lambda = as.scalar(read($2));
+d = det(lambda*A);
+write(d, $3);
diff --git a/src/test/scripts/functions/rewrite/RewriteDetTranspose.R
b/src/test/scripts/functions/rewrite/RewriteDetTranspose.R
new file mode 100644
index 0000000000..4d2fdc6ed3
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteDetTranspose.R
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# junit test class: org.apache.sysds.test.functions.rewrite.RewriteDetTest.java
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A <- readMM(paste(args[1], "A.mtx", sep=""))
+d = det(t(A))
+write(d, paste(args[2], "d", sep=""));
diff --git a/src/test/scripts/functions/rewrite/RewriteDetTranspose.dml
b/src/test/scripts/functions/rewrite/RewriteDetTranspose.dml
new file mode 100644
index 0000000000..5de949b489
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteDetTranspose.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# junit test class: org.apache.sysds.test.functions.rewrite.RewriteDetTest.java
+
+A = read($1);
+d = det(t(A));
+write(d, $2);
diff --git a/src/test/scripts/functions/unary/matrix/DetTest.R
b/src/test/scripts/functions/unary/matrix/DetTest.R
new file mode 100644
index 0000000000..175a32684a
--- /dev/null
+++ b/src/test/scripts/functions/unary/matrix/DetTest.R
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# junit test class: org.apache.sysds.test.functions.unary.matrix.DetTest.java
+
+args <- commandArgs(TRUE)
+options(digits=22)
+
+library("Matrix")
+
+A <- readMM(paste(args[1], "A.mtx", sep=""))
+d = det(A);
+write(d, paste(args[2], "d", sep=""));
diff --git a/src/test/scripts/functions/unary/matrix/DetTest.dml
b/src/test/scripts/functions/unary/matrix/DetTest.dml
new file mode 100644
index 0000000000..b372fdbad3
--- /dev/null
+++ b/src/test/scripts/functions/unary/matrix/DetTest.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# junit test class: org.apache.sysds.test.functions.unary.matrix.DetTest.java
+
+A = read($1);
+d = det(A);
+write(d, $2);