This is an automated email from the ASF dual-hosted git repository.

estrauss pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 8f4dba14f6 [SYSTEMDS-3801] Fix missing method implementations in 
ColGroupSDCZeros
8f4dba14f6 is described below

commit 8f4dba14f6f4f94ad34de559d2d72168868fcc2d
Author: e-strauss <[email protected]>
AuthorDate: Tue Dec 3 00:58:31 2024 +0100

    [SYSTEMDS-3801] Fix missing method implementations in ColGroupSDCZeros
    
    The previous master version broke the AWARE experiment for the kmeans+ 
algorithm. This patch fixes that and adds missing methods implementations for 
DenseBlocks in ColGroupSDCZeros.
    
    After the changes, the runtime additionally was decreased from 40s to 32s 
for the kmeans+ algorithm on the US Census dataset.
    
    Closes #2149.
---
 .../runtime/compress/colgroup/ColGroupSDC.java     |  4 +--
 .../compress/colgroup/ColGroupSDCZeros.java        | 38 +++++++++++++++++++---
 .../compress/colgroup/dictionary/ADictionary.java  | 12 ++++++-
 .../compress/colgroup/dictionary/IDictionary.java  | 15 ++++++++-
 .../colgroup/dictionary/PlaceHolderDict.java       |  8 ++++-
 .../compress/dictionary/PlaceHolderDictTest.java   |  9 +++--
 6 files changed, 74 insertions(+), 12 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
index ea4f2fb581..e78bea93a2 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java
@@ -683,7 +683,7 @@ public class ColGroupSDC extends ASDC implements 
IMapToDataGroup {
                        }
                        else {
                                while(c < points.length && points[c].o == of) {
-                                       _dict.put(sr, 
_data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes);
+                                       _dict.putSparse(sr, 
_data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes);
                                        c++;
                                }
                                of = it.next();
@@ -696,7 +696,7 @@ public class ColGroupSDC extends ASDC implements 
IMapToDataGroup {
                }
 
                while(of == last && c < points.length && points[c].o == of) {
-                       _dict.put(sr, _data.getIndex(it.getDataIndex()), 
points[c].r, nCol, _colIndexes);
+                       _dict.putSparse(sr, _data.getIndex(it.getDataIndex()), 
points[c].r, nCol, _colIndexes);
                        c++;
                }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
index c1e081f253..d250969a6a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java
@@ -836,7 +836,7 @@ public class ColGroupSDCZeros extends ASDCZero implements 
IMapToDataGroup {
 
                while(of < last && c < points.length) {
                        if(points[c].o == of) {
-                               c = processRow(points, sr, nCol, c, of, 
_data.getIndex(it.getDataIndex()));
+                               c = processRowSparse(points, sr, nCol, c, of, 
_data.getIndex(it.getDataIndex()));
                                of = it.next();
                        }
                        else if(points[c].o < of)
@@ -848,18 +848,46 @@ public class ColGroupSDCZeros extends ASDCZero implements 
IMapToDataGroup {
                while(c < points.length && points[c].o < last)
                        c++;
 
-               c = processRow(points, sr, nCol, c, of, 
_data.getIndex(it.getDataIndex()));
+               c = processRowSparse(points, sr, nCol, c, of, 
_data.getIndex(it.getDataIndex()));
 
        }
 
        @Override
        protected void denseSelection(MatrixBlock selection, P[] points, 
MatrixBlock ret, int rl, int ru) {
-               throw new NotImplementedException();
+               final DenseBlock dr = ret.getDenseBlock();
+               final int nCol = _colIndexes.size();
+               final AIterator it = _indexes.getIterator();
+               final int last = _indexes.getOffsetToLast();
+               int c = 0;
+               int of = it.value();
+
+               while(of < last && c < points.length) {
+                       if(points[c].o == of) {
+                               c = processRowDense(points, dr, nCol, c, of, 
_data.getIndex(it.getDataIndex()));
+                               of = it.next();
+                       }
+                       else if(points[c].o < of)
+                                       c++;
+                       else
+                               of = it.next();
+                       }
+                       // increment the c pointer until it is pointing at 
least to last point or is done.
+                       while(c < points.length && points[c].o < last)
+                               c++;
+                       c = processRowDense(points, dr, nCol, c, of, 
_data.getIndex(it.getDataIndex()));
+       }
+
+       private int processRowSparse(P[] points, final SparseBlock sr, final 
int nCol, int c, int of, final int did) {
+               while(c < points.length && points[c].o == of) {
+                       _dict.putSparse(sr, did, points[c].r, nCol, 
_colIndexes);
+                       c++;
+               }
+               return c;
        }
 
-       private int processRow(P[] points, final SparseBlock sr, final int 
nCol, int c, int of, final int did) {
+       private int processRowDense(P[] points, final DenseBlock dr, final int 
nCol, int c, int of, final int did) {
                while(c < points.length && points[c].o == of) {
-                       _dict.put(sr, did, points[c].r, nCol, _colIndexes);
+                       _dict.putDense(dr, did, points[c].r, nCol, _colIndexes);
                        c++;
                }
                return c;
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
index d41e2675f5..7d88573e3a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.compress.colgroup.dictionary;
 import java.io.Serializable;
 
 import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
+import org.apache.sysds.runtime.data.DenseBlock;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
 import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
@@ -87,8 +88,17 @@ public abstract class ADictionary implements IDictionary, 
Serializable {
        }
 
        @Override
-       public void put(SparseBlock sb, int idx, int rowOut, int nCol, 
IColIndex columns) {
+       public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, 
IColIndex columns) {
                for(int i = 0; i < nCol; i++)
                        sb.append(rowOut, columns.get(i), getValue(idx, i, 
nCol));
        }
+
+       @Override
+       public void putDense(DenseBlock dr, int idx, int rowOut, int nCol, 
IColIndex columns) {
+               double[] dv = dr.values(rowOut);
+               int off = dr.pos(rowOut);
+               for(int i = 0; i < nCol; i++)
+                       dv[off + columns.get(i)] += getValue(idx, i, nCol);
+       }
+
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
index a7a74775be..bfe4ef23c3 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
@@ -25,6 +25,7 @@ import java.io.IOException;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
+import org.apache.sysds.runtime.data.DenseBlock;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.functionobjects.Builtin;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
@@ -989,6 +990,18 @@ public interface IDictionary {
         * @param nCol    The number of columns in the dictionary
         * @param columns The columns to output into.
         */
-       public void put(SparseBlock sb, int idx, int rowOut, int nCol, 
IColIndex columns);
+       public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, 
IColIndex columns);
+
+       /**
+        * Put the row specified into the sparse block, via append calls.
+        *
+        * @param db      The dense block to put into
+        * @param idx     The dictionary index to put in.
+        * @param rowOut  The row in the sparse block to put it into
+        * @param nCol    The number of columns in the dictionary
+        * @param columns The columns to output into.
+        */
+       public void putDense(DenseBlock db, int idx, int rowOut, int nCol, 
IColIndex columns);
+
 
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java
index 88a7be2619..68a3fb3fac 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java
@@ -25,6 +25,7 @@ import java.io.IOException;
 import java.io.Serializable;
 
 import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
+import org.apache.sysds.runtime.data.DenseBlock;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.functionobjects.Builtin;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
@@ -526,7 +527,12 @@ public class PlaceHolderDict implements IDictionary, 
Serializable {
        }
        
        @Override
-       public void put(SparseBlock sb, int idx, int rowOut, int nCol, 
IColIndex columns) {
+       public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, 
IColIndex columns) {
+               throw new RuntimeException(errMessage);
+       }
+
+       @Override
+       public void putDense(DenseBlock sb, int idx, int rowOut, int nCol, 
IColIndex columns) {
                throw new RuntimeException(errMessage);
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/PlaceHolderDictTest.java
 
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/PlaceHolderDictTest.java
index 88e5d8adcc..5a112a800c 100644
--- 
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/PlaceHolderDictTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/PlaceHolderDictTest.java
@@ -490,8 +490,13 @@ public class PlaceHolderDictTest {
        }
 
        @Test(expected = Exception.class)
-       public void put() {
-               d.put(null, 1, 1, 1, null);
+       public void putDense() {
+               d.putDense(null, 1, 1, 1, null);
+       }
+
+       @Test(expected = Exception.class)
+       public void putSparse() {
+               d.putSparse(null, 1, 1, 1, null);
        }
 
        @Test

Reply via email to