This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new af937bdccf [SYSTEMDS-3520] Error if multiple encoders applied on one 
column
af937bdccf is described below

commit af937bdccf503520103e748d3fe520e6d0533801
Author: Arnab Phani <[email protected]>
AuthorDate: Wed Apr 12 14:54:56 2023 +0200

    [SYSTEMDS-3520] Error if multiple encoders applied on one column
    
    This patch adds a coherence check for finding intersection between
    the first level encoders (recode, binning, feature hashing), and
    errors out if more than one encoder is applied on a single feature.
    Applying more than one encoder on a column can lead to incorrect
    results and crashes during multithreaded processing of those encoders.
    
    Closes #1805
---
 conf/SystemDS-config.xml.template                  |  2 +-
 .../runtime/transform/encode/EncoderFactory.java   | 13 +++++++++---
 .../apache/sysds/runtime/util/CollectionUtils.java | 23 ++++++++++++++++++++++
 .../TransformFrameBuildMultithreadedTest.java      |  9 +++++----
 .../datasets/homes3/homes.tfspec_hash_recode.json  |  2 +-
 .../datasets/homes3/homes.tfspec_hash_recode2.json |  2 +-
 6 files changed, 41 insertions(+), 10 deletions(-)

diff --git a/conf/SystemDS-config.xml.template 
b/conf/SystemDS-config.xml.template
index d35f54dbaa..45073c349b 100644
--- a/conf/SystemDS-config.xml.template
+++ b/conf/SystemDS-config.xml.template
@@ -37,7 +37,7 @@
     <sysds.cp.parallel.io>true</sysds.cp.parallel.io>
 
     <!-- enalbe multi-threaded transformencode and apply -->
-    <sysds.parallel.encode>false</sysds.parallel.encode>
+    <sysds.parallel.encode>true</sysds.parallel.encode>
 
     <!-- synchronization barrier between transformencode build and apply -->
     <sysds.parallel.encode.staged>false</sysds.parallel.encode.staged>
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index 99b241cb95..41e16d6e6e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.runtime.transform.encode;
 
 import static org.apache.sysds.runtime.util.CollectionUtils.except;
+import static org.apache.sysds.runtime.util.CollectionUtils.intersect;
 import static org.apache.sysds.runtime.util.CollectionUtils.unionDistinct;
 
 import java.util.ArrayList;
@@ -87,8 +88,13 @@ public class EncoderFactory {
                        List<Integer> dcIDs = Arrays.asList(ArrayUtils
                                .toObject(TfMetaUtils.parseJsonIDList(jSpec, 
colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol)));
                        List<Integer> binIDs = 
TfMetaUtils.parseBinningColIDs(jSpec, colnames, minCol, maxCol);
-                       // note: any dummycode column requires recode as 
preparation, unless it follows binning
-                       rcIDs = except(unionDistinct(rcIDs, except(dcIDs, 
binIDs)), haIDs);
+                       // NOTE: any dummycode column requires recode as 
preparation, unless the dummycode
+                       // column follows binning or feature hashing
+                       rcIDs = unionDistinct(rcIDs, except(except(dcIDs, 
binIDs), haIDs));
+                       // Error out if the first level encoders have overlaps
+                       if (intersect(rcIDs, binIDs, haIDs))
+                               throw new DMLRuntimeException("More than one 
encoders (recode, binning, hashing) on one column is not allowed");
+
                        List<Integer> ptIDs = 
except(except(UtilFunctions.getSeqList(1, clen, 1), unionDistinct(rcIDs, 
haIDs)),
                                binIDs);
                        List<Integer> oIDs = Arrays.asList(ArrayUtils
@@ -96,7 +102,8 @@ public class EncoderFactory {
                        List<Integer> mvIDs = Arrays.asList(ArrayUtils.toObject(
                                TfMetaUtils.parseJsonObjectIDList(jSpec, 
colnames, TfMethod.IMPUTE.toString(), minCol, maxCol)));
                        List<Integer> udfIDs = 
TfMetaUtils.parseUDFColIDs(jSpec, colnames, minCol, maxCol);
-                       
+
+
                        // create individual encoders
                        if(!rcIDs.isEmpty())
                                for(Integer id : rcIDs)
diff --git a/src/main/java/org/apache/sysds/runtime/util/CollectionUtils.java 
b/src/main/java/org/apache/sysds/runtime/util/CollectionUtils.java
index a26f3b8e10..fbd35ee428 100644
--- a/src/main/java/org/apache/sysds/runtime/util/CollectionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/util/CollectionUtils.java
@@ -20,7 +20,9 @@
 package org.apache.sysds.runtime.util;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
+import java.util.Comparator;
 import java.util.HashSet;
 import java.util.Iterator;
 import java.util.List;
@@ -102,6 +104,27 @@ public class CollectionUtils {
                                return true;
                return false;
        }
+
+       @SafeVarargs
+       public static <T> boolean intersect(Collection<T>... inputs) {
+               //remove empty collections
+               Collection<T>[] nonEmpty = Arrays.stream(inputs).filter(l -> 
!l.isEmpty()).toArray(Collection[]::new);
+               if (nonEmpty.length == 0)
+                       return false;
+
+               //order the lists based on size (ascending)
+               Arrays.sort(nonEmpty, 
Comparator.comparingInt(Collection::size));
+               //maintain a central hash table of seen items
+               Set<T> probe = (nonEmpty[0] instanceof HashSet) ? (Set<T>) 
nonEmpty[0] : new HashSet<>(nonEmpty[0]);
+               for (int i=1; i<nonEmpty.length; i++) {
+                       for (T item : nonEmpty[i])
+                               //if the item is in the seen set, return true
+                               if (probe.contains(item))
+                                       return true;
+                       probe.addAll(inputs[i]);
+               }
+               return false;
+       }
        
        @SuppressWarnings("unchecked")
        public static <T> List<T> unionDistinct(List<T> a, List<T> b) {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameBuildMultithreadedTest.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameBuildMultithreadedTest.java
index ee06fc2ec7..c03212e8cf 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameBuildMultithreadedTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameBuildMultithreadedTest.java
@@ -95,10 +95,11 @@ public class TransformFrameBuildMultithreadedTest extends 
AutomatedTestBase {
                runTransformTest(Types.ExecMode.SINGLE_NODE, "csv", 
TransformType.RECODE_DUMMY, 0);
        }
 
-       @Test
-       public void testHomesBuildRecodeBinningSingleNodeCSV() {
-               runTransformTest(Types.ExecMode.SINGLE_NODE, "csv", 
TransformType.RECODE_BIN, 0);
-       }
+       // This test fails as column 1 exists in both recode and binning list.
+       //@Test
+       //public void testHomesBuildRecodeBinningSingleNodeCSV() {
+       //      runTransformTest(Types.ExecMode.SINGLE_NODE, "csv", 
TransformType.RECODE_BIN, 0);
+       //}
 
        @Test
        public void testHomesBuildBinSingleNodeCSV() {
diff --git a/src/test/resources/datasets/homes3/homes.tfspec_hash_recode.json 
b/src/test/resources/datasets/homes3/homes.tfspec_hash_recode.json
index dc0ca2efa4..fc794604a3 100644
--- a/src/test/resources/datasets/homes3/homes.tfspec_hash_recode.json
+++ b/src/test/resources/datasets/homes3/homes.tfspec_hash_recode.json
@@ -1,2 +1,2 @@
 {
-    "ids": true, "hash": [ 1, 2, 7 ], "K": 100, "recode": [ 2, 3, 6 ] }
\ No newline at end of file
+    "ids": true, "hash": [ 1, 2, 7 ], "K": 100, "recode": [ 3, 6 ] }
diff --git a/src/test/resources/datasets/homes3/homes.tfspec_hash_recode2.json 
b/src/test/resources/datasets/homes3/homes.tfspec_hash_recode2.json
index 10cf57da43..80fa04b530 100644
--- a/src/test/resources/datasets/homes3/homes.tfspec_hash_recode2.json
+++ b/src/test/resources/datasets/homes3/homes.tfspec_hash_recode2.json
@@ -1,2 +1,2 @@
 {
-    "hash": [ "zipcode", "district", "view" ], "K": 100, "recode": [ 
"zipcode", "district", "view" ] }
\ No newline at end of file
+    "hash": [ "district", "view" ], "K": 100, "recode": [ "zipcode" ] }

Reply via email to