Repository: incubator-systemml Updated Branches: refs/heads/master b7657dbc3 -> d39865e9e
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java index 3841bc8..d8446a9 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLResults.java @@ -24,11 +24,7 @@ import java.util.Set; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.rdd.RDD; import org.apache.spark.sql.DataFrame; -import org.apache.sysml.hops.OptimizerUtils; -import org.apache.sysml.parser.Expression.ValueType; -import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.LocalVariableMap; -import org.apache.sysml.runtime.controlprogram.caching.CacheException; import org.apache.sysml.runtime.controlprogram.caching.FrameObject; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; @@ -39,11 +35,6 @@ import org.apache.sysml.runtime.instructions.cp.DoubleObject; import org.apache.sysml.runtime.instructions.cp.IntObject; import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.instructions.cp.StringObject; -import org.apache.sysml.runtime.matrix.MatrixCharacteristics; -import org.apache.sysml.runtime.matrix.MatrixDimensionsMetaData; -import org.apache.sysml.runtime.matrix.data.FrameBlock; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; -import org.apache.sysml.runtime.util.DataConverter; import scala.Tuple1; import scala.Tuple10; @@ -120,21 +111,7 @@ public class MLResults { */ public MatrixObject getMatrixObject(String outputName) { Data data = getData(outputName); - if(data instanceof ScalarObject) { - double val = getDouble(outputName); - MatrixObject one_X_one_mo = new MatrixObject(ValueType.DOUBLE, " ", new MatrixDimensionsMetaData(new MatrixCharacteristics(1, 1, OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE, 1))); - MatrixBlock mb = new MatrixBlock(1, 1, false); - mb.allocateDenseBlock(); - mb.setValue(0, 0, val); - try { - one_X_one_mo.acquireModify(mb); - one_X_one_mo.release(); - } catch (CacheException e) { - throw new RuntimeException(e); - } - return one_X_one_mo; - } - else if (!(data instanceof MatrixObject)) { + if (!(data instanceof MatrixObject)) { throw new MLContextException("Variable '" + outputName + "' not a matrix"); } MatrixObject mo = (MatrixObject) data; @@ -163,7 +140,7 @@ public class MLResults { * the name of the output * @return the output as a two-dimensional {@code double} array */ - public double[][] getDoubleMatrix(String outputName) { + public double[][] getMatrixAs2DDoubleArray(String outputName) { MatrixObject mo = getMatrixObject(outputName); double[][] doubleMatrix = MLContextConversionUtil.matrixObjectToDoubleMatrix(mo); return doubleMatrix; @@ -190,22 +167,16 @@ public class MLResults { * @return the output as a {@code JavaRDD<String>} in IJV format */ public JavaRDD<String> getJavaRDDStringIJV(String outputName) { - MatrixObject mo = getMatrixObject(outputName); - JavaRDD<String> javaRDDStringIJV = MLContextConversionUtil.matrixObjectToJavaRDDStringIJV(mo); - return javaRDDStringIJV; - } - - /** - * Obtain an output as a {@code JavaRDD<String>} in IJV format. - * - * @param outputName - * the name of the output - * @return the output as a {@code JavaRDD<String>} in IJV format - */ - public JavaRDD<String> getFrameJavaRDDStringIJV(String outputName) { - FrameObject fo = getFrameObject(outputName); - JavaRDD<String> javaRDDStringIJV = MLContextConversionUtil.frameObjectToJavaRDDStringIJV(fo); - return javaRDDStringIJV; + if (isMatrixObject(outputName)) { + MatrixObject mo = getMatrixObject(outputName); + JavaRDD<String> javaRDDStringIJV = MLContextConversionUtil.matrixObjectToJavaRDDStringIJV(mo); + return javaRDDStringIJV; + } else if (isFrameObject(outputName)) { + FrameObject fo = getFrameObject(outputName); + JavaRDD<String> javaRDDStringIJV = MLContextConversionUtil.frameObjectToJavaRDDStringIJV(fo); + return javaRDDStringIJV; + } + return null; } /** @@ -227,22 +198,16 @@ public class MLResults { * @return the output as a {@code JavaRDD<String>} in CSV format */ public JavaRDD<String> getJavaRDDStringCSV(String outputName) { - MatrixObject mo = getMatrixObject(outputName); - JavaRDD<String> javaRDDStringCSV = MLContextConversionUtil.matrixObjectToJavaRDDStringCSV(mo); - return javaRDDStringCSV; - } - - /** - * Obtain an output as a {@code JavaRDD<String>} in CSV format. - * - * @param outputName - * the name of the output - * @return the output as a {@code JavaRDD<String>} in CSV format - */ - public JavaRDD<String> getFrameJavaRDDStringCSV(String outputName, String delimiter) { - FrameObject fo = getFrameObject(outputName); - JavaRDD<String> javaRDDStringCSV = MLContextConversionUtil.frameObjectToJavaRDDStringCSV(fo, delimiter); - return javaRDDStringCSV; + if (isMatrixObject(outputName)) { + MatrixObject mo = getMatrixObject(outputName); + JavaRDD<String> javaRDDStringCSV = MLContextConversionUtil.matrixObjectToJavaRDDStringCSV(mo); + return javaRDDStringCSV; + } else if (isFrameObject(outputName)) { + FrameObject fo = getFrameObject(outputName); + JavaRDD<String> javaRDDStringCSV = MLContextConversionUtil.frameObjectToJavaRDDStringCSV(fo, ","); + return javaRDDStringCSV; + } + return null; } /** @@ -264,23 +229,16 @@ public class MLResults { * @return the output as a {@code RDD<String>} in CSV format */ public RDD<String> getRDDStringCSV(String outputName) { - MatrixObject mo = getMatrixObject(outputName); - RDD<String> rddStringCSV = MLContextConversionUtil.matrixObjectToRDDStringCSV(mo); - return rddStringCSV; - } - - - /** - * Obtain an output as a {@code RDD<String>} in CSV format. - * - * @param outputName - * the name of the output - * @return the output as a {@code RDD<String>} in CSV format - */ - public RDD<String> getFrameRDDStringCSV(String outputName, String delimiter) { - FrameObject fo = getFrameObject(outputName); - RDD<String> rddStringCSV = MLContextConversionUtil.frameObjectToRDDStringCSV(fo, delimiter); - return rddStringCSV; + if (isMatrixObject(outputName)) { + MatrixObject mo = getMatrixObject(outputName); + RDD<String> rddStringCSV = MLContextConversionUtil.matrixObjectToRDDStringCSV(mo); + return rddStringCSV; + } else if (isFrameObject(outputName)) { + FrameObject fo = getFrameObject(outputName); + RDD<String> rddStringCSV = MLContextConversionUtil.frameObjectToRDDStringCSV(fo, ","); + return rddStringCSV; + } + return null; } /** @@ -304,26 +262,21 @@ public class MLResults { * @return the output as a {@code RDD<String>} in IJV format */ public RDD<String> getRDDStringIJV(String outputName) { - MatrixObject mo = getMatrixObject(outputName); - RDD<String> rddStringIJV = MLContextConversionUtil.matrixObjectToRDDStringIJV(mo); - return rddStringIJV; - } - - /** - * Obtain an output as a {@code RDD<String>} in IJV format. - * - * @param outputName - * the name of the output - * @return the output as a {@code RDD<String>} in IJV format - */ - public RDD<String> getFrameRDDStringIJV(String outputName) { - FrameObject fo = getFrameObject(outputName); - RDD<String> rddStringIJV = MLContextConversionUtil.frameObjectToRDDStringIJV(fo); - return rddStringIJV; + if (isMatrixObject(outputName)) { + MatrixObject mo = getMatrixObject(outputName); + RDD<String> rddStringIJV = MLContextConversionUtil.matrixObjectToRDDStringIJV(mo); + return rddStringIJV; + } else if (isFrameObject(outputName)) { + FrameObject fo = getFrameObject(outputName); + RDD<String> rddStringIJV = MLContextConversionUtil.frameObjectToRDDStringIJV(fo); + return rddStringIJV; + } + return null; } /** - * Obtain an output as a {@code DataFrame} of doubles with an ID column. + * Obtain an output as a {@code DataFrame}. If outputting a Matrix, this + * will be a DataFrame of doubles with an ID column. * <p> * The following matrix in DML: * </p> @@ -338,12 +291,53 @@ public class MLResults { * * @param outputName * the name of the output - * @return the output as a {@code DataFrame} of doubles with an ID column + * @return the output as a {@code DataFrame} */ public DataFrame getDataFrame(String outputName) { - MatrixObject mo = getMatrixObject(outputName); - DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, false); - return df; + if (isMatrixObject(outputName)) { + MatrixObject mo = getMatrixObject(outputName); + DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, false); + return df; + } else if (isFrameObject(outputName)) { + FrameObject mo = getFrameObject(outputName); + DataFrame df = MLContextConversionUtil.frameObjectToDataFrame(mo, sparkExecutionContext); + return df; + } + return null; + } + + /** + * Is the output a MatrixObject? + * + * @param outputName + * the name of the output + * @return {@code true} if the output is a MatrixObject, {@code false} + * otherwise. + */ + private boolean isMatrixObject(String outputName) { + Data data = getData(outputName); + if (data instanceof MatrixObject) { + return true; + } else { + return false; + } + } + + /** + * Is the output a FrameObject? + * + * @param outputName + * the name of the output + * @return {@code true} if the output is a FrameObject, {@code false} + * otherwise. + */ + private boolean isFrameObject(String outputName) { + Data data = getData(outputName); + if (data instanceof FrameObject) { + return true; + } else { + return false; + } } /** @@ -376,6 +370,9 @@ public class MLResults { * ID column */ public DataFrame getDataFrame(String outputName, boolean isVectorDF) { + if (isFrameObject(outputName)) { + throw new MLContextException("This method currently supports only matrices"); + } MatrixObject mo = getMatrixObject(outputName); DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, isVectorDF); return df; @@ -400,6 +397,9 @@ public class MLResults { * @return the output as a {@code DataFrame} of doubles with an ID column */ public DataFrame getDataFrameDoubleWithIDColumn(String outputName) { + if (isFrameObject(outputName)) { + throw new MLContextException("This method currently supports only matrices"); + } MatrixObject mo = getMatrixObject(outputName); DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, false); return df; @@ -424,6 +424,9 @@ public class MLResults { * @return the output as a {@code DataFrame} of vectors with an ID column */ public DataFrame getDataFrameVectorWithIDColumn(String outputName) { + if (isFrameObject(outputName)) { + throw new MLContextException("This method currently supports only matrices"); + } MatrixObject mo = getMatrixObject(outputName); DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, true); return df; @@ -448,6 +451,9 @@ public class MLResults { * @return the output as a {@code DataFrame} of doubles with no ID column */ public DataFrame getDataFrameDoubleNoIDColumn(String outputName) { + if (isFrameObject(outputName)) { + throw new MLContextException("This method currently supports only matrices"); + } MatrixObject mo = getMatrixObject(outputName); DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, false); df = df.sort("ID").drop("ID"); @@ -473,6 +479,9 @@ public class MLResults { * @return the output as a {@code DataFrame} of vectors with no ID column */ public DataFrame getDataFrameVectorNoIDColumn(String outputName) { + if (isFrameObject(outputName)) { + throw new MLContextException("This method currently supports only matrices"); + } MatrixObject mo = getMatrixObject(outputName); DataFrame df = MLContextConversionUtil.matrixObjectToDataFrame(mo, sparkExecutionContext, true); df = df.sort("ID").drop("ID"); @@ -480,29 +489,29 @@ public class MLResults { } /** - * Obtain an output as a {@code DataFrame} without an ID column. + * Obtain an output as a {@code Matrix}. * * @param outputName * the name of the output - * @return the output as a {@code DataFrame} without an ID column + * @return the output as a {@code Matrix} */ - public DataFrame getFrameDataFrame(String outputName) { - FrameObject mo = getFrameObject(outputName); - DataFrame df = MLContextConversionUtil.frameObjectToDataFrame(mo, sparkExecutionContext); - return df; + public Matrix getMatrix(String outputName) { + MatrixObject mo = getMatrixObject(outputName); + Matrix matrix = new Matrix(mo, sparkExecutionContext); + return matrix; } /** - * Obtain an output as a {@code Matrix}. + * Obtain an output as a {@code Frame}. * * @param outputName * the name of the output - * @return the output as a {@code Matrix} + * @return the output as a {@code Frame} */ - public Matrix getMatrix(String outputName) { - MatrixObject mo = getMatrixObject(outputName); - Matrix matrix = new Matrix(mo, sparkExecutionContext); - return matrix; + public Frame getFrame(String outputName) { + FrameObject fo = getFrameObject(outputName); + Frame frame = new Frame(fo, sparkExecutionContext); + return frame; } /** @@ -526,22 +535,10 @@ public class MLResults { * the name of the output * @return the output as a two-dimensional {@code String} array */ - public String[][] getFrame(String outputName) { - try { - Data data = getData(outputName); - if (!(data instanceof FrameObject)) { - throw new MLContextException("Variable '" + outputName + "' not a frame"); - } - FrameObject fo = (FrameObject) data; - FrameBlock fb = fo.acquireRead(); - String[][] frame = DataConverter.convertToStringFrame(fb); - fo.release(); - return frame; - } catch (CacheException e) { - throw new MLContextException("Cache exception when reading frame", e); - } catch (DMLRuntimeException e) { - throw new MLContextException("DML runtime exception when reading frame", e); - } + public String[][] getFrameAs2DStringArray(String outputName) { + FrameObject frameObject = getFrameObject(outputName); + String[][] frame = MLContextConversionUtil.frameObjectTo2DStringArray(frameObject); + return frame; } /** @@ -569,8 +566,10 @@ public class MLResults { if (data instanceof ScalarObject) { ScalarObject so = (ScalarObject) data; return so.getValue(); - } else if(data instanceof MatrixObject) { + } else if (data instanceof MatrixObject) { return getMatrix(outputName); + } else if (data instanceof FrameObject) { + return getFrame(outputName); } else { return data; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/MatrixMetadata.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MatrixMetadata.java b/src/main/java/org/apache/sysml/api/mlcontext/MatrixMetadata.java index 1ea3a10..513b74d 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MatrixMetadata.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MatrixMetadata.java @@ -27,7 +27,7 @@ import org.apache.sysml.runtime.matrix.MatrixCharacteristics; * columns per block in the matrix. * */ -public class MatrixMetadata { +public class MatrixMetadata extends Metadata { private Long numRows = null; private Long numColumns = null; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/Metadata.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/Metadata.java b/src/main/java/org/apache/sysml/api/mlcontext/Metadata.java new file mode 100644 index 0000000..c1c0a36 --- /dev/null +++ b/src/main/java/org/apache/sysml/api/mlcontext/Metadata.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.api.mlcontext; + +/** + * Abstract metadata class for MLContext API. Complex types such as SystemML + * matrices and frames typically require metadata, so this abstract class serves + * as a common parent class of these types. + * + */ +public abstract class Metadata { + +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/Script.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/Script.java b/src/main/java/org/apache/sysml/api/mlcontext/Script.java index bfa947c..17a3996 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/Script.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/Script.java @@ -68,13 +68,9 @@ public class Script { */ private Set<String> inputVariables = new LinkedHashSet<String>(); /** - * The input variable type (if its frame of matrix). + * The input matrix or frame metadata if present. */ - private Map<String, Boolean> inputVariablesType = new LinkedHashMap<String, Boolean>(); - /** - * The input matrix metadata if present. - */ - private Map<String, MatrixMetadata> inputMatrixMetadata = new LinkedHashMap<String, MatrixMetadata>(); + private Map<String, Metadata> inputMetadata = new LinkedHashMap<String, Metadata>(); /** * The output variables. */ @@ -186,15 +182,6 @@ public class Script { } /** - * Obtain the input variable type flag (if its frame or not) - * - * @return the input variable names - */ - public Map<String, Boolean> getInputVariablesType() { - return inputVariablesType; - } - - /** * Obtain the output variable names as an unmodifiable set of strings. * * @return the output variable names @@ -223,12 +210,12 @@ public class Script { } /** - * Obtain an unmodifiable map of input matrix metadata. + * Obtain an unmodifiable map of input matrix/frame metadata. * - * @return input matrix metadata + * @return input matrix/frame metadata */ - public Map<String, MatrixMetadata> getInputMatrixMetadata() { - return Collections.unmodifiableMap(inputMatrixMetadata); + public Map<String, Metadata> getInputMetadata() { + return Collections.unmodifiableMap(inputMetadata); } /** @@ -317,7 +304,7 @@ public class Script { * @return {@code this} Script object to allow chaining of methods */ public Script in(String name, Object value) { - return in(name, value, (MatrixMetadata) null); + return in(name, value, null); } /** @@ -328,45 +315,11 @@ public class Script { * name of the input * @param value * value of the input - * @param matrixFormat - * optional matrix format + * @param metadata + * optional matrix/frame metadata * @return {@code this} Script object to allow chaining of methods */ - public Script in(String name, Object value, MatrixFormat matrixFormat) { - MatrixMetadata matrixMetadata = new MatrixMetadata(matrixFormat); - return in(name, value, matrixMetadata); - } - - /** - * Register an input (parameter ($) or variable) with optional matrix - * metadata. - * - * @param name - * name of the input - * @param value - * value of the input - * @param matrixMetadata - * optional matrix metadata - * @return {@code this} Script object to allow chaining of methods - */ - public Script in(String name, Object value, MatrixMetadata matrixMetadata) { - return in(name, value, matrixMetadata, false); - } - /** - * Register an input (parameter ($) or variable) with optional matrix - * metadata. - * - * @param name - * name of the input - * @param value - * value of the input - * @param matrixMetadata - * optional matrix metadata - * @param bFrame - * if input is of type frame - * @return {@code this} Script object to allow chaining of methods - */ - public Script in(String name, Object value, MatrixMetadata matrixMetadata, boolean bFrame) { + public Script in(String name, Object value, Metadata metadata) { MLContextUtil.checkInputValueType(name, value); if (inputs == null) { inputs = new LinkedHashMap<String, Object>(); @@ -380,17 +333,13 @@ public class Script { } inputParameters.put(name, value); } else { - Data data = MLContextUtil.convertInputType(name, value, matrixMetadata, bFrame); + Data data = MLContextUtil.convertInputType(name, value, metadata); if (data != null) { symbolTable.put(name, data); inputVariables.add(name); - if (inputVariablesType == null) { - inputVariablesType = new LinkedHashMap<String, Boolean>(); - } - inputVariablesType.put(name, new Boolean(bFrame)); if (data instanceof MatrixObject || data instanceof FrameObject) { - if (matrixMetadata != null) { - inputMatrixMetadata.put(name, matrixMetadata); + if (metadata != null) { + inputMetadata.put(name, metadata); } } } @@ -454,8 +403,7 @@ public class Script { inputs.clear(); inputParameters.clear(); inputVariables.clear(); - inputVariablesType.clear(); - inputMatrixMetadata.clear(); + inputMetadata.clear(); } /** @@ -556,11 +504,10 @@ public class Script { sb.append(" = " + quotedString + ";\n"); } else if (MLContextUtil.isBasicType(inValue)) { sb.append(" = read('', data_type='scalar');\n"); + } else if (MLContextUtil.doesSymbolTableContainFrameObject(symbolTable, in)) { + sb.append(" = read('', data_type='frame');\n"); } else { - if(inputVariablesType.get(in).booleanValue()) - sb.append(" = read('', data_type='frame');\n"); - else - sb.append(" = read('');\n"); + sb.append(" = read('');\n"); } } else if (isPYDML()) { if (inValue instanceof String) { @@ -568,11 +515,10 @@ public class Script { sb.append(" = " + quotedString + "\n"); } else if (MLContextUtil.isBasicType(inValue)) { sb.append(" = load('', data_type='scalar')\n"); + } else if (MLContextUtil.doesSymbolTableContainFrameObject(symbolTable, in)) { + sb.append(" = load('', data_type='frame')\n"); } else { - if(inputVariablesType.get(in).booleanValue()) - sb.append(" = load('', data_type='frame')\n"); - else - sb.append(" = load('')\n"); + sb.append(" = load('')\n"); } } @@ -603,7 +549,7 @@ public class Script { public String toString() { StringBuilder sb = new StringBuilder(); - sb.append(MLContextUtil.displayInputs("Inputs", inputs)); + sb.append(MLContextUtil.displayInputs("Inputs", inputs, symbolTable)); sb.append("\n"); sb.append(MLContextUtil.displayOutputs("Outputs", outputVariables, symbolTable)); return sb.toString(); @@ -623,7 +569,7 @@ public class Script { sb.append("Script Type: "); sb.append(scriptType); sb.append("\n\n"); - sb.append(MLContextUtil.displayInputs("Inputs", inputs)); + sb.append(MLContextUtil.displayInputs("Inputs", inputs, symbolTable)); sb.append("\n"); sb.append(MLContextUtil.displayOutputs("Outputs", outputVariables, symbolTable)); sb.append("\n"); @@ -649,7 +595,7 @@ public class Script { * @return the script inputs */ public String displayInputs() { - return MLContextUtil.displayInputs("Inputs", inputs); + return MLContextUtil.displayInputs("Inputs", inputs, symbolTable); } /** http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java index cf0d09f..2973ed2 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java @@ -348,14 +348,14 @@ public class ScriptExecutor { */ protected void restoreInputsInSymbolTable() { Map<String, Object> inputs = script.getInputs(); - Map<String, MatrixMetadata> inputMatrixMetadata = script.getInputMatrixMetadata(); + Map<String, Metadata> inputMetadata = script.getInputMetadata(); LocalVariableMap symbolTable = script.getSymbolTable(); Set<String> inputVariables = script.getInputVariables(); for (String inputVariable : inputVariables) { if (symbolTable.get(inputVariable) == null) { // retrieve optional metadata if it exists - MatrixMetadata mm = inputMatrixMetadata.get(inputVariable); - script.in(inputVariable, inputs.get(inputVariable), mm, script.getInputVariablesType().get(inputVariable)); + Metadata m = inputMetadata.get(inputVariable); + script.in(inputVariable, inputs.get(inputVariable), m); } } } @@ -451,8 +451,8 @@ public class ScriptExecutor { if (symbolTable != null) { String[] inputs = (script.getInputVariables() == null) ? new String[0] : script.getInputVariables() .toArray(new String[0]); - String[] outputs = (script.getOutputVariables() == null) ? new String[0] : script.getOutputVariables() - .toArray(new String[0]); + String[] outputs = (script.getOutputVariables() == null) ? new String[0] + : script.getOutputVariables().toArray(new String[0]); RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs); ProgramRewriter programRewriter = new ProgramRewriter(rewrite); try { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java new file mode 100644 index 0000000..98c8b10 --- /dev/null +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java @@ -0,0 +1,557 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.test.integration.mlcontext; + +import static org.apache.sysml.api.mlcontext.ScriptFactory.dml; +import static org.apache.sysml.api.mlcontext.ScriptFactory.pydml; + +import java.io.File; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.sysml.api.mlcontext.FrameFormat; +import org.apache.sysml.api.mlcontext.FrameMetadata; +import org.apache.sysml.api.mlcontext.MLContext; +import org.apache.sysml.api.mlcontext.MLResults; +import org.apache.sysml.api.mlcontext.MatrixFormat; +import org.apache.sysml.api.mlcontext.MatrixMetadata; +import org.apache.sysml.api.mlcontext.Script; +import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils; +import org.apache.sysml.test.integration.AutomatedTestBase; +import org.apache.sysml.test.integration.mlcontext.MLContextTest.CommaSeparatedValueStringToRow; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import scala.collection.Iterator; + +public class MLContextFrameTest extends AutomatedTestBase { + protected final static String TEST_DIR = "org/apache/sysml/api/mlcontext"; + protected final static String TEST_NAME = "MLContextFrame"; + + public static enum SCRIPT_TYPE { + DML, PYDML + }; + + public static enum IO_TYPE { + ANY, FILE, JAVA_RDD_STR_CSV, JAVA_RDD_STR_IJV, RDD_STR_CSV, RDD_STR_IJV, DATAFRAME + }; + + private static SparkConf conf; + private static JavaSparkContext sc; + private static MLContext ml; + + @BeforeClass + public static void setUpClass() { + if (conf == null) + conf = new SparkConf().setAppName("MLContextFrameTest").setMaster("local"); + if (sc == null) + sc = new JavaSparkContext(conf); + ml = new MLContext(sc); + } + + @Override + public void setUp() { + addTestConfiguration(TEST_DIR, TEST_NAME); + getAndLoadTestConfiguration(TEST_NAME); + } + + @Test + public void testFrameJavaRDD_CSV_DML() { + testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.ANY); + } + + @Test + public void testFrameJavaRDD_CSV_DML_OutJavaRddCSV() { + testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.JAVA_RDD_STR_CSV); + } + + @Test + public void testFrameJavaRDD_CSV_PYDML() { + testFrame(FrameFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.ANY); + } + + @Test + public void testFrameRDD_CSV_PYDML() { + testFrame(FrameFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.RDD_STR_CSV, IO_TYPE.ANY); + } + + @Test + public void testFrameJavaRDD_CSV_PYDML_OutRddIJV() { + testFrame(FrameFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.RDD_STR_IJV); + } + + @Test + public void testFrameJavaRDD_IJV_DML() { + testFrame(FrameFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.ANY); + } + + @Test + public void testFrameRDD_IJV_DML() { + testFrame(FrameFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.RDD_STR_IJV, IO_TYPE.ANY); + } + + @Test + public void testFrameJavaRDD_IJV_DML_OutRddCSV() { + testFrame(FrameFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.RDD_STR_CSV); + } + + @Test + public void testFrameJavaRDD_IJV_PYDML() { + testFrame(FrameFormat.IJV, SCRIPT_TYPE.PYDML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.ANY); + } + + @Test + public void testFrameJavaRDD_IJV_PYDML_OutJavaRddIJV() { + testFrame(FrameFormat.IJV, SCRIPT_TYPE.PYDML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.JAVA_RDD_STR_IJV); + } + + @Test + public void testFrameFile_CSV_DML() { + testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.FILE, IO_TYPE.ANY); + } + + @Test + public void testFrameFile_CSV_PYDML() { + testFrame(FrameFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.FILE, IO_TYPE.ANY); + } + + @Test + public void testFrameFile_IJV_DML() { + testFrame(FrameFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.FILE, IO_TYPE.ANY); + } + + @Test + public void testFrameFile_IJV_PYDML() { + testFrame(FrameFormat.IJV, SCRIPT_TYPE.PYDML, IO_TYPE.FILE, IO_TYPE.ANY); + } + + @Test + public void testFrameDataFrame_CSV_DML() { + testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.DATAFRAME, IO_TYPE.ANY); + } + + @Test + public void testFrameDataFrame_CSV_PYDML() { + testFrame(FrameFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.DATAFRAME, IO_TYPE.ANY); + } + + @Test + public void testFrameDataFrameOutDataFrame_CSV_DML() { + testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.DATAFRAME, IO_TYPE.DATAFRAME); + } + + public void testFrame(FrameFormat format, SCRIPT_TYPE script_type, IO_TYPE inputType, IO_TYPE outputType) { + + System.out.println("MLContextTest - Frame JavaRDD<String> for format: " + format + " Script: " + script_type); + + List<String> listA = new ArrayList<String>(); + List<String> listB = new ArrayList<String>(); + FrameMetadata fmA = null, fmB = null; + Script script = null; + + if (inputType != IO_TYPE.FILE) { + if (format == FrameFormat.CSV) { + listA.add("1,Str2,3.0,true"); + listA.add("4,Str5,6.0,false"); + listA.add("7,Str8,9.0,true"); + + listB.add("Str12,13.0,true"); + listB.add("Str25,26.0,false"); + + fmA = new FrameMetadata(FrameFormat.CSV, 3, 4); + fmB = new FrameMetadata(FrameFormat.CSV, 2, 3); + } else if (format == FrameFormat.IJV) { + listA.add("1 1 1"); + listA.add("1 2 Str2"); + listA.add("1 3 3.0"); + listA.add("1 4 true"); + listA.add("2 1 4"); + listA.add("2 2 Str5"); + listA.add("2 3 6.0"); + listA.add("2 4 false"); + listA.add("3 1 7"); + listA.add("3 2 Str8"); + listA.add("3 3 9.0"); + listA.add("3 4 true"); + + listB.add("1 1 Str12"); + listB.add("1 2 13.0"); + listB.add("1 3 true"); + listB.add("2 1 Str25"); + listB.add("2 2 26.0"); + listB.add("2 3 false"); + + fmA = new FrameMetadata(FrameFormat.IJV, 3, 4); + fmB = new FrameMetadata(FrameFormat.IJV, 2, 3); + } + JavaRDD<String> javaRDDA = sc.parallelize(listA); + JavaRDD<String> javaRDDB = sc.parallelize(listB); + + if (inputType == IO_TYPE.DATAFRAME) { + JavaRDD<Row> javaRddRowA = javaRDDA.map(new MLContextTest.CommaSeparatedValueStringToRow()); + JavaRDD<Row> javaRddRowB = javaRDDB.map(new MLContextTest.CommaSeparatedValueStringToRow()); + + ValueType[] schemaA = { ValueType.INT, ValueType.STRING, ValueType.DOUBLE, ValueType.BOOLEAN }; + List<ValueType> lschemaA = Arrays.asList(schemaA); + ValueType[] schemaB = { ValueType.STRING, ValueType.DOUBLE, ValueType.BOOLEAN }; + List<ValueType> lschemaB = Arrays.asList(schemaB); + + // Create DataFrame + SQLContext sqlContext = new SQLContext(sc); + StructType dfSchemaA = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(lschemaA); + DataFrame dataFrameA = sqlContext.createDataFrame(javaRddRowA, dfSchemaA); + StructType dfSchemaB = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(lschemaB); + DataFrame dataFrameB = sqlContext.createDataFrame(javaRddRowB, dfSchemaB); + if (script_type == SCRIPT_TYPE.DML) + script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", dataFrameA, fmA).in("B", dataFrameB, fmB).out("A") + .out("C"); + else if (script_type == SCRIPT_TYPE.PYDML) + // DO NOT USE ; at the end of any statment, it throws NPE + script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", dataFrameA, fmA) + .in("B", dataFrameB, fmB) + // Value for ROW index gets incremented at script + // level to adjust index in PyDML, but not for + // Column Index + .in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C"); + } else { + if (inputType == IO_TYPE.JAVA_RDD_STR_CSV || inputType == IO_TYPE.JAVA_RDD_STR_IJV) { + if (script_type == SCRIPT_TYPE.DML) + script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", javaRDDA, fmA).in("B", javaRDDB, fmB).out("A") + .out("C"); + else if (script_type == SCRIPT_TYPE.PYDML) + // DO NOT USE ; at the end of any statment, it throws + // NPE + script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", javaRDDA, fmA) + .in("B", javaRDDB, fmB) + // Value for ROW index gets incremented at + // script level to adjust index in PyDML, but + // not for Column Index + .in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C"); + } else if (inputType == IO_TYPE.RDD_STR_CSV || inputType == IO_TYPE.RDD_STR_IJV) { + RDD<String> rddA = JavaRDD.toRDD(javaRDDA); + RDD<String> rddB = JavaRDD.toRDD(javaRDDB); + + if (script_type == SCRIPT_TYPE.DML) + script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", rddA, fmA).in("B", rddB, fmB).out("A") + .out("C"); + else if (script_type == SCRIPT_TYPE.PYDML) + // DO NOT USE ; at the end of any statment, it throws + // NPE + script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", rddA, fmA).in("B", rddB, fmB) + // Value for ROW index gets incremented at + // script level to adjust index in PyDML, but + // not for Column Index + .in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C"); + } + + } + + } else { // Input type is file + String fileA = null, fileB = null; + if (format == FrameFormat.CSV) { + fileA = baseDirectory + File.separator + "FrameA.csv"; + fileB = baseDirectory + File.separator + "FrameB.csv"; + } else if (format == FrameFormat.IJV) { + fileA = baseDirectory + File.separator + "FrameA.ijv"; + fileB = baseDirectory + File.separator + "FrameB.ijv"; + } + + if (script_type == SCRIPT_TYPE.DML) + script = dml("A=read($A); B=read($B);A[2:3,2:4]=B;C=A[2:3,2:3]").in("$A", fileA, fmA) + .in("$B", fileB, fmB).out("A").out("C"); + else if (script_type == SCRIPT_TYPE.PYDML) + // DO NOT USE ; at the end of any statment, it throws NPE + script = pydml("A=load($A)\nB=load($B)\nA[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("$A", fileA) + .in("$B", fileB) + // Value for ROW index gets incremented at script level + // to adjust index in PyDML, but not for Column Index + .in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C"); + } + + MLResults mlResults = ml.execute(script); + + if (outputType == IO_TYPE.JAVA_RDD_STR_CSV) { + + JavaRDD<String> javaRDDStringCSVA = mlResults.getJavaRDDStringCSV("A"); + List<String> linesA = javaRDDStringCSVA.collect(); + Assert.assertEquals("1,Str2,3.0,true", linesA.get(0)); + Assert.assertEquals("4,Str12,13.0,true", linesA.get(1)); + Assert.assertEquals("7,Str25,26.0,false", linesA.get(2)); + + JavaRDD<String> javaRDDStringCSVC = mlResults.getJavaRDDStringCSV("C"); + List<String> linesC = javaRDDStringCSVC.collect(); + Assert.assertEquals("Str12,13.0", linesC.get(0)); + Assert.assertEquals("Str25,26.0", linesC.get(1)); + } else if (outputType == IO_TYPE.JAVA_RDD_STR_IJV) { + JavaRDD<String> javaRDDStringIJVA = mlResults.getJavaRDDStringIJV("A"); + List<String> linesA = javaRDDStringIJVA.collect(); + Assert.assertEquals("1 1 1", linesA.get(0)); + Assert.assertEquals("1 2 Str2", linesA.get(1)); + Assert.assertEquals("1 3 3.0", linesA.get(2)); + Assert.assertEquals("1 4 true", linesA.get(3)); + Assert.assertEquals("2 1 4", linesA.get(4)); + Assert.assertEquals("2 2 Str12", linesA.get(5)); + Assert.assertEquals("2 3 13.0", linesA.get(6)); + Assert.assertEquals("2 4 true", linesA.get(7)); + + JavaRDD<String> javaRDDStringIJVC = mlResults.getJavaRDDStringIJV("C"); + List<String> linesC = javaRDDStringIJVC.collect(); + Assert.assertEquals("1 1 Str12", linesC.get(0)); + Assert.assertEquals("1 2 13.0", linesC.get(1)); + Assert.assertEquals("2 1 Str25", linesC.get(2)); + Assert.assertEquals("2 2 26.0", linesC.get(3)); + } else if (outputType == IO_TYPE.RDD_STR_CSV) { + RDD<String> rddStringCSVA = mlResults.getRDDStringCSV("A"); + Iterator<String> iteratorA = rddStringCSVA.toLocalIterator(); + Assert.assertEquals("1,Str2,3.0,true", iteratorA.next()); + Assert.assertEquals("4,Str12,13.0,true", iteratorA.next()); + Assert.assertEquals("7,Str25,26.0,false", iteratorA.next()); + + RDD<String> rddStringCSVC = mlResults.getRDDStringCSV("C"); + Iterator<String> iteratorC = rddStringCSVC.toLocalIterator(); + Assert.assertEquals("Str12,13.0", iteratorC.next()); + Assert.assertEquals("Str25,26.0", iteratorC.next()); + } else if (outputType == IO_TYPE.RDD_STR_IJV) { + RDD<String> rddStringIJVA = mlResults.getRDDStringIJV("A"); + Iterator<String> iteratorA = rddStringIJVA.toLocalIterator(); + Assert.assertEquals("1 1 1", iteratorA.next()); + Assert.assertEquals("1 2 Str2", iteratorA.next()); + Assert.assertEquals("1 3 3.0", iteratorA.next()); + Assert.assertEquals("1 4 true", iteratorA.next()); + Assert.assertEquals("2 1 4", iteratorA.next()); + Assert.assertEquals("2 2 Str12", iteratorA.next()); + Assert.assertEquals("2 3 13.0", iteratorA.next()); + Assert.assertEquals("2 4 true", iteratorA.next()); + Assert.assertEquals("3 1 7", iteratorA.next()); + Assert.assertEquals("3 2 Str25", iteratorA.next()); + Assert.assertEquals("3 3 26.0", iteratorA.next()); + Assert.assertEquals("3 4 false", iteratorA.next()); + + RDD<String> rddStringIJVC = mlResults.getRDDStringIJV("C"); + Iterator<String> iteratorC = rddStringIJVC.toLocalIterator(); + Assert.assertEquals("1 1 Str12", iteratorC.next()); + Assert.assertEquals("1 2 13.0", iteratorC.next()); + Assert.assertEquals("2 1 Str25", iteratorC.next()); + Assert.assertEquals("2 2 26.0", iteratorC.next()); + + } else if (outputType == IO_TYPE.DATAFRAME) { + + DataFrame dataFrameA = mlResults.getDataFrame("A"); + List<Row> listAOut = dataFrameA.collectAsList(); + + Row row1 = listAOut.get(0); + Assert.assertEquals("Mistmatch with expected value", "1", row1.getString(0)); + Assert.assertEquals("Mistmatch with expected value", "Str2", row1.getString(1)); + Assert.assertEquals("Mistmatch with expected value", "3.0", row1.getString(2)); + Assert.assertEquals("Mistmatch with expected value", "true", row1.getString(3)); + + Row row2 = listAOut.get(1); + Assert.assertEquals("Mistmatch with expected value", "4", row2.getString(0)); + Assert.assertEquals("Mistmatch with expected value", "Str12", row2.getString(1)); + Assert.assertEquals("Mistmatch with expected value", "13.0", row2.getString(2)); + Assert.assertEquals("Mistmatch with expected value", "true", row2.getString(3)); + + DataFrame dataFrameC = mlResults.getDataFrame("C"); + List<Row> listCOut = dataFrameC.collectAsList(); + + Row row3 = listCOut.get(0); + Assert.assertEquals("Mistmatch with expected value", "Str12", row3.getString(0)); + Assert.assertEquals("Mistmatch with expected value", "13.0", row3.getString(1)); + + Row row4 = listCOut.get(1); + Assert.assertEquals("Mistmatch with expected value", "Str25", row4.getString(0)); + Assert.assertEquals("Mistmatch with expected value", "26.0", row4.getString(1)); + } else { + String[][] frameA = mlResults.getFrameAs2DStringArray("A"); + Assert.assertEquals("Str2", frameA[0][1]); + Assert.assertEquals("3.0", frameA[0][2]); + Assert.assertEquals("13.0", frameA[1][2]); + Assert.assertEquals("true", frameA[1][3]); + Assert.assertEquals("Str25", frameA[2][1]); + + String[][] frameC = mlResults.getFrameAs2DStringArray("C"); + Assert.assertEquals("Str12", frameC[0][0]); + Assert.assertEquals("Str25", frameC[1][0]); + Assert.assertEquals("13.0", frameC[0][1]); + Assert.assertEquals("26.0", frameC[1][1]); + } + } + + @Test + public void testOutputFrameDML() { + System.out.println("MLContextFrameTest - output frame DML"); + + String s = "M = read($Min, data_type='frame', format='csv');"; + String csvFile = baseDirectory + File.separator + "one-two-three-four.csv"; + Script script = dml(s).in("$Min", csvFile).out("M"); + String[][] frame = ml.execute(script).getFrameAs2DStringArray("M"); + Assert.assertEquals("one", frame[0][0]); + Assert.assertEquals("two", frame[0][1]); + Assert.assertEquals("three", frame[1][0]); + Assert.assertEquals("four", frame[1][1]); + } + + @Test + public void testOutputFramePYDML() { + System.out.println("MLContextFrameTest - output frame PYDML"); + + String s = "M = load($Min, data_type='frame', format='csv')"; + String csvFile = baseDirectory + File.separator + "one-two-three-four.csv"; + Script script = pydml(s).in("$Min", csvFile).out("M"); + String[][] frame = ml.execute(script).getFrameAs2DStringArray("M"); + Assert.assertEquals("one", frame[0][0]); + Assert.assertEquals("two", frame[0][1]); + Assert.assertEquals("three", frame[1][0]); + Assert.assertEquals("four", frame[1][1]); + } + + @Test + public void testInputFrameAndMatrixOutputMatrix() { + System.out.println("MLContextFrameTest - input frame and matrix, output matrix"); + + List<String> dataA = new ArrayList<String>(); + dataA.add("Test1,4.0"); + dataA.add("Test2,5.0"); + dataA.add("Test3,6.0"); + JavaRDD<String> javaRddStringA = sc.parallelize(dataA); + + List<String> dataB = new ArrayList<String>(); + dataB.add("1.0"); + dataB.add("2.0"); + JavaRDD<String> javaRddStringB = sc.parallelize(dataB); + + JavaRDD<Row> javaRddRowA = javaRddStringA.map(new CommaSeparatedValueStringToRow()); + JavaRDD<Row> javaRddRowB = javaRddStringB.map(new CommaSeparatedValueStringToRow()); + + SQLContext sqlContext = new SQLContext(sc); + + List<StructField> fieldsA = new ArrayList<StructField>(); + fieldsA.add(DataTypes.createStructField("1", DataTypes.StringType, true)); + fieldsA.add(DataTypes.createStructField("2", DataTypes.DoubleType, true)); + StructType schemaA = DataTypes.createStructType(fieldsA); + DataFrame dataFrameA = sqlContext.createDataFrame(javaRddRowA, schemaA); + + List<StructField> fieldsB = new ArrayList<StructField>(); + fieldsB.add(DataTypes.createStructField("1", DataTypes.DoubleType, true)); + StructType schemaB = DataTypes.createStructType(fieldsB); + DataFrame dataFrameB = sqlContext.createDataFrame(javaRddRowB, schemaB); + + String dmlString = "[tA, tAM] = transformencode (target = A, spec = \"{ids: true ,recode: [ 1, 2 ]}\");\n" + + "C = tA %*% B;\n" + "M = s * C;"; + + Script script = dml(dmlString) + .in("A", dataFrameA, + new FrameMetadata(FrameFormat.CSV, dataFrameA.count(), (long) dataFrameA.columns().length)) + .in("B", dataFrameB, + new MatrixMetadata(MatrixFormat.CSV, dataFrameB.count(), (long) dataFrameB.columns().length)) + .in("s", 2).out("M"); + MLResults results = ml.execute(script); + double[][] matrix = results.getMatrixAs2DDoubleArray("M"); + Assert.assertEquals(6.0, matrix[0][0], 0.0); + Assert.assertEquals(12.0, matrix[1][0], 0.0); + Assert.assertEquals(18.0, matrix[2][0], 0.0); + } + + // NOTE: the ordering of the frame values seem to come out differently here + // than in the scala shell, + // so this should be investigated or explained. + // @Test + // public void testInputFrameOutputMatrixAndFrame() { + // System.out.println("MLContextFrameTest - input frame, output matrix and + // frame"); + // + // List<String> dataA = new ArrayList<String>(); + // dataA.add("Test1,Test4"); + // dataA.add("Test2,Test5"); + // dataA.add("Test3,Test6"); + // JavaRDD<String> javaRddStringA = sc.parallelize(dataA); + // + // JavaRDD<Row> javaRddRowA = javaRddStringA.map(new + // CommaSeparatedValueStringToRow()); + // + // SQLContext sqlContext = new SQLContext(sc); + // + // List<StructField> fieldsA = new ArrayList<StructField>(); + // fieldsA.add(DataTypes.createStructField("1", DataTypes.StringType, + // true)); + // fieldsA.add(DataTypes.createStructField("2", DataTypes.StringType, + // true)); + // StructType schemaA = DataTypes.createStructType(fieldsA); + // DataFrame dataFrameA = sqlContext.createDataFrame(javaRddRowA, schemaA); + // + // String dmlString = "[tA, tAM] = transformencode (target = A, spec = + // \"{ids: true ,recode: [ 1, 2 ]}\");\n"; + // + // Script script = dml(dmlString) + // .in("A", dataFrameA, + // new FrameMetadata(FrameFormat.CSV, dataFrameA.count(), (long) + // dataFrameA.columns().length)) + // .out("tA", "tAM"); + // MLResults results = ml.execute(script); + // double[][] matrix = results.getMatrixAs2DDoubleArray("tA"); + // Assert.assertEquals(1.0, matrix[0][0], 0.0); + // Assert.assertEquals(1.0, matrix[0][1], 0.0); + // Assert.assertEquals(2.0, matrix[1][0], 0.0); + // Assert.assertEquals(2.0, matrix[1][1], 0.0); + // Assert.assertEquals(3.0, matrix[2][0], 0.0); + // Assert.assertEquals(3.0, matrix[2][1], 0.0); + // + // TODO: Add asserts for frame if ordering is as expected + // String[][] frame = results.getFrameAs2DStringArray("tAM"); + // for (int i = 0; i < frame.length; i++) { + // for (int j = 0; j < frame[i].length; j++) { + // System.out.println("[" + i + "][" + j + "]:" + frame[i][j]); + // } + // } + // } + + @After + public void tearDown() { + super.tearDown(); + } + + @AfterClass + public static void tearDownClass() { + // stop spark context to allow single jvm tests (otherwise the + // next test that tries to create a SparkContext would fail) + sc.stop(); + sc = null; + conf = null; + + // clear status mlcontext and spark exec context + ml.close(); + ml = null; + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java index fd220d9..0252b50 100644 --- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java +++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java @@ -38,7 +38,6 @@ import java.io.InputStream; import java.net.MalformedURLException; import java.net.URL; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -67,9 +66,7 @@ import org.apache.sysml.api.mlcontext.MatrixFormat; import org.apache.sysml.api.mlcontext.MatrixMetadata; import org.apache.sysml.api.mlcontext.Script; import org.apache.sysml.api.mlcontext.ScriptExecutor; -import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils; import org.apache.sysml.test.integration.AutomatedTestBase; import org.junit.After; import org.junit.AfterClass; @@ -87,9 +84,6 @@ public class MLContextTest extends AutomatedTestBase { protected final static String TEST_DIR = "org/apache/sysml/api/mlcontext"; protected final static String TEST_NAME = "MLContext"; - public static enum SCRIPT_TYPE {DML, PYDML, SCALA}; - public static enum IO_TYPE {ANY, FILE, JAVA_RDD_STR_CSV, JAVA_RDD_STR_IJV, RDD_STR_CSV, RDD_STR_IJV, DATAFRAME}; - private static SparkConf conf; private static JavaSparkContext sc; private static MLContext ml; @@ -927,7 +921,7 @@ public class MLContextTest extends AutomatedTestBase { public void testOutputDoubleArrayMatrixDML() { System.out.println("MLContextTest - output double array matrix DML"); String s = "M = matrix('1 2 3 4', rows=2, cols=2);"; - double[][] matrix = ml.execute(dml(s).out("M")).getDoubleMatrix("M"); + double[][] matrix = ml.execute(dml(s).out("M")).getMatrixAs2DDoubleArray("M"); Assert.assertEquals(1.0, matrix[0][0], 0); Assert.assertEquals(2.0, matrix[0][1], 0); Assert.assertEquals(3.0, matrix[1][0], 0); @@ -938,7 +932,7 @@ public class MLContextTest extends AutomatedTestBase { public void testOutputDoubleArrayMatrixPYDML() { System.out.println("MLContextTest - output double array matrix PYDML"); String s = "M = full('1 2 3 4', rows=2, cols=2)"; - double[][] matrix = ml.execute(pydml(s).out("M")).getDoubleMatrix("M"); + double[][] matrix = ml.execute(pydml(s).out("M")).getMatrixAs2DDoubleArray("M"); Assert.assertEquals(1.0, matrix[0][0], 0); Assert.assertEquals(2.0, matrix[0][1], 0); Assert.assertEquals(3.0, matrix[1][0], 0); @@ -1032,34 +1026,6 @@ public class MLContextTest extends AutomatedTestBase { } @Test - public void testOutputFrameDML() { - System.out.println("MLContextTest - output frame DML"); - - String s = "M = read($Min, data_type='frame', format='csv');"; - String csvFile = baseDirectory + File.separator + "one-two-three-four.csv"; - Script script = dml(s).in("$Min", csvFile).out("M"); - String[][] frame = ml.execute(script).getFrame("M"); - Assert.assertEquals("one", frame[0][0]); - Assert.assertEquals("two", frame[0][1]); - Assert.assertEquals("three", frame[1][0]); - Assert.assertEquals("four", frame[1][1]); - } - - @Test - public void testOutputFramePYDML() { - System.out.println("MLContextTest - output frame PYDML"); - - String s = "M = load($Min, data_type='frame', format='csv')"; - String csvFile = baseDirectory + File.separator + "one-two-three-four.csv"; - Script script = pydml(s).in("$Min", csvFile).out("M"); - String[][] frame = ml.execute(script).getFrame("M"); - Assert.assertEquals("one", frame[0][0]); - Assert.assertEquals("two", frame[0][1]); - Assert.assertEquals("three", frame[1][0]); - Assert.assertEquals("four", frame[1][1]); - } - - @Test public void testOutputJavaRDDStringIJVDML() { System.out.println("MLContextTest - output Java RDD String IJV DML"); @@ -1518,7 +1484,7 @@ public class MLContextTest extends AutomatedTestBase { String s = "M = matrix('1 2 3 4', rows=2, cols=2); N = sum(M)"; // alternative to .out("M").out("N") MLResults results = ml.execute(dml(s).out("M", "N")); - double[][] matrix = results.getDoubleMatrix("M"); + double[][] matrix = results.getMatrixAs2DDoubleArray("M"); double sum = results.getDouble("N"); Assert.assertEquals(1.0, matrix[0][0], 0); Assert.assertEquals(2.0, matrix[0][1], 0); @@ -1534,7 +1500,7 @@ public class MLContextTest extends AutomatedTestBase { String s = "M = full('1 2 3 4', rows=2, cols=2)\nN = sum(M)"; // alternative to .out("M").out("N") MLResults results = ml.execute(pydml(s).out("M", "N")); - double[][] matrix = results.getDoubleMatrix("M"); + double[][] matrix = results.getMatrixAs2DDoubleArray("M"); double sum = results.getDouble("N"); Assert.assertEquals(1.0, matrix[0][0], 0); Assert.assertEquals(2.0, matrix[0][1], 0); @@ -2262,6 +2228,7 @@ public class MLContextTest extends AutomatedTestBase { setExpectedStdOut("sum: 45.0"); ml.execute(script); } + // NOTE: Uncomment these tests once they work // @SuppressWarnings({ "rawtypes", "unchecked" }) @@ -2330,322 +2297,6 @@ public class MLContextTest extends AutomatedTestBase { // ml.execute(script); // } - //////////////////////////////////////////// - // SystemML Frame MLContext testset Begin - //////////////////////////////////////////// - @Test - public void testFrameJavaRDD_CSV_DML() { - testFrame(MatrixFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.ANY); - } - - @Test - public void testFrameJavaRDD_CSV_DML_OutJavaRddCSV() { - testFrame(MatrixFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.JAVA_RDD_STR_CSV); - } - - @Test - public void testFrameJavaRDD_CSV_PYDML() { - testFrame(MatrixFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.ANY); - } - - @Test - public void testFrameRDD_CSV_PYDML() { - testFrame(MatrixFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.RDD_STR_CSV, IO_TYPE.ANY); - } - - @Test - public void testFrameJavaRDD_CSV_PYDML_OutRddIJV() { - testFrame(MatrixFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.RDD_STR_IJV); - } - - @Test - public void testFrameJavaRDD_IJV_DML() { - testFrame(MatrixFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.ANY); - } - - @Test - public void testFrameRDD_IJV_DML() { - testFrame(MatrixFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.RDD_STR_IJV, IO_TYPE.ANY); - } - - @Test - public void testFrameJavaRDD_IJV_DML_OutRddCSV() { - testFrame(MatrixFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.RDD_STR_CSV); - } - - @Test - public void testFrameJavaRDD_IJV_PYDML() { - testFrame(MatrixFormat.IJV, SCRIPT_TYPE.PYDML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.ANY); - } - - @Test - public void testFrameJavaRDD_IJV_PYDML_OutJavaRddIJV() { - testFrame(MatrixFormat.IJV, SCRIPT_TYPE.PYDML, IO_TYPE.JAVA_RDD_STR_IJV, IO_TYPE.JAVA_RDD_STR_IJV); - } - - @Test - public void testFrameFile_CSV_DML() { - testFrame(MatrixFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.FILE, IO_TYPE.ANY); - } - - @Test - public void testFrameFile_CSV_PYDML() { - testFrame(MatrixFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.FILE, IO_TYPE.ANY); - } - - @Test - public void testFrameFile_IJV_DML() { - testFrame(MatrixFormat.IJV, SCRIPT_TYPE.DML, IO_TYPE.FILE, IO_TYPE.ANY); - } - - @Test - public void testFrameFile_IJV_PYDML() { - testFrame(MatrixFormat.IJV, SCRIPT_TYPE.PYDML, IO_TYPE.FILE, IO_TYPE.ANY); - } - - @Test - public void testFrameDataFrame_CSV_DML() { - testFrame(MatrixFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.DATAFRAME, IO_TYPE.ANY); - } - - @Test - public void testFrameDataFrame_CSV_PYDML() { - testFrame(MatrixFormat.CSV, SCRIPT_TYPE.PYDML, IO_TYPE.DATAFRAME, IO_TYPE.ANY); - } - - @Test - public void testFrameDataFrameOutDataFrame_CSV_DML() { - testFrame(MatrixFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.DATAFRAME, IO_TYPE.DATAFRAME); - } - - - - - public void testFrame(MatrixFormat format, SCRIPT_TYPE script_type, IO_TYPE inputType, IO_TYPE outputType) { - - System.out.println("MLContextTest - Frame JavaRDD<String> for format: " + format + " Script: " + script_type); - - List<String> listA = new ArrayList<String>(); - List<String> listB = new ArrayList<String>(); - MatrixMetadata mmA = null, mmB = null; - Script script = null; - - - if(inputType != IO_TYPE.FILE) { - if(format == MatrixFormat.CSV) { - listA.add("1,Str2,3.0,true"); - listA.add("4,Str5,6.0,false"); - listA.add("7,Str8,9.0,true"); - - listB.add("Str12,13.0,true"); - listB.add("Str25,26.0,false"); - - mmA = new MatrixMetadata(MatrixFormat.CSV, 3, 4); - mmB = new MatrixMetadata(MatrixFormat.CSV, 2, 3); - } else if(format == MatrixFormat.IJV) { - listA.add("1 1 1"); - listA.add("1 2 Str2"); - listA.add("1 3 3.0"); - listA.add("1 4 true"); - listA.add("2 1 4"); - listA.add("2 2 Str5"); - listA.add("2 3 6.0"); - listA.add("2 4 false"); - listA.add("3 1 7"); - listA.add("3 2 Str8"); - listA.add("3 3 9.0"); - listA.add("3 4 true"); - - listB.add("1 1 Str12"); - listB.add("1 2 13.0"); - listB.add("1 3 true"); - listB.add("2 1 Str25"); - listB.add("2 2 26.0"); - listB.add("2 3 false"); - - mmA = new MatrixMetadata(MatrixFormat.IJV, 3, 4); - mmB = new MatrixMetadata(MatrixFormat.IJV, 2, 3); - } - JavaRDD<String> javaRDDA = sc.parallelize(listA); - JavaRDD<String> javaRDDB = sc.parallelize(listB); - - if(inputType == IO_TYPE.DATAFRAME) { - JavaRDD<Row> javaRddRowA = javaRDDA.map(new CommaSeparatedValueStringToRow()); - JavaRDD<Row> javaRddRowB = javaRDDB.map(new CommaSeparatedValueStringToRow()); - - ValueType[] schemaA = {ValueType.INT, ValueType.STRING, ValueType.DOUBLE, ValueType.BOOLEAN}; - List<ValueType> lschemaA = Arrays.asList(schemaA); - ValueType[] schemaB = {ValueType.STRING, ValueType.DOUBLE, ValueType.BOOLEAN}; - List<ValueType> lschemaB = Arrays.asList(schemaB); - - //Create DataFrame - SQLContext sqlContext = new SQLContext(sc); - StructType dfSchemaA = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(lschemaA); - DataFrame dataFrameA = sqlContext.createDataFrame(javaRddRowA, dfSchemaA); - StructType dfSchemaB = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(lschemaB); - DataFrame dataFrameB = sqlContext.createDataFrame(javaRddRowB, dfSchemaB); - if (script_type == SCRIPT_TYPE.DML) - script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", dataFrameA, mmA, true).in("B", dataFrameB, mmB, true).out("A").out("C"); - else if (script_type == SCRIPT_TYPE.PYDML) - // DO NOT USE ; at the end of any statment, it throws NPE - script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", dataFrameA, mmA, true).in("B", dataFrameB, mmB, true) - // Value for ROW index gets incremented at script level to adjust index in PyDML, but not for Column Index - .in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C"); - } else { - if(inputType == IO_TYPE.JAVA_RDD_STR_CSV || inputType == IO_TYPE.JAVA_RDD_STR_IJV) { - if (script_type == SCRIPT_TYPE.DML) - script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", javaRDDA, mmA, true).in("B", javaRDDB, mmB, true).out("A").out("C"); - else if (script_type == SCRIPT_TYPE.PYDML) - // DO NOT USE ; at the end of any statment, it throws NPE - script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", javaRDDA, mmA, true).in("B", javaRDDB, mmB, true) - // Value for ROW index gets incremented at script level to adjust index in PyDML, but not for Column Index - .in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C"); - } else if(inputType == IO_TYPE.RDD_STR_CSV || inputType == IO_TYPE.RDD_STR_IJV) { - RDD<String> rddA = JavaRDD.toRDD(javaRDDA); - RDD<String> rddB = JavaRDD.toRDD(javaRDDB); - - if (script_type == SCRIPT_TYPE.DML) - script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", rddA, mmA, true).in("B", rddB, mmB, true).out("A").out("C"); - else if (script_type == SCRIPT_TYPE.PYDML) - // DO NOT USE ; at the end of any statment, it throws NPE - script = pydml("A[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("A", rddA, mmA, true).in("B", rddB, mmB, true) - // Value for ROW index gets incremented at script level to adjust index in PyDML, but not for Column Index - .in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C"); - } - - } - - } else { // Input type is file - String fileA = null, fileB = null; - if(format == MatrixFormat.CSV) { - fileA = baseDirectory + File.separator + "FrameA.csv"; - fileB = baseDirectory + File.separator + "FrameB.csv"; - } else if(format == MatrixFormat.IJV) { - fileA = baseDirectory + File.separator + "FrameA.ijv"; - fileB = baseDirectory + File.separator + "FrameB.ijv"; - } - - if (script_type == SCRIPT_TYPE.DML) - script = dml("A=read($A); B=read($B);A[2:3,2:4]=B;C=A[2:3,2:3]").in("$A", fileA, mmA, true).in("$B", fileB, mmB, true).out("A").out("C"); - else if (script_type == SCRIPT_TYPE.PYDML) - // DO NOT USE ; at the end of any statment, it throws NPE - script = pydml("A=load($A)\nB=load($B)\nA[$X:$Y,$X:$Z]=B\nC=A[$X:$Y,$X:$Y]").in("$A", fileA).in("$B", fileB) - // Value for ROW index gets incremented at script level to adjust index in PyDML, but not for Column Index - .in("$X", 1).in("$Y", 3).in("$Z", 4).out("A").out("C"); - } - - MLResults mlResults = ml.execute(script); - - if(outputType == IO_TYPE.JAVA_RDD_STR_CSV) { - - JavaRDD<String> javaRDDStringCSVA = mlResults.getFrameJavaRDDStringCSV("A", ","); - List<String> linesA = javaRDDStringCSVA.collect(); - Assert.assertEquals("1,Str2,3.0,true", linesA.get(0)); - Assert.assertEquals("4,Str12,13.0,true", linesA.get(1)); - Assert.assertEquals("7,Str25,26.0,false", linesA.get(2)); - - JavaRDD<String> javaRDDStringCSVC = mlResults.getFrameJavaRDDStringCSV("C", ","); - List<String> linesC = javaRDDStringCSVC.collect(); - Assert.assertEquals("Str12,13.0", linesC.get(0)); - Assert.assertEquals("Str25,26.0", linesC.get(1)); - } else if(outputType == IO_TYPE.JAVA_RDD_STR_IJV) { - JavaRDD<String> javaRDDStringIJVA = mlResults.getFrameJavaRDDStringIJV("A"); - List<String> linesA = javaRDDStringIJVA.collect(); - Assert.assertEquals("1 1 1", linesA.get(0)); - Assert.assertEquals("1 2 Str2", linesA.get(1)); - Assert.assertEquals("1 3 3.0", linesA.get(2)); - Assert.assertEquals("1 4 true", linesA.get(3)); - Assert.assertEquals("2 1 4", linesA.get(4)); - Assert.assertEquals("2 2 Str12", linesA.get(5)); - Assert.assertEquals("2 3 13.0", linesA.get(6)); - Assert.assertEquals("2 4 true", linesA.get(7)); - - JavaRDD<String> javaRDDStringIJVC = mlResults.getFrameJavaRDDStringIJV("C"); - List<String> linesC = javaRDDStringIJVC.collect(); - Assert.assertEquals("1 1 Str12", linesC.get(0)); - Assert.assertEquals("1 2 13.0", linesC.get(1)); - Assert.assertEquals("2 1 Str25", linesC.get(2)); - Assert.assertEquals("2 2 26.0", linesC.get(3)); - } else if(outputType == IO_TYPE.RDD_STR_CSV) { - RDD<String> rddStringCSVA = mlResults.getFrameRDDStringCSV("A", ","); //TODO fix delimiter - Iterator<String> iteratorA = rddStringCSVA.toLocalIterator(); - Assert.assertEquals("1,Str2,3.0,true", iteratorA.next()); - Assert.assertEquals("4,Str12,13.0,true", iteratorA.next()); - Assert.assertEquals("7,Str25,26.0,false", iteratorA.next()); - - RDD<String> rddStringCSVC = mlResults.getFrameRDDStringCSV("C", ","); //TODO fix delimiter - Iterator<String> iteratorC = rddStringCSVC.toLocalIterator(); - Assert.assertEquals("Str12,13.0", iteratorC.next()); - Assert.assertEquals("Str25,26.0", iteratorC.next()); - } else if(outputType == IO_TYPE.RDD_STR_IJV) { - RDD<String> rddStringIJVA = mlResults.getFrameRDDStringIJV("A"); - Iterator<String> iteratorA = rddStringIJVA.toLocalIterator(); - Assert.assertEquals("1 1 1", iteratorA.next()); - Assert.assertEquals("1 2 Str2", iteratorA.next()); - Assert.assertEquals("1 3 3.0", iteratorA.next()); - Assert.assertEquals("1 4 true", iteratorA.next()); - Assert.assertEquals("2 1 4", iteratorA.next()); - Assert.assertEquals("2 2 Str12", iteratorA.next()); - Assert.assertEquals("2 3 13.0", iteratorA.next()); - Assert.assertEquals("2 4 true", iteratorA.next()); - Assert.assertEquals("3 1 7", iteratorA.next()); - Assert.assertEquals("3 2 Str25", iteratorA.next()); - Assert.assertEquals("3 3 26.0", iteratorA.next()); - Assert.assertEquals("3 4 false", iteratorA.next()); - - RDD<String> rddStringIJVC = mlResults.getFrameRDDStringIJV("C"); - Iterator<String> iteratorC = rddStringIJVC.toLocalIterator(); - Assert.assertEquals("1 1 Str12", iteratorC.next()); - Assert.assertEquals("1 2 13.0", iteratorC.next()); - Assert.assertEquals("2 1 Str25", iteratorC.next()); - Assert.assertEquals("2 2 26.0", iteratorC.next()); - - } else if(outputType == IO_TYPE.DATAFRAME) { - - DataFrame dataFrameA = mlResults.getFrameDataFrame("A"); - List<Row> listAOut = dataFrameA.collectAsList(); - - Row row1 = listAOut.get(0); - Assert.assertEquals("Mistmatch with expected value", "1", row1.getString(0)); - Assert.assertEquals("Mistmatch with expected value", "Str2", row1.getString(1)); - Assert.assertEquals("Mistmatch with expected value", "3.0", row1.getString(2)); - Assert.assertEquals("Mistmatch with expected value", "true", row1.getString(3)); - - Row row2 = listAOut.get(1); - Assert.assertEquals("Mistmatch with expected value", "4", row2.getString(0)); - Assert.assertEquals("Mistmatch with expected value", "Str12", row2.getString(1)); - Assert.assertEquals("Mistmatch with expected value", "13.0", row2.getString(2)); - Assert.assertEquals("Mistmatch with expected value", "true", row2.getString(3)); - - DataFrame dataFrameC = mlResults.getFrameDataFrame("C"); - List<Row> listCOut = dataFrameC.collectAsList(); - - Row row3 = listCOut.get(0); - Assert.assertEquals("Mistmatch with expected value", "Str12", row3.getString(0)); - Assert.assertEquals("Mistmatch with expected value", "13.0", row3.getString(1)); - - Row row4 = listCOut.get(1); - Assert.assertEquals("Mistmatch with expected value", "Str25", row4.getString(0)); - Assert.assertEquals("Mistmatch with expected value", "26.0", row4.getString(1)); - } else { - String[][] frameA = mlResults.getFrame("A"); - Assert.assertEquals("Str2", frameA[0][1]); - Assert.assertEquals("3.0", frameA[0][2]); - Assert.assertEquals("13.0", frameA[1][2]); - Assert.assertEquals("true", frameA[1][3]); - Assert.assertEquals("Str25", frameA[2][1]); - - String[][] frameC = mlResults.getFrame("C"); - Assert.assertEquals("Str12", frameC[0][0]); - Assert.assertEquals("Str25", frameC[1][0]); - Assert.assertEquals("13.0", frameC[0][1]); - Assert.assertEquals("26.0", frameC[1][1]); - } - } - //////////////////////////////////////////// - // SystemML Frame MLContext testset End - //////////////////////////////////////////// - @After public void tearDown() { super.tearDown(); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d39865e9/src/test_suites/java/org/apache/sysml/test/integration/mlcontext/ZPackageSuite.java ---------------------------------------------------------------------- diff --git a/src/test_suites/java/org/apache/sysml/test/integration/mlcontext/ZPackageSuite.java b/src/test_suites/java/org/apache/sysml/test/integration/mlcontext/ZPackageSuite.java index 5687a55..387579f 100644 --- a/src/test_suites/java/org/apache/sysml/test/integration/mlcontext/ZPackageSuite.java +++ b/src/test_suites/java/org/apache/sysml/test/integration/mlcontext/ZPackageSuite.java @@ -27,7 +27,8 @@ import org.junit.runners.Suite; * they should not be run in parallel. */ @RunWith(Suite.class) @Suite.SuiteClasses({ - org.apache.sysml.test.integration.mlcontext.MLContextTest.class + org.apache.sysml.test.integration.mlcontext.MLContextTest.class, + org.apache.sysml.test.integration.mlcontext.MLContextFrameTest.class })
