Github user yinxusen commented on a diff in the pull request:
https://github.com/apache/spark/pull/7842#discussion_r43213545
--- Diff:
mllib/src/test/scala/org/apache/spark/mllib/pmml/export/DecisionTreePMMLModelExportSuite.scala
---
@@ -0,0 +1,431 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.pmml.export
+
+import scala.collection.JavaConverters._
+
+import org.dmg.pmml._
+import org.dmg.pmml.CompoundPredicate.BooleanOperator
+import org.scalatest.PrivateMethodTester
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node,
Predict, Split}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class DecisionTreePMMLModelExportSuite extends SparkFunSuite
+with MLlibTestSparkContext
+with PrivateMethodTester {
+
+ test("PMML export should work as expected for DecisionTree model with
regressor") {
+
+ // instantiate a MLLib DecisionTreeModel with Regression and with 3
nodes with continuous
+ // feature type
+ val mlLeftNode = new Node(2, new Predict(0.5, 0.5), 0.2, true, None,
None, None, None)
+ val mlRightNode = new Node(3, new Predict(1.0, 0.5), 0.2, true, None,
None, None, None)
+ val split = new Split(100, 10.00, FeatureType.Continuous, Nil)
+ val mlTopNode = new Node(1, new Predict(0.0, 0.1), 0.2, false,
+ Some(split), Some(mlLeftNode), Some(mlRightNode), None)
+
+ val decisionTreeModel = new DecisionTreeModel(mlTopNode,
Algo.Regression)
+
+ // get the pmml exporter for the DT and verify its the right exporter
+ val pmmlExporterForDT =
PMMLModelExportFactory.createPMMLModelExport(decisionTreeModel)
+ assert(pmmlExporterForDT.isInstanceOf[DecisionTreePMMLModelExport])
+
+ // get the pmmlwrapper object for DT and verify the inner model is of
type TreeModel
+ // and basic fields are populated as expected
+ val pmmlWrapperForDT = pmmlExporterForDT.getPmml
+ assert(pmmlWrapperForDT.getHeader.getDescription == "decision tree")
+ assert(!pmmlWrapperForDT.getModels.isEmpty)
+ assert(pmmlWrapperForDT.getModels.size() == 1)
+ val pmmlModelForDT = pmmlWrapperForDT.getModels.get(0)
+ assert(pmmlModelForDT.isInstanceOf[TreeModel])
+
+ // validate the inner tree model fields are populated as expected
+ val pmmlTreeModel = pmmlModelForDT.asInstanceOf[TreeModel]
+ assert(pmmlTreeModel.getFunctionName == MiningFunctionType.REGRESSION)
+
+ // validate the root PMML node is populated as expected
+ val pmmlRootNode = pmmlTreeModel.getNode
+ assert(pmmlRootNode != null)
+ assert(pmmlRootNode.getNodes != null && pmmlRootNode.getNodes.size()
== 2)
+ assert(pmmlRootNode.getId === "1")
+ // validate the root node predicate is populated as expected
+ val predicate = pmmlRootNode.getPredicate()
+ assert(predicate != null)
+ assert(predicate.isInstanceOf[True])
+
+ // validate the left node is populated as expected
+ val pmmlLeftNode = pmmlRootNode.getNodes.get(0)
+ assert(pmmlLeftNode != null)
+ assert(!pmmlLeftNode.hasNodes)
+ assert(pmmlLeftNode.getId === "2")
+ assert(pmmlLeftNode.getScore == "0.5")
+ val predicate1 = pmmlLeftNode.getPredicate
+ assert(predicate1 != null)
+ assert(predicate1.isInstanceOf[SimplePredicate])
+ assert(predicate1.asInstanceOf[SimplePredicate].getField.getValue ===
"field_100")
+ assert(predicate1.asInstanceOf[SimplePredicate].getValue === "10.0")
+ assert(predicate1.asInstanceOf[SimplePredicate].getOperator ==
SimplePredicate.Operator
+ .LESS_OR_EQUAL)
+
+ // validate the right node is populated as expected
+ val pmmlRightNode = pmmlRootNode.getNodes.get(1)
+ assert(pmmlRightNode != null)
+ assert(!pmmlRightNode.hasNodes)
+ assert(pmmlRightNode.getId === "3")
+ assert(pmmlRightNode.getScore == "1.0")
+
+ val predicate2 = pmmlRightNode.getPredicate
+ assert(predicate2 != null)
+ assert(predicate2.isInstanceOf[SimplePredicate])
+ assert(predicate2.asInstanceOf[SimplePredicate].getField.getValue ===
"field_100")
+ assert(predicate2.asInstanceOf[SimplePredicate].getValue === "10.0")
+ assert(predicate2.asInstanceOf[SimplePredicate].getOperator ==
SimplePredicate.Operator
+ .GREATER_THAN)
+
+ // validate the mining schema is populated as expected
+ assert(pmmlModelForDT.getMiningSchema != null)
+ val miningSchema = pmmlModelForDT.getMiningSchema
+ assert(miningSchema.getMiningFields != null &&
miningSchema.getMiningFields.size() == 2)
+ val miningFields = miningSchema.getMiningFields
+ assert(miningFields.get(0).getName.getValue == "field_100")
+ assert(miningFields.get(1).getName.getValue == "target")
+
+ // validate the data dictionay is populated as expected
+ val dataDictionary = pmmlWrapperForDT.getDataDictionary
+ assert(dataDictionary != null)
+ val dataFields = dataDictionary.getDataFields
+ assert(dataFields != null && dataFields.size() == 2)
+ assert(dataFields.get(0).getName.getValue == "field_100")
+ assert(dataFields.get(0).getOpType == OpType.CONTINUOUS)
+ assert(dataFields.get(1).getName.getValue == "target")
+ assert(dataFields.get(1).getOpType == OpType.CONTINUOUS)
+ }
+
+ test("PMML export should work as expected for DecisionTree model with
classifier") {
+
+ // instantiate MLLIb DecisionTreeModel with Classification algo ,5
nodes, 2 levels
+ val mlLeftNode_L2 = new Node(4, new Predict(1.0, 0.5), 0.2, true,
None, None, None, None)
+ val mlRightNode_L2 = new Node(5, new Predict(2.0, 0.5), 0.2, true,
None, None, None, None)
+ val splitForL2 = new Split(100, 10.00, FeatureType.Categorical,
List(1, 4))
+ val mlLeftNode_L1 = new Node(2, new Predict(3.0, 0.5), 0.2, false,
+ Some(splitForL2), Some(mlLeftNode_L2), Some(mlRightNode_L2), None)
+ val mlRightNode_L1 = new Node(3, new Predict(4.0, 0.5), 0.2, true,
None, None, None, None)
+ val split = new Split(200, 10.00, FeatureType.Categorical, List(10,
20))
+ val mlTopNode = new Node(1, new Predict(5.0, 0.1), 0.2, false,
Some(split),
+ Some(mlLeftNode_L1), Some(mlRightNode_L1), None)
+ val decisionTreeModel = new DecisionTreeModel(mlTopNode,
Algo.Classification)
+
+ // get the pmml exporter for the DT and verify its the right exporter
+ val pmmlExporterForDT =
PMMLModelExportFactory.createPMMLModelExport(decisionTreeModel)
+ assert(pmmlExporterForDT.isInstanceOf[DecisionTreePMMLModelExport])
+
+ // get the pmmlwrapper object for DT and verify the inner model is of
type TreeModel
+ // and basic fields are populated as expected
+ val pmmlWrapperForDT = pmmlExporterForDT.getPmml
+ assert(pmmlWrapperForDT.getHeader.getDescription == "decision tree")
+ assert(!pmmlWrapperForDT.getModels.isEmpty)
+ assert(pmmlWrapperForDT.getModels.size() == 1)
+
+ // validate the inner tree model fields are populated as expected
+ val pmmlModelForDT = pmmlWrapperForDT.getModels.get(0)
+ assert(pmmlModelForDT.isInstanceOf[TreeModel])
+ val pmmlTreeModel = pmmlModelForDT.asInstanceOf[TreeModel]
+ assert(pmmlTreeModel.getFunctionName ==
MiningFunctionType.CLASSIFICATION)
+
+ // validate the pmml root node fields are populated as expected
+ val pmmlRootNode = pmmlTreeModel.getNode
+ assert(pmmlRootNode != null)
+ assert(pmmlRootNode.getNodes != null && pmmlRootNode.getNodes.size()
== 2)
+ assert(pmmlRootNode.getId === "1")
+
+ // validate the pmml root node predicate is a true predicate since its
root node
+
+ val predicate = pmmlRootNode.getPredicate()
+ assert(predicate != null)
+ assert(predicate.isInstanceOf[True])
+
+ // validate level 1 left node is populated properly
+ val pmmlLeftNode_L1 = pmmlRootNode.getNodes.get(0)
+ assert(pmmlLeftNode_L1 != null)
+ assert(pmmlLeftNode_L1.hasNodes)
+ assert(pmmlLeftNode_L1.getId === "2")
+ assert(pmmlLeftNode_L1.getScore == "3.0")
+ // left node to the root node should have compound predicate, since
its condition is on multiple
+ // categories
+ val predicateL1 = pmmlLeftNode_L1.getPredicate
+ assert(predicateL1 != null)
+ assert(predicateL1.isInstanceOf[CompoundPredicate])
+ val cPredicate1 = predicateL1.asInstanceOf[CompoundPredicate]
+ assert(cPredicate1.getBooleanOperator == BooleanOperator.OR)
+ assert(cPredicate1.getPredicates != null &&
cPredicate1.getPredicates.size() == 2)
+ val predicatesList1 = cPredicate1.getPredicates
+ assert(predicatesList1.get(0).isInstanceOf[SimplePredicate])
+
assert(predicatesList1.get(0).asInstanceOf[SimplePredicate].getField.getValue
=== "field_200")
+ assert(predicatesList1.get(0).asInstanceOf[SimplePredicate].getValue
=== "10.0")
+
assert(predicatesList1.get(0).asInstanceOf[SimplePredicate].getOperator ==
SimplePredicate
+ .Operator.EQUAL)
+
+ assert(predicatesList1.get(1).isInstanceOf[SimplePredicate])
+
assert(predicatesList1.get(1).asInstanceOf[SimplePredicate].getField.getValue
=== "field_200")
+ assert(predicatesList1.get(1).asInstanceOf[SimplePredicate].getValue
=== "20.0")
+
assert(predicatesList1.get(1).asInstanceOf[SimplePredicate].getOperator ==
SimplePredicate
+ .Operator.EQUAL)
+
+ // validate level 1 right node is populated properly
+ val pmmlRightNode_L1 = pmmlRootNode.getNodes.get(1)
+ assert(pmmlRightNode_L1 != null)
+ assert(!pmmlRightNode_L1.hasNodes)
+ assert(pmmlRightNode_L1.getId === "3")
+ assert(pmmlRightNode_L1.getScore == "4.0")
+ // right node at level 1 should have True Predicate since the left
node is the list of
+ // categories predicate
+ val predicateR1 = pmmlRightNode_L1.getPredicate
+ assert(predicateR1 != null)
+ assert(predicateR1.isInstanceOf[True])
+
--- End diff --
too many empty lines.
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]