This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 98ea24d  [SYSTEMDS-2601] Comparison operators for frame-frame ops (CP, 
Spark)
98ea24d is described below

commit 98ea24d9f7d713bdcc1c0898ee79eaa493b096b3
Author: Shafaq Siddiqi <[email protected]>
AuthorDate: Mon Aug 10 21:55:31 2020 +0200

    [SYSTEMDS-2601] Comparison operators for frame-frame ops (CP, Spark)
    
    Closes #1009.
---
 .../org/apache/sysds/parser/DMLTranslator.java     |   4 +
 .../apache/sysds/parser/RelationalExpression.java  |  13 +-
 .../cp/BinaryFrameFrameCPInstruction.java          |  23 ++-
 .../spark/BinaryFrameFrameSPInstruction.java       |  44 +++++-
 .../sysds/runtime/matrix/data/FrameBlock.java      |  79 +++++++++
 .../functions/binary/frame/FrameEqualTest.java     | 176 +++++++++++++++++++++
 .../functions/binary/frame/frameComparisonTest.R   |  46 ++++++
 .../functions/binary/frame/frameComparisonTest.dml |  43 +++++
 8 files changed, 415 insertions(+), 13 deletions(-)

diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index a9e5a6e..f84f469 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -1812,6 +1812,10 @@ public class DMLTranslator
                                target.setDataType(DataType.MATRIX);
                                target.setValueType(ValueType.FP64);
                        }
+                       else if(left.getDataType() == DataType.FRAME || 
right.getDataType() == DataType.FRAME) {
+                               target.setDataType(DataType.FRAME);
+                               target.setValueType(ValueType.BOOLEAN);
+                       }
                        else {
                                // Added to support scalar relational comparison
                                target.setDataType(DataType.SCALAR);
diff --git a/src/main/java/org/apache/sysds/parser/RelationalExpression.java 
b/src/main/java/org/apache/sysds/parser/RelationalExpression.java
index 5c19a18..f0b4695 100644
--- a/src/main/java/org/apache/sysds/parser/RelationalExpression.java
+++ b/src/main/java/org/apache/sysds/parser/RelationalExpression.java
@@ -140,7 +140,9 @@ public class RelationalExpression extends Expression
                output.setParseInfo(this);
                
                boolean isLeftMatrix = (_left.getOutput() != null && 
_left.getOutput().getDataType() == DataType.MATRIX);
-               boolean isRightMatrix = (_right.getOutput() != null && 
_right.getOutput().getDataType() == DataType.MATRIX); 
+               boolean isRightMatrix = (_right.getOutput() != null && 
_right.getOutput().getDataType() == DataType.MATRIX);
+               boolean isLeftFrame = (_left.getOutput() != null && 
_left.getOutput().getDataType() == DataType.FRAME);
+               boolean isRightFrame = (_right.getOutput() != null && 
_right.getOutput().getDataType() == DataType.FRAME);
                if(isLeftMatrix || isRightMatrix) {
                        // Added to support matrix relational comparison
                        if(isLeftMatrix && isRightMatrix) {
@@ -155,6 +157,15 @@ public class RelationalExpression extends Expression
                        //double; once we support boolean matrices this needs 
to change
                        output.setValueType(ValueType.FP64);
                }
+               else if(isLeftFrame && isRightFrame) {
+                       output.setDataType(DataType.FRAME);
+                       output.setDimensions(_left.getOutput().getDim1(), 
_left.getOutput().getDim2());
+                       output.setValueType(ValueType.BOOLEAN);
+               }
+               else if( isLeftFrame || isRightFrame ) {
+                       raiseValidateError("Unsupported relational expression 
for mixed types "
+                               +_left.getOutput().getDataType().name()+" 
"+_right.getOutput().getDataType().name());
+               }
                else {
                        output.setBooleanProperties();
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java
index 1116675..7968b18 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.cp;
 
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
 public class BinaryFrameFrameCPInstruction extends BinaryCPInstruction
@@ -32,16 +33,26 @@ public class BinaryFrameFrameCPInstruction extends 
BinaryCPInstruction
 
        @Override
        public void processInstruction(ExecutionContext ec) {
-               // Read input matrices
+               // get input frames
                FrameBlock inBlock1 = ec.getFrameInput(input1.getName());
                FrameBlock inBlock2 = ec.getFrameInput(input2.getName());
-
-               // Perform computation using input frames, and produce the 
result frame
-               FrameBlock retBlock = inBlock1.dropInvalid(inBlock2);
+               
+               if(getOpcode().equals("dropInvalidType")) {
+                       // Perform computation using input frames, and produce 
the result frame
+                       FrameBlock retBlock = inBlock1.dropInvalid(inBlock2);
+                       // Attach result frame with FrameBlock associated with 
output_name
+                       ec.setFrameOutput(output.getName(), retBlock);
+               }
+               else {
+                       // Execute binary operations
+                       BinaryOperator dop = (BinaryOperator) _optr;
+                       FrameBlock outBlock = inBlock1.binaryOperations(dop, 
inBlock2, null);
+                       // Attach result frame with FrameBlock associated with 
output_name
+                       ec.setFrameOutput(output.getName(), outBlock);
+               }
+               
                // Release the memory occupied by input frames
                ec.releaseFrameInput(input1.getName());
                ec.releaseFrameInput(input2.getName());
-               // Attach result frame with FrameBlock associated with 
output_name
-               ec.setFrameOutput(output.getName(), retBlock);
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java
index 263abf3..021ac84 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/BinaryFrameFrameSPInstruction.java
@@ -29,7 +29,9 @@ import 
org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
+import scala.Tuple2;
 
 public class BinaryFrameFrameSPInstruction extends BinarySPInstruction {
        protected BinaryFrameFrameSPInstruction(Operator op, CPOperand in1, 
CPOperand in2, CPOperand out, String opcode, String istr) {
@@ -55,16 +57,33 @@ public class BinaryFrameFrameSPInstruction extends 
BinarySPInstruction {
        @Override
        public void processInstruction(ExecutionContext ec) {
                SparkExecutionContext sec = (SparkExecutionContext)ec;
+
                // Get input RDDs
                JavaPairRDD<Long, FrameBlock> in1 = 
sec.getFrameBinaryBlockRDDHandleForVariable(input1.getName());
-               // get schema frame-block
-               Broadcast<FrameBlock> fb = 
sec.getSparkContext().broadcast(sec.getFrameInput(input2.getName()));
-               JavaPairRDD<Long, FrameBlock> out = in1.mapValues(new 
isCorrectbySchema(fb.getValue()));
-               //release input frame
-               sec.releaseFrameInput(input2.getName());
-               //set output RDD
+               JavaPairRDD<Long, FrameBlock> out = null;
+               
+               if(getOpcode().equals("dropInvalidType")) {
+                       // get schema frame-block
+                       Broadcast<FrameBlock> fb = 
sec.getSparkContext().broadcast(sec.getFrameInput(input2.getName()));
+                       out = in1.mapValues(new 
isCorrectbySchema(fb.getValue()));
+                       //release input frame
+                       sec.releaseFrameInput(input2.getName());
+               }
+               else {
+                       JavaPairRDD<Long, FrameBlock> in2 = 
sec.getFrameBinaryBlockRDDHandleForVariable(input2.getName());
+                       // create output frame
+                       BinaryOperator dop = (BinaryOperator) _optr;
+                       // check for binary operations
+                       out = in1.join(in2).mapValues(new FrameComparison(dop));
+               }
+               
+               //set output RDD and maintain dependencies
                sec.setRDDHandleForVariable(output.getName(), out);
                sec.addLineageRDD(output.getName(), input1.getName());
+               if( getOpcode().equals("dropInvalidType") )
+                       sec.addLineageBroadcast(output.getName(), 
input2.getName());
+               else
+                       sec.addLineageRDD(output.getName(), input2.getName());
        }
 
        private static class isCorrectbySchema implements 
Function<FrameBlock,FrameBlock> {
@@ -81,4 +100,17 @@ public class BinaryFrameFrameSPInstruction extends 
BinarySPInstruction {
                        return arg0.dropInvalid(schema_frame);
                }
        }
+
+       private static class FrameComparison implements 
Function<Tuple2<FrameBlock, FrameBlock>, FrameBlock> {
+               private static final long serialVersionUID = 
5850400295183766401L;
+               private final BinaryOperator bop;
+               public FrameComparison(BinaryOperator op){
+                       bop = op;
+               }
+
+               @Override
+               public FrameBlock call(Tuple2<FrameBlock, FrameBlock> arg0) 
throws Exception {
+                       return arg0._1().binaryOperations(bop, arg0._2(), null);
+               }
+       }
 }
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
index 325819b..e473acd 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
@@ -41,7 +41,10 @@ import org.apache.sysds.api.DMLException;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysds.runtime.functionobjects.ValueComparisonFunction;
+import org.apache.sysds.runtime.instructions.cp.*;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.runtime.transform.encode.EncoderRecode;
 import org.apache.sysds.runtime.util.IndexRange;
 import org.apache.sysds.runtime.util.UtilFunctions;
@@ -277,6 +280,7 @@ public class FrameBlock implements CacheBlock, 
Externalizable
                                case BOOLEAN: _coldata[j] = new 
BooleanArray(new boolean[numRows]); break;
                                case INT32:   _coldata[j] = new 
IntegerArray(new int[numRows]); break;
                                case INT64:   _coldata[j] = new LongArray(new 
long[numRows]); break;
+                               case FP32:   _coldata[j] = new FloatArray(new 
float[numRows]); break;
                                case FP64:   _coldata[j] = new DoubleArray(new 
double[numRows]); break;
                                default: throw new 
RuntimeException("Unsupported value type: "+_schema[j]);
                        }
@@ -702,6 +706,8 @@ public class FrameBlock implements CacheBlock, 
Externalizable
                                case BOOLEAN: arr = new BooleanArray(new 
boolean[_numRows]); break;
                                case INT64:     arr = new LongArray(new 
long[_numRows]); break;
                                case FP64:  arr = new DoubleArray(new 
double[_numRows]); break;
+                               case INT32: arr = new IntegerArray(new 
int[_numRows]); break;
+                               case FP32:  arr = new FloatArray(new 
float[_numRows]); break;
                                default: throw new IOException("Unsupported 
value type: "+vt);
                        }
                        arr.readFields(in);
@@ -837,6 +843,79 @@ public class FrameBlock implements CacheBlock, 
Externalizable
                        + 32 + value.length();     //char array 
        }
        
+       /**
+        *  This method performs the value comparison on two frames
+        *  if the values in both frames are equal, not equal, less than, 
greater than, less than/greater than and equal to
+        *  the output frame will store boolean value for each each comparison
+        *
+        *  @param bop binary operator
+        *  @param that frame block of rhs of m * n dimensions
+        *  @param out output frame block
+        *  @return a boolean frameBlock
+        */
+       public FrameBlock binaryOperations(BinaryOperator bop, FrameBlock that, 
FrameBlock out) {
+               if(getNumColumns() != that.getNumColumns() && getNumRows() != 
that.getNumColumns())
+                       throw new DMLRuntimeException("Frame dimension mismatch 
"+getNumRows()+" * "+getNumColumns()+
+                               " != "+that.getNumRows()+" * 
"+that.getNumColumns());
+               String[][] outputData = new 
String[getNumRows()][getNumColumns()];
+
+               //compare output value, incl implicit type promotion if 
necessary
+               if( !(bop.fn instanceof ValueComparisonFunction) )
+                       throw new DMLRuntimeException("Unsupported binary 
operation on frames (only comparisons supported)");
+               ValueComparisonFunction vcomp = (ValueComparisonFunction) 
bop.fn;
+
+               for (int i = 0; i < getNumColumns(); i++) {
+                       if (getSchema()[i] == ValueType.STRING || 
that.getSchema()[i] == ValueType.STRING) {
+                               for (int j = 0; j < getNumRows(); j++) {
+                                       if(checkAndSetEmpty(this, that, 
outputData, j, i))
+                                               continue;
+                                       String v1 = 
UtilFunctions.objectToString(get(j, i));
+                                       String v2 = 
UtilFunctions.objectToString(that.get(j, i));
+                                       outputData[j][i] = 
String.valueOf(vcomp.compare(v1, v2));
+                               }
+                       }
+                       else if (getSchema()[i] == ValueType.FP64 || 
that.getSchema()[i] == ValueType.FP64 ||
+                                       getSchema()[i] == ValueType.FP32 || 
that.getSchema()[i] == ValueType.FP32) {
+                               for (int j = 0; j < getNumRows(); j++) {
+                                       if(checkAndSetEmpty(this, that, 
outputData, j, i))
+                                               continue;
+                                       ScalarObject so1 = new 
DoubleObject(Double.parseDouble(get(j, i).toString()));
+                                       ScalarObject so2 = new 
DoubleObject(Double.parseDouble(that.get(j, i).toString()));
+                                       outputData[j][i] = 
String.valueOf(vcomp.compare(so1.getDoubleValue(), so2.getDoubleValue()));
+                               }
+                       }
+                       else if (getSchema()[i] == ValueType.INT64 || 
that.getSchema()[i] == ValueType.INT64 ||
+                                       getSchema()[i] == ValueType.INT32 || 
that.getSchema()[i] == ValueType.INT32) {
+                               for (int j = 0; j < this.getNumRows(); j++) {
+                                       if(checkAndSetEmpty(this, that, 
outputData, j, i))
+                                               continue;
+                                       ScalarObject so1 = new 
IntObject(Integer.parseInt(get(j, i).toString()));
+                                       ScalarObject so2 = new 
IntObject(Integer.parseInt(that.get(j, i).toString()));
+                                       outputData[j][i]  = 
String.valueOf(vcomp.compare(so1.getLongValue(), so2.getLongValue()));
+                               }
+                       }
+                       else {
+                               for (int j = 0; j < getNumRows(); j++) {
+                                       if(checkAndSetEmpty(this, that, 
outputData, j, i))
+                                               continue;
+                                       ScalarObject so1 = new BooleanObject( 
Boolean.parseBoolean(get(j, i).toString()));
+                                       ScalarObject so2 = new BooleanObject( 
Boolean.parseBoolean(that.get(j, i).toString()));
+                                       outputData[j][i] = 
String.valueOf(vcomp.compare(so1.getBooleanValue(), so2.getBooleanValue()));
+                               }
+                       }
+               }
+
+               return new 
FrameBlock(UtilFunctions.nCopies(this.getNumColumns(), ValueType.BOOLEAN), 
outputData);
+       }
+       
+       private static boolean checkAndSetEmpty(FrameBlock fb1, FrameBlock fb2, 
String[][] out, int r, int c) {
+               if(fb1.get(r, c) == null || fb2.get(r, c) == null) {
+                       out[r][c] = (fb1.get(r, c) == null && fb2.get(r, c) == 
null) ? "true" : "false";
+                       return true;
+               }
+               return false;
+       }
+       
        ///////
        // indexing and append operations
        
diff --git 
a/src/test/java/org/apache/sysds/test/functions/binary/frame/FrameEqualTest.java
 
b/src/test/java/org/apache/sysds/test/functions/binary/frame/FrameEqualTest.java
new file mode 100644
index 0000000..cdb8999
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/binary/frame/FrameEqualTest.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.binary.frame;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.lops.LopProperties.ExecType;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class FrameEqualTest extends AutomatedTestBase {
+       private final static String TEST_NAME = "frameComparisonTest";
+       private final static String TEST_DIR = "functions/binary/frame/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FrameEqualTest.class.getSimpleName() + "/";
+
+       private final static int rows = 100;
+       private final static Types.ValueType[] schemaStrings1 = 
{Types.ValueType.FP64, Types.ValueType.BOOLEAN, Types.ValueType.INT64, 
Types.ValueType.STRING, Types.ValueType.STRING, Types.ValueType.FP64};
+       private final static Types.ValueType[] schemaStrings2 = 
{Types.ValueType.INT64, Types.ValueType.BOOLEAN, Types.ValueType.FP32, 
Types.ValueType.FP64, Types.ValueType.STRING, Types.ValueType.FP32};
+
+       public enum TestType {
+               GREATER, LESS, EQUALS, NOT_EQUALS, GREATER_EQUALS, LESS_EQUALS,
+       }
+
+       @BeforeClass
+       public static void init() {
+               TestUtils.clearDirectory(TEST_DATA_DIR + TEST_CLASS_DIR);
+       }
+
+       @AfterClass
+       public static void cleanUp() {
+               if (TEST_CACHE_ENABLED) {
+                       TestUtils.clearDirectory(TEST_DATA_DIR + 
TEST_CLASS_DIR);
+               }
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"D"}));
+               if (TEST_CACHE_ENABLED) {
+                       setOutAndExpectedDeletionDisabled(true);
+               }
+       }
+
+       @Test
+       public void testFrameEqualCP() {
+               runComparisonTest(schemaStrings1, schemaStrings2, rows, 
schemaStrings1.length, TestType.EQUALS, ExecType.CP);
+       }
+
+       @Test
+       public void testFrameEqualSpark() {
+               runComparisonTest(schemaStrings1, schemaStrings2, rows, 
schemaStrings1.length, TestType.EQUALS, ExecType.SPARK);
+       }
+
+       @Test
+       public void testFrameNotEqualCP() {
+               runComparisonTest(schemaStrings1, schemaStrings2, rows, 
schemaStrings1.length, TestType.NOT_EQUALS, ExecType.CP);
+       }
+
+       @Test
+       public void testFrameNotEqualSpark() {
+               runComparisonTest(schemaStrings1, schemaStrings2, rows, 
schemaStrings1.length, TestType.NOT_EQUALS, ExecType.SPARK);
+       }
+
+       @Test
+       public void testFrameLessThanCP() {
+               runComparisonTest(schemaStrings1, schemaStrings2, rows, 
schemaStrings1.length, TestType.LESS, ExecType.CP);
+       }
+
+       @Test
+       public void testFrameLessThanSpark() {
+               runComparisonTest(schemaStrings1, schemaStrings2, rows, 
schemaStrings1.length, TestType.LESS, ExecType.SPARK);
+       }
+
+       @Test
+       public void testFrameGreaterEqualsCP() {
+               runComparisonTest(schemaStrings1, schemaStrings2, rows, 
schemaStrings1.length, TestType.GREATER_EQUALS, ExecType.CP);
+       }
+
+       @Test
+       public void testFrameGreaterEqualsSpark() {
+               runComparisonTest(schemaStrings1, schemaStrings2, rows, 
schemaStrings1.length, TestType.GREATER_EQUALS, ExecType.SPARK);
+       }
+
+       @Test 
+       public void testFrameLessEqualsCP() {
+               runComparisonTest(schemaStrings1, schemaStrings2, rows, 
schemaStrings1.length, TestType.LESS_EQUALS, ExecType.CP);
+       }
+
+       @Test 
+       public void testFrameLessEqualsSpark() {
+               runComparisonTest(schemaStrings1, schemaStrings2, rows, 
schemaStrings1.length, TestType.LESS_EQUALS, ExecType.SPARK);
+       }
+
+       @Test
+       public void testFrameGreaterThanCP() {
+               runComparisonTest(schemaStrings1, schemaStrings2, rows, 
schemaStrings1.length, TestType.GREATER, ExecType.CP);
+       }
+
+       @Test
+       public void testFrameGreaterThanSpark() {
+               runComparisonTest(schemaStrings1, schemaStrings2, rows, 
schemaStrings1.length, TestType.GREATER, ExecType.SPARK);
+       }
+
+       private void runComparisonTest(Types.ValueType[] schema1, 
Types.ValueType[] schema2, int rows, int cols,
+                       TestType type, ExecType et)
+       {
+               Types.ExecMode platformOld = setExecMode(et);
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+
+               try {
+                       getAndLoadTestConfiguration(TEST_NAME);
+
+                       double[][] A = getRandomMatrix(rows, cols, 2, 3, 1, 2);
+                       double[][] B = getRandomMatrix(rows, cols, 10, 20, 1, 
0);
+
+                       writeInputFrameWithMTD("A", A, true, schemaStrings1, 
FileFormat.BINARY);
+                       writeInputFrameWithMTD("B", B, true, schemaStrings2, 
FileFormat.BINARY);
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-explain", 
"recompile_runtime", "-nvargs", "A=" + input("A"), "B=" + input("B"),
+                                       "rows=" + String.valueOf(rows), "cols=" 
+ Integer.toString(cols), "type=" + String.valueOf(type), "C=" + output("C")};
+
+                       fullRScriptName = HOME + TEST_NAME + ".R";
+                       rCmd = "Rscript" + " " + fullRScriptName + " " + 
inputDir() + " " + String.valueOf(type) + " " + expectedDir();
+
+                       runTest(true, false, null, -1);
+                       runRScript(true);
+
+                       //compare matrices
+                       HashMap<MatrixValue.CellIndex, Double> dmlfile = 
readDMLMatrixFromHDFS("C");
+                       HashMap<MatrixValue.CellIndex, Double> rfile = 
readRMatrixFromFS("C");
+
+                       double eps = 0.0001;
+                       TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
+               }
+               catch (Exception ex) {
+                       throw new RuntimeException(ex);
+               }
+               finally {
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+                       OptimizerUtils.ALLOW_AUTO_VECTORIZATION = true;
+                       OptimizerUtils.ALLOW_OPERATOR_FUSION = true;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/binary/frame/frameComparisonTest.R 
b/src/test/scripts/functions/binary/frame/frameComparisonTest.R
new file mode 100644
index 0000000..c931d7e
--- /dev/null
+++ b/src/test/scripts/functions/binary/frame/frameComparisonTest.R
@@ -0,0 +1,46 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+A=read.csv(paste(args[1], "A.csv", sep=""), header = FALSE, 
stringsAsFactors=FALSE)
+B=read.csv(paste(args[1], "B.csv", sep=""), header = FALSE, 
stringsAsFactors=FALSE)
+
+test = args[2]
+
+if(test == "GREATER")
+{ 
+C = A > B 
+} else if (test == "LESS") {
+C = A < B
+} else if (test == "EQUALS") {
+C = A == B
+} else if (test == "NOT_EQUALS") {
+C =  A != B
+} else if(test == "GREATER_EQUALS") {
+C = A >= B
+} else if(test == "LESS_EQUALS") {
+C = A <= B
+}
+
+writeMM(as(C, "CsparseMatrix"), paste(args[3], "C", sep="")); 
\ No newline at end of file
diff --git a/src/test/scripts/functions/binary/frame/frameComparisonTest.dml 
b/src/test/scripts/functions/binary/frame/frameComparisonTest.dml
new file mode 100644
index 0000000..c43a614
--- /dev/null
+++ b/src/test/scripts/functions/binary/frame/frameComparisonTest.dml
@@ -0,0 +1,43 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($A, rows=$rows, cols=$cols, data_type="frame", format="binary", 
header=FALSE);
+B = read($B, rows=$rows, cols=$cols, data_type="frame", format="binary", 
header=FALSE);
+
+test = $type
+
+if(test == "GREATER")
+  C = A > B
+else if (test == "LESS")
+  C = A < B
+else if (test == "EQUALS")
+  C = A == B
+else if (test == "NOT_EQUALS")   
+  C =  A != B
+else if (test == "GREATER_EQUALS")   
+  C =  A >= B
+else if (test == "LESS_EQUALS")   
+  C =  A <= B
+
+C = as.matrix(C)
+# print("this is C "+toString(C))
+
+write(C, $C);
\ No newline at end of file

Reply via email to