This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 38ec722a53 [MINOR] Adding a factory method for MatrixSketch
38ec722a53 is described below
commit 38ec722a53557037b79eb6259ce17cf2b850af4e
Author: Badrul Chowdhury <[email protected]>
AuthorDate: Fri Nov 25 11:02:30 2022 -0800
[MINOR] Adding a factory method for MatrixSketch
This patch introduces a factory method for sketches.
This will centralize the creation of all sketches in one place and
prevent duplication of operator switching and validation logic.
Closes #1738
---
.../spark/AggregateUnarySketchSPInstruction.java | 16 ++++----
.../matrix/data/LibMatrixCountDistinct.java | 40 ++++++--------------
.../runtime/matrix/data/sketch/SketchFactory.java | 44 ++++++++++++++++++++++
3 files changed, 63 insertions(+), 37 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
index 767e4b0c0b..bfdecc635a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySketchSPInstruction.java
@@ -117,7 +117,7 @@ public class AggregateUnarySketchSPInstruction extends
UnarySPInstruction {
out1.fold(new CorrMatrixBlock(new MatrixBlock()),
new
AggregateUnarySketchUnionAllFunction(this.op));
- MatrixBlock out3 =
LibMatrixCountDistinct.countDistinctValuesFromSketch(out2, this.op);
+ MatrixBlock out3 =
LibMatrixCountDistinct.countDistinctValuesFromSketch(this.op, out2);
// put output block into symbol table (no lineage because single
block)
// this also includes implicit maintenance of matrix
characteristics
@@ -180,7 +180,7 @@ public class AggregateUnarySketchSPInstruction extends
UnarySPInstruction {
MatrixIndexes ixOut = new MatrixIndexes();
this.op.indexFn.execute(ixIn, ixOut);
- return LibMatrixCountDistinct.createSketch(blkIn, this.op);
+ return LibMatrixCountDistinct.createSketch(this.op, blkIn);
}
}
@@ -207,7 +207,7 @@ public class AggregateUnarySketchSPInstruction extends
UnarySPInstruction {
return arg0;
}
- return LibMatrixCountDistinct.unionSketch(arg0, arg1, this.op);
+ return LibMatrixCountDistinct.unionSketch(this.op, arg0, arg1);
}
}
@@ -246,7 +246,7 @@ public class AggregateUnarySketchSPInstruction extends
UnarySPInstruction {
public CorrMatrixBlock call(MatrixBlock arg0)
throws Exception {
- return LibMatrixCountDistinct.createSketch(arg0, this.op);
+ return LibMatrixCountDistinct.createSketch(this.op, arg0);
}
}
@@ -261,8 +261,8 @@ public class AggregateUnarySketchSPInstruction extends
UnarySPInstruction {
@Override
public CorrMatrixBlock call(CorrMatrixBlock arg0, MatrixBlock arg1)
throws Exception {
- CorrMatrixBlock arg1WithCorr =
LibMatrixCountDistinct.createSketch(arg1, this.op);
- return LibMatrixCountDistinct.unionSketch(arg0, arg1WithCorr,
this.op);
+ CorrMatrixBlock arg1WithCorr =
LibMatrixCountDistinct.createSketch(this.op, arg1);
+ return LibMatrixCountDistinct.unionSketch(this.op, arg0,
arg1WithCorr);
}
}
@@ -277,7 +277,7 @@ public class AggregateUnarySketchSPInstruction extends
UnarySPInstruction {
@Override
public CorrMatrixBlock call(CorrMatrixBlock arg0, CorrMatrixBlock
arg1) throws Exception {
- return LibMatrixCountDistinct.unionSketch(arg0, arg1, this.op);
+ return LibMatrixCountDistinct.unionSketch(this.op, arg0, arg1);
}
}
@@ -292,7 +292,7 @@ public class AggregateUnarySketchSPInstruction extends
UnarySPInstruction {
@Override
public MatrixBlock call(CorrMatrixBlock arg0) throws Exception {
- return LibMatrixCountDistinct.countDistinctValuesFromSketch(arg0,
this.op);
+ return
LibMatrixCountDistinct.countDistinctValuesFromSketch(this.op, arg0);
}
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
index ccddb4db80..72bcd64b43 100644
---
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
+++
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixCountDistinct.java
@@ -31,8 +31,8 @@ import org.apache.sysds.api.DMLException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.data.*;
import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
-import org.apache.sysds.runtime.matrix.data.sketch.CountDistinctSketch;
-import
org.apache.sysds.runtime.matrix.data.sketch.countdistinct.CountDistinctFunctionSketch;
+import org.apache.sysds.runtime.matrix.data.sketch.MatrixSketch;
+import org.apache.sysds.runtime.matrix.data.sketch.SketchFactory;
import
org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox.KMVSketch;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
@@ -356,36 +356,18 @@ public interface LibMatrixCountDistinct {
return distinct.size();
}
- static MatrixBlock countDistinctValuesFromSketch(CorrMatrixBlock arg0,
CountDistinctOperator op) {
- if(op.getOperatorType() == CountDistinctOperatorTypes.COUNT)
- return new
CountDistinctFunctionSketch(op).getValueFromSketch(arg0);
- else if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
- return new KMVSketch(op).getValueFromSketch(arg0);
- else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
- throw new NotImplementedException("Not implemented
yet");
- else
- throw new NotImplementedException("Not implemented
yet");
+ static MatrixBlock countDistinctValuesFromSketch(CountDistinctOperator
op, CorrMatrixBlock corrBlkIn) {
+ MatrixSketch sketch = SketchFactory.get(op);
+ return sketch.getValueFromSketch(corrBlkIn);
}
- static CorrMatrixBlock createSketch(MatrixBlock blkIn,
CountDistinctOperator op) {
- if(op.getOperatorType() == CountDistinctOperatorTypes.COUNT)
- return new
CountDistinctFunctionSketch(op).create(blkIn);
- else if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
- return new KMVSketch(op).create(blkIn);
- else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
- throw new NotImplementedException("Not implemented
yet");
- else
- throw new NotImplementedException("Not implemented
yet");
+ static CorrMatrixBlock createSketch(CountDistinctOperator op,
MatrixBlock blkIn) {
+ MatrixSketch sketch = SketchFactory.get(op);
+ return sketch.create(blkIn);
}
- static CorrMatrixBlock unionSketch(CorrMatrixBlock arg0,
CorrMatrixBlock arg1, CountDistinctOperator op) {
- if(op.getOperatorType() == CountDistinctOperatorTypes.COUNT)
- return new CountDistinctFunctionSketch(op).union(arg0,
arg1);
- else if(op.getOperatorType() == CountDistinctOperatorTypes.KMV)
- return new KMVSketch(op).union(arg0, arg1);
- else if(op.getOperatorType() == CountDistinctOperatorTypes.HLL)
- throw new NotImplementedException("Not implemented
yet");
- else
- throw new NotImplementedException("Not implemented
yet");
+ static CorrMatrixBlock unionSketch(CountDistinctOperator op,
CorrMatrixBlock corrBlkIn0, CorrMatrixBlock corrBlkIn1) {
+ MatrixSketch sketch = SketchFactory.get(op);
+ return sketch.union(corrBlkIn0, corrBlkIn1);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/SketchFactory.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/SketchFactory.java
new file mode 100644
index 0000000000..434582374d
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/matrix/data/sketch/SketchFactory.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data.sketch;
+
+import org.apache.commons.lang.NotImplementedException;
+import
org.apache.sysds.runtime.matrix.data.sketch.countdistinct.CountDistinctFunctionSketch;
+import
org.apache.sysds.runtime.matrix.data.sketch.countdistinctapprox.KMVSketch;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
+import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class SketchFactory {
+ public static MatrixSketch get(Operator op) {
+ if (op instanceof CountDistinctOperator) {
+ CountDistinctOperator cdop = (CountDistinctOperator) op;
+ if (cdop.getOperatorType() ==
CountDistinctOperatorTypes.COUNT) {
+ return new CountDistinctFunctionSketch(op);
+ } else if (cdop.getOperatorType() ==
CountDistinctOperatorTypes.KMV) {
+ return new KMVSketch(op);
+ } else {
+ throw new NotImplementedException("Only COUNT
and KMV count distinct sketches are supported for now");
+ }
+ } else {
+ throw new IllegalArgumentException("Only sketches for
count distinct operators are supported for now");
+ }
+ }
+}