This is an automated email from the ASF dual-hosted git repository.
myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git
The following commit(s) were added to refs/heads/master by this push:
new 176fa07 [HIVEMALL-171] Tracing functionality for prediction of
DecisionTrees
176fa07 is described below
commit 176fa070c1e2ea3b0737c8150a1302e4cb643816
Author: Makoto Yui <[email protected]>
AuthorDate: Sat Sep 28 03:39:01 2019 +0900
[HIVEMALL-171] Tracing functionality for prediction of DecisionTrees
## What changes were proposed in this pull request?
Introduce `decision_path` UDF providing tracing of decision tree prediction
paths
## What type of PR is it?
Feature
## What is the Jira issue?
https://issues.apache.org/jira/browse/HIVEMALL-171
## How was this patch tested?
unit tests, manual tests on EMR
## How to use this feature?
to be described in the user guide
## Checklist
- [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`,
for your commit?
- [x] Did you run system tests on Hive (or Spark)?
Author: Makoto Yui <[email protected]>
Closes #199 from myui/HIVEMALL-171.
---
core/src/main/java/hivemall/annotations/Cite.java | 1 +
.../smile/classification/DecisionTree.java | 19 +-
.../smile/classification/PredictionHandler.java | 35 +-
.../hivemall/smile/regression/RegressionTree.java | 31 +
.../java/hivemall/smile/tools/DecisionPathUDF.java | 659 +++++++++++++++++++++
.../java/hivemall/smile/tools/TreePredictUDF.java | 2 +-
.../main/java/hivemall/utils/lang/ArrayUtils.java | 20 +-
.../smile/classification/DecisionTreeTest.java | 80 +++
docs/gitbook/misc/funcs.md | 33 ++
resources/ddl/define-all-as-permanent.hive | 3 +
resources/ddl/define-all.hive | 4 +
resources/ddl/define-all.spark | 3 +
12 files changed, 879 insertions(+), 11 deletions(-)
diff --git a/core/src/main/java/hivemall/annotations/Cite.java
b/core/src/main/java/hivemall/annotations/Cite.java
index 2b93cd6..7d09320 100644
--- a/core/src/main/java/hivemall/annotations/Cite.java
+++ b/core/src/main/java/hivemall/annotations/Cite.java
@@ -30,6 +30,7 @@ import javax.annotation.Nullable;
public @interface Cite {
@Nonnull
String description();
+
@Nullable
String url();
}
diff --git a/core/src/main/java/hivemall/smile/classification/DecisionTree.java
b/core/src/main/java/hivemall/smile/classification/DecisionTree.java
index 95b4b2a..74a99ad 100644
--- a/core/src/main/java/hivemall/smile/classification/DecisionTree.java
+++ b/core/src/main/java/hivemall/smile/classification/DecisionTree.java
@@ -17,6 +17,10 @@
//
https://github.com/haifengl/smile/blob/master/core/src/main/java/smile/classification/DecisionTree.java
package hivemall.smile.classification;
+import static hivemall.smile.classification.PredictionHandler.Operator.EQ;
+import static hivemall.smile.classification.PredictionHandler.Operator.GT;
+import static hivemall.smile.classification.PredictionHandler.Operator.LE;
+import static hivemall.smile.classification.PredictionHandler.Operator.NE;
import static hivemall.smile.utils.SmileExtUtils.NOMINAL;
import static hivemall.smile.utils.SmileExtUtils.NUMERIC;
import static hivemall.smile.utils.SmileExtUtils.resolveFeatureName;
@@ -319,18 +323,23 @@ public class DecisionTree implements Classifier<Vector> {
*/
public void predict(@Nonnull final Vector x, @Nonnull final
PredictionHandler handler) {
if (isLeaf()) {
- handler.handle(output, posteriori);
+ handler.visitLeaf(output, posteriori);
} else {
+ final double feature = x.get(splitFeature, Double.NaN);
if (quantitativeFeature) {
- if (x.get(splitFeature, Double.NaN) <= splitValue) {
+ if (feature <= splitValue) {
+ handler.visitBranch(LE, splitFeature, feature,
splitValue);
trueChild.predict(x, handler);
} else {
+ handler.visitBranch(GT, splitFeature, feature,
splitValue);
falseChild.predict(x, handler);
}
} else {
- if (x.get(splitFeature, Double.NaN) == splitValue) {
+ if (feature == splitValue) {
+ handler.visitBranch(EQ, splitFeature, feature,
splitValue);
trueChild.predict(x, handler);
} else {
+ handler.visitBranch(NE, splitFeature, feature,
splitValue);
falseChild.predict(x, handler);
}
}
@@ -1359,6 +1368,10 @@ public class DecisionTree implements Classifier<Vector> {
return _root.predict(x);
}
+ public void predict(@Nonnull final Vector x, @Nonnull final
PredictionHandler handler) {
+ _root.predict(x, handler);
+ }
+
/**
* Predicts the class label of an instance and also calculate a posteriori
probabilities. Not
* supported.
diff --git
a/core/src/main/java/hivemall/smile/classification/PredictionHandler.java
b/core/src/main/java/hivemall/smile/classification/PredictionHandler.java
index 84ef244..6c19641 100644
--- a/core/src/main/java/hivemall/smile/classification/PredictionHandler.java
+++ b/core/src/main/java/hivemall/smile/classification/PredictionHandler.java
@@ -20,8 +20,39 @@ package hivemall.smile.classification;
import javax.annotation.Nonnull;
-public interface PredictionHandler {
+public abstract class PredictionHandler {
- void handle(int output, @Nonnull double[] posteriori);
+ public enum Operator {
+ /* = */ EQ, /* != */ NE, /* <= */ LE, /* > */ GT;
+
+ @Override
+ public String toString() {
+ switch (this) {
+ case EQ:
+ return "=";
+ case NE:
+ return "!=";
+ case LE:
+ return "<=";
+ case GT:
+ return ">";
+ default:
+ throw new IllegalStateException("Unexpected operator: " +
this);
+ }
+ }
+ }
+
+ public void init() {};
+
+ public void visitBranch(@Nonnull Operator op, int splitFeatureIndex,
double splitFeature,
+ double splitValue) {}
+
+ public void visitLeaf(double output) {}
+
+ public void visitLeaf(int output, @Nonnull double[] posteriori) {}
+
+ public <T> T getResult() {
+ throw new UnsupportedOperationException();
+ }
}
diff --git a/core/src/main/java/hivemall/smile/regression/RegressionTree.java
b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
index 764c352..ab2f25f 100755
--- a/core/src/main/java/hivemall/smile/regression/RegressionTree.java
+++ b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
@@ -17,6 +17,10 @@
//
https://github.com/haifengl/smile/blob/master/core/src/main/java/smile/regression/RegressionTree.java
package hivemall.smile.regression;
+import static hivemall.smile.classification.PredictionHandler.Operator.EQ;
+import static hivemall.smile.classification.PredictionHandler.Operator.GT;
+import static hivemall.smile.classification.PredictionHandler.Operator.LE;
+import static hivemall.smile.classification.PredictionHandler.Operator.NE;
import static hivemall.smile.utils.SmileExtUtils.NOMINAL;
import static hivemall.smile.utils.SmileExtUtils.NUMERIC;
import static hivemall.smile.utils.SmileExtUtils.resolveFeatureName;
@@ -29,6 +33,7 @@ import hivemall.math.vector.DenseVector;
import hivemall.math.vector.SparseVector;
import hivemall.math.vector.Vector;
import hivemall.math.vector.VectorProcedure;
+import hivemall.smile.classification.PredictionHandler;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.smile.utils.VariableOrder;
import hivemall.utils.collections.arrays.SparseIntArray;
@@ -274,6 +279,32 @@ public final class RegressionTree implements
Regression<Vector> {
}
}
+ public double predict(@Nonnull final Vector x, @Nonnull final
PredictionHandler handler) {
+ if (isLeaf()) {
+ handler.visitLeaf(output);
+ return output;
+ } else {
+ final double feature = x.get(splitFeature, Double.NaN);
+ if (quantitativeFeature) {
+ if (feature <= splitValue) {
+ handler.visitBranch(LE, splitFeature, feature,
splitValue);
+ return trueChild.predict(x);
+ } else {
+ handler.visitBranch(GT, splitFeature, feature,
splitValue);
+ return falseChild.predict(x);
+ }
+ } else {
+ if (feature == splitValue) {
+ handler.visitBranch(EQ, splitFeature, feature,
splitValue);
+ return trueChild.predict(x);
+ } else {
+ handler.visitBranch(NE, splitFeature, feature,
splitValue);
+ return falseChild.predict(x);
+ }
+ }
+ }
+ }
+
/**
* Evaluate the regression tree over an instance.
*/
diff --git a/core/src/main/java/hivemall/smile/tools/DecisionPathUDF.java
b/core/src/main/java/hivemall/smile/tools/DecisionPathUDF.java
new file mode 100644
index 0000000..11a05da
--- /dev/null
+++ b/core/src/main/java/hivemall/smile/tools/DecisionPathUDF.java
@@ -0,0 +1,659 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.smile.tools;
+
+import hivemall.UDFWithOptions;
+import hivemall.math.vector.DenseVector;
+import hivemall.math.vector.SparseVector;
+import hivemall.math.vector.Vector;
+import hivemall.smile.classification.DecisionTree;
+import hivemall.smile.classification.PredictionHandler;
+import hivemall.smile.regression.RegressionTree;
+import hivemall.utils.codec.Base91;
+import hivemall.utils.hadoop.HiveUtils;
+import hivemall.utils.lang.Preconditions;
+import hivemall.utils.lang.StringUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Options;
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.UDFType;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import
org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
+import org.apache.hadoop.io.Text;
+
+// @formatter:off
+@Description(name = "decision_path",
+ value = "_FUNC_(string modelId, string model, array<double|string>
features [, const string options] [, optional array<string> featureNames=null,
optional array<string> classNames=null])"
+ + " - Returns a decision path for each prediction in
array<string>",
+ extended = "SELECT\n" +
+ " t.passengerid,\n" +
+ " decision_path(m.model_id, m.model, t.features,
'-classification')\n" +
+ "FROM\n" +
+ " model_rf m\n" +
+ " LEFT OUTER JOIN\n" +
+ " test_rf t;\n" +
+ "> | 892 | [\"2 [0.0] = 0.0\",\"0 [3.0] = 3.0\",\"1 [696.0] !=
107.0\",\"7 [7.8292] <= 7.9104\",\"1 [696.0] != 828.0\",\"1 [696.0] !=
391.0\",\"0 [0.961038961038961, 0.03896103896103896]\"] |\n\n" +
+ "-- Show 100 frequent branches\n" +
+ "WITH tmp as (\n" +
+ " SELECT\n" +
+ " decision_path(m.model_id, m.model, t.features,
'-classification -no_verbose -no_leaf',
array('pclass','name','sex','age','sibsp','parch','ticket','fare','cabin','embarked'),
array('no','yes')) as path\n" +
+ " FROM\n" +
+ " model_rf m\n" +
+ " LEFT OUTER JOIN -- CROSS JOIN\n" +
+ " test_rf t\n" +
+ ")\n" +
+ "select\n" +
+ " r.branch,\n" +
+ " count(1) as cnt\n" +
+ "from\n" +
+ " tmp l\n" +
+ " LATERAL VIEW explode(l.path) r as branch\n" +
+ "group by\n" +
+ " r.branch\n" +
+ "order by\n" +
+ " cnt desc\n" +
+ "limit 100;")
+// @formatter:on
+@UDFType(deterministic = true, stateful = false)
+public final class DecisionPathUDF extends UDFWithOptions {
+
+ private StringObjectInspector modelOI;
+ private ListObjectInspector featureListOI;
+ private PrimitiveObjectInspector featureElemOI;
+ private boolean denseInput;
+
+ // options
+ private boolean classification = false;
+ private boolean summarize = true;
+ private boolean verbose = true;
+ private boolean noLeaf = false;
+
+ @Nullable
+ private String[] featureNames;
+ @Nullable
+ private String[] classNames;
+
+ @Nullable
+ private transient Vector featuresProbe;
+
+ @Nullable
+ private transient Evaluator evaluator;
+
+ @Override
+ protected Options getOptions() {
+ Options opts = new Options();
+ opts.addOption("c", "classification", false,
+ "Predict as classification [default: not enabled]");
+ opts.addOption("no_sumarize", "disable_summarization", false,
+ "Do not summarize decision paths");
+ opts.addOption("no_verbose", "disable_verbose_output", false,
+ "Disable verbose output [default: verbose]");
+ opts.addOption("no_leaf", "disable_leaf_output", false,
+ "Show leaf value [default: not enabled]");
+ return opts;
+ }
+
+ @Override
+ protected CommandLine processOptions(@Nonnull String optionValue) throws
UDFArgumentException {
+ CommandLine cl = parseOptions(optionValue);
+
+ this.classification = cl.hasOption("classification");
+ this.summarize = !cl.hasOption("no_sumarize");
+ this.verbose = !cl.hasOption("disable_verbose_output");
+ this.noLeaf = cl.hasOption("disable_leaf_output");
+
+ return cl;
+ }
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] argOIs) throws
UDFArgumentException {
+ if (argOIs.length < 3 || argOIs.length > 6) {
+ showHelp("tree_predict takes 3 ~ 6 arguments");
+ }
+
+ this.modelOI = HiveUtils.asStringOI(argOIs[1]);
+
+ ListObjectInspector listOI = HiveUtils.asListOI(argOIs[2]);
+ this.featureListOI = listOI;
+ ObjectInspector elemOI = listOI.getListElementObjectInspector();
+ if (HiveUtils.isNumberOI(elemOI)) {
+ this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
+ this.denseInput = true;
+ } else if (HiveUtils.isStringOI(elemOI)) {
+ this.featureElemOI = HiveUtils.asStringOI(elemOI);
+ this.denseInput = false;
+ } else {
+ throw new UDFArgumentException(
+ "tree_predict takes array<double> or array<string> for the 3rd
argument: "
+ + listOI.getTypeName());
+ }
+
+ if (argOIs.length >= 4) {
+ ObjectInspector argOI3 = argOIs[3];
+ if (HiveUtils.isConstString(argOI3)) {
+ String opts = HiveUtils.getConstString(argOI3);
+ processOptions(opts);
+ if (argOIs.length >= 5) {
+ ObjectInspector argOI4 = argOIs[4];
+ if (HiveUtils.isConstStringListOI(argOI4)) {
+ this.featureNames =
HiveUtils.getConstStringArray(argOI4);
+ if (argOIs.length >= 6) {
+ ObjectInspector argOI5 = argOIs[5];
+ if (HiveUtils.isConstStringListOI(argOI5)) {
+ if (!classification) {
+ throw new UDFArgumentException(
+ "classNames should not be provided for
regression");
+ }
+ this.classNames =
HiveUtils.getConstStringArray(argOI5);
+ } else {
+ throw new UDFArgumentException(
+ "decision_path expects 'const
array<string> classNames' for the 6th argument: "
+ + argOI5.getTypeName());
+ }
+ }
+ } else {
+ throw new UDFArgumentException(
+ "decision_path expects 'const array<string>
featureNames' for the 5th argument: "
+ + argOI4.getTypeName());
+ }
+ }
+ } else if (HiveUtils.isConstStringListOI(argOI3)) {
+ this.featureNames = HiveUtils.getConstStringArray(argOI3);
+ if (argOIs.length >= 5) {
+ ObjectInspector argOI4 = argOIs[4];
+ if (HiveUtils.isConstStringListOI(argOI4)) {
+ if (!classification) {
+ throw new UDFArgumentException(
+ "classNames should not be provided for
regression");
+ }
+ this.classNames =
HiveUtils.getConstStringArray(argOI4);
+ } else {
+ throw new UDFArgumentException(
+ "decision_path expects 'const array<string>
classNames' for the 5th argument: "
+ + argOI4.getTypeName());
+ }
+ }
+ } else {
+ throw new UDFArgumentException(
+ "decision_path expects 'const array<string> options' or
'const array<string> featureNames' for the 4th argument: "
+ + argOI3.getTypeName());
+ }
+ }
+
+ return ObjectInspectorFactory.getStandardListObjectInspector(
+ PrimitiveObjectInspectorFactory.javaStringObjectInspector);
+ }
+
+ @Override
+ public List<String> evaluate(@Nonnull DeferredObject[] arguments) throws
HiveException {
+ Object arg0 = arguments[0].get();
+ if (arg0 == null) {
+ throw new HiveException("modelId should not be null");
+ }
+ // Not using string OI for backward compatibilities
+ String modelId = arg0.toString();
+
+ Object arg1 = arguments[1].get();
+ if (arg1 == null) {
+ return null;
+ }
+ Text model = modelOI.getPrimitiveWritableObject(arg1);
+
+ Object arg2 = arguments[2].get();
+ if (arg2 == null) {
+ throw new HiveException("features was null");
+ }
+ this.featuresProbe = parseFeatures(arg2, featuresProbe);
+
+ if (evaluator == null) {
+ this.evaluator = classification ? new ClassificationEvaluator(this)
+ : new RegressionEvaluator(this);
+ }
+ return evaluator.evaluate(modelId, model, featuresProbe);
+ }
+
+ @Nonnull
+ private Vector parseFeatures(@Nonnull final Object argObj, @Nullable
Vector probe)
+ throws UDFArgumentException {
+ if (denseInput) {
+ final int length = featureListOI.getListLength(argObj);
+ if (probe == null) {
+ probe = new DenseVector(length);
+ } else if (length != probe.size()) {
+ probe = new DenseVector(length);
+ }
+
+ for (int i = 0; i < length; i++) {
+ final Object o = featureListOI.getListElement(argObj, i);
+ if (o == null) {
+ probe.set(i, 0.d);
+ } else {
+ double v = PrimitiveObjectInspectorUtils.getDouble(o,
featureElemOI);
+ probe.set(i, v);
+ }
+ }
+ } else {
+ if (probe == null) {
+ probe = new SparseVector();
+ } else {
+ probe.clear();
+ }
+
+ final int length = featureListOI.getListLength(argObj);
+ for (int i = 0; i < length; i++) {
+ Object o = featureListOI.getListElement(argObj, i);
+ if (o == null) {
+ continue;
+ }
+ String col = o.toString();
+
+ final int pos = col.indexOf(':');
+ if (pos == 0) {
+ throw new UDFArgumentException("Invalid feature value
representation: " + col);
+ }
+
+ final String feature;
+ final double value;
+ if (pos > 0) {
+ feature = col.substring(0, pos);
+ String s2 = col.substring(pos + 1);
+ value = Double.parseDouble(s2);
+ } else {
+ feature = col;
+ value = 1.d;
+ }
+
+ if (feature.indexOf(':') != -1) {
+ throw new UDFArgumentException(
+ "Invalid feature format `<index>:<value>`: " + col);
+ }
+
+ final int colIndex = Integer.parseInt(feature);
+ if (colIndex < 0) {
+ throw new UDFArgumentException(
+ "Col index MUST be greater than or equals to 0: " +
colIndex);
+ }
+ probe.set(colIndex, value);
+ }
+ }
+ return probe;
+ }
+
+ @Override
+ public void close() throws IOException {
+ this.modelOI = null;
+ this.featureElemOI = null;
+ this.featureListOI = null;
+ this.featureNames = null;
+ this.classNames = null;
+ this.featuresProbe = null;
+ this.evaluator = null;
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ return "decision_path(" + StringUtils.join(children, ',') + ")";
+ }
+
+ interface Evaluator {
+
+ @Nonnull
+ List<String> evaluate(@Nonnull String modelId, @Nonnull Text model,
+ @Nonnull Vector features) throws HiveException;
+
+ }
+
+ static final class ClassificationEvaluator implements Evaluator {
+
+ @Nullable
+ private final String[] featureNames;
+ @Nullable
+ private final String[] classNames;
+
+ @Nonnull
+ private final List<String> result;
+ @Nonnull
+ private final PredictionHandler handler;
+
+ @Nullable
+ private String prevModelId = null;
+ private DecisionTree.Node cNode = null;
+
+ ClassificationEvaluator(@Nonnull final DecisionPathUDF udf) {
+ this.featureNames = udf.featureNames;
+ this.classNames = udf.classNames;
+
+ final StringBuilder buf = new StringBuilder();
+ final ArrayList<String> result = new ArrayList<>();
+ this.result = result;
+
+ if (udf.summarize) {
+ final LinkedHashMap<String, Double> map = new
LinkedHashMap<>();
+
+ this.handler = new PredictionHandler() {
+
+ @Override
+ public void init() {
+ map.clear();
+ result.clear();
+ }
+
+ @Override
+ public void visitBranch(Operator op, int
splitFeatureIndex, double splitFeature,
+ double splitValue) {
+ buf.append(resolveFeatureName(splitFeatureIndex));
+ if (udf.verbose) {
+ buf.append(" [" + splitFeature + "] ");
+ } else {
+ buf.append(' ');
+ }
+ buf.append(op);
+ if (op == Operator.EQ || op == Operator.NE) {
+ buf.append(' ');
+ buf.append(splitValue);
+ }
+ String key = buf.toString();
+ map.put(key, splitValue);
+ StringUtils.clear(buf);
+ }
+
+ @Override
+ public void visitLeaf(int output, double[] posteriori) {
+ for (Map.Entry<String, Double> e : map.entrySet()) {
+ final String key = e.getKey();
+ if (key.indexOf('<') == -1 && key.indexOf('>') ==
-1) {
+ result.add(key);
+ } else {
+ double value = e.getValue().doubleValue();
+ result.add(key + ' ' + value);
+ }
+ }
+ if (udf.noLeaf) {
+ return;
+ }
+
+ if (udf.verbose) {
+ buf.append(resolveClassName(output));
+ buf.append(' ');
+ buf.append(Arrays.toString(posteriori));
+ result.add(buf.toString());
+ StringUtils.clear(buf);
+ } else {
+ result.add(resolveClassName(output));
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public ArrayList<String> getResult() {
+ return result;
+ }
+
+ };
+ } else {
+ this.handler = new PredictionHandler() {
+
+ @Override
+ public void init() {
+ result.clear();
+ }
+
+ @Override
+ public void visitBranch(Operator op, int
splitFeatureIndex, double splitFeature,
+ double splitValue) {
+ buf.append(resolveFeatureName(splitFeatureIndex));
+ if (udf.verbose) {
+ buf.append(" [" + splitFeature + "] ");
+ } else {
+ buf.append(' ');
+ }
+ buf.append(op);
+ buf.append(' ');
+ buf.append(splitValue);
+ result.add(buf.toString());
+ StringUtils.clear(buf);
+ }
+
+ @Override
+ public void visitLeaf(int output, double[] posteriori) {
+ if (udf.noLeaf) {
+ return;
+ }
+
+ if (udf.verbose) {
+ buf.append(resolveClassName(output));
+ buf.append(' ');
+ buf.append(Arrays.toString(posteriori));
+ result.add(buf.toString());
+ StringUtils.clear(buf);
+ } else {
+ result.add(resolveClassName(output));
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public ArrayList<String> getResult() {
+ return result;
+ }
+
+ };
+ }
+ }
+
+ @Nonnull
+ private String resolveFeatureName(final int splitFeatureIndex) {
+ if (featureNames == null) {
+ return Integer.toString(splitFeatureIndex);
+ } else {
+ return featureNames[splitFeatureIndex];
+ }
+ }
+
+ @Nonnull
+ private String resolveClassName(final int classLabel) {
+ if (classNames == null) {
+ return Integer.toString(classLabel);
+ } else {
+ return classNames[classLabel];
+ }
+ }
+
+ @Nonnull
+ public List<String> evaluate(@Nonnull final String modelId, @Nonnull
final Text script,
+ @Nonnull final Vector features) throws HiveException {
+ if (!modelId.equals(prevModelId)) {
+ this.prevModelId = modelId;
+ int length = script.getLength();
+ byte[] b = script.getBytes();
+ b = Base91.decode(b, 0, length);
+ this.cNode = DecisionTree.deserialize(b, b.length, true);
+ }
+ Preconditions.checkNotNull(cNode);
+
+ handler.init();
+ cNode.predict(features, handler);
+ return handler.getResult();
+ }
+
+ }
+
+ static final class RegressionEvaluator implements Evaluator {
+
+ @Nullable
+ private final String[] featureNames;
+
+ @Nonnull
+ private final List<String> result;
+ @Nonnull
+ private final PredictionHandler handler;
+
+ @Nullable
+ private String prevModelId = null;
+ private RegressionTree.Node rNode = null;
+
+ RegressionEvaluator(@Nonnull final DecisionPathUDF udf) {
+ this.featureNames = udf.featureNames;
+
+ final StringBuilder buf = new StringBuilder();
+ final ArrayList<String> result = new ArrayList<>();
+ this.result = result;
+
+ if (udf.summarize) {
+ final LinkedHashMap<String, Double> map = new
LinkedHashMap<>();
+
+ this.handler = new PredictionHandler() {
+
+ @Override
+ public void init() {
+ map.clear();
+ result.clear();
+ }
+
+ @Override
+ public void visitBranch(Operator op, int
splitFeatureIndex, double splitFeature,
+ double splitValue) {
+ buf.append(resolveFeatureName(splitFeatureIndex));
+ if (udf.verbose) {
+ buf.append(" [" + splitFeature + "] ");
+ } else {
+ buf.append(' ');
+ }
+ buf.append(op);
+ if (op == Operator.EQ || op == Operator.NE) {
+ buf.append(' ');
+ buf.append(splitValue);
+ }
+ String key = buf.toString();
+ map.put(key, splitValue);
+ StringUtils.clear(buf);
+ }
+
+ @Override
+ public void visitLeaf(double output) {
+ for (Map.Entry<String, Double> e : map.entrySet()) {
+ final String key = e.getKey();
+ if (key.indexOf('<') == -1 && key.indexOf('>') ==
-1) {
+ result.add(key);
+ } else {
+ double value = e.getValue().doubleValue();
+ result.add(key + ' ' + value);
+ }
+ }
+ if (udf.noLeaf) {
+ return;
+ }
+
+ result.add(Double.toString(output));
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public ArrayList<String> getResult() {
+ return result;
+ }
+
+ };
+ } else {
+ this.handler = new PredictionHandler() {
+
+ @Override
+ public void init() {
+ result.clear();
+ }
+
+ @Override
+ public void visitBranch(Operator op, int
splitFeatureIndex, double splitFeature,
+ double splitValue) {
+ buf.append(resolveFeatureName(splitFeatureIndex));
+ if (udf.verbose) {
+ buf.append(" [" + splitFeature + "] ");
+ }
+ buf.append(op);
+ buf.append(' ');
+ buf.append(splitValue);
+ result.add(buf.toString());
+ StringUtils.clear(buf);
+ }
+
+ @Override
+ public void visitLeaf(double output) {
+ if (udf.noLeaf) {
+ return;
+ }
+
+ result.add(Double.toString(output));
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public ArrayList<String> getResult() {
+ return result;
+ }
+
+ };
+ }
+ }
+
+ @Nonnull
+ private String resolveFeatureName(final int splitFeatureIndex) {
+ if (featureNames == null) {
+ return Integer.toString(splitFeatureIndex);
+ } else {
+ return featureNames[splitFeatureIndex];
+ }
+ }
+
+ @Nonnull
+ public List<String> evaluate(@Nonnull final String modelId, @Nonnull
final Text script,
+ @Nonnull final Vector features) throws HiveException {
+ if (!modelId.equals(prevModelId)) {
+ this.prevModelId = modelId;
+ int length = script.getLength();
+ byte[] b = script.getBytes();
+ b = Base91.decode(b, 0, length);
+ this.rNode = RegressionTree.deserialize(b, b.length, true);
+ }
+ Preconditions.checkNotNull(rNode);
+
+ handler.init();
+ rNode.predict(features, handler);
+ return handler.getResult();
+ }
+ }
+
+}
diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
index 511944c..262a28d 100644
--- a/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
+++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDF.java
@@ -284,7 +284,7 @@ public final class TreePredictUDF extends UDFWithOptions {
Arrays.fill(result, null);
Preconditions.checkNotNull(cNode);
cNode.predict(features, new PredictionHandler() {
- public void handle(int output, double[] posteriori) {
+ public void visitLeaf(int output, double[] posteriori) {
result[0] = new IntWritable(output);
result[1] = WritableUtils.toWritableList(posteriori);
}
diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
index 4e73ebc..caf21d3 100644
--- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
+++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java
@@ -148,17 +148,23 @@ public final class ArrayUtils {
return Arrays.asList(v);
}
- public static <T> void shuffle(@Nonnull final T[] array) {
+ @Nonnull
+ public static <T> T[] shuffle(@Nonnull final T[] array) {
shuffle(array, array.length);
+ return array;
}
- public static <T> void shuffle(@Nonnull final T[] array, final Random rnd)
{
+ @Nonnull
+ public static <T> T[] shuffle(@Nonnull final T[] array, final Random rnd) {
shuffle(array, array.length, rnd);
+ return array;
}
- public static <T> void shuffle(@Nonnull final T[] array, final int size) {
+ @Nonnull
+ public static <T> T[] shuffle(@Nonnull final T[] array, final int size) {
Random rnd = new Random();
shuffle(array, size, rnd);
+ return array;
}
/**
@@ -166,19 +172,23 @@ public final class ArrayUtils {
*
* @link http://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle
*/
- public static <T> void shuffle(@Nonnull final T[] array, final int size,
+ @Nonnull
+ public static <T> T[] shuffle(@Nonnull final T[] array, final int size,
@Nonnull final Random rnd) {
for (int i = size; i > 1; i--) {
int randomPosition = rnd.nextInt(i);
swap(array, i - 1, randomPosition);
}
+ return array;
}
- public static void shuffle(@Nonnull final int[] array, @Nonnull final
Random rnd) {
+ @Nonnull
+ public static int[] shuffle(@Nonnull final int[] array, @Nonnull final
Random rnd) {
for (int i = array.length; i > 1; i--) {
int randomPosition = rnd.nextInt(i);
swap(array, i - 1, randomPosition);
}
+ return array;
}
public static void swap(@Nonnull final Object[] arr, final int i, final
int j) {
diff --git
a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
index c3601eb..9e5ee9a 100644
--- a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
+++ b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
@@ -25,12 +25,17 @@ import hivemall.math.matrix.builders.CSRMatrixBuilder;
import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
import hivemall.math.random.PRNG;
import hivemall.math.random.RandomNumberGeneratorFactory;
+import hivemall.math.vector.DenseVector;
import hivemall.smile.classification.DecisionTree.Node;
import hivemall.smile.classification.DecisionTree.SplitRule;
import hivemall.smile.tools.TreeExportUDF.Evaluator;
import hivemall.smile.tools.TreeExportUDF.OutputType;
import hivemall.smile.utils.SmileExtUtils;
import hivemall.utils.codec.Base91;
+import hivemall.utils.lang.ArrayUtils;
+import hivemall.utils.lang.StringUtils;
+import hivemall.utils.math.MathUtils;
+import smile.data.Attribute;
import smile.data.AttributeDataset;
import smile.data.NominalAttribute;
import smile.data.parser.ArffParser;
@@ -43,6 +48,9 @@ import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.text.ParseException;
+import java.util.Arrays;
+import java.util.LinkedHashMap;
+import java.util.Random;
import javax.annotation.Nonnull;
@@ -99,6 +107,15 @@ public class DecisionTreeTest {
}
@Test
+ public void testIrisTracePredict() throws IOException, ParseException {
+ int responseIndex = 4;
+ int numLeafs = Integer.MAX_VALUE;
+ runTracePredict(
+
"https://gist.githubusercontent.com/myui/143fa9d05bd6e7db0114/raw/500f178316b802f1cade6e3bf8dc814a96e84b1e/iris.arff",
+ responseIndex, numLeafs);
+ }
+
+ @Test
public void testIrisDepth4() throws IOException, ParseException {
int responseIndex = 4;
int numLeafs = 4;
@@ -240,6 +257,69 @@ public class DecisionTreeTest {
}
}
+ private static void runTracePredict(String datasetUrl, int responseIndex,
int numLeafs)
+ throws IOException, ParseException {
+ URL url = new URL(datasetUrl);
+ InputStream is = new BufferedInputStream(url.openStream());
+
+ ArffParser arffParser = new ArffParser();
+ arffParser.setResponseIndex(responseIndex);
+
+ AttributeDataset ds = arffParser.parse(is);
+ final Attribute[] attrs = ds.attributes();
+ final Attribute targetAttr = ds.response();
+
+ double[][] x = ds.toArray(new double[ds.size()][]);
+ int[] y = ds.toArray(new int[ds.size()]);
+
+ Random rnd = new Random(43L);
+ int numTrain = (int) (x.length * 0.7);
+ int[] index = ArrayUtils.shuffle(MathUtils.permutation(x.length), rnd);
+ int[] cvTrain = Arrays.copyOf(index, numTrain);
+ int[] cvTest = Arrays.copyOfRange(index, numTrain, index.length);
+
+ double[][] trainx = Math.slice(x, cvTrain);
+ int[] trainy = Math.slice(y, cvTrain);
+ double[][] testx = Math.slice(x, cvTest);
+
+ DecisionTree tree = new
DecisionTree(SmileExtUtils.convertAttributeTypes(attrs),
+ matrix(trainx, false), trainy, numLeafs,
RandomNumberGeneratorFactory.createPRNG(43L));
+
+ final LinkedHashMap<String, Double> map = new LinkedHashMap<>();
+ final StringBuilder buf = new StringBuilder();
+ for (int i = 0; i < testx.length; i++) {
+ final DenseVector test = new DenseVector(testx[i]);
+ tree.predict(test, new PredictionHandler() {
+
+ @Override
+ public void visitBranch(Operator op, int splitFeatureIndex,
double splitFeature,
+ double splitValue) {
+ buf.append(attrs[splitFeatureIndex].name);
+ buf.append(" [" + splitFeature + "] ");
+ buf.append(op);
+ buf.append(' ');
+ buf.append(splitValue);
+ buf.append('\n');
+
+ map.put(attrs[splitFeatureIndex].name + " [" +
splitFeature + "] " + op,
+ splitValue);
+ }
+
+ @Override
+ public void visitLeaf(int output, double[] posteriori) {
+ buf.append(targetAttr.toString(output));
+ }
+ });
+
+ Assert.assertTrue(buf.length() > 0);
+ Assert.assertFalse(map.isEmpty());
+
+ StringUtils.clear(buf);
+ map.clear();
+ }
+
+ }
+
@Test
public void testIrisSerializedObj() throws IOException, ParseException,
HiveException {
URL url = new URL(
diff --git a/docs/gitbook/misc/funcs.md b/docs/gitbook/misc/funcs.md
index d860dba..e5e9dc8 100644
--- a/docs/gitbook/misc/funcs.md
+++ b/docs/gitbook/misc/funcs.md
@@ -589,6 +589,39 @@ Reference: <a
href="https://papers.nips.cc/paper/3848-adaptive-regularization-of
- `train_randomforest_regressor(array<double|string> features, double target
[, string options])` - Returns a relation consists of <int model_id, int
model_type, string model, array<double> var_importance, double
oob_errors, int oob_tests>
+- `decision_path(string modelId, string model, array<double|string> features
[, const string options] [, optional array<string> featureNames=null, optional
array<string> classNames=null])` - Returns a decision path for each prediction
in array<string>
+ ```sql
+ SELECT
+ t.passengerid,
+ decision_path(m.model_id, m.model, t.features, '-classification')
+ FROM
+ model_rf m
+ LEFT OUTER JOIN
+ test_rf t;
+ > | 892 | ["2 [0.0] = 0.0","0 [3.0] = 3.0","1 [696.0] != 107.0","7 [7.8292]
<= 7.9104","1 [696.0] != 828.0","1 [696.0] != 391.0","0 [0.961038961038961,
0.03896103896103896]"] |
+
+ -- Show 100 frequent branches
+ WITH tmp as (
+ SELECT
+ decision_path(m.model_id, m.model, t.features, '-classification
-no_verbose -no_leaf',
array('pclass','name','sex','age','sibsp','parch','ticket','fare','cabin','embarked'),
array('no','yes')) as path
+ FROM
+ model_rf m
+ LEFT OUTER JOIN -- CROSS JOIN
+ test_rf t
+ )
+ select
+ r.branch,
+ count(1) as cnt
+ from
+ tmp l
+ LATERAL VIEW explode(l.path) r as branch
+ group by
+ r.branch
+ order by
+ cnt desc
+ limit 100;
+ ```
+
- `guess_attribute_types(ANY, ...)` - Returns attribute types
```sql
select guess_attribute_types(*) from train limit 1;
diff --git a/resources/ddl/define-all-as-permanent.hive
b/resources/ddl/define-all-as-permanent.hive
index 17797a8..343215a 100644
--- a/resources/ddl/define-all-as-permanent.hive
+++ b/resources/ddl/define-all-as-permanent.hive
@@ -829,6 +829,9 @@ CREATE FUNCTION rf_ensemble as
'hivemall.smile.tools.RandomForestEnsembleUDAF' U
DROP FUNCTION IF EXISTS guess_attribute_types;
CREATE FUNCTION guess_attribute_types as
'hivemall.smile.tools.GuessAttributesUDF' USING JAR '${hivemall_jar}';
+DROP FUNCTION IF EXISTS decision_path;
+CREATE FUNCTION decision_path as 'hivemall.smile.tools.DecisionPathUDF' USING
JAR '${hivemall_jar}';
+
--------------------
-- Recommendation --
--------------------
diff --git a/resources/ddl/define-all.hive b/resources/ddl/define-all.hive
index 04e8915..2a9b437 100644
--- a/resources/ddl/define-all.hive
+++ b/resources/ddl/define-all.hive
@@ -821,6 +821,9 @@ create temporary function rf_ensemble as
'hivemall.smile.tools.RandomForestEnsem
drop temporary function if exists guess_attribute_types;
create temporary function guess_attribute_types as
'hivemall.smile.tools.GuessAttributesUDF';
+drop temporary function if exists decision_path;
+create temporary function decision_path as
'hivemall.smile.tools.DecisionPathUDF';
+
--------------------
-- Recommendation --
--------------------
@@ -889,3 +892,4 @@ log(10, n_docs / max2(1,df_t)) + 1.0;
create temporary macro tfidf(tf FLOAT, df_t DOUBLE, n_docs DOUBLE)
tf * (log(10, n_docs / max2(1,df_t)) + 1.0);
+
diff --git a/resources/ddl/define-all.spark b/resources/ddl/define-all.spark
index 19f01bc..d62e3a2 100644
--- a/resources/ddl/define-all.spark
+++ b/resources/ddl/define-all.spark
@@ -807,6 +807,9 @@ sqlContext.sql("CREATE TEMPORARY FUNCTION
guess_attribute_types AS 'hivemall.smi
sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS
train_gradient_tree_boosting_classifier")
sqlContext.sql("CREATE TEMPORARY FUNCTION
train_gradient_tree_boosting_classifier AS
'hivemall.smile.classification.GradientTreeBoostingClassifierUDTF'")
+sqlContext.sql("DROP TEMPORARY FUNCTION IF EXISTS decision_path")
+sqlContext.sql("CREATE TEMPORARY FUNCTION decision_path AS
'hivemall.smile.tools.DecisionPathUDF'")
+
/**
* Recommendation
*/