Repository: incubator-hivemall Updated Branches: refs/heads/master 7e96c8a99 -> e3bbaf622 (forced update)
Applied refactoring for XGboost module Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/e3bbaf62 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/e3bbaf62 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/e3bbaf62 Branch: refs/heads/master Commit: e3bbaf622a06a07c3f15792f2b1e9a3bb3bb2e78 Parents: 04372d4 Author: Makoto Yui <[email protected]> Authored: Sat Jul 15 00:25:45 2017 +0900 Committer: Makoto Yui <[email protected]> Committed: Sat Jul 15 00:58:24 2017 +0900 ---------------------------------------------------------------------- .../java/hivemall/xgboost/NativeLibLoader.java | 16 +- .../hivemall/xgboost/XGBoostPredictUDTF.java | 146 +++++++++++-------- .../main/java/hivemall/xgboost/XGBoostUDTF.java | 88 ++++++----- .../java/hivemall/xgboost/XGBoostUtils.java | 9 +- .../XGBoostBinaryClassifierUDTF.java | 12 +- .../XGBoostMulticlassClassifierUDTF.java | 14 +- .../regression/XGBoostRegressionUDTF.java | 8 +- .../tools/XGBoostMulticlassPredictUDTF.java | 47 ++++-- .../xgboost/tools/XGBoostPredictUDTF.java | 40 +++-- 9 files changed, 235 insertions(+), 145 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3bbaf62/xgboost/src/main/java/hivemall/xgboost/NativeLibLoader.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/NativeLibLoader.java b/xgboost/src/main/java/hivemall/xgboost/NativeLibLoader.java index 63a5217..da6289d 100644 --- a/xgboost/src/main/java/hivemall/xgboost/NativeLibLoader.java +++ b/xgboost/src/main/java/hivemall/xgboost/NativeLibLoader.java @@ -18,9 +18,15 @@ */ package hivemall.xgboost; -import java.io.*; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.lang.reflect.Field; import java.util.UUID; + import javax.annotation.Nonnull; import org.apache.commons.logging.Log; @@ -53,6 +59,7 @@ public final class NativeLibLoader { return NativeLibLoader.class.getResource(path) != null; } + @Nonnull private static String getOSName() { return System.getProperty("os.name"); } @@ -73,7 +80,7 @@ public final class NativeLibLoader { } } try { - final File tempFile = createTempFileFromResource(resolvedLibName, + File tempFile = createTempFileFromResource(resolvedLibName, NativeLibLoader.class.getResourceAsStream(libPath + resolvedLibName)); logger.info("Copyed the native library in JAR as " + tempFile.getAbsolutePath()); addLibraryPath(tempFile.getParent()); @@ -89,7 +96,7 @@ public final class NativeLibLoader { logger.warn(userDefinedLib + " not found"); } else { try { - final File tempFile = createTempFileFromResource(userDefinedLibFile.getName(), + File tempFile = createTempFileFromResource(userDefinedLibFile.getName(), new FileInputStream(userDefinedLibFile.getAbsolutePath())); logger.info("Copyed the user-defined native library as " + tempFile.getAbsolutePath()); @@ -101,6 +108,7 @@ public final class NativeLibLoader { } } + @Nonnull private static String getPreffix(@Nonnull String fileName) { int point = fileName.lastIndexOf("."); if (point != -1) { @@ -132,7 +140,7 @@ public final class NativeLibLoader { } // Prepare buffer for data copying - byte[] buffer = new byte[8192]; + final byte[] buffer = new byte[8192]; int readBytes; // Open output stream and copy the native library into the temporary one http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3bbaf62/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java b/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java index a175dd2..fd4c0b4 100644 --- a/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java +++ b/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java @@ -18,27 +18,35 @@ */ package hivemall.xgboost; +import hivemall.UDTFWithOptions; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.Primitives; + import java.io.ByteArrayInputStream; -import java.util.*; -import java.util.Map.Entry; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import javax.annotation.Nonnull; import ml.dmlc.xgboost4j.LabeledPoint; import ml.dmlc.xgboost4j.java.Booster; import ml.dmlc.xgboost4j.java.DMatrix; import ml.dmlc.xgboost4j.java.XGBoost; import ml.dmlc.xgboost4j.java.XGBoostError; + import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.serde2.objectinspector.*; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; -import hivemall.UDTFWithOptions; -import hivemall.utils.hadoop.HiveUtils; -import hivemall.utils.lang.Primitives; - public abstract class XGBoostPredictUDTF extends UDTFWithOptions { // For input parameters @@ -59,21 +67,8 @@ public abstract class XGBoostPredictUDTF extends UDTFWithOptions { NativeLibLoader.initXGBoost(); } - public XGBoostPredictUDTF() {} - - protected final class LabeledPointWithRowId { - public String rowId; - public LabeledPoint point; - - // Prevent other classes from instantiating this - LabeledPointWithRowId() {} - } - - private LabeledPointWithRowId createLabeledPoint(String rowId, LabeledPoint point) { - final LabeledPointWithRowId p = new LabeledPointWithRowId(); - p.rowId = rowId; - p.point = point; - return p; + public XGBoostPredictUDTF() { + super(); } @Override @@ -100,14 +95,16 @@ public abstract class XGBoostPredictUDTF extends UDTFWithOptions { return cl; } - /** Override this to output predicted results depending on a taks type */ - abstract public StructObjectInspector getReturnOI(); + /** Override this to output predicted results depending on a task type */ + @Nonnull + protected abstract StructObjectInspector getReturnOI(); - abstract public void forwardPredicted(final List<LabeledPointWithRowId> testData, - final float[][] predicted) throws HiveException; + protected abstract void forwardPredicted(@Nonnull final List<LabeledPointWithRowId> testData, + @Nonnull final float[][] predicted) throws HiveException; @Override - public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) + throws UDFArgumentException { if (argOIs.length != 4 && argOIs.length != 5) { throw new UDFArgumentException(this.getClass().getSimpleName() + " takes 4 or 5 arguments: string rowid, string[] features, string model_id," @@ -127,16 +124,18 @@ public abstract class XGBoostPredictUDTF extends UDTFWithOptions { } } - private static DMatrix createDMatrix(final List<LabeledPointWithRowId> data) + @Nonnull + private static DMatrix createDMatrix(@Nonnull final List<LabeledPointWithRowId> data) throws XGBoostError { - final List<LabeledPoint> points = new ArrayList(data.size()); + final List<LabeledPoint> points = new ArrayList<>(data.size()); for (LabeledPointWithRowId d : data) { points.add(d.point); } return new DMatrix(points.iterator(), ""); } - private static Booster initXgBooster(final byte[] input) throws HiveException { + @Nonnull + private static Booster initXgBooster(@Nonnull final byte[] input) throws HiveException { try { return XGBoost.loadModel(new ByteArrayInputStream(input)); } catch (Exception e) { @@ -146,42 +145,73 @@ public abstract class XGBoostPredictUDTF extends UDTFWithOptions { private void predictAndFlush(final Booster model, final List<LabeledPointWithRowId> buf) throws HiveException { + final DMatrix testData; + final float[][] predicted; try { - final DMatrix testData = createDMatrix(buf); - final float[][] predicted = model.predict(testData); - forwardPredicted(buf, predicted); - } catch (Exception e) { + testData = createDMatrix(buf); + predicted = model.predict(testData); + } catch (XGBoostError e) { throw new HiveException(e); } + forwardPredicted(buf, predicted); buf.clear(); } @Override public void process(Object[] args) throws HiveException { - if (args[1] != null) { - final String rowId = PrimitiveObjectInspectorUtils.getString(args[0], rowIdOI); - final List<?> features = (List<?>) featureListOI.getList(args[1]); - final String[] fv = new String[features.size()]; - for (int i = 0; i < features.size(); i++) { - fv[i] = (String) featureElemOI.getPrimitiveJavaObject(features.get(i)); - } - final String modelId = PrimitiveObjectInspectorUtils.getString(args[2], modelIdOI); - if (!mapToModel.containsKey(modelId)) { - final byte[] predModel = PrimitiveObjectInspectorUtils.getBinary(args[3], modelOI) - .getBytes(); - mapToModel.put(modelId, initXgBooster(predModel)); - } - final LabeledPoint point = XGBoostUtils.parseFeatures(0.f, fv); - if (point != null) { - if (!rowBuffer.containsKey(modelId)) { - rowBuffer.put(modelId, new ArrayList()); - } - final List<LabeledPointWithRowId> buf = rowBuffer.get(modelId); - buf.add(createLabeledPoint(rowId, point)); - if (buf.size() >= batch_size) { - predictAndFlush(mapToModel.get(modelId), buf); - } - } + if (args[1] == null) { + return; + } + + final String rowId = PrimitiveObjectInspectorUtils.getString(args[0], rowIdOI); + final List<?> features = (List<?>) featureListOI.getList(args[1]); + final String[] fv = new String[features.size()]; + for (int i = 0; i < features.size(); i++) { + fv[i] = (String) featureElemOI.getPrimitiveJavaObject(features.get(i)); + } + final String modelId = PrimitiveObjectInspectorUtils.getString(args[2], modelIdOI); + if (!mapToModel.containsKey(modelId)) { + final byte[] predModel = PrimitiveObjectInspectorUtils.getBinary(args[3], modelOI) + .getBytes(); + mapToModel.put(modelId, initXgBooster(predModel)); + } + + final LabeledPoint point = XGBoostUtils.parseFeatures(0.f, fv); + if (point == null) { + return; + } + + List<LabeledPointWithRowId> buf = rowBuffer.get(modelId); + if (buf == null) { + buf = new ArrayList<LabeledPointWithRowId>(); + rowBuffer.put(modelId, buf); + } + buf.add(new LabeledPointWithRowId(rowId, point)); + if (buf.size() >= batch_size) { + predictAndFlush(mapToModel.get(modelId), buf); + } + } + + public static final class LabeledPointWithRowId { + + @Nonnull + final String rowId; + @Nonnull + final LabeledPoint point; + + LabeledPointWithRowId(@Nonnull String rowId, @Nonnull LabeledPoint point) { + this.rowId = rowId; + this.point = point; + } + + @Nonnull + public String getRowId() { + return rowId; + } + + @Nonnull + public LabeledPoint getPoint() { + return point; } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3bbaf62/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java b/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java index 059cb1c..c67d35b 100644 --- a/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java +++ b/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java @@ -18,29 +18,38 @@ */ package hivemall.xgboost; +import hivemall.UDTFWithOptions; +import hivemall.utils.hadoop.HadoopUtils; +import hivemall.utils.hadoop.HiveUtils; + import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import javax.annotation.Nonnull; import ml.dmlc.xgboost4j.LabeledPoint; import ml.dmlc.xgboost4j.java.Booster; import ml.dmlc.xgboost4j.java.DMatrix; import ml.dmlc.xgboost4j.java.XGBoostError; + import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; -import org.apache.hadoop.hive.serde2.objectinspector.*; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; -import hivemall.UDTFWithOptions; -import hivemall.utils.hadoop.HadoopUtils; -import hivemall.utils.hadoop.HiveUtils; - /** * This is a base class to handle the options for XGBoost and provide common functions among various * tasks. @@ -48,8 +57,10 @@ import hivemall.utils.hadoop.HiveUtils; public abstract class XGBoostUDTF extends UDTFWithOptions { private static final Log logger = LogFactory.getLog(XGBoostUDTF.class); - // For XGBoost options - protected final Map<String, Object> params = new HashMap<String, Object>(); + // Settings for the XGBoost native library + static { + NativeLibLoader.initXGBoost(); + } // For input buffer private final List<LabeledPoint> featuresList; @@ -59,10 +70,9 @@ public abstract class XGBoostUDTF extends UDTFWithOptions { private PrimitiveObjectInspector featureElemOI; private PrimitiveObjectInspector targetOI; - // Settings for the XGBoost native library - static { - NativeLibLoader.initXGBoost(); - } + // For XGBoost options + @Nonnull + protected final Map<String, Object> params = new HashMap<String, Object>(); // XGBoost options can be found in https://github.com/dmlc/xgboost/blob/master/doc/parameter.md // Most of default parameters are set along with the official one. @@ -97,7 +107,7 @@ public abstract class XGBoostUDTF extends UDTFWithOptions { } public XGBoostUDTF() { - this.featuresList = new ArrayList(1024); + this.featuresList = new ArrayList<>(1024); } @Override @@ -241,9 +251,10 @@ public abstract class XGBoostUDTF extends UDTFWithOptions { } /** All the functions return (string model_id, byte[] pred_model) as built models */ + @Nonnull private static StructObjectInspector getReturnOIs() { - final ArrayList fieldNames = new ArrayList(2); - final ArrayList fieldOIs = new ArrayList(2); + final List<String> fieldNames = new ArrayList<>(2); + final List<ObjectInspector> fieldOIs = new ArrayList<>(2); fieldNames.add("model_id"); fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector); fieldNames.add("pred_model"); @@ -252,7 +263,8 @@ public abstract class XGBoostUDTF extends UDTFWithOptions { } @Override - public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + public StructObjectInspector initialize(@Nonnull ObjectInspector[] argOIs) + throws UDFArgumentException { processOptions(argOIs); final ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]); final ObjectInspector elemOI = listOI.getListElementObjectInspector(); @@ -263,35 +275,37 @@ public abstract class XGBoostUDTF extends UDTFWithOptions { } /** It `target` has valid input range, it overrides this */ - public void checkTargetValue(double target) throws HiveException {} + protected abstract void checkTargetValue(double target) throws HiveException; @Override - public void process(Object[] args) throws HiveException { - if (args[0] != null) { - // TODO: Need to support dense inputs - final List<?> features = (List<?>) featureListOI.getList(args[0]); - final String[] fv = new String[features.size()]; - for (int i = 0; i < features.size(); i++) { - fv[i] = (String) featureElemOI.getPrimitiveJavaObject(features.get(i)); - } - double target = PrimitiveObjectInspectorUtils.getDouble(args[1], this.targetOI); - checkTargetValue(target); - final LabeledPoint point = XGBoostUtils.parseFeatures(target, fv); - if (point != null) { - this.featuresList.add(point); - } + public void process(@Nonnull Object[] args) throws HiveException { + if (args[0] == null) { + return; + } + + // TODO: Need to support dense inputs + final List<?> features = (List<?>) featureListOI.getList(args[0]); + final String[] fv = new String[features.size()]; + for (int i = 0; i < features.size(); i++) { + fv[i] = (String) featureElemOI.getPrimitiveJavaObject(features.get(i)); + } + double target = PrimitiveObjectInspectorUtils.getDouble(args[1], this.targetOI); + checkTargetValue(target); + final LabeledPoint point = XGBoostUtils.parseFeatures(target, fv); + if (point != null) { + this.featuresList.add(point); } } - private String generateUniqueModelId() { + @Nonnull + private static String generateUniqueModelId() { return "xgbmodel-" + HadoopUtils.getUniqueTaskIdString(); } @Nonnull - private static Booster createXGBooster( - final Map<String, Object> params, final List<LabeledPoint> input) - throws NoSuchMethodException, XGBoostError, IllegalAccessException, - InvocationTargetException, InstantiationException { + private static Booster createXGBooster(final Map<String, Object> params, + final List<LabeledPoint> input) throws NoSuchMethodException, XGBoostError, + IllegalAccessException, InvocationTargetException, InstantiationException { Class<?>[] args = {Map.class, DMatrix[].class}; Constructor<Booster> ctor = Booster.class.getDeclaredConstructor(args); ctor.setAccessible(true); @@ -305,7 +319,7 @@ public abstract class XGBoostUDTF extends UDTFWithOptions { // Kick off training with XGBoost final DMatrix trainData = new DMatrix(featuresList.iterator(), ""); final Booster booster = createXGBooster(params, featuresList); - int num_round = (Integer) params.get("num_round"); + final int num_round = (Integer) params.get("num_round"); for (int i = 0; i < num_round; i++) { booster.update(trainData, i); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3bbaf62/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java b/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java index 632d2fe..2e2bf25 100644 --- a/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java +++ b/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java @@ -18,20 +18,23 @@ */ package hivemall.xgboost; -import ml.dmlc.xgboost4j.LabeledPoint; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; -import java.util.List; +import ml.dmlc.xgboost4j.LabeledPoint; public final class XGBoostUtils { private XGBoostUtils() {} /** Transform List<String> inputs into a XGBoost input format */ - public static LabeledPoint parseFeatures(double target, String[] features) { + @Nullable + public static LabeledPoint parseFeatures(final double target, @Nonnull final String[] features) { final int size = features.length; if (size == 0) { return null; } + final int[] indices = new int[size]; final float[] values = new float[size]; for (int i = 0; i < size; i++) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3bbaf62/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTF.java b/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTF.java index 94282bb..6636bc1 100644 --- a/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTF.java +++ b/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTF.java @@ -18,11 +18,11 @@ */ package hivemall.xgboost.classification; +import hivemall.xgboost.XGBoostUDTF; + import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.metadata.HiveException; -import hivemall.xgboost.XGBoostUDTF; - /** * A XGBoost binary classification and the document is as follows; - * https://github.com/dmlc/xgboost/tree/master/demo/binary_classification @@ -30,9 +30,11 @@ import hivemall.xgboost.XGBoostUDTF; @Description( name = "train_xgboost_classifier", value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>") -public class XGBoostBinaryClassifierUDTF extends XGBoostUDTF { +public final class XGBoostBinaryClassifierUDTF extends XGBoostUDTF { - public XGBoostBinaryClassifierUDTF() {} + public XGBoostBinaryClassifierUDTF() { + super(); + } { // Settings for binary classification @@ -41,7 +43,7 @@ public class XGBoostBinaryClassifierUDTF extends XGBoostUDTF { } @Override - public void checkTargetValue(double target) throws HiveException { + protected void checkTargetValue(final double target) throws HiveException { if (!(Double.compare(target, 0.0) == 0 || Double.compare(target, 1.0) == 0)) { throw new HiveException("target must be 0.0 or 1.0: " + target); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3bbaf62/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTF.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTF.java b/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTF.java index 3181473..62ede2c 100644 --- a/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTF.java +++ b/xgboost/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTF.java @@ -18,13 +18,13 @@ */ package hivemall.xgboost.classification; +import hivemall.xgboost.XGBoostUDTF; + import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDFArgumentException; import org.apache.hadoop.hive.ql.metadata.HiveException; - -import hivemall.xgboost.XGBoostUDTF; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; /** @@ -34,9 +34,11 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; @Description( name = "train_multiclass_xgboost_classifier", value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>") -public class XGBoostMulticlassClassifierUDTF extends XGBoostUDTF { +public final class XGBoostMulticlassClassifierUDTF extends XGBoostUDTF { - public XGBoostMulticlassClassifierUDTF() {} + public XGBoostMulticlassClassifierUDTF() { + super(); + } { // Settings for multiclass classification @@ -47,7 +49,7 @@ public class XGBoostMulticlassClassifierUDTF extends XGBoostUDTF { @Override protected Options getOptions() { - final Options opts = super.getOptions(); + Options opts = super.getOptions(); opts.addOption("num_class", true, "Number of classes to classify"); return opts; } @@ -69,7 +71,7 @@ public class XGBoostMulticlassClassifierUDTF extends XGBoostUDTF { } @Override - public void checkTargetValue(double target) throws HiveException { + protected void checkTargetValue(final double target) throws HiveException { double num_class = ((Integer) params.get("num_class")).doubleValue(); if (target < 0.0 || target > num_class || Double.compare(target - Math.floor(target), 0.0) != 0) { http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3bbaf62/xgboost/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTF.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTF.java b/xgboost/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTF.java index 98abc8a..3a7aec6 100644 --- a/xgboost/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTF.java +++ b/xgboost/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTF.java @@ -30,9 +30,11 @@ import hivemall.xgboost.XGBoostUDTF; @Description( name = "train_xgboost_regr", value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>") -public class XGBoostRegressionUDTF extends XGBoostUDTF { +public final class XGBoostRegressionUDTF extends XGBoostUDTF { - public XGBoostRegressionUDTF() {} + public XGBoostRegressionUDTF() { + super(); + } { // Settings for logistic regression @@ -41,7 +43,7 @@ public class XGBoostRegressionUDTF extends XGBoostUDTF { } @Override - public void checkTargetValue(double target) throws HiveException { + protected void checkTargetValue(final double target) throws HiveException { if (target < 0.0 || target > 1.0) { throw new HiveException("target must be in range 0 to 1: " + target); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3bbaf62/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostMulticlassPredictUDTF.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostMulticlassPredictUDTF.java b/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostMulticlassPredictUDTF.java index 4d1c0a2..fd67c09 100644 --- a/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostMulticlassPredictUDTF.java +++ b/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostMulticlassPredictUDTF.java @@ -18,47 +18,62 @@ */ package hivemall.xgboost.tools; +import hivemall.utils.lang.Preconditions; + +import java.util.ArrayList; +import java.util.List; + +import javax.annotation.Nonnull; + import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import java.util.ArrayList; -import java.util.List; - @Description( name = "xgboost_multiclass_predict", value = "_FUNC_(string rowid, string[] features, string model_id, array<byte> pred_model [, string options]) " + "- Returns a prediction result as (string rowid, int label, float probability)") public final class XGBoostMulticlassPredictUDTF extends hivemall.xgboost.XGBoostPredictUDTF { - public XGBoostMulticlassPredictUDTF() {} + public XGBoostMulticlassPredictUDTF() { + super(); + } /** Return (string rowid, int label, float probability) as a result */ @Override - public StructObjectInspector getReturnOI() { - final ArrayList fieldNames = new ArrayList(3); - final ArrayList fieldOIs = new ArrayList(3); + protected StructObjectInspector getReturnOI() { + final List<String> fieldNames = new ArrayList<>(3); + final List<ObjectInspector> fieldOIs = new ArrayList<>(3); fieldNames.add("rowid"); fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector); fieldNames.add("label"); fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector); fieldNames.add("probability"); fieldOIs.add(PrimitiveObjectInspectorFactory.javaFloatObjectInspector); + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } @Override - public void forwardPredicted(final List<LabeledPointWithRowId> testData, - final float[][] predicted) throws HiveException { - assert (predicted.length == testData.size()); - for (int i = 0; i < testData.size(); i++) { - assert (predicted[i].length > 1); - final String rowId = testData.get(i).rowId; - for (int j = 0; j < predicted[i].length; j++) { - float prob = predicted[i][j]; - forward(new Object[] {rowId, j, prob}); + protected void forwardPredicted(@Nonnull final List<LabeledPointWithRowId> testData, + @Nonnull final float[][] predicted) throws HiveException { + Preconditions.checkArgument(predicted.length == testData.size(), HiveException.class); + + final Object[] forwardObj = new Object[3]; + for (int i = 0, size = testData.size(); i < size; i++) { + final float[] predicted_i = predicted[i]; + final String rowId = testData.get(i).getRowId(); + forwardObj[0] = rowId; + + assert (predicted_i.length > 1); + for (int j = 0; j < predicted_i.length; j++) { + forwardObj[1] = j; + float prob = predicted_i[j]; + forwardObj[2] = prob; + forward(forwardObj); } } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/e3bbaf62/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostPredictUDTF.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostPredictUDTF.java b/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostPredictUDTF.java index 594a738..df5498d 100644 --- a/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostPredictUDTF.java +++ b/xgboost/src/main/java/hivemall/xgboost/tools/XGBoostPredictUDTF.java @@ -18,44 +18,58 @@ */ package hivemall.xgboost.tools; +import hivemall.utils.lang.Preconditions; + +import java.util.ArrayList; +import java.util.List; + +import javax.annotation.Nonnull; + import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import java.util.ArrayList; -import java.util.List; - @Description( name = "xgboost_predict", value = "_FUNC_(string rowid, string[] features, string model_id, array<byte> pred_model [, string options]) " + "- Returns a prediction result as (string rowid, float predicted)") public final class XGBoostPredictUDTF extends hivemall.xgboost.XGBoostPredictUDTF { - public XGBoostPredictUDTF() {} + public XGBoostPredictUDTF() { + super(); + } /** Return (string rowid, float predicted) as a result */ @Override - public StructObjectInspector getReturnOI() { - final ArrayList fieldNames = new ArrayList(2); - final ArrayList fieldOIs = new ArrayList(2); + protected StructObjectInspector getReturnOI() { + final List<String> fieldNames = new ArrayList<>(2); + final List<ObjectInspector> fieldOIs = new ArrayList<>(2); fieldNames.add("rowid"); fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector); fieldNames.add("predicted"); fieldOIs.add(PrimitiveObjectInspectorFactory.javaFloatObjectInspector); + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } @Override - public void forwardPredicted(final List<LabeledPointWithRowId> testData, - final float[][] predicted) throws HiveException { - assert (predicted.length == testData.size()); - for (int i = 0; i < testData.size(); i++) { + protected void forwardPredicted(@Nonnull final List<LabeledPointWithRowId> testData, + @Nonnull final float[][] predicted) throws HiveException { + Preconditions.checkArgument(predicted.length == testData.size(), HiveException.class); + + final Object[] forwardObj = new Object[2]; + for (int i = 0, size = testData.size(); i < size; i++) { assert (predicted[i].length == 1); - final String rowId = testData.get(i).rowId; + + final String rowId = testData.get(i).getRowId(); float p = predicted[i][0]; - forward(new Object[] {rowId, p}); + forwardObj[0] = rowId; + forwardObj[1] = p; + + forward(forwardObj); } }
