Repository: incubator-systemml Updated Branches: refs/heads/master 69acc217d -> 6627b7824
[SYSTEMML-649] JMLC scalar outputs Closes #150. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/6627b782 Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/6627b782 Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/6627b782 Branch: refs/heads/master Commit: 6627b78240f1fa7efbf107f9508a41c6e849fd60 Parents: 69acc21 Author: Deron Eriksson <de...@us.ibm.com> Authored: Wed May 11 10:46:00 2016 -0700 Committer: Deron Eriksson <de...@us.ibm.com> Committed: Wed May 11 10:46:00 2016 -0700 ---------------------------------------------------------------------- .../apache/sysml/api/jmlc/ResultVariables.java | 77 +++++++++++++++++++- .../RewriteRemovePersistentReadWrite.java | 3 + .../functions/jmlc/JMLCInputOutputTest.java | 58 +++++++++++++++ 3 files changed, 136 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6627b782/src/main/java/org/apache/sysml/api/jmlc/ResultVariables.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/jmlc/ResultVariables.java b/src/main/java/org/apache/sysml/api/jmlc/ResultVariables.java index 29398e7..4b36af7 100644 --- a/src/main/java/org/apache/sysml/api/jmlc/ResultVariables.java +++ b/src/main/java/org/apache/sysml/api/jmlc/ResultVariables.java @@ -26,6 +26,7 @@ import org.apache.sysml.api.DMLException; import org.apache.sysml.runtime.controlprogram.caching.FrameObject; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; import org.apache.sysml.runtime.instructions.cp.Data; +import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.matrix.data.FrameBlock; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.util.DataConverter; @@ -77,7 +78,7 @@ public class ResultVariables throws DMLException { if( !_out.containsKey(varname) ) - throw new DMLException("Non-existing output variable: "+varname); + throw new DMLException("Non-existent output variable: "+varname); double[][] ret = null; Data dat = _out.get(varname); @@ -106,7 +107,7 @@ public class ResultVariables throws DMLException { if( !_out.containsKey(varname) ) - throw new DMLException("Non-existing output variable: "+varname); + throw new DMLException("Non-existent output variable: "+varname); Data dat = _out.get(varname); @@ -124,6 +125,78 @@ public class ResultVariables } /** + * Obtain the double value represented by the given output variable. + * + * @param varname + * output variable name + * @return double value + * @throws DMLException + */ + public double getDouble(String varname) throws DMLException { + ScalarObject sObj = getScalarObject(varname); + return sObj.getDoubleValue(); + } + + /** + * Obtain the boolean value represented by the given output variable. + * + * @param varname + * output variable name + * @return boolean value + * @throws DMLException + */ + public boolean getBoolean(String varname) throws DMLException { + ScalarObject sObj = getScalarObject(varname); + return sObj.getBooleanValue(); + } + + /** + * Obtain the long value represented by the given output variable. + * + * @param varname + * output variable name + * @return long value + * @throws DMLException + */ + public long getLong(String varname) throws DMLException { + ScalarObject sObj = getScalarObject(varname); + return sObj.getLongValue(); + } + + /** + * Obtain the string value represented by the given output variable. + * + * @param varname + * output variable name + * @return string value + * @throws DMLException + */ + public String getString(String varname) throws DMLException { + ScalarObject sObj = getScalarObject(varname); + return sObj.getStringValue(); + } + + /** + * Obtain the ScalarObject represented by the given output variable. + * + * @param varname + * output variable name + * @return ScalarObject + * @throws DMLException + */ + public ScalarObject getScalarObject(String varname) throws DMLException { + if (!_out.containsKey(varname)) + throw new DMLException("Non-existent output variable: " + varname); + + Data dat = _out.get(varname); + + if (!(dat instanceof ScalarObject)) { + throw new DMLException("Expected scalar result '" + varname + "' not a scalar."); + } + return (ScalarObject) dat; + } + + /** * Add the output variable name and generated output data to the ResultVariable * object. Called during the execution of {@link PreparedScript}'s * {@link PreparedScript#executeScript executeScript} method. http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6627b782/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemovePersistentReadWrite.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemovePersistentReadWrite.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemovePersistentReadWrite.java index c71a935..89812f7 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemovePersistentReadWrite.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteRemovePersistentReadWrite.java @@ -116,6 +116,9 @@ public class RewriteRemovePersistentReadWrite extends HopRewriteRule case PERSISTENTWRITE: if( _outputs.contains(dop.getName()) ) dop.setDataOpType(DataOpTypes.TRANSIENTWRITE); + if (hop.getDataType() == DataType.SCALAR) { + dop.removeInput("iofilename"); + } else LOG.warn("Non-registered persistent write of variable '"+dop.getName()+"' (line "+dop.getBeginLine()+")."); break; http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/6627b782/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCInputOutputTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCInputOutputTest.java b/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCInputOutputTest.java index 44188eb..2bed8d7 100644 --- a/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCInputOutputTest.java +++ b/src/test/java/org/apache/sysml/test/integration/functions/jmlc/JMLCInputOutputTest.java @@ -25,7 +25,9 @@ import java.io.IOException; import org.apache.sysml.api.DMLException; import org.apache.sysml.api.jmlc.Connection; import org.apache.sysml.api.jmlc.PreparedScript; +import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.test.integration.AutomatedTestBase; +import org.junit.Assert; import org.junit.Test; /** @@ -142,4 +144,60 @@ public class JMLCInputOutputTest extends AutomatedTestBase { conn.close(); } + @Test + public void testScalarOutputLong() throws DMLException { + Connection conn = new Connection(); + String str = "outInteger = 5;\nwrite(outInteger, './tmp/outInteger');"; + PreparedScript script = conn.prepareScript(str, new String[] {}, new String[] { "outInteger" }, false); + + long result = script.executeScript().getLong("outInteger"); + Assert.assertEquals(5, result); + conn.close(); + } + + @Test + public void testScalarOutputDouble() throws DMLException { + Connection conn = new Connection(); + String str = "outDouble = 1.23;\nwrite(outDouble, './tmp/outDouble');"; + PreparedScript script = conn.prepareScript(str, new String[] {}, new String[] { "outDouble" }, false); + + double result = script.executeScript().getDouble("outDouble"); + Assert.assertEquals(1.23, result, 0); + conn.close(); + } + + @Test + public void testScalarOutputString() throws DMLException { + Connection conn = new Connection(); + String str = "outString = 'hello';\nwrite(outString, './tmp/outString');"; + PreparedScript script = conn.prepareScript(str, new String[] {}, new String[] { "outString" }, false); + + String result = script.executeScript().getString("outString"); + Assert.assertEquals("hello", result); + conn.close(); + } + + @Test + public void testScalarOutputBoolean() throws DMLException { + Connection conn = new Connection(); + String str = "outBoolean = FALSE;\nwrite(outBoolean, './tmp/outBoolean');"; + PreparedScript script = conn.prepareScript(str, new String[] {}, new String[] { "outBoolean" }, false); + + boolean result = script.executeScript().getBoolean("outBoolean"); + Assert.assertEquals(false, result); + conn.close(); + } + + @Test + public void testScalarOutputScalarObject() throws DMLException { + Connection conn = new Connection(); + String str = "outDouble = 1.23;\nwrite(outDouble, './tmp/outDouble');"; + PreparedScript script = conn.prepareScript(str, new String[] {}, new String[] { "outDouble" }, false); + + ScalarObject so = script.executeScript().getScalarObject("outDouble"); + double result = so.getDoubleValue(); + Assert.assertEquals(1.23, result, 0); + conn.close(); + } + } \ No newline at end of file