Repository: incubator-hivemall
Updated Branches:
  refs/heads/master 69aa64b73 -> e9c66f0a1


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c66f0a/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java
----------------------------------------------------------------------
diff --git a/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java 
b/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java
index 33568c7..e05755e 100644
--- a/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java
+++ b/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java
@@ -87,13 +87,13 @@ public abstract class XGBoostPredictUDTF extends 
UDTFWithOptions {
     protected CommandLine processOptions(ObjectInspector[] argOIs) throws 
UDFArgumentException {
         int _batch_size = 128;
         CommandLine cl = null;
-        if(argOIs.length >= 5) {
+        if (argOIs.length >= 5) {
             String rawArgs = HiveUtils.getConstString(argOIs[4]);
             cl = this.parseOptions(rawArgs);
             _batch_size = 
Primitives.parseInt(cl.getOptionValue("_batch_size"), _batch_size);
-            if(_batch_size < 1) {
-                throw new IllegalArgumentException(
-                        "batch_size must be greater than 0: " + _batch_size);
+            if (_batch_size < 1) {
+                throw new IllegalArgumentException("batch_size must be greater 
than 0: "
+                        + _batch_size);
             }
         }
         this.batch_size = _batch_size;
@@ -103,13 +103,12 @@ public abstract class XGBoostPredictUDTF extends 
UDTFWithOptions {
     /** Override this to output predicted results depending on a taks type */
     abstract public StructObjectInspector getReturnOI();
 
-    abstract public void forwardPredicted(
-            final List<LabeledPointWithRowId> testData,
+    abstract public void forwardPredicted(final List<LabeledPointWithRowId> 
testData,
             final float[][] predicted) throws HiveException;
 
     @Override
     public StructObjectInspector initialize(ObjectInspector[] argOIs) throws 
UDFArgumentException {
-        if(argOIs.length != 4 && argOIs.length != 5) {
+        if (argOIs.length != 4 && argOIs.length != 5) {
             throw new UDFArgumentException(this.getClass().getSimpleName()
                     + " takes 4 or 5 arguments: string rowid, string[] 
features, string model_id,"
                     + " array<byte> pred_model [, string options]: " + 
argOIs.length);
@@ -128,9 +127,10 @@ public abstract class XGBoostPredictUDTF extends 
UDTFWithOptions {
         }
     }
 
-    private static DMatrix createDMatrix(final List<LabeledPointWithRowId> 
data) throws XGBoostError {
+    private static DMatrix createDMatrix(final List<LabeledPointWithRowId> 
data)
+            throws XGBoostError {
         final List<LabeledPoint> points = new ArrayList(data.size());
-        for(LabeledPointWithRowId d : data) {
+        for (LabeledPointWithRowId d : data) {
             points.add(d.point);
         }
         return new DMatrix(points.iterator(), "");
@@ -158,22 +158,23 @@ public abstract class XGBoostPredictUDTF extends 
UDTFWithOptions {
 
     @Override
     public void process(Object[] args) throws HiveException {
-        if(args[1] != null) {
+        if (args[1] != null) {
             final String rowId = 
PrimitiveObjectInspectorUtils.getString(args[0], rowIdOI);
             final List<String> features = (List<String>) 
featureListOI.getList(args[1]);
             final String modelId = 
PrimitiveObjectInspectorUtils.getString(args[2], modelIdOI);
-            if(!mapToModel.containsKey(modelId)) {
-                final byte[] predModel = 
PrimitiveObjectInspectorUtils.getBinary(args[3], modelOI).getBytes();
+            if (!mapToModel.containsKey(modelId)) {
+                final byte[] predModel = 
PrimitiveObjectInspectorUtils.getBinary(args[3], modelOI)
+                                                                      
.getBytes();
                 mapToModel.put(modelId, initXgBooster(predModel));
             }
             final LabeledPoint point = XGBoostUtils.parseFeatures(0.f, 
features);
-            if(point != null) {
-                if(!rowBuffer.containsKey(modelId)) {
+            if (point != null) {
+                if (!rowBuffer.containsKey(modelId)) {
                     rowBuffer.put(modelId, new ArrayList());
                 }
                 final List<LabeledPointWithRowId> buf = rowBuffer.get(modelId);
                 buf.add(createLabeledPoint(rowId, point));
-                if(buf.size() >= batch_size) {
+                if (buf.size() >= batch_size) {
                     predictAndFlush(mapToModel.get(modelId), buf);
                 }
             }
@@ -182,7 +183,7 @@ public abstract class XGBoostPredictUDTF extends 
UDTFWithOptions {
 
     @Override
     public void close() throws HiveException {
-        for(Entry<String, List<LabeledPointWithRowId>> e : 
rowBuffer.entrySet()) {
+        for (Entry<String, List<LabeledPointWithRowId>> e : 
rowBuffer.entrySet()) {
             predictAndFlush(mapToModel.get(e.getKey()), e.getValue());
         }
     }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c66f0a/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java
----------------------------------------------------------------------
diff --git a/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java 
b/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java
index b269549..b57925a 100644
--- a/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java
+++ b/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java
@@ -41,8 +41,8 @@ import hivemall.utils.hadoop.HadoopUtils;
 import hivemall.utils.hadoop.HiveUtils;
 
 /**
- * This is a base class to handle the options for XGBoost and provide
- * common functions among various tasks.
+ * This is a base class to handle the options for XGBoost and provide common 
functions among various
+ * tasks.
  */
 public abstract class XGBoostUDTF extends UDTFWithOptions {
     private static final Log logger = LogFactory.getLog(XGBoostUDTF.class);
@@ -104,33 +104,49 @@ public abstract class XGBoostUDTF extends UDTFWithOptions 
{
         final Options opts = new Options();
 
         /** General parameters */
-        opts.addOption("booster", true, "Set a booster to use, gbtree or 
gblinear. [default: gbree]");
+        opts.addOption("booster", true,
+            "Set a booster to use, gbtree or gblinear. [default: gbree]");
         opts.addOption("num_round", true, "Number of boosting iterations 
[default: 8]");
-        opts.addOption("silent", true, "0 means printing running messages, 1 
means silent mode [default: 1]");
-        opts.addOption("nthread", true, "Number of parallel threads used to 
run xgboost [default: 1]");
-        opts.addOption("num_pbuffer", true, "Size of prediction buffer [set 
automatically by xgboost]");
-        opts.addOption("num_feature", true, "Feature dimension used in 
boosting [default: set automatically by xgboost]");
+        opts.addOption("silent", true,
+            "0 means printing running messages, 1 means silent mode [default: 
1]");
+        opts.addOption("nthread", true,
+            "Number of parallel threads used to run xgboost [default: 1]");
+        opts.addOption("num_pbuffer", true,
+            "Size of prediction buffer [set automatically by xgboost]");
+        opts.addOption("num_feature", true,
+            "Feature dimension used in boosting [default: set automatically by 
xgboost]");
 
         /** Parameters for both boosters */
         opts.addOption("alpha", true, "L1 regularization term on weights 
[default: 0.0]");
-        opts.addOption("lambda", true, "L2 regularization term on weights 
[default: 1.0 for gbtree, 0.0 for gblinear]");
+        opts.addOption("lambda", true,
+            "L2 regularization term on weights [default: 1.0 for gbtree, 0.0 
for gblinear]");
 
         /** Parameters for Tree Booster */
-        opts.addOption("eta", true, "Step size shrinkage used in update to 
prevents overfitting [default: 0.3]");
-        opts.addOption("gamma", true, "Minimum loss reduction required to make 
a further partition on a leaf node of the tree [default: 0.0]");
+        opts.addOption("eta", true,
+            "Step size shrinkage used in update to prevents overfitting 
[default: 0.3]");
+        opts.addOption(
+            "gamma",
+            true,
+            "Minimum loss reduction required to make a further partition on a 
leaf node of the tree [default: 0.0]");
         opts.addOption("max_depth", true, "Max depth of decision tree 
[default: 6]");
-        opts.addOption("min_child_weight", true, "Minimum sum of instance 
weight(hessian) needed in a child [default: 1]");
-        opts.addOption("max_delta_step", true, "Maximum delta step we allow 
each tree's weight estimation to be [default: 0]");
+        opts.addOption("min_child_weight", true,
+            "Minimum sum of instance weight(hessian) needed in a child 
[default: 1]");
+        opts.addOption("max_delta_step", true,
+            "Maximum delta step we allow each tree's weight estimation to be 
[default: 0]");
         opts.addOption("subsample", true, "Subsample ratio of the training 
instance [default: 1.0]");
-        opts.addOption("colsample_bytree", true, "Subsample ratio of columns 
when constructing each tree [default: 1.0]");
-        opts.addOption("colsample_bylevel", true, "Subsample ratio of columns 
for each split, in each level [default: 1.0]");
+        opts.addOption("colsample_bytree", true,
+            "Subsample ratio of columns when constructing each tree [default: 
1.0]");
+        opts.addOption("colsample_bylevel", true,
+            "Subsample ratio of columns for each split, in each level 
[default: 1.0]");
 
         /** Parameters for Linear Booster */
         opts.addOption("lambda_bias", true, "L2 regularization term on bias 
[default: 0.0]");
 
         /** Learning Task Parameters */
-        opts.addOption("base_score", true, "Initial prediction score of all 
instances, global bias [default: 0.5]");
-        opts.addOption("eval_metric", true, "Evaluation metrics for validation 
data [default according to objective]");
+        opts.addOption("base_score", true,
+            "Initial prediction score of all instances, global bias [default: 
0.5]");
+        opts.addOption("eval_metric", true,
+            "Evaluation metrics for validation data [default according to 
objective]");
 
         return opts;
     }
@@ -138,7 +154,7 @@ public abstract class XGBoostUDTF extends UDTFWithOptions {
     @Override
     protected CommandLine processOptions(ObjectInspector[] argOIs) throws 
UDFArgumentException {
         CommandLine cl = null;
-        if(argOIs.length >= 3) {
+        if (argOIs.length >= 3) {
             final String rawArgs = HiveUtils.getConstString(argOIs[2]);
             cl = this.parseOptions(rawArgs);
 
@@ -181,7 +197,8 @@ public abstract class XGBoostUDTF extends UDTFWithOptions {
                 params.put("max_depth", 
Integer.valueOf(cl.getOptionValue("max_depth")));
             }
             if (cl.hasOption("min_child_weight")) {
-                params.put("min_child_weight", 
Integer.valueOf(cl.getOptionValue("min_child_weight")));
+                params.put("min_child_weight",
+                    Integer.valueOf(cl.getOptionValue("min_child_weight")));
             }
             if (cl.hasOption("max_delta_step")) {
                 params.put("max_delta_step", 
Integer.valueOf(cl.getOptionValue("max_delta_step")));
@@ -193,7 +210,8 @@ public abstract class XGBoostUDTF extends UDTFWithOptions {
                 params.put("colsamle_bytree", 
Double.valueOf(cl.getOptionValue("colsample_bytree")));
             }
             if (cl.hasOption("colsample_bylevel")) {
-                params.put("colsamle_bylevel", 
Double.valueOf(cl.getOptionValue("colsample_bylevel")));
+                params.put("colsamle_bylevel",
+                    Double.valueOf(cl.getOptionValue("colsample_bylevel")));
             }
 
             /** Parameters for Linear Booster */
@@ -249,36 +267,34 @@ public abstract class XGBoostUDTF extends UDTFWithOptions 
{
 
     @Override
     public void process(Object[] args) throws HiveException {
-        if(args[0] != null) {
+        if (args[0] != null) {
             // TODO: Need to support dense inputs
             final List<String> features = (List<String>) 
featureListOI.getList(args[0]);
             double target = PrimitiveObjectInspectorUtils.getDouble(args[1], 
this.targetOI);
             checkTargetValue(target);
             final LabeledPoint point = XGBoostUtils.parseFeatures(target, 
features);
-            if(point != null) {
+            if (point != null) {
                 this.featuresList.add(point);
             }
         }
     }
 
     /**
-     * Need to override this for a Spark wrapper because `MapredContext`
-     * does not work in there.
+     * Need to override this for a Spark wrapper because `MapredContext` does 
not work in there.
      */
     protected String generateUniqueModelId() {
         return "xgbmodel-" + String.valueOf(HadoopUtils.getTaskId());
     }
 
-    private static Booster createXGBooster(
-            final Map<String, Object> params,
+    private static Booster createXGBooster(final Map<String, Object> params,
             final List<LabeledPoint> input) throws XGBoostError {
         try {
             Class<?>[] args = {Map.class, DMatrix[].class};
             Constructor<Booster> ctor;
             ctor = Booster.class.getDeclaredConstructor(args);
             ctor.setAccessible(true);
-            return ctor.newInstance(
-                    new Object[]{params, new DMatrix[]{new 
DMatrix(input.iterator(), "")}});
+            return ctor.newInstance(new Object[] {params,
+                    new DMatrix[] {new DMatrix(input.iterator(), "")}});
         } catch (InstantiationException e) {
             // Catch java reflection error as fast as possible
             e.printStackTrace();
@@ -300,7 +316,7 @@ public abstract class XGBoostUDTF extends UDTFWithOptions {
             final DMatrix trainData = new DMatrix(featuresList.iterator(), "");
             final Booster booster = createXGBooster(params, featuresList);
             int num_round = (Integer) params.get("num_round");
-            for(int i = 0; i < num_round; i++) {
+            for (int i = 0; i < num_round; i++) {
                 booster.update(trainData, i);
             }
 
@@ -308,7 +324,7 @@ public abstract class XGBoostUDTF extends UDTFWithOptions {
             final String modelId = generateUniqueModelId();
             final byte[] predModel = booster.toByteArray();
             logger.info("model_id:" + modelId.toString() + " size:" + 
predModel.length);
-            forward(new Object[]{modelId, predModel});
+            forward(new Object[] {modelId, predModel});
         } catch (Exception e) {
             throw new HiveException(e.getMessage());
         }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c66f0a/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java
----------------------------------------------------------------------
diff --git a/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java 
b/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java
index 9705f94..d0769f4 100644
--- a/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java
+++ b/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java
@@ -29,18 +29,18 @@ public final class XGBoostUtils {
     /** Transform List<String> inputs into a XGBoost input format */
     public static LabeledPoint parseFeatures(double target, List<String> 
features) {
         final int size = features.size();
-        if(size == 0) {
+        if (size == 0) {
             return null;
         }
         final int[] indices = new int[size];
         final float[] values = new float[size];
-        for(int i = 0; i < size; i++) {
-            if(features.get(i) == null) {
+        for (int i = 0; i < size; i++) {
+            if (features.get(i) == null) {
                 continue;
             }
             final String str = features.get(i);
             final int pos = str.indexOf(':');
-            if(pos >= 1) {
+            if (pos >= 1) {
                 indices[i] = Integer.parseInt(str.substring(0, pos));
                 values[i] = Float.parseFloat(str.substring(pos + 1));
             }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c66f0a/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTF.java
----------------------------------------------------------------------
diff --git 
a/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTF.java
 
b/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTF.java
index 7e135ec..94282bb 100644
--- 
a/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTF.java
+++ 
b/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTF.java
@@ -24,13 +24,12 @@ import org.apache.hadoop.hive.ql.metadata.HiveException;
 import hivemall.xgboost.XGBoostUDTF;
 
 /**
- * A XGBoost binary classification and the document is as follows;
- *  - https://github.com/dmlc/xgboost/tree/master/demo/binary_classification
+ * A XGBoost binary classification and the document is as follows; -
+ * https://github.com/dmlc/xgboost/tree/master/demo/binary_classification
  */
 @Description(
-    name = "train_xgboost_classifier",
-    value = "_FUNC_(string[] features, double target [, string options]) - 
Returns a relation consisting of <string model_id, array<byte> pred_model>"
-)
+        name = "train_xgboost_classifier",
+        value = "_FUNC_(string[] features, double target [, string options]) - 
Returns a relation consisting of <string model_id, array<byte> pred_model>")
 public class XGBoostBinaryClassifierUDTF extends XGBoostUDTF {
 
     public XGBoostBinaryClassifierUDTF() {}
@@ -43,7 +42,7 @@ public class XGBoostBinaryClassifierUDTF extends XGBoostUDTF {
 
     @Override
     public void checkTargetValue(double target) throws HiveException {
-        if(!(Double.compare(target, 0.0) == 0|| Double.compare(target, 1.0) == 
0)) {
+        if (!(Double.compare(target, 0.0) == 0 || Double.compare(target, 1.0) 
== 0)) {
             throw new HiveException("target must be 0.0 or 1.0: " + target);
         }
     }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c66f0a/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTF.java
----------------------------------------------------------------------
diff --git 
a/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTF.java
 
b/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTF.java
index 171e4bc..3181473 100644
--- 
a/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTF.java
+++ 
b/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTF.java
@@ -28,13 +28,12 @@ import hivemall.xgboost.XGBoostUDTF;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 
 /**
- * A XGBoost multiclass classification and the document is as follows;
- *  - 
https://github.com/dmlc/xgboost/tree/master/demo/multiclass_classification
+ * A XGBoost multiclass classification and the document is as follows; -
+ * https://github.com/dmlc/xgboost/tree/master/demo/multiclass_classification
  */
 @Description(
-    name = "train_multiclass_xgboost_classifier",
-    value = "_FUNC_(string[] features, double target [, string options]) - 
Returns a relation consisting of <string model_id, array<byte> pred_model>"
-)
+        name = "train_multiclass_xgboost_classifier",
+        value = "_FUNC_(string[] features, double target [, string options]) - 
Returns a relation consisting of <string model_id, array<byte> pred_model>")
 public class XGBoostMulticlassClassifierUDTF extends XGBoostUDTF {
 
     public XGBoostMulticlassClassifierUDTF() {}
@@ -56,12 +55,12 @@ public class XGBoostMulticlassClassifierUDTF extends 
XGBoostUDTF {
     @Override
     protected CommandLine processOptions(ObjectInspector[] argOIs) throws 
UDFArgumentException {
         final CommandLine cli = super.processOptions(argOIs);
-        if(cli != null) {
-            if(cli.hasOption("num_class")) {
+        if (cli != null) {
+            if (cli.hasOption("num_class")) {
                 int _num_class = 
Integer.valueOf(cli.getOptionValue("num_class"));
-                if(_num_class < 2) {
-                    throw new UDFArgumentException(
-                            "num_class must be greater than 1: " + _num_class);
+                if (_num_class < 2) {
+                    throw new UDFArgumentException("num_class must be greater 
than 1: "
+                            + _num_class);
                 }
                 params.put("num_class", _num_class);
             }
@@ -72,12 +71,10 @@ public class XGBoostMulticlassClassifierUDTF extends 
XGBoostUDTF {
     @Override
     public void checkTargetValue(double target) throws HiveException {
         double num_class = ((Integer) params.get("num_class")).doubleValue();
-        if(target < 0.0 || target > num_class
+        if (target < 0.0 || target > num_class
                 || Double.compare(target - Math.floor(target), 0.0) != 0) {
-            throw new HiveException(
-                    "target must be {0.0, ..., "
-                            + String.format("%.1f", (num_class - 1.0))
-                            + "}: " + target);
+            throw new HiveException("target must be {0.0, ..., "
+                    + String.format("%.1f", (num_class - 1.0)) + "}: " + 
target);
         }
     }
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c66f0a/xgboost/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTF.java
----------------------------------------------------------------------
diff --git 
a/xgboost/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTF.java 
b/xgboost/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTF.java
index d00b430..98abc8a 100644
--- 
a/xgboost/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTF.java
+++ 
b/xgboost/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTF.java
@@ -24,13 +24,12 @@ import org.apache.hadoop.hive.ql.metadata.HiveException;
 import hivemall.xgboost.XGBoostUDTF;
 
 /**
- * A XGBoost regression and the document is as follows;
- *  - https://github.com/dmlc/xgboost/tree/master/demo/regression
+ * A XGBoost regression and the document is as follows; -
+ * https://github.com/dmlc/xgboost/tree/master/demo/regression
  */
 @Description(
-    name = "train_xgboost_regr",
-    value = "_FUNC_(string[] features, double target [, string options]) - 
Returns a relation consisting of <string model_id, array<byte> pred_model>"
-)
+        name = "train_xgboost_regr",
+        value = "_FUNC_(string[] features, double target [, string options]) - 
Returns a relation consisting of <string model_id, array<byte> pred_model>")
 public class XGBoostRegressionUDTF extends XGBoostUDTF {
 
     public XGBoostRegressionUDTF() {}
@@ -43,7 +42,7 @@ public class XGBoostRegressionUDTF extends XGBoostUDTF {
 
     @Override
     public void checkTargetValue(double target) throws HiveException {
-        if(target < 0.0 || target > 1.0) {
+        if (target < 0.0 || target > 1.0) {
             throw new HiveException("target must be in range 0 to 1: " + 
target);
         }
     }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c66f0a/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostMulticlassPredictUDTF.java
----------------------------------------------------------------------
diff --git 
a/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostMulticlassPredictUDTF.java
 
b/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostMulticlassPredictUDTF.java
index 6ceb17e..4d1c0a2 100644
--- 
a/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostMulticlassPredictUDTF.java
+++ 
b/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostMulticlassPredictUDTF.java
@@ -28,10 +28,9 @@ import java.util.ArrayList;
 import java.util.List;
 
 @Description(
-    name = "xgboost_multiclass_predict",
-    value = "_FUNC_(string rowid, string[] features, string model_id, 
array<byte> pred_model [, string options]) "
-                + "- Returns a prediction result as (string rowid, int label, 
float probability)"
-)
+        name = "xgboost_multiclass_predict",
+        value = "_FUNC_(string rowid, string[] features, string model_id, 
array<byte> pred_model [, string options]) "
+                + "- Returns a prediction result as (string rowid, int label, 
float probability)")
 public final class XGBoostMulticlassPredictUDTF extends 
hivemall.xgboost.XGBoostPredictUDTF {
 
     public XGBoostMulticlassPredictUDTF() {}
@@ -51,16 +50,15 @@ public final class XGBoostMulticlassPredictUDTF extends 
hivemall.xgboost.XGBoost
     }
 
     @Override
-    public void forwardPredicted(
-            final List<LabeledPointWithRowId> testData,
+    public void forwardPredicted(final List<LabeledPointWithRowId> testData,
             final float[][] predicted) throws HiveException {
-        assert(predicted.length == testData.size());
-        for(int i = 0; i < testData.size(); i++) {
-            assert(predicted[i].length > 1);
+        assert (predicted.length == testData.size());
+        for (int i = 0; i < testData.size(); i++) {
+            assert (predicted[i].length > 1);
             final String rowId = testData.get(i).rowId;
-            for(int j = 0; j < predicted[i].length; j++) {
+            for (int j = 0; j < predicted[i].length; j++) {
                 float prob = predicted[i][j];
-                forward(new Object[]{rowId, j, prob});
+                forward(new Object[] {rowId, j, prob});
             }
         }
     }

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c66f0a/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostPredictUDTF.java
----------------------------------------------------------------------
diff --git 
a/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostPredictUDTF.java 
b/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostPredictUDTF.java
index 4510206..594a738 100644
--- a/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostPredictUDTF.java
+++ b/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostPredictUDTF.java
@@ -28,10 +28,9 @@ import java.util.ArrayList;
 import java.util.List;
 
 @Description(
-    name = "xgboost_predict",
-    value = "_FUNC_(string rowid, string[] features, string model_id, 
array<byte> pred_model [, string options]) "
-                + "- Returns a prediction result as (string rowid, float 
predicted)"
-)
+        name = "xgboost_predict",
+        value = "_FUNC_(string rowid, string[] features, string model_id, 
array<byte> pred_model [, string options]) "
+                + "- Returns a prediction result as (string rowid, float 
predicted)")
 public final class XGBoostPredictUDTF extends 
hivemall.xgboost.XGBoostPredictUDTF {
 
     public XGBoostPredictUDTF() {}
@@ -49,15 +48,14 @@ public final class XGBoostPredictUDTF extends 
hivemall.xgboost.XGBoostPredictUDT
     }
 
     @Override
-    public void forwardPredicted(
-            final List<LabeledPointWithRowId> testData,
+    public void forwardPredicted(final List<LabeledPointWithRowId> testData,
             final float[][] predicted) throws HiveException {
-        assert(predicted.length == testData.size());
-        for(int i = 0; i < testData.size(); i++) {
-            assert(predicted[i].length == 1);
+        assert (predicted.length == testData.size());
+        for (int i = 0; i < testData.size(); i++) {
+            assert (predicted[i].length == 1);
             final String rowId = testData.get(i).rowId;
             float p = predicted[i][0];
-            forward(new Object[]{rowId, p});
+            forward(new Object[] {rowId, p});
         }
     }
 

Reply via email to