This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 589f574e22 [SYSTEMDS-3708] New builtin function for relational algebra 
- selection
589f574e22 is described below

commit 589f574e22cfb4b7bda8f8d5fb167f7fc70b0260
Author: gghsu <[email protected]>
AuthorDate: Sun Jun 2 12:20:27 2024 +0200

    [SYSTEMDS-3708] New builtin function for relational algebra - selection
    
    LDE project SoSe'24
    Closes #2027.
---
 scripts/builtin/raSelection.dml                    |  52 ++++++
 .../java/org/apache/sysds/common/Builtins.java     |   1 +
 .../builtin/part2/BuiltinRaSelectionTest.java      | 176 +++++++++++++++++++++
 src/test/scripts/functions/builtin/raSelection.R   |  51 ++++++
 src/test/scripts/functions/builtin/raSelection.dml |  30 ++++
 5 files changed, 310 insertions(+)

diff --git a/scripts/builtin/raSelection.dml b/scripts/builtin/raSelection.dml
new file mode 100644
index 0000000000..70d05b5cc8
--- /dev/null
+++ b/scripts/builtin/raSelection.dml
@@ -0,0 +1,52 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# This raSelection-function takes a matrix data set as input from where it 
performs
+# relational operations : selection
+#
+# INPUT:
+# 
------------------------------------------------------------------------------
+# X         Matrix of input data [shape: N x M]
+# col       Integer indicating the column index to execute selection command
+# op        String specifying the comparison operator (e.g., ">", "<", "==").
+# val       Constant value to compare the column values "with col op val'
+# 
------------------------------------------------------------------------------
+#
+# OUTPUT:
+# 
------------------------------------------------------------------------------
+# Y         Matrix of selected data [shape N' x M] with N' <= N
+# 
------------------------------------------------------------------------------
+
+m_raSelection = function (Matrix[Double] X, Integer col, String op, Double val)
+  return (Matrix[Double] Y)
+{
+  # Dertimine the operators
+  I = ifelse(op == "==", X[,col] == val,
+        ifelse(op == "!=", X[,col] != val,
+        ifelse(op == "<",  X[,col] <  val,
+        ifelse(op == ">",  X[,col] >  val,
+        ifelse(op == "<=", X[,col] <= val,
+        X[,col] >= val)))))
+
+  # Perform actual selection
+  Y = removeEmpty(target=X, margin="rows", select=I);
+}
+
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 9a70cd50db..59d41fb227 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -274,6 +274,7 @@ public enum Builtins {
        RANDOM_FOREST("randomForest", true),
        RANDOM_FOREST_PREDICT("randomForestPredict", true),
        RANGE("range", false),
+       RASELECTION("raSelection", true),
        RBIND("rbind", false),
        REMOVE("remove", false, ReturnType.MULTI_RETURN),
        REV("rev", false),
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaSelectionTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaSelectionTest.java
new file mode 100644
index 0000000000..95deb0cbfc
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinRaSelectionTest.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.builtin.part2;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class BuiltinRaSelectionTest extends AutomatedTestBase
+{
+       private final static String TEST_NAME = "raSelection";
+       private final static String TEST_DIR = "functions/builtin/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
BuiltinRaSelectionTest.class.getSimpleName() + "/";
+       private final static double eps = 1e-8;
+       
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME,new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"result"}));
+       }
+
+       @Test
+       public void testRaSelectionTestGreaterThan() {
+               //generate actual dataset and variables
+               double[][] X = {
+                               {1.0, 2.0},
+                               {3.0, 4.0},
+                               {5.0, 6.0},
+                               {7.0, 8.0},
+                               {9.0, 10.0}};
+               int select_col = 1;
+               String op = ">";
+               double val = 4.0;
+
+               runRaSelectionTest(X, select_col, op, val);
+       }
+
+       @Test
+       public void testRaSelectionGreaterThanOrEqul() {
+               // Generate actual dataset and variables
+               double[][] X = {
+                               {1.0, 2.0, 3.0},
+                               {4.0, 5.0, 6.0},
+                               {7.0, 8.0, 9.0}
+               };
+               int select_col = 1;
+               String op = ">=";
+               double val = 4.0;
+
+               runRaSelectionTest(X, select_col, op, val);
+       }
+
+       @Test
+       public void testRaSelectionTestLessThan() {
+               // Generate actual dataset and variables
+               double[][] X = {
+                               {1.0, 2.0, 3.0, 4.0},
+                               {5.0, 6.0, 7.0, 8.0}
+               };
+               int select_col = 2;
+               String op = "<";
+               double val = 7.0;
+
+               runRaSelectionTest(X, select_col, op, val);
+       }
+
+       @Test
+       public void testRaSelectionTestLessThanOrEqual() {
+               // Generate actual dataset and variables
+               double[][] X = {
+                               {5.0, 1.0, 3.0},
+                               {2.0, 4.0, 6.0},
+                               {7.0, 8.0, 9.0},
+                               {3.0, 5.0, 7.0},
+                               {1.0, 6.0, 8.0}
+               };
+               int select_col = 1;
+               String op = "<=";
+               double val = 4.0;
+
+               runRaSelectionTest(X, select_col, op, val);
+       }
+
+       @Test
+       public void testRaSelectionTestEqual() {
+               // Generate actual dataset and variables
+               double[][] X = {
+                               {1.0, 2.0, 3.0, 4.0},
+                               {5.0, 6.0, 7.0, 8.0},
+                               {9.0, 10.0, 11.0, 12.0}
+               };
+               int select_col = 4;
+               String op = "==";
+               double val = 8.0;
+
+               runRaSelectionTest(X, select_col, op, val);
+       }
+
+       @Test
+       public void testRaSelectionTestNotEqual() {
+               // Generate actual dataset and variables
+               double[][] X = {
+                               {1.0, 2.0, 3.0, 4.0},
+                               {5.0, 6.0, 7.0, 8.0},
+                               {9.0, 10.0, 11.0, 12.0},
+                               {13.0, 14.0, 15.0, 16.0}
+               };
+               int select_col = 2;
+               String op = "!=";
+               double val = 10.0;
+
+               runRaSelectionTest(X, select_col, op, val);
+       }
+
+       private void runRaSelectionTest(double [][] X, int col, String op, 
double val)
+       {
+               ExecMode platformOld = setExecMode(ExecMode.SINGLE_NODE);
+               
+               try
+               {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{"-stats", "-args",
+                               input("X"), String.valueOf(col), op, 
String.valueOf(val), output("result") };
+                       fullRScriptName = HOME + TEST_NAME + ".R";
+                       rCmd = "Rscript" + " " + fullRScriptName + " " 
+                               + inputDir() + " " + col + " " + op + " " + val 
+ " " + expectedDir();
+
+                       writeInputMatrixWithMTD("X", X, true);
+                       //writeExpectedMatrix("result", Y);
+
+                       // run dmlScript and RScript
+                       runTest(true, false, null, -1);
+                       runRScript(true);
+
+                       //compare matrices
+                       HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("result");
+                       HashMap<CellIndex, Double> rfile  = 
readRMatrixFromExpectedDir("result");
+                       TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Expected");
+                       
+                       //additional assertions
+                       if( !op.equals("==") )
+                               Assert.assertEquals(1, 
Statistics.getCPHeavyHitterCount(op));
+                       String otherOp = op.equals("!=") ? ">" : "!=";
+                       Assert.assertFalse(heavyHittersContainsString(otherOp));
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/builtin/raSelection.R 
b/src/test/scripts/functions/builtin/raSelection.R
new file mode 100644
index 0000000000..e0f94c7796
--- /dev/null
+++ b/src/test/scripts/functions/builtin/raSelection.R
@@ -0,0 +1,51 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Function definition for raSelection in R
+raSelection <- function(X, col, op, val) {
+  # Determine the operators
+  I = switch(op,
+              "==" = X[, col] == val,
+              "!=" = X[, col] != val,
+              "<"  = X[, col] <  val,
+              ">"  = X[, col] >  val,
+              "<=" = X[, col] <= val,
+              ">=" = X[, col] >= val,
+              stop("Invalid operator"))
+
+  # Select rows based on the condition
+  Y = X[I, , drop = FALSE]
+
+  return(Y)
+}
+
+args<-commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+X = as.matrix(readMM(paste(args[1],"X.mtx",sep="")));
+col = as.integer(args[2])
+op = args[3]
+val = as.numeric(args[4])
+
+result = raSelection(X,col,op,val);
+writeMM(as(result,"CsparseMatrix"),paste(args[5],"result",sep=""));
+ 
\ No newline at end of file
diff --git a/src/test/scripts/functions/builtin/raSelection.dml 
b/src/test/scripts/functions/builtin/raSelection.dml
new file mode 100644
index 0000000000..6d17146ed1
--- /dev/null
+++ b/src/test/scripts/functions/builtin/raSelection.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1)
+col = as.integer($2)
+op = $3
+val = as.double($4)
+
+result = raSelection(X, col, op, val);
+write(result, $5);
+print(toString(result))
+

Reply via email to