Repository: systemml
Updated Branches:
  refs/heads/master 1a58946a0 -> 7019f3bc8


[SYSTEMML-2505] Generate the DML for Caffe and Keras models

Here is a sample example:

```
from keras.applications.vgg16 import VGG16
keras_model = VGG16(weights="imagenet", pooling="max")
from systemml.mllearn import Keras2DML
sysml_model = Keras2DML(spark, keras_model, input_shape=(3,224,224), 
weights='weights_dir')
sysml_model.set(test_algo='batch', train_algo='minibatch')
print(sysml_model.get_training_script())
print(sysml_model.get_prediction_script())
```


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7019f3bc
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7019f3bc
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7019f3bc

Branch: refs/heads/master
Commit: 7019f3bc805aaae67ef32e281cf99e26cbd26b29
Parents: 1a58946
Author: Niketan Pansare <npan...@us.ibm.com>
Authored: Sat Dec 8 11:20:09 2018 -0800
Committer: Niketan Pansare <npan...@us.ibm.com>
Committed: Sat Dec 8 11:20:09 2018 -0800

----------------------------------------------------------------------
 src/main/python/systemml/mllearn/estimators.py         | 12 ++++++++++++
 src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala |  3 +++
 2 files changed, 15 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/7019f3bc/src/main/python/systemml/mllearn/estimators.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mllearn/estimators.py 
b/src/main/python/systemml/mllearn/estimators.py
index 8a100b4..a2647f1 100644
--- a/src/main/python/systemml/mllearn/estimators.py
+++ b/src/main/python/systemml/mllearn/estimators.py
@@ -973,6 +973,18 @@ class Caffe2DML(BaseSystemMLClassifier):
                 raise TypeError("parfor_parameters should be a dictionary")
         return self
 
+    def get_training_script(self):
+        """
+        Return the training DML script
+        """
+        return self.estimator.get_training_script()
+        
+    def get_prediction_script(self):
+        """
+        Return the prediction DML script
+        """
+        return self.estimator.get_prediction_script()
+    
     def summary(self):
         """
         Print the summary of the network

http://git-wip-us.apache.org/repos/asf/systemml/blob/7019f3bc/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala 
b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
index 8ddb1fe..13f8a65 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
@@ -221,6 +221,9 @@ class Caffe2DML(val sc: SparkContext,
     mloutput = baseFit(df, sc)
     new Caffe2DMLModel(this)
   }
+  // Public methods to be called from the Python APIs:
+  def get_training_script():String = getTrainingScript(true)._1.getScriptString
+  def get_prediction_script():String = new 
Caffe2DMLModel(this).getPredictionScript(true)._1.getScriptString
   // --------------------------------------------------------------
   // Returns true if last 2 of 4 dimensions are 1.
   // The first dimension refers to number of input datapoints.

Reply via email to