This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 2e68ad3b1a [SYSTEMDS-3860] Extended codegen row template by var 
aggregates
2e68ad3b1a is described below

commit 2e68ad3b1acd4892ed782d01dbfccffc61d2f680
Author: Frxms <tomoki.men...@gmail.com>
AuthorDate: Fri Apr 18 11:13:02 2025 +0200

    [SYSTEMDS-3860] Extended codegen row template by var aggregates
    
    Closes #2244.
---
 .../sysds/hops/codegen/cplan/CNodeUnary.java       |  4 ++-
 .../sysds/hops/codegen/cplan/java/Unary.java       |  2 +-
 .../sysds/hops/codegen/template/TemplateRow.java   |  2 +-
 .../sysds/runtime/codegen/LibSpoofPrimitives.java  | 14 +++++++-
 .../sysds/test/component/misc/DMLScriptTest.java   |  1 -
 .../functions/builtin/part2/BuiltinMDTest.java     |  2 --
 .../test/functions/codegen/RowAggTmplTest.java     | 38 ++++++++++++++++++++--
 .../scripts/functions/codegen/rowAggPattern47.R    | 36 ++++++++++++++++++++
 .../scripts/functions/codegen/rowAggPattern47.dml  | 29 +++++++++++++++++
 .../scripts/functions/codegen/rowAggPattern48.R    | 36 ++++++++++++++++++++
 .../scripts/functions/codegen/rowAggPattern48.dml  | 30 +++++++++++++++++
 11 files changed, 184 insertions(+), 10 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeUnary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeUnary.java
index 93cdb2f661..fe67995b6b 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeUnary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeUnary.java
@@ -33,7 +33,7 @@ public class CNodeUnary extends CNode
        public enum UnaryType {
                LOOKUP_R, LOOKUP_C, LOOKUP_RC, LOOKUP0, //codegen specific
                ROW_SUMS, ROW_SUMSQS, ROW_COUNTNNZS, //codegen specific
-               ROW_MEANS, ROW_MINS, ROW_MAXS,
+               ROW_MEANS, ROW_MINS, ROW_MAXS, ROW_VARS,
                VECT_EXP, VECT_POW2, VECT_MULT2, VECT_SQRT, VECT_LOG,
                VECT_ABS, VECT_ROUND, VECT_CEIL, VECT_FLOOR, VECT_SIGN, 
                VECT_SIN, VECT_COS, VECT_TAN, VECT_ASIN, VECT_ACOS, VECT_ATAN, 
@@ -139,6 +139,7 @@ public class CNodeUnary extends CNode
                        case ROW_MINS:   return "u(Rmin)";
                        case ROW_MAXS:   return "u(Rmax)";
                        case ROW_MEANS:  return "u(Rmean)";
+                       case ROW_VARS:   return "u(Rvar)";
                        case ROW_COUNTNNZS: return "u(Rnnz)";
                        case VECT_EXP:
                        case VECT_POW2:
@@ -210,6 +211,7 @@ public class CNodeUnary extends CNode
                        case ROW_MINS:
                        case ROW_MAXS:
                        case ROW_MEANS:
+                       case ROW_VARS:
                        case ROW_COUNTNNZS:
                        case EXP:
                        case LOOKUP_R:
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java
index 50ea2bace8..d8a1085df5 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/java/Unary.java
@@ -32,12 +32,12 @@ public class Unary extends CodeTemplate {
                        case ROW_MINS:
                        case ROW_MAXS:
                        case ROW_MEANS:
+                       case ROW_VARS:
                        case ROW_COUNTNNZS: {
                                String vectName = 
StringUtils.capitalize(type.name().substring(4, 
type.name().length()-1).toLowerCase());
                                return sparse ? "    double %TMP% = 
LibSpoofPrimitives.vect"+vectName+"(%IN1v%, %IN1i%, %POS1%, alen, len);\n":
                                                "    double %TMP% = 
LibSpoofPrimitives.vect"+vectName+"(%IN1%, %POS1%, %LEN%);\n";
                        }
-
                        case VECT_EXP:
                        case VECT_POW2:
                        case VECT_MULT2:
diff --git 
a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java 
b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
index c42ea6c858..955bf778b8 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateRow.java
@@ -67,7 +67,7 @@ import org.apache.sysds.runtime.matrix.data.Pair;
 
 public class TemplateRow extends TemplateBase 
 {
-       private static final AggOp[] SUPPORTED_ROW_AGG = new AggOp[]{AggOp.SUM, 
AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.PROD};
+       private static final AggOp[] SUPPORTED_ROW_AGG = new AggOp[]{AggOp.SUM, 
AggOp.MIN, AggOp.MAX, AggOp.MEAN, AggOp.PROD, AggOp.VAR};
        private static final OpOp1[] SUPPORTED_VECT_UNARY = new OpOp1[]{
                OpOp1.EXP, OpOp1.SQRT, OpOp1.LOG, OpOp1.ABS, OpOp1.ROUND, 
OpOp1.CEIL, OpOp1.FLOOR, OpOp1.SIGN,
                OpOp1.SIN, OpOp1.COS, OpOp1.TAN, OpOp1.ASIN, OpOp1.ACOS, 
OpOp1.ATAN, OpOp1.SINH, OpOp1.COSH, OpOp1.TANH,
diff --git 
a/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java 
b/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
index 6497b6f321..6c0dc395c3 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/LibSpoofPrimitives.java
@@ -2151,7 +2151,19 @@ public class LibSpoofPrimitives
                        new DenseBlockFP64(new int[]{K, PQ}, c), PQ, CRS, 0, K, 
0, PQ);
                return c;
        } 
-       
+
+       public static double vectVar(double[] a, int ai, int len) {
+               double meanVal = Math.pow(vectMean(a, ai, len), 2);
+               double[] aSqr = vectPow2Write(a, ai, len);
+               return (vectSum(aSqr, 0, len)-len*meanVal)/(len-1);
+       }
+
+       public static double vectVar(double[] avals, int[] aix, int ai, int 
alen, int len) {
+               double meanVal = Math.pow(vectMean(avals, aix, ai, alen, len), 
2);
+               double[] avalsSqr = vectPow2Write(avals, aix, ai, alen, len);
+               return (vectSum(avalsSqr, 0, len)-len*meanVal)/(len-1);
+       }
+
        //complex builtin functions that are not directly generated
        //(included here in order to reduce the number of imports)
        
diff --git 
a/src/test/java/org/apache/sysds/test/component/misc/DMLScriptTest.java 
b/src/test/java/org/apache/sysds/test/component/misc/DMLScriptTest.java
index 5b5483823a..4244ce7421 100644
--- a/src/test/java/org/apache/sysds/test/component/misc/DMLScriptTest.java
+++ b/src/test/java/org/apache/sysds/test/component/misc/DMLScriptTest.java
@@ -24,7 +24,6 @@ package org.apache.sysds.test.component.misc;
 import org.apache.log4j.Level;
 import org.apache.log4j.Logger;
 import org.apache.log4j.spi.LoggingEvent;
-import org.apache.sysds.api.DMLOptions;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.parser.LanguageException;
 import org.apache.sysds.test.LoggingUtils;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinMDTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinMDTest.java
index b04d476d06..4c51602058 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinMDTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinMDTest.java
@@ -90,8 +90,6 @@ public class BuiltinMDTest extends AutomatedTestBase {
        }
 
        @Test
-       //@Ignore
-       // https://issues.apache.org/jira/browse/SYSTEMDS-3716
        public void testMDSP() {
                double[][] D =  {
                        {7567, 231, 1231, 1232, 122, 321},
diff --git 
a/src/test/java/org/apache/sysds/test/functions/codegen/RowAggTmplTest.java 
b/src/test/java/org/apache/sysds/test/functions/codegen/RowAggTmplTest.java
index d3c9edf8e8..b89f3007b4 100644
--- a/src/test/java/org/apache/sysds/test/functions/codegen/RowAggTmplTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/codegen/RowAggTmplTest.java
@@ -87,7 +87,9 @@ public class RowAggTmplTest extends AutomatedTestBase
        private static final String TEST_NAME44 = TEST_NAME+"44"; //maxpool(X - 
mean(X)) + 7;
        private static final String TEST_NAME45 = TEST_NAME+"45"; //vector 
allocation;
        private static final String TEST_NAME46 = TEST_NAME+"46"; //conv2d(X - 
mean(X), F1) + conv2d(X - mean(X), F2);
-       
+       private static final String TEST_NAME47 = TEST_NAME+"47"; //sum(X + 
rowVars(X))
+       private static final String TEST_NAME48 = TEST_NAME+"48"; 
//sum(rowVars(X))
+
        private static final String TEST_DIR = "functions/codegen/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
RowAggTmplTest.class.getSimpleName() + "/";
        private final static String TEST_CONF = "SystemDS-config-codegen.xml";
@@ -98,7 +100,7 @@ public class RowAggTmplTest extends AutomatedTestBase
        @Override
        public void setUp() {
                TestUtils.clearAssertionInformation();
-               for(int i=1; i<=46; i++)
+               for(int i=1; i<=48; i++)
                        addTestConfiguration( TEST_NAME+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME+i, new String[] { String.valueOf(i) 
}) );
        }
        
@@ -795,6 +797,36 @@ public class RowAggTmplTest extends AutomatedTestBase
                testCodegenIntegration( TEST_NAME46, false, ExecType.SPARK );
        }
 
+       @Test
+       public void testCodegenRowAggRewrite47CP() {
+               testCodegenIntegration( TEST_NAME47, true, ExecType.CP );
+       }
+
+       @Test
+       public void testCodegenRowAgg47CP() {
+               testCodegenIntegration( TEST_NAME47, false, ExecType.CP );
+       }
+
+       @Test
+       public void testCodegenRowAgg47SP() {
+               testCodegenIntegration( TEST_NAME47, false, ExecType.SPARK );
+       }
+
+       @Test
+       public void testCodegenRowAggRewrite48CP() {
+               testCodegenIntegration( TEST_NAME48, true, ExecType.CP );
+       }
+
+       @Test
+       public void testCodegenRowAgg48CP() {
+               testCodegenIntegration( TEST_NAME48, false, ExecType.CP );
+       }
+
+       @Test
+       public void testCodegenRowAgg48SP() {
+               testCodegenIntegration( TEST_NAME48, false, ExecType.SPARK );
+       }
+
        private void testCodegenIntegration( String testname, boolean rewrites, 
ExecType instType )
        {
                boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
@@ -807,7 +839,7 @@ public class RowAggTmplTest extends AutomatedTestBase
                        
                        String HOME = SCRIPT_DIR + TEST_DIR;
                        fullDMLScriptName = HOME + testname + ".dml";
-                       programArgs = new String[]{"-stats", "-args", 
output("S") };
+                       programArgs = new String[]{"-explain", "codegen", 
"-stats", "-args", output("S") };
                        
                        fullRScriptName = HOME + testname + ".R";
                        rCmd = getRCmd(inputDir(), expectedDir());
diff --git a/src/test/scripts/functions/codegen/rowAggPattern47.R 
b/src/test/scripts/functions/codegen/rowAggPattern47.R
new file mode 100644
index 0000000000..9d6d4bc9f6
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern47.R
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+# rowVars <- function(X) {
+#   apply(X, 1, function(x) sum((x - mean(x))^2) / length(x))
+# }
+
+X = matrix(seq(7, 50*10+6), 50, 10, byrow=TRUE);
+z = seq(1,50)
+
+R = as.matrix(sum(X + rowVars(X)));
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep=""));
diff --git a/src/test/scripts/functions/codegen/rowAggPattern47.dml 
b/src/test/scripts/functions/codegen/rowAggPattern47.dml
new file mode 100644
index 0000000000..e3ee077fc1
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern47.dml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = matrix(seq(7, 50*10+6), 50, 10);
+z = seq(1,50)
+
+while(FALSE){}
+
+R = as.matrix(sum(X + rowVars(X)));
+
+write(R, $1)
diff --git a/src/test/scripts/functions/codegen/rowAggPattern48.R 
b/src/test/scripts/functions/codegen/rowAggPattern48.R
new file mode 100644
index 0000000000..bec1427d61
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern48.R
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+library("matrixStats")
+
+# rowVars <- function(X) {
+#   apply(X, 1, function(x) sum((x - mean(x))^2) / length(x))
+# }
+
+Z = matrix(seq(1,10), 1, 10)
+Y = matrix(0, 10, 10)
+X = rbind(Y, Z, Y)
+
+R = as.matrix(sum(rowVars(X)));
+
+writeMM(as(R, "CsparseMatrix"), paste(args[2], "S", sep=""));
diff --git a/src/test/scripts/functions/codegen/rowAggPattern48.dml 
b/src/test/scripts/functions/codegen/rowAggPattern48.dml
new file mode 100644
index 0000000000..c367a359cd
--- /dev/null
+++ b/src/test/scripts/functions/codegen/rowAggPattern48.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+Z = matrix(seq(1,10), 1, 10)
+Y = matrix(0, 10, 10)
+X = rbind(Y, Z, Y)
+
+while(FALSE){}
+
+R = as.matrix(sum(rowVars(X)));
+
+write(R, $1)
\ No newline at end of file

Reply via email to