http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java index 03db65c..5a831df 100644 --- a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java +++ b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java @@ -19,30 +19,43 @@ package hivemall.smile.classification; import hivemall.UDTFWithOptions; -import hivemall.smile.ModelType; +import hivemall.math.matrix.Matrix; +import hivemall.math.matrix.MatrixUtils; +import hivemall.math.matrix.builders.CSRMatrixBuilder; +import hivemall.math.matrix.builders.MatrixBuilder; +import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder; +import hivemall.math.matrix.ints.ColumnMajorIntMatrix; +import hivemall.math.matrix.ints.DoKIntMatrix; +import hivemall.math.matrix.ints.IntMatrix; +import hivemall.math.random.PRNG; +import hivemall.math.random.RandomNumberGeneratorFactory; +import hivemall.math.vector.Vector; +import hivemall.math.vector.VectorProcedure; import hivemall.smile.classification.DecisionTree.SplitRule; import hivemall.smile.data.Attribute; import hivemall.smile.utils.SmileExtUtils; import hivemall.smile.utils.SmileTaskExecutor; -import hivemall.smile.vm.StackMachine; import hivemall.utils.codec.Base91; -import hivemall.utils.codec.DeflateCodec; -import hivemall.utils.collections.IntArrayList; +import hivemall.utils.collections.lists.IntArrayList; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.hadoop.WritableUtils; -import hivemall.utils.io.IOUtils; +import hivemall.utils.lang.Preconditions; import hivemall.utils.lang.Primitives; import hivemall.utils.lang.RandomUtils; -import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.BitSet; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; @@ -52,7 +65,9 @@ import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.MapredContext; import org.apache.hadoop.hive.ql.exec.MapredContextAccessor; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; @@ -67,9 +82,9 @@ import org.apache.hadoop.mapred.Reporter; @Description( name = "train_randomforest_classifier", - value = "_FUNC_(double[] features, int label [, string options]) - " + value = "_FUNC_(array<double|string> features, int label [, const array<double> classWeights, const string options]) - " + "Returns a relation consists of " - + "<int model_id, int model_type, string pred_model, array<double> var_importance, int oob_errors, int oob_tests>") + + "<int model_id, int model_type, string pred_model, array<double> var_importance, int oob_errors, int oob_tests, double weight>") public final class RandomForestClassifierUDTF extends UDTFWithOptions { private static final Log logger = LogFactory.getLog(RandomForestClassifierUDTF.class); @@ -77,8 +92,10 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions { private PrimitiveObjectInspector featureElemOI; private PrimitiveObjectInspector labelOI; - private List<double[]> featuresList; + private boolean denseInput; + private MatrixBuilder matrixBuilder; private IntArrayList labels; + /** * The number of trees for each task */ @@ -99,8 +116,12 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions { private int _minSamplesLeaf; private long _seed; private Attribute[] _attributes; - private ModelType _outputType; private SplitRule _splitRule; + private boolean _stratifiedSampling; + private double _subsample; + + @Nullable + private double[] _classWeight; @Nullable private Reporter _progressReporter; @@ -126,11 +147,10 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions { opts.addOption("seed", true, "seed value in long [default: -1 (random)]"); opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types " + "(Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])"); - opts.addOption("output", "output_type", true, - "The output type (serialization/ser or opscode/vm or javascript/js) [default: serialization]"); opts.addOption("rule", "split_rule", true, "Split algorithm [default: GINI, ENTROPY]"); - opts.addOption("disable_compression", false, - "Whether to disable compression of the output script [default: false]"); + opts.addOption("stratified", "stratified_sampling", false, + "Enable Stratified sampling for unbalanced data"); + opts.addOption("subsample", true, "Sampling rate in range (0.0,1.0]"); return opts; } @@ -141,9 +161,10 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions { float numVars = -1.f; Attribute[] attrs = null; long seed = -1L; - String output = "serialization"; SplitRule splitRule = SplitRule.GINI; - boolean compress = true; + double[] classWeight = null; + boolean stratifiedSampling = false; + double subsample = 1.0d; CommandLine cl = null; if (argOIs.length >= 3) { @@ -162,10 +183,26 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions { minSamplesLeaf); seed = Primitives.parseLong(cl.getOptionValue("seed"), seed); attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types")); - output = cl.getOptionValue("output", output); splitRule = SmileExtUtils.resolveSplitRule(cl.getOptionValue("split_rule", "GINI")); - if (cl.hasOption("disable_compression")) { - compress = false; + stratifiedSampling = cl.hasOption("stratified_sampling"); + subsample = Primitives.parseDouble(cl.getOptionValue("subsample"), 1.0d); + Preconditions.checkArgument(subsample > 0.d && subsample <= 1.0d, + UDFArgumentException.class, "Invalid -subsample value: " + subsample); + + if (argOIs.length >= 4) { + classWeight = HiveUtils.getConstDoubleArray(argOIs[3]); + if (classWeight != null) { + for (int i = 0; i < classWeight.length; i++) { + double v = classWeight[i]; + if (Double.isNaN(v)) { + classWeight[i] = 1.0d; + } else if (v <= 0.d) { + throw new UDFArgumentTypeException(3, + "each classWeight must be greather than 0: " + + Arrays.toString(classWeight)); + } + } + } } } @@ -177,43 +214,60 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions { this._minSamplesLeaf = minSamplesLeaf; this._seed = seed; this._attributes = attrs; - this._outputType = ModelType.resolve(output, compress); this._splitRule = splitRule; + this._stratifiedSampling = stratifiedSampling; + this._subsample = subsample; + this._classWeight = classWeight; return cl; } @Override public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { - if (argOIs.length != 2 && argOIs.length != 3) { + if (argOIs.length < 2 || argOIs.length > 4) { throw new UDFArgumentException( - getClass().getSimpleName() - + " takes 2 or 3 arguments: double[] features, int label [, const string options]: " + "_FUNC_ takes 2 ~ 4 arguments: array<double|string> features, int label [, const string options, const array<double> classWeight]: " + argOIs.length); } ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]); ObjectInspector elemOI = listOI.getListElementObjectInspector(); this.featureListOI = listOI; - this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI); + if (HiveUtils.isNumberOI(elemOI)) { + this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI); + this.denseInput = true; + this.matrixBuilder = new RowMajorDenseMatrixBuilder(8192); + } else if (HiveUtils.isStringOI(elemOI)) { + this.featureElemOI = HiveUtils.asStringOI(elemOI); + this.denseInput = false; + this.matrixBuilder = new CSRMatrixBuilder(8192); + } else { + throw new UDFArgumentException( + "_FUNC_ takes double[] or string[] for the first argument: " + listOI.getTypeName()); + } this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]); processOptions(argOIs); - this.featuresList = new ArrayList<double[]>(1024); this.labels = new IntArrayList(1024); - ArrayList<String> fieldNames = new ArrayList<String>(6); - ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(6); + final ArrayList<String> fieldNames = new ArrayList<String>(6); + final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(6); fieldNames.add("model_id"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); - fieldNames.add("model_type"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); - fieldNames.add("pred_model"); + fieldNames.add("model_weight"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + fieldNames.add("model"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); fieldNames.add("var_importance"); - fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); + if (denseInput) { + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); + } else { + fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.writableIntObjectInspector, + PrimitiveObjectInspectorFactory.writableDoubleObjectInspector)); + } fieldNames.add("oob_errors"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); fieldNames.add("oob_tests"); @@ -227,13 +281,36 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions { if (args[0] == null) { throw new HiveException("array<double> features was null"); } - double[] features = HiveUtils.asDoubleArray(args[0], featureListOI, featureElemOI); + parseFeatures(args[0], matrixBuilder); int label = PrimitiveObjectInspectorUtils.getInt(args[1], labelOI); - - featuresList.add(features); labels.add(label); } + private void parseFeatures(@Nonnull final Object argObj, @Nonnull final MatrixBuilder builder) { + if (denseInput) { + final int length = featureListOI.getListLength(argObj); + for (int i = 0; i < length; i++) { + Object o = featureListOI.getListElement(argObj, i); + if (o == null) { + continue; + } + double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI); + builder.nextColumn(i, v); + } + } else { + final int length = featureListOI.getListLength(argObj); + for (int i = 0; i < length; i++) { + Object o = featureListOI.getListElement(argObj, i); + if (o == null) { + continue; + } + String fv = o.toString(); + builder.nextColumn(fv); + } + } + builder.nextRow(); + } + @Override public void close() throws HiveException { this._progressReporter = getReporter(); @@ -242,10 +319,9 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions { "finishedTreeBuildTasks"); reportProgress(_progressReporter); - int numExamples = featuresList.size(); - if (numExamples > 0) { - double[][] x = featuresList.toArray(new double[numExamples][]); - this.featuresList = null; + if (!labels.isEmpty()) { + Matrix x = matrixBuilder.buildMatrix(); + this.matrixBuilder = null; int[] y = labels.toArray(); this.labels = null; @@ -277,15 +353,16 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions { * @param numVars The number of variables to pick up in each node. * @param seed The seed number for Random Forest */ - private void train(@Nonnull final double[][] x, @Nonnull final int[] y) throws HiveException { - if (x.length != y.length) { + private void train(@Nonnull Matrix x, @Nonnull final int[] y) throws HiveException { + final int numExamples = x.numRows(); + if (numExamples != y.length) { throw new HiveException(String.format("The sizes of X and Y don't match: %d != %d", - x.length, y.length)); + numExamples, y.length)); } checkOptions(); - // Shuffle training samples - SmileExtUtils.shuffle(x, y, _seed); + // Shuffle training samples + x = SmileExtUtils.shuffle(x, y, _seed); int[] labels = SmileExtUtils.classLables(y); Attribute[] attributes = SmileExtUtils.attributeTypes(_attributes, x); @@ -297,9 +374,8 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions { + _maxLeafNodes + ", splitRule: " + _splitRule + ", seed: " + _seed); } - final int numExamples = x.length; - int[][] prediction = new int[numExamples][labels.length]; // placeholder for out-of-bag prediction - int[][] order = SmileExtUtils.sort(attributes, x); + IntMatrix prediction = new DoKIntMatrix(numExamples, labels.length); // placeholder for out-of-bag prediction + ColumnMajorIntMatrix order = SmileExtUtils.sort(attributes, x); AtomicInteger remainingTasks = new AtomicInteger(_numTrees); List<TrainingTask> tasks = new ArrayList<TrainingTask>(); for (int i = 0; i < _numTrees; i++) { @@ -321,17 +397,19 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions { /** * Synchronized because {@link #forward(Object)} should be called from a single thread. + * + * @param accuracy */ synchronized void forward(final int taskId, @Nonnull final Text model, - @Nonnull final double[] importance, final int[] y, final int[][] prediction, - final boolean lastTask) throws HiveException { + @Nonnull final Vector importance, @Nonnegative final double accuracy, final int[] y, + @Nonnull final IntMatrix prediction, final boolean lastTask) throws HiveException { int oobErrors = 0; int oobTests = 0; if (lastTask) { // out-of-bag error estimate for (int i = 0; i < y.length; i++) { - final int pred = smile.math.Math.whichMax(prediction[i]); - if (prediction[i][pred] > 0) { + final int pred = MatrixUtils.whichMax(prediction, i); + if (pred != -1 && prediction.get(i, pred) > 0) { oobTests++; if (pred != y[i]) { oobErrors++; @@ -340,12 +418,23 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions { } } - String modelId = RandomUtils.getUUID(); final Object[] forwardObjs = new Object[6]; + String modelId = RandomUtils.getUUID(); forwardObjs[0] = new Text(modelId); - forwardObjs[1] = new IntWritable(_outputType.getId()); + forwardObjs[1] = new DoubleWritable(accuracy); forwardObjs[2] = model; - forwardObjs[3] = WritableUtils.toWritableList(importance); + if (denseInput) { + forwardObjs[3] = WritableUtils.toWritableList(importance.toArray()); + } else { + final Map<IntWritable, DoubleWritable> map = new HashMap<IntWritable, DoubleWritable>( + importance.size()); + importance.each(new VectorProcedure() { + public void apply(int i, double value) { + map.put(new IntWritable(i), new DoubleWritable(value)); + } + }); + forwardObjs[3] = map; + } forwardObjs[4] = new IntWritable(oobErrors); forwardObjs[5] = new IntWritable(oobTests); forward(forwardObjs); @@ -363,20 +452,23 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions { /** * Attribute properties. */ + @Nonnull private final Attribute[] _attributes; /** * Training instances. */ - private final double[][] _x; + @Nonnull + private final Matrix _x; /** * Training sample labels. */ + @Nonnull private final int[] _y; /** - * The index of training values in ascending order. Note that only numeric attributes will - * be sorted. + * The index of training values in ascending order. Note that only numeric attributes will be sorted. */ - private final int[][] _order; + @Nonnull + private final ColumnMajorIntMatrix _order; /** * The number of variables to pick up in each node. */ @@ -384,16 +476,21 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions { /** * The out-of-bag predictions. */ - private final int[][] _prediction; + @Nonnull + @GuardedBy("_udtf") + private final IntMatrix _prediction; + @Nonnull private final RandomForestClassifierUDTF _udtf; private final int _taskId; private final long _seed; + @Nonnull private final AtomicInteger _remainingTasks; - TrainingTask(RandomForestClassifierUDTF udtf, int taskId, Attribute[] attributes, - double[][] x, int[] y, int numVars, int[][] order, int[][] prediction, long seed, - AtomicInteger remainingTasks) { + TrainingTask(@Nonnull RandomForestClassifierUDTF udtf, int taskId, + @Nonnull Attribute[] attributes, @Nonnull Matrix x, @Nonnull int[] y, int numVars, + @Nonnull ColumnMajorIntMatrix order, @Nonnull IntMatrix prediction, long seed, + @Nonnull AtomicInteger remainingTasks) { this._udtf = udtf; this._taskId = taskId; this._attributes = attributes; @@ -408,98 +505,107 @@ public final class RandomForestClassifierUDTF extends UDTFWithOptions { @Override public Integer call() throws HiveException { - long s = (this._seed == -1L) ? SmileExtUtils.generateSeed() : new smile.math.Random( - _seed).nextLong(); - final smile.math.Random rnd1 = new smile.math.Random(s); - final smile.math.Random rnd2 = new smile.math.Random(rnd1.nextLong()); - final int N = _x.length; + long s = (this._seed == -1L) ? SmileExtUtils.generateSeed() + : RandomNumberGeneratorFactory.createPRNG(_seed).nextLong(); + final PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s); + final PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong()); + final int N = _x.numRows(); // Training samples draw with replacement. - final int[] bags = new int[N]; final BitSet sampled = new BitSet(N); - for (int i = 0; i < N; i++) { - int index = rnd1.nextInt(N); - bags[i] = index; - sampled.set(index); - } + final int[] bags = sampling(sampled, N, rnd1); DecisionTree tree = new DecisionTree(_attributes, _x, _y, _numVars, _udtf._maxDepth, _udtf._maxLeafNodes, _udtf._minSamplesSplit, _udtf._minSamplesLeaf, bags, _order, _udtf._splitRule, rnd2); // out-of-bag prediction + int oob = 0; + int correct = 0; + final Vector xProbe = _x.rowVector(); for (int i = sampled.nextClearBit(0); i < N; i = sampled.nextClearBit(i + 1)) { - final int p = tree.predict(_x[i]); - synchronized (_prediction[i]) { - _prediction[i][p]++; + oob++; + _x.getRow(i, xProbe); + final int p = tree.predict(xProbe); + if (p == _y[i]) { + correct++; + } + synchronized (_udtf) { + _prediction.incr(i, p); } } - Text model = getModel(tree, _udtf._outputType); - double[] importance = tree.importance(); + Text model = getModel(tree); + Vector importance = tree.importance(); + double accuracy = (oob == 0) ? 1.0d : (double) correct / oob; int remain = _remainingTasks.decrementAndGet(); boolean lastTask = (remain == 0); - _udtf.forward(_taskId + 1, model, importance, _y, _prediction, lastTask); + _udtf.forward(_taskId + 1, model, importance, accuracy, _y, _prediction, lastTask); return Integer.valueOf(remain); } - private static Text getModel(@Nonnull final DecisionTree tree, - @Nonnull final ModelType outputType) throws HiveException { - final Text model; - switch (outputType) { - case serialization: - case serialization_compressed: { - byte[] b = tree.predictSerCodegen(outputType.isCompressed()); - b = Base91.encode(b); - model = new Text(b); - break; - } - case opscode: - case opscode_compressed: { - String s = tree.predictOpCodegen(StackMachine.SEP); - if (outputType.isCompressed()) { - byte[] b = s.getBytes(); - final DeflateCodec codec = new DeflateCodec(true, false); - try { - b = codec.compress(b); - } catch (IOException e) { - throw new HiveException("Failed to compressing a model", e); - } finally { - IOUtils.closeQuietly(codec); - } - b = Base91.encode(b); - model = new Text(b); - } else { - model = new Text(s); + @Nonnull + private int[] sampling(@Nonnull final BitSet sampled, final int N, @Nonnull PRNG rnd) { + return _udtf._stratifiedSampling ? stratifiedSampling(sampled, N, _udtf._subsample, rnd) + : uniformSampling(sampled, N, _udtf._subsample, rnd); + } + + @Nonnull + private static int[] uniformSampling(@Nonnull final BitSet sampled, final int N, + final double subsample, final PRNG rnd) { + final int size = (int) Math.round(N * subsample); + final int[] bags = new int[N]; + for (int i = 0; i < size; i++) { + int index = rnd.nextInt(N); + bags[i] = index; + sampled.set(index); + } + return bags; + } + + /** + * Stratified sampling for unbalanced data. + * + * @link https://en.wikipedia.org/wiki/Stratified_sampling + */ + @Nonnull + private int[] stratifiedSampling(@Nonnull final BitSet sampled, final int N, + final double subsample, final PRNG rnd) { + final IntArrayList bagsList = new IntArrayList(N); + final int k = smile.math.Math.max(_y) + 1; + final IntArrayList cj = new IntArrayList(N / k); + for (int l = 0; l < k; l++) { + int nj = 0; + for (int i = 0; i < N; i++) { + if (_y[i] == l) { + cj.add(i); + nj++; } - break; } - case javascript: - case javascript_compressed: { - String s = tree.predictJsCodegen(); - if (outputType.isCompressed()) { - byte[] b = s.getBytes(); - final DeflateCodec codec = new DeflateCodec(true, false); - try { - b = codec.compress(b); - } catch (IOException e) { - throw new HiveException("Failed to compressing a model", e); - } finally { - IOUtils.closeQuietly(codec); - } - b = Base91.encode(b); - model = new Text(b); - } else { - model = new Text(s); - } - break; + if (subsample != 1.0d) { + nj = (int) Math.round(nj * subsample); + } + final int size = (_udtf._classWeight == null) ? nj : (int) Math.round(nj + * _udtf._classWeight[l]); + for (int j = 0; j < size; j++) { + int xi = rnd.nextInt(nj); + int index = cj.get(xi); + bagsList.add(index); + sampled.set(index); } - default: - throw new HiveException("Unexpected output type: " + outputType - + ". Use javascript for the output instead"); + cj.clear(); } - return model; + int[] bags = bagsList.toArray(true); + SmileExtUtils.shuffle(bags, rnd); + return bags; + } + + @Nonnull + private static Text getModel(@Nonnull final DecisionTree tree) throws HiveException { + byte[] b = tree.predictSerCodegen(true); + b = Base91.encode(b); + return new Text(b); } }
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/data/Attribute.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/data/Attribute.java b/core/src/main/java/hivemall/smile/data/Attribute.java index be6651a..6569726 100644 --- a/core/src/main/java/hivemall/smile/data/Attribute.java +++ b/core/src/main/java/hivemall/smile/data/Attribute.java @@ -18,6 +18,9 @@ */ package hivemall.smile.data; +import hivemall.annotations.Immutable; +import hivemall.annotations.Mutable; + import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; @@ -25,11 +28,9 @@ import java.io.ObjectOutput; public abstract class Attribute { public final AttributeType type; - public final int attrIndex; - Attribute(AttributeType type, int attrIndex) { + Attribute(AttributeType type) { this.type = type; - this.attrIndex = attrIndex; } public void setSize(int size) { @@ -44,24 +45,23 @@ public abstract class Attribute { } public void writeTo(ObjectOutput out) throws IOException { - out.writeInt(type.getTypeId()); - out.writeInt(attrIndex); + out.writeByte(type.getTypeId()); } public enum AttributeType { - NUMERIC(1), NOMINAL(2); + NUMERIC((byte) 1), NOMINAL((byte) 2); - private final int id; + private final byte id; - private AttributeType(int id) { + private AttributeType(byte id) { this.id = id; } - public int getTypeId() { + public byte getTypeId() { return id; } - public static AttributeType resolve(int id) { + public static AttributeType resolve(byte id) { final AttributeType type; switch (id) { case 1: @@ -78,25 +78,27 @@ public abstract class Attribute { } + @Immutable public static final class NumericAttribute extends Attribute { - public NumericAttribute(int attrIndex) { - super(AttributeType.NUMERIC, attrIndex); + public NumericAttribute() { + super(AttributeType.NUMERIC); } @Override public String toString() { - return "NumericAttribute [type=" + type + ", attrIndex=" + attrIndex + "]"; + return "NumericAttribute [type=" + type + "]"; } } + @Mutable public static final class NominalAttribute extends Attribute { private int size; - public NominalAttribute(int attrIndex) { - super(AttributeType.NOMINAL, attrIndex); + public NominalAttribute() { + super(AttributeType.NOMINAL); this.size = -1; } @@ -118,25 +120,23 @@ public abstract class Attribute { @Override public String toString() { - return "NominalAttribute [size=" + size + ", type=" + type + ", attrIndex=" + attrIndex - + "]"; + return "NominalAttribute [size=" + size + ", type=" + type + "]"; } } public static Attribute readFrom(ObjectInput in) throws IOException { - int typeId = in.readInt(); - int attrIndex = in.readInt(); - final Attribute attr; + + byte typeId = in.readByte(); final AttributeType type = AttributeType.resolve(typeId); switch (type) { case NUMERIC: { - attr = new NumericAttribute(attrIndex); + attr = new NumericAttribute(); break; } case NOMINAL: { - attr = new NominalAttribute(attrIndex); + attr = new NominalAttribute(); int size = in.readInt(); attr.setSize(size); break; http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java index ebb58c6..557df21 100644 --- a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java +++ b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java @@ -19,22 +19,25 @@ package hivemall.smile.regression; import hivemall.UDTFWithOptions; -import hivemall.smile.ModelType; +import hivemall.math.matrix.Matrix; +import hivemall.math.matrix.builders.CSRMatrixBuilder; +import hivemall.math.matrix.builders.MatrixBuilder; +import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder; +import hivemall.math.matrix.ints.ColumnMajorIntMatrix; +import hivemall.math.random.PRNG; +import hivemall.math.random.RandomNumberGeneratorFactory; +import hivemall.math.vector.Vector; import hivemall.smile.data.Attribute; import hivemall.smile.utils.SmileExtUtils; import hivemall.smile.utils.SmileTaskExecutor; -import hivemall.smile.vm.StackMachine; import hivemall.utils.codec.Base91; -import hivemall.utils.codec.DeflateCodec; -import hivemall.utils.collections.DoubleArrayList; +import hivemall.utils.collections.lists.DoubleArrayList; import hivemall.utils.datetime.StopWatch; import hivemall.utils.hadoop.HiveUtils; import hivemall.utils.hadoop.WritableUtils; -import hivemall.utils.io.IOUtils; import hivemall.utils.lang.Primitives; import hivemall.utils.lang.RandomUtils; -import java.io.IOException; import java.util.ArrayList; import java.util.BitSet; import java.util.List; @@ -42,6 +45,7 @@ import java.util.concurrent.Callable; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -69,7 +73,7 @@ import org.apache.hadoop.mapred.Reporter; @Description( name = "train_randomforest_regression", - value = "_FUNC_(double[] features, double target [, string options]) - " + value = "_FUNC_(array<double|string> features, double target [, string options]) - " + "Returns a relation consists of " + "<int model_id, int model_type, string pred_model, array<double> var_importance, int oob_errors, int oob_tests>") public final class RandomForestRegressionUDTF extends UDTFWithOptions { @@ -79,7 +83,8 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { private PrimitiveObjectInspector featureElemOI; private PrimitiveObjectInspector targetOI; - private List<double[]> featuresList; + private boolean denseInput; + private MatrixBuilder matrixBuilder; private DoubleArrayList targets; /** * The number of trees for each task @@ -101,7 +106,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { private int _minSamplesLeaf; private long _seed; private Attribute[] _attributes; - private ModelType _outputType; @Nullable private Reporter _progressReporter; @@ -131,10 +135,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { opts.addOption("seed", true, "seed value in long [default: -1 (random)]"); opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types " + "(Q for quantitative variable and C for categorical variable. e.g., [Q,C,Q,C])"); - opts.addOption("output", "output_type", true, - "The output type (serialization/ser or opscode/vm or javascript/js) [default: serialization]"); - opts.addOption("disable_compression", false, - "Whether to disable compression of the output script [default: false]"); return opts; } @@ -145,8 +145,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { float numVars = -1.f; Attribute[] attrs = null; long seed = -1L; - String output = "serialization"; - boolean compress = true; CommandLine cl = null; if (argOIs.length >= 3) { @@ -165,10 +163,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { minSamplesLeaf); seed = Primitives.parseLong(cl.getOptionValue("seed"), seed); attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types")); - output = cl.getOptionValue("output", output); - if (cl.hasOption("disable_compression")) { - compress = false; - } } this._numTrees = trees; @@ -179,7 +173,6 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { this._minSamplesLeaf = minSamplesLeaf; this._seed = seed; this._attributes = attrs; - this._outputType = ModelType.resolve(output, compress); return cl; } @@ -189,19 +182,29 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { if (argOIs.length != 2 && argOIs.length != 3) { throw new UDFArgumentException( getClass().getSimpleName() - + " takes 2 or 3 arguments: double[] features, double target [, const string options]: " + + " takes 2 or 3 arguments: array<double|string> features, double target [, const string options]: " + argOIs.length); } ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]); ObjectInspector elemOI = listOI.getListElementObjectInspector(); this.featureListOI = listOI; - this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI); + if (HiveUtils.isNumberOI(elemOI)) { + this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI); + this.denseInput = true; + this.matrixBuilder = new RowMajorDenseMatrixBuilder(8192); + } else if (HiveUtils.isStringOI(elemOI)) { + this.featureElemOI = HiveUtils.asStringOI(elemOI); + this.denseInput = false; + this.matrixBuilder = new CSRMatrixBuilder(8192); + } else { + throw new UDFArgumentException( + "_FUNC_ takes double[] or string[] for the first argument: " + listOI.getTypeName()); + } this.targetOI = HiveUtils.asDoubleCompatibleOI(argOIs[1]); processOptions(argOIs); - this.featuresList = new ArrayList<double[]>(1024); this.targets = new DoubleArrayList(1024); ArrayList<String> fieldNames = new ArrayList<String>(5); @@ -209,8 +212,8 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { fieldNames.add("model_id"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); - fieldNames.add("model_type"); - fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("model_err"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); fieldNames.add("pred_model"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); fieldNames.add("var_importance"); @@ -228,13 +231,36 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { if (args[0] == null) { throw new HiveException("array<double> features was null"); } - double[] features = HiveUtils.asDoubleArray(args[0], featureListOI, featureElemOI); + parseFeatures(args[0], matrixBuilder); double target = PrimitiveObjectInspectorUtils.getDouble(args[1], targetOI); - - featuresList.add(features); targets.add(target); } + private void parseFeatures(@Nonnull final Object argObj, @Nonnull final MatrixBuilder builder) { + if (denseInput) { + final int length = featureListOI.getListLength(argObj); + for (int i = 0; i < length; i++) { + Object o = featureListOI.getListElement(argObj, i); + if (o == null) { + continue; + } + double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI); + builder.nextColumn(i, v); + } + } else { + final int length = featureListOI.getListLength(argObj); + for (int i = 0; i < length; i++) { + Object o = featureListOI.getListElement(argObj, i); + if (o == null) { + continue; + } + String fv = o.toString(); + builder.nextColumn(fv); + } + } + builder.nextRow(); + } + @Override public void close() throws HiveException { this._progressReporter = getReporter(); @@ -250,10 +276,9 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { reportProgress(_progressReporter); - int numExamples = featuresList.size(); - if (numExamples > 0) { - double[][] x = featuresList.toArray(new double[numExamples][]); - this.featuresList = null; + if (!targets.isEmpty()) { + Matrix x = matrixBuilder.buildMatrix(); + this.matrixBuilder = null; double[] y = targets.toArray(); this.targets = null; @@ -285,15 +310,16 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { * @param _numVars The number of variables to pick up in each node. * @param _seed The seed number for Random Forest */ - private void train(@Nonnull final double[][] x, @Nonnull final double[] y) throws HiveException { - if (x.length != y.length) { + private void train(@Nonnull Matrix x, @Nonnull final double[] y) throws HiveException { + final int numExamples = x.numRows(); + if (numExamples != y.length) { throw new HiveException(String.format("The sizes of X and Y don't match: %d != %d", - x.length, y.length)); + numExamples, y.length)); } checkOptions(); - // Shuffle training samples - SmileExtUtils.shuffle(x, y, _seed); + // Shuffle training samples + x = SmileExtUtils.shuffle(x, y, _seed); Attribute[] attributes = SmileExtUtils.attributeTypes(_attributes, x); int numInputVars = SmileExtUtils.computeNumInputVars(_numVars, x); @@ -305,10 +331,9 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { + ", seed: " + _seed); } - int numExamples = x.length; double[] prediction = new double[numExamples]; // placeholder for out-of-bag prediction int[] oob = new int[numExamples]; - int[][] order = SmileExtUtils.sort(attributes, x); + ColumnMajorIntMatrix order = SmileExtUtils.sort(attributes, x); AtomicInteger remainingTasks = new AtomicInteger(_numTrees); List<TrainingTask> tasks = new ArrayList<TrainingTask>(); for (int i = 0; i < _numTrees; i++) { @@ -330,10 +355,13 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { /** * Synchronized because {@link #forward(Object)} should be called from a single thread. + * + * @param error */ synchronized void forward(final int taskId, @Nonnull final Text model, - @Nonnull final double[] importance, final double[] y, final double[] prediction, - final int[] oob, final boolean lastTask) throws HiveException { + @Nonnull final double[] importance, @Nonnegative final double error, final double[] y, + final double[] prediction, final int[] oob, final boolean lastTask) + throws HiveException { double oobErrors = 0.d; int oobTests = 0; if (lastTask) { @@ -349,7 +377,7 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { String modelId = RandomUtils.getUUID(); final Object[] forwardObjs = new Object[6]; forwardObjs[0] = new Text(modelId); - forwardObjs[1] = new IntWritable(_outputType.getId()); + forwardObjs[1] = new DoubleWritable(error); forwardObjs[2] = model; forwardObjs[3] = WritableUtils.toWritableList(importance); forwardObjs[4] = new DoubleWritable(oobErrors); @@ -373,16 +401,15 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { /** * Training instances. */ - private final double[][] _x; + private final Matrix _x; /** * Training sample labels. */ private final double[] _y; /** - * The index of training values in ascending order. Note that only numeric attributes will - * be sorted. + * The index of training values in ascending order. Note that only numeric attributes will be sorted. */ - private final int[][] _order; + private final ColumnMajorIntMatrix _order; /** * The number of variables to pick up in each node. */ @@ -401,8 +428,8 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { private final long _seed; private final AtomicInteger _remainingTasks; - TrainingTask(RandomForestRegressionUDTF udtf, int taskId, Attribute[] attributes, - double[][] x, double[] y, int numVars, int[][] order, double[] prediction, + TrainingTask(RandomForestRegressionUDTF udtf, int taskId, Attribute[] attributes, Matrix x, + double[] y, int numVars, ColumnMajorIntMatrix order, double[] prediction, int[] oob, long seed, AtomicInteger remainingTasks) { this._udtf = udtf; this._taskId = taskId; @@ -419,11 +446,11 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { @Override public Integer call() throws HiveException { - long s = (this._seed == -1L) ? SmileExtUtils.generateSeed() : new smile.math.Random( - _seed).nextLong(); - final smile.math.Random rnd1 = new smile.math.Random(s); - final smile.math.Random rnd2 = new smile.math.Random(rnd1.nextLong()); - final int N = _x.length; + long s = (this._seed == -1L) ? SmileExtUtils.generateSeed() + : RandomNumberGeneratorFactory.createPRNG(_seed).nextLong(); + final PRNG rnd1 = RandomNumberGeneratorFactory.createPRNG(s); + final PRNG rnd2 = RandomNumberGeneratorFactory.createPRNG(rnd1.nextLong()); + final int N = _x.numRows(); // Training samples draw with replacement. final int[] bags = new int[N]; @@ -441,82 +468,40 @@ public final class RandomForestRegressionUDTF extends UDTFWithOptions { incrCounter(_udtf._treeConstuctionTimeCounter, stopwatch.elapsed(TimeUnit.SECONDS)); // out-of-bag prediction + int oob = 0; + double error = 0.d; + final Vector xProbe = _x.rowVector(); for (int i = sampled.nextClearBit(0); i < N; i = sampled.nextClearBit(i + 1)) { - double pred = tree.predict(_x[i]); - synchronized (_x[i]) { + oob++; + _x.getRow(i, xProbe); + final double pred = tree.predict(xProbe); + synchronized (_prediction) { _prediction[i] += pred; _oob[i]++; } + error += Math.abs(pred - _y[i]); + } + if (oob != 0) { + error /= oob; } stopwatch.reset().start(); - Text model = getModel(tree, _udtf._outputType); + Text model = getModel(tree); double[] importance = tree.importance(); tree = null; // help GC int remain = _remainingTasks.decrementAndGet(); boolean lastTask = (remain == 0); - _udtf.forward(_taskId + 1, model, importance, _y, _prediction, _oob, lastTask); + _udtf.forward(_taskId + 1, model, importance, error, _y, _prediction, _oob, lastTask); incrCounter(_udtf._treeSerializationTimeCounter, stopwatch.elapsed(TimeUnit.SECONDS)); return Integer.valueOf(remain); } - private static Text getModel(@Nonnull final RegressionTree tree, - @Nonnull final ModelType outputType) throws HiveException { - final Text model; - switch (outputType) { - case serialization: - case serialization_compressed: { - byte[] b = tree.predictSerCodegen(outputType.isCompressed()); - b = Base91.encode(b); - model = new Text(b); - break; - } - case opscode: - case opscode_compressed: { - String s = tree.predictOpCodegen(StackMachine.SEP); - if (outputType.isCompressed()) { - byte[] b = s.getBytes(); - final DeflateCodec codec = new DeflateCodec(true, false); - try { - b = codec.compress(b); - } catch (IOException e) { - throw new HiveException("Failed to compressing a model", e); - } finally { - IOUtils.closeQuietly(codec); - } - b = Base91.encode(b); - model = new Text(b); - } else { - model = new Text(s); - } - break; - } - case javascript: - case javascript_compressed: { - String s = tree.predictJsCodegen(); - if (outputType.isCompressed()) { - byte[] b = s.getBytes(); - final DeflateCodec codec = new DeflateCodec(true, false); - try { - b = codec.compress(b); - } catch (IOException e) { - throw new HiveException("Failed to compressing a model", e); - } finally { - IOUtils.closeQuietly(codec); - } - b = Base91.encode(b); - model = new Text(b); - } else { - model = new Text(s); - } - break; - } - default: - throw new HiveException("Unexpected output type: " + outputType - + ". Use javascript for the output instead"); - } - return model; + @Nonnull + private static Text getModel(@Nonnull final RegressionTree tree) throws HiveException { + byte[] b = tree.predictSerCodegen(true); + b = Base91.encode(b); + return new Text(b); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/8dc3a024/core/src/main/java/hivemall/smile/regression/RegressionTree.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/smile/regression/RegressionTree.java b/core/src/main/java/hivemall/smile/regression/RegressionTree.java index 07887c1..da7e80b 100755 --- a/core/src/main/java/hivemall/smile/regression/RegressionTree.java +++ b/core/src/main/java/hivemall/smile/regression/RegressionTree.java @@ -33,20 +33,28 @@ */ package hivemall.smile.regression; +import hivemall.annotations.VisibleForTesting; +import hivemall.math.matrix.Matrix; +import hivemall.math.matrix.ints.ColumnMajorIntMatrix; +import hivemall.math.random.PRNG; +import hivemall.math.random.RandomNumberGeneratorFactory; +import hivemall.math.vector.DenseVector; +import hivemall.math.vector.Vector; +import hivemall.math.vector.VectorProcedure; import hivemall.smile.data.Attribute; import hivemall.smile.data.Attribute.AttributeType; import hivemall.smile.utils.SmileExtUtils; -import hivemall.utils.collections.IntArrayList; +import hivemall.utils.collections.lists.IntArrayList; +import hivemall.utils.collections.sets.IntArraySet; +import hivemall.utils.collections.sets.IntSet; import hivemall.utils.lang.ObjectUtils; -import hivemall.utils.lang.StringUtils; +import hivemall.utils.math.MathUtils; import java.io.Externalizable; import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; -import java.util.ArrayList; import java.util.Arrays; -import java.util.List; import java.util.PriorityQueue; import javax.annotation.Nonnull; @@ -55,60 +63,48 @@ import javax.annotation.Nullable; import org.apache.hadoop.hive.ql.metadata.HiveException; import smile.math.Math; -import smile.math.Random; import smile.regression.GradientTreeBoost; import smile.regression.RandomForest; import smile.regression.Regression; /** - * Decision tree for regression. A decision tree can be learned by splitting the training set into - * subsets based on an attribute value test. This process is repeated on each derived subset in a - * recursive manner called recursive partitioning. + * Decision tree for regression. A decision tree can be learned by splitting the training set into subsets based on an attribute value test. This + * process is repeated on each derived subset in a recursive manner called recursive partitioning. * <p> - * Classification and Regression Tree techniques have a number of advantages over many of those - * alternative techniques. + * Classification and Regression Tree techniques have a number of advantages over many of those alternative techniques. * <dl> * <dt>Simple to understand and interpret.</dt> - * <dd>In most cases, the interpretation of results summarized in a tree is very simple. This - * simplicity is useful not only for purposes of rapid classification of new observations, but can - * also often yield a much simpler "model" for explaining why observations are classified or - * predicted in a particular manner.</dd> + * <dd>In most cases, the interpretation of results summarized in a tree is very simple. This simplicity is useful not only for purposes of rapid + * classification of new observations, but can also often yield a much simpler "model" for explaining why observations are classified or predicted in + * a particular manner.</dd> * <dt>Able to handle both numerical and categorical data.</dt> - * <dd>Other techniques are usually specialized in analyzing datasets that have only one type of - * variable.</dd> + * <dd>Other techniques are usually specialized in analyzing datasets that have only one type of variable.</dd> * <dt>Tree methods are nonparametric and nonlinear.</dt> - * <dd>The final results of using tree methods for classification or regression can be summarized in - * a series of (usually few) logical if-then conditions (tree nodes). Therefore, there is no - * implicit assumption that the underlying relationships between the predictor variables and the - * dependent variable are linear, follow some specific non-linear link function, or that they are - * even monotonic in nature. Thus, tree methods are particularly well suited for data mining tasks, - * where there is often little a priori knowledge nor any coherent set of theories or predictions - * regarding which variables are related and how. In those types of data analytics, tree methods can - * often reveal simple relationships between just a few variables that could have easily gone - * unnoticed using other analytic techniques.</dd> + * <dd>The final results of using tree methods for classification or regression can be summarized in a series of (usually few) logical if-then + * conditions (tree nodes). Therefore, there is no implicit assumption that the underlying relationships between the predictor variables and the + * dependent variable are linear, follow some specific non-linear link function, or that they are even monotonic in nature. Thus, tree methods are + * particularly well suited for data mining tasks, where there is often little a priori knowledge nor any coherent set of theories or predictions + * regarding which variables are related and how. In those types of data analytics, tree methods can often reveal simple relationships between just a + * few variables that could have easily gone unnoticed using other analytic techniques.</dd> * </dl> - * One major problem with classification and regression trees is their high variance. Often a small - * change in the data can result in a very different series of splits, making interpretation - * somewhat precarious. Besides, decision-tree learners can create over-complex trees that cause - * over-fitting. Mechanisms such as pruning are necessary to avoid this problem. Another limitation - * of trees is the lack of smoothness of the prediction surface. + * One major problem with classification and regression trees is their high variance. Often a small change in the data can result in a very different + * series of splits, making interpretation somewhat precarious. Besides, decision-tree learners can create over-complex trees that cause over-fitting. + * Mechanisms such as pruning are necessary to avoid this problem. Another limitation of trees is the lack of smoothness of the prediction surface. * <p> - * Some techniques such as bagging, boosting, and random forest use more than one decision tree for - * their analysis. + * Some techniques such as bagging, boosting, and random forest use more than one decision tree for their analysis. * * @see GradientTreeBoost * @see RandomForest */ -public final class RegressionTree implements Regression<double[]> { +public final class RegressionTree implements Regression<Vector> { /** * The attributes of independent variable. */ private final Attribute[] _attributes; private final boolean _hasNumericType; /** - * Variable importance. Every time a split of a node is made on variable the impurity criterion - * for the two descendant nodes is less than the parent node. Adding up the decreases for each - * individual variable over the tree gives a simple measure of variable importance. + * Variable importance. Every time a split of a node is made on variable the impurity criterion for the two descendant nodes is less than the + * parent node. Adding up the decreases for each individual variable over the tree gives a simple measure of variable importance. */ private final double[] _importance; /** @@ -120,8 +116,7 @@ public final class RegressionTree implements Regression<double[]> { */ private final int _maxDepth; /** - * The number of instances in a node below which the tree will not split, setting S = 5 - * generally gives good results. + * The number of instances in a node below which the tree will not split, setting S = 5 generally gives good results. */ private final int _minSplit; /** @@ -133,19 +128,17 @@ public final class RegressionTree implements Regression<double[]> { */ private final int _numVars; /** - * The index of training values in ascending order. Note that only numeric attributes will be - * sorted. + * The index of training values in ascending order. Note that only numeric attributes will be sorted. */ - private final int[][] _order; + private final ColumnMajorIntMatrix _order; - private final Random _rnd; + private final PRNG _rnd; private final NodeOutput _nodeOutput; /** - * An interface to calculate node output. Note that samples[i] is the number of sampling of - * dataset[i]. 0 means that the datum is not included and values of greater than 1 are possible - * because of sampling with replacement. + * An interface to calculate node output. Note that samples[i] is the number of sampling of dataset[i]. 0 means that the datum is not included and + * values of greater than 1 are possible because of sampling with replacement. */ public interface NodeOutput { /** @@ -205,22 +198,30 @@ public final class RegressionTree implements Regression<double[]> { this.output = output; } + private boolean isLeaf() { + return trueChild == null && falseChild == null; + } + + @VisibleForTesting + public double predict(@Nonnull final double[] x) { + return predict(new DenseVector(x)); + } + /** * Evaluate the regression tree over an instance. */ - public double predict(final double[] x) { + public double predict(@Nonnull final Vector x) { if (trueChild == null && falseChild == null) { return output; } else { if (splitFeatureType == AttributeType.NOMINAL) { - // REVIEWME if(Math.equals(x[splitFeature], splitValue)) { - if (x[splitFeature] == splitValue) { + if (x.get(splitFeature, Double.NaN) == splitValue) { return trueChild.predict(x); } else { return falseChild.predict(x); } } else if (splitFeatureType == AttributeType.NUMERIC) { - if (x[splitFeature] <= splitValue) { + if (x.get(splitFeature, Double.NaN) <= splitValue) { return trueChild.predict(x); } else { return falseChild.predict(x); @@ -283,99 +284,58 @@ public final class RegressionTree implements Regression<double[]> { } } - public int opCodegen(final List<String> scripts, int depth) { - int selfDepth = 0; - final StringBuilder buf = new StringBuilder(); - if (trueChild == null && falseChild == null) { - buf.append("push ").append(output); - scripts.add(buf.toString()); - buf.setLength(0); - buf.append("goto last"); - scripts.add(buf.toString()); - selfDepth += 2; - } else { - if (splitFeatureType == AttributeType.NOMINAL) { - buf.append("push ").append("x[").append(splitFeature).append("]"); - scripts.add(buf.toString()); - buf.setLength(0); - buf.append("push ").append(splitValue); - scripts.add(buf.toString()); - buf.setLength(0); - buf.append("ifeq "); - scripts.add(buf.toString()); - depth += 3; - selfDepth += 3; - int trueDepth = trueChild.opCodegen(scripts, depth); - selfDepth += trueDepth; - scripts.set(depth - 1, "ifeq " + String.valueOf(depth + trueDepth)); - int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth); - selfDepth += falseDepth; - } else if (splitFeatureType == AttributeType.NUMERIC) { - buf.append("push ").append("x[").append(splitFeature).append("]"); - scripts.add(buf.toString()); - buf.setLength(0); - buf.append("push ").append(splitValue); - scripts.add(buf.toString()); - buf.setLength(0); - buf.append("ifle "); - scripts.add(buf.toString()); - depth += 3; - selfDepth += 3; - int trueDepth = trueChild.opCodegen(scripts, depth); - selfDepth += trueDepth; - scripts.set(depth - 1, "ifle " + String.valueOf(depth + trueDepth)); - int falseDepth = falseChild.opCodegen(scripts, depth + trueDepth); - selfDepth += falseDepth; - } else { - throw new IllegalStateException("Unsupported attribute type: " - + splitFeatureType); - } - } - return selfDepth; - } - @Override public void writeExternal(ObjectOutput out) throws IOException { - out.writeDouble(output); out.writeInt(splitFeature); if (splitFeatureType == null) { - out.writeInt(-1); + out.writeByte(-1); } else { - out.writeInt(splitFeatureType.getTypeId()); + out.writeByte(splitFeatureType.getTypeId()); } out.writeDouble(splitValue); - if (trueChild == null) { - out.writeBoolean(false); - } else { + + if (isLeaf()) { out.writeBoolean(true); - trueChild.writeExternal(out); - } - if (falseChild == null) { - out.writeBoolean(false); + out.writeDouble(output); } else { - out.writeBoolean(true); - falseChild.writeExternal(out); + out.writeBoolean(false); + if (trueChild == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + trueChild.writeExternal(out); + } + if (falseChild == null) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + falseChild.writeExternal(out); + } } } @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - this.output = in.readDouble(); this.splitFeature = in.readInt(); - int typeId = in.readInt(); + byte typeId = in.readByte(); if (typeId == -1) { this.splitFeatureType = null; } else { this.splitFeatureType = AttributeType.resolve(typeId); } this.splitValue = in.readDouble(); - if (in.readBoolean()) { - this.trueChild = new Node(); - trueChild.readExternal(in); - } - if (in.readBoolean()) { - this.falseChild = new Node(); - falseChild.readExternal(in); + + if (in.readBoolean()) {// isLeaf() + this.output = in.readDouble(); + } else { + if (in.readBoolean()) { + this.trueChild = new Node(); + trueChild.readExternal(in); + } + if (in.readBoolean()) { + this.falseChild = new Node(); + falseChild.readExternal(in); + } } } } @@ -406,7 +366,7 @@ public final class RegressionTree implements Regression<double[]> { /** * Training dataset. */ - final double[][] x; + final Matrix x; /** * Training data response value. */ @@ -419,7 +379,7 @@ public final class RegressionTree implements Regression<double[]> { /** * Constructor. */ - public TrainNode(Node node, double[][] x, double[] y, int[] bags, int depth) { + public TrainNode(Node node, Matrix x, double[] y, int[] bags, int depth) { this.node = node; this.x = x; this.y = y; @@ -452,8 +412,7 @@ public final class RegressionTree implements Regression<double[]> { } /** - * Finds the best attribute to split on at the current node. Returns true if a split exists - * to reduce squared error, false otherwise. + * Finds the best attribute to split on at the current node. Returns true if a split exists to reduce squared error, false otherwise. */ public boolean findBestSplit() { // avoid split if tree depth is larger than threshold @@ -467,22 +426,14 @@ public final class RegressionTree implements Regression<double[]> { } final double sum = node.output * numSamples; - final int p = _attributes.length; - final int[] variables = new int[p]; - for (int i = 0; i < p; i++) { - variables[i] = i; - } - if (_numVars < p) { - SmileExtUtils.shuffle(variables, _rnd); - } // Loop through features and compute the reduction of squared error, // which is trueCount * trueMean^2 + falseCount * falseMean^2 - count * parentMean^2 - final int[] samples = _hasNumericType ? SmileExtUtils.bagsToSamples(bags, x.length) + final int[] samples = _hasNumericType ? SmileExtUtils.bagsToSamples(bags, x.numRows()) : null; - for (int j = 0; j < _numVars; j++) { - Node split = findBestSplit(numSamples, sum, variables[j], samples); + for (int varJ : variableIndex(x, bags)) { + final Node split = findBestSplit(numSamples, sum, varJ, samples); if (split.splitScore > node.splitScore) { node.splitFeature = split.splitFeature; node.splitFeatureType = split.splitFeatureType; @@ -496,6 +447,31 @@ public final class RegressionTree implements Regression<double[]> { return node.splitFeature != -1; } + private int[] variableIndex(@Nonnull final Matrix x, @Nonnull final int[] bags) { + final int[] variableIndex; + if (x.isSparse()) { + final IntSet cols = new IntArraySet(_numVars); + final VectorProcedure proc = new VectorProcedure() { + public void apply(int col, double value) { + cols.add(col); + } + }; + for (final int row : bags) { + x.eachNonNullInRow(row, proc); + } + variableIndex = cols.toArray(false); + } else { + variableIndex = MathUtils.permutation(_attributes.length); + } + + if (_numVars < variableIndex.length) { + SmileExtUtils.shuffle(variableIndex, _rnd); + return Arrays.copyOf(variableIndex, _numVars); + + } + return variableIndex; + } + /** * Finds the best split cutoff for attribute j at the current node. * @@ -517,7 +493,11 @@ public final class RegressionTree implements Regression<double[]> { // For each true feature of this datum increment the // sufficient statistics for the "true" branch to evaluate // splitting on this feature. - int index = (int) x[i][j]; + final double v = x.get(i, j, Double.NaN); + if (Double.isNaN(v)) { + continue; + } + int index = (int) v; trueSum[index] += y[i]; ++trueCount[index]; } @@ -548,28 +528,38 @@ public final class RegressionTree implements Regression<double[]> { } } } else if (_attributes[j].type == AttributeType.NUMERIC) { - double trueSum = 0.0; - int trueCount = 0; - double prevx = Double.NaN; - - for (int i : _order[j]) { - final int sample = samples[i]; - if (sample > 0) { - if (Double.isNaN(prevx) || x[i][j] == prevx) { - prevx = x[i][j]; - trueSum += sample * y[i]; + + _order.eachNonNullInColumn(j, new VectorProcedure() { + double trueSum = 0.0; + int trueCount = 0; + double prevx = Double.NaN; + + public void apply(final int row, final int i) { + final int sample = samples[i]; + if (sample == 0) { + return; + } + final double x_ij = x.get(i, j, Double.NaN); + if (Double.isNaN(x_ij)) { + return; + } + final double y_i = y[i]; + + if (Double.isNaN(prevx) || x_ij == prevx) { + prevx = x_ij; + trueSum += sample * y_i; trueCount += sample; - continue; + return; } final double falseCount = n - trueCount; // If either side is empty, skip this feature. if (trueCount < _minSplit || falseCount < _minSplit) { - prevx = x[i][j]; - trueSum += sample * y[i]; + prevx = x_ij; + trueSum += sample * y_i; trueCount += sample; - continue; + return; } // compute penalized means @@ -586,17 +576,18 @@ public final class RegressionTree implements Regression<double[]> { // new best split split.splitFeature = j; split.splitFeatureType = AttributeType.NUMERIC; - split.splitValue = (x[i][j] + prevx) / 2; + split.splitValue = (x_ij + prevx) / 2; split.splitScore = gain; split.trueChildOutput = trueMean; split.falseChildOutput = falseMean; } - prevx = x[i][j]; - trueSum += sample * y[i]; + prevx = x_ij; + trueSum += sample * y_i; trueCount += sample; - } - } + }//apply + }); + } else { throw new IllegalStateException("Unsupported attribute type: " + _attributes[j].type); @@ -672,7 +663,7 @@ public final class RegressionTree implements Regression<double[]> { final double splitValue = node.splitValue; for (int i = 0, size = bags.length; i < size; i++) { final int index = bags[i]; - if (x[index][splitFeature] == splitValue) { + if (x.get(index, splitFeature, Double.NaN) == splitValue) { trueBags.add(index); tc++; } else { @@ -684,7 +675,7 @@ public final class RegressionTree implements Regression<double[]> { final double splitValue = node.splitValue; for (int i = 0, size = bags.length; i < size; i++) { final int index = bags[i]; - if (x[index][splitFeature] <= splitValue) { + if (x.get(index, splitFeature, Double.NaN) <= splitValue) { trueBags.add(index); tc++; } else { @@ -700,20 +691,19 @@ public final class RegressionTree implements Regression<double[]> { } - public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, - @Nonnull double[] y, int maxLeafs) { - this(attributes, x, y, x[0].length, Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, null); + public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull double[] y, + int maxLeafs) { + this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, null); } - public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, - @Nonnull double[] y, int maxLeafs, @Nullable smile.math.Random rand) { - this(attributes, x, y, x[0].length, Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, rand); + public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull double[] y, + int maxLeafs, @Nullable PRNG rand) { + this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, maxLeafs, 5, 1, null, null, rand); } - public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, - @Nonnull double[] y, int numVars, int maxDepth, int maxLeafs, int minSplits, - int minLeafSize, @Nullable int[][] order, @Nullable int[] bags, - @Nullable smile.math.Random rand) { + public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull double[] y, + int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize, + @Nullable ColumnMajorIntMatrix order, @Nullable int[] bags, @Nullable PRNG rand) { this(attributes, x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize, order, bags, null, rand); } @@ -723,24 +713,22 @@ public final class RegressionTree implements Regression<double[]> { * @param attributes the attribute properties. * @param x the training instances. * @param y the response variable. - * @param numVars the number of input variables to pick to split on at each node. It seems that - * dim/3 give generally good performance, where dim is the number of variables. + * @param numVars the number of input variables to pick to split on at each node. It seems that dim/3 give generally good performance, where dim + * is the number of variables. * @param maxLeafs the maximum number of leaf nodes in the tree. - * @param minSplits number of instances in a node below which the tree will not split, setting S - * = 5 generally gives good results. - * @param order the index of training values in ascending order. Note that only numeric - * attributes need be sorted. + * @param minSplits number of instances in a node below which the tree will not split, setting S = 5 generally gives good results. + * @param order the index of training values in ascending order. Note that only numeric attributes need be sorted. * @param bags the sample set of instances for stochastic learning. * @param output An interface to calculate node output. */ - public RegressionTree(@Nullable Attribute[] attributes, @Nonnull double[][] x, - @Nonnull double[] y, int numVars, int maxDepth, int maxLeafs, int minSplits, - int minLeafSize, @Nullable int[][] order, @Nullable int[] bags, - @Nullable NodeOutput output, @Nullable smile.math.Random rand) { + public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, @Nonnull double[] y, + int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize, + @Nullable ColumnMajorIntMatrix order, @Nullable int[] bags, + @Nullable NodeOutput output, @Nullable PRNG rand) { checkArgument(x, y, numVars, maxDepth, maxLeafs, minSplits, minLeafSize); this._attributes = SmileExtUtils.attributeTypes(attributes, x); - if (_attributes.length != x[0].length) { + if (_attributes.length != x.numColumns()) { throw new IllegalArgumentException("-attrs option is invliad: " + Arrays.toString(attributes)); } @@ -752,7 +740,7 @@ public final class RegressionTree implements Regression<double[]> { this._minLeafSize = minLeafSize; this._order = (order == null) ? SmileExtUtils.sort(_attributes, x) : order; this._importance = new double[_attributes.length]; - this._rnd = (rand == null) ? new smile.math.Random() : rand; + this._rnd = (rand == null) ? RandomNumberGeneratorFactory.createPRNG() : rand; this._nodeOutput = output; int n = 0; @@ -803,13 +791,13 @@ public final class RegressionTree implements Regression<double[]> { } } - private static void checkArgument(@Nonnull double[][] x, @Nonnull double[] y, int numVars, + private static void checkArgument(@Nonnull Matrix x, @Nonnull double[] y, int numVars, int maxDepth, int maxLeafs, int minSplits, int minLeafSize) { - if (x.length != y.length) { + if (x.numRows() != y.length) { throw new IllegalArgumentException(String.format( - "The sizes of X and Y don't match: %d != %d", x.length, y.length)); + "The sizes of X and Y don't match: %d != %d", x.numRows(), y.length)); } - if (numVars <= 0 || numVars > x[0].length) { + if (numVars <= 0 || numVars > x.numColumns()) { throw new IllegalArgumentException( "Invalid number of variables to split on at a node of the tree: " + numVars); } @@ -830,10 +818,8 @@ public final class RegressionTree implements Regression<double[]> { } /** - * Returns the variable importance. Every time a split of a node is made on variable the - * impurity criterion for the two descendent nodes is less than the parent node. Adding up the - * decreases for each individual variable over the tree gives a simple measure of variable - * importance. + * Returns the variable importance. Every time a split of a node is made on variable the impurity criterion for the two descendent nodes is less + * than the parent node. Adding up the decreases for each individual variable over the tree gives a simple measure of variable importance. * * @return the variable importance */ @@ -841,8 +827,13 @@ public final class RegressionTree implements Regression<double[]> { return _importance; } + @VisibleForTesting + public double predict(@Nonnull final double[] x) { + return predict(new DenseVector(x)); + } + @Override - public double predict(double[] x) { + public double predict(@Nonnull final Vector x) { return _root.predict(x); } @@ -852,14 +843,6 @@ public final class RegressionTree implements Regression<double[]> { return buf.toString(); } - public String predictOpCodegen(@Nonnull String sep) { - List<String> opslist = new ArrayList<String>(); - _root.opCodegen(opslist, 0); - opslist.add("call end"); - String scripts = StringUtils.concat(opslist, sep); - return scripts; - } - @Nonnull public byte[] predictSerCodegen(boolean compress) throws HiveException { try {
