This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 06ad1b5a82 [SYSTEMDS-3258] Builtin for matrix square root (multiple
strategies)
06ad1b5a82 is described below
commit 06ad1b5a825dc1fddbeea317e89ac7550366c2d8
Author: trp-ex <[email protected]>
AuthorDate: Thu Jan 16 08:07:19 2025 +0100
[SYSTEMDS-3258] Builtin for matrix square root (multiple strategies)
Closes #2178.
Co-authored-by: Florian Hoffmann <[email protected]>
Co-authored-by: Melisa Akbaydar <[email protected]>
---
scripts/builtin/sqrtMatrix.dml | 114 ++++++++++
.../java/org/apache/sysds/common/Builtins.java | 2 +
src/main/java/org/apache/sysds/common/Types.java | 2 +-
src/main/java/org/apache/sysds/hops/UnaryOp.java | 2 +-
.../sysds/parser/BuiltinFunctionExpression.java | 13 ++
.../org/apache/sysds/parser/DMLTranslator.java | 1 +
.../runtime/instructions/CPInstructionParser.java | 1 +
.../sysds/runtime/matrix/data/LibCommonsMath.java | 18 +-
.../builtin/part2/BuiltinSQRTMatrixTest.java | 235 +++++++++++++++++++++
src/test/scripts/functions/builtin/SQRTMatrix.dml | 32 +++
10 files changed, 416 insertions(+), 4 deletions(-)
diff --git a/scripts/builtin/sqrtMatrix.dml b/scripts/builtin/sqrtMatrix.dml
new file mode 100644
index 0000000000..12e1829ad4
--- /dev/null
+++ b/scripts/builtin/sqrtMatrix.dml
@@ -0,0 +1,114 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Computes the matrix square root B of a matrix A, such that
+# A = B %*% B.
+#
+# INPUT:
+#
------------------------------------------------------------------------------
+# A Input Matrix A
+# S Strategy (COMMON .. java-based commons-math, DML)
+#
------------------------------------------------------------------------------
+#
+# OUTPUT:
+#
------------------------------------------------------------------------------
+# B Output Matrix B
+#
------------------------------------------------------------------------------
+
+
+m_sqrtMatrix = function(Matrix[Double] A, String S)
+ return(Matrix[Double] B)
+{
+ if (S == "COMMON") {
+ B = sqrtMatrixJava(A)
+ } else if (S == "DML") {
+ N = nrow(A);
+ D = ncol(A);
+
+ #check that matrix is square
+ if (D != N){
+ stop("matrixSqrt Input Error: matrix not square!")
+ }
+
+ # Any non singualar square matrix has a square root
+ isDiag = isDiagonal(A)
+ if(isDiag) {
+ B = sqrtDiagMatrix(A);
+ } else {
+ [eValues, eVectors] = eigen(A);
+
+ hasNonNegativeEigenValues = (sum(eValues >= 0) == length(eValues));
+
+ if(!hasNonNegativeEigenValues) {
+ stop("matrixSqrt exec Error: matrix has imaginary square root");
+ }
+
+ isSymmetric = sum(A == t(A)) == length(A);
+ allEigenValuesUnique = length(eValues) == length(unique(eValues));
+
+ if(allEigenValuesUnique | isSymmetric) {
+ # calculate X = VDV^(-1) -> S = sqrt(D) -> sqrt_x = VSV^(-1)
+ sqrtD = sqrtDiagMatrix(diag(eValues));
+ V_Inv = inv(eVectors);
+ B = eVectors %*% sqrtD %*% V_Inv;
+ } else {
+ #formular: (Denman–Beavers iteration)
+ Y = A
+ #identity matrix
+ Z = diag(matrix(1.0, rows=N, cols=1))
+
+ for (x in 1:100) {
+ Y_new = (1 / 2) * (Y + inv(Z))
+ Z_new = (1 / 2) * (Z + inv(Y))
+ Y = Y_new
+ Z = Z_new
+ }
+ B = Y
+ }
+ }
+ } else {
+ stop("Error: Unknown strategy for matrix square root.")
+ }
+}
+
+# assumes square and diagonal matrix
+sqrtDiagMatrix = function(Matrix[Double] X)
+ return(Matrix[Double] sqrt_x)
+{
+ N = nrow(X);
+
+ #check if identity matrix
+ is_identity = sum(diag(diag(X)) == X)==length(X)
+ & sum(diag(X) == matrix(1,nrow(X),1))==nrow(X);
+
+ if(is_identity)
+ sqrt_x = X;
+ else
+ sqrt_x = diag(sqrt(diag(X)));
+}
+
+isDiagonal = function (Matrix[Double] X)
+ return(boolean diagonal)
+{
+ #all cells should be the same to be diagonal
+ diagonal = sum(diag(diag(X)) == X) == length(X);
+}
+
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index a6331905ac..5429cb287c 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -325,6 +325,8 @@ public enum Builtins {
STEPLM("steplm",true, ReturnType.MULTI_RETURN),
STFT("stft", false, ReturnType.MULTI_RETURN),
SQRT("sqrt", false),
+ SQRT_MATRIX("sqrtMatrix", true),
+ SQRT_MATRIX_JAVA("sqrtMatrixJava", false, ReturnType.SINGLE_RETURN),
SUM("sum", false),
SVD("svd", false, ReturnType.MULTI_RETURN),
TABLE("table", "ctable", false),
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index dd351ae894..21595efd03 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -542,7 +542,7 @@ public interface Types {
CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR,
INVERSE,
IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT,
STOP, _EVICT,
- SVD, TAN, TANH, TYPEOF, TRIGREMOTE,
+ SVD, TAN, TANH, TYPEOF, TRIGREMOTE, SQRT_MATRIX_JAVA,
//fused ML-specific operators for performance
SPROP, //sample proportion: P * (1 - P)
SIGMOID, //sigmoid function: 1 / (1 + exp(-X))
diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java
b/src/main/java/org/apache/sysds/hops/UnaryOp.java
index 2c0cd4a61b..1bda77530b 100644
--- a/src/main/java/org/apache/sysds/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java
@@ -512,7 +512,7 @@ public class UnaryOp extends MultiThreadedHop
//ensure cp exec type for single-node operations
if( _op == OpOp1.PRINT || _op == OpOp1.ASSERT || _op ==
OpOp1.STOP || _op == OpOp1.TYPEOF
- || _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op ==
OpOp1.CHOLESKY || _op == OpOp1.SVD
+ || _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op ==
OpOp1.CHOLESKY || _op == OpOp1.SVD || _op == OpOp1.SQRT_MATRIX_JAVA
|| getInput().get(0).getDataType() == DataType.LIST ||
isMetadataOperation() )
{
_etype = ExecType.CP;
diff --git
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index 1de3442dd9..c12e4c4705 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -1759,6 +1759,19 @@ public class BuiltinFunctionExpression extends
DataIdentifier {
output.setDimensions(in.getDim1(), in.getDim2());
output.setBlocksize(in.getBlocksize());
break;
+
+ case SQRT_MATRIX_JAVA:
+
+ checkNumParameters(1);
+ checkMatrixParam(getFirstExpr());
+ output.setDataType(DataType.MATRIX);
+ output.setValueType(ValueType.FP64);
+ Identifier sqrt = getFirstExpr().getOutput();
+ if(sqrt.dimsKnown() && sqrt.getDim1() != sqrt.getDim2())
+ raiseValidateError("Input to sqrtMatrix() must
be square matrix -- given: a " + sqrt.getDim1() + "x" + sqrt.getDim2() + "
matrix.", conditional);
+ output.setDimensions( sqrt.getDim1(), sqrt.getDim2());
+ output.setBlocksize( sqrt.getBlocksize());
+ break;
case CHOLESKY:
{
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 6121711933..b0673be092 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2749,6 +2749,7 @@ public class DMLTranslator
break;
case INVERSE:
+ case SQRT_MATRIX_JAVA:
case CHOLESKY:
case TYPEOF:
case DETECTSCHEMA:
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index a0270f6b20..2d19b39f8a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -208,6 +208,7 @@ public class CPInstructionParser extends InstructionParser {
String2CPInstructionType.put( "ucummax", CPType.Unary);
String2CPInstructionType.put( "stop" , CPType.Unary);
String2CPInstructionType.put( "inverse", CPType.Unary);
+ String2CPInstructionType.put( "sqrt_matrix_java", CPType.Unary);
String2CPInstructionType.put( "cholesky",CPType.Unary);
String2CPInstructionType.put( "sprop", CPType.Unary);
String2CPInstructionType.put( "sigmoid", CPType.Unary);
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
index 61a5f0d784..5365944a3b 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java
@@ -80,7 +80,7 @@ public class LibCommonsMath
}
public static boolean isSupportedUnaryOperation( String opcode ) {
- return ( opcode.equals("inverse") || opcode.equals("cholesky")
);
+ return ( opcode.equals("inverse") || opcode.equals("cholesky")
|| opcode.equals("sqrt_matrix_java") );
}
public static boolean isSupportedMultiReturnOperation( String opcode ) {
@@ -111,6 +111,8 @@ public class LibCommonsMath
return computeMatrixInverse(matrixInput);
else if (opcode.equals("cholesky"))
return computeCholesky(matrixInput);
+ else if (opcode.equals("sqrt_matrix_java"))
+ return computeSqrt(inj);
return null;
}
@@ -512,7 +514,19 @@ public class LibCommonsMath
return new MatrixBlock[] { U, Sigma, V };
}
-
+
+ /**
+ * Computes the square root of a matrix Calls Apache Commons Math
EigenDecomposition.
+ *
+ * @param in Input matrix
+ * @return matrix block
+ */
+ private static MatrixBlock computeSqrt(MatrixBlock in) {
+ Array2DRowRealMatrix matrixInput =
DataConverter.convertToArray2DRowRealMatrix(in);
+ EigenDecomposition ed = new EigenDecomposition(matrixInput);
+ return DataConverter.convertToMatrixBlock(ed.getSquareRoot());
+ }
+
/**
* Function to compute matrix inverse via matrix decomposition.
*
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSQRTMatrixTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSQRTMatrixTest.java
new file mode 100644
index 0000000000..a86a6892a8
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSQRTMatrixTest.java
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin.part2;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class BuiltinSQRTMatrixTest extends AutomatedTestBase {
+ private final static String TEST_NAME = "SQRTMatrix";
+ private final static String TEST_DIR = "functions/builtin/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
BuiltinSQRTMatrixTest.class.getSimpleName() + "/";
+
+ private final static double eps = 1e-8;
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"C"}));
+ }
+
+ // tests for strategy "COMMON"
+ @Test
+ public void testSQRTMatrixJavaSize1x1() {
+ runSQRTMatrix(true, ExecType.CP, "COMMON", 1);
+ }
+
+ @Test
+ public void testSQRTMatrixJavaUpperTriangularMatrixSize2x2() {
+ runSQRTMatrix(true, ExecType.CP, "COMMON", 2);
+ }
+
+
+ @Test
+ public void testSQRTMatrixJavaDiagonalMatrixSize2x2() {
+ runSQRTMatrix(true, ExecType.CP, "COMMON", 3);
+ }
+
+ @Test
+ public void testSQRTMatrixJavaPSDMatrixSize2x2() {
+ runSQRTMatrix(true, ExecType.CP, "COMMON", 4);
+ }
+
+ @Test
+ public void testSQRTMatrixJavaPSDMatrixSize3x3() {
+ runSQRTMatrix(true, ExecType.CP, "COMMON", 5);
+ }
+
+ @Test
+ public void testSQRTMatrixJavaPSDMatrixSize4x4() {
+ runSQRTMatrix(true, ExecType.CP, "COMMON", 6);
+ }
+
+ @Test
+ public void testSQRTMatrixJavaPSDMatrixSize8x8() {
+ runSQRTMatrix(true, ExecType.CP, "COMMON", 7);
+ }
+
+ // tests for strategy "DML"
+ @Test
+ public void testSQRTMatrixDMLSize1x1() {
+ runSQRTMatrix(true, ExecType.CP, "DML", 1);
+ }
+
+ @Test
+ public void testSQRTMatrixDMLUpperTriangularMatrixSize2x2() {
+ runSQRTMatrix(true, ExecType.CP, "DML", 2);
+ }
+
+ @Test
+ public void testSQRTMatrixDMLDiagonalMatrixSize2x2() {
+ runSQRTMatrix(true, ExecType.CP, "DML", 3);
+ }
+
+ @Test
+ public void testSQRTMatrixDMLPSDMatrixSize2x2() {
+ runSQRTMatrix(true, ExecType.CP, "DML", 4);
+ }
+
+ @Test
+ public void testSQRTMatrixDMLPSDMatrixSize3x3() {
+ runSQRTMatrix(true, ExecType.CP, "DML", 5);
+ }
+
+ @Test
+ public void testSQRTMatrixDMLPSDMatrixSize4x4() {
+ runSQRTMatrix(true, ExecType.CP, "DML", 6);
+ }
+
+ @Test
+ public void testSQRTMatrixDMLPSDMatrixSize8x8() {
+ runSQRTMatrix(true, ExecType.CP, "DML", 7);
+ }
+
+ private void runSQRTMatrix(boolean defaultProb, ExecType instType,
String strategy, int test_case) {
+ Types.ExecMode platformOld = setExecMode(instType);
+
+ try {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ // find path to associated dml script and define
parameters
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[] {"-args", input("X"),
strategy, output("Y")};
+
+ // define input matrix for the matrix sqrt function
according to test case
+ double[][] X = null;
+ switch(test_case) {
+ case 1: // arbitrary square matrix of dimension
1x1 (PSD)
+ double[][] X1 = {
+ {4}
+ };
+ X = X1;
+ break;
+ case 2: // arbitrary upper right triangular
matrix (PSD) of dimension 2x2
+ double[][] X2 = {
+ {1, 1},
+ {1, 1},
+ };
+ X = X2;
+ break;
+ case 3: // arbitrary diagonal matrix (PSD) of
dimension 2x2
+ double[][] X3 = {
+ {1, 0},
+ {0, 1},
+ };
+ X = X3;
+ break;
+ case 4: // arbitrary PSD matrix of dimension 2x2
+ // PSD matrix generated by taking
(A^T)A of matrix A = [[1, 0], [2, 3]]
+ double[][] X4 = {
+ {1, 2},
+ {2, 13}
+ };
+ X = X4;
+ break;
+ case 5: // arbitrary PSD matrix of dimension 3x3
+ // PSD matrix generated by taking
(A^T)A of matrix A =
+ // [[1.5, 0, 1.2],
+ // [2.2, 3.8, 4.4],
+ // [4.2, 6.1, 0.2]]
+ double[][] X5 = {
+ {3.69, 8.58, 6.54},
+ {8.58, 38.64, 33.30},
+ {6.54, 33.3, 54.89}
+ };
+ X = X5;
+ break;
+ case 6: // arbitrary PSD matrix of dimension 4x4
+ // PSD matrix generated by taking
(A^T)A of matrix A=
+ // [[1, 0, 5, 6],
+ // [2, 3, 0, 2],
+ // [5, 0, 1, 1],
+ // [2, 3, 4, 8]]
+ double[][] X6 = {
+ {62, 14, 16, 70},
+ {14, 17, 12, 29},
+ {16, 12, 27, 22},
+ {70, 29, 22, 93}
+ };
+ X = X6;
+ break;
+ case 7: // arbitrary PSD matrix of dimension 8x8
+ // PSD matrix generated by taking
(A^T)A of matrix A =
+ // [[ 8.41557894, 3.44748042,
1.44911908, 4.95381036, 4.42875187, 4.14710712, -0.42719386, 6.1366026 ],
+ // [ 3.44748042, 11.38083039,
4.99475137, 3.36734826, 4.08943809, 4.23308448, 4.50030176, 3.92552912],
+ // [ 1.44911908, 4.99475137,
9.78651357, 4.00347878, 4.60244914, 4.24468227, 3.62945751, 6.54033601],
+ // [ 4.95381036, 3.36734826,
4.00347878, 12.75936071, 3.78643598, 1.99998784, 5.41689723, 7.9756991 ],
+ // [ 4.42875187, 4.08943809,
4.60244914, 3.78643598, 12.49158813, 6.69560056, 3.87176913, 5.5028702 ],
+ // [ 4.14710712, 4.23308448,
4.24468227, 1.99998784, 6.69560056, 7.66015758, 4.21792513, 4.53489207],
+ // [-0.42719386, 4.50030176,
3.62945751, 5.41689723, 3.87176913, 4.21792513, 9.07079513, 2.64352781],
+ // [ 6.1366026 , 3.92552912,
6.54033601, 7.9756991 , 5.5028702 , 4.53489207, 2.64352781, 8.92801728]]
+ double[][] X7 = {
+ {184, 150, 140, 194,
192, 153, 91, 211},
+ {150, 248, 203, 198,
216, 187, 171, 214},
+ {140, 203, 234, 212,
223, 185, 165, 237},
+ {194, 198, 212, 326,
228, 177, 190, 287},
+ {192, 216, 223, 228,
318, 239, 180, 262},
+ {153, 187, 185, 177,
239, 199, 152, 209},
+ { 91, 171, 165, 190,
180, 152, 185, 170},
+ {211, 214, 237, 287,
262, 209, 170, 297}
+ };
+ X = X7;
+ break;
+ }
+
+ assert X != null;
+
+ // write the input matrix and strategy for matrix sqrt
function to dml script
+ writeInputMatrixWithMTD("X", X, true);
+
+ // run the test dml script
+ runTest(true, false, null, -1);
+
+ // read the result matrix from the dml script output Y
+ HashMap<MatrixValue.CellIndex, Double> actual_Y =
readDMLMatrixFromOutputDir("Y");
+
+ // create a HashMap with Matrix Values from the input
matrix X to compare to the received output matrix
+ HashMap<MatrixValue.CellIndex, Double> expected_Y = new
HashMap<>();
+ for (int r = 0; r < X.length; r++) {
+ for (int c = 0; c < X[0].length; c++) {
+ expected_Y.put(new
MatrixValue.CellIndex(r + 1, c + 1), X[r][c]);
+ }
+ }
+
+ // compare the expected matrix (the input matrix X)
with the received output matrix Y, which should be the (SQRT_MATRIX(X))^2 = X
again
+ TestUtils.compareMatrices(expected_Y, actual_Y, eps,
"Expected-DML", "Actual-DML");
+ }
+ finally {
+ resetExecMode(platformOld);
+ }
+ }
+}
diff --git a/src/test/scripts/functions/builtin/SQRTMatrix.dml
b/src/test/scripts/functions/builtin/SQRTMatrix.dml
new file mode 100644
index 0000000000..f14d72e6a3
--- /dev/null
+++ b/src/test/scripts/functions/builtin/SQRTMatrix.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# DML script to test the Square Root Operator for matrices
+# Result should be correct, if the result * result == input
+
+X = read($1)
+S = $2
+
+A = sqrtMatrix(X, S)
+Y = A %*% A
+
+write (Y, $3);
+