Repository: incubator-hivemall Updated Branches: refs/heads/master 5ac14b7d8 -> 4be8adbc8
Close #144: [HIVEMALL-190][HOTFIX] Fixed a bug in tree_predict_v1 on loading old prediction models Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/4be8adbc Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/4be8adbc Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/4be8adbc Branch: refs/heads/master Commit: 4be8adbc8e6648e15edf5c9ead9b51f3ad9fa45d Parents: 5ac14b7 Author: Makoto Yui <m...@apache.org> Authored: Fri Apr 13 16:20:09 2018 +0900 Committer: Makoto Yui <m...@apache.org> Committed: Fri Apr 13 16:20:09 2018 +0900 ---------------------------------------------------------------------- .../annotations/BackwardCompatibility.java | 34 +++ .../java/hivemall/smile/data/Attribute.java | 17 ++ .../hivemall/smile/tools/TreePredictUDFv1.java | 274 ++++++++++++++++++- .../smile/tools/TreePredictUDFv1Test.java | 92 ++++--- .../hivemall/smile/tools/dtv1_serialized.csv.gz | Bin 0 -> 668 bytes 5 files changed, 370 insertions(+), 47 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4be8adbc/core/src/main/java/hivemall/annotations/BackwardCompatibility.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/annotations/BackwardCompatibility.java b/core/src/main/java/hivemall/annotations/BackwardCompatibility.java new file mode 100644 index 0000000..35f9c4e --- /dev/null +++ b/core/src/main/java/hivemall/annotations/BackwardCompatibility.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.annotations; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotate program elements that is for backward compatibility + */ +@Retention(RetentionPolicy.CLASS) +@Target({ElementType.TYPE, ElementType.METHOD, ElementType.CONSTRUCTOR, ElementType.PACKAGE}) +@Documented +public @interface BackwardCompatibility { +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4be8adbc/core/src/main/java/hivemall/smile/data/Attribute.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/data/Attribute.java b/core/src/main/java/hivemall/smile/data/Attribute.java index 6569726..f9cb5a6 100644 --- a/core/src/main/java/hivemall/smile/data/Attribute.java +++ b/core/src/main/java/hivemall/smile/data/Attribute.java @@ -18,6 +18,7 @@ */ package hivemall.smile.data; +import hivemall.annotations.BackwardCompatibility; import hivemall.annotations.Immutable; import hivemall.annotations.Mutable; @@ -76,6 +77,22 @@ public abstract class Attribute { return type; } + @BackwardCompatibility + public static AttributeType resolve(int id) { + final AttributeType type; + switch (id) { + case 1: + type = NUMERIC; + break; + case 2: + type = NOMINAL; + break; + default: + throw new IllegalStateException("Unexpected type: " + id); + } + return type; + } + } @Immutable http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4be8adbc/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java b/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java index 6e422a5..5d16248 100644 --- a/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java +++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java @@ -19,17 +19,21 @@ package hivemall.smile.tools; import hivemall.annotations.Since; -import hivemall.smile.classification.DecisionTree; -import hivemall.smile.regression.RegressionTree; +import hivemall.annotations.VisibleForTesting; +import hivemall.smile.data.Attribute.AttributeType; import hivemall.smile.vm.StackMachine; import hivemall.smile.vm.VMRuntimeException; import hivemall.utils.codec.Base91; import hivemall.utils.codec.DeflateCodec; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.io.IOUtils; +import hivemall.utils.lang.ObjectUtils; import java.io.Closeable; +import java.io.Externalizable; import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; import java.util.Arrays; import javax.annotation.Nonnull; @@ -59,8 +63,7 @@ import org.apache.hadoop.io.Text; import org.apache.hadoop.io.Writable; import org.apache.hadoop.mapred.JobConf; -@Description( - name = "tree_predict_v1", +@Description(name = "tree_predict_v1", value = "_FUNC_(string modelId, int modelType, string script, array<double> features [, const boolean classification])" + " - Returns a prediction result of a random forest") @UDFType(deterministic = true, stateful = false) @@ -263,8 +266,8 @@ public final class TreePredictUDFv1 extends GenericUDF { @Nullable private String prevModelId = null; - private DecisionTree.Node cNode = null; - private RegressionTree.Node rNode = null; + private DtNodeV1 cNode = null; + private RtNodeV1 rNode = null; JavaSerializationEvaluator() {} @@ -285,13 +288,34 @@ public final class TreePredictUDFv1 extends GenericUDF { int length = script.getLength(); byte[] b = script.getBytes(); b = Base91.decode(b, 0, length); - this.cNode = DecisionTree.deserialize(b, b.length, compressed); + this.cNode = deserializeDecisionTree(b, b.length, compressed); } assert (cNode != null); int result = cNode.predict(features); return new IntWritable(result); } + @Nonnull + @VisibleForTesting + static DtNodeV1 deserializeDecisionTree(@Nonnull final byte[] serializedObj, + final int length, final boolean compressed) throws HiveException { + final DtNodeV1 root = new DtNodeV1(); + try { + if (compressed) { + ObjectUtils.readCompressedObject(serializedObj, 0, length, root); + } else { + ObjectUtils.readObject(serializedObj, length, root); + } + } catch (IOException ioe) { + throw new HiveException("IOException cause while deserializing DecisionTree object", + ioe); + } catch (Exception e) { + throw new HiveException("Exception cause while deserializing DecisionTree object", + e); + } + return root; + } + private DoubleWritable evaluateRegression(@Nonnull String modelId, boolean compressed, @Nonnull Text script, double[] features) throws HiveException { if (!modelId.equals(prevModelId)) { @@ -299,18 +323,246 @@ public final class TreePredictUDFv1 extends GenericUDF { int length = script.getLength(); byte[] b = script.getBytes(); b = Base91.decode(b, 0, length); - this.rNode = RegressionTree.deserialize(b, b.length, compressed); + this.rNode = deserializeRegressionTree(b, b.length, compressed); } assert (rNode != null); double result = rNode.predict(features); return new DoubleWritable(result); } + @Nonnull + @VisibleForTesting + static RtNodeV1 deserializeRegressionTree(final byte[] serializedObj, + final int length, final boolean compressed) throws HiveException { + final RtNodeV1 root = new RtNodeV1(); + try { + if (compressed) { + ObjectUtils.readCompressedObject(serializedObj, 0, length, root); + } else { + ObjectUtils.readObject(serializedObj, length, root); + } + } catch (IOException ioe) { + throw new HiveException("IOException cause while deserializing DecisionTree object", + ioe); + } catch (Exception e) { + throw new HiveException("Exception cause while deserializing DecisionTree object", + e); + } + return root; + } + @Override public void close() throws IOException {} } + /** + * Classification tree node. + */ + static final class DtNodeV1 implements Externalizable { + + /** + * Predicted class label for this node. + */ + int output = -1; + /** + * The split feature for this node. + */ + int splitFeature = -1; + /** + * The type of split feature + */ + AttributeType splitFeatureType = null; + /** + * The split value. + */ + double splitValue = Double.NaN; + /** + * Reduction in splitting criterion. + */ + double splitScore = 0.0; + /** + * Children node. + */ + DtNodeV1 trueChild = null; + /** + * Children node. + */ + DtNodeV1 falseChild = null; + /** + * Predicted output for children node. + */ + int trueChildOutput = -1; + /** + * Predicted output for children node. + */ + int falseChildOutput = -1; + + DtNodeV1() {}// for Externalizable + + /** + * Constructor. + */ + DtNodeV1(int output) { + this.output = output; + } + + /** + * Evaluate the regression tree over an instance. + */ + int predict(final double[] x) { + if (trueChild == null && falseChild == null) { + return output; + } else { + if (splitFeatureType == AttributeType.NOMINAL) { + if (x[splitFeature] == splitValue) { + return trueChild.predict(x); + } else { + return falseChild.predict(x); + } + } else if (splitFeatureType == AttributeType.NUMERIC) { + if (x[splitFeature] <= splitValue) { + return trueChild.predict(x); + } else { + return falseChild.predict(x); + } + } else { + throw new IllegalStateException( + "Unsupported attribute type: " + splitFeatureType); + } + } + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + this.output = in.readInt(); + this.splitFeature = in.readInt(); + int typeId = in.readInt(); + if (typeId == -1) { + this.splitFeatureType = null; + } else { + this.splitFeatureType = AttributeType.resolve(typeId); + } + this.splitValue = in.readDouble(); + if (in.readBoolean()) { + this.trueChild = new DtNodeV1(); + trueChild.readExternal(in); + } + if (in.readBoolean()) { + this.falseChild = new DtNodeV1(); + falseChild.readExternal(in); + } + } + + } + + /** + * Regression tree node. + */ + static final class RtNodeV1 implements Externalizable { + + /** + * Predicted real value for this node. + */ + double output = 0.0; + /** + * The split feature for this node. + */ + int splitFeature = -1; + /** + * The type of split feature + */ + AttributeType splitFeatureType = null; + /** + * The split value. + */ + double splitValue = Double.NaN; + /** + * Reduction in squared error compared to parent. + */ + double splitScore = 0.0; + /** + * Children node. + */ + RtNodeV1 trueChild; + /** + * Children node. + */ + RtNodeV1 falseChild; + /** + * Predicted output for children node. + */ + double trueChildOutput = 0.0; + /** + * Predicted output for children node. + */ + double falseChildOutput = 0.0; + + RtNodeV1() {}//for Externalizable + + RtNodeV1(double output) { + this.output = output; + } + + /** + * Evaluate the regression tree over an instance. + */ + double predict(final double[] x) { + if (trueChild == null && falseChild == null) { + return output; + } else { + if (splitFeatureType == AttributeType.NOMINAL) { + // REVIEWME if(Math.equals(x[splitFeature], splitValue)) { + if (x[splitFeature] == splitValue) { + return trueChild.predict(x); + } else { + return falseChild.predict(x); + } + } else if (splitFeatureType == AttributeType.NUMERIC) { + if (x[splitFeature] <= splitValue) { + return trueChild.predict(x); + } else { + return falseChild.predict(x); + } + } else { + throw new IllegalStateException( + "Unsupported attribute type: " + splitFeatureType); + } + } + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + this.output = in.readDouble(); + this.splitFeature = in.readInt(); + int typeId = in.readInt(); + if (typeId == -1) { + this.splitFeatureType = null; + } else { + this.splitFeatureType = AttributeType.resolve(typeId); + } + this.splitValue = in.readDouble(); + if (in.readBoolean()) { + this.trueChild = new RtNodeV1(); + trueChild.readExternal(in); + } + if (in.readBoolean()) { + this.falseChild = new RtNodeV1(); + falseChild.readExternal(in); + } + } + } + static final class StackmachineEvaluator implements Evaluator { private String prevModelId = null; @@ -394,9 +646,9 @@ public final class TreePredictUDFv1 extends GenericUDF { ScriptEngineManager manager = new ScriptEngineManager(); ScriptEngine engine = manager.getEngineByExtension("js"); if (!(engine instanceof Compilable)) { - throw new UDFArgumentException("ScriptEngine was not compilable: " - + engine.getFactory().getEngineName() + " version " - + engine.getFactory().getEngineVersion()); + throw new UDFArgumentException( + "ScriptEngine was not compilable: " + engine.getFactory().getEngineName() + + " version " + engine.getFactory().getEngineVersion()); } this.scriptEngine = engine; this.compilableEngine = (Compilable) engine; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4be8adbc/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java b/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java index f885041..75bbe78 100644 --- a/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java +++ b/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java @@ -19,20 +19,32 @@ package hivemall.smile.tools; import static org.junit.Assert.assertEquals; + import hivemall.math.matrix.dense.RowMajorDenseMatrix2d; import hivemall.smile.classification.DecisionTree; import hivemall.smile.data.Attribute; import hivemall.smile.regression.RegressionTree; +import hivemall.smile.tools.TreePredictUDFv1.DtNodeV1; +import hivemall.smile.tools.TreePredictUDFv1.JavaSerializationEvaluator; import hivemall.smile.tools.TreePredictUDFv1.ModelType; import hivemall.smile.utils.SmileExtUtils; import hivemall.smile.vm.StackMachine; +import hivemall.utils.codec.Base91; +import hivemall.utils.io.IOUtils; import hivemall.utils.lang.ArrayUtils; +import smile.data.AttributeDataset; +import smile.data.parser.ArffParser; +import smile.math.Math; +import smile.validation.CrossValidation; +import smile.validation.LOOCV; +import smile.validation.RMSE; import java.io.BufferedInputStream; import java.io.IOException; import java.io.InputStream; import java.net.URL; import java.text.ParseException; +import java.util.zip.GZIPInputStream; import org.apache.hadoop.hive.ql.metadata.HiveException; import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredJavaObject; @@ -43,19 +55,26 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.io.IntWritable; +import org.junit.Assert; import org.junit.Test; -import smile.data.AttributeDataset; -import smile.data.parser.ArffParser; -import smile.math.Math; -import smile.validation.CrossValidation; -import smile.validation.LOOCV; -import smile.validation.RMSE; - @SuppressWarnings("deprecation") public class TreePredictUDFv1Test { private static final boolean DEBUG = false; + @Test + public void testDeserializationOfV1() throws IOException, HiveException { + InputStream io = TreePredictUDFv1Test.class.getResourceAsStream("dtv1_serialized.csv.gz"); + GZIPInputStream gis = new GZIPInputStream(io); + byte[] serialized = IOUtils.toByteArray(gis); + + byte[] b = Base91.decode(serialized); + DtNodeV1 deserialized = + JavaSerializationEvaluator.deserializeDecisionTree(b, b.length, true); + + Assert.assertNotNull(deserialized); + } + /** * Test of learn method, of class DecisionTree. */ @@ -78,8 +97,8 @@ public class TreePredictUDFv1Test { int[] trainy = Math.slice(y, loocv.train[i]); Attribute[] attrs = SmileExtUtils.convertAttributeTypes(iris.attributes()); - DecisionTree tree = new DecisionTree(attrs, new RowMajorDenseMatrix2d(trainx, - x[0].length), trainy, 4); + DecisionTree tree = new DecisionTree(attrs, + new RowMajorDenseMatrix2d(trainx, x[0].length), trainy, 4); assertEquals(tree.predict(x[loocv.test[i]]), evalPredict(tree, x[loocv.test[i]])); } } @@ -106,8 +125,8 @@ public class TreePredictUDFv1Test { double[][] testx = Math.slice(datax, cv.test[i]); Attribute[] attrs = SmileExtUtils.convertAttributeTypes(data.attributes()); - RegressionTree tree = new RegressionTree(attrs, new RowMajorDenseMatrix2d(trainx, - trainx[0].length), trainy, 20); + RegressionTree tree = new RegressionTree(attrs, + new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20); for (int j = 0; j < testx.length; j++) { assertEquals(tree.predict(testx[j]), evalPredict(tree, testx[j]), 1.0); @@ -146,8 +165,8 @@ public class TreePredictUDFv1Test { } Attribute[] attrs = SmileExtUtils.convertAttributeTypes(data.attributes()); - RegressionTree tree = new RegressionTree(attrs, new RowMajorDenseMatrix2d(trainx, - trainx[0].length), trainy, 20); + RegressionTree tree = new RegressionTree(attrs, + new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20); debugPrint(String.format("RMSE = %.4f\n", rmse(tree, testx, testy))); for (int i = m; i < n; i++) { @@ -164,45 +183,46 @@ public class TreePredictUDFv1Test { return new RMSE().measure(y, predictions); } - private static int evalPredict(DecisionTree tree, double[] x) throws HiveException, IOException { + private static int evalPredict(DecisionTree tree, double[] x) + throws HiveException, IOException { String opScript = tree.predictOpCodegen(StackMachine.SEP); debugPrint(opScript); TreePredictUDFv1 udf = new TreePredictUDFv1(); - udf.initialize(new ObjectInspector[] { - PrimitiveObjectInspectorFactory.javaStringObjectInspector, - PrimitiveObjectInspectorFactory.javaIntObjectInspector, - PrimitiveObjectInspectorFactory.javaStringObjectInspector, - ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), - ObjectInspectorUtils.getConstantObjectInspector( - PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, true)}); + udf.initialize( + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, true)}); DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"), - new DeferredJavaObject(ModelType.opscode.getId()), - new DeferredJavaObject(opScript), new DeferredJavaObject(ArrayUtils.toList(x)), - new DeferredJavaObject(true)}; + new DeferredJavaObject(ModelType.opscode.getId()), new DeferredJavaObject(opScript), + new DeferredJavaObject(ArrayUtils.toList(x)), new DeferredJavaObject(true)}; IntWritable result = (IntWritable) udf.evaluate(arguments); udf.close(); return result.get(); } - private static double evalPredict(RegressionTree tree, double[] x) throws HiveException, - IOException { + private static double evalPredict(RegressionTree tree, double[] x) + throws HiveException, IOException { String opScript = tree.predictOpCodegen(StackMachine.SEP); debugPrint(opScript); TreePredictUDFv1 udf = new TreePredictUDFv1(); - udf.initialize(new ObjectInspector[] { - PrimitiveObjectInspectorFactory.javaStringObjectInspector, - PrimitiveObjectInspectorFactory.javaIntObjectInspector, - PrimitiveObjectInspectorFactory.javaStringObjectInspector, - ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), - ObjectInspectorUtils.getConstantObjectInspector( - PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, false)}); + udf.initialize( + new ObjectInspector[] {PrimitiveObjectInspectorFactory.javaStringObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorFactory.getStandardListObjectInspector( + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector), + ObjectInspectorUtils.getConstantObjectInspector( + PrimitiveObjectInspectorFactory.javaBooleanObjectInspector, false)}); DeferredObject[] arguments = new DeferredObject[] {new DeferredJavaObject("model_id#1"), - new DeferredJavaObject(ModelType.opscode.getId()), - new DeferredJavaObject(opScript), new DeferredJavaObject(ArrayUtils.toList(x)), - new DeferredJavaObject(false)}; + new DeferredJavaObject(ModelType.opscode.getId()), new DeferredJavaObject(opScript), + new DeferredJavaObject(ArrayUtils.toList(x)), new DeferredJavaObject(false)}; DoubleWritable result = (DoubleWritable) udf.evaluate(arguments); udf.close(); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/4be8adbc/core/src/test/java/hivemall/smile/tools/dtv1_serialized.csv.gz ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/smile/tools/dtv1_serialized.csv.gz b/core/src/test/java/hivemall/smile/tools/dtv1_serialized.csv.gz new file mode 100644 index 0000000..69579c9 Binary files /dev/null and b/core/src/test/java/hivemall/smile/tools/dtv1_serialized.csv.gz differ