This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new d3d3911126 [SYSTEMDS-2944] Fix incorrect nnz propagation of unary hops
d3d3911126 is described below

commit d3d3911126496342ae1daa12d93c7d1c09644025
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Jul 21 20:44:24 2023 +0200

    [SYSTEMDS-2944] Fix incorrect nnz propagation of unary hops
    
    This patch fixes an issue identified in #1862, where despite no NaNs,
    the numbers of sum(X!=0) and sum(X==0) did not add up to the number
    of cells because we incorrectly propagated the nnz for X and rewrote
    sum(X!=0) to obtain the nnz from the metadata, while sum(X==0) was
    actually executed.
---
 src/main/java/org/apache/sysds/hops/UnaryOp.java   |  7 +--
 .../test/functions/misc/NNZPropagationTest.java    | 67 ++++++++++++++++++++++
 src/test/scripts/functions/misc/nnzUnary.dml       | 25 ++++++++
 3 files changed, 94 insertions(+), 5 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java 
b/src/main/java/org/apache/sysds/hops/UnaryOp.java
index b0fb1c24f8..72f63e99ed 100644
--- a/src/main/java/org/apache/sysds/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java
@@ -558,13 +558,10 @@ public class UnaryOp extends MultiThreadedHop
                {
                        // If output is a Matrix then this operation is of type 
(B = op(A))
                        // Dimensions of B are same as that of A, and sparsity 
may/maynot change
+                       // note: round, sin, cos can introduce new zeros for 
non-zero inputs
                        setDim1( input.getDim1() );
                        setDim2( input.getDim2() );
-                       // cosh(0)=cos(0)=1, acos(0)=1.5707963267948966
-                       if( _op==OpOp1.ABS || _op==OpOp1.SIN || _op==OpOp1.TAN  
-                               || _op==OpOp1.SINH || _op==OpOp1.TANH
-                               || _op==OpOp1.ASIN || _op==OpOp1.ATAN
-                               || _op==OpOp1.SQRT || _op==OpOp1.ROUND || 
_op==OpOp1.SPROP
+                       if( _op==OpOp1.ABS || _op==OpOp1.SQRT || 
_op==OpOp1.SPROP
                                || _op==OpOp1.COMPRESS || _op==OpOp1.DECOMPRESS 
|| _op==OpOp1.LOCAL) //sparsity preserving
                        {
                                setNnz( input.getNnz() );
diff --git 
a/src/test/java/org/apache/sysds/test/functions/misc/NNZPropagationTest.java 
b/src/test/java/org/apache/sysds/test/functions/misc/NNZPropagationTest.java
new file mode 100644
index 0000000000..73b5e255ae
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/misc/NNZPropagationTest.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.misc;
+
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Map;
+
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+
+public class NNZPropagationTest extends AutomatedTestBase 
+{
+       private final static String TEST_NAME1 = "nnzUnary";
+       private final static String TEST_DIR = "functions/misc/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
NNZPropagationTest.class.getSimpleName() + "/";
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"R"}));
+       }
+
+       @Test
+       public void testNnzUnary() {
+               runExistsTest(TEST_NAME1, 14967, 5033);
+       }
+       
+       private void runExistsTest(String testName, int expNNZ, int expNZ) {
+               TestConfiguration config = getTestConfiguration(testName);
+               loadTestConfiguration(config);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+               fullDMLScriptName = HOME + testName + ".dml";
+               programArgs = new String[]{"-explain","-args", output("R") };
+               
+               //run script and compare output
+               runTest(true, false, null, -1); 
+               
+               //compare results
+               Map<CellIndex, Double> ret = readDMLMatrixFromOutputDir("R");
+               Double nnonzero = ret.get(new CellIndex(1,1));
+               Double nzero = ret.get(new CellIndex(2,1));
+               Assert.assertEquals(expNNZ, nnonzero, 1e-14);
+               Assert.assertEquals(expNZ, nzero, 1e-14);
+       }
+}
diff --git a/src/test/scripts/functions/misc/nnzUnary.dml 
b/src/test/scripts/functions/misc/nnzUnary.dml
new file mode 100644
index 0000000000..6eb801c6dd
--- /dev/null
+++ b/src/test/scripts/functions/misc/nnzUnary.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = round(rand(rows=10000, cols=2, min=-2, max=2, seed=7))
+R = as.matrix(list(sum(X!=0), sum(X==0)))
+write(R, $1);
+

Reply via email to