[SYSTEMML-927] Fix frame schema handling in spark cast/write instruction

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/81d2b641
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/81d2b641
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/81d2b641

Branch: refs/heads/master
Commit: 81d2b641d99743ab54528a214659a5166e65aabe
Parents: 69a7858
Author: Matthias Boehm <mbo...@us.ibm.com>
Authored: Sat Sep 17 05:39:56 2016 +0200
Committer: Matthias Boehm <mbo...@us.ibm.com>
Committed: Sat Sep 17 00:25:22 2016 -0700

----------------------------------------------------------------------
 .../sysml/runtime/instructions/spark/CastSPInstruction.java | 9 +++++++++
 .../runtime/instructions/spark/WriteSPInstruction.java      | 9 ++++++---
 2 files changed, 15 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/81d2b641/src/main/java/org/apache/sysml/runtime/instructions/spark/CastSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CastSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CastSPInstruction.java
index d869f11..4487b20 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/CastSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/CastSPInstruction.java
@@ -19,9 +19,12 @@
 
 package org.apache.sysml.runtime.instructions.spark;
 
+import java.util.Collections;
+
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.lops.UnaryCP;
+import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
@@ -88,5 +91,11 @@ public class CastSPInstruction extends UnarySPInstruction
                sec.setRDDHandleForVariable(output.getName(), out);
                updateUnaryOutputMatrixCharacteristics(sec, input1.getName(), 
output.getName());
                sec.addLineageRDD(output.getName(), input1.getName());
+               
+               //update schema information for output frame
+               if( opcode.equals(UnaryCP.CAST_AS_FRAME_OPCODE) ) {
+                       sec.getFrameObject(output.getName()).setSchema(
+                               Collections.nCopies((int)mcIn.getCols(), 
ValueType.DOUBLE));
+               }
        }
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/81d2b641/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java
index e4e2606..1b974f9 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/WriteSPInstruction.java
@@ -21,6 +21,7 @@ package org.apache.sysml.runtime.instructions.spark;
 
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.List;
 import java.util.Random;
 
 import org.apache.hadoop.io.LongWritable;
@@ -136,6 +137,8 @@ public class WriteSPInstruction extends SPInstruction
 
                //get filename (literal or variable expression)
                String fname = ec.getScalarInput(input2.getName(), 
ValueType.STRING, input2.isLiteral()).getStringValue();
+               List<ValueType> schema = (input1.getDataType()==DataType.FRAME) 
? 
+                               
sec.getFrameObject(input1.getName()).getSchema() : null;
                
                try
                {
@@ -150,7 +153,7 @@ public class WriteSPInstruction extends SPInstruction
                        if( input1.getDataType()==DataType.MATRIX )
                                processMatrixWriteInstruction(sec, fname, oi);
                        else
-                               processFrameWriteInstruction(sec, fname, oi);
+                               processFrameWriteInstruction(sec, fname, oi, 
schema);
                }
                catch(IOException ex)
                {
@@ -279,7 +282,7 @@ public class WriteSPInstruction extends SPInstruction
         * @throws IOException 
         */
        @SuppressWarnings("unchecked")
-       protected void processFrameWriteInstruction(SparkExecutionContext sec, 
String fname, OutputInfo oi) 
+       protected void processFrameWriteInstruction(SparkExecutionContext sec, 
String fname, OutputInfo oi, List<ValueType> schema) 
                throws DMLRuntimeException, IOException
        {
                //get input rdd
@@ -310,7 +313,7 @@ public class WriteSPInstruction extends SPInstruction
                }
                
                // write meta data file
-               MapReduceTool.writeMetaDataFile(fname + ".mtd", 
input1.getValueType(), null, DataType.FRAME, mc, oi, formatProperties); 
+               MapReduceTool.writeMetaDataFile(fname + ".mtd", 
input1.getValueType(), schema, DataType.FRAME, mc, oi, formatProperties);       
        }
        
        /**

Reply via email to