Repository: systemml
Updated Branches:
  refs/heads/master 81b9248fc -> 54e809898


[SYSTEMML-1742] Fix label map while training from Caffe2DML


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/54e80989
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/54e80989
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/54e80989

Branch: refs/heads/master
Commit: 54e80989897a9cd85fbad48e0d57a42c10e09cbf
Parents: 81b9248
Author: Arvind Surve <[email protected]>
Authored: Mon Aug 14 13:13:18 2017 -0700
Committer: Arvind Surve <[email protected]>
Committed: Mon Aug 14 13:13:18 2017 -0700

----------------------------------------------------------------------
 src/main/python/systemml/mllearn/estimators.py | 14 +++++++++-----
 1 file changed, 9 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/54e80989/src/main/python/systemml/mllearn/estimators.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mllearn/estimators.py 
b/src/main/python/systemml/mllearn/estimators.py
index 1ec3628..44c6125 100644
--- a/src/main/python/systemml/mllearn/estimators.py
+++ b/src/main/python/systemml/mllearn/estimators.py
@@ -392,13 +392,17 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator):
         """
         if self.model != None:
             self.model.save(self.sc._jsc, outputDir, format, sep)
-            if self.le is not None:
+
+            labelMapping = None
+            if hasattr(self, 'le') and self.le is not None:
                 labelMapping = dict(enumerate(list(self.le.classes_), 1))
-            else:
+            elif hasattr(self, 'labelMap') and self.labelMap is not None:
                 labelMapping = self.labelMap
-            lStr = [ [ int(k), str(labelMapping[k]) ] for k in labelMapping ]
-            df = self.sparkSession.createDataFrame(lStr)
-            df.write.csv(outputDir + sep + 'labels.txt', mode='overwrite', 
header=False)
+
+            if labelMapping is not None:
+                lStr = [ [ int(k), str(labelMapping[k]) ] for k in 
labelMapping ]
+                df = self.sparkSession.createDataFrame(lStr)
+                df.write.csv(outputDir + sep + 'labels.txt', mode='overwrite', 
header=False)
         else:
             raise Exception('Cannot save as you need to train the model first 
using fit')
         return self

Reply via email to