Repository: systemml
Updated Branches:
  refs/heads/master ab251f6ee -> 2fc26b3dc


[MINOR] Allow non-literal values in parameterized built-in functions


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/2fc26b3d
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/2fc26b3d
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/2fc26b3d

Branch: refs/heads/master
Commit: 2fc26b3dced89a473055828b08550ed6e6a8d7be
Parents: ab251f6
Author: Niketan Pansare <npan...@us.ibm.com>
Authored: Mon Sep 10 15:05:05 2018 -0700
Committer: Niketan Pansare <npan...@us.ibm.com>
Committed: Mon Sep 10 15:05:05 2018 -0700

----------------------------------------------------------------------
 .../gpu/GPUDenseInputPointerFetcher.java        |  1 -
 .../gpu/MatrixReshapeGPUInstruction.java        |  3 +-
 .../ParameterizedBuiltinSPInstruction.java      | 75 ++++++++++++++++----
 3 files changed, 63 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/2fc26b3d/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
index 8fcaec3..1ab3420 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/GPUDenseInputPointerFetcher.java
@@ -20,7 +20,6 @@ package org.apache.sysml.runtime.instructions.gpu;
 
 import java.util.HashMap;
 
-import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;

http://git-wip-us.apache.org/repos/asf/systemml/blob/2fc26b3d/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixReshapeGPUInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixReshapeGPUInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixReshapeGPUInstruction.java
index 61cb643..ee2166e 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixReshapeGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/gpu/MatrixReshapeGPUInstruction.java
@@ -79,7 +79,8 @@ public class MatrixReshapeGPUInstruction extends 
GPUInstruction {
                GPUContext gCtx = ec.getGPUContext(0); 
                MatrixObject mat = getMatrixInputForGPUInstruction(ec, 
_input.getName());
                if(rows*cols != mat.getNumRows()*mat.getNumColumns()) {
-                       throw new DMLRuntimeException("Incorrect number of rows 
and cols in rshape instruction");
+                       throw new DMLRuntimeException("Cannot reshape a matrix 
of dimensions: [" + mat.getNumRows() + ", " + mat.getNumColumns() + "] to a 
matrix of"
+                                       + " dimensions [" + rows + ", " + cols 
+ "]");
                }
                // We currently support only dense rshape
                Pointer inPtr = LibMatrixCUDA.getDensePointer(gCtx, mat, 
instName);

http://git-wip-us.apache.org/repos/asf/systemml/blob/2fc26b3d/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index f9d7ef3..4a1c710 100644
--- 
a/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++ 
b/src/main/java/org/apache/sysml/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -174,6 +174,54 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                }
        }
        
+       private double getDoubleParam(ExecutionContext ec, String key) {
+               String val = params.get(key);
+               try {
+                       if(val != null)
+                               return Double.parseDouble( val );
+                       else
+                               throw new RuntimeException("Expected parameter 
" + key);
+               } catch(NumberFormatException e) {
+                       return ec.getScalarInput(val, ValueType.DOUBLE, 
false).getDoubleValue();
+               } 
+       }
+       
+       private boolean getBooleanParam(ExecutionContext ec, String key) {
+               String val = params.get(key);
+               try {
+                       if(val != null)
+                               return Boolean.parseBoolean( val.toLowerCase() 
);
+                       else
+                               throw new RuntimeException("Expected parameter 
" + key);
+               } catch(NumberFormatException e) {
+                       return ec.getScalarInput(val, ValueType.BOOLEAN, 
false).getBooleanValue();
+               } 
+       }
+       
+       private long getLongParam(ExecutionContext ec, String key) {
+               String val = params.get(key);
+               try {
+                       if(val != null)
+                               return Long.parseLong( val );
+                       else
+                               throw new RuntimeException("Expected parameter 
" + key);
+               } catch(NumberFormatException e) {
+                       return ec.getScalarInput(val, ValueType.INT, 
false).getLongValue(); 
+               } 
+       }
+       
+       private long getLongParam(ExecutionContext ec, String key, long 
defaultValue) {
+               String val = params.get(key);
+               try {
+                       if(val != null)
+                               return Long.parseLong( val );
+                       else
+                               return defaultValue;
+               } catch(NumberFormatException e) {
+                       return ec.getScalarInput(val, ValueType.INT, 
false).getLongValue(); 
+               } 
+       }
+       
 
        @Override 
        @SuppressWarnings("unchecked")
@@ -191,8 +239,8 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                        PartitionedBroadcast<MatrixBlock> groups = 
sec.getBroadcastForVariable(groupsVar);
                        MatrixCharacteristics mc1 = 
sec.getMatrixCharacteristics( targetVar );
                        MatrixCharacteristics mcOut = 
sec.getMatrixCharacteristics(output.getName());
-                       CPOperand ngrpOp = new 
CPOperand(params.get(Statement.GAGG_NUM_GROUPS));
-                       int ngroups = (int)sec.getScalarInput(ngrpOp.getName(), 
ngrpOp.getValueType(), ngrpOp.isLiteral()).getLongValue();
+               
+                       int ngroups = (int) getLongParam(ec, 
Statement.GAGG_NUM_GROUPS);
                        
                        //single-block aggregation
                        if( ngroups <= mc1.getRowsPerBlock() && mc1.getCols() 
<= mc1.getColsPerBlock() ) {
@@ -222,7 +270,7 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                }
                else if ( opcode.equalsIgnoreCase("groupedagg") )
                {
-                       boolean broadcastGroups = 
Boolean.parseBoolean(params.get("broadcast"));
+                       boolean broadcastGroups = getBooleanParam(ec, 
"broadcast");
                        
                        //get input rdd handle
                        String groupsVar = params.get(Statement.GAGG_GROUPS);
@@ -253,8 +301,7 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                        }
                        else //input vector or matrix
                        {
-                               String ngroupsStr = 
params.get(Statement.GAGG_NUM_GROUPS);
-                               long ngroups = (ngroupsStr != null) ? (long) 
Double.parseDouble(ngroupsStr) : -1;
+                               long ngroups = getLongParam(ec, 
Statement.GAGG_NUM_GROUPS, -1);
                                
                                //execute basic grouped aggregate (extract and 
preagg)
                                if( broadcastGroups ) {
@@ -312,8 +359,8 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                        String rddOffVar = params.get("offset");
                        
                        boolean rows = sec.getScalarInput(params.get("margin"), 
ValueType.STRING, true).getStringValue().equals("rows");
-                       boolean emptyReturn = 
Boolean.parseBoolean(params.get("empty.return").toLowerCase());
-                       long maxDim = sec.getScalarInput(params.get("maxdim"), 
ValueType.DOUBLE, false).getLongValue();
+                       boolean emptyReturn = getBooleanParam(ec, 
"empty.return");
+                       long maxDim = getLongParam(ec, "maxdim");
                        MatrixCharacteristics mcIn = 
sec.getMatrixCharacteristics(rddInVar);
                        
                        if( maxDim > 0 ) //default case
@@ -369,8 +416,8 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                        MatrixCharacteristics mcIn = 
sec.getMatrixCharacteristics(params.get("target"));
                        
                        //execute replace operation
-                       double pattern = Double.parseDouble( 
params.get("pattern") );
-                       double replacement = Double.parseDouble( 
params.get("replacement") );
+                       double pattern = getDoubleParam(ec, "pattern"); 
+                       double replacement = getDoubleParam(ec, "replacement");
                        JavaPairRDD<MatrixIndexes,MatrixBlock> out = 
                                in1.mapValues(new RDDReplaceFunction(pattern, 
replacement));
                        
@@ -388,8 +435,8 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                        JavaPairRDD<MatrixIndexes,MatrixBlock> in1 = 
sec.getBinaryBlockRDDHandleForVariable(params.get("target"));
                        MatrixCharacteristics mcIn = 
sec.getMatrixCharacteristics(params.get("target"));
                        boolean lower = opcode.equalsIgnoreCase("lowertri");
-                       boolean diag = Boolean.parseBoolean(params.get("diag"));
-                       boolean values = 
Boolean.parseBoolean(params.get("values"));
+                       boolean diag = getBooleanParam(ec, "diag"); 
+                       boolean values = getBooleanParam(ec, "values");
                        
                        JavaPairRDD<MatrixIndexes,MatrixBlock> out = 
in1.mapPartitionsToPair(
                                new RDDExtractTriangularFunction(lower, diag, 
values), true);
@@ -408,11 +455,11 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                        //get input rdd handle
                        JavaPairRDD<MatrixIndexes,MatrixBlock> in = 
sec.getBinaryBlockRDDHandleForVariable( rddInVar );
                        MatrixCharacteristics mcIn = 
sec.getMatrixCharacteristics(rddInVar);
-                       double maxVal = Double.parseDouble( params.get("max") );
+                       double maxVal = getDoubleParam(ec, "max"); 
                        long lmaxVal = UtilFunctions.toLong(maxVal);
                        boolean dirRows = params.get("dir").equals("rows");
-                       boolean cast = Boolean.parseBoolean(params.get("cast"));
-                       boolean ignore = 
Boolean.parseBoolean(params.get("ignore"));
+                       boolean cast = getBooleanParam(ec, "cast");
+                       boolean ignore = getBooleanParam(ec, "ignore");
                        long brlen = mcIn.getRowsPerBlock();
                        long bclen = mcIn.getColsPerBlock();
                        

Reply via email to