Repository: incubator-systemml Updated Branches: refs/heads/master 77363c0c6 -> d7b9cc467
[SYSTEMML-847] Remove LogisticRegression and LogisticRegressionModel in api/javaml Removed Java classes in api/javaml since corresponding Scala versions exist in api/ml. Closes #205. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/d7b9cc46 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/d7b9cc46 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/d7b9cc46 Branch: refs/heads/master Commit: d7b9cc4676e04992235a63669ee06f31fa1d14b5 Parents: 77363c0 Author: Glenn Weidner <[email protected]> Authored: Fri Aug 5 16:25:33 2016 -0700 Committer: Glenn Weidner <[email protected]> Committed: Fri Aug 5 16:25:33 2016 -0700 ---------------------------------------------------------------------- .../sysml/api/javaml/LogisticRegression.java | 473 ------------------- .../api/javaml/LogisticRegressionModel.java | 179 ------- 2 files changed, 652 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7b9cc46/src/main/java/org/apache/sysml/api/javaml/LogisticRegression.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/javaml/LogisticRegression.java b/src/main/java/org/apache/sysml/api/javaml/LogisticRegression.java deleted file mode 100644 index dbcc118..0000000 --- a/src/main/java/org/apache/sysml/api/javaml/LogisticRegression.java +++ /dev/null @@ -1,473 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.api.javaml; - -import java.io.File; -import java.util.HashMap; - -import org.apache.spark.SparkContext; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.classification.LogisticRegressionParams; -import org.apache.spark.ml.classification.ProbabilisticClassifier; -import org.apache.spark.ml.param.BooleanParam; -import org.apache.spark.ml.param.DoubleParam; -import org.apache.spark.ml.param.IntParam; -import org.apache.spark.ml.param.StringArrayParam; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; -import org.apache.sysml.api.MLContext; -import org.apache.sysml.api.MLOutput; -import org.apache.sysml.api.javaml.LogisticRegressionModel; -import org.apache.sysml.api.ml.functions.ConvertSingleColumnToString; -import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt; -import org.apache.sysml.runtime.matrix.MatrixCharacteristics; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; -import org.apache.sysml.runtime.matrix.data.MatrixIndexes; - -/** - * - * This class shows how SystemML can be integrated into MLPipeline. Note, it has not been optimized for performance and - * is implemented as a proof of concept. An optimized pipeline can be constructed by usage of DML's 'parfor' construct. - * - * TODO: - * - Please note that this class expects 1-based labels. To run below example, - * please set environment variable 'SYSTEMML_HOME' and create folder 'algorithms' - * and place atleast two scripts in that folder 'MultiLogReg.dml' and 'GLM-predict.dml' - * - It is not yet optimized for performance. - * - Also, it needs to be extended to surface all the parameters of MultiLogReg.dml - * - * Example usage: - * <pre><code> - * // Code to demonstrate usage of pipeline - * import org.apache.spark.ml.Pipeline - * import org.apache.sysml.api.ml.LogisticRegression - * import org.apache.spark.ml.feature.{HashingTF, Tokenizer} - * import org.apache.spark.mllib.linalg.Vector - * case class LabeledDocument(id: Long, text: String, label: Double) - * case class Document(id: Long, text: String) - * val training = sc.parallelize(Seq( - * LabeledDocument(0L, "a b c d e spark", 1.0), - * LabeledDocument(1L, "b d", 2.0), - * LabeledDocument(2L, "spark f g h", 1.0), - * LabeledDocument(3L, "hadoop mapreduce", 2.0), - * LabeledDocument(4L, "b spark who", 1.0), - * LabeledDocument(5L, "g d a y", 2.0), - * LabeledDocument(6L, "spark fly", 1.0), - * LabeledDocument(7L, "was mapreduce", 2.0), - * LabeledDocument(8L, "e spark program", 1.0), - * LabeledDocument(9L, "a e c l", 2.0), - * LabeledDocument(10L, "spark compile", 1.0), - * LabeledDocument(11L, "hadoop software", 2.0))) - * val tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words") - * val hashingTF = new HashingTF().setNumFeatures(1000).setInputCol(tokenizer.getOutputCol).setOutputCol("features") - * val lr = new LogisticRegression(sc, sqlContext) - * val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, lr)) - * val model = pipeline.fit(training.toDF) - * val test = sc.parallelize(Seq( - * Document(12L, "spark i j k"), - * Document(13L, "l m n"), - * Document(14L, "mapreduce spark"), - * Document(15L, "apache hadoop"))) - * model.transform(test.toDF).show - * - * // Code to demonstrate usage of cross-validation - * import org.apache.spark.ml.Pipeline - * import org.apache.sysml.api.ml.LogisticRegression - * import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator - * import org.apache.spark.ml.feature.{HashingTF, Tokenizer} - * import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} - * import org.apache.spark.mllib.linalg.Vector - * case class LabeledDocument(id: Long, text: String, label: Double) - * case class Document(id: Long, text: String) - * val training = sc.parallelize(Seq( - * LabeledDocument(0L, "a b c d e spark", 1.0), - * LabeledDocument(1L, "b d", 2.0), - * LabeledDocument(2L, "spark f g h", 1.0), - * LabeledDocument(3L, "hadoop mapreduce", 2.0), - * LabeledDocument(4L, "b spark who", 1.0), - * LabeledDocument(5L, "g d a y", 2.0), - * LabeledDocument(6L, "spark fly", 1.0), - * LabeledDocument(7L, "was mapreduce", 2.0), - * LabeledDocument(8L, "e spark program", 1.0), - * LabeledDocument(9L, "a e c l", 2.0), - * LabeledDocument(10L, "spark compile", 1.0), - * LabeledDocument(11L, "hadoop software", 2.0))) - * val tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words") - * val hashingTF = new HashingTF().setNumFeatures(1000).setInputCol(tokenizer.getOutputCol).setOutputCol("features") - * val lr = new LogisticRegression(sc, sqlContext) - * val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, lr)) - * val crossval = new CrossValidator().setEstimator(pipeline).setEvaluator(new BinaryClassificationEvaluator) - * val paramGrid = new ParamGridBuilder().addGrid(hashingTF.numFeatures, Array(10, 100, 1000)).addGrid(lr.regParam, Array(0.1, 0.01)).build() - * crossval.setEstimatorParamMaps(paramGrid) - * crossval.setNumFolds(2) - * val cvModel = crossval.fit(training.toDF) - * val test = sc.parallelize(Seq( - * Document(12L, "spark i j k"), - * Document(13L, "l m n"), - * Document(14L, "mapreduce spark"), - * Document(15L, "apache hadoop"))) - * cvModel.transform(test.toDF).show - * </code></pre> - * - */ -public class LogisticRegression extends ProbabilisticClassifier<Vector, LogisticRegression, LogisticRegressionModel> - implements LogisticRegressionParams { - - private static final long serialVersionUID = 7763813395635870734L; - - private SparkContext sc = null; - private SQLContext sqlContext = null; - private HashMap<String, String> cmdLineParams = new HashMap<String, String>(); - - private IntParam icpt = new IntParam(this, "icpt", "Value of intercept"); - private DoubleParam reg = new DoubleParam(this, "reg", "Value of regularization parameter"); - private DoubleParam tol = new DoubleParam(this, "tol", "Value of tolerance"); - private IntParam moi = new IntParam(this, "moi", "Max outer iterations"); - private IntParam mii = new IntParam(this, "mii", "Max inner iterations"); - private IntParam labelIndex = new IntParam(this, "li", "Index of the label column"); - private StringArrayParam inputCol = new StringArrayParam(this, "icname", "Feature column name"); - private StringArrayParam outputCol = new StringArrayParam(this, "ocname", "Label column name"); - private int intMin = Integer.MIN_VALUE; - @SuppressWarnings("unused") - private int li = 0; - private String[] icname = new String[1]; - private String[] ocname = new String[1]; - - public LogisticRegression() { - } - - public LogisticRegression(String uid) { - } - - @Override - public LogisticRegression copy(org.apache.spark.ml.param.ParamMap paramMap) { - try { - // Copy deals with command-line parameter of script MultiLogReg.dml - LogisticRegression lr = new LogisticRegression(sc, sqlContext); - lr.cmdLineParams.put(icpt.name(), paramMap.getOrElse(icpt, 0).toString()); - lr.cmdLineParams.put(reg.name(), paramMap.getOrElse(reg, 0.0f).toString()); - lr.cmdLineParams.put(tol.name(), paramMap.getOrElse(tol, 0.000001f).toString()); - lr.cmdLineParams.put(moi.name(), paramMap.getOrElse(moi, 100).toString()); - lr.cmdLineParams.put(mii.name(), paramMap.getOrElse(mii, 0).toString()); - - return lr; - } catch (DMLRuntimeException e) { - e.printStackTrace(); - } - return null; - - } - - public LogisticRegression(SparkContext sc, SQLContext sqlContext) throws DMLRuntimeException { - this.sc = sc; - this.sqlContext = sqlContext; - - setDefault(intercept(), 0); - cmdLineParams.put(icpt.name(), "0"); - setDefault(regParam(), 0.0f); - cmdLineParams.put(reg.name(), "0.0f"); - setDefault(tol(), 0.000001f); - cmdLineParams.put(tol.name(), "0.000001f"); - setDefault(maxOuterIter(), 100); - cmdLineParams.put(moi.name(), "100"); - setDefault(maxInnerIter(), 0); - cmdLineParams.put(mii.name(), "0"); - setDefault(labelIdx(), intMin); - li = intMin; - setDefault(inputCol(), icname); - icname[0] = ""; - setDefault(outputCol(), ocname); - ocname[0] = ""; - } - - public LogisticRegression(SparkContext sc, SQLContext sqlContext, int icpt, double reg, double tol, int moi, int mii) throws DMLRuntimeException { - this.sc = sc; - this.sqlContext = sqlContext; - - setDefault(intercept(), icpt); - cmdLineParams.put(this.icpt.name(), Integer.toString(icpt)); - setDefault(regParam(), reg); - cmdLineParams.put(this.reg.name(), Double.toString(reg)); - setDefault(tol(), tol); - cmdLineParams.put(this.tol.name(), Double.toString(tol)); - setDefault(maxOuterIter(), moi); - cmdLineParams.put(this.moi.name(), Integer.toString(moi)); - setDefault(maxInnerIter(), mii); - cmdLineParams.put(this.mii.name(), Integer.toString(mii)); - setDefault(labelIdx(), intMin); - li = intMin; - setDefault(inputCol(), icname); - icname[0] = ""; - setDefault(outputCol(), ocname); - ocname[0] = ""; - } - - @Override - public String uid() { - return Long.toString(LogisticRegression.serialVersionUID); - } - - public LogisticRegression setRegParam(double value) { - cmdLineParams.put(reg.name(), Double.toString(value)); - return (LogisticRegression) setDefault(reg, value); - } - - @Override - public org.apache.spark.sql.types.StructType validateAndTransformSchema(org.apache.spark.sql.types.StructType arg0, boolean arg1, org.apache.spark.sql.types.DataType arg2) { - return null; - } - - @Override - public double getRegParam() { - return Double.parseDouble(cmdLineParams.get(reg.name())); - } - - @Override - public void org$apache$spark$ml$param$shared$HasRegParam$_setter_$regParam_$eq(DoubleParam arg0) { - - } - - @Override - public DoubleParam regParam() { - return reg; - } - - @Override - public DoubleParam elasticNetParam() { - return null; - } - - @Override - public double getElasticNetParam() { - return 0.0f; - } - - @Override - public void org$apache$spark$ml$param$shared$HasElasticNetParam$_setter_$elasticNetParam_$eq(DoubleParam arg0) { - - } - - @Override - public int getMaxIter() { - return 0; - } - - @Override - public IntParam maxIter() { - return null; - } - - public LogisticRegression setMaxOuterIter(int value) { - cmdLineParams.put(moi.name(), Integer.toString(value)); - return (LogisticRegression) setDefault(moi, value); - } - - public int getMaxOuterIter() { - return Integer.parseInt(cmdLineParams.get(moi.name())); - } - - public IntParam maxOuterIter() { - return this.moi; - } - - public LogisticRegression setMaxInnerIter(int value) { - cmdLineParams.put(mii.name(), Integer.toString(value)); - return (LogisticRegression) setDefault(mii, value); - } - - public int getMaxInnerIter() { - return Integer.parseInt(cmdLineParams.get(mii.name())); - } - - public IntParam maxInnerIter() { - return mii; - } - - @Override - public void org$apache$spark$ml$param$shared$HasMaxIter$_setter_$maxIter_$eq(IntParam arg0) { - - } - - public LogisticRegression setIntercept(int value) { - cmdLineParams.put(icpt.name(), Integer.toString(value)); - return (LogisticRegression) setDefault(icpt, value); - } - - public int getIntercept() { - return Integer.parseInt(cmdLineParams.get(icpt.name())); - } - - public IntParam intercept() { - return icpt; - } - - @Override - public BooleanParam fitIntercept() { - return null; - } - - @Override - public boolean getFitIntercept() { - return false; - } - - @Override - public void org$apache$spark$ml$param$shared$HasFitIntercept$_setter_$fitIntercept_$eq(BooleanParam arg0) { - - } - - public LogisticRegression setTol(double value) { - cmdLineParams.put(tol.name(), Double.toString(value)); - return (LogisticRegression) setDefault(tol, value); - } - - @Override - public double getTol() { - return Double.parseDouble(cmdLineParams.get(tol.name())); - } - - @Override - public void org$apache$spark$ml$param$shared$HasTol$_setter_$tol_$eq(DoubleParam arg0) { - - } - - @Override - public DoubleParam tol() { - return tol; - } - - @Override - public double getThreshold() { - return 0; - } - - @Override - public void org$apache$spark$ml$param$shared$HasThreshold$_setter_$threshold_$eq(DoubleParam arg0) { - - } - - @Override - public DoubleParam threshold() { - return null; - } - - public LogisticRegression setLabelIndex(int value) { - li = value; - return (LogisticRegression) setDefault(labelIndex, value); - } - - public int getLabelIndex() { - return Integer.parseInt(cmdLineParams.get(labelIndex.name())); - } - - public IntParam labelIdx() { - return labelIndex; - } - - public LogisticRegression setInputCol(String[] value) { - icname[0] = value[0]; - return (LogisticRegression) setDefault(inputCol, value); - } - - public String getInputCol() { - return icname[0]; - } - - public StringArrayParam inputCol() { - return inputCol; - } - - public LogisticRegression setOutputCol(String[] value) { - ocname[0] = value[0]; - return (LogisticRegression) setDefault(outputCol, value); - } - - public String getOutputCol() { - return ocname[0]; - } - - public StringArrayParam outputCol() { - return outputCol; - } - - @Override - public LogisticRegressionModel train(DataFrame df) { - MLContext ml = null; - MLOutput out = null; - - try { - ml = new MLContext(this.sc); - } catch (DMLRuntimeException e1) { - e1.printStackTrace(); - return null; - } - - // Convert input data to format that SystemML accepts - MatrixCharacteristics mcXin = new MatrixCharacteristics(); - JavaPairRDD<MatrixIndexes, MatrixBlock> Xin; - try { - Xin = RDDConverterUtilsExt.vectorDataFrameToBinaryBlock(new JavaSparkContext(this.sc), df, mcXin, false, "features"); - } catch (DMLRuntimeException e1) { - e1.printStackTrace(); - return null; - } - - JavaRDD<String> yin = df.select("label").rdd().toJavaRDD().map(new ConvertSingleColumnToString()); - - try { - // Register the input/output variables of script 'MultiLogReg.dml' - ml.registerInput("X", Xin, mcXin); - ml.registerInput("Y_vec", yin, "csv"); - ml.registerOutput("B_out"); - - // Or add ifdef in MultiLogReg.dml - cmdLineParams.put("X", " "); - cmdLineParams.put("Y", " "); - cmdLineParams.put("B", " "); - - - // ------------------------------------------------------------------------------------ - // Please note that this logic is subject to change and is put as a placeholder - String systemmlHome = System.getenv("SYSTEMML_HOME"); - if(systemmlHome == null) { - System.err.println("ERROR: The environment variable SYSTEMML_HOME is not set."); - return null; - } - - String dmlFilePath = systemmlHome + File.separator + "algorithms" + File.separator + "MultiLogReg.dml"; - // ------------------------------------------------------------------------------------ - - synchronized(MLContext.class) { - // static synchronization is necessary before execute call - out = ml.execute(dmlFilePath, cmdLineParams); - } - - JavaPairRDD<MatrixIndexes, MatrixBlock> b_out = out.getBinaryBlockedRDD("B_out"); - MatrixCharacteristics b_outMC = out.getMatrixCharacteristics("B_out"); - return new LogisticRegressionModel(b_out, b_outMC, sc).setParent(this); - } catch (Exception e) { - throw new RuntimeException(e); - } - } -} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/d7b9cc46/src/main/java/org/apache/sysml/api/javaml/LogisticRegressionModel.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/javaml/LogisticRegressionModel.java b/src/main/java/org/apache/sysml/api/javaml/LogisticRegressionModel.java deleted file mode 100644 index 819380c..0000000 --- a/src/main/java/org/apache/sysml/api/javaml/LogisticRegressionModel.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.api.javaml; - -import java.io.File; -import java.util.HashMap; - -import org.apache.spark.SparkContext; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.ml.classification.ProbabilisticClassificationModel; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; - -import org.apache.sysml.api.MLContext; -import org.apache.sysml.api.MLOutput; -import org.apache.sysml.runtime.DMLRuntimeException; -import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt; -import org.apache.sysml.runtime.matrix.MatrixCharacteristics; -import org.apache.sysml.runtime.matrix.data.MatrixBlock; -import org.apache.sysml.runtime.matrix.data.MatrixIndexes; - -public class LogisticRegressionModel extends ProbabilisticClassificationModel<Vector, LogisticRegressionModel> { - - private static final long serialVersionUID = -6464693773946415027L; - private JavaPairRDD<MatrixIndexes, MatrixBlock> b_out; - private SparkContext sc; - private MatrixCharacteristics b_outMC; - @Override - public LogisticRegressionModel copy(ParamMap paramMap) { - return this; - } - - public LogisticRegressionModel(JavaPairRDD<MatrixIndexes, MatrixBlock> b_out2, MatrixCharacteristics b_outMC, SparkContext sc) { - this.b_out = b_out2; - this.b_outMC = b_outMC; - this.sc = sc; - //this.cmdLineParams = cmdLineParams; - } - - public LogisticRegressionModel() { - } - - public LogisticRegressionModel(String uid) { - } - - @Override - public String uid() { - return Long.toString(LogisticRegressionModel.serialVersionUID); - } - - @Override - public Vector raw2probabilityInPlace(Vector arg0) { - return arg0; - } - - @Override - public int numClasses() { - return 2; - } - - @Override - public Vector predictRaw(Vector arg0) { - return arg0; - } - - - @Override - public double predict(Vector features) { - return super.predict(features); - } - - @Override - public double raw2prediction(Vector rawPrediction) { - return super.raw2prediction(rawPrediction); - } - - @Override - public double probability2prediction(Vector probability) { - return super.probability2prediction(probability); - } - - public static class ConvertIntToRow implements Function<Integer, Row> { - - private static final long serialVersionUID = -3480953015655773622L; - - @Override - public Row call(Integer arg0) throws Exception { - Object[] row_fields = new Object[1]; - row_fields[0] = new Double(arg0); - return RowFactory.create(row_fields); - } - - } - - @Override - public DataFrame transform(DataFrame dataset) { - try { - MatrixCharacteristics mcXin = new MatrixCharacteristics(); - JavaPairRDD<MatrixIndexes, MatrixBlock> Xin; - try { - Xin = RDDConverterUtilsExt.vectorDataFrameToBinaryBlock(new JavaSparkContext(this.sc), dataset, mcXin, false, "features"); - } catch (DMLRuntimeException e1) { - e1.printStackTrace(); - return null; - } - MLContext ml = new MLContext(sc); - ml.registerInput("X", Xin, mcXin); - ml.registerInput("B_full", b_out, b_outMC); // Changed MLContext for this method - ml.registerOutput("means"); - HashMap<String, String> param = new HashMap<String, String>(); - param.put("dfam", "3"); - - // ------------------------------------------------------------------------------------ - // Please note that this logic is subject to change and is put as a placeholder - String systemmlHome = System.getenv("SYSTEMML_HOME"); - if(systemmlHome == null) { - System.err.println("ERROR: The environment variable SYSTEMML_HOME is not set."); - return null; - } - // Or add ifdef in GLM-predict.dml - param.put("X", " "); - param.put("B", " "); - - String dmlFilePath = systemmlHome + File.separator + "algorithms" + File.separator + "GLM-predict.dml"; - // ------------------------------------------------------------------------------------ - MLOutput out = ml.execute(dmlFilePath, param); - - SQLContext sqlContext = new SQLContext(sc); - DataFrame prob = out.getDF(sqlContext, "means", true).withColumnRenamed("C1", "probability"); - - MLContext mlNew = new MLContext(sc); - mlNew.registerInput("X", Xin, mcXin); - mlNew.registerInput("B_full", b_out, b_outMC); - mlNew.registerInput("Prob", out.getBinaryBlockedRDD("means"), out.getMatrixCharacteristics("means")); - mlNew.registerOutput("Prediction"); - mlNew.registerOutput("rawPred"); - MLOutput outNew = mlNew.executeScript("Prob = read(\"temp1\"); " - + "Prediction = rowIndexMax(Prob); " - + "write(Prediction, \"tempOut\", \"csv\")" - + "X = read(\"temp2\");" - + "B_full = read(\"temp3\");" - + "rawPred = 1 / (1 + exp(- X * t(B_full)) );" // Raw prediction logic: - + "write(rawPred, \"tempOut1\", \"csv\")"); - - // TODO: Perform joins in the DML - DataFrame pred = outNew.getDF(sqlContext, "Prediction").withColumnRenamed("C1", "prediction").withColumnRenamed("ID", "ID1"); - DataFrame rawPred = outNew.getDF(sqlContext, "rawPred", true).withColumnRenamed("C1", "rawPrediction").withColumnRenamed("ID", "ID2"); - DataFrame predictionsNProb = prob.join(pred, prob.col("ID").equalTo(pred.col("ID1"))).select("ID", "probability", "prediction"); - predictionsNProb = predictionsNProb.join(rawPred, predictionsNProb.col("ID").equalTo(rawPred.col("ID2"))).select("ID", "probability", "prediction", "rawPrediction"); - DataFrame dataset1 = RDDConverterUtilsExt.addIDToDataFrame(dataset, sqlContext, "ID"); - return dataset1.join(predictionsNProb, dataset1.col("ID").equalTo(predictionsNProb.col("ID"))).orderBy("id"); - } catch (Exception e) { - throw new RuntimeException(e); - } - } -}
