[SYSTEMML-562] New spark frame/matrix casting instructions, tests

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/9607a376
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/9607a376
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/9607a376

Branch: refs/heads/master
Commit: 9607a376ae2346b103c97efe1c4afc9a75a62c08
Parents: 507172a
Author: Matthias Boehm <[email protected]>
Authored: Wed Jun 1 01:34:18 2016 -0700
Committer: Matthias Boehm <[email protected]>
Committed: Wed Jun 1 09:48:50 2016 -0700

----------------------------------------------------------------------
 src/main/java/org/apache/sysml/hops/Hop.java    |   2 +
 .../java/org/apache/sysml/hops/MemoTable.java   |   6 +
 .../sysml/hops/ParameterizedBuiltinOp.java      |  14 +-
 .../java/org/apache/sysml/hops/UnaryOp.java     |   4 +-
 src/main/java/org/apache/sysml/lops/Unary.java  |   7 +
 .../context/SparkExecutionContext.java          |  16 +-
 .../instructions/SPInstructionParser.java       |   7 +
 .../instructions/spark/CastSPInstruction.java   |  99 ++++++++
 .../spark/CheckpointSPInstruction.java          |   2 +-
 .../instructions/spark/SPInstruction.java       |   4 +-
 .../functions/frame/FrameMatrixCastingTest.java | 251 +++++++++++++++++++
 .../functions/frame/Frame2MatrixCast.dml        |  26 ++
 .../functions/frame/Matrix2FrameCast.dml        |  26 ++
 .../functions/frame/ZPackageSuite.java          |   1 +
 14 files changed, 450 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9607a376/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java 
b/src/main/java/org/apache/sysml/hops/Hop.java
index 8519055..e0a06e6 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1288,6 +1288,8 @@ public abstract class Hop
                HopsOpOp1LopsU.put(OpOp1.SIGMOID, 
org.apache.sysml.lops.Unary.OperationTypes.SIGMOID);
                HopsOpOp1LopsU.put(OpOp1.SELP, 
org.apache.sysml.lops.Unary.OperationTypes.SELP);
                HopsOpOp1LopsU.put(OpOp1.LOG_NZ, 
org.apache.sysml.lops.Unary.OperationTypes.LOG_NZ);
+               HopsOpOp1LopsU.put(OpOp1.CAST_AS_MATRIX, 
org.apache.sysml.lops.Unary.OperationTypes.CAST_AS_MATRIX);
+               HopsOpOp1LopsU.put(OpOp1.CAST_AS_FRAME, 
org.apache.sysml.lops.Unary.OperationTypes.CAST_AS_FRAME);
        }
 
        protected static final HashMap<Hop.OpOp1, 
org.apache.sysml.lops.UnaryCP.OperationTypes> HopsOpOp1LopsUS;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9607a376/src/main/java/org/apache/sysml/hops/MemoTable.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/MemoTable.java 
b/src/main/java/org/apache/sysml/hops/MemoTable.java
index ff41f09..cae8507 100644
--- a/src/main/java/org/apache/sysml/hops/MemoTable.java
+++ b/src/main/java/org/apache/sysml/hops/MemoTable.java
@@ -127,6 +127,9 @@ public class MemoTable
         */
        public MatrixCharacteristics[] getAllInputStats( ArrayList<Hop> inputs )
        {
+               if( inputs == null )
+                       return null;
+               
                MatrixCharacteristics[] ret = new 
MatrixCharacteristics[inputs.size()];
                for( int i=0; i<inputs.size(); i++ )
                {
@@ -164,6 +167,9 @@ public class MemoTable
         */
        public MatrixCharacteristics getAllInputStats( Hop input )
        {
+               if( input == null )
+                       return null;
+               
                MatrixCharacteristics ret = null;
                        
                long dim1 = input.getDim1();

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9607a376/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java 
b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
index a793fb1..b2a1807 100644
--- a/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysml/hops/ParameterizedBuiltinOp.java
@@ -143,11 +143,9 @@ public class ParameterizedBuiltinOp extends Hop implements 
MultiThreadedHop
                _outputPermutationMatrix = flag;
        }
        
-       public Hop getTargetHop()
-       {
-               Hop targetHop = getInput().get(_paramIndexMap.get("target"));
-               
-               return targetHop;
+       public Hop getTargetHop() {
+               return _paramIndexMap.containsKey("target") ?   
+                       getInput().get(_paramIndexMap.get("target")) : null;
        }
        
        @Override
@@ -985,13 +983,11 @@ public class ParameterizedBuiltinOp extends Hop 
implements MultiThreadedHop
        @Override
        protected long[] inferOutputCharacteristics( MemoTable memo )
        {
-               //CDF always known because 
-               
-               // TOSTRING outputs a string
+               //Notes: CDF, TOSTRING always known because scalar outputs
                
                long[] ret = null;
        
-               Hop input = getInput().get(_paramIndexMap.get("target"));       
+               Hop input = getTargetHop();     
                MatrixCharacteristics mc = memo.getAllInputStats(input);
 
                if( _op == ParamBuiltinOp.GROUPEDAGG ) 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9607a376/src/main/java/org/apache/sysml/hops/UnaryOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/UnaryOp.java 
b/src/main/java/org/apache/sysml/hops/UnaryOp.java
index 4ed0225..f0c8482 100644
--- a/src/main/java/org/apache/sysml/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/UnaryOp.java
@@ -121,8 +121,8 @@ public class UnaryOp extends Hop implements MultiThreadedHop
                {
                        Hop input = getInput().get(0);
                        
-                       if (getDataType() == DataType.SCALAR 
-                               || _op == OpOp1.CAST_AS_MATRIX || _op == 
OpOp1.CAST_AS_FRAME ) //TODO generalize frames to distributed ops 
+                       if(    getDataType() == DataType.SCALAR //value type 
casts or matrix to scalar
+                               || (_op == OpOp1.CAST_AS_MATRIX && 
getInput().get(0).getDataType()==DataType.SCALAR) )
                        {
                                if (_op == Hop.OpOp1.IQM)  //special handling 
IQM
                                {

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9607a376/src/main/java/org/apache/sysml/lops/Unary.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/lops/Unary.java 
b/src/main/java/org/apache/sysml/lops/Unary.java
index 8ef3e7b..bea7352 100644
--- a/src/main/java/org/apache/sysml/lops/Unary.java
+++ b/src/main/java/org/apache/sysml/lops/Unary.java
@@ -43,6 +43,7 @@ public class Unary extends Lop
                ROUND, CEIL, FLOOR, MR_IQM, INVERSE, CHOLESKY,
                CUMSUM, CUMPROD, CUMMIN, CUMMAX,
                SPROP, SIGMOID, SELP, SUBTRACT_NZ, LOG_NZ,
+               CAST_AS_MATRIX, CAST_AS_FRAME,
                NOTSUPPORTED
        };
 
@@ -326,6 +327,12 @@ public class Unary extends Lop
                case SELP:
                        return "sel+";
                
+               case CAST_AS_MATRIX:
+                       return UnaryCP.CAST_AS_MATRIX_OPCODE;
+
+               case CAST_AS_FRAME:
+                       return UnaryCP.CAST_AS_FRAME_OPCODE;
+                       
                default:
                        throw new LopsException(
                                        "Instruction not defined for Unary 
operation: " + op);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9607a376/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
index f282579..956ea4e 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
@@ -540,7 +540,7 @@ public class SparkExecutionContext extends ExecutionContext
         * @param rdd
         * @throws DMLRuntimeException 
         */
-       public void setRDDHandleForVariable(String varname, 
JavaPairRDD<MatrixIndexes,?> rdd) 
+       public void setRDDHandleForVariable(String varname, JavaPairRDD<?,?> 
rdd) 
                throws DMLRuntimeException
        {
                MatrixObject mo = getMatrixObject(varname);
@@ -549,6 +549,20 @@ public class SparkExecutionContext extends ExecutionContext
        }
        
        /**
+        * 
+        * @param varname
+        * @param rdd
+        * @throws DMLRuntimeException
+        */
+       public void setFrameRDDHandleForVariable(String varname, 
JavaPairRDD<?,?> rdd) 
+               throws DMLRuntimeException
+       {
+               FrameObject mo = getFrameObject(varname);
+               RDDObject rddhandle = new RDDObject(rdd, varname);
+               mo.setRDDHandle( rddhandle );
+       }
+       
+       /**
         * Utility method for creating an RDD out of an in-memory matrix block.
         * 
         * @param sc

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9607a376/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java 
b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
index fa1a630..430aaf4 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/SPInstructionParser.java
@@ -45,6 +45,7 @@ import 
org.apache.sysml.runtime.instructions.spark.BinUaggChainSPInstruction;
 import org.apache.sysml.runtime.instructions.spark.BuiltinBinarySPInstruction;
 import org.apache.sysml.runtime.instructions.spark.BuiltinUnarySPInstruction;
 import org.apache.sysml.runtime.instructions.spark.CSVReblockSPInstruction;
+import org.apache.sysml.runtime.instructions.spark.CastSPInstruction;
 import org.apache.sysml.runtime.instructions.spark.CentralMomentSPInstruction;
 import org.apache.sysml.runtime.instructions.spark.CheckpointSPInstruction;
 import org.apache.sysml.runtime.instructions.spark.CovarianceSPInstruction;
@@ -256,6 +257,9 @@ public class SPInstructionParser extends InstructionParser
                String2SPInstructionType.put( "binuaggchain", 
SPINSTRUCTION_TYPE.BinUaggChain);
                
                String2SPInstructionType.put( "write"   , 
SPINSTRUCTION_TYPE.Write);
+       
+               String2SPInstructionType.put( "castdtm"   , 
SPINSTRUCTION_TYPE.Cast);
+               String2SPInstructionType.put( "castdtf"   , 
SPINSTRUCTION_TYPE.Cast);
        }
 
        public static SPInstruction parseSingleInstruction (String str ) 
@@ -406,6 +410,9 @@ public class SPInstructionParser extends InstructionParser
                                
                        case Checkpoint:
                                return 
CheckpointSPInstruction.parseInstruction(str);
+                       
+                       case Cast:
+                               return CastSPInstruction.parseInstruction(str);
                                
                        case INVALID:
                        default:

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9607a376/src/main/java/org/apache/sysml/runtime/instructions/spark/CastSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CastSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CastSPInstruction.java
new file mode 100644
index 0000000..8160bd0
--- /dev/null
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CastSPInstruction.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.runtime.instructions.spark;
+
+import org.apache.hadoop.io.LongWritable;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.lops.UnaryCP;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.instructions.InstructionUtils;
+import org.apache.sysml.runtime.instructions.cp.CPOperand;
+import 
org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
+import 
org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils.LongFrameToLongWritableFrameFunction;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.InputInfo;
+import org.apache.sysml.runtime.matrix.data.FrameBlock;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysml.runtime.matrix.operators.Operator;
+
+
+public class CastSPInstruction extends UnarySPInstruction
+{
+       public CastSPInstruction(Operator op, CPOperand in, CPOperand out, 
String opcode, String istr) {
+               super(op, in, out, opcode, istr);
+               _sptype = SPINSTRUCTION_TYPE.Cast;
+       }
+       
+       public static CastSPInstruction parseInstruction ( String str ) 
+               throws DMLRuntimeException 
+       {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               InstructionUtils.checkNumFields(parts, 2);
+               
+               String opcode = parts[0];
+               CPOperand in = new CPOperand(parts[1]);
+               CPOperand out = new CPOperand(parts[2]);
+
+               return new CastSPInstruction(null, in, out, opcode, str);
+       }
+       
+       @Override
+       @SuppressWarnings("unchecked")
+       public void processInstruction(ExecutionContext ec)
+                       throws DMLRuntimeException 
+       {
+               SparkExecutionContext sec = (SparkExecutionContext)ec;
+               String opcode = getOpcode();
+               
+               //get input RDD and prepare output
+               JavaPairRDD<?,?> in = sec.getRDDHandleForVariable( 
input1.getName(), InputInfo.BinaryBlockInputInfo );
+               MatrixCharacteristics mcIn = sec.getMatrixCharacteristics( 
input1.getName() );
+               JavaPairRDD<?,?> out = null;
+               
+               //convert frame-matrix / matrix-frame and set output
+               if( opcode.equals(UnaryCP.CAST_AS_MATRIX_OPCODE) ) {
+                       //TODO: simplify converter api to allow long indexes to 
be passed in
+                       MatrixCharacteristics mcOut = new 
MatrixCharacteristics(mcIn);
+                       mcOut.setBlockSize(ConfigurationManager.getBlocksize(), 
ConfigurationManager.getBlocksize());
+                       in = ((JavaPairRDD<Long, FrameBlock>)in).mapToPair(new 
LongFrameToLongWritableFrameFunction());
+                       out = FrameRDDConverterUtils.binaryBlockToMatrixBlock(
+                               (JavaPairRDD<LongWritable, FrameBlock>)in, 
mcIn, mcOut);
+                       
+                       sec.setRDDHandleForVariable(output.getName(), out);
+               }
+               else if( opcode.equals(UnaryCP.CAST_AS_FRAME_OPCODE) ) {
+                       out = 
FrameRDDConverterUtils.matrixBlockToBinaryBlockLongIndex(sec.getSparkContext(), 
+                               (JavaPairRDD<MatrixIndexes, MatrixBlock>)in, 
mcIn);
+               
+                       sec.setFrameRDDHandleForVariable(output.getName(), out);
+               }
+               else {
+                       throw new DMLRuntimeException("Unsupported spark cast 
operation: "+opcode);
+               }
+               
+               //update output statistics and add lineage
+               updateUnaryOutputMatrixCharacteristics(sec, input1.getName(), 
output.getName());
+               sec.addLineageRDD(output.getName(), input1.getName());
+       }
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9607a376/src/main/java/org/apache/sysml/runtime/instructions/spark/CheckpointSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CheckpointSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CheckpointSPInstruction.java
index 24c614a..03edadd 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CheckpointSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CheckpointSPInstruction.java
@@ -51,7 +51,7 @@ public class CheckpointSPInstruction extends 
UnarySPInstruction
        
        public CheckpointSPInstruction(Operator op, CPOperand in, CPOperand 
out, StorageLevel level, String opcode, String istr) {
                super(op, in, out, opcode, istr);
-               _sptype = SPINSTRUCTION_TYPE.Reorg;
+               _sptype = SPINSTRUCTION_TYPE.Checkpoint;
                
                _level = level;
        }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9607a376/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
index ae16075..acc2ff6 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java
@@ -40,8 +40,8 @@ public abstract class SPInstruction extends Instruction
        public enum SPINSTRUCTION_TYPE { 
                MAPMM, MAPMMCHAIN, CPMM, RMM, TSMM, PMM, ZIPMM, PMAPMM, 
//matrix multiplication instructions  
                MatrixIndexing, Reorg, ArithmeticBinary, RelationalBinary, 
AggregateUnary, AggregateTernary, Reblock, CSVReblock, 
-               Builtin, BuiltinUnary, BuiltinBinary, Checkpoint, 
-               CentralMoment, Covariance, QSort, QPick,
+               Builtin, BuiltinUnary, BuiltinBinary, Checkpoint, Cast,
+               CentralMoment, Covariance, QSort, QPick, 
                ParameterizedBuiltin, MAppend, RAppend, GAppend, 
GAlignedAppend, Rand, 
                MatrixReshape, Ternary, Quaternary, CumsumAggregate, 
CumsumOffset, BinUaggChain, UaggOuterChain, 
                Write, INVALID, 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9607a376/src/test/java/org/apache/sysml/test/integration/functions/frame/FrameMatrixCastingTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/frame/FrameMatrixCastingTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/frame/FrameMatrixCastingTest.java
new file mode 100644
index 0000000..0baf9c2
--- /dev/null
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/frame/FrameMatrixCastingTest.java
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.functions.frame;
+
+import java.io.IOException;
+
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.lops.LopProperties.ExecType;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.parser.Expression.ValueType;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.io.FrameReader;
+import org.apache.sysml.runtime.io.FrameReaderFactory;
+import org.apache.sysml.runtime.io.FrameWriter;
+import org.apache.sysml.runtime.io.FrameWriterFactory;
+import org.apache.sysml.runtime.io.MatrixReader;
+import org.apache.sysml.runtime.io.MatrixReaderFactory;
+import org.apache.sysml.runtime.io.MatrixWriter;
+import org.apache.sysml.runtime.io.MatrixWriterFactory;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.FrameBlock;
+import org.apache.sysml.runtime.matrix.data.InputInfo;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.OutputInfo;
+import org.apache.sysml.runtime.util.DataConverter;
+import org.apache.sysml.runtime.util.MapReduceTool;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.utils.TestUtils;
+import org.junit.Test;
+
+/**
+ * 
+ */
+public class FrameMatrixCastingTest extends AutomatedTestBase
+{
+       private final static String TEST_DIR = "functions/frame/";
+       private final static String TEST_NAME1 = "Frame2MatrixCast";
+       private final static String TEST_NAME2 = "Matrix2FrameCast";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FrameMatrixCastingTest.class.getSimpleName() + "/";
+
+       private final static int rows = 2593;
+       private final static int cols1 = 372;
+       private final static int cols2 = 1102;
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"B"}));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"B"}));             
   
+       }
+       
+       @Test
+       public void testStringFrame2MatrixCastSingleCP() {
+               runFrameCastingTest(TEST_NAME1, false, ValueType.STRING, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testStringFrame2MatrixCastMultiCP() {
+               runFrameCastingTest(TEST_NAME1, true, ValueType.STRING, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testDoubleFrame2MatrixCastSingleCP() {
+               runFrameCastingTest(TEST_NAME1, false, ValueType.DOUBLE, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testDoubleFrame2MatrixCastMultiCP() {
+               runFrameCastingTest(TEST_NAME1, true, ValueType.DOUBLE, 
ExecType.CP);
+       }
+
+       @Test
+       public void testMatrix2FrameCastSingleCP() {
+               runFrameCastingTest(TEST_NAME2, false, null, ExecType.CP);
+       }
+       
+       @Test
+       public void testMatrix2FrameCastMultiCP() {
+               runFrameCastingTest(TEST_NAME2, true, null, ExecType.CP);
+       }
+       
+       @Test
+       public void testStringFrame2MatrixCastSingleSpark() {
+               runFrameCastingTest(TEST_NAME1, false, ValueType.STRING, 
ExecType.SPARK);
+       }
+       
+       @Test
+       public void testStringFrame2MatrixCastMultiSpark() {
+               runFrameCastingTest(TEST_NAME1, true, ValueType.STRING, 
ExecType.SPARK);
+       }
+       
+       @Test
+       public void testDoubleFrame2MatrixCastSingleSpark() {
+               runFrameCastingTest(TEST_NAME1, false, ValueType.DOUBLE, 
ExecType.SPARK);
+       }
+       
+       @Test
+       public void testDoubleFrame2MatrixCastMultiSpark() {
+               runFrameCastingTest(TEST_NAME1, true, ValueType.DOUBLE, 
ExecType.SPARK);
+       }
+
+       /*TODO write distributed frame missing
+       @Test
+       public void testMatrix2FrameCastSingleSpark() {
+               runFrameCastingTest(TEST_NAME2, false, null, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testMatrix2FrameCastMultiSpark() {
+               runFrameCastingTest(TEST_NAME2, true, null, ExecType.SPARK);
+       }
+       */
+       
+       /**
+        * 
+        * @param testname
+        * @param schema
+        * @param wildcard
+        */
+       private void runFrameCastingTest( String testname, boolean multColBlks, 
ValueType vt, ExecType et)
+       {
+               //rtplatform for MR
+               RUNTIME_PLATFORM platformOld = rtplatform;
+               switch( et ){
+                       case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
+                       case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+                       default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
+               }
+       
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               if( rtplatform == RUNTIME_PLATFORM.SPARK )
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               
+               try
+               {
+                       int cols = multColBlks ? cols2 : cols1;
+                       
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[]{"-explain","-args", 
input("A"), output("B") };
+                       
+                       //data generation
+                       double[][] A = getRandomMatrix(rows, cols, -1, 1, 0.9, 
7); 
+                       DataType dtin = testname.equals(TEST_NAME1) ? 
DataType.FRAME : DataType.MATRIX;
+                       ValueType vtin = testname.equals(TEST_NAME1) ? vt : 
ValueType.DOUBLE;
+                       writeMatrixOrFrameInput(input("A"), A, rows, cols, 
dtin, vtin);
+                       
+                       //run testcase
+                       runTest(true, false, null, -1);
+                       
+                       //compare matrices
+                       DataType dtout = testname.equals(TEST_NAME1) ? 
DataType.MATRIX : DataType.FRAME;
+                       double[][] B = readMatrixOrFrameInput(output("B"), 
rows, cols, dtout);
+                       TestUtils.compareMatrices(A, B, rows, cols, 0);
+               }
+               catch(Exception ex) {
+                       throw new RuntimeException(ex);
+               }
+               finally {
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
+       }
+       
+       /**
+        * 
+        * @param fname
+        * @param A
+        * @param rows
+        * @param cols
+        * @param dt
+        * @param vt
+        * @throws DMLRuntimeException
+        * @throws IOException
+        */
+       private void writeMatrixOrFrameInput(String fname, double[][] A, int 
rows, int cols, DataType dt, ValueType vt) 
+               throws DMLRuntimeException, IOException 
+       {
+               int blksize = ConfigurationManager.getBlocksize();
+               
+               //write input data
+               if( dt == DataType.FRAME ) {
+                       FrameBlock fb = 
DataConverter.convertToFrameBlock(DataConverter.convertToMatrixBlock(A), vt);
+                       FrameWriter writer = 
FrameWriterFactory.createFrameWriter(OutputInfo.BinaryBlockOutputInfo);
+                       writer.writeFrameToHDFS(fb, fname, rows, cols);
+               }
+               else {
+                       MatrixBlock mb = DataConverter.convertToMatrixBlock(A);
+                       MatrixWriter writer = 
MatrixWriterFactory.createMatrixWriter(OutputInfo.BinaryBlockOutputInfo);
+                       writer.writeMatrixToHDFS(mb, fname, (long)rows, 
(long)cols, blksize, blksize, -1);
+               }
+               
+               //write meta data
+               MatrixCharacteristics mc = new MatrixCharacteristics(rows, 
cols, blksize, blksize);
+               MapReduceTool.writeMetaDataFile(fname+".mtd", vt, dt, mc, 
OutputInfo.BinaryBlockOutputInfo);
+       
+       }
+       
+       /**
+        * 
+        * @param fname
+        * @param rows
+        * @param cols
+        * @param dt
+        * @return
+        * @throws DMLRuntimeException
+        * @throws IOException
+        */
+       private double[][] readMatrixOrFrameInput(String fname, int rows, int 
cols, DataType dt) 
+               throws DMLRuntimeException, IOException 
+       {
+               MatrixBlock ret = null;
+               
+               //read input data
+               if( dt == DataType.FRAME ) {
+                       FrameReader reader = 
FrameReaderFactory.createFrameReader(InputInfo.BinaryBlockInputInfo);
+                       FrameBlock fb = reader.readFrameFromHDFS(fname, rows, 
cols);
+                       ret = DataConverter.convertToMatrixBlock(fb);
+               }
+               else {
+                       int blksize = ConfigurationManager.getBlocksize();
+                       MatrixReader reader = 
MatrixReaderFactory.createMatrixReader(InputInfo.BinaryBlockInputInfo);
+                       ret = reader.readMatrixFromHDFS(fname, rows, cols, 
blksize, blksize, -1);
+               }
+               
+               return DataConverter.convertToDoubleMatrix(ret);
+       }
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9607a376/src/test/scripts/functions/frame/Frame2MatrixCast.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/frame/Frame2MatrixCast.dml 
b/src/test/scripts/functions/frame/Frame2MatrixCast.dml
new file mode 100644
index 0000000..61d2ad4
--- /dev/null
+++ b/src/test/scripts/functions/frame/Frame2MatrixCast.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+
+B = as.matrix(A);
+
+write(B, $2, format="binary");

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9607a376/src/test/scripts/functions/frame/Matrix2FrameCast.dml
----------------------------------------------------------------------
diff --git a/src/test/scripts/functions/frame/Matrix2FrameCast.dml 
b/src/test/scripts/functions/frame/Matrix2FrameCast.dml
new file mode 100644
index 0000000..23009d6
--- /dev/null
+++ b/src/test/scripts/functions/frame/Matrix2FrameCast.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+
+B = as.frame(A);
+
+write(B, $2, format="binary");

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9607a376/src/test_suites/java/org/apache/sysml/test/integration/functions/frame/ZPackageSuite.java
----------------------------------------------------------------------
diff --git 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/frame/ZPackageSuite.java
 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/frame/ZPackageSuite.java
index d810534..819225c 100644
--- 
a/src/test_suites/java/org/apache/sysml/test/integration/functions/frame/ZPackageSuite.java
+++ 
b/src/test_suites/java/org/apache/sysml/test/integration/functions/frame/ZPackageSuite.java
@@ -32,6 +32,7 @@ import org.junit.runners.Suite;
        FrameCopyTest.class,
        FrameGetSetTest.class,
        FrameIndexingTest.class,
+       FrameMatrixCastingTest.class,
        FrameReadWriteTest.class,
        FrameSchemaReadTest.class,
        FrameSerializationTest.class,

Reply via email to