This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 40f2f7a0536aca988a889bdebbef6f219a38aebe Author: Matthias Boehm <[email protected]> AuthorDate: Fri Apr 14 20:12:05 2023 +0200 [SYSTEMDS-3522] Add missing binning decoder to transformdecode Since binning is a lossy transformation, we did not support decoding binned features yet. However, this creates an unnecessary burden for users having to decode predictions of decisionTree/randomForest which are averaged binned representations of leaf buckets. This patch adds this decoder with smooth interpolation of bin centers for non-integer bin codes (e.g., average of binned labels). --- .../sysds/runtime/transform/decode/DecoderBin.java | 133 +++++++++++++++++++++ .../runtime/transform/decode/DecoderFactory.java | 12 +- 2 files changed, 142 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java new file mode 100644 index 0000000000..5eb43e8bde --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.transform.decode; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.UtilFunctions; + +/** + * Simple atomic decoder for binned columns. This decoder builds internally + * arrays of lower/upper bin boundaries, accesses these boundaries in + * constant time for incoming values and + * + */ +public class DecoderBin extends Decoder +{ + private static final long serialVersionUID = -3784249774608228805L; + + // a) column bin boundaries + private int[] _numBins; + private double[][] _binMins = null; + private double[][] _binMaxs = null; + + public DecoderBin() { + super(null, null); + } + + protected DecoderBin(ValueType[] schema, int[] binCols) { + super(schema, binCols); + } + + @Override + public FrameBlock decode(MatrixBlock in, FrameBlock out) { + out.ensureAllocatedColumns(in.getNumRows()); + for( int i=0; i<in.getNumRows(); i++ ) { + for( int j=0; j<_colList.length; j++ ) { + double val = in.quickGetValue(i, _colList[j]-1); + int key = (int) Math.round(val); + double bmin = _binMins[j][key-1]; + double bmax = _binMaxs[j][key-1]; + double oval = bmin + (bmax-bmin)/2 // bin center + + (val-key) * (bmax-bmin); // bin fractions + out.set(i, _colList[j]-1, oval); + } + } + return out; + } + + @Override + public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { + // federated not supported yet + throw new NotImplementedException(); + } + + @Override + public void initMetaData(FrameBlock meta) { + //initialize bin boundaries + _numBins = new int[_colList.length]; + _binMins = new double[_colList.length][]; + _binMaxs = new double[_colList.length][]; + + //parse and insert bin boundaries + for( int j=0; j<_colList.length; j++ ) { + int numBins = (int)meta.getColumnMetadata(_colList[j]-1).getNumDistinct(); + _binMins[j] = new double[numBins]; + _binMaxs[j] = new double[numBins]; + for( int i=0; i<meta.getNumRows() & i<numBins; i++ ) { + if( meta.get(i, _colList[j]-1)==null ) { + if( i+1 < numBins ) + throw new DMLRuntimeException("Did not reach number of bins: "+(i+1)+"/"+numBins); + break; //reached end of bins + } + String[] parts = UtilFunctions.splitRecodeEntry( + meta.get(i, _colList[j]-1).toString()); + _binMins[j][i] = Double.parseDouble(parts[0]); + _binMaxs[j][i] = Double.parseDouble(parts[1]); + } + } + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + super.writeExternal(out); + for( int i=0; i<_colList.length; i++ ) { + int len = _numBins[i]; + out.writeInt(len); + for(int j=0; j<len; j++) { + out.writeDouble(_binMins[i][j]); + out.writeDouble(_binMaxs[i][j]); + } + } + } + + @Override + public void readExternal(ObjectInput in) throws IOException { + super.readExternal(in); + _numBins = new int[_colList.length]; + _binMins = new double[_colList.length][]; + _binMaxs = new double[_colList.length][]; + for( int i=0; i<_colList.length; i++ ) { + int len = in.readInt(); + _numBins[i] = len; + for(int j=0; j<len; j++) { + _binMins[i][j] = in.readDouble(); + _binMaxs[i][j] = in.readDouble(); + } + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java index a2fe499459..df1d4381c9 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java @@ -37,9 +37,10 @@ import static org.apache.sysds.runtime.util.CollectionUtils.unionDistinct; public class DecoderFactory { public enum DecoderType { + Bin, Dummycode, PassThrough, - Recode + Recode, }; public static Decoder createDecoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta) { @@ -65,14 +66,15 @@ public class DecoderFactory JSONObject jSpec = new JSONObject(spec); List<Decoder> ldecoders = new ArrayList<>(); - //create decoders 'recode', 'dummy' and 'pass-through' + //create decoders 'bin', 'recode', 'dummy' and 'pass-through' + List<Integer> binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames, minCol, maxCol); List<Integer> rcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.RECODE.toString(), minCol, maxCol))); List<Integer> dcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol))); rcIDs = unionDistinct(rcIDs, dcIDs); int len = dcIDs.isEmpty() ? Math.min(meta.getNumColumns(), clen) : meta.getNumColumns(); - List<Integer> ptIDs = except(UtilFunctions.getSeqList(1, len, 1), rcIDs); + List<Integer> ptIDs = except(except(UtilFunctions.getSeqList(1, len, 1), rcIDs), binIDs); //create default schema if unspecified (with double columns for pass-through) if( schema == null ) { @@ -81,6 +83,10 @@ public class DecoderFactory schema[col-1] = ValueType.FP64; } + if( !binIDs.isEmpty() ) { + ldecoders.add(new DecoderBin(schema, + ArrayUtils.toPrimitive(binIDs.toArray(new Integer[0])))); + } if( !dcIDs.isEmpty() ) { ldecoders.add(new DecoderDummycode(schema, ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0]))));
