Repository: systemml Updated Branches: refs/heads/master 1a58946a0 -> 7019f3bc8
[SYSTEMML-2505] Generate the DML for Caffe and Keras models Here is a sample example: ``` from keras.applications.vgg16 import VGG16 keras_model = VGG16(weights="imagenet", pooling="max") from systemml.mllearn import Keras2DML sysml_model = Keras2DML(spark, keras_model, input_shape=(3,224,224), weights='weights_dir') sysml_model.set(test_algo='batch', train_algo='minibatch') print(sysml_model.get_training_script()) print(sysml_model.get_prediction_script()) ``` Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/7019f3bc Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/7019f3bc Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/7019f3bc Branch: refs/heads/master Commit: 7019f3bc805aaae67ef32e281cf99e26cbd26b29 Parents: 1a58946 Author: Niketan Pansare <npan...@us.ibm.com> Authored: Sat Dec 8 11:20:09 2018 -0800 Committer: Niketan Pansare <npan...@us.ibm.com> Committed: Sat Dec 8 11:20:09 2018 -0800 ---------------------------------------------------------------------- src/main/python/systemml/mllearn/estimators.py | 12 ++++++++++++ src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala | 3 +++ 2 files changed, 15 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/7019f3bc/src/main/python/systemml/mllearn/estimators.py ---------------------------------------------------------------------- diff --git a/src/main/python/systemml/mllearn/estimators.py b/src/main/python/systemml/mllearn/estimators.py index 8a100b4..a2647f1 100644 --- a/src/main/python/systemml/mllearn/estimators.py +++ b/src/main/python/systemml/mllearn/estimators.py @@ -973,6 +973,18 @@ class Caffe2DML(BaseSystemMLClassifier): raise TypeError("parfor_parameters should be a dictionary") return self + def get_training_script(self): + """ + Return the training DML script + """ + return self.estimator.get_training_script() + + def get_prediction_script(self): + """ + Return the prediction DML script + """ + return self.estimator.get_prediction_script() + def summary(self): """ Print the summary of the network http://git-wip-us.apache.org/repos/asf/systemml/blob/7019f3bc/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala ---------------------------------------------------------------------- diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala index 8ddb1fe..13f8a65 100644 --- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala +++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala @@ -221,6 +221,9 @@ class Caffe2DML(val sc: SparkContext, mloutput = baseFit(df, sc) new Caffe2DMLModel(this) } + // Public methods to be called from the Python APIs: + def get_training_script():String = getTrainingScript(true)._1.getScriptString + def get_prediction_script():String = new Caffe2DMLModel(this).getPredictionScript(true)._1.getScriptString // -------------------------------------------------------------- // Returns true if last 2 of 4 dimensions are 1. // The first dimension refers to number of input datapoints.