Repository: incubator-systemml
Updated Branches:
  refs/heads/master 5ac32d6be -> 6df0d2348


[SYSTEMML-860] SparkR/HydraR integration with SystemML

Closes #212.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/6df0d234
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/6df0d234
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/6df0d234

Branch: refs/heads/master
Commit: 6df0d2348e77d583ef02974e5a1f1120a959270a
Parents: 5ac32d6
Author: Alok Singh <[email protected]>
Authored: Mon Aug 15 14:49:44 2016 -0700
Committer: Deron Eriksson <[email protected]>
Committed: Mon Aug 15 14:49:44 2016 -0700

----------------------------------------------------------------------
 .../java/org/apache/sysml/api/MLContext.java    | 76 +++++++++++++++++++-
 .../spark/utils/RDDConverterUtilsExt.java       | 67 ++++++++++++++++-
 2 files changed, 141 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6df0d234/src/main/java/org/apache/sysml/api/MLContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/MLContext.java 
b/src/main/java/org/apache/sysml/api/MLContext.java
index d8a290d..405478f 100644
--- a/src/main/java/org/apache/sysml/api/MLContext.java
+++ b/src/main/java/org/apache/sysml/api/MLContext.java
@@ -837,7 +837,52 @@ public class MLContext {
                argsArr = args.toArray(argsArr);
                return execute(dmlScriptFilePath, argsArr, parsePyDML, 
configFilePath);
        }
-       
+
+       /*
+         @NOTE: from calling with the SparkR , somehow Map passing from R to 
java
+          is not working and hence we pass in two  arrays each representing 
keys
+          and values
+        */
+       /**
+        * Execute DML script by passing positional arguments using specified 
config file
+        * @param dmlScriptFilePath
+        * @param argsName
+        * @param argsValues
+        * @param configFilePath
+        * @throws IOException
+        * @throws DMLException
+        * @throws ParseException
+        */
+       public MLOutput execute(String dmlScriptFilePath, ArrayList<String> 
argsName,
+                                                       ArrayList<String> 
argsValues, String configFilePath)
+                       throws IOException, DMLException, ParseException  {
+               HashMap<String, String> newNamedArgs = new HashMap<String, 
String>();
+               if (argsName.size() != argsValues.size()) {
+                       throw new DMLException("size of argsName " + 
argsName.size() +
+                                       " is diff than " + " size of 
argsValues");
+               }
+               for (int i = 0; i < argsName.size(); i++) {
+                       String k = argsName.get(i);
+                       String v = argsValues.get(i);
+                       newNamedArgs.put(k, v);
+               }
+               return execute(dmlScriptFilePath, newNamedArgs, configFilePath);
+       }
+       /**
+        * Execute DML script by passing positional arguments using specified 
config file
+        * @param dmlScriptFilePath
+        * @param argsName
+        * @param argsValues
+        * @throws IOException
+        * @throws DMLException
+        * @throws ParseException
+        */
+       public MLOutput execute(String dmlScriptFilePath, ArrayList<String> 
argsName,
+                                                       ArrayList<String> 
argsValues)
+                       throws IOException, DMLException, ParseException  {
+               return execute(dmlScriptFilePath, argsName, argsValues, null);
+       }
+
        /**
         * Experimental: Execute DML script by passing positional arguments if 
parsePyDML=true, using specified config file.
         * @param dmlScriptFilePath
@@ -1163,11 +1208,40 @@ public class MLContext {
                return executeScript(dmlScript, false, configFilePath);
        }
 
+
        public MLOutput executeScript(String dmlScript, boolean isPyDML, String 
configFilePath)
                        throws IOException, DMLException {
                return compileAndExecuteScript(dmlScript, null, false, false, 
isPyDML, configFilePath);
        }
 
+       /*
+         @NOTE: from calling with the SparkR , somehow HashMap passing from R 
to java
+          is not working and hence we pass in two  arrays each representing 
keys
+          and values
+        */
+       public MLOutput executeScript(String dmlScript, ArrayList<String> 
argsName,
+                                                                 
ArrayList<String> argsValues, String configFilePath)
+                       throws IOException, DMLException, ParseException  {
+               HashMap<String, String> newNamedArgs = new HashMap<String, 
String>();
+               if (argsName.size() != argsValues.size()) {
+                       throw new DMLException("size of argsName " + 
argsName.size() +
+                                       " is diff than " + " size of 
argsValues");
+               }
+               for (int i = 0; i < argsName.size(); i++) {
+                       String k = argsName.get(i);
+                       String v = argsValues.get(i);
+                       newNamedArgs.put(k, v);
+               }
+               return executeScript(dmlScript, newNamedArgs, configFilePath);
+       }
+
+       public MLOutput executeScript(String dmlScript, ArrayList<String> 
argsName,
+                                                                 
ArrayList<String> argsValues)
+                       throws IOException, DMLException, ParseException  {
+               return executeScript(dmlScript, argsName, argsValues, null);
+       }
+
+
        public MLOutput executeScript(String dmlScript, 
scala.collection.immutable.Map<String, String> namedArgs)
                        throws IOException, DMLException {
                return executeScript(dmlScript, new HashMap<String, 
String>(scala.collection.JavaConversions.mapAsJavaMap(namedArgs)), null);

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6df0d234/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
index 72ab230..88dd44c 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
@@ -37,6 +37,7 @@ import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.function.PairFlatMapFunction;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.mllib.linalg.VectorUDT;
+import org.apache.spark.mllib.linalg.Vectors;
 import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
 import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
 import org.apache.spark.sql.DataFrame;
@@ -141,7 +142,71 @@ public class RDDConverterUtilsExt
                        throw new DMLRuntimeException("The output format:" + 
format + " is not implemented yet.");
                }
        }
-       
+
+
+
+       public static DataFrame stringDataFrameToVectorDataFrame(SQLContext 
sqlContext, DataFrame inputDF)
+                       throws DMLRuntimeException {
+
+               StructField[] oldSchema = inputDF.schema().fields();
+               //create the new schema
+               StructField[] newSchema = new StructField[oldSchema.length];
+               for(int i = 0; i < oldSchema.length; i++) {
+                       String colName = oldSchema[i].name();
+                       newSchema[i] = DataTypes.createStructField(colName, new 
VectorUDT(), true);
+               }
+
+               //converter
+               class StringToVector implements Function<Tuple2<Row, Long>, 
Row> {
+                       private static final long serialVersionUID = 
-4733816995375745659L;
+                       @Override
+                       public Row call(Tuple2<Row, Long> arg0) throws 
Exception {
+                               Row oldRow = arg0._1;
+                               int oldNumCols = oldRow.length();
+                               if (oldNumCols > 1) {
+                                       throw new DMLRuntimeException("The row 
must have at most one column");
+                               }
+
+                               // parse the various strings. i.e
+                               // ((1.2,4.3, 3.4))  or (1.2, 3.4, 2.2) or (1.2 
3.4)
+                               // [[1.2,34.3, 1.2, 1.2]] or [1.2, 3.4] or [1.3 
1.2]
+                               Object [] fields = new Object[oldNumCols];
+                               ArrayList<Object> fieldsArr = new 
ArrayList<Object>();
+                               for (int i = 0; i < oldRow.length(); i++) {
+                                       Object ci=oldRow.get(i);
+                                       if (ci instanceof String) {
+                                               String cis = (String)ci;
+                                               StringBuffer sb = new 
StringBuffer(cis.trim());
+                                               for (int nid=0; i < 2; i++) { 
//remove two level nesting
+                                                       if ((sb.charAt(0) == 
'(' && sb.charAt(sb.length() - 1) == ')') ||
+                                                                       
(sb.charAt(0) == '[' && sb.charAt(sb.length() - 1) == ']')
+                                                                       ) {
+                                                               
sb.deleteCharAt(0);
+                                                               
sb.setLength(sb.length() - 1);
+                                                       }
+                                               }
+                                               //have the replace code
+                                               String ncis = "[" + 
sb.toString().replaceAll(" *, *", ",") + "]";
+                                               Vector v = Vectors.parse(ncis);
+                                               fieldsArr.add(v);
+                                       } else {
+                                               throw new 
DMLRuntimeException("Only String is supported");
+                                       }
+                               }
+                               Row row = 
RowFactory.create(fieldsArr.toArray());
+                               return row;
+                       }
+               }
+
+               //output DF
+               JavaRDD<Row> newRows = 
inputDF.rdd().toJavaRDD().zipWithIndex().map(new StringToVector());
+               // DataFrame outDF = sqlContext.createDataFrame(newRows, new 
StructType(newSchema)); //TODO investigate why it doesn't work
+               DataFrame outDF = sqlContext.createDataFrame(newRows.rdd(),
+                               DataTypes.createStructType(newSchema));
+
+               return outDF;
+       }
+
        public static JavaPairRDD<MatrixIndexes, MatrixBlock> 
vectorDataFrameToBinaryBlock(SparkContext sc,
                        DataFrame inputDF, MatrixCharacteristics mcOut, boolean 
containsID, String vectorColumnName) throws DMLRuntimeException {
                return vectorDataFrameToBinaryBlock(new JavaSparkContext(sc), 
inputDF, mcOut, containsID, vectorColumnName);

Reply via email to