This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 74b9d86f1f86ccc5636e433fa51acf17e5c007a4 Author: anuunchin <88698977+anuunc...@users.noreply.github.com> AuthorDate: Thu May 15 11:51:00 2025 +0200 [SYSTEMDS-3780] Compression-fused Quantization This commit adds a new fused operator that both quantize an input and compresses the result. The operator does not allocate the intermediate quantized matrix. Closes #2226 --- .../java/org/apache/sysds/common/Builtins.java | 1 + .../org/apache/sysds/common/InstructionType.java | 1 + src/main/java/org/apache/sysds/common/Opcodes.java | 1 + src/main/java/org/apache/sysds/common/Types.java | 5 +- .../java/org/apache/sysds/hops/OptimizerUtils.java | 9 + .../apache/sysds/hops/rewrite/ProgramRewriter.java | 4 +- .../RewriteQuantizationFusedCompression.java | 125 ++++++++ .../sysds/parser/BuiltinFunctionExpression.java | 30 +- .../org/apache/sysds/parser/DMLTranslator.java | 3 + .../compress/CompressedMatrixBlockFactory.java | 43 ++- .../runtime/compress/CompressionSettings.java | 10 +- .../compress/CompressionSettingsBuilder.java | 16 +- .../runtime/compress/bitmap/BitmapEncoder.java | 194 +++++++++--- .../runtime/compress/colgroup/ColGroupFactory.java | 306 +++++++++++++++--- .../compress/colgroup/ColGroupUncompressed.java | 112 ++++++- .../sysds/runtime/compress/estim/ComEstExact.java | 4 +- .../sysds/runtime/compress/estim/ComEstSample.java | 2 +- .../compress/estim/CompressedSizeInfoColGroup.java | 4 + .../compress/estim/encoding/EncodingFactory.java | 315 ++++++++++++++++++- .../compress/readers/ReaderColumnSelection.java | 63 +++- ...erColumnSelectionDenseSingleBlockQuantized.java | 50 +++ .../runtime/instructions/CPInstructionParser.java | 6 +- .../runtime/instructions/InstructionParser.java | 3 +- .../runtime/instructions/cp/CPInstruction.java | 1 + .../instructions/cp/CompressionCPInstruction.java | 77 +++++ .../qcompress/CompareCompressionTypeTest.java | 149 +++++++++ .../QuantizationFusedCompressionTest.java | 148 +++++++++ ...uantizationFusedForcedCompressionTypesTest.java | 348 +++++++++++++++++++++ .../component/compress/readers/ReadersTest.java | 44 +++ .../RewriteQuantizationFusedCompressionTest.java | 149 +++++++++ .../RewriteQuantizationFusedCompressionMatrix.dml | 37 +++ .../RewriteQuantizationFusedCompressionScalar.dml | 37 +++ 32 files changed, 2183 insertions(+), 114 deletions(-) diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index 971e43a14b..423679d038 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -89,6 +89,7 @@ public enum Builtins { COLVAR("colVars", false), COMPONENTS("components", true), COMPRESS("compress", false, ReturnType.MULTI_RETURN), + QUANTIZE_COMPRESS("quantize_compress", false, ReturnType.MULTI_RETURN), CONFUSIONMATRIX("confusionMatrix", true), CONV2D("conv2d", false), CONV2D_BACKWARD_FILTER("conv2d_backward_filter", false), diff --git a/src/main/java/org/apache/sysds/common/InstructionType.java b/src/main/java/org/apache/sysds/common/InstructionType.java index 41c12f5c2c..bbf55f2b6d 100644 --- a/src/main/java/org/apache/sysds/common/InstructionType.java +++ b/src/main/java/org/apache/sysds/common/InstructionType.java @@ -50,6 +50,7 @@ public enum InstructionType { Partition, Compression, DeCompression, + QuantizeCompression, SpoofFused, Prefetch, EvictLineageCache, diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index dafd771c6b..f0bccef77b 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -296,6 +296,7 @@ public enum Opcodes { PARTITION("partition", InstructionType.Partition), COMPRESS(Compression.OPCODE, InstructionType.Compression, InstructionType.Compression), DECOMPRESS(DeCompression.OPCODE, InstructionType.DeCompression, InstructionType.DeCompression), + QUANTIZE_COMPRESS("quantize_compress", InstructionType.QuantizeCompression), SPOOF("spoof", InstructionType.SpoofFused), PREFETCH("prefetch", InstructionType.Prefetch), EVICT("_evict", InstructionType.EvictLineageCache), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index e6d5f8e945..e69ad375b2 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -634,8 +634,9 @@ public interface Types { //fused ML-specific operators for performance MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=)) LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5) - MINUS1_MULT(false); //1-X*Y - + MINUS1_MULT(false), //1-X*Y + QUANTIZE_COMPRESS(false); //quantization-fused compression + private final boolean _validOuter; private OpOp2(boolean outer) { diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java index 73dd4a80a3..9ba3ea3ed7 100644 --- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java @@ -280,6 +280,15 @@ public class OptimizerUtils */ public static boolean ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND = true; + /** + * This variable allows for insertion of Quantize and compress in the dml script from the user. + */ + public static boolean ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND = true; + + /** + * Boolean specifying if quantization-fused compression rewrite is allowed. + */ + public static boolean ALLOW_QUANTIZE_COMPRESS_REWRITE = true; /** * Boolean specifying if compression rewrites is allowed. This is disabled at run time if the IPA for Workload aware compression diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java index b08d836efe..874ddae034 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java @@ -90,7 +90,9 @@ public class ProgramRewriter{ if( OptimizerUtils.ALLOW_AUTO_VECTORIZATION ) _dagRuleSet.add( new RewriteIndexingVectorization() ); //dependency: cse, simplifications _dagRuleSet.add( new RewriteInjectSparkPReadCheckpointing() ); //dependency: reblock - + if( OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE ) + _dagRuleSet.add( new RewriteQuantizationFusedCompression() ); + //add statement block rewrite rules if( OptimizerUtils.ALLOW_BRANCH_REMOVAL ) _sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java new file mode 100644 index 0000000000..f29d1dce81 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteQuantizationFusedCompression.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.rewrite; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map.Entry; + +import org.apache.sysds.common.Types.OpOp1; +import org.apache.sysds.common.Types.OpOp2; +import org.apache.sysds.hops.UnaryOp; +import org.apache.sysds.runtime.instructions.cp.DoubleObject; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; +import org.apache.sysds.hops.BinaryOp; + +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.common.Types.ValueType; + +import org.apache.sysds.hops.Hop; + +/** + * Rule: RewriteFloorCompress. Detects the sequence `M2 = floor(M * S)` followed by `C = compress(M2)` and prepares for + * fusion into a single operation. This rewrite improves performance by avoiding intermediate results. Currently, it + * identifies the pattern without applying fusion. + */ +public class RewriteQuantizationFusedCompression extends HopRewriteRule { + @Override + public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) { + if(roots == null) + return null; + + // traverse the HOP DAG + HashMap<String, Hop> floors = new HashMap<>(); + HashMap<String, Hop> compresses = new HashMap<>(); + for(Hop h : roots) + collectFloorCompressSequences(h, floors, compresses); + + Hop.resetVisitStatus(roots); + + // check compresses for compress-after-floor pattern + for(Entry<String, Hop> e : compresses.entrySet()) { + String inputname = e.getKey(); + Hop compresshop = e.getValue(); + + if(floors.containsKey(inputname) // floors same name + && ((floors.get(inputname).getBeginLine() < compresshop.getBeginLine()) || + (floors.get(inputname).getEndLine() < compresshop.getEndLine()) || + (floors.get(inputname).getBeginLine() == compresshop.getBeginLine() && + floors.get(inputname).getEndLine() == compresshop.getBeginLine() && + floors.get(inputname).getBeginColumn() < compresshop.getBeginColumn()))) { + + // retrieve the floor hop and inputs + Hop floorhop = floors.get(inputname); + Hop floorInput = floorhop.getInput().get(0); + + // check if the input of the floor operation is a matrix + if(floorInput.getDataType() == DataType.MATRIX) { + + // Check if the input of the floor operation involves a multiplication operation + if(floorInput instanceof BinaryOp && ((BinaryOp) floorInput).getOp() == OpOp2.MULT) { + Hop initialMatrix = floorInput.getInput().get(0); + Hop sf = floorInput.getInput().get(1); + + // create fused hop + BinaryOp fusedhop = new BinaryOp("test", DataType.MATRIX, ValueType.FP64, + OpOp2.QUANTIZE_COMPRESS, initialMatrix, sf); + + // rewire compress consumers to fusedHop + List<Hop> parents = new ArrayList<>(compresshop.getParent()); + for(Hop p : parents) { + HopRewriteUtils.replaceChildReference(p, compresshop, fusedhop); + } + } + } + } + } + return roots; + } + + @Override + public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { + // do nothing, floor/compress do not occur in predicates + return root; + } + + private void collectFloorCompressSequences(Hop hop, HashMap<String, Hop> floors, HashMap<String, Hop> compresses) { + if(hop.isVisited()) + return; + + // process childs + if(!hop.getInput().isEmpty()) + for(Hop c : hop.getInput()) + collectFloorCompressSequences(c, floors, compresses); + + // process current hop + if(hop instanceof UnaryOp) { + UnaryOp uop = (UnaryOp) hop; + if(uop.getOp() == OpOp1.FLOOR) { + floors.put(uop.getName(), uop); + } + else if(uop.getOp() == OpOp1.COMPRESS) { + compresses.put(uop.getInput(0).getName(), uop); + } + } + hop.setVisited(); + } +} diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index 6a68f867f9..70245de070 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -751,7 +751,7 @@ public class BuiltinFunctionExpression extends DataIdentifier { else raiseValidateError("Compress/DeCompress instruction not allowed in dml script"); break; - + default: //always unconditional raiseValidateError("Unknown Builtin Function opcode: " + _opcode, false); } @@ -2013,6 +2013,34 @@ public class BuiltinFunctionExpression extends DataIdentifier { else raiseValidateError("Compress/DeCompress instruction not allowed in dml script"); break; + case QUANTIZE_COMPRESS: + if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND) { + checkNumParameters(2); + Expression firstExpr = getFirstExpr(); + Expression secondExpr = getSecondExpr(); + + checkMatrixParam(getFirstExpr()); + + if(secondExpr != null) { + // check if scale factor is a scalar, vector or matrix + checkMatrixScalarParam(secondExpr); + // if scale factor is a vector or matrix, make sure it has an appropriate shape + if(secondExpr.getOutput().getDataType() != DataType.SCALAR) { + if(is1DMatrix(secondExpr)) { + long vectorLength = secondExpr.getOutput().getDim1(); + if(vectorLength != firstExpr.getOutput().getDim1()) { + raiseValidateError( + "The length of the row-wise scale factor vector must match the number of rows in the matrix."); + } + } + else { + checkMatchingDimensions(firstExpr, secondExpr); + } + } + } + } + break; + case ROW_COUNT_DISTINCT: checkNumParameters(1); checkMatrixParam(getFirstExpr()); diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index 4884de4c74..c4f7f672ab 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2585,6 +2585,9 @@ public class DMLTranslator case DECOMPRESS: currBuiltinOp = new UnaryOp(target.getName(), target.getDataType(), ValueType.FP64, OpOp1.DECOMPRESS, expr); break; + case QUANTIZE_COMPRESS: + currBuiltinOp = new BinaryOp(target.getName(), target.getDataType(), target.getValueType(), OpOp2.valueOf(source.getOpCode().name()), expr, expr2); + break; // Boolean binary case XOR: diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java index 7ea2cc3966..4c48effb4d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java @@ -50,6 +50,7 @@ import org.apache.sysds.runtime.compress.workload.WTreeRoot; import org.apache.sysds.runtime.controlprogram.caching.CacheableData; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.CommonThreadPool; @@ -137,6 +138,21 @@ public class CompressedMatrixBlockFactory { return compress(mb, k, new CompressionSettingsBuilder(), root); } + public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, MatrixBlock sf, int k, WTreeRoot root) { + // Handle only row vectors, as column-wise quantization is not allowed. + // The restriction is handled upstream + double[] scaleFactors = sf.getDenseBlockValues(); + CompressionSettingsBuilder builder = new CompressionSettingsBuilder().setScaleFactor(scaleFactors); + return compress(mb, k, builder, root); + } + + public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, ScalarObject sf, int k, WTreeRoot root) { + double[] scaleFactors = new double[1]; + scaleFactors[0] = sf.getDoubleValue(); + CompressionSettingsBuilder builder = new CompressionSettingsBuilder().setScaleFactor(scaleFactors); + return compress(mb, k, builder, root); + } + public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, CostEstimatorBuilder csb) { return compress(mb, k, new CompressionSettingsBuilder(), csb); } @@ -285,7 +301,7 @@ public class CompressedMatrixBlockFactory { return new ImmutablePair<>(mb, null); } - _stats.denseSize = MatrixBlock.estimateSizeInMemory(mb.getNumRows(), mb.getNumColumns(), 1.0); + _stats.denseSize = MatrixBlock.estimateSizeInMemory(mb.getNumRows(), mb.getNumColumns(), 1.0); _stats.sparseSize = MatrixBlock.estimateSizeSparseInMemory(mb.getNumRows(), mb.getNumColumns(), mb.getSparsity()); _stats.originalSize = mb.getInMemorySize(); _stats.originalCost = costEstimator.getCost(mb); @@ -300,8 +316,10 @@ public class CompressedMatrixBlockFactory { res = new CompressedMatrixBlock(mb); // copy metadata and allocate soft reference logInit(); + classifyPhase(); - if(compressionGroups == null) + + if(compressionGroups == null) return abortCompression(); // clear extra data from analysis @@ -491,7 +509,26 @@ public class CompressedMatrixBlockFactory { MatrixBlock ucmb = ((CompressedMatrixBlock) mb).getUncompressed("Decompressing for abort: ", k); return new ImmutablePair<>(ucmb, _stats); } - return new ImmutablePair<>(mb, _stats); + if(compSettings.scaleFactors == null) { + LOG.warn("Scale factors are null - returning original matrix."); + return new ImmutablePair<>(mb, _stats); + } else { + LOG.warn("Scale factors are present - returning scaled matrix."); + MatrixBlock scaledMb = new MatrixBlock(mb.getNumRows(), mb.getNumColumns(), mb.isInSparseFormat()); + scaledMb.copy(mb); + + // Apply scaling and flooring + // TODO: Use internal matrix prod + for(int r = 0; r < mb.getNumRows(); r++) { + double scaleFactor = compSettings.scaleFactors.length == 1 ? compSettings.scaleFactors[0] : compSettings.scaleFactors[r]; + for(int c = 0; c < mb.getNumColumns(); c++) { + double newValue = Math.floor(mb.get(r, c) * scaleFactor); + scaledMb.set(r, c, newValue); + } + } + scaledMb.recomputeNonZeros(); + return new ImmutablePair<>(scaledMb, _stats); + } } private void logInit() { diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java index 31c034ef4d..f6321bc1b6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettings.java @@ -46,8 +46,8 @@ public class CompressionSettings { public static final int BITMAP_BLOCK_SZ = Character.MAX_VALUE; /** - * Sorting of values by physical length helps by 10-20%, especially for serial, while slight performance decrease for - * parallel incl multi-threaded, hence not applied for distributed operations (also because compression time + + * Sorting of values by physical length helps by 10-20%, especially for serial, while slight performance decrease + * for parallel incl multi-threaded, hence not applied for distributed operations (also because compression time + * garbage collection increases) */ public final boolean sortTuplesByFrequency; @@ -131,11 +131,13 @@ public class CompressionSettings { /** if the settings have been logged already. */ public static boolean printedStatus = false; + public final double[] scaleFactors; + protected CompressionSettings(double samplingRatio, double samplePower, boolean allowSharedDictionary, String transposeInput, int seed, boolean lossy, EnumSet<CompressionType> validCompressions, boolean sortValuesByLength, PartitionerType columnPartitioner, int maxColGroupCoCode, double coCodePercentage, int minimumSampleSize, int maxSampleSize, EstimationType estimationType, CostType costComputationType, - double minimumCompressionRatio, boolean isInSparkInstruction, SORT_TYPE sdcSortType) { + double minimumCompressionRatio, boolean isInSparkInstruction, SORT_TYPE sdcSortType, double[] scaleFactors) { this.samplingRatio = samplingRatio; this.samplePower = samplePower; this.allowSharedDictionary = allowSharedDictionary; @@ -154,6 +156,8 @@ public class CompressionSettings { this.minimumCompressionRatio = minimumCompressionRatio; this.isInSparkInstruction = isInSparkInstruction; this.sdcSortType = sdcSortType; + this.scaleFactors = scaleFactors; + if(!printedStatus && LOG.isDebugEnabled()) { printedStatus = true; LOG.debug(this.toString()); diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java index dc0908dc9b..ae6a0b2d23 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressionSettingsBuilder.java @@ -52,6 +52,7 @@ public class CompressionSettingsBuilder { private double minimumCompressionRatio = 1.0; private boolean isInSparkInstruction = false; private SORT_TYPE sdcSortType = SORT_TYPE.MATERIALIZE; + private double[] scaleFactors = null; public CompressionSettingsBuilder() { @@ -69,6 +70,19 @@ public class CompressionSettingsBuilder { } + /** + * Sets the scale factors for compression, enabling quantization-fused compression. + * + * @param scaleFactors An array of scale factors applied during compression. + * - If row-wise scaling is used, this should be an array where each value corresponds to a row. + * - If a single scalar is provided, it is applied uniformly to the entire matrix. + * @return The CompressionSettingsBuilder instance with the updated scale factors. + */ + public CompressionSettingsBuilder setScaleFactor(double[] scaleFactors) { + this.scaleFactors = scaleFactors; + return this; + } + /** * Copy the settings from another CompressionSettings Builder, modifies this, not that. * @@ -331,6 +345,6 @@ public class CompressionSettingsBuilder { return new CompressionSettings(samplingRatio, samplePower, allowSharedDictionary, transposeInput, seed, lossy, validCompressions, sortValuesByLength, columnPartitioner, maxColGroupCoCode, coCodePercentage, minimumSampleSize, maxSampleSize, estimationType, costType, minimumCompressionRatio, isInSparkInstruction, - sdcSortType); + sdcSortType, scaleFactors); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/bitmap/BitmapEncoder.java b/src/main/java/org/apache/sysds/runtime/compress/bitmap/BitmapEncoder.java index 70fd4da263..7be7ac4b93 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/bitmap/BitmapEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/compress/bitmap/BitmapEncoder.java @@ -24,6 +24,7 @@ import java.util.Collections; import java.util.Comparator; import java.util.List; +import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.CompressionSettings; @@ -48,7 +49,8 @@ public class BitmapEncoder { public static ABitmap extractBitmap(IColIndex colIndices, MatrixBlock rawBlock, int estimatedNumberOfUniques, CompressionSettings cs) { - return extractBitmap(colIndices, rawBlock, cs.transposed, estimatedNumberOfUniques, cs.sortTuplesByFrequency); + return extractBitmap(colIndices, rawBlock, cs.transposed, estimatedNumberOfUniques, cs.sortTuplesByFrequency, + cs.scaleFactors); } /** @@ -61,75 +63,177 @@ public class BitmapEncoder { * @param rawBlock An uncompressed matrix block; can be dense, sparse, empty, or null (not * Compressed!) * @param transposed Boolean specifying if the rawBlock was transposed. - * @param estimatedNumberOfUniques The number of estimated uniques inside this group. Used to allocated the HashMaps. + * @param estimatedNumberOfUniques The number of estimated uniques inside this group. Used to allocated the + * HashMaps. * @param sortedEntries Boolean specifying if the entries should be sorted based on frequency of tuples * @return Uncompressed bitmap representation of the columns specified */ public static ABitmap extractBitmap(IColIndex colIndices, MatrixBlock rawBlock, boolean transposed, int estimatedNumberOfUniques, boolean sortedEntries) { + // Overloaded method with scaleFactors defaulted to null + return extractBitmap(colIndices, rawBlock, transposed, estimatedNumberOfUniques, sortedEntries, null); + } + + /** + * Generate quantization-fused uncompressed bitmaps for a set of columns in an uncompressed matrix block. + * + * if the rawBlock is transposed and sparse it should be guaranteed that the rows specified are not empty, aka all + * zero. + * + * @param colIndices Indexes (within the block) of the columns to extract + * @param rawBlock An uncompressed matrix block; can be dense, sparse, empty, or null (not + * Compressed!) + * @param transposed Boolean specifying if the rawBlock was transposed. + * @param estimatedNumberOfUniques The number of estimated uniques inside this group. Used to allocated the + * HashMaps. + * @param sortedEntries Boolean specifying if the entries should be sorted based on frequency of tuples + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for + * entire matrix + * @return Uncompressed bitmap representation of the columns specified + */ + public static ABitmap extractBitmap(IColIndex colIndices, MatrixBlock rawBlock, boolean transposed, + int estimatedNumberOfUniques, boolean sortedEntries, double[] scaleFactors) { if(rawBlock == null || rawBlock.isEmpty()) return null; final int numRows = transposed ? rawBlock.getNumColumns() : rawBlock.getNumRows(); final int estimatedNumber = Math.max(estimatedNumberOfUniques, 8); - if(colIndices.size() == 1) + if(colIndices.size() == 1) { return extractBitmapSingleColumn(colIndices.get(0), rawBlock, numRows, transposed, estimatedNumber, - sortedEntries); - else - return extractBitmapMultiColumns(colIndices, rawBlock, numRows, transposed, estimatedNumber, sortedEntries); + sortedEntries, scaleFactors); + } + else { + return extractBitmapMultiColumns(colIndices, rawBlock, numRows, transposed, estimatedNumber, sortedEntries, + scaleFactors); + } } - private static ABitmap extractBitmapSingleColumn(int colIndex, MatrixBlock rawBlock, int numRows, boolean transposed, - int est, boolean sort) { + private static ABitmap extractBitmapSingleColumn(int colIndex, MatrixBlock rawBlock, int numRows, + boolean transposed, int est, boolean sort, double[] scaleFactors) { if(transposed) { if(rawBlock.isInSparseFormat() && rawBlock.getSparseBlock().isEmpty(colIndex)) return null; - return makeSingleColBitmap(extractSingleColT(colIndex, rawBlock, est), rawBlock.getNumColumns(), sort); + return makeSingleColBitmap(extractSingleColT(colIndex, rawBlock, est, scaleFactors), rawBlock.getNumColumns(), sort); } else - return makeSingleColBitmap(extractSingleCol(colIndex, rawBlock, est), rawBlock.getNumRows(), sort); + return makeSingleColBitmap(extractSingleCol(colIndex, rawBlock, est, scaleFactors), rawBlock.getNumRows(), + sort); } - private static DoubleIntListHashMap extractSingleCol(int colIndex, MatrixBlock rawBlock, int estimatedUnique) { + private static DoubleIntListHashMap extractSingleCol(int colIndex, MatrixBlock rawBlock, int estimatedUnique, + double[] scaleFactors) { final DoubleIntListHashMap distinctVals = new DoubleIntListHashMap(estimatedUnique); final int nRows = rawBlock.getNumRows(); final int nCols = rawBlock.getNumColumns(); final boolean sparse = rawBlock.isInSparseFormat(); - if(sparse) { - final SparseBlock sb = rawBlock.getSparseBlock(); - for(int r = 0; r < nRows; r++) { - if(sb.isEmpty(r)) - continue; - final int apos = sb.pos(r); - final int alen = sb.size(r) + apos; - final int[] aix = sb.indexes(r); - final int idx = Arrays.binarySearch(aix, apos, alen, colIndex); - if(idx >= 0) - distinctVals.appendValue(sb.values(r)[idx], r); + if(scaleFactors == null) { + if(sparse) { + final SparseBlock sb = rawBlock.getSparseBlock(); + for(int r = 0; r < nRows; r++) { + if(sb.isEmpty(r)) + continue; + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final int idx = Arrays.binarySearch(aix, apos, alen, colIndex); + if(idx >= 0) + distinctVals.appendValue(sb.values(r)[idx], r); + } + } + else if(rawBlock.getDenseBlock().isContiguous()) { + final double[] values = rawBlock.getDenseBlockValues(); + if(nCols == 1) + // Since the only values contained is in this column index. simply extract it continuously. + for(int i = 0; i < values.length; i++) + distinctVals.appendValue(values[i], i); + else + // For loop down through the rows skipping all other values than the ones in the specified column + // index. + for(int i = 0, off = colIndex; off < nRows * nCols; i++, off += nCols) + distinctVals.appendValue(values[off], i); + } + else { // GENERAL CASE + // This case is slow, because it does a binary search in each row of the sparse input. (if sparse) + // and does get value in dense cases with multi blocks. + for(int i = 0; i < nRows; i++) + distinctVals.appendValue(rawBlock.get(i, colIndex), i); } } - else if(rawBlock.getDenseBlock().isContiguous()) { - final double[] values = rawBlock.getDenseBlockValues(); - if(nCols == 1) - // Since the only values contained is in this column index. simply extract it continuously. - for(int i = 0; i < values.length; i++) - distinctVals.appendValue(values[i], i); - else - // For loop down through the rows skipping all other values than the ones in the specified column index. - for(int i = 0, off = colIndex; off < nRows * nCols; i++, off += nCols) - distinctVals.appendValue(values[off], i); - } - else { // GENERAL CASE - // This case is slow, because it does a binary search in each row of the sparse input. (if sparse) - // and does get value in dense cases with multi blocks. - for(int i = 0; i < nRows; i++) - distinctVals.appendValue(rawBlock.get(i, colIndex), i); + else { + // Apply single scale factor + if(scaleFactors.length == 1) { + final double scaleFactor = scaleFactors[0]; + + if(sparse) { + final SparseBlock sb = rawBlock.getSparseBlock(); + for(int r = 0; r < nRows; r++) { + if(sb.isEmpty(r)) + continue; + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final int idx = Arrays.binarySearch(aix, apos, alen, colIndex); + if(idx >= 0) + distinctVals.appendValue(Math.floor(sb.values(r)[idx] * scaleFactor), r); + } + } + else if(rawBlock.getDenseBlock().isContiguous()) { + final double[] values = rawBlock.getDenseBlockValues(); + if(nCols == 1) { + for(int i = 0; i < values.length; i++) + distinctVals.appendValue(Math.floor(values[i] * scaleFactor), i); + } + else { + for(int i = 0, off = colIndex; off < nRows * nCols; i++, off += nCols) + distinctVals.appendValue(Math.floor(values[off] * scaleFactor), i); + } + } + else { // GENERAL CASE + for(int i = 0; i < nRows; i++) + distinctVals.appendValue(Math.floor(rawBlock.get(i, colIndex) * scaleFactor), i); + } + } + else { + // Apply scale factor row-wise. The shape of scale factor is handled upstream. + if(sparse) { + final SparseBlock sb = rawBlock.getSparseBlock(); + for(int r = 0; r < nRows; r++) { + if(sb.isEmpty(r)) + continue; + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final int idx = Arrays.binarySearch(aix, apos, alen, colIndex); + if(idx >= 0) + distinctVals.appendValue(Math.floor(sb.values(r)[idx] * scaleFactors[r]), r); + } + } + else if(rawBlock.getDenseBlock().isContiguous()) { + final double[] values = rawBlock.getDenseBlockValues(); + if(nCols == 1) { + for(int i = 0; i < values.length; i++) + distinctVals.appendValue(Math.floor(values[i] * scaleFactors[i]), i); + } + else { + for(int i = 0, off = colIndex; off < nRows * nCols; i++, off += nCols) + distinctVals.appendValue(Math.floor(values[off] * scaleFactors[i]), i); + } + } + else { // GENERAL CASE + for(int i = 0; i < nRows; i++) + distinctVals.appendValue(Math.floor(rawBlock.get(i, colIndex) * scaleFactors[i]), i); + } + } } return distinctVals; } - private static DoubleIntListHashMap extractSingleColT(int colIndex, MatrixBlock rawBlock, int estimatedUnique) { + private static DoubleIntListHashMap extractSingleColT(int colIndex, MatrixBlock rawBlock, int estimatedUnique, double[] scaleFactors) { + if (scaleFactors != null) { + throw new NotImplementedException(); + } + // probe map for distinct items (for value or value groups) final DoubleIntListHashMap distinctVals = new DoubleIntListHashMap(estimatedUnique); @@ -163,15 +267,20 @@ public class BitmapEncoder { } private static ABitmap extractBitmapMultiColumns(IColIndex colIndices, MatrixBlock rawBlock, int numRows, - boolean transposed, int estimatedUnique, boolean sort) { + boolean transposed, int estimatedUnique, boolean sort, double[] scaleFactors) { final DblArrayIntListHashMap map = new DblArrayIntListHashMap(estimatedUnique); - final ReaderColumnSelection reader = ReaderColumnSelection.createReader(rawBlock, colIndices, transposed); + + final ReaderColumnSelection reader = (scaleFactors == null) ? ReaderColumnSelection.createReader(rawBlock, + colIndices, + transposed) : ReaderColumnSelection.createQuantizedReader(rawBlock, colIndices, transposed, scaleFactors); + DblArray cellVals = null; try { DblArray empty = new DblArray(new double[colIndices.size()]); while((cellVals = reader.nextRow()) != null) { - if(!cellVals.equals(empty)) + if(!cellVals.equals(empty)) { map.appendValue(cellVals, reader.getCurrentRowIndex()); + } } } @@ -195,6 +304,7 @@ public class BitmapEncoder { values[bitmapIx] = val.key.getData(); offsetsLists[bitmapIx++] = val.value; } + return new MultiColBitmap(offsetsLists, values, numRows); } else diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java index a759e99939..ef8b83c3b8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java @@ -29,6 +29,7 @@ import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; +import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; @@ -226,12 +227,12 @@ public class ColGroupFactory { if(estC < actC * 0.75) { String warning = "The estimate cost is significantly off : " + est; LOG.debug( - String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s\n\t\t%s", time, - retType, estC, actC, act.getNumValues(), cols, wanted, warning)); + String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s\n\t\t%s", + time, retType, estC, actC, act.getNumValues(), cols, wanted, warning)); } else { - LOG.debug(String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s", time, - retType, estC, actC, act.getNumValues(), cols, wanted)); + LOG.debug(String.format("time[ms]: %10.2f %25s est %10.0f -- act %10.0f distinct:%5d cols:%s wanted:%s", + time, retType, estC, actC, act.getNumValues(), cols, wanted)); } } @@ -261,19 +262,40 @@ public class ColGroupFactory { if((ct == CompressionType.EMPTY && !t) || // (t && colIndexes.size() == 1 && in.isInSparseFormat() // Empty Column && in.getSparseBlock().isEmpty(colIndexes.get(0)))) + // TODO: handle quantization-fused compression if deemed necessary, + // but if the matrix reaches here, it likely doesn't need quantization. return new ColGroupEmpty(colIndexes); - else if(ct == CompressionType.UNCOMPRESSED) // don't construct mapping if uncompressed - return ColGroupUncompressed.create(colIndexes, in, t); + else if(ct == CompressionType.UNCOMPRESSED) { // don't construct mapping if uncompressed + if(cs.scaleFactors != null) { + return ColGroupUncompressed.createQuantized(colIndexes, in, t, cs.scaleFactors); + } + else { + return ColGroupUncompressed.create(colIndexes, in, t); + } + } else if((ct == CompressionType.SDC || ct == CompressionType.CONST) // && in.isInSparseFormat() // && t && (// (colIndexes.size() > 1 && cg.getNumOffs() < 0.3 * nRow) // - || colIndexes.size() == 1)) - return compressSDCFromSparseTransposedBlock(colIndexes, cg.getNumVals(), cg.getTupleSparsity()); - else if(ct == CompressionType.DDC) + || colIndexes.size() == 1)) { + if(cs.scaleFactors != null) { + throw new NotImplementedException(); // TODO: handle quantization-fused compression + } + else { + return compressSDCFromSparseTransposedBlock(colIndexes, cg.getNumVals(), cg.getTupleSparsity()); + } + } + else if(ct == CompressionType.DDC) { return directCompressDDC(colIndexes, cg); - else if(ct == CompressionType.LinearFunctional) - return compressLinearFunctional(colIndexes, in, cs); + } + else if(ct == CompressionType.LinearFunctional) { + if(cs.scaleFactors != null) { + throw new NotImplementedException(); // quantization-fused compression NOT allowed + } + else { + return compressLinearFunctional(colIndexes, in, cs); + } + } else if(ct == CompressionType.DDCFOR) { AColGroup g = directCompressDDC(colIndexes, cg); if(g instanceof ColGroupDDC) @@ -285,12 +307,13 @@ public class ColGroupFactory { } final ABitmap ubm = BitmapEncoder.extractBitmap(colIndexes, in, cg.getNumVals(), cs); - if(ubm == null) // no values ... therefore empty + if(ubm == null) {// no values ... therefore empty return new ColGroupEmpty(colIndexes); - + } final IntArrayList[] of = ubm.getOffsetList(); - if(of.length == 1 && of[0].size() == nRow) // If this always constant + if(of.length == 1 && of[0].size() == nRow) { // If this always constant return ColGroupConst.create(colIndexes, DictionaryFactory.create(ubm)); + } final double tupleSparsity = colIndexes.size() > 4 ? cg.getTupleSparsity() : 1.0; @@ -330,19 +353,100 @@ public class ColGroupFactory { IDictionary dict = Dictionary.create(cMap.getDictionary(dictSize)); IntArrayList offs = new IntArrayList(nRow - defCount); AMapToData map = MapToFactory.create(nRow - defCount, dictSize); - getOffsets(offs, map, cMap, col, def); - + if(cs.scaleFactors != null) { + getOffsetsScaled(offs, map, cMap, col, def); + } + else { + getOffsets(offs, map, cMap, col, def); + } AOffset aoff = OffsetFactory.createOffset(offs); return ColGroupSDC.create(colIndexes, nRow, dict, new double[] {def}, aoff, map, null); + } + private void getOffsetsScaled(IntArrayList offs, AMapToData map, DoubleCountHashMap cMap, int col, double def) { + final double scaleFactor = cs.scaleFactors[0]; // Single column, thus single scalar value. + + if(in.isInSparseFormat()) { + final SparseBlock sb = in.getSparseBlock(); + + if(def == 0) { // If zero is the default value + for(int r = 0; r < nRow; r++) { + if(sb.isEmpty(r)) + continue; // Skip explicitly storing zero values + + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final int idx = Arrays.binarySearch(aix, apos, alen, col); + + if(idx >= 0) { + double v = Math.floor(sb.values(r)[idx] * scaleFactor); + map.set(offs.size(), cMap.getId(v)); + offs.appendValue(r); + } + } + } + + else { // If zero is NOT the default value, track missing values explicitly + for(int r = 0; r < nRow; r++) { + if(sb.isEmpty(r)) { + map.set(offs.size(), cMap.getId(0.0)); + offs.appendValue(r); + } + else { + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final int idx = Arrays.binarySearch(aix, apos, alen, col); + + if(idx < 0) { // Missing entry + map.set(offs.size(), cMap.getId(0.0)); + offs.appendValue(r); + } + else { + double v = Math.floor(sb.values(r)[idx] * scaleFactor); + if(!Util.eq(v, def)) { + map.set(offs.size(), cMap.getId(v)); + offs.appendValue(r); + } + } + } + } + } + + } + else if(in.getDenseBlock().isContiguous()) { + final double[] dv = in.getDenseBlockValues(); + int off = col; + + for(int r = 0; r < nRow; r++, off += nCol) { + double scaledValue = Math.floor(dv[off] * scaleFactor); + if(!Util.eq(scaledValue, def)) { + map.set(offs.size(), cMap.getId(scaledValue)); + offs.appendValue(r); + } + } + } + else { + final DenseBlock db = in.getDenseBlock(); + for(int r = 0; r < nRow; r++) { + final double[] dv = db.values(r); + int off = db.pos(r) + col; + double scaledValue = Math.floor(dv[off] * scaleFactor); + if(!Util.eq(scaledValue, def)) { + map.set(offs.size(), cMap.getId(scaledValue)); + offs.appendValue(r); + } + } + } } private void getOffsets(IntArrayList offs, AMapToData map, DoubleCountHashMap cMap, int col, double def) { if(in.isInSparseFormat()) { + final SparseBlock sb = in.getSparseBlock(); if(def == 0) { - final SparseBlock sb = in.getSparseBlock(); for(int r = 0; r < nRow; r++) { if(sb.isEmpty(r)) continue; @@ -358,11 +462,8 @@ public class ColGroupFactory { } } else { - - final SparseBlock sb = in.getSparseBlock(); for(int r = 0; r < nRow; r++) { if(sb.isEmpty(r)) { - map.set(offs.size(), cMap.getId(0.0)); offs.appendValue(r); } @@ -384,11 +485,13 @@ public class ColGroupFactory { } } } + } } else if(in.getDenseBlock().isContiguous()) { final double[] dv = in.getDenseBlockValues(); int off = col; + for(int r = 0; r < nRow; r++, off += nCol) if(!Util.eq(dv[off], def)) { map.set(offs.size(), cMap.getId(dv[off])); @@ -409,16 +512,55 @@ public class ColGroupFactory { } private void countElements(DoubleCountHashMap map, int col) { - if(in.isInSparseFormat()) - countElementsSparse(map, col); - else if(in.getDenseBlock().isContiguous()) - countElementsDenseContiguous(map, col); - else - countElementsDenseGeneric(map, col); + if(cs.scaleFactors != null) { + if(in.isInSparseFormat()) { + countElementsSparseScaled(map, col); + } + else if(in.getDenseBlock().isContiguous()) { + countElementsDenseContiguousScaled(map, col); + } + else { + countElementsDenseGenericScaled(map, col); + } + } + else { + if(in.isInSparseFormat()) { + countElementsSparse(map, col); + } + else if(in.getDenseBlock().isContiguous()) { + countElementsDenseContiguous(map, col); + } + else { + countElementsDenseGeneric(map, col); + } + } + } + + private void countElementsSparseScaled(DoubleCountHashMap map, int col) { + final SparseBlock sb = in.getSparseBlock(); + + double scaleFactor = cs.scaleFactors[0]; + for(int r = 0; r < nRow; r++) { + if(sb.isEmpty(r)) + map.increment(0.0); + else { + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final int idx = Arrays.binarySearch(aix, apos, alen, col); + if(idx < 0) { + map.increment(0.0); + } + else { + map.increment(Math.floor(sb.values(r)[idx] * scaleFactor)); + } + } + } } private void countElementsSparse(DoubleCountHashMap map, int col) { final SparseBlock sb = in.getSparseBlock(); + for(int r = 0; r < nRow; r++) { if(sb.isEmpty(r)) map.increment(0.0); @@ -435,13 +577,34 @@ public class ColGroupFactory { } } + private void countElementsDenseContiguousScaled(DoubleCountHashMap map, int col) { + final double[] dv = in.getDenseBlockValues(); + int off = col; + + double scaleFactor = cs.scaleFactors[0]; + for(int r = 0; r < nRow; r++, off += nCol) { + map.increment(Math.floor(dv[off] * scaleFactor)); + } + } + private void countElementsDenseContiguous(DoubleCountHashMap map, int col) { final double[] dv = in.getDenseBlockValues(); int off = col; + for(int r = 0; r < nRow; r++, off += nCol) map.increment(dv[off]); } + private void countElementsDenseGenericScaled(DoubleCountHashMap map, int col) { + final DenseBlock db = in.getDenseBlock(); + double scaleFactor = cs.scaleFactors[0]; + for(int r = 0; r < nRow; r++) { + final double[] dv = db.values(r); + int off = db.pos(r) + col; + map.increment(Math.floor(dv[off] * scaleFactor)); + } + } + private void countElementsDenseGeneric(DoubleCountHashMap map, int col) { final DenseBlock db = in.getDenseBlock(); for(int r = 0; r < nRow; r++) { @@ -452,10 +615,15 @@ public class ColGroupFactory { } private AColGroup directCompressDDC(IColIndex colIndexes, CompressedSizeInfoColGroup cg) throws Exception { - if(colIndexes.size() > 1) + // testing multicol + if(colIndexes.size() > 1) { + LOG.debug("DDC multi column"); return directCompressDDCMultiCol(colIndexes, cg); - else + } + else { + LOG.debug("DDC single column"); return directCompressDDCSingleCol(colIndexes, cg); + } } private AColGroup directCompressDDCSingleCol(IColIndex colIndexes, CompressedSizeInfoColGroup cg) { @@ -465,16 +633,27 @@ public class ColGroupFactory { // unlike multi-col no special handling of zero entries are needed. if(cs.transposed) - readToMapDDCTransposed(col, map, d); - else - readToMapDDC(col, map, d); + if(cs.scaleFactors != null) { + throw new NotImplementedException(); // TODO: Handle scaled transposed columns + } + else { + readToMapDDCTransposed(col, map, d); + } + else { + if(cs.scaleFactors != null) { + readToMapDDCScaled(col, map, d); + } + else { + readToMapDDC(col, map, d); + } + } if(map.size() == 0) return new ColGroupEmpty(colIndexes); IDictionary dict = DictionaryFactory.create(map); final int nUnique = map.size(); - final AMapToData resData = d.resize( nUnique); + final AMapToData resData = d.resize(nUnique); return ColGroupDDC.create(colIndexes, dict, resData, null); } @@ -485,10 +664,14 @@ public class ColGroupFactory { final DblArrayCountHashMap map = new DblArrayCountHashMap(Math.max(cg.getNumVals(), 64)); boolean extra; - if(nRow < CompressionSettings.PAR_DDC_THRESHOLD || k < csi.getNumberColGroups() || pool == null ) + if(nRow < CompressionSettings.PAR_DDC_THRESHOLD || k < csi.getNumberColGroups() || pool == null) { + LOG.debug("Non parallel"); extra = readToMapDDC(colIndexes, map, d, 0, nRow, fill); - else + } + else { + LOG.debug("Parallel"); extra = parallelReadToMapDDC(colIndexes, map, d, nRow, fill, k); + } if(map.size() == 0) // If the column was empty. @@ -507,7 +690,11 @@ public class ColGroupFactory { private boolean readToMapDDC(IColIndex colIndexes, DblArrayCountHashMap map, AMapToData data, int rl, int ru, int fill) { - ReaderColumnSelection reader = ReaderColumnSelection.createReader(in, colIndexes, cs.transposed, rl, ru); + + ReaderColumnSelection reader = (cs.scaleFactors == null) ? ReaderColumnSelection.createReader(in, colIndexes, + cs.transposed, rl, + ru) : ReaderColumnSelection.createQuantizedReader(in, colIndexes, cs.transposed, rl, ru, cs.scaleFactors); + DblArray cellVals = reader.nextRow(); boolean extra = false; int r = rl; @@ -531,6 +718,49 @@ public class ColGroupFactory { return extra; } + // TODO: Merge logic to readToMapDDC. This should be done for other scaled methods + private void readToMapDDCScaled(int col, DoubleCountHashMap map, AMapToData data) { + double scaleFactor = cs.scaleFactors[0]; + + if(in.isInSparseFormat()) { + // not good but could happen + final SparseBlock sb = in.getSparseBlock(); + for(int r = 0; r < nRow; r++) { + if(sb.isEmpty(r)) + data.set(r, map.increment(0.0)); + else { + final int apos = sb.pos(r); + final int alen = sb.size(r) + apos; + final int[] aix = sb.indexes(r); + final int idx = Arrays.binarySearch(aix, apos, alen, col); + if(idx < 0) + data.set(r, map.increment(0.0)); + else { + double scaledValue = Math.floor(sb.values(r)[idx] * scaleFactor); + data.set(r, map.increment(scaledValue)); + } + } + } + } + else if(in.getDenseBlock().isContiguous()) { + final double[] dv = in.getDenseBlockValues(); + int off = col; + for(int r = 0; r < nRow; r++, off += nCol) { + double scaledValue = Math.floor(dv[off] * scaleFactor); + data.set(r, map.increment(scaledValue)); + } + } + else { + final DenseBlock db = in.getDenseBlock(); + for(int r = 0; r < nRow; r++) { + final double[] dv = db.values(r); + int off = db.pos(r) + col; + double scaledValue = Math.floor(dv[off] * scaleFactor); + data.set(r, map.increment(scaledValue)); + } + } + } + private void readToMapDDC(int col, DoubleCountHashMap map, AMapToData data) { if(in.isInSparseFormat()) { // not good but could happen @@ -673,7 +903,7 @@ public class ColGroupFactory { cs.sdcSortType); AOffset indexes = OffsetFactory.createOffset(s.getIndexes()); AMapToData _data = s.getData(); - _data = _data.resize( dict.getNumberOfValues(colIndexes.size())); + _data = _data.resize(dict.getNumberOfValues(colIndexes.size())); return ColGroupSDC.create(colIndexes, rlen, dict, defaultTuple, indexes, _data, null); } @@ -764,7 +994,9 @@ public class ColGroupFactory { } } IColIndex subCols = ColIndexFactory.create(cols.size()); - ReaderColumnSelection reader = ReaderColumnSelection.createReader(sub, subCols, false); + ReaderColumnSelection reader = (cs.scaleFactors == null) ? ReaderColumnSelection.createReader(sub, subCols, + false) : ReaderColumnSelection.createQuantizedReader(sub, subCols, false, cs.scaleFactors); + final int mapStartSize = Math.min(nrUniqueEstimate, offsetsInt.length / 2); DblArrayCountHashMap map = new DblArrayCountHashMap(mapStartSize); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java index cf0959bba7..1c3bce2e16 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java @@ -81,7 +81,7 @@ public class ColGroupUncompressed extends AColGroup { private final MatrixBlock _data; /** - * Do not use this constructor of column group uncompressed, instead uce the create constructor. + * Do not use this constructor of column group uncompressed, instead use the create constructor. * @param mb The contained data. * @param colIndexes Column indexes for this Columngroup */ @@ -90,6 +90,25 @@ public class ColGroupUncompressed extends AColGroup { _data = mb; } + /** + * Do not use this constructor of column group quantization-fused uncompressed, instead use the create constructor. + * @param mb The contained data. + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix + * @param colIndexes Column indexes for this Columngroup + */ + protected ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes, double[] scaleFactors) { + super(colIndexes); + // Apply scaling and flooring + // TODO: Use internal matrix prod + for(int r = 0; r < mb.getNumRows(); r++) { + double scaleFactor = scaleFactors.length == 1 ? scaleFactors[0] : scaleFactors[r]; + for(int c = 0; c < mb.getNumColumns(); c++) { + double newValue = Math.floor(mb.get(r, c) * scaleFactor); + mb.set(r, c, newValue); + } + } + _data = mb; + } /** * Create an Uncompressed Matrix Block, where the columns are offset by col indexes. * @@ -106,6 +125,97 @@ public class ColGroupUncompressed extends AColGroup { return new ColGroupUncompressed(mb, colIndexes); } + /** + * Create ana quantization-fused uncompressed Matrix Block, where the columns are offset by col indexes. + * + * It is assumed that the size of the colIndexes and number of columns in mb is matching. + * + * @param mb The MB / data to contain in the uncompressed column + * @param colIndexes The column indexes for the group + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix + * @return An Uncompressed Column group + */ + public static AColGroup createQuantized(MatrixBlock mb, IColIndex colIndexes, double[] scaleFactors) { + if(mb == null || mb.isEmpty()) + // TODO: handle quantization-fused compression if deemed necessary, + // but if the matrix reaches here, it likely doesn't need quantization. + return new ColGroupEmpty(colIndexes); + else + return new ColGroupUncompressed(mb, colIndexes, scaleFactors); + } + + /** + * Main constructor for a quantization-fused uncompressed ColGroup. + * + * @param colIndexes Indices (relative to the current block) of the columns that this column group represents. + * @param rawBlock The uncompressed block; uncompressed data must be present at the time that the constructor is + * called + * @param transposed Says if the input matrix raw block have been transposed. + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix + * @return AColGroup. + */ + public static AColGroup createQuantized(IColIndex colIndexes, MatrixBlock rawBlock, boolean transposed, double[] scaleFactors) { + + // special cases + if(rawBlock.isEmptyBlock(false)) // empty input + // TODO: handle quantization-fused compression if deemed necessary, + // but if the matrix reaches here, it likely doesn't need quantization. + return new ColGroupEmpty(colIndexes); + else if(!transposed && colIndexes.size() == rawBlock.getNumColumns()) + // full input to uncompressedColumnGroup + return new ColGroupUncompressed(rawBlock, colIndexes, scaleFactors); + + MatrixBlock mb; + final int _numRows = transposed ? rawBlock.getNumColumns() : rawBlock.getNumRows(); + + if(colIndexes.size() == 1) { + final int col = colIndexes.get(0); + if(transposed) { + mb = rawBlock.slice(col, col, 0, rawBlock.getNumColumns() - 1); + mb = LibMatrixReorg.transposeInPlace(mb, InfrastructureAnalyzer.getLocalParallelism()); + } + else + mb = rawBlock.slice(0, rawBlock.getNumRows() - 1, col, col); + + return createQuantized(mb, colIndexes, scaleFactors); + } + + // Create a matrix with just the requested rows of the original block + mb = new MatrixBlock(_numRows, colIndexes.size(), rawBlock.isInSparseFormat()); + + final int m = _numRows; + final int n = colIndexes.size(); + + if(transposed) { + if (scaleFactors.length == 1) { + for(int i = 0; i < m; i++) + for(int j = 0; j < n; j++) + mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[0])); + } else { + for(int i = 0; i < m; i++) + for(int j = 0; j < n; j++) + mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[j])); + } + } + else { + if (scaleFactors.length == 1) { + for(int i = 0; i < m; i++) + for(int j = 0; j < n; j++) + mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[0])); + } else { + for(int i = 0; i < m; i++) + for(int j = 0; j < n; j++) + mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[i])); + } + } + + mb.recomputeNonZeros(); + mb.examSparsity(); + + return create(mb, colIndexes); + + } + /** * Main constructor for Uncompressed ColGroup. * diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstExact.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstExact.java index 6483eba104..48dd245c4e 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstExact.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstExact.java @@ -38,7 +38,7 @@ public class ComEstExact extends AComEst { @Override public CompressedSizeInfoColGroup getColGroupInfo(IColIndex colIndexes, int estimate, int nrUniqueUpperBound) { - final IEncode map = EncodingFactory.createFromMatrixBlock(_data, _cs.transposed, colIndexes); + final IEncode map = EncodingFactory.createFromMatrixBlock(_data, _cs.transposed, colIndexes, _cs.scaleFactors); if(map instanceof EmptyEncoding) return new CompressedSizeInfoColGroup(colIndexes, getNumRows(), CompressionType.EMPTY); return getFacts(map, colIndexes); @@ -59,7 +59,7 @@ public class ComEstExact extends AComEst { protected CompressedSizeInfoColGroup getFacts(IEncode map, IColIndex colIndexes) { final int _numRows = getNumRows(); - final EstimationFactors em = map.extractFacts(_numRows, _data.getSparsity(), _data.getSparsity(), _cs); + final EstimationFactors em = map.extractFacts(_numRows, _data.getSparsity(), _data.getSparsity(), _cs); return new CompressedSizeInfoColGroup(colIndexes, em, _cs.validCompressions, map); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java index 97b451daee..8a28a0ca49 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java @@ -90,7 +90,7 @@ public class ComEstSample extends AComEst { _data.getSparseBlock().isEmpty(colIndexes.get(0)))) return new CompressedSizeInfoColGroup(colIndexes, getNumRows(), CompressionType.EMPTY); - final IEncode map = EncodingFactory.createFromMatrixBlock(_sample, _transposed, colIndexes); + final IEncode map = EncodingFactory.createFromMatrixBlock(_sample, _transposed, colIndexes, _cs.scaleFactors); return extractInfo(map, colIndexes, maxDistinct); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java index 6cc882cc2f..963a044d14 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java @@ -206,6 +206,10 @@ public class CompressedSizeInfoColGroup { return _facts.tupleSparsity; } + public EstimationFactors getFacts() { + return _facts; + } + public IEncode getMap() { return _map; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java index 257ddf6f3c..b196da658c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java @@ -56,10 +56,24 @@ public interface EncodingFactory { public static IEncode createFromMatrixBlock(MatrixBlock m, boolean transposed, IColIndex rowCols) { if(m.isEmpty()) return new EmptyEncoding(); - else if(rowCols.size() == 1) - return createFromMatrixBlock(m, transposed, rowCols.get(0)); - else - return createWithReader(m, rowCols, transposed); + else if(rowCols.size() == 1) { + return createFromMatrixBlock(m, transposed, rowCols.get(0), null); + } + else { + return createWithReader(m, rowCols, transposed, null); + } + } + + public static IEncode createFromMatrixBlock(MatrixBlock m, boolean transposed, IColIndex rowCols, + double[] scaleFactors) { + if(m.isEmpty()) + return new EmptyEncoding(); + else if(rowCols.size() == 1) { + return createFromMatrixBlock(m, transposed, rowCols.get(0), scaleFactors); + } + else { + return createWithReader(m, rowCols, transposed, scaleFactors); + } } /** @@ -115,8 +129,54 @@ public interface EncodingFactory { } else if(m.isInSparseFormat()) return createFromSparse(m, rowCol); - else + else { return createFromDense(m, rowCol); + } + } + + /** + * Create encoding of a single specific column inside the matrix input. + * + * @param m The Matrix to encode a column from + * @param transposed If the matrix is in transposed format. + * @param rowCol The column index to encode + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire + * matrix + * @return An encoded format of the information of this column. + */ + public static IEncode createFromMatrixBlock(MatrixBlock m, boolean transposed, int rowCol, double[] scaleFactors) { + if(m.isEmpty()) + return new EmptyEncoding(); + else if(transposed) { + if(scaleFactors != null) { + if(m.isInSparseFormat()) + throw new NotImplementedException(); + else + return createFromDenseTransposedQuantized(m, rowCol, scaleFactors); + } + else { + if(m.isInSparseFormat()) + return createFromSparseTransposed(m, rowCol); + else + return createFromDenseTransposed(m, rowCol); + } + } + else if(m.isInSparseFormat()) { + if(scaleFactors != null) { + throw new NotImplementedException(); // TODO: handle quantization-fused compression + } + else { + return createFromSparse(m, rowCol); + } + } + else { + if(scaleFactors != null) { + return createFromDenseQuantized(m, rowCol, scaleFactors); + } + else { + return createFromDense(m, rowCol); + } + } } public static IEncode create(ColGroupConst c) { @@ -229,13 +289,13 @@ public interface EncodingFactory { // Iteration 3 of non zero indexes, make a Offset Encoding to know what cells are zero and not. // not done yet - try{ + try { final AOffset o = OffsetFactory.createOffset(aix, apos, alen); return new SparseEncoding(d, o, m.getNumColumns()); } - catch(Exception e){ - String mes = Arrays.toString(Arrays.copyOfRange(aix, apos, alen)) + "\n" + apos + " " + alen; + catch(Exception e) { + String mes = Arrays.toString(Arrays.copyOfRange(aix, apos, alen)) + "\n" + apos + " " + alen; mes += Arrays.toString(Arrays.copyOfRange(avals, apos, alen)); throw new DMLRuntimeException(mes, e); } @@ -341,8 +401,229 @@ public interface EncodingFactory { return new SparseEncoding(d, o, m.getNumRows()); } - private static IEncode createWithReader(MatrixBlock m, IColIndex rowCols, boolean transposed) { - final ReaderColumnSelection reader1 = ReaderColumnSelection.createReader(m, rowCols, transposed); + private static IEncode createFromDenseTransposedQuantized(MatrixBlock m, int row, double[] scaleFactors) { + final DenseBlock db = m.getDenseBlock(); + if(!db.isContiguous()) + throw new NotImplementedException("Not Implemented non contiguous dense matrix encoding for sample"); + final DoubleCountHashMap map = new DoubleCountHashMap(); + final int off = db.pos(row); + final int nCol = m.getNumColumns(); + final int end = off + nCol; + final double[] vals = db.values(row); + + // Validate scaleFactors + boolean useSingleScalar = false; + if(scaleFactors != null) { + if(scaleFactors.length == 1) { + useSingleScalar = true; + } + } + + if(useSingleScalar == true) { + + // Iteration 1: Apply scaling & quantization, then populate the HashMap + for(int i = off; i < end; i++) // sequential access + map.increment(Math.floor(vals[i] * scaleFactors[0])); + + final int nUnique = map.size(); + + if(nUnique == 1) + return new ConstEncoding(m.getNumColumns()); + else if(nUnique == 0) + return new EmptyEncoding(); + else if(map.getOrDefault(0.0, -1) > nCol / 4) { + map.replaceWithUIDsNoZero(); + final int zeroCount = map.get(0.0); + final int nV = nCol - zeroCount; + final IntArrayList offsets = new IntArrayList(nV); + + final AMapToData d = MapToFactory.create(nV, nUnique - 1); + int di = 0; + for(int i = off, r = 0; i < end; i++, r++) { + double value = Math.floor(vals[i] * scaleFactors[0]); + if (value != 0) { + offsets.appendValue(r); + d.set(di++, map.getId(value)); + } + } + if(di != nV) + throw new RuntimeException("Did not find equal number of elements " + di + " vs " + nV); + + final AOffset o = OffsetFactory.createOffset(offsets); + return new SparseEncoding(d, o, nCol); + } + else { + // Create output map + final AMapToData d = MapToFactory.create(nCol, nUnique); + + // Iteration 2, make final map + for(int i = off, r = 0; i < end; i++, r++) { + double value = Math.floor(vals[i] * scaleFactors[0]); + d.set(r, map.getId(value)); + } + + return new DenseEncoding(d); + } + } + else { + // Iteration 1: Apply scaling & quantization, then populate the HashMap + for(int i = off; i < end; i++) // sequential access + map.increment(Math.floor(vals[i] * scaleFactors[row])); + + final int nUnique = map.size(); + + if(nUnique == 1) + return new ConstEncoding(m.getNumColumns()); + else if(nUnique == 0) + return new EmptyEncoding(); + else if(map.getOrDefault(0.0, -1) > nCol / 4) { + map.replaceWithUIDsNoZero(); + final int zeroCount = map.get(0.0); + final int nV = nCol - zeroCount; + final IntArrayList offsets = new IntArrayList(nV); + + final AMapToData d = MapToFactory.create(nV, nUnique - 1); + int di = 0; + for(int i = off, r = 0; i < end; i++, r++) { + double value = Math.floor(vals[i] * scaleFactors[row]); + if (value != 0) { + offsets.appendValue(r); + d.set(di++, map.getId(value)); + } + } + if(di != nV) + throw new RuntimeException("Did not find equal number of elements " + di + " vs " + nV); + + final AOffset o = OffsetFactory.createOffset(offsets); + return new SparseEncoding(d, o, nCol); + } + else { + // Create output map + final AMapToData d = MapToFactory.create(nCol, nUnique); + + // Iteration 2, make final map + for(int i = off, r = 0; i < end; i++, r++) { + double value = Math.floor(vals[i] * scaleFactors[row]); + d.set(r, map.getId(value)); + } + + return new DenseEncoding(d); + } + + } + } + + private static IEncode createFromDenseQuantized(MatrixBlock m, int col, double[] scaleFactors) { + final DenseBlock db = m.getDenseBlock(); + if(!db.isContiguous()) + throw new NotImplementedException("Not Implemented non contiguous dense matrix encoding for sample"); + final DoubleCountHashMap map = new DoubleCountHashMap(16); + final int off = col; + final int nCol = m.getNumColumns(); + final int nRow = m.getNumRows(); + final int end = off + nRow * nCol; + final double[] vals = m.getDenseBlockValues(); + + // Validate scaleFactors + boolean useSingleScalar = false; + if(scaleFactors != null) { + if(scaleFactors.length == 1) { + useSingleScalar = true; + } + } + + if(useSingleScalar == true) { + // Iteration 1, make Count HashMap with quantized values + for(int i = off; i < end; i += nCol) {// jump down through rows. + map.increment(Math.floor(vals[i] * scaleFactors[0])); + } + final int nUnique = map.size(); + if(nUnique == 1) + return new ConstEncoding(m.getNumColumns()); + + if(map.getOrDefault(0.0, -1) > nRow / 4) { + map.replaceWithUIDsNoZero(); + final int zeroCount = map.get(0.0); + final int nV = m.getNumRows() - zeroCount; + final IntArrayList offsets = new IntArrayList(nV); + + final AMapToData d = MapToFactory.create(nV, nUnique - 1); + int di = 0; + for(int i = off, r = 0; i < end; i += nCol, r++) { + double value = Math.floor(vals[i] * scaleFactors[0]); + if(value != 0) { + offsets.appendValue(r); + d.set(di++, map.getId(value)); + } + } + if(di != nV) + throw new DMLRuntimeException("Invalid number of zero."); + + final AOffset o = OffsetFactory.createOffset(offsets); + + return new SparseEncoding(d, o, nRow); + } + else { + // Allocate counts, and iterate once to replace counts with u ids + + final AMapToData d = MapToFactory.create(nRow, nUnique); + // Iteration 2, make final map with quantized values + for(int i = off, r = 0; i < end; i += nCol, r++) { + double value = Math.floor(vals[i] * scaleFactors[0]); + d.set(r, map.getId(value)); + } + return new DenseEncoding(d); + } + } + else { + // Iteration 1, make Count HashMap with row-wise quantized values + for(int i = off, r = 0; i < end; i += nCol, r++) {// jump down through rows. + map.increment(Math.floor(vals[i] * scaleFactors[r])); + } + final int nUnique = map.size(); + if(nUnique == 1) + return new ConstEncoding(m.getNumColumns()); + + if(map.getOrDefault(0.0, -1) > nRow / 4) { + map.replaceWithUIDsNoZero(); + final int zeroCount = map.get(0.0); + final int nV = m.getNumRows() - zeroCount; + final IntArrayList offsets = new IntArrayList(nV); + + final AMapToData d = MapToFactory.create(nV, nUnique - 1); + int di = 0; + for(int i = off, r = 0; i < end; i += nCol, r++) { + double value = Math.floor(vals[i] * scaleFactors[r]); + if(value != 0) { + offsets.appendValue(r); + d.set(di++, map.getId(value)); + } + } + if(di != nV) + throw new DMLRuntimeException("Invalid number of zero."); + + final AOffset o = OffsetFactory.createOffset(offsets); + + return new SparseEncoding(d, o, nRow); + } + else { + // Allocate counts, and iterate once to replace counts with u ids + + final AMapToData d = MapToFactory.create(nRow, nUnique); + // Iteration 2, make final map with row-wise quantized values + for(int i = off, r = 0; i < end; i += nCol, r++) { + double value = Math.floor(vals[i] * scaleFactors[r]); + d.set(r, map.getId(value)); + } + return new DenseEncoding(d); + } + } + } + + private static IEncode createWithReader(MatrixBlock m, IColIndex rowCols, boolean transposed, + double[] scaleFactors) { + final ReaderColumnSelection reader1 = (scaleFactors == null) ? ReaderColumnSelection.createReader(m, rowCols, + transposed) : ReaderColumnSelection.createQuantizedReader(m, rowCols, transposed, scaleFactors); final int nRows = transposed ? m.getNumColumns() : m.getNumRows(); final DblArrayCountHashMap map = new DblArrayCountHashMap(); final IntArrayList offsets = new IntArrayList(); @@ -362,17 +643,18 @@ public interface EncodingFactory { if(offsets.size() < nRows / 4) // Output encoded sparse since there is very empty. - return createWithReaderSparse(m, map, rowCols, offsets, nRows, transposed); + return createWithReaderSparse(m, map, rowCols, offsets, nRows, transposed, scaleFactors); else - return createWithReaderDense(m, map, rowCols, nRows, transposed, offsets.size() < nRows); + return createWithReaderDense(m, map, rowCols, nRows, transposed, offsets.size() < nRows, scaleFactors); } private static IEncode createWithReaderDense(MatrixBlock m, DblArrayCountHashMap map, IColIndex rowCols, int nRows, - boolean transposed, boolean zero) { + boolean transposed, boolean zero, double[] scaleFactors) { // Iteration 2, final int unique = map.size() + (zero ? 1 : 0); - final ReaderColumnSelection reader2 = ReaderColumnSelection.createReader(m, rowCols, transposed); + final ReaderColumnSelection reader2 = (scaleFactors == null) ? ReaderColumnSelection.createReader(m, rowCols, + transposed) : ReaderColumnSelection.createQuantizedReader(m, rowCols, transposed, scaleFactors); final AMapToData d = MapToFactory.create(nRows, unique); DblArray cellVals; @@ -387,8 +669,9 @@ public interface EncodingFactory { } private static IEncode createWithReaderSparse(MatrixBlock m, DblArrayCountHashMap map, IColIndex rowCols, - IntArrayList offsets, int nRows, boolean transposed) { - final ReaderColumnSelection reader2 = ReaderColumnSelection.createReader(m, rowCols, transposed); + IntArrayList offsets, int nRows, boolean transposed, double[] scaleFactors) { + final ReaderColumnSelection reader2 = (scaleFactors == null) ? ReaderColumnSelection.createReader(m, rowCols, + transposed) : ReaderColumnSelection.createQuantizedReader(m, rowCols, transposed, scaleFactors); DblArray cellVals = reader2.nextRow(); final AMapToData d = MapToFactory.create(offsets.size(), map.size()); diff --git a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java index e087525bbb..d6ec60336f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java +++ b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelection.java @@ -19,6 +19,8 @@ package org.apache.sysds.runtime.compress.readers; +import org.apache.commons.lang3.NotImplementedException; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.DMLCompressionException; @@ -83,7 +85,7 @@ public abstract class ReaderColumnSelection { } /** - * Create an reader of the matrix block that is able to iterate though all the rows and return as dense double + * Create a reader of the matrix block that is able to iterate though all the rows and return as dense double * arrays. * * Note the reader reuse the return, therefore if needed for something please copy the returned rows. @@ -100,7 +102,30 @@ public abstract class ReaderColumnSelection { } /** - * Create an reader of the matrix block that is able to iterate though all the rows and return as dense double + * Create a reader of the matrix block that directly reads quantized values using scale factors. + * + * Note the reader reuse the return, therefore if needed for something please copy the returned rows. + * + * @param rawBlock The block to iterate though + * @param colIndices The column indexes to extract and insert into the double array + * @param transposed If the raw block should be treated as transposed + * @param scaleFactors An array of scale factors applied. + * - If row-wise scaling is used, this should be an array where each value corresponds to a row. + * - If a single scalar is provided, it is applied uniformly to the entire matrix. + * @return A reader of the columns specified + */ + + public static ReaderColumnSelection createQuantizedReader(MatrixBlock rawBlock, IColIndex colIndices, boolean transposed, double[] scaleFactors) { + if (transposed) { + throw new NotImplementedException(); + } + final int rl = 0; + final int ru = transposed ? rawBlock.getNumColumns() : rawBlock.getNumRows(); + return createQuantizedReader(rawBlock, colIndices, transposed, rl, ru, scaleFactors); + } + + /** + * Create a reader of the matrix block that is able to iterate though all the rows and return as dense double * arrays. * * Note the reader reuse the return, therefore if needed for something please copy the returned rows. @@ -136,6 +161,40 @@ public abstract class ReaderColumnSelection { return new ReaderColumnSelectionDenseSingleBlock(rawBlock, colIndices, rl, ru); } + /** + * Create a reader of the matrix block that directly reads quantized values using scale factors. + * + * Note the reader reuse the return, therefore if needed for something please copy the returned rows. + * + * @param rawBlock The block to iterate though + * @param colIndices The column indexes to extract and insert into the double array + * @param transposed If the raw block should be treated as transposed + * @param rl The row to start at + * @param ru The row to end at (not inclusive) + * @param scaleFactors An array of scale factors applied. + * - If row-wise scaling is used, this should be an array where each value corresponds to a row. + * - If a single scalar is provided, it is applied uniformly to the entire matrix. + * @return A reader of the columns specified + */ + public static ReaderColumnSelection createQuantizedReader(MatrixBlock rawBlock, IColIndex colIndices, boolean transposed, + int rl, int ru, double[] scaleFactors) { + checkInput(rawBlock, colIndices, rl, ru, transposed); + rl = rl - 1; + if(rawBlock.isEmpty()) { + LOG.warn("It is likely an error occurred when reading an empty block, but we do support it!"); + return new ReaderColumnSelectionEmpty(rawBlock, colIndices, rl, ru, transposed); + } + else if(transposed) { + throw new NotImplementedException(); + } + else if(rawBlock.isInSparseFormat()) { + throw new NotImplementedException(); + } + else { + return new ReaderColumnSelectionDenseSingleBlockQuantized(rawBlock, colIndices, rl, ru, scaleFactors); + } + } + private static void checkInput(final MatrixBlock rawBlock, final IColIndex colIndices, final int rl, final int ru, final boolean transposed) { if(colIndices.size() <= 1) diff --git a/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseSingleBlockQuantized.java b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseSingleBlockQuantized.java new file mode 100644 index 0000000000..645e694bb4 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/readers/ReaderColumnSelectionDenseSingleBlockQuantized.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.readers; + +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.utils.DblArray; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; + +public class ReaderColumnSelectionDenseSingleBlockQuantized extends ReaderColumnSelection { + private final double[] _data; + private final int _numCols; + private final double[] _scaleFactors; + + protected ReaderColumnSelectionDenseSingleBlockQuantized(MatrixBlock data, IColIndex colIndices, int rl, int ru, + double[] scaleFactors) { + super(colIndices, rl, Math.min(ru, data.getNumRows()) - 1); + _data = data.getDenseBlockValues(); + _numCols = data.getNumColumns(); + _scaleFactors = scaleFactors; + } + + protected DblArray getNextRow() { + + _rl++; + final int indexOff = _rl * _numCols; + double scaleFactor = _scaleFactors.length == 1 ? _scaleFactors[0] : _scaleFactors[_rl]; + + for(int i = 0; i < _colIndexes.size(); i++) + reusableArr[i] = Math.floor(_data[indexOff + _colIndexes.get(i)] * scaleFactor); + + return reusableReturn; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java index cfa67f3c3c..a76dd6aaca 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java @@ -196,7 +196,11 @@ public class CPInstructionParser extends InstructionParser { case DeCompression: return DeCompressionCPInstruction.parseInstruction(str); - + + case QuantizeCompression: + LOG.debug("Parsing Quantize Compress instruction"); + return CompressionCPInstruction.parseQuantizationFusedInstruction(str); + case Local: return LocalCPInstruction.parseInstruction(str); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionParser.java index fbe7c1d757..ee25cfaaa2 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionParser.java @@ -63,8 +63,9 @@ public class InstructionParser return null; String[] strlist = str.split(Instruction.INSTRUCTION_DELIM); Instruction[] inst = new Instruction[strlist.length]; - for ( int i=0; i < inst.length; i++ ) + for ( int i=0; i < inst.length; i++ ) { inst[i] = parseSingleInstruction ( strlist[i] ); + } return inst; } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java index f8527276a7..b0b502f8a0 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java @@ -46,6 +46,7 @@ public abstract class CPInstruction extends Instruction { StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, Sql, Prefetch, Broadcast, TrigRemote, EvictLineageCache, NoOp, + QuantizeCompression } protected final CPType _cptype; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java index efc8e21777..52e0ad8154 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java @@ -44,6 +44,9 @@ public class CompressionCPInstruction extends ComputationCPInstruction { private final int _singletonLookupID; private final int _numThreads; + /** This is set to true only for quantization-fused compression */ + private final boolean _quantizationFused; + /** This is only for binned compression with 2 outputs */ protected final List<CPOperand> _outputs; @@ -53,6 +56,7 @@ public class CompressionCPInstruction extends ComputationCPInstruction { _outputs = null; this._singletonLookupID = singletonLookupID; this._numThreads = numThreads; + this._quantizationFused = false; } private CompressionCPInstruction(Operator op, CPOperand in1, CPOperand in2, List<CPOperand> out, String opcode, @@ -61,8 +65,18 @@ public class CompressionCPInstruction extends ComputationCPInstruction { _outputs = out; this._singletonLookupID = singletonLookupID; this._numThreads = numThreads; + this._quantizationFused = false; } + private CompressionCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, + String istr, int singletonLookupID, int numThreads) { + super(CPType.QuantizeCompression, op, in1, in2, null, out, opcode, istr); + _outputs = null; + this._singletonLookupID = singletonLookupID; + this._numThreads = numThreads; + this._quantizationFused = true; + } + public static CompressionCPInstruction parseInstruction(String str) { InstructionUtils.checkNumFields(str, 3, 4, 5); String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); @@ -89,10 +103,23 @@ public class CompressionCPInstruction extends ComputationCPInstruction { } } + public static CompressionCPInstruction parseQuantizationFusedInstruction(String str) { + InstructionUtils.checkNumFields(str, 3, 4, 5); + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode = parts[0]; + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand out = new CPOperand(parts[3]); + int numThreads = Integer.parseInt(parts[4]); + return new CompressionCPInstruction(null, in1, in2, out, opcode, str, 0, numThreads); + } + @Override public void processInstruction(ExecutionContext ec) { if(input2 == null) processSimpleCompressInstruction(ec); + else if (this._quantizationFused == true) + processSimpleQuantizationFusedCompressInstruction(ec); else processCompressByBinInstruction(ec); } @@ -143,6 +170,28 @@ public class CompressionCPInstruction extends ComputationCPInstruction { } } + private void processSimpleQuantizationFusedCompressInstruction(ExecutionContext ec) { + // final MatrixBlock in = ec.getMatrixInput(input1.getName()); + final SingletonLookupHashMap m = SingletonLookupHashMap.getMap(); + + // Get and clear workload tree entry for this compression instruction. + final WTreeRoot root = (_singletonLookupID != 0) ? (WTreeRoot) m.get(_singletonLookupID) : null; + // We used to remove the key from the hash map, + // however this is not correct since the compression statement + // can be reused in multiple for loops. + + ScalarObject scalarIn2 = null; + MatrixBlock matrixIn2 = null; + + if (input2.isScalar() == true) { + scalarIn2 = ec.getScalarInput(input2); + processMatrixBlockQuantizationFusedCompression(ec, ec.getMatrixInput(input1.getName()), scalarIn2, _numThreads, root); + } else if (input2.isMatrix() == true) { + matrixIn2 = ec.getMatrixInput(input2.getName()); + processMatrixBlockQuantizationFusedCompression(ec, ec.getMatrixInput(input1.getName()), matrixIn2, _numThreads, root); + } + } + private void processMatrixBlockCompression(ExecutionContext ec, MatrixBlock in, int k, WTreeRoot root) { Pair<MatrixBlock, CompressionStatistics> compResult = CompressedMatrixBlockFactory.compress(in, k, root); if(LOG.isTraceEnabled()) @@ -161,4 +210,32 @@ public class CompressionCPInstruction extends ComputationCPInstruction { ec.releaseFrameInput(input1.getName()); ec.setFrameOutput(output.getName(), compResult); } + + private void processMatrixBlockQuantizationFusedCompression(ExecutionContext ec, MatrixBlock in1, MatrixBlock in2, int k, WTreeRoot root) { + Pair<MatrixBlock, CompressionStatistics> compResult = CompressedMatrixBlockFactory.compress(in1, in2, k, root); + if(LOG.isTraceEnabled()) + LOG.trace(compResult.getRight()); + MatrixBlock out = compResult.getLeft(); + if(LOG.isInfoEnabled()) + LOG.info("Compression output class: " + out.getClass().getSimpleName()); + // Set output and release input + ec.releaseMatrixInput(input1.getName()); + ec.releaseMatrixInput(input2.getName()); + ec.setMatrixOutput(output.getName(), out); + } + + private void processMatrixBlockQuantizationFusedCompression(ExecutionContext ec, MatrixBlock in1, ScalarObject in2, int k, WTreeRoot root) { + Pair<MatrixBlock, CompressionStatistics> compResult = CompressedMatrixBlockFactory.compress(in1, in2, k, root); + if(LOG.isTraceEnabled()) + LOG.trace(compResult.getRight()); + MatrixBlock out = compResult.getLeft(); + if(LOG.isInfoEnabled()) + LOG.info("Compression output class: " + out.getClass().getSimpleName()); + // Set output and release input + ec.releaseMatrixInput(input1.getName()); + if (input2.isMatrix()) { + ec.releaseMatrixInput(input2.getName()); + } + ec.setMatrixOutput(output.getName(), out); + } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/qcompress/CompareCompressionTypeTest.java b/src/test/java/org/apache/sysds/test/component/compress/qcompress/CompareCompressionTypeTest.java new file mode 100644 index 0000000000..f3a30a53cf --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/qcompress/CompareCompressionTypeTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress.qcompress; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import java.util.List; +import java.util.Random; + +import org.apache.sysds.runtime.compress.CompressionSettings; +import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; +import org.apache.sysds.runtime.compress.cocode.CoCoderFactory.PartitionerType; +import org.apache.sysds.runtime.compress.estim.AComEst; +import org.apache.sysds.runtime.compress.estim.ComEstFactory; +import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; +import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class CompareCompressionTypeTest { + + /** + * Test 1: Compare the best compression types of two matrices, m0 and m1: DDC. + * + * - m0 is generated as a floored matrix. - m1 is generated as a full-precision matrix, but will be internally + * multiplied by 1.0 and floored. - Since m1 undergoes an equivalent transformation (scaling by 1.0 and flooring), + * the best compression types determined by the estimator should match elementwise for both matrices. - This + * validates that the estimator correctly handles explicit flooring vs. internal scaling and flooring during + * quantization-fused compression. + */ + @Test + public void testCompareBestCompressionTypeForTwoMatricesDDC() { + try { + Random r = new Random(1234); + int k = 4; + + // Generate first floored matrix and compute compression info + MatrixBlock m0 = generateTestMatrix(10000, 500, 1, 100, 1.0, r, true); + CompressionSettings cs0 = new CompressionSettingsBuilder().setColumnPartitioner(PartitionerType.GREEDY) + .setSeed(1234).create(); + AComEst estimator0 = ComEstFactory.createEstimator(m0, cs0, k); + CompressedSizeInfo compressedGroups0 = estimator0.computeCompressedSizeInfos(k); + + // Generate second matrix full-precision matrix that will be internally scaled by 1.0 and floored and + // compute + // compression info + MatrixBlock m1 = generateTestMatrix(10000, 500, 1, 100, 1.0, r, false); + double[] scaleFactor = {1.0}; + CompressionSettings cs1 = new CompressionSettingsBuilder().setColumnPartitioner(PartitionerType.GREEDY) + .setScaleFactor(scaleFactor).setSeed(1234).create(); + AComEst estimator1 = ComEstFactory.createEstimator(m1, cs1, k); + CompressedSizeInfo compressedGroups1 = estimator1.computeCompressedSizeInfos(k); + + List<CompressedSizeInfoColGroup> groups0 = compressedGroups0.getInfo(); + List<CompressedSizeInfoColGroup> groups1 = compressedGroups1.getInfo(); + + assertEquals("Mismatch in number of compressed groups", groups0.size(), groups1.size()); + + for(int i = 0; i < groups0.size(); i++) { + assertEquals("Best compression type mismatch at index " + i, groups0.get(i).getBestCompressionType(), + groups1.get(i).getBestCompressionType()); + } + + } + catch(Exception e) { + e.printStackTrace(); + fail("Compression extraction failed: " + e.getMessage()); + } + } + + /** + * Test 1: Compare the best compression types of two matrices, m0 and m1: CONST. + * + * - m0 is generated as a floored matrix. - m1 is generated as a full-precision matrix, but will be internally + * multiplied by 1.0 and floored. - Since m1 undergoes an equivalent transformation (scaling by 1.0 and flooring), + * the best compression types determined by the estimator should match elementwise for both matrices. - This + * validates that the estimator correctly handles explicit flooring vs. internal scaling and flooring during + * quantization-fused compression. + */ + @Test + public void testCompareBestCompressionTypeForTwoMatricesConst() { + try { + Random r = new Random(1234); + int k = 4; + + // Generate first floored matrix and compute compression info + MatrixBlock m0 = generateTestMatrix(10000, 500, 1, 1, 1.0, r, true); + CompressionSettings cs0 = new CompressionSettingsBuilder().setColumnPartitioner(PartitionerType.GREEDY) + .setSeed(1234).create(); + AComEst estimator0 = ComEstFactory.createEstimator(m0, cs0, k); + CompressedSizeInfo compressedGroups0 = estimator0.computeCompressedSizeInfos(k); + + // Generate second matrix full-precision matrix that will be internally scaled by 1.0 and floored and + // compute + // compression info + MatrixBlock m1 = generateTestMatrix(10000, 500, 1, 1, 1.0, r, false); + double[] scaleFactor = {1.0}; + CompressionSettings cs1 = new CompressionSettingsBuilder().setColumnPartitioner(PartitionerType.GREEDY) + .setScaleFactor(scaleFactor).setSeed(1234).create(); + AComEst estimator1 = ComEstFactory.createEstimator(m1, cs1, k); + CompressedSizeInfo compressedGroups1 = estimator1.computeCompressedSizeInfos(k); + + List<CompressedSizeInfoColGroup> groups0 = compressedGroups0.getInfo(); + List<CompressedSizeInfoColGroup> groups1 = compressedGroups1.getInfo(); + + assertEquals("Mismatch in number of compressed groups", groups0.size(), groups1.size()); + + for(int i = 0; i < groups0.size(); i++) { + assertEquals("Best compression type mismatch at index " + i, groups0.get(i).getBestCompressionType(), + groups1.get(i).getBestCompressionType()); + } + + } + catch(Exception e) { + e.printStackTrace(); + fail("Compression extraction failed: " + e.getMessage()); + } + } + + /** + * Generate a test matrix with specified dimensions, value range, and sparsity. + */ + private static MatrixBlock generateTestMatrix(int nRow, int nCol, int min, int max, double s, Random r, + boolean floored) { + final int m = Integer.MAX_VALUE; + MatrixBlock mb = TestUtils.generateTestMatrixBlock(nRow, nCol, min, max, s, r.nextInt(m)); + return floored ? TestUtils.floor(mb) : mb; + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/qcompress/QuantizationFusedCompressionTest.java b/src/test/java/org/apache/sysds/test/component/compress/qcompress/QuantizationFusedCompressionTest.java new file mode 100644 index 0000000000..916c22abfd --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/qcompress/QuantizationFusedCompressionTest.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress.qcompress; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; +import org.apache.sysds.runtime.compress.CompressionStatistics; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; +import org.apache.sysds.runtime.instructions.cp.DoubleObject; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; +import org.apache.commons.lang3.tuple.Pair; + +/** + * This class tests the quantization-fused compression in SystemDS. + */ +public class QuantizationFusedCompressionTest { + + /** + * Test 1: Quantization-fused Compression with a scalar scaling factor. + */ + @Test + public void testQuantizationCompressionWithScalar() { + MatrixBlock mb = TestUtils.generateTestMatrixBlock(4, 4, 1, 10, 1.0, 1234); + ScalarObject sf = new DoubleObject(2.5); + Pair<MatrixBlock, CompressionStatistics> result = CompressedMatrixBlockFactory.compress(mb, sf, 1, null); + MatrixBlock qmb = result.getLeft(); + for(int i = 0; i < mb.getNumRows(); i++) { + for(int j = 0; j < mb.getNumColumns(); j++) { + double expected = Math.floor(mb.get(i, j) * sf.getDoubleValue()); + assertEquals("Quantized compression mismatch!", expected, qmb.get(i, j), 0.0); + } + } + } + + /** + * Test 2: Quantization-fused compression with row-wise vector scaling. + */ + @Test + public void testQuantizationCompressionWithRowwiseVectorScale() { + MatrixBlock mb = TestUtils.generateTestMatrixBlock(5, 4, 1, 10, 1.0, 5678); + MatrixBlock sf = new MatrixBlock(5, 1, false); + sf.set(0, 0, 1.5); + sf.set(1, 0, 2.0); + sf.set(2, 0, 2.5); + sf.set(3, 0, 3.0); + sf.set(4, 0, 3.5); + Pair<MatrixBlock, CompressionStatistics> result = CompressedMatrixBlockFactory.compress(mb, sf, 1, null); + MatrixBlock qmb = result.getLeft(); + for(int i = 0; i < mb.getNumRows(); i++) { + for(int j = 0; j < mb.getNumColumns(); j++) { + double expected = Math.floor(mb.get(i, j) * sf.get(i, 0)); + assertEquals("Quantized compression mismatch!", expected, qmb.get(i, j), 0.0); + } + } + } + + /** + * Test 3: Compare compression statistics of two matrices, m0 and m1, where m0 is derived as m0 = floor(m1 * sf) with + * sf = 0.5. + * + * - Compression for m0 is aborted at phase 1 (before co-code). - Compression for m1 should also be aborted at the + * same phase. - The resulting compression statistics for both matrices should match. + */ + @Test + public void testQuantizationFusedCompressionAbortedBeforeCoCodeStats() { + double[][] values0 = {{0, 1, 1, 2, 2}, {3, 3, 4, 4, 5}, {5, 6, 6, 7, 7}, {8, 8, 9, 9, 10}, {10, 11, 11, 12, 12}, + {13, 13, 14, 14, 15}}; + MatrixBlock m0 = DataConverter.convertToMatrixBlock(values0); + m0.recomputeNonZeros(); + + Pair<MatrixBlock, CompressionStatistics> cm0 = CompressedMatrixBlockFactory.compress(m0); + CompressionStatistics stats0 = cm0.getRight(); + + MatrixBlock m1 = new MatrixBlock(6, 5, false); + int val = 1; + for(int i = 0; i < 6; i++) { + for(int j = 0; j < 5; j++) { + m1.set(i, j, val++); + } + } + m1.recomputeNonZeros(); + + DoubleObject sf = new DoubleObject(0.5); + Pair<MatrixBlock, CompressionStatistics> cm1 = CompressedMatrixBlockFactory.compress(m1, sf, 1, null); + CompressionStatistics stats1 = cm1.getRight(); + + assertTrue("Compression statistics must match", stats0.toString().equals(stats1.toString())); + // Since m0 and m1 have different values their number of non-zero values is different + // assertEquals("Non-zero count should match", m0.getNonZeros(), m1.getNonZeros(), 0.1); + } + + /** + * Test 4: Compare compression statistics of two matrices, m0 and m1, where m0 is derived as m0 = floor(m1 * sf) with + * sf = 0.3. + * + * - Compression for m0 is aborted at phase 2 (after co-code). - Compression for m1 should also be aborted at the + * same phase. - The resulting compression statistics for both matrices should match. + */ + @Test + public void testQuantizationFusedCompressionAbortedAfterCoCodeStats() { + double[][] values1 = {{1, 8, 3, 4, 5}, {1, 2, 3, 4, 5}, {1, 2, 3, 4, 5}, {2, 3, 4, 5, 6}, {2, 3, 4, 5, 6}, + {3, 4, 5, 6, 7}}; + MatrixBlock m1 = DataConverter.convertToMatrixBlock(values1); + m1.recomputeNonZeros(); + + double scaleFactor = 0.3; + MatrixBlock m0 = new MatrixBlock(m1.getNumRows(), m1.getNumColumns(), false); + for(int i = 0; i < m1.getNumRows(); i++) { + for(int j = 0; j < m1.getNumColumns(); j++) { + m0.set(i, j, Math.floor(m1.get(i, j) * scaleFactor)); + } + } + m0.recomputeNonZeros(); + + Pair<MatrixBlock, CompressionStatistics> cm0 = CompressedMatrixBlockFactory.compress(m0); + CompressionStatistics stats0 = cm0.getRight(); + + DoubleObject sf = new DoubleObject(scaleFactor); + Pair<MatrixBlock, CompressionStatistics> cm1 = CompressedMatrixBlockFactory.compress(m1, sf, 1, null); + CompressionStatistics stats1 = cm1.getRight(); + + assertTrue("Compression statistics must match", stats0.toString().equals(stats1.toString())); + // Since m0 and m1 have different values their number of non-zero values is different + // assertEquals("Non-zero count should match", m0.getNonZeros(), m1.getNonZeros(), 0.1); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/qcompress/QuantizationFusedForcedCompressionTypesTest.java b/src/test/java/org/apache/sysds/test/component/compress/qcompress/QuantizationFusedForcedCompressionTypesTest.java new file mode 100644 index 0000000000..3b07cc62f5 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/qcompress/QuantizationFusedForcedCompressionTypesTest.java @@ -0,0 +1,348 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information regarding copyright ownership. +* The ASF licenses this file to you under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software distributed under the License +* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, +* either express or implied. See the License for the specific language governing permissions and limitations +* under the License. +*/ + +package org.apache.sysds.test.component.compress.qcompress; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.sysds.runtime.compress.CompressionSettings; +import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; +import org.apache.sysds.runtime.compress.cocode.CoCoderFactory; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; +import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; +import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory; +import org.apache.sysds.runtime.compress.colgroup.ColGroupOLE; +import org.apache.sysds.runtime.compress.colgroup.ColGroupRLE; +import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC; +import org.apache.sysds.runtime.compress.colgroup.ColGroupSDCSingle; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; +import org.apache.sysds.runtime.compress.cost.ACostEstimate; +import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory; +import org.apache.sysds.runtime.compress.estim.AComEst; +import org.apache.sysds.runtime.compress.estim.ComEstFactory; +import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; +import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class QuantizationFusedForcedCompressionTypesTest { + + private static final int K = 4; + private static final long SEED = 1234; + + /** + * Test 1: Test the Uncompressed column group by directly calling the create method. + * + * m0 is generated as a floored matrix. m1 is generated as a full-precision matrix, but will be internally multiplied + * by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). The best compression types for both matrices are + * DDC, but we explicitly create UNCOMPRESSED columns. + * + */ + @Test + public void testForcedUncompressed() { + try { + MatrixBlock m0 = generateTestMatrix(10000, 500, -100, 100, 1.0, SEED, true); + MatrixBlock m1 = generateTestMatrix(10000, 500, -100, 100, 1.0, SEED, false); + + CompressionSettings cs0 = createCompressionSettings(null); + CompressionSettings cs1 = createCompressionSettings(new double[] {1.0}); + + Pair<CompressedSizeInfo, AComEst> compressedGroupsResult0 = generateCompressedGroups(m0, cs0); + CompressedSizeInfo compressedGroups0 = compressedGroupsResult0.getLeft(); + + Pair<CompressedSizeInfo, AComEst> compressedGroupsResult1 = generateCompressedGroups(m1, cs1); + CompressedSizeInfo compressedGroups1 = compressedGroupsResult1.getLeft(); + + assertEquals("Mismatch in number of compressed groups", compressedGroups0.getInfo().size(), + compressedGroups1.getInfo().size(), 0.0); + + for(int i = 0; i < compressedGroups0.getInfo().size(); i++) { + AColGroup colGroup0 = ColGroupUncompressed.create(compressedGroups0.getInfo().get(i).getColumns(), m0, + cs0.transposed); + AColGroup colGroup1 = ColGroupUncompressed.createQuantized(compressedGroups1.getInfo().get(i).getColumns(), + m1, cs1.transposed, cs1.scaleFactors); + + assertEquals("Mismatch in column group sum", colGroup0.getSum(m0.getNumRows()), + colGroup1.getSum(m1.getNumRows()), 0.0); + } + } + catch(Exception e) { + e.printStackTrace(); + fail("Compression extraction failed: " + e.getMessage()); + } + } + + /** + * Test 2: Test the RLE compression type by forcing RLE in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored column matrix. m1 is generated as a full-precision column matrix, but will be + * internally multiplied by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). Reaches + * extractBitmapSingleColumn(). + */ + @Test + public void testForcedRLETypeSingleColumn() { + testForcedCompressionTypeSingleColumn(CompressionType.RLE, ColGroupRLE.class); + } + + /** + * Test 3: Test the RLE compression type by forcing RLE in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored matrix. m1 is generated as a full-precision matrix, but will be internally multiplied + * by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). Reaches extractBitmapMultiColumns(). + * + */ + @Test + public void testForcedRLETypeMultiColumn() { + testForcedCompressionTypeMultiColumn(CompressionType.RLE, ColGroupRLE.class); + } + + /** + * Test 4: Test the OLE compression type by forcing OLE in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored column matrix. m1 is generated as a full-precision column matrix, but will be + * internally multiplied by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). Reaches + * extractBitmapSingleColumn(). + */ + @Test + public void testForcedOLETypeSingleColumn() { + testForcedCompressionTypeSingleColumn(CompressionType.OLE, ColGroupOLE.class); + } + + /** + * Test 5: Test the OLE compression type by forcing OLE in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored matrix. m1 is generated as a full-precision matrix, but will be internally multiplied + * by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). Reaches extractBitmapMultiColumn(). + */ + @Test + public void testForcedOLETypeMultiColumn() { + testForcedCompressionTypeMultiColumn(CompressionType.OLE, ColGroupOLE.class); + } + + /** + * Test 6: Test the SDC compression type by forcing SDC in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored column matrix. m1 is generated as a full-precision column matrix, but will be + * internally multiplied by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). Reaches + * extractBitmapSingleColumn(). This should also cover CONST, EMPTY, SDCFOR. + */ + @Test + public void testForcedSDCTypeSingleColumn() { + testForcedCompressionTypeSingleColumn(CompressionType.SDC, ColGroupSDC.class); + } + + /** + * Test 7: Test the SDC compression type by forcing SDC in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored matrix. m1 is generated as a full-precision matrix, but will be internally multiplied + * by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). Reaches extractBitmapMultiColumn(). This should also + * cover CONST, EMPTY, SDCFOR. + */ + @Test + public void testForcedSDCTypeMultiColumn() { + testForcedCompressionTypeMultiColumn(CompressionType.SDC, ColGroupSDCSingle.class); + } + + /** + * Test 8: Test the DDC compression type by forcing DDC in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored column matrix. m1 is generated as a full-precision column matrix, but will be + * internally multiplied by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). Reaches + * directCompressDDCSingleCol(). This should also cover DDCFOR. + */ + @Test + public void testForcedDDCTypeSingleColumn() { + testForcedCompressionTypeSingleColumn(CompressionType.DDC, ColGroupDDC.class); + } + + /** + * Test 9: Test the DDC compression type by forcing DDC in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored matrix. m1 is generated as a full-precision matrix, but will be internally multiplied + * by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). Reaches directCompressDDCMultiCol(). This should + * also cover DDCFOR. + */ + @Test + public void testForcedDDCTypeMultiColumn() { + testForcedCompressionTypeMultiColumn(CompressionType.DDC, ColGroupDDC.class); + } + + /** + * Test the given compression type by forcing it in each CompressedSizeInfoColGroup. + * + * m0 is generated as a floored column matrix. m1 is generated as a full-precision column matrix, but will be + * internally multiplied by 1.0 and floored. Essentially m0 = floor(m1 * scaleFactor). Reaches + * extractBitmapSingleColumn(). + */ + private void testForcedCompressionTypeSingleColumn(CompressionType compressionType, + Class<? extends AColGroup> expectedGroupClass) { + try { + int nRow = 100; + int nCol = 1; + int max = 50; + int min = -50; + double s = 1.0; + + MatrixBlock m0 = generateTestMatrix(nRow, nCol, min, max, s, SEED, true); + MatrixBlock m1 = generateTestMatrix(nRow, nCol, min, max, s, SEED, false); + + CompressionSettings cs0 = createCompressionSettings(null); + CompressionSettings cs1 = createCompressionSettings(new double[] {1.0}); + + List<AColGroup> results0 = compressWithForcedTypeNoCoCode(m0, cs0, compressionType); + List<AColGroup> results1 = compressWithForcedTypeNoCoCode(m1, cs1, compressionType); + + assertEquals("Mismatch in number of resulting column groups", results0.size(), results1.size(), 0.0); + + for(int i = 0; i < results0.size(); i++) { + assertInstanceOf(expectedGroupClass, results0.get(i), "Mismatch in forced compression type"); + assertInstanceOf(expectedGroupClass, results1.get(i), "Mismatch in forced compression type"); + + assertEquals("Mismatch in sum of values in column group", results0.get(i).getSum(nRow), + results1.get(i).getSum(nRow), 0.0); + } + } + catch(Exception e) { + e.printStackTrace(); + fail("Compression extraction failed: " + e.getMessage()); + } + } + + private void testForcedCompressionTypeMultiColumn(CompressionType compressionType, + Class<? extends AColGroup> expectedGroupClass) { + try { + double[][] values = {{1.5, 2.5, 3.5, 4.5, 5.5}, {1.5, 2.5, 3.5, 4.5, 5.5}, {1.5, 2.5, 3.5, 4.5, 5.5}, + {2.5, 3.5, 4.5, 5.5, 6.5}, {2.5, 3.5, 4.5, 5.5, 6.5}, {2.5, 3.5, 4.5, 5.5, 6.5},}; + + int nRow = values.length; + + MatrixBlock m0 = DataConverter.convertToMatrixBlock(values); + m0 = TestUtils.floor(m0); + m0.recomputeNonZeros(); + + MatrixBlock m1 = DataConverter.convertToMatrixBlock(values); + + CompressionSettings cs0 = createCompressionSettings(null); + CompressionSettings cs1 = createCompressionSettings(new double[] {1.0}); + + List<AColGroup> results0 = compressWithForcedTypeCoCode(m0, cs0, compressionType); + List<AColGroup> results1 = compressWithForcedTypeCoCode(m1, cs1, compressionType); + + assertEquals("Mismatch in number of resulting column groups", results0.size(), results1.size(), 0.0); + + for(int i = 0; i < results0.size(); i++) { + assertInstanceOf(expectedGroupClass, results0.get(i), "Mismatch in forced compression type"); + assertInstanceOf(expectedGroupClass, results1.get(i), "Mismatch in forced compression type"); + assertEquals("Mismatch in sum of values in column group", results0.get(i).getSum(nRow), + results1.get(i).getSum(nRow), 0.0); + } + } + catch(Exception e) { + e.printStackTrace(); + fail("Compression extraction failed: " + e.getMessage()); + } + } + + private static void assertInstanceOf(Class<?> expected, Object obj, String message) { + if(!expected.isInstance(obj)) { + fail(message + ": Expected " + expected.getSimpleName() + ", but got " + obj.getClass().getSimpleName()); + } + } + + /** + * Generate compressed groups with an estimator. + */ + private static Pair<CompressedSizeInfo, AComEst> generateCompressedGroups(MatrixBlock matrix, + CompressionSettings cs) { + AComEst estimator = ComEstFactory.createEstimator(matrix, cs, K); + CompressedSizeInfo sizeInfo = estimator.computeCompressedSizeInfos(K); + return Pair.of(sizeInfo, estimator); + } + + /** + * Force a specific compression type (e.g., RLE) on a set of compressed groups. + */ + private static List<AColGroup> compressWithForcedTypeNoCoCode(MatrixBlock matrix, CompressionSettings cs, + CompressionType type) { + Pair<CompressedSizeInfo, AComEst> result = generateCompressedGroups(matrix, cs); + CompressedSizeInfo originalGroups = result.getLeft(); + List<CompressedSizeInfoColGroup> modifiedGroups = forceCompressionType(originalGroups, type); + CompressedSizeInfo compressedGroupsNew = new CompressedSizeInfo(modifiedGroups); + return ColGroupFactory.compressColGroups(matrix, compressedGroupsNew, cs, K); + } + + /** + * Force a specific compression type (e.g., RLE) on a set of compressed groups with CoCode. + */ + private static List<AColGroup> compressWithForcedTypeCoCode(MatrixBlock matrix, CompressionSettings cs, + CompressionType type) { + Pair<CompressedSizeInfo, AComEst> result = generateCompressedGroups(matrix, cs); + CompressedSizeInfo originalGroups = result.getLeft(); + AComEst estimator = result.getRight(); + ACostEstimate ice = CostEstimatorFactory.create(cs, null, matrix.getNumRows(), matrix.getNumColumns(), + matrix.getSparsity()); + originalGroups = CoCoderFactory.findCoCodesByPartitioning(estimator, originalGroups, K, ice, cs); + List<CompressedSizeInfoColGroup> modifiedGroups = forceCompressionType(originalGroups, type); + CompressedSizeInfo compressedGroupsNew = new CompressedSizeInfo(modifiedGroups); + return ColGroupFactory.compressColGroups(matrix, compressedGroupsNew, cs, K); + } + + /** + * Modify the compression type of each group to a specific type. + */ + private static List<CompressedSizeInfoColGroup> forceCompressionType(CompressedSizeInfo originalGroups, + CompressionType type) { + List<CompressedSizeInfoColGroup> modifiedGroups = new ArrayList<>(); + for(CompressedSizeInfoColGroup cg : originalGroups.getInfo()) { + Set<CompressionType> compressionTypes = new HashSet<>(); + compressionTypes.add(type); + modifiedGroups + .add(new CompressedSizeInfoColGroup(cg.getColumns(), cg.getFacts(), compressionTypes, cg.getMap())); + } + return modifiedGroups; + } + + /** + * Generate a test matrix with specified dimensions, value range, and sparsity. + */ + private static MatrixBlock generateTestMatrix(int nRow, int nCol, int min, int max, double s, long seed, + boolean floored) { + MatrixBlock mb = TestUtils.generateTestMatrixBlock(nRow, nCol, min, max, s, seed); + return floored ? TestUtils.floor(mb) : mb; + } + + /** + * Create compression settings with an optional scale factor. + */ + private static CompressionSettings createCompressionSettings(double[] scaleFactor) { + CompressionSettingsBuilder builder = new CompressionSettingsBuilder(); + // .setColumnPartitioner(PartitionerType.GREEDY).setSeed((int) SEED); + if(scaleFactor != null) { + builder.setScaleFactor(scaleFactor); + } + return builder.create(); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/readers/ReadersTest.java b/src/test/java/org/apache/sysds/test/component/compress/readers/ReadersTest.java index b4054169af..ae92d3a431 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/readers/ReadersTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/readers/ReadersTest.java @@ -19,7 +19,9 @@ package org.apache.sysds.test.component.compress.readers; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -31,6 +33,7 @@ import org.apache.sysds.runtime.compress.utils.DblArrayCountHashMap; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.test.TestUtils; import org.junit.Test; +import java.util.Arrays; public class ReadersTest { @@ -83,4 +86,45 @@ public class ReadersTest { mb.allocateDenseBlock(); ReaderColumnSelection.createReader(mb, ColIndexFactory.create(2), false, 10, 9); } + + @Test + public void testReaderColumnSelectionQuantized() { + + // 4.0 0.0 + // 3.0 0.0 + // 0.0 5.0 + + MatrixBlock mb = new MatrixBlock(3, 2, false); + mb.allocateDenseBlock(); + mb.set(0, 0, 4); + mb.set(1, 0, 3); + mb.set(2, 1, 5); + + double[][] scaleFactorCases = { + {0.3}, // Scalar case + {0.3, 0.4, 0.5} // Per-row scale factor + }; + + for (double[] scaleFactors : scaleFactorCases) { + ReaderColumnSelection r = ReaderColumnSelection.createQuantizedReader( + mb, ColIndexFactory.create(2), false, scaleFactors); + + double[][] expectedValues = { + { Math.floor(4 * (scaleFactors.length > 1 ? scaleFactors[0] : scaleFactors[0])), Math.floor(0.0 * (scaleFactors.length > 1 ? scaleFactors[0] : scaleFactors[0])) }, + { Math.floor(3 * (scaleFactors.length > 1 ? scaleFactors[1] : scaleFactors[0])), Math.floor(0.0 * (scaleFactors.length > 1 ? scaleFactors[1] : scaleFactors[0])) }, + { Math.floor(0.0 * (scaleFactors.length > 1 ? scaleFactors[2] : scaleFactors[0])), Math.floor(5 * (scaleFactors.length > 1 ? scaleFactors[2] : scaleFactors[0])) } + }; + + DblArray d; + int rowIndex = 0; + while ((d = r.nextRow()) != null) { + assertNotNull("Row " + rowIndex + " should not be null", d); + assertArrayEquals("Mismatch for scaleFactors " + Arrays.toString(scaleFactors), + expectedValues[rowIndex], d.getData(), 0.0); + rowIndex++; + } + } + } + + } diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteQuantizationFusedCompressionTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteQuantizationFusedCompressionTest.java new file mode 100644 index 0000000000..3a9dfa48dd --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteQuantizationFusedCompressionTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.utils.Statistics; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysds.hops.OptimizerUtils; +import java.util.Arrays; + +/** + * Test for the rewrite that replaces a sequence of X = floor(M * sf) Y = compress(X) to a fused quantize_compress(M, + * sf). + * + */ +public class RewriteQuantizationFusedCompressionTest extends AutomatedTestBase { + private static final String TEST_NAME1 = "RewriteQuantizationFusedCompressionScalar"; + private static final String TEST_NAME2 = "RewriteQuantizationFusedCompressionMatrix"; + private static final String TEST_DIR = "functions/rewrite/"; + private static final String TEST_CLASS_DIR = TEST_DIR + + RewriteQuantizationFusedCompressionTest.class.getSimpleName() + "/"; + + private static final int rows = 500; + private static final int cols = 500; + private static final double sfValue = 0.5; // Value used to fill the scale factor matrix or as a standalone scalar + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"})); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"R"})); + } + + @Test + public void testRewriteQuantizationFusedCompressionScalar() { + testRewriteQuantizationFusedCompression(TEST_NAME1, true, true); + } + + @Test + public void testRewriteQuantizationFusedCompressionNoRewriteScalar() { + testRewriteQuantizationFusedCompression(TEST_NAME1, false, true); + } + + @Test + public void testRewriteQuantizationFusedCompression() { + testRewriteQuantizationFusedCompression(TEST_NAME2, true, false); + } + + @Test + public void testRewriteQuantizationFusedCompressionNoRewrite() { + testRewriteQuantizationFusedCompression(TEST_NAME2, false, false); + } + + /** + * Unified method to test both scalar and matrix scale factors. + * + * @param testname Test name + * @param rewrites Whether to enable fusion rewrites + * @param isScalar Whether the scale factor is a scalar or a matrix + */ + private void testRewriteQuantizationFusedCompression(String testname, boolean rewrites, boolean isScalar) { + boolean oldRewriteFlag = OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE; + OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE = rewrites; + + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + + double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.70d, 7); + + String[] programArgs; + if(isScalar) { + // Scalar case: pass sfValue as a string + String s = Double.toString(sfValue); + programArgs = new String[] {"-stats", "-args", input("A"), s, output("R")}; + writeInputMatrixWithMTD("A", A, 174522, false); + } + else { + // Matrix case: pass S as a separate matrix + double[][] S = new double[rows][1]; + for(int i = 0; i < rows; i++) { + S[i][0] = sfValue; + } + programArgs = new String[] {"-stats", "-args", input("A"), input("S"), output("R")}; + writeInputMatrixWithMTD("A", A, 174522, false); + writeInputMatrixWithMTD("S", S, 500, false); + } + + this.programArgs = programArgs; + runTest(true, false, null, -1); + + // Simple check if quantization indeed occured by computing expected sum + // Even if compression is aborted, the quantization step should still take effect + double expectedR = Arrays.stream(A).flatMapToDouble(Arrays::stream).map(x -> Math.floor(x * sfValue)).sum(); + double actualR = TestUtils.readDMLScalar(output("R")); + + Assert.assertEquals("Mismatch in expected sum after quantization and compression", expectedR, actualR, 0.0); + + // Check if fusion occurred + if(rewrites) { + Assert.assertEquals("Expected fused operation count mismatch", 1, + Statistics.getCPHeavyHitterCount(Opcodes.QUANTIZE_COMPRESS.toString())); + Assert.assertEquals("Expected no separate floor op", 0, + Statistics.getCPHeavyHitterCount(Opcodes.FLOOR.toString())); + Assert.assertEquals("Expected no separate compress op", 0, + Statistics.getCPHeavyHitterCount(Opcodes.COMPRESS.toString())); + Assert.assertEquals("Expected no separate multiplication op", 0, + Statistics.getCPHeavyHitterCount(Opcodes.MULT.toString())); + } + else { + Assert.assertEquals("Expected no fused op", 0, + Statistics.getCPHeavyHitterCount(Opcodes.QUANTIZE_COMPRESS.toString())); + Assert.assertEquals("Expected separate floor op", 1, + Statistics.getCPHeavyHitterCount(Opcodes.FLOOR.toString())); + Assert.assertEquals("Expected separate compress op", 1, + Statistics.getCPHeavyHitterCount(Opcodes.COMPRESS.toString())); + Assert.assertEquals("Expected separate multiplication op", 1, + Statistics.getCPHeavyHitterCount(Opcodes.MULT.toString())); + } + } + finally { + OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE = oldRewriteFlag; + } + } +} diff --git a/src/test/scripts/functions/rewrite/RewriteQuantizationFusedCompressionMatrix.dml b/src/test/scripts/functions/rewrite/RewriteQuantizationFusedCompressionMatrix.dml new file mode 100644 index 0000000000..38c6d712c6 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteQuantizationFusedCompressionMatrix.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +# Load matrix A +A = read($1); + +# Load vecotr/matrix scale factor S +S = read($2); + +# Quantize +B = floor(A * S); + +# Compress +C = compress(B); + +# Write the sum, as writing a compressed matrix is complicated +R = sum(C); +write(R, $3); diff --git a/src/test/scripts/functions/rewrite/RewriteQuantizationFusedCompressionScalar.dml b/src/test/scripts/functions/rewrite/RewriteQuantizationFusedCompressionScalar.dml new file mode 100644 index 0000000000..c6ccfb2a21 --- /dev/null +++ b/src/test/scripts/functions/rewrite/RewriteQuantizationFusedCompressionScalar.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +# Load matrix A +A = read($1); + +# Load scalar scale factor S +s = as.double($2); + +# Quantize +B = floor(A * s); + +# Compress +C = compress(B); + +# Write the sum, as writing a compressed matrix is complicated +R = sum(C); +write(R, $3);