Repository: incubator-hivemall Updated Branches: refs/heads/master 69aa64b73 -> e9c66f0a1
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c66f0a/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java b/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java index 33568c7..e05755e 100644 --- a/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java +++ b/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java @@ -87,13 +87,13 @@ public abstract class XGBoostPredictUDTF extends UDTFWithOptions { protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { int _batch_size = 128; CommandLine cl = null; - if(argOIs.length >= 5) { + if (argOIs.length >= 5) { String rawArgs = HiveUtils.getConstString(argOIs[4]); cl = this.parseOptions(rawArgs); _batch_size = Primitives.parseInt(cl.getOptionValue("_batch_size"), _batch_size); - if(_batch_size < 1) { - throw new IllegalArgumentException( - "batch_size must be greater than 0: " + _batch_size); + if (_batch_size < 1) { + throw new IllegalArgumentException("batch_size must be greater than 0: " + + _batch_size); } } this.batch_size = _batch_size; @@ -103,13 +103,12 @@ public abstract class XGBoostPredictUDTF extends UDTFWithOptions { /** Override this to output predicted results depending on a taks type */ abstract public StructObjectInspector getReturnOI(); - abstract public void forwardPredicted( - final List<LabeledPointWithRowId> testData, + abstract public void forwardPredicted(final List<LabeledPointWithRowId> testData, final float[][] predicted) throws HiveException; @Override public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { - if(argOIs.length != 4 && argOIs.length != 5) { + if (argOIs.length != 4 && argOIs.length != 5) { throw new UDFArgumentException(this.getClass().getSimpleName() + " takes 4 or 5 arguments: string rowid, string[] features, string model_id," + " array<byte> pred_model [, string options]: " + argOIs.length); @@ -128,9 +127,10 @@ public abstract class XGBoostPredictUDTF extends UDTFWithOptions { } } - private static DMatrix createDMatrix(final List<LabeledPointWithRowId> data) throws XGBoostError { + private static DMatrix createDMatrix(final List<LabeledPointWithRowId> data) + throws XGBoostError { final List<LabeledPoint> points = new ArrayList(data.size()); - for(LabeledPointWithRowId d : data) { + for (LabeledPointWithRowId d : data) { points.add(d.point); } return new DMatrix(points.iterator(), ""); @@ -158,22 +158,23 @@ public abstract class XGBoostPredictUDTF extends UDTFWithOptions { @Override public void process(Object[] args) throws HiveException { - if(args[1] != null) { + if (args[1] != null) { final String rowId = PrimitiveObjectInspectorUtils.getString(args[0], rowIdOI); final List<String> features = (List<String>) featureListOI.getList(args[1]); final String modelId = PrimitiveObjectInspectorUtils.getString(args[2], modelIdOI); - if(!mapToModel.containsKey(modelId)) { - final byte[] predModel = PrimitiveObjectInspectorUtils.getBinary(args[3], modelOI).getBytes(); + if (!mapToModel.containsKey(modelId)) { + final byte[] predModel = PrimitiveObjectInspectorUtils.getBinary(args[3], modelOI) + .getBytes(); mapToModel.put(modelId, initXgBooster(predModel)); } final LabeledPoint point = XGBoostUtils.parseFeatures(0.f, features); - if(point != null) { - if(!rowBuffer.containsKey(modelId)) { + if (point != null) { + if (!rowBuffer.containsKey(modelId)) { rowBuffer.put(modelId, new ArrayList()); } final List<LabeledPointWithRowId> buf = rowBuffer.get(modelId); buf.add(createLabeledPoint(rowId, point)); - if(buf.size() >= batch_size) { + if (buf.size() >= batch_size) { predictAndFlush(mapToModel.get(modelId), buf); } } @@ -182,7 +183,7 @@ public abstract class XGBoostPredictUDTF extends UDTFWithOptions { @Override public void close() throws HiveException { - for(Entry<String, List<LabeledPointWithRowId>> e : rowBuffer.entrySet()) { + for (Entry<String, List<LabeledPointWithRowId>> e : rowBuffer.entrySet()) { predictAndFlush(mapToModel.get(e.getKey()), e.getValue()); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c66f0a/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java b/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java index b269549..b57925a 100644 --- a/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java +++ b/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java @@ -41,8 +41,8 @@ import hivemall.utils.hadoop.HadoopUtils; import hivemall.utils.hadoop.HiveUtils; /** - * This is a base class to handle the options for XGBoost and provide - * common functions among various tasks. + * This is a base class to handle the options for XGBoost and provide common functions among various + * tasks. */ public abstract class XGBoostUDTF extends UDTFWithOptions { private static final Log logger = LogFactory.getLog(XGBoostUDTF.class); @@ -104,33 +104,49 @@ public abstract class XGBoostUDTF extends UDTFWithOptions { final Options opts = new Options(); /** General parameters */ - opts.addOption("booster", true, "Set a booster to use, gbtree or gblinear. [default: gbree]"); + opts.addOption("booster", true, + "Set a booster to use, gbtree or gblinear. [default: gbree]"); opts.addOption("num_round", true, "Number of boosting iterations [default: 8]"); - opts.addOption("silent", true, "0 means printing running messages, 1 means silent mode [default: 1]"); - opts.addOption("nthread", true, "Number of parallel threads used to run xgboost [default: 1]"); - opts.addOption("num_pbuffer", true, "Size of prediction buffer [set automatically by xgboost]"); - opts.addOption("num_feature", true, "Feature dimension used in boosting [default: set automatically by xgboost]"); + opts.addOption("silent", true, + "0 means printing running messages, 1 means silent mode [default: 1]"); + opts.addOption("nthread", true, + "Number of parallel threads used to run xgboost [default: 1]"); + opts.addOption("num_pbuffer", true, + "Size of prediction buffer [set automatically by xgboost]"); + opts.addOption("num_feature", true, + "Feature dimension used in boosting [default: set automatically by xgboost]"); /** Parameters for both boosters */ opts.addOption("alpha", true, "L1 regularization term on weights [default: 0.0]"); - opts.addOption("lambda", true, "L2 regularization term on weights [default: 1.0 for gbtree, 0.0 for gblinear]"); + opts.addOption("lambda", true, + "L2 regularization term on weights [default: 1.0 for gbtree, 0.0 for gblinear]"); /** Parameters for Tree Booster */ - opts.addOption("eta", true, "Step size shrinkage used in update to prevents overfitting [default: 0.3]"); - opts.addOption("gamma", true, "Minimum loss reduction required to make a further partition on a leaf node of the tree [default: 0.0]"); + opts.addOption("eta", true, + "Step size shrinkage used in update to prevents overfitting [default: 0.3]"); + opts.addOption( + "gamma", + true, + "Minimum loss reduction required to make a further partition on a leaf node of the tree [default: 0.0]"); opts.addOption("max_depth", true, "Max depth of decision tree [default: 6]"); - opts.addOption("min_child_weight", true, "Minimum sum of instance weight(hessian) needed in a child [default: 1]"); - opts.addOption("max_delta_step", true, "Maximum delta step we allow each tree's weight estimation to be [default: 0]"); + opts.addOption("min_child_weight", true, + "Minimum sum of instance weight(hessian) needed in a child [default: 1]"); + opts.addOption("max_delta_step", true, + "Maximum delta step we allow each tree's weight estimation to be [default: 0]"); opts.addOption("subsample", true, "Subsample ratio of the training instance [default: 1.0]"); - opts.addOption("colsample_bytree", true, "Subsample ratio of columns when constructing each tree [default: 1.0]"); - opts.addOption("colsample_bylevel", true, "Subsample ratio of columns for each split, in each level [default: 1.0]"); + opts.addOption("colsample_bytree", true, + "Subsample ratio of columns when constructing each tree [default: 1.0]"); + opts.addOption("colsample_bylevel", true, + "Subsample ratio of columns for each split, in each level [default: 1.0]"); /** Parameters for Linear Booster */ opts.addOption("lambda_bias", true, "L2 regularization term on bias [default: 0.0]"); /** Learning Task Parameters */ - opts.addOption("base_score", true, "Initial prediction score of all instances, global bias [default: 0.5]"); - opts.addOption("eval_metric", true, "Evaluation metrics for validation data [default according to objective]"); + opts.addOption("base_score", true, + "Initial prediction score of all instances, global bias [default: 0.5]"); + opts.addOption("eval_metric", true, + "Evaluation metrics for validation data [default according to objective]"); return opts; } @@ -138,7 +154,7 @@ public abstract class XGBoostUDTF extends UDTFWithOptions { @Override protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { CommandLine cl = null; - if(argOIs.length >= 3) { + if (argOIs.length >= 3) { final String rawArgs = HiveUtils.getConstString(argOIs[2]); cl = this.parseOptions(rawArgs); @@ -181,7 +197,8 @@ public abstract class XGBoostUDTF extends UDTFWithOptions { params.put("max_depth", Integer.valueOf(cl.getOptionValue("max_depth"))); } if (cl.hasOption("min_child_weight")) { - params.put("min_child_weight", Integer.valueOf(cl.getOptionValue("min_child_weight"))); + params.put("min_child_weight", + Integer.valueOf(cl.getOptionValue("min_child_weight"))); } if (cl.hasOption("max_delta_step")) { params.put("max_delta_step", Integer.valueOf(cl.getOptionValue("max_delta_step"))); @@ -193,7 +210,8 @@ public abstract class XGBoostUDTF extends UDTFWithOptions { params.put("colsamle_bytree", Double.valueOf(cl.getOptionValue("colsample_bytree"))); } if (cl.hasOption("colsample_bylevel")) { - params.put("colsamle_bylevel", Double.valueOf(cl.getOptionValue("colsample_bylevel"))); + params.put("colsamle_bylevel", + Double.valueOf(cl.getOptionValue("colsample_bylevel"))); } /** Parameters for Linear Booster */ @@ -249,36 +267,34 @@ public abstract class XGBoostUDTF extends UDTFWithOptions { @Override public void process(Object[] args) throws HiveException { - if(args[0] != null) { + if (args[0] != null) { // TODO: Need to support dense inputs final List<String> features = (List<String>) featureListOI.getList(args[0]); double target = PrimitiveObjectInspectorUtils.getDouble(args[1], this.targetOI); checkTargetValue(target); final LabeledPoint point = XGBoostUtils.parseFeatures(target, features); - if(point != null) { + if (point != null) { this.featuresList.add(point); } } } /** - * Need to override this for a Spark wrapper because `MapredContext` - * does not work in there. + * Need to override this for a Spark wrapper because `MapredContext` does not work in there. */ protected String generateUniqueModelId() { return "xgbmodel-" + String.valueOf(HadoopUtils.getTaskId()); } - private static Booster createXGBooster( - final Map<String, Object> params, + private static Booster createXGBooster(final Map<String, Object> params, final List<LabeledPoint> input) throws XGBoostError { try { Class<?>[] args = {Map.class, DMatrix[].class}; Constructor<Booster> ctor; ctor = Booster.class.getDeclaredConstructor(args); ctor.setAccessible(true); - return ctor.newInstance( - new Object[]{params, new DMatrix[]{new DMatrix(input.iterator(), "")}}); + return ctor.newInstance(new Object[] {params, + new DMatrix[] {new DMatrix(input.iterator(), "")}}); } catch (InstantiationException e) { // Catch java reflection error as fast as possible e.printStackTrace(); @@ -300,7 +316,7 @@ public abstract class XGBoostUDTF extends UDTFWithOptions { final DMatrix trainData = new DMatrix(featuresList.iterator(), ""); final Booster booster = createXGBooster(params, featuresList); int num_round = (Integer) params.get("num_round"); - for(int i = 0; i < num_round; i++) { + for (int i = 0; i < num_round; i++) { booster.update(trainData, i); } @@ -308,7 +324,7 @@ public abstract class XGBoostUDTF extends UDTFWithOptions { final String modelId = generateUniqueModelId(); final byte[] predModel = booster.toByteArray(); logger.info("model_id:" + modelId.toString() + " size:" + predModel.length); - forward(new Object[]{modelId, predModel}); + forward(new Object[] {modelId, predModel}); } catch (Exception e) { throw new HiveException(e.getMessage()); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c66f0a/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java b/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java index 9705f94..d0769f4 100644 --- a/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java +++ b/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java @@ -29,18 +29,18 @@ public final class XGBoostUtils { /** Transform List<String> inputs into a XGBoost input format */ public static LabeledPoint parseFeatures(double target, List<String> features) { final int size = features.size(); - if(size == 0) { + if (size == 0) { return null; } final int[] indices = new int[size]; final float[] values = new float[size]; - for(int i = 0; i < size; i++) { - if(features.get(i) == null) { + for (int i = 0; i < size; i++) { + if (features.get(i) == null) { continue; } final String str = features.get(i); final int pos = str.indexOf(':'); - if(pos >= 1) { + if (pos >= 1) { indices[i] = Integer.parseInt(str.substring(0, pos)); values[i] = Float.parseFloat(str.substring(pos + 1)); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c66f0a/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTF.java b/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTF.java index 7e135ec..94282bb 100644 --- a/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTF.java +++ b/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTF.java @@ -24,13 +24,12 @@ import org.apache.hadoop.hive.ql.metadata.HiveException; import hivemall.xgboost.XGBoostUDTF; /** - * A XGBoost binary classification and the document is as follows; - * - https://github.com/dmlc/xgboost/tree/master/demo/binary_classification + * A XGBoost binary classification and the document is as follows; - + * https://github.com/dmlc/xgboost/tree/master/demo/binary_classification */ @Description( - name = "train_xgboost_classifier", - value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>" -) + name = "train_xgboost_classifier", + value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>") public class XGBoostBinaryClassifierUDTF extends XGBoostUDTF { public XGBoostBinaryClassifierUDTF() {} @@ -43,7 +42,7 @@ public class XGBoostBinaryClassifierUDTF extends XGBoostUDTF { @Override public void checkTargetValue(double target) throws HiveException { - if(!(Double.compare(target, 0.0) == 0|| Double.compare(target, 1.0) == 0)) { + if (!(Double.compare(target, 0.0) == 0 || Double.compare(target, 1.0) == 0)) { throw new HiveException("target must be 0.0 or 1.0: " + target); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c66f0a/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTF.java b/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTF.java index 171e4bc..3181473 100644 --- a/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTF.java +++ b/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTF.java @@ -28,13 +28,12 @@ import hivemall.xgboost.XGBoostUDTF; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; /** - * A XGBoost multiclass classification and the document is as follows; - * - https://github.com/dmlc/xgboost/tree/master/demo/multiclass_classification + * A XGBoost multiclass classification and the document is as follows; - + * https://github.com/dmlc/xgboost/tree/master/demo/multiclass_classification */ @Description( - name = "train_multiclass_xgboost_classifier", - value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>" -) + name = "train_multiclass_xgboost_classifier", + value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>") public class XGBoostMulticlassClassifierUDTF extends XGBoostUDTF { public XGBoostMulticlassClassifierUDTF() {} @@ -56,12 +55,12 @@ public class XGBoostMulticlassClassifierUDTF extends XGBoostUDTF { @Override protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { final CommandLine cli = super.processOptions(argOIs); - if(cli != null) { - if(cli.hasOption("num_class")) { + if (cli != null) { + if (cli.hasOption("num_class")) { int _num_class = Integer.valueOf(cli.getOptionValue("num_class")); - if(_num_class < 2) { - throw new UDFArgumentException( - "num_class must be greater than 1: " + _num_class); + if (_num_class < 2) { + throw new UDFArgumentException("num_class must be greater than 1: " + + _num_class); } params.put("num_class", _num_class); } @@ -72,12 +71,10 @@ public class XGBoostMulticlassClassifierUDTF extends XGBoostUDTF { @Override public void checkTargetValue(double target) throws HiveException { double num_class = ((Integer) params.get("num_class")).doubleValue(); - if(target < 0.0 || target > num_class + if (target < 0.0 || target > num_class || Double.compare(target - Math.floor(target), 0.0) != 0) { - throw new HiveException( - "target must be {0.0, ..., " - + String.format("%.1f", (num_class - 1.0)) - + "}: " + target); + throw new HiveException("target must be {0.0, ..., " + + String.format("%.1f", (num_class - 1.0)) + "}: " + target); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c66f0a/xgboost/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTF.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTF.java b/xgboost/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTF.java index d00b430..98abc8a 100644 --- a/xgboost/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTF.java +++ b/xgboost/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTF.java @@ -24,13 +24,12 @@ import org.apache.hadoop.hive.ql.metadata.HiveException; import hivemall.xgboost.XGBoostUDTF; /** - * A XGBoost regression and the document is as follows; - * - https://github.com/dmlc/xgboost/tree/master/demo/regression + * A XGBoost regression and the document is as follows; - + * https://github.com/dmlc/xgboost/tree/master/demo/regression */ @Description( - name = "train_xgboost_regr", - value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>" -) + name = "train_xgboost_regr", + value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>") public class XGBoostRegressionUDTF extends XGBoostUDTF { public XGBoostRegressionUDTF() {} @@ -43,7 +42,7 @@ public class XGBoostRegressionUDTF extends XGBoostUDTF { @Override public void checkTargetValue(double target) throws HiveException { - if(target < 0.0 || target > 1.0) { + if (target < 0.0 || target > 1.0) { throw new HiveException("target must be in range 0 to 1: " + target); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c66f0a/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostMulticlassPredictUDTF.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostMulticlassPredictUDTF.java b/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostMulticlassPredictUDTF.java index 6ceb17e..4d1c0a2 100644 --- a/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostMulticlassPredictUDTF.java +++ b/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostMulticlassPredictUDTF.java @@ -28,10 +28,9 @@ import java.util.ArrayList; import java.util.List; @Description( - name = "xgboost_multiclass_predict", - value = "_FUNC_(string rowid, string[] features, string model_id, array<byte> pred_model [, string options]) " - + "- Returns a prediction result as (string rowid, int label, float probability)" -) + name = "xgboost_multiclass_predict", + value = "_FUNC_(string rowid, string[] features, string model_id, array<byte> pred_model [, string options]) " + + "- Returns a prediction result as (string rowid, int label, float probability)") public final class XGBoostMulticlassPredictUDTF extends hivemall.xgboost.XGBoostPredictUDTF { public XGBoostMulticlassPredictUDTF() {} @@ -51,16 +50,15 @@ public final class XGBoostMulticlassPredictUDTF extends hivemall.xgboost.XGBoost } @Override - public void forwardPredicted( - final List<LabeledPointWithRowId> testData, + public void forwardPredicted(final List<LabeledPointWithRowId> testData, final float[][] predicted) throws HiveException { - assert(predicted.length == testData.size()); - for(int i = 0; i < testData.size(); i++) { - assert(predicted[i].length > 1); + assert (predicted.length == testData.size()); + for (int i = 0; i < testData.size(); i++) { + assert (predicted[i].length > 1); final String rowId = testData.get(i).rowId; - for(int j = 0; j < predicted[i].length; j++) { + for (int j = 0; j < predicted[i].length; j++) { float prob = predicted[i][j]; - forward(new Object[]{rowId, j, prob}); + forward(new Object[] {rowId, j, prob}); } } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e9c66f0a/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostPredictUDTF.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostPredictUDTF.java b/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostPredictUDTF.java index 4510206..594a738 100644 --- a/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostPredictUDTF.java +++ b/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostPredictUDTF.java @@ -28,10 +28,9 @@ import java.util.ArrayList; import java.util.List; @Description( - name = "xgboost_predict", - value = "_FUNC_(string rowid, string[] features, string model_id, array<byte> pred_model [, string options]) " - + "- Returns a prediction result as (string rowid, float predicted)" -) + name = "xgboost_predict", + value = "_FUNC_(string rowid, string[] features, string model_id, array<byte> pred_model [, string options]) " + + "- Returns a prediction result as (string rowid, float predicted)") public final class XGBoostPredictUDTF extends hivemall.xgboost.XGBoostPredictUDTF { public XGBoostPredictUDTF() {} @@ -49,15 +48,14 @@ public final class XGBoostPredictUDTF extends hivemall.xgboost.XGBoostPredictUDT } @Override - public void forwardPredicted( - final List<LabeledPointWithRowId> testData, + public void forwardPredicted(final List<LabeledPointWithRowId> testData, final float[][] predicted) throws HiveException { - assert(predicted.length == testData.size()); - for(int i = 0; i < testData.size(); i++) { - assert(predicted[i].length == 1); + assert (predicted.length == testData.size()); + for (int i = 0; i < testData.size(); i++) { + assert (predicted[i].length == 1); final String rowId = testData.get(i).rowId; float p = predicted[i][0]; - forward(new Object[]{rowId, p}); + forward(new Object[] {rowId, p}); } }
