This is an automated email from the ASF dual-hosted git repository.

myui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-hivemall.git


The following commit(s) were added to refs/heads/master by this push:
     new eea73d5  [HIVEMALL-243] Fix nominal variable handling in DecisionTree 
and RegressionTre
eea73d5 is described below

commit eea73d52eb9f0edbf5c7bc9b4dccdb85b577d6fd
Author: Makoto Yui <[email protected]>
AuthorDate: Wed Mar 13 16:56:17 2019 +0900

    [HIVEMALL-243] Fix nominal variable handling in DecisionTree and 
RegressionTre
    
    ## What changes were proposed in this pull request?
    
    For NOMINAL variable, the maximum attribute index 'm' is used for computing 
splits.
    
    This cause performance issues for sparse nominal variables. So, revise this 
handling for a better performance.
    
    
https://github.com/apache/incubator-hivemall/blob/master/core/src/main/java/hivemall/smile/classification/DecisionTree.java#L703
    
    ## What type of PR is it?
    
    Improvement
    
    ## What is the Jira issue?
    
    https://issues.apache.org/jira/browse/HIVEMALL-243
    
    ## How was this patch tested?
    
    - [x] manual test on EMR
    
    ## Checklist
    
    - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, 
for your commit?
    - [x] Did you run system tests on Hive (or Spark)?
    
    Author: Makoto Yui <[email protected]>
    
    Closes #185 from myui/HIVEMALL-243.
---
 .../smile/classification/DecisionTree.java         |  42 +++---
 .../GradientTreeBoostingClassifierUDTF.java        |   8 +-
 .../classification/RandomForestClassifierUDTF.java |  16 +-
 .../main/java/hivemall/smile/data/Attribute.java   | 167 ---------------------
 .../java/hivemall/smile/data/AttributeType.java    |  67 +++++++++
 .../regression/RandomForestRegressionUDTF.java     |  16 +-
 .../hivemall/smile/regression/RegressionTree.java  |  62 ++++----
 .../hivemall/smile/tools/TreePredictUDFv1.java     |   2 +-
 .../java/hivemall/smile/utils/SmileExtUtils.java   | 114 +++-----------
 .../java/hivemall/utils/hadoop/JsonSerdeUtils.java |   1 -
 .../smile/classification/DecisionTreeTest.java     |  12 +-
 .../smile/regression/RegressionTreeTest.java       |  19 ++-
 .../hivemall/smile/tools/TreePredictUDFTest.java   |  23 ++-
 .../hivemall/smile/tools/TreePredictUDFv1Test.java |  10 +-
 docs/gitbook/binaryclass/news20_rf.md              |   7 +-
 15 files changed, 205 insertions(+), 361 deletions(-)

diff --git a/core/src/main/java/hivemall/smile/classification/DecisionTree.java 
b/core/src/main/java/hivemall/smile/classification/DecisionTree.java
index a80a299..00ebae3 100644
--- a/core/src/main/java/hivemall/smile/classification/DecisionTree.java
+++ b/core/src/main/java/hivemall/smile/classification/DecisionTree.java
@@ -29,14 +29,15 @@ import hivemall.math.vector.DenseVector;
 import hivemall.math.vector.SparseVector;
 import hivemall.math.vector.Vector;
 import hivemall.math.vector.VectorProcedure;
-import hivemall.smile.data.Attribute;
-import hivemall.smile.data.Attribute.AttributeType;
+import hivemall.smile.data.AttributeType;
 import hivemall.smile.utils.SmileExtUtils;
 import hivemall.utils.collections.lists.IntArrayList;
 import hivemall.utils.lang.ObjectUtils;
 import hivemall.utils.lang.StringUtils;
 import hivemall.utils.lang.mutable.MutableInt;
 import hivemall.utils.sampling.IntReservoirSampler;
+import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
+import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
 import smile.classification.Classifier;
 import smile.math.Math;
 
@@ -114,7 +115,7 @@ public final class DecisionTree implements 
Classifier<Vector> {
      * The attributes of independent variable.
      */
     @Nonnull
-    private final Attribute[] _attributes;
+    private final AttributeType[] _attributes;
     private final boolean _hasNumericType;
     /**
      * Variable importance. Every time a split of a node is made on variable 
the (GINI, information
@@ -686,9 +687,8 @@ public final class DecisionTree implements 
Classifier<Vector> {
                 final double impurity, final int j, @Nullable final int[] 
samples) {
             final Node splitNode = new Node();
 
-            if (_attributes[j].type == AttributeType.NOMINAL) {
-                final int m = _attributes[j].getSize();
-                final int[][] trueCount = new int[m][_k];
+            if (_attributes[j] == AttributeType.NOMINAL) {
+                final Int2ObjectMap<int[]> trueCount = new 
Int2ObjectOpenHashMap<int[]>();
 
                 for (int i = 0, size = bags.length; i < size; i++) {
                     int index = bags[i];
@@ -697,11 +697,18 @@ public final class DecisionTree implements 
Classifier<Vector> {
                         continue;
                     }
                     int x_ij = (int) v;
-                    trueCount[x_ij][y[index]]++;
+                    int[] tc_x = trueCount.get(x_ij);
+                    if (tc_x == null) {
+                        tc_x = new int[_k];
+                    }
+                    tc_x[y[index]]++;
                 }
 
-                for (int l = 0; l < m; l++) {
-                    final int tc = Math.sum(trueCount[l]);
+                for (Int2ObjectMap.Entry<int[]> e : 
trueCount.int2ObjectEntrySet()) {
+                    final int l = e.getIntKey();
+                    final int[] trueCount_l = e.getValue();
+
+                    final int tc = Math.sum(trueCount_l);
                     final int fc = n - tc;
 
                     // skip splitting this feature.
@@ -710,11 +717,11 @@ public final class DecisionTree implements 
Classifier<Vector> {
                     }
 
                     for (int q = 0; q < _k; q++) {
-                        falseCount[q] = count[q] - trueCount[l][q];
+                        falseCount[q] = count[q] - trueCount_l[q];
                     }
 
                     final double gain =
-                            impurity - (double) tc / n * 
impurity(trueCount[l], tc, _rule)
+                            impurity - (double) tc / n * impurity(trueCount_l, 
tc, _rule)
                                     - (double) fc / n * impurity(falseCount, 
fc, _rule);
 
                     if (gain > splitNode.splitScore) {
@@ -723,11 +730,11 @@ public final class DecisionTree implements 
Classifier<Vector> {
                         splitNode.splitFeatureType = AttributeType.NOMINAL;
                         splitNode.splitValue = l;
                         splitNode.splitScore = gain;
-                        splitNode.trueChildOutput = 
Math.whichMax(trueCount[l]);
+                        splitNode.trueChildOutput = Math.whichMax(trueCount_l);
                         splitNode.falseChildOutput = Math.whichMax(falseCount);
                     }
                 }
-            } else if (_attributes[j].type == AttributeType.NUMERIC) {
+            } else if (_attributes[j] == AttributeType.NUMERIC) {
                 final int[] trueCount = new int[_k];
 
                 _order.eachNonNullInColumn(j, new VectorProcedure() {
@@ -788,8 +795,7 @@ public final class DecisionTree implements 
Classifier<Vector> {
                     }//apply()
                 });
             } else {
-                throw new IllegalStateException(
-                    "Unsupported attribute type: " + _attributes[j].type);
+                throw new IllegalStateException("Unsupported attribute type: " 
+ _attributes[j]);
             }
 
             return splitNode;
@@ -953,12 +959,12 @@ public final class DecisionTree implements 
Classifier<Vector> {
         return impurity;
     }
 
-    public DecisionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, 
@Nonnull int[] y,
+    public DecisionTree(@Nullable AttributeType[] attributes, @Nonnull Matrix 
x, @Nonnull int[] y,
             int numLeafs) {
         this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, numLeafs, 2, 
1, null, null, SplitRule.GINI, null);
     }
 
-    public DecisionTree(@Nullable Attribute[] attributes, @Nullable Matrix x, 
@Nullable int[] y,
+    public DecisionTree(@Nullable AttributeType[] attributes, @Nullable Matrix 
x, @Nullable int[] y,
             int numLeafs, @Nullable PRNG rand) {
         this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, numLeafs, 2, 
1, null, null, SplitRule.GINI, rand);
     }
@@ -980,7 +986,7 @@ public final class DecisionTree implements 
Classifier<Vector> {
      * @param rule the splitting rule.
      * @param seed
      */
-    public DecisionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, 
@Nonnull int[] y,
+    public DecisionTree(@Nullable AttributeType[] attributes, @Nonnull Matrix 
x, @Nonnull int[] y,
             int numVars, int maxDepth, int maxLeafs, int minSplits, int 
minLeafSize,
             @Nullable int[] bags, @Nullable ColumnMajorIntMatrix order, 
@Nonnull SplitRule rule,
             @Nullable PRNG rand) {
diff --git 
a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
 
b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
index 9edc63d..5feaa36 100644
--- 
a/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
+++ 
b/core/src/main/java/hivemall/smile/classification/GradientTreeBoostingClassifierUDTF.java
@@ -30,7 +30,7 @@ import hivemall.math.vector.DenseVector;
 import hivemall.math.vector.SparseVector;
 import hivemall.math.vector.Vector;
 import hivemall.math.vector.VectorProcedure;
-import hivemall.smile.data.Attribute;
+import hivemall.smile.data.AttributeType;
 import hivemall.smile.regression.RegressionTree;
 import hivemall.smile.utils.SmileExtUtils;
 import hivemall.utils.codec.Base91;
@@ -43,8 +43,8 @@ import hivemall.utils.math.MathUtils;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.BitSet;
-import java.util.Map;
 import java.util.HashMap;
+import java.util.Map;
 
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
@@ -112,7 +112,7 @@ public final class GradientTreeBoostingClassifierUDTF 
extends UDTFWithOptions {
     private int _minSamplesSplit;
     private int _minSamplesLeaf;
     private long _seed;
-    private Attribute[] _attributes;
+    private AttributeType[] _attributes;
 
     @Nullable
     private Reporter _progressReporter;
@@ -151,7 +151,7 @@ public final class GradientTreeBoostingClassifierUDTF 
extends UDTFWithOptions {
         int maxLeafs = Integer.MAX_VALUE, minSplit = 5, minSamplesLeaf = 1;
         float numVars = -1.f;
         double eta = 0.05d, subsample = 0.7d;
-        Attribute[] attrs = null;
+        AttributeType[] attrs = null;
         long seed = -1L;
 
         CommandLine cl = null;
diff --git 
a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
 
b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
index b2f5e9e..7f2966b 100644
--- 
a/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
+++ 
b/core/src/main/java/hivemall/smile/classification/RandomForestClassifierUDTF.java
@@ -32,7 +32,7 @@ import hivemall.math.random.RandomNumberGeneratorFactory;
 import hivemall.math.vector.Vector;
 import hivemall.math.vector.VectorProcedure;
 import hivemall.smile.classification.DecisionTree.SplitRule;
-import hivemall.smile.data.Attribute;
+import hivemall.smile.data.AttributeType;
 import hivemall.smile.utils.SmileExtUtils;
 import hivemall.smile.utils.SmileTaskExecutor;
 import hivemall.utils.codec.Base91;
@@ -114,7 +114,7 @@ public final class RandomForestClassifierUDTF extends 
UDTFWithOptions {
     private int _minSamplesSplit;
     private int _minSamplesLeaf;
     private long _seed;
-    private Attribute[] _attributes;
+    private AttributeType[] _attributes;
     private SplitRule _splitRule;
     private boolean _stratifiedSampling;
     private double _subsample;
@@ -159,7 +159,7 @@ public final class RandomForestClassifierUDTF extends 
UDTFWithOptions {
         int trees = 50, maxDepth = Integer.MAX_VALUE;
         int numLeafs = Integer.MAX_VALUE, minSplits = 2, minSamplesLeaf = 1;
         float numVars = -1.f;
-        Attribute[] attrs = null;
+        AttributeType[] attrs = null;
         long seed = -1L;
         SplitRule splitRule = SplitRule.GINI;
         double[] classWeight = null;
@@ -367,7 +367,7 @@ public final class RandomForestClassifierUDTF extends 
UDTFWithOptions {
         x = SmileExtUtils.shuffle(x, y, _seed);
 
         int[] labels = SmileExtUtils.classLabels(y);
-        Attribute[] attributes = SmileExtUtils.attributeTypes(_attributes, x);
+        AttributeType[] attributes = SmileExtUtils.attributeTypes(_attributes, 
x);
         int numInputVars = SmileExtUtils.computeNumInputVars(_numVars, x);
 
         if (logger.isInfoEnabled()) {
@@ -455,7 +455,7 @@ public final class RandomForestClassifierUDTF extends 
UDTFWithOptions {
          * Attribute properties.
          */
         @Nonnull
-        private final Attribute[] _attributes;
+        private final AttributeType[] _attributes;
         /**
          * Training instances.
          */
@@ -491,9 +491,9 @@ public final class RandomForestClassifierUDTF extends 
UDTFWithOptions {
         private final AtomicInteger _remainingTasks;
 
         TrainingTask(@Nonnull RandomForestClassifierUDTF udtf, int taskId,
-                @Nonnull Attribute[] attributes, @Nonnull Matrix x, @Nonnull 
int[] y, int numVars,
-                @Nonnull ColumnMajorIntMatrix order, @Nonnull IntMatrix 
prediction, long seed,
-                @Nonnull AtomicInteger remainingTasks) {
+                @Nonnull AttributeType[] attributes, @Nonnull Matrix x, 
@Nonnull int[] y,
+                int numVars, @Nonnull ColumnMajorIntMatrix order, @Nonnull 
IntMatrix prediction,
+                long seed, @Nonnull AtomicInteger remainingTasks) {
             this._udtf = udtf;
             this._taskId = taskId;
             this._attributes = attributes;
diff --git a/core/src/main/java/hivemall/smile/data/Attribute.java 
b/core/src/main/java/hivemall/smile/data/Attribute.java
deleted file mode 100644
index f9cb5a6..0000000
--- a/core/src/main/java/hivemall/smile/data/Attribute.java
+++ /dev/null
@@ -1,167 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.smile.data;
-
-import hivemall.annotations.BackwardCompatibility;
-import hivemall.annotations.Immutable;
-import hivemall.annotations.Mutable;
-
-import java.io.IOException;
-import java.io.ObjectInput;
-import java.io.ObjectOutput;
-
-public abstract class Attribute {
-
-    public final AttributeType type;
-
-    Attribute(AttributeType type) {
-        this.type = type;
-    }
-
-    public void setSize(int size) {
-        throw new UnsupportedOperationException();
-    }
-
-    /**
-     * @return -1 if not set
-     */
-    public int getSize() {
-        throw new UnsupportedOperationException();
-    }
-
-    public void writeTo(ObjectOutput out) throws IOException {
-        out.writeByte(type.getTypeId());
-    }
-
-    public enum AttributeType {
-        NUMERIC((byte) 1), NOMINAL((byte) 2);
-
-        private final byte id;
-
-        private AttributeType(byte id) {
-            this.id = id;
-        }
-
-        public byte getTypeId() {
-            return id;
-        }
-
-        public static AttributeType resolve(byte id) {
-            final AttributeType type;
-            switch (id) {
-                case 1:
-                    type = NUMERIC;
-                    break;
-                case 2:
-                    type = NOMINAL;
-                    break;
-                default:
-                    throw new IllegalStateException("Unexpected type: " + id);
-            }
-            return type;
-        }
-
-        @BackwardCompatibility
-        public static AttributeType resolve(int id) {
-            final AttributeType type;
-            switch (id) {
-                case 1:
-                    type = NUMERIC;
-                    break;
-                case 2:
-                    type = NOMINAL;
-                    break;
-                default:
-                    throw new IllegalStateException("Unexpected type: " + id);
-            }
-            return type;
-        }
-
-    }
-
-    @Immutable
-    public static final class NumericAttribute extends Attribute {
-
-        public NumericAttribute() {
-            super(AttributeType.NUMERIC);
-        }
-
-        @Override
-        public String toString() {
-            return "NumericAttribute [type=" + type + "]";
-        }
-
-    }
-
-    @Mutable
-    public static final class NominalAttribute extends Attribute {
-
-        private int size;
-
-        public NominalAttribute() {
-            super(AttributeType.NOMINAL);
-            this.size = -1;
-        }
-
-        @Override
-        public int getSize() {
-            return size;
-        }
-
-        @Override
-        public void setSize(int size) {
-            this.size = size;
-        }
-
-        @Override
-        public void writeTo(ObjectOutput out) throws IOException {
-            super.writeTo(out);
-            out.writeInt(size);
-        }
-
-        @Override
-        public String toString() {
-            return "NominalAttribute [size=" + size + ", type=" + type + "]";
-        }
-
-    }
-
-    public static Attribute readFrom(ObjectInput in) throws IOException {
-        final Attribute attr;
-
-        byte typeId = in.readByte();
-        final AttributeType type = AttributeType.resolve(typeId);
-        switch (type) {
-            case NUMERIC: {
-                attr = new NumericAttribute();
-                break;
-            }
-            case NOMINAL: {
-                attr = new NominalAttribute();
-                int size = in.readInt();
-                attr.setSize(size);
-                break;
-            }
-            default:
-                throw new IllegalStateException("Unexpected type: " + type);
-        }
-        return attr;
-    }
-
-}
diff --git a/core/src/main/java/hivemall/smile/data/AttributeType.java 
b/core/src/main/java/hivemall/smile/data/AttributeType.java
new file mode 100644
index 0000000..559ad2d
--- /dev/null
+++ b/core/src/main/java/hivemall/smile/data/AttributeType.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package hivemall.smile.data;
+
+import hivemall.annotations.BackwardCompatibility;
+
+public enum AttributeType {
+    NUMERIC((byte) 1), NOMINAL((byte) 2);
+
+    private final byte id;
+
+    private AttributeType(byte id) {
+        this.id = id;
+    }
+
+    public byte getTypeId() {
+        return id;
+    }
+
+    public static AttributeType resolve(final byte id) {
+        final AttributeType type;
+        switch (id) {
+            case 1:
+                type = NUMERIC;
+                break;
+            case 2:
+                type = NOMINAL;
+                break;
+            default:
+                throw new IllegalStateException("Unexpected type: " + id);
+        }
+        return type;
+    }
+
+    @BackwardCompatibility
+    public static AttributeType resolve(final int id) {
+        final AttributeType type;
+        switch (id) {
+            case 1:
+                type = NUMERIC;
+                break;
+            case 2:
+                type = NOMINAL;
+                break;
+            default:
+                throw new IllegalStateException("Unexpected type: " + id);
+        }
+        return type;
+    }
+
+}
\ No newline at end of file
diff --git 
a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java 
b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
index dc148e2..ec2e25d 100644
--- 
a/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
+++ 
b/core/src/main/java/hivemall/smile/regression/RandomForestRegressionUDTF.java
@@ -28,7 +28,7 @@ import hivemall.math.random.PRNG;
 import hivemall.math.random.RandomNumberGeneratorFactory;
 import hivemall.math.vector.Vector;
 import hivemall.math.vector.VectorProcedure;
-import hivemall.smile.data.Attribute;
+import hivemall.smile.data.AttributeType;
 import hivemall.smile.utils.SmileExtUtils;
 import hivemall.smile.utils.SmileTaskExecutor;
 import hivemall.utils.codec.Base91;
@@ -107,7 +107,7 @@ public final class RandomForestRegressionUDTF extends 
UDTFWithOptions {
     private int _minSamplesSplit;
     private int _minSamplesLeaf;
     private long _seed;
-    private Attribute[] _attributes;
+    private AttributeType[] _attributes;
 
     @Nullable
     private Reporter _progressReporter;
@@ -145,7 +145,7 @@ public final class RandomForestRegressionUDTF extends 
UDTFWithOptions {
         int trees = 50, maxDepth = Integer.MAX_VALUE;
         int maxLeafs = Integer.MAX_VALUE, minSplit = 5, minSamplesLeaf = 1;
         float numVars = -1.f;
-        Attribute[] attrs = null;
+        AttributeType[] attrs = null;
         long seed = -1L;
 
         CommandLine cl = null;
@@ -330,7 +330,7 @@ public final class RandomForestRegressionUDTF extends 
UDTFWithOptions {
         // Shuffle training samples
         x = SmileExtUtils.shuffle(x, y, _seed);
 
-        Attribute[] attributes = SmileExtUtils.attributeTypes(_attributes, x);
+        AttributeType[] attributes = SmileExtUtils.attributeTypes(_attributes, 
x);
         int numInputVars = SmileExtUtils.computeNumInputVars(_numVars, x);
 
         if (logger.isInfoEnabled()) {
@@ -417,7 +417,7 @@ public final class RandomForestRegressionUDTF extends 
UDTFWithOptions {
         /**
          * Attribute properties.
          */
-        private final Attribute[] _attributes;
+        private final AttributeType[] _attributes;
         /**
          * Training instances.
          */
@@ -449,9 +449,9 @@ public final class RandomForestRegressionUDTF extends 
UDTFWithOptions {
         private final long _seed;
         private final AtomicInteger _remainingTasks;
 
-        TrainingTask(RandomForestRegressionUDTF udtf, int taskId, Attribute[] 
attributes, Matrix x,
-                double[] y, int numVars, ColumnMajorIntMatrix order, double[] 
prediction, int[] oob,
-                long seed, AtomicInteger remainingTasks) {
+        TrainingTask(RandomForestRegressionUDTF udtf, int taskId, 
AttributeType[] attributes,
+                Matrix x, double[] y, int numVars, ColumnMajorIntMatrix order, 
double[] prediction,
+                int[] oob, long seed, AtomicInteger remainingTasks) {
             this._udtf = udtf;
             this._taskId = taskId;
             this._attributes = attributes;
diff --git a/core/src/main/java/hivemall/smile/regression/RegressionTree.java 
b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
index f1fe7b0..0e42094 100755
--- a/core/src/main/java/hivemall/smile/regression/RegressionTree.java
+++ b/core/src/main/java/hivemall/smile/regression/RegressionTree.java
@@ -28,8 +28,7 @@ import hivemall.math.vector.DenseVector;
 import hivemall.math.vector.SparseVector;
 import hivemall.math.vector.Vector;
 import hivemall.math.vector.VectorProcedure;
-import hivemall.smile.data.Attribute;
-import hivemall.smile.data.Attribute.AttributeType;
+import hivemall.smile.data.AttributeType;
 import hivemall.smile.utils.SmileExtUtils;
 import hivemall.utils.collections.lists.IntArrayList;
 import hivemall.utils.collections.sets.IntArraySet;
@@ -38,6 +37,9 @@ import hivemall.utils.lang.ObjectUtils;
 import hivemall.utils.lang.StringUtils;
 import hivemall.utils.lang.mutable.MutableInt;
 import hivemall.utils.math.MathUtils;
+import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;
+import it.unimi.dsi.fastutil.ints.Int2IntMap.Entry;
+import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
 import smile.math.Math;
 import smile.regression.GradientTreeBoost;
 import smile.regression.RandomForest;
@@ -100,7 +102,7 @@ public final class RegressionTree implements 
Regression<Vector> {
     /**
      * The attributes of independent variable.
      */
-    private final Attribute[] _attributes;
+    private final AttributeType[] _attributes;
     private final boolean _hasNumericType;
     /**
      * Variable importance. Every time a split of a node is made on variable 
the impurity criterion
@@ -619,10 +621,12 @@ public final class RegressionTree implements 
Regression<Vector> {
         private Node findBestSplit(final int n, final double sum, final int j,
                 @Nullable final int[] samples) {
             final Node split = new Node(0.d);
-            if (_attributes[j].type == AttributeType.NOMINAL) {
-                final int m = _attributes[j].getSize();
-                final double[] trueSum = new double[m];
-                final int[] trueCount = new int[m];
+            if (_attributes[j] == AttributeType.NOMINAL) {
+                //final int m = _attributes[j].getSize();
+                //final double[] trueSum = new double[m];
+                //final int[] trueCount = new int[m];
+                final Int2DoubleOpenHashMap trueSum = new 
Int2DoubleOpenHashMap();
+                final Int2IntOpenHashMap trueCount = new Int2IntOpenHashMap();
 
                 for (int b = 0, size = bags.length; b < size; b++) {
                     int i = bags[b];
@@ -634,12 +638,15 @@ public final class RegressionTree implements 
Regression<Vector> {
                         continue;
                     }
                     int index = (int) v;
-                    trueSum[index] += y[i];
-                    ++trueCount[index];
+
+                    trueSum.addTo(index, y[i]);
+                    trueCount.addTo(index, 1);
                 }
 
-                for (int k = 0; k < m; k++) {
-                    final double tc = (double) trueCount[k];
+                for (Entry e : trueCount.int2IntEntrySet()) {
+                    final int k = e.getIntKey();
+                    final double tc = e.getIntValue();
+
                     final double fc = n - tc;
 
                     // skip splitting
@@ -648,8 +655,9 @@ public final class RegressionTree implements 
Regression<Vector> {
                     }
 
                     // compute penalized means
-                    final double trueMean = trueSum[k] / tc;
-                    final double falseMean = (sum - trueSum[k]) / fc;
+                    double trueSum_k = trueSum.get(k);
+                    final double trueMean = trueSum_k / tc;
+                    final double falseMean = (sum - trueSum_k) / fc;
 
                     final double gain = (tc * trueMean * trueMean + fc * 
falseMean * falseMean)
                             - n * split.output * split.output;
@@ -663,7 +671,7 @@ public final class RegressionTree implements 
Regression<Vector> {
                         split.falseChildOutput = falseMean;
                     }
                 }
-            } else if (_attributes[j].type == AttributeType.NUMERIC) {
+            } else if (_attributes[j] == AttributeType.NUMERIC) {
 
                 _order.eachNonNullInColumn(j, new VectorProcedure() {
                     double trueSum = 0.0;
@@ -725,8 +733,7 @@ public final class RegressionTree implements 
Regression<Vector> {
                 });
 
             } else {
-                throw new IllegalStateException(
-                    "Unsupported attribute type: " + _attributes[j].type);
+                throw new IllegalStateException("Unsupported attribute type: " 
+ _attributes[j]);
             }
 
             return split;
@@ -827,19 +834,20 @@ public final class RegressionTree implements 
Regression<Vector> {
 
     }
 
-    public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, 
@Nonnull double[] y,
-            int maxLeafs) {
+    public RegressionTree(@Nullable AttributeType[] attributes, @Nonnull 
Matrix x,
+            @Nonnull double[] y, int maxLeafs) {
         this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, maxLeafs, 5, 
1, null, null, null);
     }
 
-    public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, 
@Nonnull double[] y,
-            int maxLeafs, @Nullable PRNG rand) {
+    public RegressionTree(@Nullable AttributeType[] attributes, @Nonnull 
Matrix x,
+            @Nonnull double[] y, int maxLeafs, @Nullable PRNG rand) {
         this(attributes, x, y, x.numColumns(), Integer.MAX_VALUE, maxLeafs, 5, 
1, null, null, rand);
     }
 
-    public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, 
@Nonnull double[] y,
-            int numVars, int maxDepth, int maxLeafs, int minSplits, int 
minLeafSize,
-            @Nullable ColumnMajorIntMatrix order, @Nullable int[] bags, 
@Nullable PRNG rand) {
+    public RegressionTree(@Nullable AttributeType[] attributes, @Nonnull 
Matrix x,
+            @Nonnull double[] y, int numVars, int maxDepth, int maxLeafs, int 
minSplits,
+            int minLeafSize, @Nullable ColumnMajorIntMatrix order, @Nullable 
int[] bags,
+            @Nullable PRNG rand) {
         this(attributes, x, y, numVars, maxDepth, maxLeafs, minSplits, 
minLeafSize, order, bags, null, rand);
     }
 
@@ -859,10 +867,10 @@ public final class RegressionTree implements 
Regression<Vector> {
      * @param bags the sample set of instances for stochastic learning.
      * @param output An interface to calculate node output.
      */
-    public RegressionTree(@Nullable Attribute[] attributes, @Nonnull Matrix x, 
@Nonnull double[] y,
-            int numVars, int maxDepth, int maxLeafs, int minSplits, int 
minLeafSize,
-            @Nullable ColumnMajorIntMatrix order, @Nullable int[] bags, 
@Nullable NodeOutput output,
-            @Nullable PRNG rand) {
+    public RegressionTree(@Nullable AttributeType[] attributes, @Nonnull 
Matrix x,
+            @Nonnull double[] y, int numVars, int maxDepth, int maxLeafs, int 
minSplits,
+            int minLeafSize, @Nullable ColumnMajorIntMatrix order, @Nullable 
int[] bags,
+            @Nullable NodeOutput output, @Nullable PRNG rand) {
         checkArgument(x, y, numVars, maxDepth, maxLeafs, minSplits, 
minLeafSize);
 
         this._attributes = SmileExtUtils.attributeTypes(attributes, x);
diff --git a/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java 
b/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java
index 549c984..12afa7c 100644
--- a/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java
+++ b/core/src/main/java/hivemall/smile/tools/TreePredictUDFv1.java
@@ -20,7 +20,7 @@ package hivemall.smile.tools;
 
 import hivemall.annotations.Since;
 import hivemall.annotations.VisibleForTesting;
-import hivemall.smile.data.Attribute.AttributeType;
+import hivemall.smile.data.AttributeType;
 import hivemall.smile.vm.StackMachine;
 import hivemall.smile.vm.VMRuntimeException;
 import hivemall.utils.codec.Base91;
diff --git a/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java 
b/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java
index de7f01e..0c72866 100644
--- a/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java
+++ b/core/src/main/java/hivemall/smile/utils/SmileExtUtils.java
@@ -18,6 +18,7 @@
  */
 package hivemall.smile.utils;
 
+import hivemall.annotations.VisibleForTesting;
 import hivemall.math.matrix.ColumnMajorMatrix;
 import hivemall.math.matrix.Matrix;
 import hivemall.math.matrix.MatrixUtils;
@@ -27,15 +28,14 @@ import hivemall.math.random.PRNG;
 import hivemall.math.random.RandomNumberGeneratorFactory;
 import hivemall.math.vector.VectorProcedure;
 import hivemall.smile.classification.DecisionTree.SplitRule;
-import hivemall.smile.data.Attribute;
-import hivemall.smile.data.Attribute.AttributeType;
-import hivemall.smile.data.Attribute.NominalAttribute;
-import hivemall.smile.data.Attribute.NumericAttribute;
+import hivemall.smile.data.AttributeType;
 import hivemall.utils.collections.lists.DoubleArrayList;
 import hivemall.utils.collections.lists.IntArrayList;
 import hivemall.utils.lang.Preconditions;
-import hivemall.utils.lang.mutable.MutableInt;
 import hivemall.utils.math.MathUtils;
+import smile.data.NominalAttribute;
+import smile.data.NumericAttribute;
+import smile.sort.QuickSort;
 
 import java.util.Arrays;
 
@@ -46,8 +46,6 @@ import javax.annotation.Nullable;
 import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 
-import smile.sort.QuickSort;
-
 public final class SmileExtUtils {
 
     private SmileExtUtils() {}
@@ -56,21 +54,20 @@ public final class SmileExtUtils {
      * Q for {@link NumericAttribute}, C for {@link NominalAttribute}.
      */
     @Nullable
-    public static Attribute[] resolveAttributes(@Nullable final String opt)
+    public static AttributeType[] resolveAttributes(@Nullable final String opt)
             throws UDFArgumentException {
         if (opt == null) {
             return null;
         }
         final String[] opts = opt.split(",");
         final int size = opts.length;
-        final NumericAttribute immutableNumAttr = new NumericAttribute();
-        final Attribute[] attr = new Attribute[size];
+        final AttributeType[] attr = new AttributeType[size];
         for (int i = 0; i < size; i++) {
             final String type = opts[i];
             if ("Q".equals(type)) {
-                attr[i] = immutableNumAttr;
+                attr[i] = AttributeType.NUMERIC;
             } else if ("C".equals(type)) {
-                attr[i] = new NominalAttribute();
+                attr[i] = AttributeType.NOMINAL;
             } else {
                 throw new UDFArgumentException("Unexpected type: " + type);
             }
@@ -79,97 +76,32 @@ public final class SmileExtUtils {
     }
 
     @Nonnull
-    public static Attribute[] attributeTypes(@Nullable final Attribute[] 
attributes,
+    public static AttributeType[] attributeTypes(@Nullable final 
AttributeType[] attributes,
             @Nonnull final Matrix x) {
         if (attributes == null) {
             int p = x.numColumns();
-            Attribute[] newAttributes = new Attribute[p];
-            Arrays.fill(newAttributes, new NumericAttribute());
+            AttributeType[] newAttributes = new AttributeType[p];
+            Arrays.fill(newAttributes, AttributeType.NUMERIC);
             return newAttributes;
         }
-
-        if (x.isRowMajorMatrix()) {
-            final VectorProcedure proc = new VectorProcedure() {
-                @Override
-                public void apply(final int j, final double value) {
-                    final Attribute attr = attributes[j];
-                    if (attr.type == AttributeType.NOMINAL) {
-                        final int x_ij = ((int) value) + 1;
-                        final int prevSize = attr.getSize();
-                        if (x_ij > prevSize) {
-                            attr.setSize(x_ij);
-                        }
-                    }
-                }
-            };
-            for (int i = 0, rows = x.numRows(); i < rows; i++) {
-                x.eachNonNullInRow(i, proc);
-            }
-        } else if (x.isColumnMajorMatrix()) {
-            final MutableInt max_x = new MutableInt(0);
-            final VectorProcedure proc = new VectorProcedure() {
-                @Override
-                public void apply(final int i, final double value) {
-                    final int x_ij = (int) value;
-                    if (x_ij > max_x.getValue()) {
-                        max_x.setValue(x_ij);
-                    }
-                }
-            };
-
-            final int size = attributes.length;
-            for (int j = 0; j < size; j++) {
-                final Attribute attr = attributes[j];
-                if (attr.type == AttributeType.NOMINAL) {
-                    if (attr.getSize() != -1) {
-                        continue;
-                    }
-                    max_x.setValue(0);
-                    x.eachNonNullInColumn(j, proc);
-                    attr.setSize(max_x.getValue() + 1);
-                }
-            }
-        } else {
-            int size = attributes.length;
-            for (int j = 0; j < size; j++) {
-                Attribute attr = attributes[j];
-                if (attr.type == AttributeType.NOMINAL) {
-                    if (attr.getSize() != -1) {
-                        continue;
-                    }
-                    int max_x = 0;
-                    for (int i = 0, rows = x.numRows(); i < rows; i++) {
-                        final double v = x.get(i, j, Double.NaN);
-                        if (Double.isNaN(v)) {
-                            continue;
-                        }
-                        int x_ij = (int) v;
-                        if (x_ij > max_x) {
-                            max_x = x_ij;
-                        }
-                    }
-                    attr.setSize(max_x + 1);
-                }
-            }
-        }
         return attributes;
     }
 
+    @VisibleForTesting
     @Nonnull
-    public static Attribute[] convertAttributeTypes(
+    public static AttributeType[] convertAttributeTypes(
             @Nonnull final smile.data.Attribute[] original) {
         final int size = original.length;
-        final NumericAttribute immutableNumAttr = new NumericAttribute();
-        final Attribute[] dst = new Attribute[size];
+        final AttributeType[] dst = new AttributeType[size];
         for (int i = 0; i < size; i++) {
             smile.data.Attribute o = original[i];
             switch (o.type) {
                 case NOMINAL: {
-                    dst[i] = new NominalAttribute();
+                    dst[i] = AttributeType.NOMINAL;
                     break;
                 }
                 case NUMERIC: {
-                    dst[i] = immutableNumAttr;
+                    dst[i] = AttributeType.NUMERIC;
                     break;
                 }
                 default:
@@ -180,7 +112,7 @@ public final class SmileExtUtils {
     }
 
     @Nonnull
-    public static ColumnMajorIntMatrix sort(@Nonnull final Attribute[] 
attributes,
+    public static ColumnMajorIntMatrix sort(@Nonnull final AttributeType[] 
attributes,
             @Nonnull final Matrix x) {
         final int n = x.numRows();
         final int p = x.numColumns();
@@ -200,7 +132,7 @@ public final class SmileExtUtils {
 
             final ColumnMajorMatrix x2 = x.toColumnMajorMatrix();
             for (int j = 0; j < p; j++) {
-                if (attributes[j].type != AttributeType.NUMERIC) {
+                if (attributes[j] != AttributeType.NUMERIC) {
                     continue;
                 }
                 x2.eachNonNullInColumn(j, proc);
@@ -216,7 +148,7 @@ public final class SmileExtUtils {
         } else {
             final double[] a = new double[n];
             for (int j = 0; j < p; j++) {
-                if (attributes[j].type == AttributeType.NUMERIC) {
+                if (attributes[j] == AttributeType.NUMERIC) {
                     for (int i = 0; i < n; i++) {
                         a[i] = x.get(i, j);
                     }
@@ -388,9 +320,9 @@ public final class SmileExtUtils {
         return samples;
     }
 
-    public static boolean containsNumericType(@Nonnull final Attribute[] 
attributes) {
-        for (Attribute attr : attributes) {
-            if (attr.type == AttributeType.NUMERIC) {
+    public static boolean containsNumericType(@Nonnull final AttributeType[] 
attributes) {
+        for (AttributeType attr : attributes) {
+            if (attr == AttributeType.NUMERIC) {
                 return true;
             }
         }
diff --git a/core/src/main/java/hivemall/utils/hadoop/JsonSerdeUtils.java 
b/core/src/main/java/hivemall/utils/hadoop/JsonSerdeUtils.java
index d26019c..dff97a3 100644
--- a/core/src/main/java/hivemall/utils/hadoop/JsonSerdeUtils.java
+++ b/core/src/main/java/hivemall/utils/hadoop/JsonSerdeUtils.java
@@ -28,7 +28,6 @@ import java.nio.charset.CharacterCodingException;
 import java.sql.Date;
 import java.sql.Timestamp;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collections;
 import java.util.LinkedHashMap;
 import java.util.List;
diff --git 
a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java 
b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
index b789e71..3f287af 100644
--- a/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
+++ b/core/src/test/java/hivemall/smile/classification/DecisionTreeTest.java
@@ -25,7 +25,7 @@ import hivemall.math.matrix.builders.CSRMatrixBuilder;
 import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
 import hivemall.math.random.RandomNumberGeneratorFactory;
 import hivemall.smile.classification.DecisionTree.Node;
-import hivemall.smile.data.Attribute;
+import hivemall.smile.data.AttributeType;
 import hivemall.smile.tools.TreeExportUDF.Evaluator;
 import hivemall.smile.tools.TreeExportUDF.OutputType;
 import hivemall.smile.utils.SmileExtUtils;
@@ -165,7 +165,7 @@ public class DecisionTreeTest {
         double[][] x = ds.toArray(new double[ds.size()][]);
         int[] y = ds.toArray(new int[ds.size()]);
 
-        Attribute[] attrs = 
SmileExtUtils.convertAttributeTypes(ds.attributes());
+        AttributeType[] attrs = 
SmileExtUtils.convertAttributeTypes(ds.attributes());
         DecisionTree tree = new DecisionTree(attrs, matrix(x, dense), y, 
numLeafs,
             RandomNumberGeneratorFactory.createPRNG(31));
 
@@ -196,7 +196,7 @@ public class DecisionTreeTest {
             double[][] trainx = Math.slice(x, loocv.train[i]);
             int[] trainy = Math.slice(y, loocv.train[i]);
 
-            Attribute[] attrs = 
SmileExtUtils.convertAttributeTypes(ds.attributes());
+            AttributeType[] attrs = 
SmileExtUtils.convertAttributeTypes(ds.attributes());
             DecisionTree tree = new DecisionTree(attrs, matrix(trainx, dense), 
trainy, numLeafs,
                 RandomNumberGeneratorFactory.createPRNG(i));
             if (y[loocv.test[i]] != tree.predict(x[loocv.test[i]])) {
@@ -226,7 +226,7 @@ public class DecisionTreeTest {
             double[][] trainx = Math.slice(x, loocv.train[i]);
             int[] trainy = Math.slice(y, loocv.train[i]);
 
-            Attribute[] attrs = 
SmileExtUtils.convertAttributeTypes(ds.attributes());
+            AttributeType[] attrs = 
SmileExtUtils.convertAttributeTypes(ds.attributes());
             DecisionTree dtree = new DecisionTree(attrs, matrix(trainx, true), 
trainy, numLeafs,
                 RandomNumberGeneratorFactory.createPRNG(i));
             DecisionTree stree = new DecisionTree(attrs, matrix(trainx, 
false), trainy, numLeafs,
@@ -253,7 +253,7 @@ public class DecisionTreeTest {
             double[][] trainx = Math.slice(x, loocv.train[i]);
             int[] trainy = Math.slice(y, loocv.train[i]);
 
-            Attribute[] attrs = 
SmileExtUtils.convertAttributeTypes(iris.attributes());
+            AttributeType[] attrs = 
SmileExtUtils.convertAttributeTypes(iris.attributes());
             DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), 
trainy, 4);
 
             byte[] b = tree.serialize(false);
@@ -280,7 +280,7 @@ public class DecisionTreeTest {
             double[][] trainx = Math.slice(x, loocv.train[i]);
             int[] trainy = Math.slice(y, loocv.train[i]);
 
-            Attribute[] attrs = 
SmileExtUtils.convertAttributeTypes(iris.attributes());
+            AttributeType[] attrs = 
SmileExtUtils.convertAttributeTypes(iris.attributes());
             DecisionTree tree = new DecisionTree(attrs, matrix(trainx, true), 
trainy, 4);
 
             byte[] b1 = tree.serialize(true);
diff --git 
a/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java 
b/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
index 9d24b54..75aa65a 100644
--- a/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
+++ b/core/src/test/java/hivemall/smile/regression/RegressionTreeTest.java
@@ -22,8 +22,7 @@ import hivemall.math.matrix.Matrix;
 import hivemall.math.matrix.builders.CSRMatrixBuilder;
 import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
 import hivemall.math.random.RandomNumberGeneratorFactory;
-import hivemall.smile.data.Attribute;
-import hivemall.smile.data.Attribute.NumericAttribute;
+import hivemall.smile.data.AttributeType;
 import hivemall.smile.tools.TreeExportUDF.Evaluator;
 import hivemall.smile.tools.TreeExportUDF.OutputType;
 import hivemall.utils.codec.Base91;
@@ -67,8 +66,8 @@ public class RegressionTreeTest {
         double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 
104.6, 108.4, 110.8,
                 112.6, 114.2, 115.7, 116.9};
 
-        Attribute[] attrs = new Attribute[longley[0].length];
-        Arrays.fill(attrs, new NumericAttribute());
+        AttributeType[] attrs = new AttributeType[longley[0].length];
+        Arrays.fill(attrs, AttributeType.NUMERIC);
 
         int n = longley.length;
         LOOCV loocv = new LOOCV(n);
@@ -110,8 +109,8 @@ public class RegressionTreeTest {
         double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 
104.6, 108.4, 110.8,
                 112.6, 114.2, 115.7, 116.9};
 
-        Attribute[] attrs = new Attribute[longley[0].length];
-        Arrays.fill(attrs, new NumericAttribute());
+        AttributeType[] attrs = new AttributeType[longley[0].length];
+        Arrays.fill(attrs, AttributeType.NUMERIC);
 
         int n = longley.length;
         LOOCV loocv = new LOOCV(n);
@@ -153,8 +152,8 @@ public class RegressionTreeTest {
         double[] y = {83.0, 88.5, 88.2, 89.5, 96.2, 98.1, 99.0, 100.0, 101.2, 
104.6, 108.4, 110.8,
                 112.6, 114.2, 115.7, 116.9};
 
-        Attribute[] attrs = new Attribute[longley[0].length];
-        Arrays.fill(attrs, new NumericAttribute());
+        AttributeType[] attrs = new AttributeType[longley[0].length];
+        Arrays.fill(attrs, AttributeType.NUMERIC);
 
         int n = longley.length;
         LOOCV loocv = new LOOCV(n);
@@ -205,8 +204,8 @@ public class RegressionTreeTest {
     private static String graphvizOutput(double[][] x, double[] y, int 
maxLeafs, boolean dense,
             String[] featureNames, String outputName)
             throws IOException, HiveException, ParseException {
-        Attribute[] attrs = new Attribute[x[0].length];
-        Arrays.fill(attrs, new NumericAttribute());
+        AttributeType[] attrs = new AttributeType[x[0].length];
+        Arrays.fill(attrs, AttributeType.NUMERIC);
         RegressionTree tree = new RegressionTree(attrs, matrix(x, dense), y, 
maxLeafs);
 
         Text model = new Text(Base91.encode(tree.serialize(true)));
diff --git a/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java 
b/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
index c98b087..f44b9ec 100644
--- a/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
+++ b/core/src/test/java/hivemall/smile/tools/TreePredictUDFTest.java
@@ -21,11 +21,17 @@ package hivemall.smile.tools;
 import hivemall.TestUtils;
 import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
 import hivemall.smile.classification.DecisionTree;
-import hivemall.smile.data.Attribute;
+import hivemall.smile.data.AttributeType;
 import hivemall.smile.regression.RegressionTree;
 import hivemall.smile.utils.SmileExtUtils;
 import hivemall.utils.codec.Base91;
 import hivemall.utils.lang.ArrayUtils;
+import smile.data.AttributeDataset;
+import smile.data.parser.ArffParser;
+import smile.math.Math;
+import smile.validation.CrossValidation;
+import smile.validation.LOOCV;
+import smile.validation.RMSE;
 
 import java.io.BufferedInputStream;
 import java.io.IOException;
@@ -46,13 +52,6 @@ import org.apache.hadoop.io.Text;
 import org.junit.Assert;
 import org.junit.Test;
 
-import smile.data.AttributeDataset;
-import smile.data.parser.ArffParser;
-import smile.math.Math;
-import smile.validation.CrossValidation;
-import smile.validation.LOOCV;
-import smile.validation.RMSE;
-
 public class TreePredictUDFTest {
     private static final boolean DEBUG = false;
 
@@ -77,7 +76,7 @@ public class TreePredictUDFTest {
             double[][] trainx = Math.slice(x, loocv.train[i]);
             int[] trainy = Math.slice(y, loocv.train[i]);
 
-            Attribute[] attrs = 
SmileExtUtils.convertAttributeTypes(iris.attributes());
+            AttributeType[] attrs = 
SmileExtUtils.convertAttributeTypes(iris.attributes());
             DecisionTree tree = new DecisionTree(attrs,
                 new RowMajorDenseMatrix2d(trainx, x[0].length), trainy, 4);
             Assert.assertEquals(tree.predict(x[loocv.test[i]]),
@@ -106,7 +105,7 @@ public class TreePredictUDFTest {
             double[] trainy = Math.slice(datay, cv.train[i]);
             double[][] testx = Math.slice(datax, cv.test[i]);
 
-            Attribute[] attrs = 
SmileExtUtils.convertAttributeTypes(data.attributes());
+            AttributeType[] attrs = 
SmileExtUtils.convertAttributeTypes(data.attributes());
             RegressionTree tree = new RegressionTree(attrs,
                 new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 
20);
 
@@ -146,7 +145,7 @@ public class TreePredictUDFTest {
             testy[i - m] = datay[index[i]];
         }
 
-        Attribute[] attrs = 
SmileExtUtils.convertAttributeTypes(data.attributes());
+        AttributeType[] attrs = 
SmileExtUtils.convertAttributeTypes(data.attributes());
         RegressionTree tree = new RegressionTree(attrs,
             new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
         debugPrint(String.format("RMSE = %.4f\n", rmse(tree, testx, testy)));
@@ -241,7 +240,7 @@ public class TreePredictUDFTest {
             testy[i - m] = datay[index[i]];
         }
 
-        Attribute[] attrs = 
SmileExtUtils.convertAttributeTypes(data.attributes());
+        AttributeType[] attrs = 
SmileExtUtils.convertAttributeTypes(data.attributes());
         RegressionTree tree = new RegressionTree(attrs,
             new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
 
diff --git a/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java 
b/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java
index 68edf33..25e1cc6 100644
--- a/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java
+++ b/core/src/test/java/hivemall/smile/tools/TreePredictUDFv1Test.java
@@ -23,7 +23,7 @@ import static org.junit.Assert.assertEquals;
 import hivemall.TestUtils;
 import hivemall.math.matrix.dense.RowMajorDenseMatrix2d;
 import hivemall.smile.classification.DecisionTree;
-import hivemall.smile.data.Attribute;
+import hivemall.smile.data.AttributeType;
 import hivemall.smile.regression.RegressionTree;
 import hivemall.smile.tools.TreePredictUDFv1.DtNodeV1;
 import hivemall.smile.tools.TreePredictUDFv1.JavaSerializationEvaluator;
@@ -97,7 +97,7 @@ public class TreePredictUDFv1Test {
             double[][] trainx = Math.slice(x, loocv.train[i]);
             int[] trainy = Math.slice(y, loocv.train[i]);
 
-            Attribute[] attrs = 
SmileExtUtils.convertAttributeTypes(iris.attributes());
+            AttributeType[] attrs = 
SmileExtUtils.convertAttributeTypes(iris.attributes());
             DecisionTree tree = new DecisionTree(attrs,
                 new RowMajorDenseMatrix2d(trainx, x[0].length), trainy, 4);
             assertEquals(tree.predict(x[loocv.test[i]]), evalPredict(tree, 
x[loocv.test[i]]));
@@ -125,7 +125,7 @@ public class TreePredictUDFv1Test {
             double[] trainy = Math.slice(datay, cv.train[i]);
             double[][] testx = Math.slice(datax, cv.test[i]);
 
-            Attribute[] attrs = 
SmileExtUtils.convertAttributeTypes(data.attributes());
+            AttributeType[] attrs = 
SmileExtUtils.convertAttributeTypes(data.attributes());
             RegressionTree tree = new RegressionTree(attrs,
                 new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 
20);
 
@@ -165,7 +165,7 @@ public class TreePredictUDFv1Test {
             testy[i - m] = datay[index[i]];
         }
 
-        Attribute[] attrs = 
SmileExtUtils.convertAttributeTypes(data.attributes());
+        AttributeType[] attrs = 
SmileExtUtils.convertAttributeTypes(data.attributes());
         RegressionTree tree = new RegressionTree(attrs,
             new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
         debugPrint(String.format("RMSE = %.4f\n", rmse(tree, testx, testy)));
@@ -260,7 +260,7 @@ public class TreePredictUDFv1Test {
             testy[i - m] = datay[index[i]];
         }
 
-        Attribute[] attrs = 
SmileExtUtils.convertAttributeTypes(data.attributes());
+        AttributeType[] attrs = 
SmileExtUtils.convertAttributeTypes(data.attributes());
         RegressionTree tree = new RegressionTree(attrs,
             new RowMajorDenseMatrix2d(trainx, trainx[0].length), trainy, 20);
         String opScript = tree.predictOpCodegen(StackMachine.SEP);
diff --git a/docs/gitbook/binaryclass/news20_rf.md 
b/docs/gitbook/binaryclass/news20_rf.md
index 9a0d1f8..7b1f2be 100644
--- a/docs/gitbook/binaryclass/news20_rf.md
+++ b/docs/gitbook/binaryclass/news20_rf.md
@@ -85,9 +85,10 @@ WITH submit as (
     test t 
     JOIN rf_predicted p on (t.rowid = p.rowid)
 )
-select count(1) / 4996.0
-from submit 
-where actual = predicted;
+select
+  sum(if(actual = predicted, 1, 0)) / count(1) as accuracy
+from
+  submit;
 ```
 
 > 0.8112489991993594

Reply via email to