This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 40f2f7a0536aca988a889bdebbef6f219a38aebe
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Apr 14 20:12:05 2023 +0200

    [SYSTEMDS-3522] Add missing binning decoder to transformdecode
    
    Since binning is a lossy transformation, we did not support decoding
    binned features yet. However, this creates an unnecessary burden for
    users having to decode predictions of decisionTree/randomForest which
    are averaged binned representations of leaf buckets. This patch adds
    this decoder with smooth interpolation of bin centers for non-integer
    bin codes (e.g., average of binned labels).
---
 .../sysds/runtime/transform/decode/DecoderBin.java | 133 +++++++++++++++++++++
 .../runtime/transform/decode/DecoderFactory.java   |  12 +-
 2 files changed, 142 insertions(+), 3 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java 
b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java
new file mode 100644
index 0000000000..5eb43e8bde
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.transform.decode;
+
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+
+import org.apache.commons.lang3.NotImplementedException;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.UtilFunctions;
+
+/**
+ * Simple atomic decoder for binned columns. This decoder builds internally
+ * arrays of lower/upper bin boundaries, accesses these boundaries in
+ * constant time for incoming values and  
+ *  
+ */
+public class DecoderBin extends Decoder
+{
+       private static final long serialVersionUID = -3784249774608228805L;
+
+       // a) column bin boundaries
+       private int[] _numBins;
+       private double[][] _binMins = null;
+       private double[][] _binMaxs = null;
+
+       public DecoderBin() {
+               super(null, null);
+       }
+
+       protected DecoderBin(ValueType[] schema, int[] binCols) {
+               super(schema, binCols);
+       }
+
+       @Override
+       public FrameBlock decode(MatrixBlock in, FrameBlock out) {
+               out.ensureAllocatedColumns(in.getNumRows());
+               for( int i=0; i<in.getNumRows(); i++ ) {
+                       for( int j=0; j<_colList.length; j++ ) {
+                               double val = in.quickGetValue(i, _colList[j]-1);
+                               int key = (int) Math.round(val);
+                               double bmin = _binMins[j][key-1];
+                               double bmax = _binMaxs[j][key-1];
+                               double oval = bmin + (bmax-bmin)/2 // bin center
+                                       + (val-key) * (bmax-bmin);     // bin 
fractions
+                               out.set(i, _colList[j]-1, oval);
+                       }
+               }
+               return out;
+       }
+
+       @Override
+       public Decoder subRangeDecoder(int colStart, int colEnd, int 
dummycodedOffset) {
+               // federated not supported yet
+               throw new NotImplementedException();
+       }
+
+       @Override
+       public void initMetaData(FrameBlock meta) {
+               //initialize bin boundaries
+               _numBins = new int[_colList.length];
+               _binMins = new double[_colList.length][];
+               _binMaxs = new double[_colList.length][];
+               
+               //parse and insert bin boundaries
+               for( int j=0; j<_colList.length; j++ ) {
+                       int numBins = 
(int)meta.getColumnMetadata(_colList[j]-1).getNumDistinct();
+                       _binMins[j] = new double[numBins];
+                       _binMaxs[j] = new double[numBins];
+                       for( int i=0; i<meta.getNumRows() & i<numBins; i++ ) {
+                               if( meta.get(i, _colList[j]-1)==null  ) {
+                                       if( i+1 < numBins )
+                                               throw new 
DMLRuntimeException("Did not reach number of bins: "+(i+1)+"/"+numBins);
+                                       break; //reached end of bins
+                               }
+                               String[] parts = UtilFunctions.splitRecodeEntry(
+                                       meta.get(i, _colList[j]-1).toString());
+                               _binMins[j][i] = Double.parseDouble(parts[0]);
+                               _binMaxs[j][i] = Double.parseDouble(parts[1]);
+                       }
+               }
+       }
+
+       @Override
+       public void writeExternal(ObjectOutput out) throws IOException {
+               super.writeExternal(out);
+               for( int i=0; i<_colList.length; i++ ) {
+                       int len = _numBins[i];
+                       out.writeInt(len);
+                       for(int j=0; j<len; j++) {
+                               out.writeDouble(_binMins[i][j]);
+                               out.writeDouble(_binMaxs[i][j]);
+                       }
+               }
+       }
+
+       @Override
+       public void readExternal(ObjectInput in) throws IOException {
+               super.readExternal(in);
+               _numBins = new int[_colList.length];
+               _binMins = new double[_colList.length][];
+               _binMaxs = new double[_colList.length][];
+               for( int i=0; i<_colList.length; i++ ) {
+                       int len = in.readInt();
+                       _numBins[i] = len;
+                       for(int j=0; j<len; j++) {
+                               _binMins[i][j] = in.readDouble();
+                               _binMaxs[i][j] = in.readDouble();
+                       }
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java 
b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java
index a2fe499459..df1d4381c9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java
@@ -37,9 +37,10 @@ import static 
org.apache.sysds.runtime.util.CollectionUtils.unionDistinct;
 public class DecoderFactory 
 {
        public enum DecoderType {
+               Bin,
                Dummycode, 
                PassThrough,
-               Recode
+               Recode,
        };
        
        public static Decoder createDecoder(String spec, String[] colnames, 
ValueType[] schema, FrameBlock meta) {
@@ -65,14 +66,15 @@ public class DecoderFactory
                        JSONObject jSpec = new JSONObject(spec);
                        List<Decoder> ldecoders = new ArrayList<>();
                        
-                       //create decoders 'recode', 'dummy' and 'pass-through'
+                       //create decoders 'bin', 'recode', 'dummy' and 
'pass-through'
+                       List<Integer> binIDs = 
TfMetaUtils.parseBinningColIDs(jSpec, colnames, minCol, maxCol);
                        List<Integer> rcIDs = Arrays.asList(ArrayUtils.toObject(
                                        TfMetaUtils.parseJsonIDList(jSpec, 
colnames, TfMethod.RECODE.toString(), minCol, maxCol)));
                        List<Integer> dcIDs = Arrays.asList(ArrayUtils.toObject(
                                        TfMetaUtils.parseJsonIDList(jSpec, 
colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol)));
                        rcIDs = unionDistinct(rcIDs, dcIDs);
                        int len = dcIDs.isEmpty() ? 
Math.min(meta.getNumColumns(), clen) : meta.getNumColumns();
-                       List<Integer> ptIDs = 
except(UtilFunctions.getSeqList(1, len, 1), rcIDs);
+                       List<Integer> ptIDs = 
except(except(UtilFunctions.getSeqList(1, len, 1), rcIDs), binIDs);
                        
                        //create default schema if unspecified (with double 
columns for pass-through)
                        if( schema == null ) {
@@ -81,6 +83,10 @@ public class DecoderFactory
                                        schema[col-1] = ValueType.FP64;
                        }
                        
+                       if( !binIDs.isEmpty() ) {
+                               ldecoders.add(new DecoderBin(schema, 
+                                       
ArrayUtils.toPrimitive(binIDs.toArray(new Integer[0]))));
+                       }
                        if( !dcIDs.isEmpty() ) {
                                ldecoders.add(new DecoderDummycode(schema, 
                                        
ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0]))));

Reply via email to