http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/ReadOnlyCSRMatrix.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/matrix/ReadOnlyCSRMatrix.java b/core/src/main/java/hivemall/matrix/ReadOnlyCSRMatrix.java deleted file mode 100644 index 1c7a9a1..0000000 --- a/core/src/main/java/hivemall/matrix/ReadOnlyCSRMatrix.java +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.matrix; - -import hivemall.utils.lang.Preconditions; - -import java.util.Arrays; - -import javax.annotation.Nonnegative; -import javax.annotation.Nonnull; - -/** - * Read-only CSR Matrix. - * - * @see http://netlib.org/linalg/html_templates/node91.html#SECTION00931100000000000000 - */ -public final class ReadOnlyCSRMatrix extends Matrix { - - @Nonnull - private final int[] rowPointers; - @Nonnull - private final int[] columnIndices; - @Nonnull - private final double[] values; - - @Nonnegative - private final int numRows; - @Nonnegative - private final int numColumns; - - public ReadOnlyCSRMatrix(@Nonnull int[] rowPointers, @Nonnull int[] columnIndices, - @Nonnull double[] values, @Nonnegative int numColumns) { - super(); - Preconditions.checkArgument(rowPointers.length >= 1, - "rowPointers must be greather than 0: " + rowPointers.length); - Preconditions.checkArgument(columnIndices.length == values.length, "#columnIndices (" - + columnIndices.length + ") must be equals to #values (" + values.length + ")"); - this.rowPointers = rowPointers; - this.columnIndices = columnIndices; - this.values = values; - this.numRows = rowPointers.length - 1; - this.numColumns = numColumns; - } - - @Override - public boolean readOnly() { - return true; - } - - @Override - public int numRows() { - return numRows; - } - - @Override - public int numColumns() { - return numColumns; - } - - @Override - public int numColumns(@Nonnegative final int row) { - checkRowIndex(row, numRows); - - int columns = rowPointers[row + 1] - rowPointers[row]; - return columns; - } - - @Override - public double get(@Nonnegative final int row, @Nonnegative final int col, - final double defaultValue) { - checkIndex(row, col, numRows, numColumns); - - final int index = getIndex(row, col); - if (index < 0) { - return defaultValue; - } - return values[index]; - } - - @Override - public double getAndSet(@Nonnegative final int row, @Nonnegative final int col, - final double value) { - checkIndex(row, col, numRows, numColumns); - - final int index = getIndex(row, col); - if (index < 0) { - throw new UnsupportedOperationException("Cannot update value in row " + row + ", col " - + col); - } - - double old = values[index]; - values[index] = value; - return old; - } - - @Override - public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) { - checkIndex(row, col, numRows, numColumns); - - final int index = getIndex(row, col); - if (index < 0) { - throw new UnsupportedOperationException("Cannot update value in row " + row + ", col " - + col); - } - values[index] = value; - } - - private int getIndex(@Nonnegative final int row, @Nonnegative final int col) { - int leftIn = rowPointers[row]; - int rightEx = rowPointers[row + 1]; - final int index = Arrays.binarySearch(columnIndices, leftIn, rightEx, col); - if (index >= 0 && index >= values.length) { - throw new IndexOutOfBoundsException("Value index " + index + " out of range " - + values.length); - } - return index; - } - -}
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/matrix/ReadOnlyDenseMatrix2d.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/matrix/ReadOnlyDenseMatrix2d.java b/core/src/main/java/hivemall/matrix/ReadOnlyDenseMatrix2d.java deleted file mode 100644 index 040fef8..0000000 --- a/core/src/main/java/hivemall/matrix/ReadOnlyDenseMatrix2d.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.matrix; - -import javax.annotation.Nonnegative; -import javax.annotation.Nonnull; - -public final class ReadOnlyDenseMatrix2d extends Matrix { - - @Nonnull - private final double[][] data; - - @Nonnegative - private final int numRows; - @Nonnegative - private final int numColumns; - - public ReadOnlyDenseMatrix2d(@Nonnull double[][] data, @Nonnegative int numColumns) { - this.data = data; - this.numRows = data.length; - this.numColumns = numColumns; - } - - @Override - public boolean readOnly() { - return true; - } - - @Override - public void setDefaultValue(double value) { - throw new UnsupportedOperationException("The defaultValue of a DenseMatrix is fixed to 0.d"); - } - - @Override - public int numRows() { - return numRows; - } - - @Override - public int numColumns() { - return numColumns; - } - - @Override - public int numColumns(@Nonnegative final int row) { - checkRowIndex(row, numRows); - - return data[row].length; - } - - @Override - public double get(@Nonnegative final int row, @Nonnegative final int col, - final double defaultValue) { - checkIndex(row, col, numRows, numColumns); - - final double[] rowData = data[row]; - if (col >= rowData.length) { - return defaultValue; - } - return rowData[col]; - } - - @Override - public double getAndSet(@Nonnegative final int row, @Nonnegative final int col, - final double value) { - checkIndex(row, col, numRows, numColumns); - - final double[] rowData = data[row]; - checkColIndex(col, rowData.length); - - double old = rowData[col]; - rowData[col] = value; - return old; - } - - @Override - public void set(@Nonnegative final int row, @Nonnegative final int col, final double value) { - checkIndex(row, col, numRows, numColumns); - - final double[] rowData = data[row]; - checkColIndex(col, rowData.length); - - rowData[col] = value; - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/mf/FactorizedModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/mf/FactorizedModel.java b/core/src/main/java/hivemall/mf/FactorizedModel.java index b92a5d8..a4bea00 100644 --- a/core/src/main/java/hivemall/mf/FactorizedModel.java +++ b/core/src/main/java/hivemall/mf/FactorizedModel.java @@ -18,7 +18,7 @@ */ package hivemall.mf; -import hivemall.utils.collections.IntOpenHashMap; +import hivemall.utils.collections.maps.IntOpenHashMap; import hivemall.utils.math.MathUtils; import java.util.Random; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/model/AbstractPredictionModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/AbstractPredictionModel.java b/core/src/main/java/hivemall/model/AbstractPredictionModel.java index 37b69da..b48282b 100644 --- a/core/src/main/java/hivemall/model/AbstractPredictionModel.java +++ b/core/src/main/java/hivemall/model/AbstractPredictionModel.java @@ -21,8 +21,8 @@ package hivemall.model; import hivemall.mix.MixedWeight; import hivemall.mix.MixedWeight.WeightWithCovar; import hivemall.mix.MixedWeight.WeightWithDelta; -import hivemall.utils.collections.IntOpenHashMap; -import hivemall.utils.collections.OpenHashMap; +import hivemall.utils.collections.maps.IntOpenHashMap; +import hivemall.utils.collections.maps.OpenHashMap; import javax.annotation.Nonnull; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/model/SparseModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/model/SparseModel.java b/core/src/main/java/hivemall/model/SparseModel.java index 96e1d5a..a2b4708 100644 --- a/core/src/main/java/hivemall/model/SparseModel.java +++ b/core/src/main/java/hivemall/model/SparseModel.java @@ -22,7 +22,7 @@ import hivemall.model.WeightValueWithClock.WeightValueParamsF1Clock; import hivemall.model.WeightValueWithClock.WeightValueParamsF2Clock; import hivemall.model.WeightValueWithClock.WeightValueWithCovarClock; import hivemall.utils.collections.IMapIterator; -import hivemall.utils.collections.OpenHashMap; +import hivemall.utils.collections.maps.OpenHashMap; import javax.annotation.Nonnull; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/ModelType.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/ModelType.java b/core/src/main/java/hivemall/smile/ModelType.java deleted file mode 100644 index 8925075..0000000 --- a/core/src/main/java/hivemall/smile/ModelType.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.smile; - -public enum ModelType { - - // not compressed - opscode(1, false), javascript(2, false), serialization(3, false), - // compressed - opscode_compressed(-1, true), javascript_compressed(-2, true), - serialization_compressed(-3, true); - - private final int id; - private final boolean compressed; - - private ModelType(int id, boolean compressed) { - this.id = id; - this.compressed = compressed; - } - - public int getId() { - return id; - } - - public boolean isCompressed() { - return compressed; - } - - public static ModelType resolve(String name, boolean compressed) { - name = name.toLowerCase(); - if ("opscode".equals(name) || "vm".equals(name)) { - return compressed ? opscode_compressed : opscode; - } else if ("javascript".equals(name) || "js".equals(name)) { - return compressed ? javascript_compressed : javascript; - } else if ("serialization".equals(name) || "ser".equals(name)) { - return compressed ? serialization_compressed : serialization; - } else { - throw new IllegalStateException("Unexpected output type: " + name); - } - } - - public static ModelType resolve(final int id) { - final ModelType type; - switch (id) { - case 1: - type = opscode; - break; - case -1: - type = opscode_compressed; - break; - case 2: - type = javascript; - break; - case -2: - type = javascript_compressed; - break; - case 3: - type = serialization; - break; - case -3: - type = serialization_compressed; - break; - default: - throw new IllegalStateException("Unexpected ID for ModelType: " + id); - } - return type; - } - -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/classification/DecisionTree.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/classification/DecisionTree.java b/core/src/main/java/hivemall/smile/classification/DecisionTree.java index 6b22473..2d086b9 100644 --- a/core/src/main/java/hivemall/smile/classification/DecisionTree.java +++ b/core/src/main/java/hivemall/smile/classification/DecisionTree.java @@ -33,100 +33,94 @@ */ package hivemall.smile.classification; +import hivemall.annotations.VisibleForTesting; +import hivemall.math.matrix.Matrix; +import hivemall.math.matrix.ints.ColumnMajorIntMatrix; +import hivemall.math.random.PRNG; +import hivemall.math.random.RandomNumberGeneratorFactory; +import hivemall.math.vector.DenseVector; +import hivemall.math.vector.SparseVector; +import hivemall.math.vector.Vector; +import hivemall.math.vector.VectorProcedure; import hivemall.smile.data.Attribute; import hivemall.smile.data.Attribute.AttributeType; import hivemall.smile.utils.SmileExtUtils; -import hivemall.utils.collections.IntArrayList; +import hivemall.utils.collections.lists.IntArrayList; import hivemall.utils.lang.ObjectUtils; -import hivemall.utils.lang.StringUtils; +import hivemall.utils.sampling.IntReservoirSampler; import java.io.Externalizable; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; -import java.util.ArrayList; import java.util.Arrays; -import java.util.List; import java.util.PriorityQueue; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.roaringbitmap.IntConsumer; +import org.roaringbitmap.RoaringBitmap; import smile.classification.Classifier; import smile.math.Math; -import smile.math.Random; /** - * Decision tree for classification. A decision tree can be learned by splitting the training set - * into subsets based on an attribute value test. This process is repeated on each derived subset in - * a recursive manner called recursive partitioning. The recursion is completed when the subset at a - * node all has the same value of the target variable, or when splitting no longer adds value to the - * predictions. + * Decision tree for classification. A decision tree can be learned by splitting the training set into subsets based on an attribute value test. This + * process is repeated on each derived subset in a recursive manner called recursive partitioning. The recursion is completed when the subset at a + * node all has the same value of the target variable, or when splitting no longer adds value to the predictions. * <p> - * The algorithms that are used for constructing decision trees usually work top-down by choosing a - * variable at each step that is the next best variable to use in splitting the set of items. "Best" - * is defined by how well the variable splits the set into homogeneous subsets that have the same - * value of the target variable. Different algorithms use different formulae for measuring "best". - * Used by the CART algorithm, Gini impurity is a measure of how often a randomly chosen element - * from the set would be incorrectly labeled if it were randomly labeled according to the - * distribution of labels in the subset. Gini impurity can be computed by summing the probability of - * each item being chosen times the probability of a mistake in categorizing that item. It reaches - * its minimum (zero) when all cases in the node fall into a single target category. Information - * gain is another popular measure, used by the ID3, C4.5 and C5.0 algorithms. Information gain is - * based on the concept of entropy used in information theory. For categorical variables with - * different number of levels, however, information gain are biased in favor of those attributes - * with more levels. Instead, one may employ the information gain ratio, which solves the drawback - * of information gain. + * The algorithms that are used for constructing decision trees usually work top-down by choosing a variable at each step that is the next best + * variable to use in splitting the set of items. "Best" is defined by how well the variable splits the set into homogeneous subsets that have the + * same value of the target variable. Different algorithms use different formulae for measuring "best". Used by the CART algorithm, Gini impurity is a + * measure of how often a randomly chosen element from the set would be incorrectly labeled if it were randomly labeled according to the distribution + * of labels in the subset. Gini impurity can be computed by summing the probability of each item being chosen times the probability of a mistake in + * categorizing that item. It reaches its minimum (zero) when all cases in the node fall into a single target category. Information gain is another + * popular measure, used by the ID3, C4.5 and C5.0 algorithms. Information gain is based on the concept of entropy used in information theory. For + * categorical variables with different number of levels, however, information gain are biased in favor of those attributes with more levels. Instead, + * one may employ the information gain ratio, which solves the drawback of information gain. * <p> - * Classification and Regression Tree techniques have a number of advantages over many of those - * alternative techniques. + * Classification and Regression Tree techniques have a number of advantages over many of those alternative techniques. * <dl> * <dt>Simple to understand and interpret.</dt> - * <dd>In most cases, the interpretation of results summarized in a tree is very simple. This - * simplicity is useful not only for purposes of rapid classification of new observations, but can - * also often yield a much simpler "model" for explaining why observations are classified or - * predicted in a particular manner.</dd> + * <dd>In most cases, the interpretation of results summarized in a tree is very simple. This simplicity is useful not only for purposes of rapid + * classification of new observations, but can also often yield a much simpler "model" for explaining why observations are classified or predicted in + * a particular manner.</dd> * <dt>Able to handle both numerical and categorical data.</dt> - * <dd>Other techniques are usually specialized in analyzing datasets that have only one type of - * variable.</dd> + * <dd>Other techniques are usually specialized in analyzing datasets that have only one type of variable.</dd> * <dt>Tree methods are nonparametric and nonlinear.</dt> - * <dd>The final results of using tree methods for classification or regression can be summarized in - * a series of (usually few) logical if-then conditions (tree nodes). Therefore, there is no - * implicit assumption that the underlying relationships between the predictor variables and the - * dependent variable are linear, follow some specific non-linear link function, or that they are - * even monotonic in nature. Thus, tree methods are particularly well suited for data mining tasks, - * where there is often little a priori knowledge nor any coherent set of theories or predictions - * regarding which variables are related and how. In those types of data analytics, tree methods can - * often reveal simple relationships between just a few variables that could have easily gone - * unnoticed using other analytic techniques.</dd> + * <dd>The final results of using tree methods for classification or regression can be summarized in a series of (usually few) logical if-then + * conditions (tree nodes). Therefore, there is no implicit assumption that the underlying relationships between the predictor variables and the + * dependent variable are linear, follow some specific non-linear link function, or that they are even monotonic in nature. Thus, tree methods are + * particularly well suited for data mining tasks, where there is often little a priori knowledge nor any coherent set of theories or predictions + * regarding which variables are related and how. In those types of data analytics, tree methods can often reveal simple relationships between just a + * few variables that could have easily gone unnoticed using other analytic techniques.</dd> * </dl> - * One major problem with classification and regression trees is their high variance. Often a small - * change in the data can result in a very different series of splits, making interpretation - * somewhat precarious. Besides, decision-tree learners can create over-complex trees that cause - * over-fitting. Mechanisms such as pruning are necessary to avoid this problem. Another limitation - * of trees is the lack of smoothness of the prediction surface. + * One major problem with classification and regression trees is their high variance. Often a small change in the data can result in a very different + * series of splits, making interpretation somewhat precarious. Besides, decision-tree learners can create over-complex trees that cause over-fitting. + * Mechanisms such as pruning are necessary to avoid this problem. Another limitation of trees is the lack of smoothness of the prediction surface. * <p> - * Some techniques such as bagging, boosting, and random forest use more than one decision tree for - * their analysis. + * Some techniques such as bagging, boosting, and random forest use more than one decision tree for their analysis. */ -public final class DecisionTree implements Classifier<double[]> { +public final class DecisionTree implements Classifier<Vector> { /** * The attributes of independent variable. */ + @Nonnull private final Attribute[] _attributes; private final boolean _hasNumericType; /** - * Variable importance. Every time a split of a node is made on variable the (GINI, information - * gain, etc.) impurity criterion for the two descendant nodes is less than the parent node. - * Adding up the decreases for each individual variable over the tree gives a simple measure of + * Variable importance. Every time a split of a node is made on variable the (GINI, information gain, etc.) impurity criterion for the two + * descendant nodes is less than the parent node. Adding up the decreases for each individual variable over the tree gives a simple measure of * variable importance. */ - private final double[] _importance; + @Nonnull + private final Vector _importance; /** * The root of the regression tree */ + @Nonnull private final Node _root; /** * The maximum number of the tree depth @@ -135,6 +129,7 @@ public final class DecisionTree implements Classifier<double[]> { /** * The splitting rule. */ + @Nonnull private final SplitRule _rule; /** * The number of classes. @@ -153,24 +148,23 @@ public final class DecisionTree implements Classifier<double[]> { */ private final int _minLeafSize; /** - * The index of training values in ascending order. Note that only numeric attributes will be - * sorted. + * The index of training values in ascending order. Note that only numeric attributes will be sorted. */ - private final int[][] _order; + @Nonnull + private final ColumnMajorIntMatrix _order; - private final Random _rnd; + @Nonnull + private final PRNG _rnd; /** * The criterion to choose variable to split instances. */ public static enum SplitRule { /** - * Used by the CART algorithm, Gini impurity is a measure of how often a randomly chosen - * element from the set would be incorrectly labeled if it were randomly labeled according - * to the distribution of labels in the subset. Gini impurity can be computed by summing the - * probability of each item being chosen times the probability of a mistake in categorizing - * that item. It reaches its minimum (zero) when all cases in the node fall into a single - * target category. + * Used by the CART algorithm, Gini impurity is a measure of how often a randomly chosen element from the set would be incorrectly labeled if + * it were randomly labeled according to the distribution of labels in the subset. Gini impurity can be computed by summing the probability of + * each item being chosen times the probability of a mistake in categorizing that item. It reaches its minimum (zero) when all cases in the + * node fall into a single target category. */ GINI, /** @@ -193,6 +187,11 @@ public final class DecisionTree implements Classifier<double[]> { */ int output = -1; /** + * Posteriori probability based on sample ratios in this node. + */ + @Nullable + double[] posteriori = null; + /** * The split feature for this node. */ int splitFeature = -1; @@ -227,28 +226,35 @@ public final class DecisionTree implements Classifier<double[]> { public Node() {}// for Externalizable - /** - * Constructor. - */ - public Node(int output) { + public Node(int output, @Nonnull double[] posteriori) { this.output = output; + this.posteriori = posteriori; + } + + private boolean isLeaf() { + return posteriori != null; + } + + @VisibleForTesting + public int predict(@Nonnull final double[] x) { + return predict(new DenseVector(x)); } /** * Evaluate the regression tree over an instance. */ - public int predict(final double[] x) { + public int predict(@Nonnull final Vector x) { if (trueChild == null && falseChild == null) { return output; } else { if (splitFeatureType == AttributeType.NOMINAL) { - if (x[splitFeature] == splitValue) { + if (x.get(splitFeature, Double.NaN) == splitValue) { return trueChild.predict(x); } else { return falseChild.predict(x); } } else if (splitFeatureType == AttributeType.NUMERIC) { - if (x[splitFeature] <= splitValue) { + if (x.get(splitFeature, Double.NaN) <= splitValue) { return trueChild.predict(x); } else { return falseChild.predict(x); @@ -260,6 +266,32 @@ public final class DecisionTree implements Classifier<double[]> { } } + /** + * Evaluate the regression tree over an instance. + */ + public void predict(@Nonnull final Vector x, @Nonnull final PredictionHandler handler) { + if (trueChild == null && falseChild == null) { + handler.handle(output, posteriori); + } else { + if (splitFeatureType == AttributeType.NOMINAL) { + if (x.get(splitFeature, Double.NaN) == splitValue) { + trueChild.predict(x, handler); + } else { + falseChild.predict(x, handler); + } + } else if (splitFeatureType == AttributeType.NUMERIC) { + if (x.get(splitFeature, Double.NaN) <= splitValue) { + trueChild.predict(x, handler); + } else { + falseChild.predict(x, handler); + } + } else { + throw new IllegalStateException("Unsupported attribute type: " + + splitFeatureType); + } + } + } + public void jsCodegen(@Nonnull final StringBuilder builder, final int depth) { if (trueChild == null && falseChild == null) { indent(builder, depth); @@ -298,99 +330,71 @@ public final class DecisionTree implements Classifier<double[]> { } } - public int opCodegen(final List<String> scripts, int depth) { - int selfDepth = 0; - final StringBuilder buf = new StringBuilder(); - if (trueChild == null && falseChild == null) { - buf.append("push ").append(output); - scripts.add(buf.toString()); - buf.setLength(0); - buf.append("goto last"); - scripts.add(buf.toString()); - selfDepth += 2; - } else { - if (splitFeatureType == AttributeType.NOMINAL) { - buf.append("push ").append("x[").append(splitFeature).append("]"); - scripts.add(buf.toString()); - buf.setLength(0); - buf.append("push ").append(splitValue); - scripts.add(buf.toString()); - buf.setLength(0); - buf.append("ifeq "); - scripts.add(buf.toString()); - depth += 3; - selfDepth += 3; - int trueDepth = trueChild.opCodegen(scripts, depth); - selfDepth += trueDepth; - scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth)); - int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth); - selfDepth += falseDepth; - } else if (splitFeatureType == AttributeType.NUMERIC) { - buf.append("push ").append("x[").append(splitFeature).append("]"); - scripts.add(buf.toString()); - buf.setLength(0); - buf.append("push ").append(splitValue); - scripts.add(buf.toString()); - buf.setLength(0); - buf.append("ifle "); - scripts.add(buf.toString()); - depth += 3; - selfDepth += 3; - int trueDepth = trueChild.opCodegen(scripts, depth); - selfDepth += trueDepth; - scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth)); - int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth); - selfDepth += falseDepth; - } else { - throw new IllegalStateException("Unsupported attribute type: " - + splitFeatureType); - } - } - return selfDepth; - } - @Override public void writeExternal(ObjectOutput out) throws IOException { - out.writeInt(output); out.writeInt(splitFeature); if (splitFeatureType == null) { - out.writeInt(-1); + out.writeByte(-1); } else { - out.writeInt(splitFeatureType.getTypeId()); + out.writeByte(splitFeatureType.getTypeId()); } out.writeDouble(splitValue); - if (trueChild == null) { - out.writeBoolean(false); - } else { + + if (isLeaf()) { out.writeBoolean(true); - trueChild.writeExternal(out); - } - if (falseChild == null) { - out.writeBoolean(false); + + out.writeInt(output); + out.writeInt(posteriori.length); + for (int i = 0; i < posteriori.length; i++) { + out.writeDouble(posteriori[i]); + } } else { - out.writeBoolean(true); - falseChild.writeExternal(out); + out.writeBoolean(false); + + if (trueChild == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + trueChild.writeExternal(out); + } + if (falseChild == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + falseChild.writeExternal(out); + } } } @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - this.output = in.readInt(); this.splitFeature = in.readInt(); - int typeId = in.readInt(); + byte typeId = in.readByte(); if (typeId == -1) { this.splitFeatureType = null; } else { this.splitFeatureType = AttributeType.resolve(typeId); } this.splitValue = in.readDouble(); - if (in.readBoolean()) { - this.trueChild = new Node(); - trueChild.readExternal(in); - } - if (in.readBoolean()) { - this.falseChild = new Node(); - falseChild.readExternal(in); + + if (in.readBoolean()) {//isLeaf + this.output = in.readInt(); + + final int size = in.readInt(); + final double[] posteriori = new double[size]; + for (int i = 0; i < size; i++) { + posteriori[i] = in.readDouble(); + } + this.posteriori = posteriori; + } else { + if (in.readBoolean()) { + this.trueChild = new Node(); + trueChild.readExternal(in); + } + if (in.readBoolean()) { + this.falseChild = new Node(); + falseChild.readExternal(in); + } } } @@ -413,7 +417,7 @@ public final class DecisionTree implements Classifier<double[]> { /** * Training dataset. */ - final double[][] x; + final Matrix x; /** * class labels. */ @@ -426,7 +430,7 @@ public final class DecisionTree implements Classifier<double[]> { /** * Constructor. */ - public TrainNode(Node node, double[][] x, int[] y, int[] bags, int depth) { + public TrainNode(Node node, Matrix x, int[] y, int[] bags, int depth) { this.node = node; this.x = x; this.y = y; @@ -466,21 +470,12 @@ public final class DecisionTree implements Classifier<double[]> { final double impurity = impurity(count, numSamples, _rule); - final int p = _attributes.length; - final int[] variableIndex = new int[p]; - for (int i = 0; i < p; i++) { - variableIndex[i] = i; - } - if (_numVars < p) { - SmileExtUtils.shuffle(variableIndex, _rnd); - } - - final int[] samples = _hasNumericType ? SmileExtUtils.bagsToSamples(bags, x.length) + final int[] samples = _hasNumericType ? SmileExtUtils.bagsToSamples(bags, x.numRows()) : null; final int[] falseCount = new int[_k]; - for (int j = 0; j < _numVars; j++) { - Node split = findBestSplit(numSamples, count, falseCount, impurity, - variableIndex[j], samples); + for (int varJ : variableIndex(x, bags)) { + final Node split = findBestSplit(numSamples, count, falseCount, impurity, varJ, + samples); if (split.splitScore > node.splitScore) { node.splitFeature = split.splitFeature; node.splitFeatureType = split.splitFeatureType; @@ -491,7 +486,33 @@ public final class DecisionTree implements Classifier<double[]> { } } - return (node.splitFeature != -1); + return node.splitFeature != -1; + } + + @Nonnull + private int[] variableIndex(@Nonnull final Matrix x, @Nonnull final int[] bags) { + final IntReservoirSampler sampler = new IntReservoirSampler(_numVars, _rnd.nextLong()); + if (x.isSparse()) { + final RoaringBitmap cols = new RoaringBitmap(); + final VectorProcedure proc = new VectorProcedure() { + public void apply(final int col) { + cols.add(col); + } + }; + for (final int row : bags) { + x.eachColumnIndexInRow(row, proc); + } + cols.forEach(new IntConsumer() { + public void accept(final int k) { + sampler.add(k); + } + }); + } else { + for (int i = 0, size = _attributes.length; i < size; i++) { + sampler.add(i); + } + } + return sampler.getSample(); } private boolean sampleCount(@Nonnull final int[] count) { @@ -530,7 +551,11 @@ public final class DecisionTree implements Classifier<double[]> { for (int i = 0, size = bags.length; i < size; i++) { int index = bags[i]; - int x_ij = (int) x[index][j]; + final double v = x.get(index, j, Double.NaN); + if (Double.isNaN(v)) { + continue; + } + int x_ij = (int) v; trueCount[x_ij][y[index]]++; } @@ -563,21 +588,28 @@ public final class DecisionTree implements Classifier<double[]> { } } else if (_attributes[j].type == AttributeType.NUMERIC) { final int[] trueCount = new int[_k]; - double prevx = Double.NaN; - int prevy = -1; - - assert (samples != null); - for (final int i : _order[j]) { - final int sample = samples[i]; - if (sample > 0) { - final double x_ij = x[i][j]; + + _order.eachNonNullInColumn(j, new VectorProcedure() { + double prevx = Double.NaN; + int prevy = -1; + + public void apply(final int row, final int i) { + final int sample = samples[i]; + if (sample == 0) { + return; + } + + final double x_ij = x.get(i, j, Double.NaN); + if (Double.isNaN(x_ij)) { + return; + } final int y_i = y[i]; if (Double.isNaN(prevx) || x_ij == prevx || y_i == prevy) { prevx = x_ij; prevy = y_i; trueCount[y_i] += sample; - continue; + return; } final int tc = Math.sum(trueCount); @@ -588,7 +620,7 @@ public final class DecisionTree implements Classifier<double[]> { prevx = x_ij; prevy = y_i; trueCount[y_i] += sample; - continue; + return; } for (int l = 0; l < _k; l++) { @@ -612,8 +644,8 @@ public final class DecisionTree implements Classifier<double[]> { prevx = x_ij; prevy = y_i; trueCount[y_i] += sample; - } - } + }//apply() + }); } else { throw new IllegalStateException("Unsupported attribute type: " + _attributes[j].type); @@ -634,7 +666,9 @@ public final class DecisionTree implements Classifier<double[]> { int childBagSize = (int) (bags.length * 0.4); IntArrayList trueBags = new IntArrayList(childBagSize); IntArrayList falseBags = new IntArrayList(childBagSize); - int tc = splitSamples(trueBags, falseBags); + double[] trueChildPosteriori = new double[_k]; + double[] falseChildPosteriori = new double[_k]; + int tc = splitSamples(trueBags, falseBags, trueChildPosteriori, falseChildPosteriori); int fc = bags.length - tc; this.bags = null; // help GC for recursive call @@ -647,7 +681,12 @@ public final class DecisionTree implements Classifier<double[]> { return false; } - node.trueChild = new Node(node.trueChildOutput); + for (int i = 0; i < _k; i++) { + trueChildPosteriori[i] /= tc; + falseChildPosteriori[i] /= fc; + } + + node.trueChild = new Node(node.trueChildOutput, trueChildPosteriori); TrainNode trueChild = new TrainNode(node.trueChild, x, y, trueBags.toArray(), depth + 1); trueBags = null; // help GC for recursive call if (tc >= _minSplit && trueChild.findBestSplit()) { @@ -658,7 +697,7 @@ public final class DecisionTree implements Classifier<double[]> { } } - node.falseChild = new Node(node.falseChildOutput); + node.falseChild = new Node(node.falseChildOutput, falseChildPosteriori); TrainNode falseChild = new TrainNode(node.falseChild, x, y, falseBags.toArray(), depth + 1); falseBags = null; // help GC for recursive call @@ -670,27 +709,33 @@ public final class DecisionTree implements Classifier<double[]> { } } - _importance[node.splitFeature] += node.splitScore; + _importance.incr(node.splitFeature, node.splitScore); + node.posteriori = null; // posteriori is not needed for non-leaf nodes return true; } /** + * @param falseChildPosteriori + * @param trueChildPosteriori * @return the number of true samples */ private int splitSamples(@Nonnull final IntArrayList trueBags, - @Nonnull final IntArrayList falseBags) { + @Nonnull final IntArrayList falseBags, @Nonnull final double[] trueChildPosteriori, + @Nonnull final double[] falseChildPosteriori) { int tc = 0; if (node.splitFeatureType == AttributeType.NOMINAL) { final int splitFeature = node.splitFeature; final double splitValue = node.splitValue; for (int i = 0, size = bags.length; i < size; i++) { final int index = bags[i]; - if (x[index][splitFeature] == splitValue) { + if (x.get(index, splitFeature, Double.NaN) == splitValue) { trueBags.add(index); + trueChildPosteriori[y[index]]++; tc++; } else { falseBags.add(index); + falseChildPosteriori[y[index]]++; } } } else if (node.splitFeatureType == AttributeType.NUMERIC) { @@ -698,11 +743,13 @@ public final class DecisionTree implements Classifier<double[]> { final double splitValue = node.splitValue; for (int i = 0, size = bags.length; i < size; i++) { final int index = bags[i]; - if (x[index][splitFeature] <= splitValue) { + if (x.get(index, splitFeature, Double.NaN) <= splitValue) { trueBags.add(index); + trueChildPosteriori[y[index]]++; tc++; } else { falseBags.add(index); + falseChildPosteriori[y[index]]++; } } } else { @@ -714,7 +761,6 @@ public final class DecisionTree implements Classifier<double[]> { } - /** * Returns the impurity of a node. * @@ -731,8 +777,9 @@ public final class DecisionTree implements Classifier<double[]> { case GINI: { impurity = 1.0; for (int i = 0; i < count.length; i++) { - if (count[i] > 0) { - double p = (double) count[i] / n; + final int count_i = count[i]; + if (count_i > 0) { + double p = (double) count_i / n; impurity -= p * p; } } @@ -740,8 +787,9 @@ public final class DecisionTree implements Classifier<double[]> { } case ENTROPY: { for (int i = 0; i < count.length; i++) { - if (count[i] > 0) { - double p = (double) count[i] / n; + final int count_i = count[i]; + if (count_i > 0) { + double p = (double) count_i / n; impurity -= p * Math.log2(p); } } @@ -750,8 +798,9 @@ public final class DecisionTree implements Classifier<double[]> { case CLASSIFICATION_ERROR: { impurity = 0.d; for (int i = 0; i < count.length; i++) { - if (count[i] > 0) { - impurity = Math.max(impurity, (double) count[i] / n); + final int count_i = count[i]; + if (count_i > 0) { + impurity = Math.max(impurity, (double) count_i / n); } } impurity = Math.abs(1.d - impurity); @@ -762,14 +811,14 @@ public final class DecisionTree implements Classifier<double[]> { return impurity; } - public DecisionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @Nonnull int[] y, + public DecisionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull int[] y, int numLeafs) { - this(attributes, x, y, x[0].length, Integer.MAX_VALUE, numLeafs, 2, 1, null, null, SplitRule.GINI, null); + this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, numLeafs, 2, 1, null, null, SplitRule.GINI, null); } - public DecisionTree(@Nullable Attribute[] attributes, @Nullable double[][] x, - @Nullable int[] y, int numLeafs, @Nullable smile.math.Random rand) { - this(attributes, x, y, x[0].length, Integer.MAX_VALUE, numLeafs, 2, 1, null, null, SplitRule.GINI, rand); + public DecisionTree(@Nullable Attribute[] attributes, @Nullable Matrix x, @Nullable int[] y, + int numLeafs, @Nullable PRNG rand) { + this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, numLeafs, 2, 1, null, null, SplitRule.GINI, rand); } /** @@ -778,21 +827,20 @@ public final class DecisionTree implements Classifier<double[]> { * @param attributes the attribute properties. * @param x the training instances. * @param y the response variable. - * @param numVars the number of input variables to pick to split on at each node. It seems that - * dim/3 give generally good performance, where dim is the number of variables. + * @param numVars the number of input variables to pick to split on at each node. It seems that dim/3 give generally good performance, where dim + * is the number of variables. * @param maxLeafs the maximum number of leaf nodes in the tree. * @param minSplits the number of minimum elements in a node to split * @param minLeafSize the minimum size of leaf nodes. - * @param order the index of training values in ascending order. Note that only numeric - * attributes need be sorted. + * @param order the index of training values in ascending order. Note that only numeric attributes need be sorted. * @param bags the sample set of instances for stochastic learning. * @param rule the splitting rule. * @param seed */ - public DecisionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, @Nonnull int[] y, + public DecisionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull int[] y, int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize, - @Nullable int[] bags, @Nullable int[][] order, @Nonnull SplitRule rule, - @Nullable smile.math.Random rand) { + @Nullable int[] bags, @Nullable ColumnMajorIntMatrix order, @Nonnull SplitRule rule, + @Nullable PRNG rand) { checkArgument(x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize); this._k = Math.max(y) + 1; @@ -801,7 +849,7 @@ public final class DecisionTree implements Classifier<double[]> { } this._attributes = SmileExtUtils.attributeTypes(attributes, x); - if (attributes.length != x[0].length) { + if (attributes.length != x.numColumns()) { throw new IllegalArgumentException("-attrs option is invliad: " + Arrays.toString(attributes)); } @@ -813,8 +861,8 @@ public final class DecisionTree implements Classifier<double[]> { this._minLeafSize = minLeafSize; this._rule = rule; this._order = (order == null) ? SmileExtUtils.sort(_attributes, x) : order; - this._importance = new double[_attributes.length]; - this._rnd = (rand == null) ? new smile.math.Random() : rand; + this._importance = x.isSparse() ? new SparseVector() : new DenseVector(_attributes.length); + this._rnd = (rand == null) ? RandomNumberGeneratorFactory.createPRNG() : rand; final int n = y.length; final int[] count = new int[_k]; @@ -825,13 +873,17 @@ public final class DecisionTree implements Classifier<double[]> { count[y[i]]++; } } else { - for (int i = 0; i < n; i++) { + for (int i = 0, size = bags.length; i < size; i++) { int index = bags[i]; count[y[index]]++; } } - this._root = new Node(Math.whichMax(count)); + final double[] posteriori = new double[_k]; + for (int i = 0; i < _k; i++) { + posteriori[i] = (double) count[i] / n; + } + this._root = new Node(Math.whichMax(count), posteriori); final TrainNode trainRoot = new TrainNode(_root, x, y, bags, 1); if (maxLeafs == Integer.MAX_VALUE) { @@ -858,13 +910,13 @@ public final class DecisionTree implements Classifier<double[]> { } } - private static void checkArgument(@Nonnull double[][] x, @Nonnull int[] y, int numVars, + private static void checkArgument(@Nonnull Matrix x, @Nonnull int[] y, int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize) { - if (x.length != y.length) { + if (x.numRows() != y.length) { throw new IllegalArgumentException(String.format( - "The sizes of X and Y don't match: %d != %d", x.length, y.length)); + "The sizes of X and Y don't match: %d != %d", x.numRows(), y.length)); } - if (numVars <= 0 || numVars > x[0].length) { + if (numVars <= 0 || numVars > x.numColumns()) { throw new IllegalArgumentException( "Invalid number of variables to split on at a node of the tree: " + numVars); } @@ -885,28 +937,31 @@ public final class DecisionTree implements Classifier<double[]> { } /** - * Returns the variable importance. Every time a split of a node is made on variable the (GINI, - * information gain, etc.) impurity criterion for the two descendent nodes is less than the - * parent node. Adding up the decreases for each individual variable over the tree gives a - * simple measure of variable importance. + * Returns the variable importance. Every time a split of a node is made on variable the (GINI, information gain, etc.) impurity criterion for the + * two descendent nodes is less than the parent node. Adding up the decreases for each individual variable over the tree gives a simple measure of + * variable importance. * * @return the variable importance */ - public double[] importance() { + @Nonnull + public Vector importance() { return _importance; } + @VisibleForTesting + public int predict(@Nonnull final double[] x) { + return predict(new DenseVector(x)); + } + @Override - public int predict(final double[] x) { + public int predict(@Nonnull final Vector x) { return _root.predict(x); } /** - * Predicts the class label of an instance and also calculate a posteriori probabilities. Not - * supported. + * Predicts the class label of an instance and also calculate a posteriori probabilities. Not supported. */ - @Override - public int predict(double[] x, double[] posteriori) { + public int predict(Vector x, double[] posteriori) { throw new UnsupportedOperationException("Not supported."); } @@ -916,14 +971,6 @@ public final class DecisionTree implements Classifier<double[]> { return buf.toString(); } - public String predictOpCodegen(String sep) { - List<String> opslist = new ArrayList<String>(); - _root.opCodegen(opslist, 0); - opslist.add("call end"); - String scripts = StringUtils.concat(opslist, sep); - return scripts; - } - @Nonnull public byte[] predictSerCodegen(boolean compress) throws HiveException { try { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java index 3a0924e..a380a11 100644 --- a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java +++ b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java @@ -19,24 +19,27 @@ package hivemall.smile.classification; import hivemall.UDTFWithOptions; -import hivemall.smile.ModelType; +import hivemall.math.matrix.Matrix; +import hivemall.math.matrix.builders.CSRMatrixBuilder; +import hivemall.math.matrix.builders.MatrixBuilder; +import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder; +import hivemall.math.matrix.ints.ColumnMajorIntMatrix; +import hivemall.math.random.PRNG; +import hivemall.math.random.RandomNumberGeneratorFactory; +import hivemall.math.vector.Vector; import hivemall.smile.data.Attribute; import hivemall.smile.regression.RegressionTree; import hivemall.smile.utils.SmileExtUtils; -import hivemall.smile.vm.StackMachine; import hivemall.utils.codec.Base91; -import hivemall.utils.codec.DeflateCodec; -import hivemall.utils.collections.IntArrayList; +import hivemall.utils.collections.lists.IntArrayList; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.hadoop.WritableUtils; -import hivemall.utils.io.IOUtils; import hivemall.utils.lang.Primitives; +import hivemall.utils.math.MathUtils; -import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.BitSet; -import java.util.List; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -63,7 +66,7 @@ import org.apache.hadoop.mapred.Counters.Counter; import org.apache.hadoop.mapred.Reporter; @Description(name = "train_gradient_tree_boosting_classifier", - value = "_FUNC_(double[] features, int label [, string options]) - " + value = "_FUNC_(array<double|string> features, int label [, string options]) - " + "Returns a relation consists of " + "<int iteration, int model_type, array<string> pred_models, double intercept, " + "double shrinkage, array<double> var_importance, float oob_error_rate>") @@ -74,7 +77,8 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { private PrimitiveObjectInspector featureElemOI; private PrimitiveObjectInspector labelOI; - private List<double[]> featuresList; + private boolean denseInput; + private MatrixBuilder matrixBuilder; private IntArrayList labels; /** * The number of trees for each task @@ -104,7 +108,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { private int _minSamplesLeaf; private long _seed; private Attribute[] _attributes; - private ModelType _outputType; @Nullable private Reporter _progressReporter; @@ -134,10 +137,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { opts.addOption("seed", true, "seed value in long [default: -1 (random)]"); opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types " + "(Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])"); - opts.addOption("output", "output_type", true, - "The output type (serialization/ser or opscode/vm or javascript/js) [default: serialization]"); - opts.addOption("disable_compression", false, - "Whether to disable compression of the output script [default: false]"); return opts; } @@ -149,8 +148,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { double eta = 0.05d, subsample = 0.7d; Attribute[] attrs = null; long seed = -1L; - String output = "serialization"; - boolean compress = true; CommandLine cl = null; if (argOIs.length >= 3) { @@ -171,10 +168,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { minSamplesLeaf); seed = Primitives.parseLong(cl.getOptionValue("seed"), seed); attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types")); - output = cl.getOptionValue("output", output); - if (cl.hasOption("disable_compression")) { - compress = false; - } } this._numTrees = trees; @@ -187,7 +180,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { this._minSamplesLeaf = minSamplesLeaf; this._seed = seed; this._attributes = attrs; - this._outputType = ModelType.resolve(output, compress); return cl; } @@ -197,19 +189,29 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { if (argOIs.length != 2 && argOIs.length != 3) { throw new UDFArgumentException( getClass().getSimpleName() - + " takes 2 or 3 arguments: double[] features, int label [, const string options]: " + + " takes 2 or 3 arguments: array<double|string> features, int label [, const string options]: " + argOIs.length); } ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]); ObjectInspector elemOI = listOI.getListElementObjectInspector(); this.featureListOI = listOI; - this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI); + if (HiveUtils.isNumberOI(elemOI)) { + this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI); + this.denseInput = true; + this.matrixBuilder = new RowMajorDenseMatrixBuilder(8192); + } else if (HiveUtils.isStringOI(elemOI)) { + this.featureElemOI = HiveUtils.asStringOI(elemOI); + this.denseInput = false; + this.matrixBuilder = new CSRMatrixBuilder(8192); + } else { + throw new UDFArgumentException( + "_FUNC_ takes double[] or string[] for the first argument: " + listOI.getTypeName()); + } this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]); processOptions(argOIs); - this.featuresList = new ArrayList<double[]>(1024); this.labels = new IntArrayList(1024); ArrayList<String> fieldNames = new ArrayList<String>(6); @@ -217,8 +219,6 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { fieldNames.add("iteration"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); - fieldNames.add("model_type"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); fieldNames.add("pred_models"); fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector)); fieldNames.add("intercept"); @@ -238,13 +238,36 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { if (args[0] == null) { throw new HiveException("array<double> features was null"); } - double[] features = HiveUtils.asDoubleArray(args[0], featureListOI, featureElemOI); + parseFeatures(args[0], matrixBuilder); int label = PrimitiveObjectInspectorUtils.getInt(args[1], labelOI); - - featuresList.add(features); labels.add(label); } + private void parseFeatures(@Nonnull final Object argObj, @Nonnull final MatrixBuilder builder) { + if (denseInput) { + final int length = featureListOI.getListLength(argObj); + for (int i = 0; i < length; i++) { + Object o = featureListOI.getListElement(argObj, i); + if (o == null) { + continue; + } + double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI); + builder.nextColumn(i, v); + } + } else { + final int length = featureListOI.getListLength(argObj); + for (int i = 0; i < length; i++) { + Object o = featureListOI.getListElement(argObj, i); + if (o == null) { + continue; + } + String fv = o.toString(); + builder.nextColumn(fv); + } + } + builder.nextRow(); + } + @Override public void close() throws HiveException { this._progressReporter = getReporter(); @@ -252,14 +275,15 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { "hivemall.smile.GradientTreeBoostingClassifier$Counter", "iteration"); reportProgress(_progressReporter); - int numExamples = featuresList.size(); - double[][] x = featuresList.toArray(new double[numExamples][]); - this.featuresList = null; - int[] y = labels.toArray(); - this.labels = null; + if (!labels.isEmpty()) { + Matrix x = matrixBuilder.buildMatrix(); + this.matrixBuilder = null; + int[] y = labels.toArray(); + this.labels = null; - // run training - train(x, y); + // run training + train(x, y); + } // clean up this.featureListOI = null; @@ -287,25 +311,25 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { * @param x features * @param y label */ - private void train(@Nonnull final double[][] x, @Nonnull final int[] y) throws HiveException { - if (x.length != y.length) { + private void train(@Nonnull Matrix x, @Nonnull final int[] y) throws HiveException { + final int numRows = x.numRows(); + if (numRows != y.length) { throw new HiveException(String.format("The sizes of X and Y don't match: %d != %d", - x.length, y.length)); + numRows, y.length)); } checkOptions(); this._attributes = SmileExtUtils.attributeTypes(_attributes, x); // Shuffle training samples - SmileExtUtils.shuffle(x, y, _seed); + x = SmileExtUtils.shuffle(x, y, _seed); final int k = smile.math.Math.max(y) + 1; if (k < 2) { throw new UDFArgumentException("Only one class or negative class labels."); } if (k == 2) { - int n = x.length; - final int[] y2 = new int[n]; - for (int i = 0; i < n; i++) { + final int[] y2 = new int[numRows]; + for (int i = 0; i < numRows; i++) { if (y[i] == 1) { y2[i] = 1; } else { @@ -318,7 +342,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { } } - private void train2(@Nonnull final double[][] x, @Nonnull final int[] y) throws HiveException { + private void train2(@Nonnull final Matrix x, @Nonnull final int[] y) throws HiveException { final int numVars = SmileExtUtils.computeNumInputVars(_numVars, x); if (logger.isInfoEnabled()) { logger.info("k: " + 2 + ", numTrees: " + _numTrees + ", shirinkage: " + _eta @@ -327,7 +351,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { + _maxLeafNodes + ", seed: " + _seed); } - final int numInstances = x.length; + final int numInstances = x.numRows(); final int numSamples = (int) Math.round(numInstances * _subsample); final double[] h = new double[numInstances]; // current F(x_i) @@ -340,7 +364,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { h[i] = intercept; } - final int[][] order = SmileExtUtils.sort(_attributes, x); + final ColumnMajorIntMatrix order = SmileExtUtils.sort(_attributes, x); final RegressionTree.NodeOutput output = new L2NodeOutput(response); final BitSet sampled = new BitSet(numInstances); @@ -351,10 +375,11 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { } long s = (this._seed == -1L) ? SmileExtUtils.generateSeed() - : new smile.math.Random(_seed).nextLong(); - final smile.math.Random rnd1 = new smile.math.Random(s); - final smile.math.Random rnd2 = new smile.math.Random(rnd1.nextLong()); + : RandomNumberGeneratorFactory.createPRNG(_seed).nextLong(); + final PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s); + final PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong()); + final Vector xProbe = x.rowVector(); for (int m = 0; m < _numTrees; m++) { reportProgress(_progressReporter); @@ -373,7 +398,8 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { _maxLeafNodes, _minSamplesSplit, _minSamplesLeaf, order, bag, output, rnd2); for (int i = 0; i < numInstances; i++) { - h[i] += _eta * tree.predict(x[i]); + x.getRow(i, xProbe); + h[i] += _eta * tree.predict(xProbe); } // out-of-bag error estimate @@ -398,7 +424,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { /** * Train L-k tree boost. */ - private void traink(final double[][] x, final int[] y, final int k) throws HiveException { + private void traink(final Matrix x, final int[] y, final int k) throws HiveException { final int numVars = SmileExtUtils.computeNumInputVars(_numVars, x); if (logger.isInfoEnabled()) { logger.info("k: " + k + ", numTrees: " + _numTrees + ", shirinkage: " + _eta @@ -407,14 +433,14 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { + ", maxLeafs: " + _maxLeafNodes + ", seed: " + _seed); } - final int numInstances = x.length; + final int numInstances = x.numRows(); final int numSamples = (int) Math.round(numInstances * _subsample); final double[][] h = new double[k][numInstances]; // boost tree output. final double[][] p = new double[k][numInstances]; // posteriori probabilities. final double[][] response = new double[k][numInstances]; // pseudo response. - final int[][] order = SmileExtUtils.sort(_attributes, x); + final ColumnMajorIntMatrix order = SmileExtUtils.sort(_attributes, x); final RegressionTree.NodeOutput[] output = new LKNodeOutput[k]; for (int i = 0; i < k; i++) { output[i] = new LKNodeOutput(response[i], k); @@ -422,19 +448,16 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { final BitSet sampled = new BitSet(numInstances); final int[] bag = new int[numSamples]; - final int[] perm = new int[numSamples]; - for (int i = 0; i < numSamples; i++) { - perm[i] = i; - } + final int[] perm = MathUtils.permutation(numInstances); long s = (this._seed == -1L) ? SmileExtUtils.generateSeed() - : new smile.math.Random(_seed).nextLong(); - final smile.math.Random rnd1 = new smile.math.Random(s); - final smile.math.Random rnd2 = new smile.math.Random(rnd1.nextLong()); + : RandomNumberGeneratorFactory.createPRNG(_seed).nextLong(); + final PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s); + final PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong()); // out-of-bag prediction final int[] prediction = new int[numInstances]; - + final Vector xProbe = x.rowVector(); for (int m = 0; m < _numTrees; m++) { for (int i = 0; i < numInstances; i++) { double max = Double.NEGATIVE_INFINITY; @@ -490,7 +513,8 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { trees[j] = tree; for (int i = 0; i < numInstances; i++) { - double h_ji = h_j[i] + _eta * tree.predict(x[i]); + x.getRow(i, xProbe); + double h_ji = h_j[i] + _eta * tree.predict(xProbe); h_j[i] += h_ji; if (h_ji > max_h) { max_h = h_ji; @@ -524,7 +548,7 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { */ private void forward(final int m, final double intercept, final double shrinkage, final float oobErrorRate, @Nonnull final RegressionTree... trees) throws HiveException { - Text[] models = getModel(trees, _outputType); + Text[] models = getModel(trees); double[] importance = new double[_attributes.length]; for (RegressionTree tree : trees) { @@ -534,14 +558,13 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { } } - Object[] forwardObjs = new Object[7]; + Object[] forwardObjs = new Object[6]; forwardObjs[0] = new IntWritable(m); - forwardObjs[1] = new IntWritable(_outputType.getId()); - forwardObjs[2] = models; - forwardObjs[3] = new DoubleWritable(intercept); - forwardObjs[4] = new DoubleWritable(shrinkage); - forwardObjs[5] = WritableUtils.toWritableList(importance); - forwardObjs[6] = new FloatWritable(oobErrorRate); + forwardObjs[1] = models; + forwardObjs[2] = new DoubleWritable(intercept); + forwardObjs[3] = new DoubleWritable(shrinkage); + forwardObjs[4] = WritableUtils.toWritableList(importance); + forwardObjs[5] = new FloatWritable(oobErrorRate); forward(forwardObjs); @@ -551,67 +574,14 @@ public final class GradientTreeBoostingClassifierUDTF extends UDTFWithOptions { logger.info("Forwarded the output of " + m + "-th Boosting iteration out of " + _numTrees); } - private static Text[] getModel(@Nonnull final RegressionTree[] trees, - @Nonnull final ModelType outputType) throws HiveException { + @Nonnull + private static Text[] getModel(@Nonnull final RegressionTree[] trees) throws HiveException { final int m = trees.length; final Text[] models = new Text[m]; - switch (outputType) { - case serialization: - case serialization_compressed: { - for (int i = 0; i < m; i++) { - byte[] b = trees[i].predictSerCodegen(outputType.isCompressed()); - b = Base91.encode(b); - models[i] = new Text(b); - } - break; - } - case opscode: - case opscode_compressed: { - for (int i = 0; i < m; i++) { - String s = trees[i].predictOpCodegen(StackMachine.SEP); - if (outputType.isCompressed()) { - byte[] b = s.getBytes(); - final DeflateCodec codec = new DeflateCodec(true, false); - try { - b = codec.compress(b); - } catch (IOException e) { - throw new HiveException("Failed to compressing a model", e); - } finally { - IOUtils.closeQuietly(codec); - } - b = Base91.encode(b); - models[i] = new Text(b); - } else { - models[i] = new Text(s); - } - } - break; - } - case javascript: - case javascript_compressed: { - for (int i = 0; i < m; i++) { - String s = trees[i].predictJsCodegen(); - if (outputType.isCompressed()) { - byte[] b = s.getBytes(); - final DeflateCodec codec = new DeflateCodec(true, false); - try { - b = codec.compress(b); - } catch (IOException e) { - throw new HiveException("Failed to compressing a model", e); - } finally { - IOUtils.closeQuietly(codec); - } - b = Base91.encode(b); - models[i] = new Text(b); - } else { - models[i] = new Text(s); - } - } - break; - } - default: - throw new HiveException("Unexpected output type: " + outputType - + ". Use javascript for the output instead"); + for (int i = 0; i < m; i++) { + byte[] b = trees[i].predictSerCodegen(true); + b = Base91.encode(b); + models[i] = new Text(b); } return models; } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/classification/PredictionHandler.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/classification/PredictionHandler.java b/core/src/main/java/hivemall/smile/classification/PredictionHandler.java new file mode 100644 index 0000000..84ef244 --- /dev/null +++ b/core/src/main/java/hivemall/smile/classification/PredictionHandler.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.smile.classification; + +import javax.annotation.Nonnull; + +public interface PredictionHandler { + + void handle(int output, @Nonnull double[] posteriori); + +}