Repository: incubator-systemml Updated Branches: refs/heads/master 5ac32d6be -> 6df0d2348
[SYSTEMML-860] SparkR/HydraR integration with SystemML Closes #212. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/6df0d234 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/6df0d234 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/6df0d234 Branch: refs/heads/master Commit: 6df0d2348e77d583ef02974e5a1f1120a959270a Parents: 5ac32d6 Author: Alok Singh <[email protected]> Authored: Mon Aug 15 14:49:44 2016 -0700 Committer: Deron Eriksson <[email protected]> Committed: Mon Aug 15 14:49:44 2016 -0700 ---------------------------------------------------------------------- .../java/org/apache/sysml/api/MLContext.java | 76 +++++++++++++++++++- .../spark/utils/RDDConverterUtilsExt.java | 67 ++++++++++++++++- 2 files changed, 141 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6df0d234/src/main/java/org/apache/sysml/api/MLContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/MLContext.java b/src/main/java/org/apache/sysml/api/MLContext.java index d8a290d..405478f 100644 --- a/src/main/java/org/apache/sysml/api/MLContext.java +++ b/src/main/java/org/apache/sysml/api/MLContext.java @@ -837,7 +837,52 @@ public class MLContext { argsArr = args.toArray(argsArr); return execute(dmlScriptFilePath, argsArr, parsePyDML, configFilePath); } - + + /* + @NOTE: from calling with the SparkR , somehow Map passing from R to java + is not working and hence we pass in two arrays each representing keys + and values + */ + /** + * Execute DML script by passing positional arguments using specified config file + * @param dmlScriptFilePath + * @param argsName + * @param argsValues + * @param configFilePath + * @throws IOException + * @throws DMLException + * @throws ParseException + */ + public MLOutput execute(String dmlScriptFilePath, ArrayList<String> argsName, + ArrayList<String> argsValues, String configFilePath) + throws IOException, DMLException, ParseException { + HashMap<String, String> newNamedArgs = new HashMap<String, String>(); + if (argsName.size() != argsValues.size()) { + throw new DMLException("size of argsName " + argsName.size() + + " is diff than " + " size of argsValues"); + } + for (int i = 0; i < argsName.size(); i++) { + String k = argsName.get(i); + String v = argsValues.get(i); + newNamedArgs.put(k, v); + } + return execute(dmlScriptFilePath, newNamedArgs, configFilePath); + } + /** + * Execute DML script by passing positional arguments using specified config file + * @param dmlScriptFilePath + * @param argsName + * @param argsValues + * @throws IOException + * @throws DMLException + * @throws ParseException + */ + public MLOutput execute(String dmlScriptFilePath, ArrayList<String> argsName, + ArrayList<String> argsValues) + throws IOException, DMLException, ParseException { + return execute(dmlScriptFilePath, argsName, argsValues, null); + } + /** * Experimental: Execute DML script by passing positional arguments if parsePyDML=true, using specified config file. * @param dmlScriptFilePath @@ -1163,11 +1208,40 @@ public class MLContext { return executeScript(dmlScript, false, configFilePath); } + public MLOutput executeScript(String dmlScript, boolean isPyDML, String configFilePath) throws IOException, DMLException { return compileAndExecuteScript(dmlScript, null, false, false, isPyDML, configFilePath); } + /* + @NOTE: from calling with the SparkR , somehow HashMap passing from R to java + is not working and hence we pass in two arrays each representing keys + and values + */ + public MLOutput executeScript(String dmlScript, ArrayList<String> argsName, + ArrayList<String> argsValues, String configFilePath) + throws IOException, DMLException, ParseException { + HashMap<String, String> newNamedArgs = new HashMap<String, String>(); + if (argsName.size() != argsValues.size()) { + throw new DMLException("size of argsName " + argsName.size() + + " is diff than " + " size of argsValues"); + } + for (int i = 0; i < argsName.size(); i++) { + String k = argsName.get(i); + String v = argsValues.get(i); + newNamedArgs.put(k, v); + } + return executeScript(dmlScript, newNamedArgs, configFilePath); + } + + public MLOutput executeScript(String dmlScript, ArrayList<String> argsName, + ArrayList<String> argsValues) + throws IOException, DMLException, ParseException { + return executeScript(dmlScript, argsName, argsValues, null); + } + + public MLOutput executeScript(String dmlScript, scala.collection.immutable.Map<String, String> namedArgs) throws IOException, DMLException { return executeScript(dmlScript, new HashMap<String, String>(scala.collection.JavaConversions.mapAsJavaMap(namedArgs)), null); http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6df0d234/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java index 72ab230..88dd44c 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java @@ -37,6 +37,7 @@ import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix; import org.apache.spark.mllib.linalg.distributed.MatrixEntry; import org.apache.spark.sql.DataFrame; @@ -141,7 +142,71 @@ public class RDDConverterUtilsExt throw new DMLRuntimeException("The output format:" + format + " is not implemented yet."); } } - + + + + public static DataFrame stringDataFrameToVectorDataFrame(SQLContext sqlContext, DataFrame inputDF) + throws DMLRuntimeException { + + StructField[] oldSchema = inputDF.schema().fields(); + //create the new schema + StructField[] newSchema = new StructField[oldSchema.length]; + for(int i = 0; i < oldSchema.length; i++) { + String colName = oldSchema[i].name(); + newSchema[i] = DataTypes.createStructField(colName, new VectorUDT(), true); + } + + //converter + class StringToVector implements Function<Tuple2<Row, Long>, Row> { + private static final long serialVersionUID = -4733816995375745659L; + @Override + public Row call(Tuple2<Row, Long> arg0) throws Exception { + Row oldRow = arg0._1; + int oldNumCols = oldRow.length(); + if (oldNumCols > 1) { + throw new DMLRuntimeException("The row must have at most one column"); + } + + // parse the various strings. i.e + // ((1.2,4.3, 3.4)) or (1.2, 3.4, 2.2) or (1.2 3.4) + // [[1.2,34.3, 1.2, 1.2]] or [1.2, 3.4] or [1.3 1.2] + Object [] fields = new Object[oldNumCols]; + ArrayList<Object> fieldsArr = new ArrayList<Object>(); + for (int i = 0; i < oldRow.length(); i++) { + Object ci=oldRow.get(i); + if (ci instanceof String) { + String cis = (String)ci; + StringBuffer sb = new StringBuffer(cis.trim()); + for (int nid=0; i < 2; i++) { //remove two level nesting + if ((sb.charAt(0) == '(' && sb.charAt(sb.length() - 1) == ')') || + (sb.charAt(0) == '[' && sb.charAt(sb.length() - 1) == ']') + ) { + sb.deleteCharAt(0); + sb.setLength(sb.length() - 1); + } + } + //have the replace code + String ncis = "[" + sb.toString().replaceAll(" *, *", ",") + "]"; + Vector v = Vectors.parse(ncis); + fieldsArr.add(v); + } else { + throw new DMLRuntimeException("Only String is supported"); + } + } + Row row = RowFactory.create(fieldsArr.toArray()); + return row; + } + } + + //output DF + JavaRDD<Row> newRows = inputDF.rdd().toJavaRDD().zipWithIndex().map(new StringToVector()); + // DataFrame outDF = sqlContext.createDataFrame(newRows, new StructType(newSchema)); //TODO investigate why it doesn't work + DataFrame outDF = sqlContext.createDataFrame(newRows.rdd(), + DataTypes.createStructType(newSchema)); + + return outDF; + } + public static JavaPairRDD<MatrixIndexes, MatrixBlock> vectorDataFrameToBinaryBlock(SparkContext sc, DataFrame inputDF, MatrixCharacteristics mcOut, boolean containsID, String vectorColumnName) throws DMLRuntimeException { return vectorDataFrameToBinaryBlock(new JavaSparkContext(sc), inputDF, mcOut, containsID, vectorColumnName);
