Repository: systemml Updated Branches: refs/heads/master 81b9248fc -> 54e809898
[SYSTEMML-1742] Fix label map while training from Caffe2DML Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/54e80989 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/54e80989 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/54e80989 Branch: refs/heads/master Commit: 54e80989897a9cd85fbad48e0d57a42c10e09cbf Parents: 81b9248 Author: Arvind Surve <[email protected]> Authored: Mon Aug 14 13:13:18 2017 -0700 Committer: Arvind Surve <[email protected]> Committed: Mon Aug 14 13:13:18 2017 -0700 ---------------------------------------------------------------------- src/main/python/systemml/mllearn/estimators.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/54e80989/src/main/python/systemml/mllearn/estimators.py ---------------------------------------------------------------------- diff --git a/src/main/python/systemml/mllearn/estimators.py b/src/main/python/systemml/mllearn/estimators.py index 1ec3628..44c6125 100644 --- a/src/main/python/systemml/mllearn/estimators.py +++ b/src/main/python/systemml/mllearn/estimators.py @@ -392,13 +392,17 @@ class BaseSystemMLClassifier(BaseSystemMLEstimator): """ if self.model != None: self.model.save(self.sc._jsc, outputDir, format, sep) - if self.le is not None: + + labelMapping = None + if hasattr(self, 'le') and self.le is not None: labelMapping = dict(enumerate(list(self.le.classes_), 1)) - else: + elif hasattr(self, 'labelMap') and self.labelMap is not None: labelMapping = self.labelMap - lStr = [ [ int(k), str(labelMapping[k]) ] for k in labelMapping ] - df = self.sparkSession.createDataFrame(lStr) - df.write.csv(outputDir + sep + 'labels.txt', mode='overwrite', header=False) + + if labelMapping is not None: + lStr = [ [ int(k), str(labelMapping[k]) ] for k in labelMapping ] + df = self.sparkSession.createDataFrame(lStr) + df.write.csv(outputDir + sep + 'labels.txt', mode='overwrite', header=False) else: raise Exception('Cannot save as you need to train the model first using fit') return self
