This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 1c60cc5e35 [MINOR] Improve parameter validation of countDistinct
1c60cc5e35 is described below

commit 1c60cc5e35f472fefbe1df8cf3c6877fa2fe4ba8
Author: Badrul Chowdhury <[email protected]>
AuthorDate: Sat Nov 5 19:27:06 2022 -0700

    [MINOR] Improve parameter validation of countDistinct
    
    This patch improves the validation of parameters to aliases for
    countDistinct() and countDistinctApprox(). The aliases have also been
    renamed for consistency with other builtin functions:
    
    - countDistinctRow() -> rowCountDistinct()
    - countDistinctCol() -> colCountDistinct()
    - countDistinctApproxRow() -> rowCountDistinctApprox()
    - countDistinctApproxCol() -> colCountDistinctApprox()
    
    countDistinctApproxRow() and countDistinctApproxCol() only accept a
    single additional parameter for type (default=KMV), which is optional.
    countDistinctRow() and countDistinctCol() have been converted from
    parameterized builtin functions to non-parameterized builtins as The
    aliases specify their respective directions implicitly.
    
    Closes #1722
---
 .../java/org/apache/sysds/common/Builtins.java     |  8 +--
 src/main/java/org/apache/sysds/common/Types.java   |  2 +-
 .../org/apache/sysds/lops/PartialAggregate.java    |  4 +-
 .../sysds/parser/BuiltinFunctionExpression.java    | 22 +++++-
 .../org/apache/sysds/parser/DMLTranslator.java     | 13 +++-
 .../ParameterizedBuiltinFunctionExpression.java    | 54 +++++++++------
 .../functions/countDistinct/CountDistinctBase.java |  6 +-
 .../CountDistinctColAliasException.java            | 77 +++++++++++++++++++++
 .../CountDistinctRowAliasException.java            | 77 +++++++++++++++++++++
 .../CountDistinctApproxColAliasException.java      | 78 ++++++++++++++++++++++
 .../CountDistinctApproxRowAliasException.java      | 78 ++++++++++++++++++++++
 .../countDistinct/countDistinctColAlias.dml        |  2 +-
 .../countDistinctColAliasException.dml}            |  2 +-
 .../countDistinct/countDistinctRowAlias.dml        |  2 +-
 .../countDistinctRowAliasException.dml}            |  2 +-
 .../countDistinctApproxColAlias.dml                |  2 +-
 ...ml => countDistinctApproxColAliasException.dml} |  2 +-
 .../countDistinctApproxRowAlias.dml                |  2 +-
 ...ml => countDistinctApproxRowAliasException.dml} |  2 +-
 19 files changed, 395 insertions(+), 40 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 262212570e..5afef9c308 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -74,6 +74,7 @@ public enum Builtins {
        CBIND("cbind", "append", false),
        CEIL("ceil", "ceiling", false),
        CHOLESKY("cholesky", false),
+       COL_COUNT_DISTINCT("colCountDistinct",false),
        COLMAX("colMaxs", false),
        COLMEAN("colMeans", false),
        COLMIN("colMins", false),
@@ -245,6 +246,7 @@ public enum Builtins {
        REMOVE("remove", false, ReturnType.MULTI_RETURN),
        REV("rev", false),
        ROUND("round", false),
+       ROW_COUNT_DISTINCT("rowCountDistinct",false),
        ROWINDEXMAX("rowIndexMax", false),
        ROWINDEXMIN("rowIndexMin", false),
        ROWMAX("rowMaxs", false),
@@ -308,11 +310,9 @@ public enum Builtins {
        AUTODIFF("autoDiff", false, true),
        CDF("cdf", false, true),
        COUNT_DISTINCT("countDistinct",false, true),
-       COUNT_DISTINCT_ROW("countDistinctRow",false, true),
-       COUNT_DISTINCT_COL("countDistinctCol",false, true),
        COUNT_DISTINCT_APPROX("countDistinctApprox", false, true),
-       COUNT_DISTINCT_APPROX_ROW("countDistinctApproxRow", false, true),
-       COUNT_DISTINCT_APPROX_COL("countDistinctApproxCol", false, true),
+       COUNT_DISTINCT_APPROX_ROW("rowCountDistinctApprox", false, true),
+       COUNT_DISTINCT_APPROX_COL("colCountDistinctApprox", false, true),
        CVLM("cvlm", true, false),
        GROUPEDAGG("aggregate", "groupedAggregate", false, true),
        INVCDF("icdf", false, true),
diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index 991284c13e..7c3a3f1e53 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -198,7 +198,7 @@ public class Types
                PROD(4), SUM_PROD(5),
                TRACE(6), MEAN(7), VAR(8),
                MAXINDEX(9), MININDEX(10),
-               COUNT_DISTINCT(11), COUNT_DISTINCT_ROW(12), 
COUNT_DISTINCT_COL(13),
+               COUNT_DISTINCT(11), ROW_COUNT_DISTINCT(12), 
COL_COUNT_DISTINCT(13),
                COUNT_DISTINCT_APPROX(14), COUNT_DISTINCT_APPROX_ROW(15), 
COUNT_DISTINCT_APPROX_COL(16);
 
                @Override
diff --git a/src/main/java/org/apache/sysds/lops/PartialAggregate.java 
b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
index 0481c7373a..1a7d22b989 100644
--- a/src/main/java/org/apache/sysds/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysds/lops/PartialAggregate.java
@@ -352,10 +352,10 @@ public class PartialAggregate extends Lop
                                }
                        }
 
-                       case COUNT_DISTINCT_ROW:
+                       case ROW_COUNT_DISTINCT:
                                return "uacdr";
 
-                       case COUNT_DISTINCT_COL:
+                       case COL_COUNT_DISTINCT:
                                return "uacdc";
 
                        case COUNT_DISTINCT_APPROX: {
diff --git 
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index c3aca47d38..5634e02be9 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -1606,6 +1606,26 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        else
                                raiseValidateError("Compress/DeCompress 
instruction not allowed in dml script");
                        break;
+               case ROW_COUNT_DISTINCT:
+                       checkNumParameters(1);
+                       checkMatrixParam(getFirstExpr());
+                       output.setDataType(DataType.MATRIX);
+                       output.setDimensions(id.getDim1(), 1);
+                       output.setBlocksize (id.getBlocksize());
+                       output.setValueType(ValueType.INT64);
+                       output.setNnz(id.getDim1());
+                       break;
+
+               case COL_COUNT_DISTINCT:
+                       checkNumParameters(1);
+                       checkMatrixParam(getFirstExpr());
+                       output.setDataType(DataType.MATRIX);
+                       output.setDimensions(1, id.getDim2());
+                       output.setBlocksize (id.getBlocksize());
+                       output.setValueType(ValueType.INT64);
+                       output.setNnz(id.getDim2());
+                       break;
+
                default:
                        if( isMathFunction() ) {
                                checkMathFunctionParam();
@@ -1637,7 +1657,7 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        }
                }
        }
-       
+
        private void setBinaryOutputProperties(DataIdentifier output) {
                DataType dt1 = getFirstExpr().getOutput().getDataType();
                DataType dt2 = getSecondExpr().getOutput().getDataType();
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 553bf56fc5..0c3a6dfd8f 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2064,13 +2064,11 @@ public class DMLTranslator
                                                
AggOp.valueOf(source.getOpCode().name()), dir, paramHops.get("data"));
                                break;
 
-                       case COUNT_DISTINCT_ROW:
                        case COUNT_DISTINCT_APPROX_ROW:
                                currBuiltinOp = new 
AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
                                                
AggOp.valueOf(source.getOpCode().name()), Direction.Row, paramHops.get("data"));
                                break;
 
-                       case COUNT_DISTINCT_COL:
                        case COUNT_DISTINCT_APPROX_COL:
                                currBuiltinOp = new 
AggUnaryOp(target.getName(), DataType.MATRIX, target.getValueType(),
                                                
AggOp.valueOf(source.getOpCode().name()), Direction.Col, paramHops.get("data"));
@@ -2758,6 +2756,17 @@ public class DMLTranslator
                        setBlockSizeAndRefreshSizeInfo(expr, currBuiltinOp);
                        break;
                }
+
+               case ROW_COUNT_DISTINCT:
+                       currBuiltinOp = new AggUnaryOp(target.getName(), 
DataType.MATRIX, target.getValueType(),
+                                       
AggOp.valueOf(source.getOpCode().name()), Direction.Row, expr);
+                       break;
+
+               case COL_COUNT_DISTINCT:
+                       currBuiltinOp = new AggUnaryOp(target.getName(), 
DataType.MATRIX, target.getValueType(),
+                                       
AggOp.valueOf(source.getOpCode().name()), Direction.Col, expr);
+                       break;
+
                default:
                        throw new ParseException("Unsupported builtin function 
type: "+source.getOpCode());
                }
diff --git 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index bdfd38c5a4..7ef19badde 100644
--- 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -247,15 +247,16 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                        break;
 
                case COUNT_DISTINCT:
-               case COUNT_DISTINCT_ROW:
-               case COUNT_DISTINCT_COL:
                        validateCountDistinct(output, conditional);
                        break;
 
                case COUNT_DISTINCT_APPROX:
+                       validateCountDistinctApprox(output, conditional, false);
+                       break;
+
                case COUNT_DISTINCT_APPROX_ROW:
                case COUNT_DISTINCT_APPROX_COL:
-                       validateCountDistinctApprox(output, conditional);
+                       validateCountDistinctApprox(output, conditional, true);
                        break;
 
                default: //always unconditional (because unsupported operation)
@@ -400,7 +401,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
                validateAggregationDirection(dataId, output);
        }
 
-       private void validateCountDistinctApprox(DataIdentifier output, boolean 
conditional) {
+       private void validateCountDistinctApprox(DataIdentifier output, boolean 
conditional, boolean isDirectionAlias) {
                Set<String> validTypeNames = CollectionUtils.asSet("KMV");
                HashMap<String, Expression> varParams = getVarParams();
 
@@ -411,13 +412,26 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
 
                // Validate the number of parameters
                String fname = getOpCode().getName();
-               String usageMessage = "function " + fname + " takes at least 1 
and at most 3 parameters";
-               if (varParams.size() < 1) {
-                       raiseValidateError("Too few parameters: " + 
usageMessage, conditional);
-               }
+               if (!isDirectionAlias) {
+                       // Function is not an alias, so we have to check for 
all 3 permissible parameters
+                       String usageMessage = "function " + fname + " takes at 
least 1 and at most 3 parameters";
+                       if (varParams.size() < 1) {
+                               raiseValidateError("Too few parameters: " + 
usageMessage, conditional);
+                       }
 
-               if (varParams.size() > 3) {
-                       raiseValidateError("Too many parameters: " + 
usageMessage, conditional);
+                       if (varParams.size() > 3) {
+                               raiseValidateError("Too many parameters: " + 
usageMessage, conditional);
+                       }
+               } else {
+                       // The direction is fixed for function aliases
+                       String usageMessage = "function " + fname + " takes at 
least 1 and at most 2 parameters";
+                       if (varParams.size() < 1) {
+                               raiseValidateError("Too few parameters: " + 
usageMessage, conditional);
+                       }
+
+                       if (varParams.size() > 2) {
+                               raiseValidateError("Too many parameters: " + 
usageMessage, conditional);
+                       }
                }
 
                // Check parameter names are valid
@@ -447,20 +461,22 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                        addVarParam("type", new StringIdentifier("KMV", this));
                }
 
-               checkStringParam(true, fname, "dir", conditional);
-               // Check data value of "dir" parameter
-               validateAggregationDirection(dataId, output);
+               if (!isDirectionAlias) {
+                       checkStringParam(true, fname, "dir", conditional);
+                       // Check data value of "dir" parameter
+                       validateAggregationDirection(dataId, output);
+               }
        }
 
        private void validateAggregationDirection(Identifier dataId, 
DataIdentifier output) {
                HashMap<String, Expression> varParams = getVarParams();
                if (varParams.containsKey("dir")) {
-                       String directionString = 
varParams.get("dir").toString().toUpperCase();
+                       String inputDirectionString = 
varParams.get("dir").toString().toUpperCase();
 
                        // Set output type and dimensions based on direction
 
                        // "r" -> count across all rows, resulting in a Mx1 
matrix
-                       if 
(directionString.equals(Types.Direction.Row.toString())) {
+                       if 
(inputDirectionString.equals(Types.Direction.Row.toString())) {
                                output.setDataType(DataType.MATRIX);
                                output.setDimensions(dataId.getDim1(), 1);
                                output.setBlocksize(dataId.getBlocksize());
@@ -468,7 +484,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
                                output.setNnz(dataId.getDim1());
 
                        // "c" -> count across all cols, resulting in a 1xN 
matrix
-                       } else if 
(directionString.equals(Types.Direction.Col.toString())) {
+                       } else if 
(inputDirectionString.equals(Types.Direction.Col.toString())) {
                                output.setDataType(DataType.MATRIX);
                                output.setDimensions(1, dataId.getDim2());
                                output.setBlocksize(dataId.getBlocksize());
@@ -476,16 +492,16 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                                output.setNnz(dataId.getDim2());
 
                        // "rc" -> count across all rows and cols in input 
matrix, resulting in a single value
-                       } else if 
(directionString.equals(Types.Direction.RowCol.toString())) {
+                       } else if 
(inputDirectionString.equals(Types.Direction.RowCol.toString())) {
                                output.setDataType(DataType.SCALAR);
                                output.setDimensions(0, 0);
                                output.setBlocksize(0);
                                output.setValueType(ValueType.INT64);
                                output.setNnz(1);
 
-                       // unrecognized value for "dir" parameter, should "cr" 
be valid?
+                       // unrecognized value for "dir" parameter
                        } else {
-                               raiseValidateError("Invalid argument: " + 
directionString + " is not recognized");
+                               raiseValidateError("Invalid argument: " + 
inputDirectionString + " is not recognized");
                        }
                } else {  // default to dir="rc"
                        output.setDataType(DataType.SCALAR);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
index 5bf850d49a..72838b63c9 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctBase.java
@@ -46,17 +46,17 @@ public abstract class CountDistinctBase extends 
AutomatedTestBase {
        public abstract void setUp();
 
        public void countDistinctScalarTest(long numberDistinct, int cols, int 
rows, double sparsity,
-               Types.ExecType instType, double tolerance) {
+                                                                               
Types.ExecType instType, double tolerance) {
                countDistinctTest(Types.Direction.RowCol, numberDistinct, cols, 
rows, sparsity, instType, tolerance);
        }
 
        public void countDistinctMatrixTest(Types.Direction dir, long 
numberDistinct, int cols, int rows, double sparsity,
-               Types.ExecType instType, double tolerance) {
+                                                                               
Types.ExecType instType, double tolerance) {
                countDistinctTest(dir, numberDistinct, cols, rows, sparsity, 
instType, tolerance);
        }
 
        public void countDistinctTest(Types.Direction dir, long numberDistinct, 
int cols, int rows, double sparsity,
-               Types.ExecType instType, double tolerance) {
+                                                                 
Types.ExecType instType, double tolerance) {
 
                Types.ExecMode platformOld = setExecMode(instType);
                try {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctColAliasException.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctColAliasException.java
new file mode 100644
index 0000000000..8af98ea790
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctColAliasException.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.countDistinct;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+public class CountDistinctColAliasException extends CountDistinctBase {
+
+       @Rule
+       public ExpectedException exceptionRule = ExpectedException.none();
+
+       private final static String TEST_NAME = 
"countDistinctColAliasException";
+       private final static String TEST_DIR = "functions/countDistinct/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctColAliasException.class.getSimpleName() + "/";
+
+       private final Types.Direction DIRECTION = Types.Direction.Row;
+
+       @Override
+       protected String getTestClassDir() {
+               return TEST_CLASS_DIR;
+       }
+
+       @Override
+       protected String getTestName() {
+               return TEST_NAME;
+       }
+
+       @Override
+       protected String getTestDir() {
+               return TEST_DIR;
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(getTestName(), new 
TestConfiguration(getTestClassDir(), getTestName(), new String[] {"A"}));
+
+               this.percentTolerance = 0.2;
+       }
+
+       @Test
+       public void testCPSparseSmall() {
+               exceptionRule.expect(AssertionError.class);
+               exceptionRule.expectMessage("Invalid number of arguments for 
function col_count_distinct(). " +
+                               "This function only takes 1 or 2 arguments.");
+
+               Types.ExecType execType = Types.ExecType.CP;
+
+               int actualDistinctCount = 10;
+               int rows = 1000, cols = 1000;
+               double sparsity = 0.1;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+               countDistinctMatrixTest(DIRECTION, actualDistinctCount, cols, 
rows, sparsity, execType, tolerance);
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowAliasException.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowAliasException.java
new file mode 100644
index 0000000000..dd5c4c2a05
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinct/CountDistinctRowAliasException.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.countDistinct;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+public class CountDistinctRowAliasException extends CountDistinctBase {
+
+       @Rule
+       public ExpectedException exceptionRule = ExpectedException.none();
+
+       private final static String TEST_NAME = 
"countDistinctRowAliasException";
+       private final static String TEST_DIR = "functions/countDistinct/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctRowAliasException.class.getSimpleName() + "/";
+
+       private final Types.Direction DIRECTION = Types.Direction.Row;
+
+       @Override
+       protected String getTestClassDir() {
+               return TEST_CLASS_DIR;
+       }
+
+       @Override
+       protected String getTestName() {
+               return TEST_NAME;
+       }
+
+       @Override
+       protected String getTestDir() {
+               return TEST_DIR;
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(getTestName(), new 
TestConfiguration(getTestClassDir(), getTestName(), new String[] {"A"}));
+
+               this.percentTolerance = 0.2;
+       }
+
+       @Test
+       public void testCPSparseSmall() {
+               exceptionRule.expect(AssertionError.class);
+               exceptionRule.expectMessage("Invalid number of arguments for 
function row_count_distinct(). " +
+                               "This function only takes 1 or 2 arguments.");
+
+               Types.ExecType execType = Types.ExecType.CP;
+
+               int actualDistinctCount = 10;
+               int rows = 1000, cols = 1000;
+               double sparsity = 0.1;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+               countDistinctMatrixTest(DIRECTION, actualDistinctCount, cols, 
rows, sparsity, execType, tolerance);
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxColAliasException.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxColAliasException.java
new file mode 100644
index 0000000000..8ea94a3a88
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxColAliasException.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.countDistinctApprox;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.test.functions.countDistinct.CountDistinctBase;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+public class CountDistinctApproxColAliasException extends CountDistinctBase {
+
+       @Rule
+       public ExpectedException exceptionRule = ExpectedException.none();
+
+       private final static String TEST_NAME = 
"countDistinctApproxColAliasException";
+       private final static String TEST_DIR = "functions/countDistinctApprox/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxColAliasException.class.getSimpleName() + "/";
+
+       private final Types.Direction DIRECTION = Types.Direction.Row;
+
+       @Override
+       protected String getTestClassDir() {
+               return TEST_CLASS_DIR;
+       }
+
+       @Override
+       protected String getTestName() {
+               return TEST_NAME;
+       }
+
+       @Override
+       protected String getTestDir() {
+               return TEST_DIR;
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(getTestName(), new 
TestConfiguration(getTestClassDir(), getTestName(), new String[] {"A"}));
+
+               this.percentTolerance = 0.2;
+       }
+
+       @Test
+       public void testCPSparseSmall() {
+               exceptionRule.expect(AssertionError.class);
+               exceptionRule.expectMessage("Too many parameters: function 
colCountDistinctApprox takes at least 1" +
+                               " and at most 2 parameters");
+
+               Types.ExecType execType = Types.ExecType.CP;
+
+               int actualDistinctCount = 10;
+               int rows = 1000, cols = 1000;
+               double sparsity = 0.1;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+               countDistinctMatrixTest(DIRECTION, actualDistinctCount, cols, 
rows, sparsity, execType, tolerance);
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowAliasException.java
 
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowAliasException.java
new file mode 100644
index 0000000000..5693985a91
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/countDistinctApprox/CountDistinctApproxRowAliasException.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.countDistinctApprox;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.test.functions.countDistinct.CountDistinctBase;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+public class CountDistinctApproxRowAliasException extends CountDistinctBase {
+
+       @Rule
+       public ExpectedException exceptionRule = ExpectedException.none();
+
+       private final static String TEST_NAME = 
"countDistinctApproxRowAliasException";
+       private final static String TEST_DIR = "functions/countDistinctApprox/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
CountDistinctApproxRowAliasException.class.getSimpleName() + "/";
+
+       private final Types.Direction DIRECTION = Types.Direction.Row;
+
+       @Override
+       protected String getTestClassDir() {
+               return TEST_CLASS_DIR;
+       }
+
+       @Override
+       protected String getTestName() {
+               return TEST_NAME;
+       }
+
+       @Override
+       protected String getTestDir() {
+               return TEST_DIR;
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(getTestName(), new 
TestConfiguration(getTestClassDir(), getTestName(), new String[] {"A"}));
+
+               this.percentTolerance = 0.2;
+       }
+
+       @Test
+       public void testCPSparseSmall() {
+               exceptionRule.expect(AssertionError.class);
+               exceptionRule.expectMessage("Too many parameters: function 
rowCountDistinctApprox takes at least 1" +
+                               " and at most 2 parameters");
+
+               Types.ExecType execType = Types.ExecType.CP;
+
+               int actualDistinctCount = 10;
+               int rows = 1000, cols = 1000;
+               double sparsity = 0.1;
+               double tolerance = actualDistinctCount * this.percentTolerance;
+               countDistinctMatrixTest(DIRECTION, actualDistinctCount, cols, 
rows, sparsity, execType, tolerance);
+       }
+}
diff --git a/src/test/scripts/functions/countDistinct/countDistinctColAlias.dml 
b/src/test/scripts/functions/countDistinct/countDistinctColAlias.dml
index 3eeb8ed54a..2522fbd1a5 100644
--- a/src/test/scripts/functions/countDistinct/countDistinctColAlias.dml
+++ b/src/test/scripts/functions/countDistinct/countDistinctColAlias.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,  
seed = 7))
-res = countDistinctCol(input, dir="c")
+res = colCountDistinct(input)
 write(res, $5, format="text")
diff --git 
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
 b/src/test/scripts/functions/countDistinct/countDistinctColAliasException.dml
similarity index 94%
copy from 
src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
copy to 
src/test/scripts/functions/countDistinct/countDistinctColAliasException.dml
index 83a9f5070c..45caeb85af 100644
--- 
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
+++ 
b/src/test/scripts/functions/countDistinct/countDistinctColAliasException.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, 
seed = 7))
-res = countDistinctApproxCol(input, dir="c", type="KMV")
+res = colCountDistinct(input, dir="x")
 write(res, $5, format="text")
diff --git a/src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml 
b/src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml
index 62d7196ce1..685221ffbe 100644
--- a/src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml
+++ b/src/test/scripts/functions/countDistinct/countDistinctRowAlias.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4,  
seed = 7))
-res = countDistinctRow(input, dir="r")
+res = rowCountDistinct(input)
 write(res, $5, format="text")
diff --git 
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
 b/src/test/scripts/functions/countDistinct/countDistinctRowAliasException.dml
similarity index 94%
copy from 
src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
copy to 
src/test/scripts/functions/countDistinct/countDistinctRowAliasException.dml
index 83a9f5070c..3b1cabfe98 100644
--- 
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
+++ 
b/src/test/scripts/functions/countDistinct/countDistinctRowAliasException.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, 
seed = 7))
-res = countDistinctApproxCol(input, dir="c", type="KMV")
+res = rowCountDistinct(input, dir="x")
 write(res, $5, format="text")
diff --git 
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
 
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
index 83a9f5070c..0eda3fb989 100644
--- 
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
+++ 
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, 
seed = 7))
-res = countDistinctApproxCol(input, dir="c", type="KMV")
+res = colCountDistinctApprox(input, type="KMV")
 write(res, $5, format="text")
diff --git 
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
 
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAliasException.dml
similarity index 94%
copy from 
src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
copy to 
src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAliasException.dml
index 83a9f5070c..8428cd061b 100644
--- 
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
+++ 
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAliasException.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, 
seed = 7))
-res = countDistinctApproxCol(input, dir="c", type="KMV")
+res = colCountDistinctApprox(input, dir="x", type="KMV")
 write(res, $5, format="text")
diff --git 
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml
 
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml
index f4be480156..f2c226e62e 100644
--- 
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml
+++ 
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAlias.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, 
seed = 7))
-res = countDistinctApproxRow(input, dir="r", type="KMV")
+res = rowCountDistinctApprox(input, type="KMV")
 write(res, $5, format="text")
diff --git 
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
 
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAliasException.dml
similarity index 94%
copy from 
src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
copy to 
src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAliasException.dml
index 83a9f5070c..05526c9ce8 100644
--- 
a/src/test/scripts/functions/countDistinctApprox/countDistinctApproxColAlias.dml
+++ 
b/src/test/scripts/functions/countDistinctApprox/countDistinctApproxRowAliasException.dml
@@ -20,5 +20,5 @@
 #-------------------------------------------------------------
 
 input = round(rand(rows = $2, cols = $3, min = 0, max = $1 -1, sparsity= $4, 
seed = 7))
-res = countDistinctApproxCol(input, dir="c", type="KMV")
+res = rowCountDistinctApprox(input, dir="x", type="KMV")
 write(res, $5, format="text")

Reply via email to