Close #70: [HIVEMALL-75-2] Add tree_export UDF and update RandomForest tutorial
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/9876d063 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/9876d063 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/9876d063 Branch: refs/heads/master Commit: 9876d06316ad6e4ef5b62511c0806d1c3d8c03ce Parents: 9f01ebf Author: Makoto Yui <[email protected]> Authored: Fri Jun 30 21:15:54 2017 +0900 Committer: Makoto Yui <[email protected]> Committed: Fri Jun 30 21:15:54 2017 +0900 ---------------------------------------------------------------------- .../smile/classification/DecisionTree.java | 145 ++++++-- .../GradientTreeBoostingClassifierUDTF.java | 2 +- .../RandomForestClassifierUDTF.java | 2 +- .../regression/RandomForestRegressionUDTF.java | 2 +- .../smile/regression/RegressionTree.java | 124 +++++-- .../hivemall/smile/tools/TreeExportUDF.java | 241 +++++++++++++ .../hivemall/smile/tools/TreePredictUDF.java | 4 +- .../hivemall/smile/utils/SmileExtUtils.java | 41 +++ .../java/hivemall/utils/hadoop/HiveUtils.java | 19 ++ .../smile/classification/DecisionTreeTest.java | 79 ++++- .../RandomForestClassifierUDTFTest.java | 4 +- .../smile/regression/RegressionTreeTest.java | 60 +++- .../smile/tools/TreePredictUDFTest.java | 4 +- docs/gitbook/SUMMARY.md | 3 +- docs/gitbook/binaryclass/news20_rf.md | 90 +++++ docs/gitbook/binaryclass/titanic_rf.md | 93 +++++- docs/gitbook/ft_engineering/hashing.md | 45 ++- docs/gitbook/multiclass/iris_dataset.md | 65 +--- docs/gitbook/multiclass/iris_randomforest.md | 259 ++++++++------ docs/gitbook/multiclass/iris_scw.md | 334 +++---------------- docs/gitbook/resources/images/iris.png | Bin 0 -> 92872 bytes .../ddl/define-all-as-permanent.deprecated.hive | 6 - resources/ddl/define-all-as-permanent.hive | 3 + resources/ddl/define-all.deprecated.hive | 6 - resources/ddl/define-all.hive | 3 + resources/ddl/define-all.spark | 3 + resources/ddl/define-udfs.td.hql | 1 + 27 files changed, 1076 insertions(+), 562 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/smile/classification/DecisionTree.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/classification/DecisionTree.java b/core/src/main/java/hivemall/smile/classification/DecisionTree.java index 2d086b9..fa97dba 100644 --- a/core/src/main/java/hivemall/smile/classification/DecisionTree.java +++ b/core/src/main/java/hivemall/smile/classification/DecisionTree.java @@ -33,6 +33,8 @@ */ package hivemall.smile.classification; +import static hivemall.smile.utils.SmileExtUtils.resolveFeatureName; +import static hivemall.smile.utils.SmileExtUtils.resolveName; import hivemall.annotations.VisibleForTesting; import hivemall.math.matrix.Matrix; import hivemall.math.matrix.ints.ColumnMajorIntMatrix; @@ -47,6 +49,7 @@ import hivemall.smile.data.Attribute.AttributeType; import hivemall.smile.utils.SmileExtUtils; import hivemall.utils.collections.lists.IntArrayList; import hivemall.utils.lang.ObjectUtils; +import hivemall.utils.lang.mutable.MutableInt; import hivemall.utils.sampling.IntReservoirSampler; import java.io.Externalizable; @@ -292,41 +295,114 @@ public final class DecisionTree implements Classifier<Vector> { } } - public void jsCodegen(@Nonnull final StringBuilder builder, final int depth) { + public void exportJavascript(@Nonnull final StringBuilder builder, + @Nullable final String[] featureNames, @Nullable final String[] classNames, + final int depth) { if (trueChild == null && falseChild == null) { indent(builder, depth); - builder.append("").append(output).append(";\n"); + builder.append("").append(resolveName(output, classNames)).append(";\n"); } else { + indent(builder, depth); if (splitFeatureType == AttributeType.NOMINAL) { - indent(builder, depth); - builder.append("if(x[") - .append(splitFeature) - .append("] == ") - .append(splitValue) - .append(") {\n"); - trueChild.jsCodegen(builder, depth + 1); - indent(builder, depth); - builder.append("} else {\n"); - falseChild.jsCodegen(builder, depth + 1); - indent(builder, depth); - builder.append("}\n"); + if (featureNames == null) { + builder.append("if( x[") + .append(splitFeature) + .append("] == ") + .append(splitValue) + .append(" ) {\n"); + } else { + builder.append("if( ") + .append(resolveFeatureName(splitFeature, featureNames)) + .append(" == ") + .append(splitValue) + .append(" ) {\n"); + } } else if (splitFeatureType == AttributeType.NUMERIC) { - indent(builder, depth); - builder.append("if(x[") - .append(splitFeature) - .append("] <= ") - .append(splitValue) - .append(") {\n"); - trueChild.jsCodegen(builder, depth + 1); - indent(builder, depth); - builder.append("} else {\n"); - falseChild.jsCodegen(builder, depth + 1); - indent(builder, depth); - builder.append("}\n"); + if (featureNames == null) { + builder.append("if( x[") + .append(splitFeature) + .append("] <= ") + .append(splitValue) + .append(" ) {\n"); + } else { + builder.append("if( ") + .append(resolveFeatureName(splitFeature, featureNames)) + .append(" <= ") + .append(splitValue) + .append(" ) {\n"); + } } else { throw new IllegalStateException("Unsupported attribute type: " + splitFeatureType); } + trueChild.exportJavascript(builder, featureNames, classNames, depth + 1); + indent(builder, depth); + builder.append("} else {\n"); + falseChild.exportJavascript(builder, featureNames, classNames, depth + 1); + indent(builder, depth); + builder.append("}\n"); + } + } + + public void exportGraphviz(@Nonnull final StringBuilder builder, + @Nullable final String[] featureNames, @Nullable final String[] classNames, + @Nonnull final String outputName, @Nullable double[] colorBrew, + final @Nonnull MutableInt nodeIdGenerator, final int parentNodeId) { + final int myNodeId = nodeIdGenerator.getValue(); + + if (trueChild == null && falseChild == null) { + // fillcolor=h,s,v + // https://en.wikipedia.org/wiki/HSL_and_HSV + // http://www.graphviz.org/doc/info/attrs.html#k:colorList + String hsvColor = (colorBrew == null || output >= colorBrew.length) ? "#00000000" + : String.format("%.4f,1.000,1.000", colorBrew[output]); + builder.append(String.format( + " %d [label=<%s = %s>, fillcolor=\"%s\", shape=ellipse];\n", myNodeId, + outputName, resolveName(output, classNames), hsvColor)); + + if (myNodeId != parentNodeId) { + builder.append(' ').append(parentNodeId).append(" -> ").append(myNodeId); + if (parentNodeId == 0) { + if (myNodeId == 1) { + builder.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]"); + } else { + builder.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]"); + } + } + builder.append(";\n"); + } + } else { + if (splitFeatureType == AttributeType.NOMINAL) { + builder.append(String.format( + " %d [label=<%s = %s>, fillcolor=\"#00000000\"];\n", myNodeId, + resolveFeatureName(splitFeature, featureNames), Double.toString(splitValue))); + } else if (splitFeatureType == AttributeType.NUMERIC) { + builder.append(String.format( + " %d [label=<%s ≤ %s>, fillcolor=\"#00000000\"];\n", myNodeId, + resolveFeatureName(splitFeature, featureNames), Double.toString(splitValue))); + } else { + throw new IllegalStateException("Unsupported attribute type: " + + splitFeatureType); + } + + if (myNodeId != parentNodeId) { + builder.append(' ').append(parentNodeId).append(" -> ").append(myNodeId); + if (parentNodeId == 0) {//only draw edge label on top + if (myNodeId == 1) { + builder.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]"); + } else { + builder.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]"); + } + } + builder.append(";\n"); + } + + nodeIdGenerator.addValue(1); + trueChild.exportGraphviz(builder, featureNames, classNames, outputName, colorBrew, + nodeIdGenerator, myNodeId); + nodeIdGenerator.addValue(1); + falseChild.exportGraphviz(builder, featureNames, classNames, outputName, colorBrew, + nodeIdGenerator, myNodeId); } } @@ -910,6 +986,11 @@ public final class DecisionTree implements Classifier<Vector> { } } + @VisibleForTesting + Node getRootNode() { + return _root; + } + private static void checkArgument(@Nonnull Matrix x, @Nonnull int[] y, int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize) { if (x.numRows() != y.length) { @@ -965,14 +1046,15 @@ public final class DecisionTree implements Classifier<Vector> { throw new UnsupportedOperationException("Not supported."); } - public String predictJsCodegen() { + public String predictJsCodegen(@Nonnull final String[] featureNames, + @Nonnull final String[] classNames) { StringBuilder buf = new StringBuilder(1024); - _root.jsCodegen(buf, 0); + _root.exportJavascript(buf, featureNames, classNames, 0); return buf.toString(); } @Nonnull - public byte[] predictSerCodegen(boolean compress) throws HiveException { + public byte[] serialize(boolean compress) throws HiveException { try { if (compress) { return ObjectUtils.toCompressedBytes(_root); @@ -986,7 +1068,8 @@ public final class DecisionTree implements Classifier<Vector> { } } - public static Node deserializeNode(final byte[] serializedObj, final int length, + @Nonnull + public static Node deserialize(@Nonnull final byte[] serializedObj, final int length, final boolean compressed) throws HiveException { final Node root = new Node(); try { @@ -1006,7 +1089,7 @@ public final class DecisionTree implements Classifier<Vector> { @Override public String toString() { - return _root == null ? "" : predictJsCodegen(); + return _root == null ? "" : predictJsCodegen(null, null); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java index a380a11..adb405f 100644 --- a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java +++ b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java @@ -579,7 +579,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { final int m = trees.length; final Text[] models = new Text[m]; for (int i = 0; i < m; i++) { - byte[] b = trees[i].predictSerCodegen(true); + byte[] b = trees[i].serialize(true); b = Base91.encode(b); models[i] = new Text(b); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java index 5a831df..59f52d3 100644 --- a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java +++ b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java @@ -603,7 +603,7 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions { @Nonnull private static Text getModel(@Nonnull final DecisionTree tree) throws HiveException { - byte[] b = tree.predictSerCodegen(true); + byte[] b = tree.serialize(true); b = Base91.encode(b); return new Text(b); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java index 557df21..58151e4 100644 --- a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java +++ b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java @@ -499,7 +499,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { @Nonnull private static Text getModel(@Nonnull final RegressionTree tree) throws HiveException { - byte[] b = tree.predictSerCodegen(true); + byte[] b = tree.serialize(true); b = Base91.encode(b); return new Text(b); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/smile/regression/RegressionTree.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/regression/RegressionTree.java b/core/src/main/java/hivemall/smile/regression/RegressionTree.java index da7e80b..81b9ba8 100755 --- a/core/src/main/java/hivemall/smile/regression/RegressionTree.java +++ b/core/src/main/java/hivemall/smile/regression/RegressionTree.java @@ -33,6 +33,8 @@ */ package hivemall.smile.regression; +import static hivemall.smile.utils.SmileExtUtils.resolveFeatureName; +import static hivemall.smile.utils.SmileExtUtils.resolveName; import hivemall.annotations.VisibleForTesting; import hivemall.math.matrix.Matrix; import hivemall.math.matrix.ints.ColumnMajorIntMatrix; @@ -48,6 +50,7 @@ import hivemall.utils.collections.lists.IntArrayList; import hivemall.utils.collections.sets.IntArraySet; import hivemall.utils.collections.sets.IntSet; import hivemall.utils.lang.ObjectUtils; +import hivemall.utils.lang.mutable.MutableInt; import hivemall.utils.math.MathUtils; import java.io.Externalizable; @@ -246,35 +249,52 @@ public final class RegressionTree implements Regression<Vector> { } } - public void jsCodegen(@Nonnull final StringBuilder builder, final int depth) { + public void exportJavascript(@Nonnull final StringBuilder builder, + @Nullable final String[] featureNames, final int depth) { if (trueChild == null && falseChild == null) { indent(builder, depth); - builder.append("").append(output).append(";\n"); + builder.append(output).append(";\n"); } else { if (splitFeatureType == AttributeType.NOMINAL) { indent(builder, depth); - builder.append("if(x[") - .append(splitFeature) - .append("] == ") - .append(splitValue) - .append(") {\n"); - trueChild.jsCodegen(builder, depth + 1); + if (featureNames == null) { + builder.append("if( x[") + .append(splitFeature) + .append("] == ") + .append(splitValue) + .append(") {\n"); + } else { + builder.append("if( ") + .append(resolveFeatureName(splitFeature, featureNames)) + .append(" == ") + .append(splitValue) + .append(") {\n"); + } + trueChild.exportJavascript(builder, featureNames, depth + 1); indent(builder, depth); builder.append("} else {\n"); - falseChild.jsCodegen(builder, depth + 1); + falseChild.exportJavascript(builder, featureNames, depth + 1); indent(builder, depth); builder.append("}\n"); } else if (splitFeatureType == AttributeType.NUMERIC) { indent(builder, depth); - builder.append("if(x[") - .append(splitFeature) - .append("] <= ") - .append(splitValue) - .append(") {\n"); - trueChild.jsCodegen(builder, depth + 1); + if (featureNames == null) { + builder.append("if( x[") + .append(splitFeature) + .append("] <= ") + .append(splitValue) + .append(") {\n"); + } else { + builder.append("if( ") + .append(resolveFeatureName(splitFeature, featureNames)) + .append(" <= ") + .append(splitValue) + .append(") {\n"); + } + trueChild.exportJavascript(builder, featureNames, depth + 1); indent(builder, depth); - builder.append("} else {\n"); - falseChild.jsCodegen(builder, depth + 1); + builder.append("} else {\n"); + falseChild.exportJavascript(builder, featureNames, depth + 1); indent(builder, depth); builder.append("}\n"); } else { @@ -284,6 +304,63 @@ public final class RegressionTree implements Regression<Vector> { } } + public void exportGraphviz(@Nonnull final StringBuilder builder, + @Nullable final String[] featureNames, @Nonnull final String outputName, + final @Nonnull MutableInt nodeIdGenerator, final int parentNodeId) { + final int myNodeId = nodeIdGenerator.getValue(); + + if (trueChild == null && falseChild == null) { + builder.append(String.format( + " %d [label=<%s = %s>, fillcolor=\"#00000000\", shape=ellipse];\n", myNodeId, + outputName, Double.toString(output))); + + if (myNodeId != parentNodeId) { + builder.append(' ').append(parentNodeId).append(" -> ").append(myNodeId); + if (parentNodeId == 0) { + if (myNodeId == 1) { + builder.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]"); + } else { + builder.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]"); + } + } + builder.append(";\n"); + } + } else { + if (splitFeatureType == AttributeType.NOMINAL) { + builder.append(String.format( + " %d [label=<%s = %s>, fillcolor=\"#00000000\"];\n", myNodeId, + resolveFeatureName(splitFeature, featureNames), Double.toString(splitValue))); + } else if (splitFeatureType == AttributeType.NUMERIC) { + builder.append(String.format( + " %d [label=<%s ≤ %s>, fillcolor=\"#00000000\"];\n", myNodeId, + resolveFeatureName(splitFeature, featureNames), Double.toString(splitValue))); + } else { + throw new IllegalStateException("Unsupported attribute type: " + + splitFeatureType); + } + + if (myNodeId != parentNodeId) { + builder.append(' ').append(parentNodeId).append(" -> ").append(myNodeId); + if (parentNodeId == 0) {//only draw edge label on top + if (myNodeId == 1) { + builder.append(" [labeldistance=2.5, labelangle=45, headlabel=\"True\"]"); + } else { + builder.append(" [labeldistance=2.5, labelangle=-45, headlabel=\"False\"]"); + } + } + builder.append(";\n"); + } + + nodeIdGenerator.addValue(1); + trueChild.exportGraphviz(builder, featureNames, outputName, nodeIdGenerator, + myNodeId); + nodeIdGenerator.addValue(1); + falseChild.exportGraphviz(builder, featureNames, outputName, nodeIdGenerator, + myNodeId); + } + } + + @Override public void writeExternal(ObjectOutput out) throws IOException { out.writeInt(splitFeature); @@ -837,14 +914,14 @@ public final class RegressionTree implements Regression<Vector> { return _root.predict(x); } - public String predictJsCodegen() { + public String predictJsCodegen(@Nonnull final String[] featureNames) { StringBuilder buf = new StringBuilder(1024); - _root.jsCodegen(buf, 0); + _root.exportJavascript(buf, featureNames, 0); return buf.toString(); } @Nonnull - public byte[] predictSerCodegen(boolean compress) throws HiveException { + public byte[] serialize(boolean compress) throws HiveException { try { if (compress) { return ObjectUtils.toCompressedBytes(_root); @@ -858,7 +935,8 @@ public final class RegressionTree implements Regression<Vector> { } } - public static Node deserializeNode(final byte[] serializedObj, final int length, + @Nonnull + public static Node deserialize(@Nonnull final byte[] serializedObj, final int length, final boolean compressed) throws HiveException { final Node root = new Node(); try { @@ -876,4 +954,8 @@ public final class RegressionTree implements Regression<Vector> { return root; } + @Override + public String toString() { + return _root == null ? "" : predictJsCodegen(null); + } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/smile/tools/TreeExportUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/tools/TreeExportUDF.java b/core/src/main/java/hivemall/smile/tools/TreeExportUDF.java new file mode 100644 index 0000000..7d509ad --- /dev/null +++ b/core/src/main/java/hivemall/smile/tools/TreeExportUDF.java @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.smile.tools; + +import hivemall.UDFWithOptions; +import hivemall.smile.classification.DecisionTree; +import hivemall.smile.regression.RegressionTree; +import hivemall.smile.utils.SmileExtUtils; +import hivemall.utils.codec.Base91; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.mutable.MutableInt; + +import java.util.Arrays; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.UDFType; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; +import org.apache.hadoop.io.Text; + +@Description( + name = "tree_export", + value = "_FUNC_(string model, const string options, optional array<string> featureNames=null, optional array<string> classNames=null)" + + " - exports a Decision Tree model as javascript/dot]") +@UDFType(deterministic = true, stateful = false) +public final class TreeExportUDF extends UDFWithOptions { + + private transient Evaluator evaluator; + + private transient StringObjectInspector modelOI; + @Nullable + private transient ListObjectInspector featureNamesOI; + @Nullable + private transient ListObjectInspector classNamesOI; + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("t", "type", true, + "Type of output [default: js, javascript/js, graphvis/dot"); + opts.addOption("r", "regression", false, "Is regression tree or not"); + opts.addOption("output_name", "outputName", true, "output name [default: predicted]"); + return opts; + } + + @Override + protected CommandLine processOptions(@Nonnull String opts) throws UDFArgumentException { + CommandLine cl = parseOptions(opts); + + OutputType outputType = OutputType.resolve(cl.getOptionValue("type")); + boolean regression = cl.hasOption("regression"); + String outputName = cl.getOptionValue("output_name", "predicted"); + this.evaluator = new Evaluator(outputType, outputName, regression); + + return cl; + } + + @Override + public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + final int argLen = argOIs.length; + if (argLen < 2 || argLen > 4) { + throw new UDFArgumentException("_FUNC_ takes 2~4 arguments: " + argLen); + } + + this.modelOI = HiveUtils.asStringOI(argOIs[0]); + + String options = HiveUtils.getConstString(argOIs[1]); + processOptions(options); + + if (argLen >= 3) { + this.featureNamesOI = HiveUtils.asListOI(argOIs[2]); + if (!HiveUtils.isStringOI(featureNamesOI.getListElementObjectInspector())) { + throw new UDFArgumentException("_FUNC_ expected array<string> for featureNames: " + + featureNamesOI.getTypeName()); + } + if (argLen == 4) { + this.classNamesOI = HiveUtils.asListOI(argOIs[3]); + if (!HiveUtils.isStringOI(classNamesOI.getListElementObjectInspector())) { + throw new UDFArgumentException("_FUNC_ expected array<string> for classNames: " + + classNamesOI.getTypeName()); + } + } + } + + return PrimitiveObjectInspectorFactory.writableStringObjectInspector; + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + Object arg0 = arguments[0].get(); + if (arg0 == null) { + return null; + } + Text model = modelOI.getPrimitiveWritableObject(arg0); + + String[] featureNames = null, classNames = null; + if (arguments.length >= 3) { + featureNames = HiveUtils.asStringArray(arguments[2], featureNamesOI); + if (arguments.length >= 4) { + classNames = HiveUtils.asStringArray(arguments[3], classNamesOI); + } + } + + try { + return evaluator.export(model, featureNames, classNames); + } catch (HiveException he) { + throw he; + } catch (Throwable e) { + throw new HiveException(e); + } + } + + @Override + public String getDisplayString(String[] children) { + return "tree_export(" + Arrays.toString(children) + ")"; + } + + public enum OutputType { + javascript, graphvis; + + @Nonnull + public static OutputType resolve(@Nonnull String name) throws UDFArgumentException { + if ("js".equalsIgnoreCase(name) || "javascript".equalsIgnoreCase(name)) { + return javascript; + } else if ("dot".equalsIgnoreCase(name) || "graphvis".equalsIgnoreCase(name)) { + return graphvis; + } else { + throw new UDFArgumentException( + "Please provide a valid `-type` option from [javascript, graphvis]: " + name); + } + } + } + + public static class Evaluator { + + @Nonnull + private final OutputType outputType; + @Nonnull + private final String outputName; + private final boolean regression; + + public Evaluator(@Nonnull OutputType outputType, @Nonnull String outputName, + boolean regression) { + this.outputType = outputType; + this.outputName = outputName; + this.regression = regression; + } + + @Nonnull + public Text export(@Nonnull Text model, @Nullable String[] featureNames, + @Nullable String[] classNames) throws HiveException { + int length = model.getLength(); + byte[] b = model.getBytes(); + b = Base91.decode(b, 0, length); + + final String exported; + if (regression) { + exported = exportRegressor(b, featureNames); + } else { + exported = exportClassifier(b, featureNames, classNames); + } + return new Text(exported); + } + + @Nonnull + private String exportClassifier(@Nonnull byte[] b, @Nullable String[] featureNames, + @Nullable String[] classNames) throws HiveException { + final DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true); + + final StringBuilder buf = new StringBuilder(8192); + switch (outputType) { + case javascript: { + node.exportJavascript(buf, featureNames, classNames, 0); + break; + } + case graphvis: { + buf.append("digraph Tree {\n node [shape=box, style=\"filled, rounded\", color=\"black\", fontname=helvetica];\n edge [fontname=helvetica];\n"); + double[] colorBrew = (classNames == null) ? null + : SmileExtUtils.getColorBrew(classNames.length); + node.exportGraphviz(buf, featureNames, classNames, outputName, colorBrew, + new MutableInt(0), 0); + buf.append("}"); + break; + } + default: + throw new HiveException("Unsupported outputType: " + outputType); + } + return buf.toString(); + } + + @Nonnull + private String exportRegressor(@Nonnull byte[] b, @Nullable String[] featureNames) + throws HiveException { + final RegressionTree.Node node = RegressionTree.deserialize(b, b.length, true); + + final StringBuilder buf = new StringBuilder(8192); + switch (outputType) { + case javascript: { + node.exportJavascript(buf, featureNames, 0); + break; + } + case graphvis: { + buf.append("digraph Tree {\n node [shape=box, style=\"filled, rounded\", color=\"black\", fontname=helvetica];\n edge [fontname=helvetica];\n"); + node.exportGraphviz(buf, featureNames, outputName, new MutableInt(0), 0); + buf.append("}"); + break; + } + default: + throw new HiveException("Unsupported outputType: " + outputType); + } + return buf.toString(); + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java index dc544ae..46b8758 100644 --- a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java +++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java @@ -249,7 +249,7 @@ public final class TreePredictUDF extends GenericUDF { int length = script.getLength(); byte[] b = script.getBytes(); b = Base91.decode(b, 0, length); - this.cNode = DecisionTree.deserializeNode(b, b.length, true); + this.cNode = DecisionTree.deserialize(b, b.length, true); } Arrays.fill(result, null); @@ -287,7 +287,7 @@ public final class TreePredictUDF extends GenericUDF { int length = script.getLength(); byte[] b = script.getBytes(); b = Base91.decode(b, 0, length); - this.rNode = RegressionTree.deserializeNode(b, b.length, true); + this.rNode = RegressionTree.deserialize(b, b.length, true); } Preconditions.checkNotNull(rNode); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java b/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java index 74a3032..5e27e12 100644 --- a/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java +++ b/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java @@ -33,11 +33,13 @@ import hivemall.smile.data.Attribute.NominalAttribute; import hivemall.smile.data.Attribute.NumericAttribute; import hivemall.utils.collections.lists.DoubleArrayList; import hivemall.utils.collections.lists.IntArrayList; +import hivemall.utils.lang.Preconditions; import hivemall.utils.lang.mutable.MutableInt; import hivemall.utils.math.MathUtils; import java.util.Arrays; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -394,4 +396,43 @@ public final class SmileExtUtils { return false; } + @Nonnull + public static String resolveFeatureName(final int index, @Nullable final String[] names) { + if (names == null) { + return "feature#" + index; + } + if (index >= names.length) { + return "feature#" + index; + } + return names[index]; + } + + @Nonnull + public static String resolveName(final int index, @Nullable final String[] names) { + if (names == null) { + return String.valueOf(index); + } + if (index >= names.length) { + return String.valueOf(index); + } + return names[index]; + } + + /** + * Generates an evenly distributed range of hue values in the HSV color scale. + * + * @return colors + */ + public static double[] getColorBrew(@Nonnegative int n) { + Preconditions.checkArgument(n >= 1); + + final double hue_step = 360.d / n; + + final double[] colors = new double[n]; + for (int i = 0; i < n; i++) { + colors[i] = i * hue_step / 360.d; + } + return colors; + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index 6c1b0d1..4ed1f12 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -187,6 +187,25 @@ public final class HiveUtils { return Arrays.asList(ary); } + @Nullable + public static String[] asStringArray(@Nonnull final DeferredObject arg, + @Nonnull final ListObjectInspector listOI) throws HiveException { + Object argObj = arg.get(); + if (argObj == null) { + return null; + } + List<?> data = listOI.getList(argObj); + final int size = data.size(); + final String[] arr = new String[size]; + for (int i = 0; i < size; i++) { + Object o = data.get(i); + if (o != null) { + arr[i] = o.toString(); + } + } + return arr; + } + @Nonnull public static StructObjectInspector asStructOI(@Nonnull final ObjectInspector oi) throws UDFArgumentException { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java index bb6de6b..897da0c 100644 --- a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java +++ b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java @@ -25,7 +25,10 @@ import hivemall.math.matrix.dense.RowMajorDenseMatrix2d; import hivemall.math.random.RandomNumberGeneratorFactory; import hivemall.smile.classification.DecisionTree.Node; import hivemall.smile.data.Attribute; +import hivemall.smile.tools.TreeExportUDF.Evaluator; +import hivemall.smile.tools.TreeExportUDF.OutputType; import hivemall.smile.utils.SmileExtUtils; +import hivemall.utils.codec.Base91; import java.io.BufferedInputStream; import java.io.IOException; @@ -36,6 +39,7 @@ import java.text.ParseException; import javax.annotation.Nonnull; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.io.Text; import org.junit.Assert; import org.junit.Test; @@ -106,6 +110,71 @@ public class DecisionTreeTest { assertEquals(7, error); } + @Test + public void testGraphvisOutputIris() throws IOException, ParseException, HiveException { + String datasetUrl = "https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff"; + int responseIndex = 4; + int numLeafs = 4; + boolean dense = true; + String outputName = "class"; + String[] featureNames = new String[] {"sepallength", "sepalwidth", "petallength", + "petalwidth"}; + String[] classNames = new String[] {"setosa", "versicolor", "virginica"}; + + debugPrint(graphvisOutput(datasetUrl, responseIndex, numLeafs, dense, featureNames, + classNames, outputName)); + + featureNames = null; + classNames = null; + outputName = null; + debugPrint(graphvisOutput(datasetUrl, responseIndex, numLeafs, dense, featureNames, + classNames, outputName)); + } + + @Test + public void testGraphvisOutputWeather() throws IOException, ParseException, HiveException { + String datasetUrl = "https://gist.githubusercontent.com/myui/2c9df50db3de93a71b92/raw/3f6b4ecfd4045008059e1a2d1c4064fb8a3d372a/weather.nominal.arff"; + int responseIndex = 4; + int numLeafs = 3; + boolean dense = true; + String[] featureNames = new String[] {"outlook", "temperature", "humidity", "windy"}; + String[] classNames = new String[] {"yes", "no"}; + String outputName = "play"; + + debugPrint(graphvisOutput(datasetUrl, responseIndex, numLeafs, dense, featureNames, + classNames, outputName)); + + featureNames = null; + classNames = null; + debugPrint(graphvisOutput(datasetUrl, responseIndex, numLeafs, dense, featureNames, + classNames, outputName)); + } + + private static String graphvisOutput(String datasetUrl, int responseIndex, int numLeafs, + boolean dense, String[] featureNames, String[] classNames, String outputName) + throws IOException, HiveException, ParseException { + URL url = new URL(datasetUrl); + InputStream is = new BufferedInputStream(url.openStream()); + + ArffParser arffParser = new ArffParser(); + arffParser.setResponseIndex(responseIndex); + + AttributeDataset ds = arffParser.parse(is); + double[][] x = ds.toArray(new double[ds.size()][]); + int[] y = ds.toArray(new int[ds.size()]); + + Attribute[] attrs = SmileExtUtils.convertAttributeTypes(ds.attributes()); + DecisionTree tree = new DecisionTree(attrs, matrix(x, dense), y, numLeafs, + RandomNumberGeneratorFactory.createPRNG(31)); + + Text model = new Text(Base91.encode(tree.serialize(true))); + + Evaluator eval = new Evaluator(OutputType.graphvis, outputName, false); + Text exported = eval.export(model, featureNames, classNames); + + return exported.toString(); + } + private static int run(String datasetUrl, int responseIndex, int numLeafs, boolean dense) throws IOException, ParseException { URL url = new URL(datasetUrl); @@ -185,8 +254,8 @@ public class DecisionTreeTest { Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4); - byte[] b = tree.predictSerCodegen(false); - Node node = DecisionTree.deserializeNode(b, b.length, false); + byte[] b = tree.serialize(false); + Node node = DecisionTree.deserialize(b, b.length, false); assertEquals(tree.predict(x[loocv.test[i]]), node.predict(x[loocv.test[i]])); } } @@ -212,11 +281,11 @@ public class DecisionTreeTest { Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), trainy, 4); - byte[] b1 = tree.predictSerCodegen(true); - byte[] b2 = tree.predictSerCodegen(false); + byte[] b1 = tree.serialize(true); + byte[] b2 = tree.serialize(false); Assert.assertTrue("b1.length = " + b1.length + ", b2.length = " + b2.length, b1.length < b2.length); - Node node = DecisionTree.deserializeNode(b1, b1.length, true); + Node node = DecisionTree.deserialize(b1, b1.length, true); assertEquals(tree.predict(x[loocv.test[i]]), node.predict(x[loocv.test[i]])); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java index d682093..578689c 100644 --- a/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java +++ b/core/src/test/java/hivemall/smile/classification/RandomForestClassifierUDTFTest.java @@ -208,7 +208,7 @@ public class RandomForestClassifierUDTFTest { Assert.assertNotNull(modelTxt); byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength()); - DecisionTree.Node node = DecisionTree.deserializeNode(b, b.length, true); + DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true); return node; } @@ -257,7 +257,7 @@ public class RandomForestClassifierUDTFTest { Assert.assertNotNull(modelTxt); byte[] b = Base91.decode(modelTxt.getBytes(), 0, modelTxt.getLength()); - DecisionTree.Node node = DecisionTree.deserializeNode(b, b.length, true); + DecisionTree.Node node = DecisionTree.deserialize(b, b.length, true); return node; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java b/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java index eae625d..f3eb5e5 100644 --- a/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java +++ b/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java @@ -24,12 +24,18 @@ import hivemall.math.matrix.dense.RowMajorDenseMatrix2d; import hivemall.math.random.RandomNumberGeneratorFactory; import hivemall.smile.data.Attribute; import hivemall.smile.data.Attribute.NumericAttribute; +import hivemall.smile.tools.TreeExportUDF.Evaluator; +import hivemall.smile.tools.TreeExportUDF.OutputType; +import hivemall.utils.codec.Base91; +import java.io.IOException; +import java.text.ParseException; import java.util.Arrays; import javax.annotation.Nonnull; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.io.Text; import org.junit.Assert; import org.junit.Test; @@ -37,6 +43,7 @@ import smile.math.Math; import smile.validation.LOOCV; public class RegressionTreeTest { + private static final boolean DEBUG = false; @Test public void testPredictDense() { @@ -158,8 +165,8 @@ public class RegressionTreeTest { int maxLeafs = Integer.MAX_VALUE; RegressionTree tree = new RegressionTree(attrs, matrix(trainx, true), trainy, maxLeafs); - byte[] b = tree.predictSerCodegen(true); - RegressionTree.Node node = RegressionTree.deserializeNode(b, b.length, true); + byte[] b = tree.serialize(true); + RegressionTree.Node node = RegressionTree.deserialize(b, b.length, true); double expected = tree.predict(longley[loocv.test[i]]); double actual = node.predict(longley[loocv.test[i]]); @@ -168,6 +175,49 @@ public class RegressionTreeTest { } } + @Test + public void testGraphvizOutput() throws HiveException, IOException, ParseException { + int maxLeafts = 10; + String outputName = "predicted"; + + double[][] x = { {234.289, 235.6, 159.0, 107.608, 1947, 60.323}, + {259.426, 232.5, 145.6, 108.632, 1948, 61.122}, + {258.054, 368.2, 161.6, 109.773, 1949, 60.171}, + {284.599, 335.1, 165.0, 110.929, 1950, 61.187}, + {328.975, 209.9, 309.9, 112.075, 1951, 63.221}, + {346.999, 193.2, 359.4, 113.270, 1952, 63.639}, + {365.385, 187.0, 354.7, 115.094, 1953, 64.989}, + {363.112, 357.8, 335.0, 116.219, 1954, 63.761}, + {397.469, 290.4, 304.8, 117.388, 1955, 66.019}, + {419.180, 282.2, 285.7, 118.734, 1956, 67.857}, + {442.769, 293.6, 279.8, 120.445, 1957, 68.169}, + {444.546, 468.1, 263.7, 121.950, 1958, 66.513}, + {482.704, 381.3, 255.2, 123.366, 1959, 68.655}, + {502.601, 393.1, 251.4, 125.368, 1960, 69.564}, + {518.173, 480.6, 257.2, 127.852, 1961, 69.331}, + {554.894, 400.7, 282.7, 130.081, 1962, 70.551}}; + + double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 104.6, 108.4, 110.8, + 112.6, 114.2, 115.7, 116.9}; + + debugPrint(graphvisOutput(x, y, maxLeafts, true, null, outputName)); + } + + private static String graphvisOutput(double[][] x, double[] y, int maxLeafts, boolean dense, + String[] featureNames, String outputName) throws IOException, HiveException, + ParseException { + Attribute[] attrs = new Attribute[x[0].length]; + Arrays.fill(attrs, new NumericAttribute()); + RegressionTree tree = new RegressionTree(attrs, matrix(x, dense), y, maxLeafts); + + Text model = new Text(Base91.encode(tree.serialize(true))); + + Evaluator eval = new Evaluator(OutputType.graphvis, outputName, true); + Text exported = eval.export(model, featureNames, null); + + return exported.toString(); + } + @Nonnull private static Matrix matrix(@Nonnull final double[][] x, boolean dense) { if (dense) { @@ -182,4 +232,10 @@ public class RegressionTreeTest { } } + private static void debugPrint(String msg) { + if (DEBUG) { + System.out.println(msg); + } + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java b/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java index 65feeeb..31713d9 100644 --- a/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java +++ b/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java @@ -164,7 +164,7 @@ public class TreePredictUDFTest { } private static int evalPredict(DecisionTree tree, double[] x) throws HiveException, IOException { - byte[] b = tree.predictSerCodegen(true); + byte[] b = tree.serialize(true); byte[] encoded = Base91.encode(b); Text model = new Text(encoded); @@ -186,7 +186,7 @@ public class TreePredictUDFTest { private static double evalPredict(RegressionTree tree, double[] x) throws HiveException, IOException { - byte[] b = tree.predictSerCodegen(true); + byte[] b = tree.serialize(true); byte[] encoded = Base91.encode(b); Text model = new Text(encoded); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/docs/gitbook/SUMMARY.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md index 32b0150..cc7f622 100644 --- a/docs/gitbook/SUMMARY.md +++ b/docs/gitbook/SUMMARY.md @@ -92,6 +92,7 @@ * [Perceptron, Passive Aggressive](binaryclass/news20_pa.md) * [CW, AROW, SCW](binaryclass/news20_scw.md) * [AdaGradRDA, AdaGrad, AdaDelta](binaryclass/news20_adagrad.md) + * [Random Forest](binaryclass/news20_rf.md) * [KDD2010a tutorial](binaryclass/kdd2010a.md) * [Data preparation](binaryclass/kdd2010a_dataset.md) @@ -121,7 +122,7 @@ * [Iris tutorial](multiclass/iris.md) * [Data preparation](multiclass/iris_dataset.md) * [SCW](multiclass/iris_scw.md) - * [RandomForest](multiclass/iris_randomforest.md) + * [Random Forest](multiclass/iris_randomforest.md) ## Part VIII - Regression http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/docs/gitbook/binaryclass/news20_rf.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/binaryclass/news20_rf.md b/docs/gitbook/binaryclass/news20_rf.md new file mode 100644 index 0000000..fd0b475 --- /dev/null +++ b/docs/gitbook/binaryclass/news20_rf.md @@ -0,0 +1,90 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +Hivemall Random Forest supports libsvm-like sparse inputs. + +> #### Note +> This feature, i.e., Sparse input support in Random Forest, is supported since Hivemall v0.5-rc.1 or later._ +> [`feature_hashing`](http://hivemall.incubator.apache.org/userguide/ft_engineering/hashing.html#featurehashing-function) function is useful to prepare feature vectors for Random Forest. + +<!-- toc --> + +## Training + +```sql +drop table rf_model; +create table rf_model +as +select + train_randomforest_classifier( + features, + convert_label(label), -- convert -1/1 to 0/1 + '-trees 50 -seed 71' -- hyperparameters + ) +from + train; +``` + +> #### Caution +> label must be in `[0, k)` where `k` is the number of classes. + +## Prediction + +```sql +SET hivevar:classification=true; + +drop table rf_predicted; +create table rf_predicted +as +SELECT + rowid, + rf_ensemble(predicted.value, predicted.posteriori, model_weight) as predicted + -- rf_ensemble(predicted.value, predicted.posteriori) as predicted -- avoid OOB accuracy (i.e., model_weight) +FROM ( + SELECT + rowid, + m.model_weight, + tree_predict(m.model_id, m.model, t.features, ${classification}) as predicted + FROM + rf_model m + LEFT OUTER JOIN -- CROSS JOIN + test t +) t1 +group by + rowid +; +``` + +## Evaluation + +```sql +WITH submit as ( + select + convert_label(t.label) as actual, + p.predicted.label as predicted + from + test t + JOIN rf_predicted p on (t.rowid = p.rowid) +) +select count(1) / 4996.0 +from submit +where actual = predicted; +``` + +> 0.8112489991993594 http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/docs/gitbook/binaryclass/titanic_rf.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/binaryclass/titanic_rf.md b/docs/gitbook/binaryclass/titanic_rf.md index 1a9786e..64502b9 100644 --- a/docs/gitbook/binaryclass/titanic_rf.md +++ b/docs/gitbook/binaryclass/titanic_rf.md @@ -47,10 +47,14 @@ ROW FORMAT DELIMITED FIELDS TERMINATED BY '|' LINES TERMINATED BY '\n' STORED AS TEXTFILE LOCATION '/dataset/titanic/train'; +``` +```sh hadoop fs -rm /dataset/titanic/train/train.csv awk '{ FPAT="([^,]*)|(\"[^\"]+\")";OFS="|"; } NR >1 {$1=$1;$4=substr($4,2,length($4)-2);print $0}' train.csv | hadoop fs -put - /dataset/titanic/train/train.csv +``` +```sql drop table test_raw; create external table test_raw ( passengerid int, @@ -69,7 +73,9 @@ ROW FORMAT DELIMITED FIELDS TERMINATED BY '|' LINES TERMINATED BY '\n' STORED AS TEXTFILE LOCATION '/dataset/titanic/test_raw'; +``` +```sh hadoop fs -rm /dataset/titanic/test_raw/test.csv awk '{ FPAT="([^,]*)|(\"[^\"]+\")";OFS="|"; } NR >1 {$1=$1;$3=substr($3,2,length($3)-2);print $0}' test.csv | hadoop fs -put - /dataset/titanic/test_raw/test.csv ``` @@ -163,9 +169,8 @@ select sum(oob_errors) / sum(oob_tests) as oob_err_rate from model_rf; - -> [137.00242639169272,1194.2140119834373,328.78017188176966,628.2568660509628,200.31275032394072,160.12876797647078,1083.5987543408116,664.1234312561456,422.89449844090393,130.72019667694784] 0.18742985409652077 ``` +> [137.00242639169272,1194.2140119834373,328.78017188176966,628.2568660509628,200.31275032394072,160.12876797647078,1083.5987543408116,664.1234312561456,422.89449844090393,130.72019667694784] 0.18742985409652077 # Prediction @@ -186,16 +191,27 @@ SELECT FROM ( SELECT passengerid, - rf_ensemble(predicted) as predicted + -- rf_ensemble(predicted) as predicted + -- hivemall v0.5-rc.1 or later + rf_ensemble(predicted.value, predicted.posteriori, model_weight) as predicted + -- rf_ensemble(predicted.value, predicted.posteriori) as predicted -- avoid OOB accuracy (i.e., model_weight) FROM ( SELECT t.passengerid, -- hivemall v0.4.1-alpha.2 or before -- tree_predict(p.model, t.features, ${classification}) as predicted -ãã -- hivemall v0.4.1-alpha.3 or later - tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted + -- hivemall v0.4.1-alpha.3 or later + -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted + -- hivemall v0.5-rc.1 or later + p.model_weight, + tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted FROM ( - SELECT model_id, model_type, pred_model FROM model_rf + SELECT + -- model_id, pred_model + -- hivemall v0.5-rc.1 or later + model_id, model_weight, model + FROM + model_rf DISTRIBUTE BY rand(1) ) p LEFT OUTER JOIN test_rf t @@ -223,12 +239,49 @@ ORDER BY passengerid ASC; ```sh hadoop fs -getmerge /user/hive/warehouse/titanic.db/predicted_rf_submit predicted_rf_submit.csv - sed -i -e "1i PassengerId,Survived" predicted_rf_submit.csv ``` Accuracy would gives `0.76555` for a Kaggle submission. +# Graphvis export + +> #### Note +> `tree_export` feature is supported from Hivemall v0.5-rc.1 or later. +> Better to limit tree depth on training by `-depth` option to plot a Decision Tree. + +Hivemall provide `tree_export` to export a decision tree into [Graphviz](http://www.graphviz.org/) or human-readable Javascript format. You can find the usage by issuing the following query: + +``` +> select tree_export("","-help"); + +usage: tree_export(string model, const string options, optional + array<string> featureNames=null, optional array<string> + classNames=null) - exports a Decision Tree model as javascript/dot] + [-help] [-output_name <arg>] [-r] [-t <arg>] + -help Show function help + -output_name,--outputName <arg> output name [default: predicted] + -r,--regression Is regression tree or not + -t,--type <arg> Type of output [default: js, + javascript/js, graphvis/dot +``` + +```sql +CREATE TABLE model_exported + STORED AS ORC tblproperties("orc.compress"="SNAPPY") +AS +select + model_id, + tree_export(model, "-type javascript -output_name survived", array('pclass','name','sex','age','sibsp','parch','ticket','fare','cabin','embarked'), array('no','yes')) as js, + tree_export(model, "-type graphvis -output_name survived", array('pclass','name','sex','age','sibsp','parch','ticket','fare','cabin','embarked'), array('no','yes')) as dot +from + model_rf +-- limit 1 +; +``` + +[Here is an example](https://gist.github.com/myui/a83ba3795bad9b278cf8bcc59f946e2c#file-titanic-dot) plotting a decision tree using Graphvis or [Vis.js](http://viz-js.com/). + --- # Test by dividing training dataset @@ -259,8 +312,10 @@ select sum(oob_errors) / sum(oob_tests) as oob_err_rate from model_rf_07; +``` > [116.12055542977338,960.8569891444097,291.08765260103837,469.74671636586226,163.721292772701,120.784769882858,847.9769298113661,554.4617571355476,346.3500941757221,97.42593940113392] > 0.1838351822503962 +```sql SET hivevar:classification=true; SET hive.mapjoin.optimized.hashtable=false; SET mapred.reduce.tasks=16; @@ -276,16 +331,27 @@ SELECT FROM ( SELECT passengerid, - rf_ensemble(predicted) as predicted + -- rf_ensemble(predicted) as predicted + -- hivemall v0.5-rc.1 or later + rf_ensemble(predicted.value, predicted.posteriori, model_weight) as predicted + -- rf_ensemble(predicted.value, predicted.posteriori) as predicted -- avoid OOB accuracy (i.e., model_weight) FROM ( SELECT t.passengerid, -- hivemall v0.4.1-alpha.2 or before -- tree_predict(p.model, t.features, ${classification}) as predicted -- hivemall v0.4.1-alpha.3 or later - tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted + -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted + -- hivemall v0.5-rc.1 or later + p.model_weight, + tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted FROM ( - SELECT model_id, model_type, pred_model FROM model_rf_07 + SELECT + -- model_id, model_type, pred_model + -- hivemall v0.5-rc.1 or later + model_id, model_weight, model + FROM + model_rf_07 DISTRIBUTE BY rand(1) ) p LEFT OUTER JOIN test_rf_03 t @@ -306,13 +372,16 @@ from ; select count(1) from test_rf_03; +``` > 260 + +```sql set hivevar:testcnt=260; select count(1)/${testcnt} as accuracy from rf_submit_03 where actual = predicted; - -> 0.8 ``` +> 0.8153846153846154 + http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/docs/gitbook/ft_engineering/hashing.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/ft_engineering/hashing.md b/docs/gitbook/ft_engineering/hashing.md index baa4cd4..8a08b8c 100644 --- a/docs/gitbook/ft_engineering/hashing.md +++ b/docs/gitbook/ft_engineering/hashing.md @@ -28,40 +28,54 @@ Find the differences in the following examples. ```sql select feature_hashing('aaa'); +``` > 4063537 +```sql select feature_hashing('aaa','-features 3'); +``` > 2 +```sql select feature_hashing(array('aaa','bbb')); +``` > ["4063537","8459207"] +```sql select feature_hashing(array('aaa','bbb'),'-features 10'); +``` > ["7","1"] +```sql select feature_hashing(array('aaa:1.0','aaa','bbb:2.0')); +``` > ["4063537:1.0","4063537","8459207:2.0"] +```sql select feature_hashing(array(1,2,3)); +``` > ["11293631","3322224","4331412"] +```sql select feature_hashing(array('1','2','3')); +``` > ["11293631","3322224","4331412"] +```sql select feature_hashing(array('1:0.1','2:0.2','3:0.3')); +``` > ["11293631:0.1","3322224:0.2","4331412:0.3"] +```sql select feature_hashing(features), features from training_fm limit 2; - +``` > ["1803454","6630176"] ["userid#5689","movieid#3072"] > ["1828616","6238429"] ["userid#4505","movieid#2331"] +```sql select feature_hashing(array("userid#4505:3.3","movieid#2331:4.999", "movieid#2331")); - -> ["1828616:3.3","6238429:4.999","6238429"] ``` - -_Note: The hash value is starting from 1 and 0 is system reserved for a bias clause. The default number of features are 16777217 (2^24). You can control the number of features by `-num_features` (or `-features`) option._ +> ["1828616:3.3","6238429:4.999","6238429"] ```sql select feature_hashing(null,'-help'); @@ -74,49 +88,50 @@ usage: feature_hashing(array<string> features [, const string options]) - -help Show function help ``` +> #### Note +> The hash value is starting from 1 and 0 is system reserved for a bias clause. The default number of features are 16777217 (2^24). +> You can control the number of features by `-num_features` (or `-features`) option. + ## `mhash` function ```sql describe function extended mhash; -> mhash(string word) returns a murmurhash3 INT value starting from 1 ``` +> mhash(string word) returns a murmurhash3 INT value starting from 1 ```sql - select mhash('aaa'); -> 4063537 ``` +> 4063537 _Note: The default number of features are `16777216 (2^24)`._ ```sql set hivevar:num_features=16777216; - select mhash('aaa',${num_features}); ->4063537 ``` +>4063537 _Note: `mhash` returns a `+1'd` murmurhash3 value starting from 1. Never returns 0 (It's a system reserved number)._ ```sql set hivevar:num_features=1; - select mhash('aaa',${num_features}); -> 1 ``` +> 1 _Note: `mhash` does not considers feature values._ ```sql select mhash('aaa:2.0'); -> 2746618 ``` +> 2746618 _Note: `mhash` always returns a scalar INT value._ ```sql select mhash(array('aaa','bbb')); -> 9566153 ``` +> 9566153 _Note: `mhash` value of an array is element order-sentitive._ ```sql select mhash(array('bbb','aaa')); +``` > 3874068 -``` \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/docs/gitbook/multiclass/iris_dataset.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/multiclass/iris_dataset.md b/docs/gitbook/multiclass/iris_dataset.md index e67737e..8dae7c9 100644 --- a/docs/gitbook/multiclass/iris_dataset.md +++ b/docs/gitbook/multiclass/iris_dataset.md @@ -126,13 +126,13 @@ select rand(${rand_seed}) as rnd, * from iris_scaled; -- 80% for training create table train80p as -select * from iris_shuffled +select * from iris_shuffled order by rnd DESC limit 120; -- 20% for testing create table test20p as -select * from iris_shuffled +select * from iris_shuffled order by rnd ASC limit 30; @@ -159,64 +159,3 @@ select from train80p; ``` - -# Training (multiclass classification) - -```sql -create table model_scw1 as -select - label, - feature, - argmin_kld(weight, covar) as weight -from - (select - train_multiclass_scw(features, label) as (label, feature, weight, covar) - from - training_x10 - ) t -group by label, feature; -``` - -# Predict - -```sql -create or replace view predict_scw1 -as -select - rowid, - m.col0 as score, - m.col1 as label -from ( -select - rowid, - maxrow(score, label) as m -from ( - select - t.rowid, - m.label, - sum(m.weight * t.value) as score - from - test20p_exploded t LEFT OUTER JOIN - model_scw1 m ON (t.feature = m.feature) - group by - t.rowid, m.label -) t1 -group by rowid -) t2; -``` - -# Evaluation - -```sql -create or replace view eval_scw1 as -select - t.label as actual, - p.label as predicted -from - test20p t JOIN predict_scw1 p - on (t.rowid = p.rowid); - -select count(1)/30 from eval_scw1 -where actual = predicted; -``` -> 0.9666666666666667 http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/9876d063/docs/gitbook/multiclass/iris_randomforest.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/multiclass/iris_randomforest.md b/docs/gitbook/multiclass/iris_randomforest.md index 4b0750c..d0e8e8c 100644 --- a/docs/gitbook/multiclass/iris_randomforest.md +++ b/docs/gitbook/multiclass/iris_randomforest.md @@ -89,7 +89,7 @@ from ```sql CREATE TABLE model -STORED AS SEQUENCEFILE + STORED AS SEQUENCEFILE AS select train_randomforest_classifier(features, label) @@ -100,60 +100,72 @@ select from training; ``` -*Note: The default TEXTFILE should not be used for model table when using Javascript output through "-output javascript" option.* +> #### Caution +> The default `TEXTFILE` should not be used for model table when using Javascript output through `-output javascript` option. + +```sql +hive> desc extended model; ``` -hive> desc model; -model_id int -model_type int -pred_model string -var_importance array<double> -oob_errors int -oob_tests int -``` + +| col_name | data_type +|:-:|:-:| +| model_id | string | +| model_weight | double | +| model | string | +| var_importance | array<double> | +| oob_errors | int | +| oob_tests | int | + ## Training options -"-help" option shows usage of the function. +`-help` option shows usage of the function. -``` +```sql select train_randomforest_classifier(features, label, "-help") from training; > FAILED: UDFArgumentException -usage: train_randomforest_classifier(double[] features, int label [, - string options]) - Returns a relation consists of <int model_id, - int model_type, string pred_model, array<double> var_importance, - int oob_errors, int oob_tests> [-attrs <arg>] [-depth <arg>] - [-disable_compression] [-help] [-leafs <arg>] [-output <arg>] - [-rule <arg>] [-seed <arg>] [-splits <arg>] [-trees <arg>] [-vars - <arg>] - -attrs,--attribute_types <arg> Comma separated attribute types (Q for - quantitative variable and C for - categorical variable. e.g., [Q,C,Q,C]) - -depth,--max_depth <arg> The maximum number of the tree depth - [default: Integer.MAX_VALUE] - -disable_compression Whether to disable compression of the - output script [default: false] - -help Show function help - -leafs,--max_leaf_nodes <arg> The maximum number of leaf nodes - [default: Integer.MAX_VALUE] - -output,--output_type <arg> The output type (serialization/ser or - opscode/vm or javascript/js) [default: - serialization] - -rule,--split_rule <arg> Split algorithm [default: GINI, ENTROPY] - -seed <arg> seed value in long [default: -1 - (random)] - -splits,--min_split <arg> A node that has greater than or equals - to `min_split` examples will split - [default: 2] - -trees,--num_trees <arg> The number of trees for each task - [default: 50] - -vars,--num_variables <arg> The number of random selected features - [default: ceil(sqrt(x[0].length))]. - int(num_variables * x[0].length) is - considered if num_variable is (0,1] +usage: train_randomforest_classifier(array<double|string> features, int + label [, const array<double> classWeights, const string options]) - + Returns a relation consists of <int model_id, int model_type, + string pred_model, array<double> var_importance, int oob_errors, + int oob_tests, double weight> [-attrs <arg>] [-depth <arg>] [-help] + [-leafs <arg>] [-min_samples_leaf <arg>] [-rule <arg>] [-seed + <arg>] [-splits <arg>] [-stratified] [-subsample <arg>] [-trees + <arg>] [-vars <arg>] + -attrs,--attribute_types <arg> Comma separated attribute types (Q + for quantitative variable and C for + categorical variable. e.g., + [Q,C,Q,C]) + -depth,--max_depth <arg> The maximum number of the tree depth + [default: Integer.MAX_VALUE] + -help Show function help + -leafs,--max_leaf_nodes <arg> The maximum number of leaf nodes + [default: Integer.MAX_VALUE] + -min_samples_leaf <arg> The minimum number of samples in a + leaf node [default: 1] + -rule,--split_rule <arg> Split algorithm [default: GINI, + ENTROPY] + -seed <arg> seed value in long [default: -1 + (random)] + -splits,--min_split <arg> A node that has greater than or + equals to `min_split` examples will + split [default: 2] + -stratified,--stratified_sampling Enable Stratified sampling for + unbalanced data + -subsample <arg> Sampling rate in range (0.0,1.0] + -trees,--num_trees <arg> The number of trees for each task + [default: 50] + -vars,--num_variables <arg> The number of random selected + features [default: + ceil(sqrt(x[0].length))]. + int(num_variables * x[0].length) is + considered if num_variable is (0,1 ``` -*Caution: "-num_trees" controls the number of trees for each task, not the total number of trees.* + +> #### Caution +> `-num_trees` controls the number of trees for each task, not the total number of trees. ### Parallelize Training @@ -161,7 +173,8 @@ To parallelize RandomForest training, you can use UNION ALL as follows: ```sql CREATE TABLE model -STORED AS SEQUENCEFILE + STORED AS ORC tblproperties("orc.compress"="SNAPPY") + -- STORED AS SEQUENCEFILE AS select train_randomforest_classifier(features, label, '-trees 25') @@ -186,51 +199,7 @@ select from model; ``` -> [2.81010338879605,0.4970357753626371,23.790369091407698,14.315316390235273] 0.05333333333333334 - -### Output prediction model by Javascipt - -```sql -CREATE TABLE model_javascript -STORED AS SEQUENCEFILE -AS -select train_randomforest_classifier(features, label, "-output_type js -disable_compression") -from training; - -select model from model_javascript limit 1; -``` - -```js -if(x[3] <= 0.5) { - 0; -} else { - if(x[2] <= 4.5) { - if(x[3] <= 1.5) { - if(x[0] <= 4.5) { - 1; - } else { - if(x[0] <= 5.5) { - 1; - } else { - if(x[1] <= 2.5) { - 1; - } else { - 1; - } - } - } - } else { - 2; - } - } else { - if(x[3] <= 1.5) { - 2; - } else { - 2; - } - } -} -``` +> [6.837674865013268,4.1317115752776665,24.331571871930226,25.677497925673062] 0.056666666666666664 # Prediction @@ -239,18 +208,24 @@ set hivevar:classification=true; set hive.auto.convert.join=true; set hive.mapjoin.optimized.hashtable=false; -create table predicted_vm +create table predicted as SELECT rowid, - rf_ensemble(predicted) as predicted + -- rf_ensemble(predicted) as predicted + -- hivemall v0.5-rc.1 or later + rf_ensemble(predicted.value, predicted.posteriori, model_weight) as predicted + -- rf_ensemble(predicted.value, predicted.posteriori) as predicted -- avoid OOB accuracy (i.e., model_weight) FROM ( SELECT rowid, -- hivemall v0.4.1-alpha.2 and before -- tree_predict(p.model, t.features, ${classification}) as predicted -- hivemall v0.4.1 and later - tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted + -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted + -- hivemall v0.5-rc.1 or later + p.model_weight, + tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted FROM model p LEFT OUTER JOIN -- CROSS JOIN @@ -260,7 +235,6 @@ group by rowid ; ``` -_Note: Javascript outputs can be evaluated by `js_tree_predict`._ ### Parallelize Prediction @@ -272,20 +246,29 @@ set hive.auto.convert.join=true; SET hive.mapjoin.optimized.hashtable=false; SET mapred.reduce.tasks=8; -create table predicted_vm +create table predicted as SELECT rowid, - rf_ensemble(predicted) as predicted + -- rf_ensemble(predicted) as predicted + -- hivemall v0.5-rc.1 or later + rf_ensemble(predicted.value, predicted.posteriori, model_weight) as predicted + -- rf_ensemble(predicted.value, predicted.posteriori) as predicted -- avoid OOB accuracy (i.e., model_weight) FROM ( SELECT t.rowid, -- hivemall v0.4.1-alpha.2 and before -- tree_predict(p.pred_model, t.features, ${classification}) as predicted -- hivemall v0.4.1 and later - tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted + -- tree_predict(p.model_id, p.model_type, p.pred_model, t.features, ${classification}) as predicted + -- hivemall v0.5-rc.1 or later + p.model_weight, + tree_predict(p.model_id, p.model, t.features, ${classification}) as predicted FROM ( - SELECT model_id, model_type, pred_model + SELECT + -- model_id, model_type, pred_model + -- hivemall v0.5-rc.1 or later + model_id, model_weight, model FROM model DISTRIBUTE BY rand(1) ) p @@ -300,8 +283,10 @@ group by ```sql select count(1) from training; +``` > 150 +```sql set hivevar:total_cnt=150; WITH t1 as ( @@ -310,7 +295,7 @@ SELECT t.label as actual, p.predicted.label as predicted FROM - predicted_vm p + predicted p LEFT OUTER JOIN training t ON (t.rowid = p.rowid) ) SELECT @@ -321,4 +306,76 @@ WHERE actual = predicted ; ``` -> 0.9533333333333334 +> 0.98 + +# Graphvis export + +> #### Note +> `tree_export` feature is supported from Hivemall v0.5-rc.1 or later. +> Better to limit tree depth on training by `-depth` option to plot a Decision Tree. + +Hivemall provide `tree_export` to export a decision tree into [Graphviz](http://www.graphviz.org/) or human-readable Javascript format. You can find the usage by issuing the following query: + +``` +> select tree_export("","-help"); + +usage: tree_export(string model, const string options, optional + array<string> featureNames=null, optional array<string> + classNames=null) - exports a Decision Tree model as javascript/dot] + [-help] [-output_name <arg>] [-r] [-t <arg>] + -help Show function help + -output_name,--outputName <arg> output name [default: predicted] + -r,--regression Is regression tree or not + -t,--type <arg> Type of output [default: js, + javascript/js, graphvis/dot +``` + +```sql +CREATE TABLE model_exported + STORED AS ORC tblproperties("orc.compress"="SNAPPY") +AS +select + model_id, + tree_export(model, "-type javascript", array('sepal_length','sepal_width','petal_length','petak_width'), array('Setosa','Versicolour','Virginica')) as js, + tree_export(model, "-type graphvis", array('sepal_length','sepal_width','petal_length','petak_width'), array('Setosa','Versicolour','Virginica')) as dot +from + model +-- limit 1 +; +``` + +``` +digraph Tree { + node [shape=box, style="filled, rounded", color="black", fontname=helvetica]; + edge [fontname=helvetica]; + 0 [label=<petal_length ≤ 2.599999964237213>, fillcolor="#00000000"]; + 1 [label=<predicted = Setosa>, fillcolor="0.0000,1.000,1.000", shape=ellipse]; + 0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"]; + 2 [label=<petal_length ≤ 4.950000047683716>, fillcolor="#00000000"]; + 0 -> 2 [labeldistance=2.5, labelangle=-45, headlabel="False"]; + 3 [label=<petak_width ≤ 1.6500000357627869>, fillcolor="#00000000"]; + 2 -> 3; + 4 [label=<predicted = Versicolour>, fillcolor="0.3333,1.000,1.000", shape=ellipse]; + 3 -> 4; + 5 [label=<sepal_width ≤ 3.100000023841858>, fillcolor="#00000000"]; + 3 -> 5; + 6 [label=<predicted = Virginica>, fillcolor="0.6667,1.000,1.000", shape=ellipse]; + 5 -> 6; + 7 [label=<predicted = Versicolour>, fillcolor="0.3333,1.000,1.000", shape=ellipse]; + 5 -> 7; + 8 [label=<petak_width ≤ 1.75>, fillcolor="#00000000"]; + 2 -> 8; + 9 [label=<petal_length ≤ 5.299999952316284>, fillcolor="#00000000"]; + 8 -> 9; + 10 [label=<predicted = Versicolour>, fillcolor="0.3333,1.000,1.000", shape=ellipse]; + 9 -> 10; + 11 [label=<predicted = Virginica>, fillcolor="0.6667,1.000,1.000", shape=ellipse]; + 9 -> 11; + 12 [label=<predicted = Virginica>, fillcolor="0.6667,1.000,1.000", shape=ellipse]; + 8 -> 12; +} +``` + +<img src="../resources/images/iris.png" alt="Iris Graphvis output"/> + +You can draw a graph by `dot -Tpng iris.dot -o iris.png` or using [Viz.js](http://viz-js.com/). \ No newline at end of file
