Repository: incubator-hivemall Updated Branches: refs/heads/master 55c858816 -> 7e96c8a99
Close #95: [HIVEMALL-119] Fix type cast issues in XGBoostUDTF Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/04372d49 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/04372d49 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/04372d49 Branch: refs/heads/master Commit: 04372d490194d598bbc79ec1adba8b0918225c38 Parents: 55c8588 Author: Takeshi Yamamuro <[email protected]> Authored: Fri Jul 14 23:27:56 2017 +0900 Committer: Makoto Yui <[email protected]> Committed: Fri Jul 14 23:27:56 2017 +0900 ---------------------------------------------------------------------- .../XGBoostBinaryClassifierUDTFWrapper.java | 47 ----------------- .../XGBoostMulticlassClassifierUDTFWrapper.java | 47 ----------------- .../XGBoostRegressionUDTFWrapper.java | 47 ----------------- .../org/apache/spark/sql/hive/HivemallOps.scala | 6 +-- .../XGBoostBinaryClassifierUDTFWrapper.java | 47 ----------------- .../XGBoostMulticlassClassifierUDTFWrapper.java | 47 ----------------- .../XGBoostRegressionUDTFWrapper.java | 47 ----------------- .../org/apache/spark/sql/hive/HivemallOps.scala | 6 +-- .../hivemall/xgboost/XGBoostPredictUDTF.java | 12 +++-- .../main/java/hivemall/xgboost/XGBoostUDTF.java | 54 ++++++++------------ .../java/hivemall/xgboost/XGBoostUtils.java | 8 +-- 11 files changed, 40 insertions(+), 328 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/spark/spark-2.0/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java b/spark/spark-2.0/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java deleted file mode 100644 index 310d15e..0000000 --- a/spark/spark-2.0/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.xgboost.classification; - -import java.util.UUID; - -import org.apache.hadoop.hive.ql.exec.Description; - -/** An alternative implementation of [[hivemall.xgboost.classification.XGBoostBinaryClassifierUDTF]]. */ -@Description( - name = "train_xgboost_classifier", - value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>" -) -public class XGBoostBinaryClassifierUDTFWrapper extends XGBoostBinaryClassifierUDTF { - private long sequence; - private long taskId; - - public XGBoostBinaryClassifierUDTFWrapper() { - this.sequence = 0L; - this.taskId = Thread.currentThread().getId(); - } - - @Override - protected String generateUniqueModelId() { - sequence++; - /** - * TODO: Check if it is unique over all tasks in executors of Spark. - */ - return "xgbmodel-" + taskId + "-" + UUID.randomUUID() + "-" + sequence; - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/spark/spark-2.0/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java b/spark/spark-2.0/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java deleted file mode 100644 index 81e6fe8..0000000 --- a/spark/spark-2.0/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.xgboost.classification; - -import java.util.UUID; - -import org.apache.hadoop.hive.ql.exec.Description; - -/** An alternative implementation of [[hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTFWrapper]]. */ -@Description( - name = "train_multiclass_xgboost_classifier", - value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>" -) -public class XGBoostMulticlassClassifierUDTFWrapper extends XGBoostMulticlassClassifierUDTF { - private long sequence; - private long taskId; - - public XGBoostMulticlassClassifierUDTFWrapper() { - this.sequence = 0L; - this.taskId = Thread.currentThread().getId(); - } - - @Override - protected String generateUniqueModelId() { - sequence++; - /** - * TODO: Check if it is unique over all tasks in executors of Spark. - */ - return "xgbmodel-" + taskId + "-" + UUID.randomUUID() + "-" + sequence; - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/spark/spark-2.0/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java b/spark/spark-2.0/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java deleted file mode 100644 index b72e045..0000000 --- a/spark/spark-2.0/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.xgboost.regression; - -import java.util.UUID; - -import org.apache.hadoop.hive.ql.exec.Description; - -/** An alternative implementation of [[hivemall.xgboost.regression.XGBoostRegressionUDTF]]. */ -@Description( - name = "train_xgboost_regr", - value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>" -) -public class XGBoostRegressionUDTFWrapper extends XGBoostRegressionUDTF { - private long sequence; - private long taskId; - - public XGBoostRegressionUDTFWrapper() { - this.sequence = 0L; - this.taskId = Thread.currentThread().getId(); - } - - @Override - protected String generateUniqueModelId() { - sequence++; - /** - * TODO: Check if it is unique over all tasks in executors of Spark. - */ - return "xgbmodel-" + taskId + "-" + UUID.randomUUID() + "-" + sequence; - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala index 7b13892..b5299ef 100644 --- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala +++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala @@ -519,7 +519,7 @@ final class HivemallOps(df: DataFrame) extends Logging { def train_xgboost_regr(exprs: Column*): DataFrame = withTypedPlan { planHiveGenericUDTF( df, - "hivemall.xgboost.regression.XGBoostRegressionUDTFWrapper", + "hivemall.xgboost.regression.XGBoostRegressionUDTF", "train_xgboost_regr", setMixServs(toHivemallFeatures(exprs)), Seq("model_id", "pred_model") @@ -536,7 +536,7 @@ final class HivemallOps(df: DataFrame) extends Logging { def train_xgboost_classifier(exprs: Column*): DataFrame = withTypedPlan { planHiveGenericUDTF( df, - "hivemall.xgboost.classification.XGBoostBinaryClassifierUDTFWrapper", + "hivemall.xgboost.classification.XGBoostBinaryClassifierUDTF", "train_xgboost_classifier", setMixServs(toHivemallFeatures(exprs)), Seq("model_id", "pred_model") @@ -553,7 +553,7 @@ final class HivemallOps(df: DataFrame) extends Logging { def train_xgboost_multiclass_classifier(exprs: Column*): DataFrame = withTypedPlan { planHiveGenericUDTF( df, - "hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTFWrapper", + "hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTF", "train_xgboost_multiclass_classifier", setMixServs(toHivemallFeatures(exprs)), Seq("model_id", "pred_model") http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java b/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java deleted file mode 100644 index 310d15e..0000000 --- a/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.xgboost.classification; - -import java.util.UUID; - -import org.apache.hadoop.hive.ql.exec.Description; - -/** An alternative implementation of [[hivemall.xgboost.classification.XGBoostBinaryClassifierUDTF]]. */ -@Description( - name = "train_xgboost_classifier", - value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>" -) -public class XGBoostBinaryClassifierUDTFWrapper extends XGBoostBinaryClassifierUDTF { - private long sequence; - private long taskId; - - public XGBoostBinaryClassifierUDTFWrapper() { - this.sequence = 0L; - this.taskId = Thread.currentThread().getId(); - } - - @Override - protected String generateUniqueModelId() { - sequence++; - /** - * TODO: Check if it is unique over all tasks in executors of Spark. - */ - return "xgbmodel-" + taskId + "-" + UUID.randomUUID() + "-" + sequence; - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java b/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java deleted file mode 100644 index 81e6fe8..0000000 --- a/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.xgboost.classification; - -import java.util.UUID; - -import org.apache.hadoop.hive.ql.exec.Description; - -/** An alternative implementation of [[hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTFWrapper]]. */ -@Description( - name = "train_multiclass_xgboost_classifier", - value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>" -) -public class XGBoostMulticlassClassifierUDTFWrapper extends XGBoostMulticlassClassifierUDTF { - private long sequence; - private long taskId; - - public XGBoostMulticlassClassifierUDTFWrapper() { - this.sequence = 0L; - this.taskId = Thread.currentThread().getId(); - } - - @Override - protected String generateUniqueModelId() { - sequence++; - /** - * TODO: Check if it is unique over all tasks in executors of Spark. - */ - return "xgbmodel-" + taskId + "-" + UUID.randomUUID() + "-" + sequence; - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/spark/spark-2.1/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java b/spark/spark-2.1/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java deleted file mode 100644 index b72e045..0000000 --- a/spark/spark-2.1/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package hivemall.xgboost.regression; - -import java.util.UUID; - -import org.apache.hadoop.hive.ql.exec.Description; - -/** An alternative implementation of [[hivemall.xgboost.regression.XGBoostRegressionUDTF]]. */ -@Description( - name = "train_xgboost_regr", - value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>" -) -public class XGBoostRegressionUDTFWrapper extends XGBoostRegressionUDTF { - private long sequence; - private long taskId; - - public XGBoostRegressionUDTFWrapper() { - this.sequence = 0L; - this.taskId = Thread.currentThread().getId(); - } - - @Override - protected String generateUniqueModelId() { - sequence++; - /** - * TODO: Check if it is unique over all tasks in executors of Spark. - */ - return "xgbmodel-" + taskId + "-" + UUID.randomUUID() + "-" + sequence; - } -} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala index 83129a7..9350a81 100644 --- a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala +++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala @@ -523,7 +523,7 @@ final class HivemallOps(df: DataFrame) extends Logging { def train_xgboost_regr(exprs: Column*): DataFrame = withTypedPlan { planHiveGenericUDTF( df, - "hivemall.xgboost.regression.XGBoostRegressionUDTFWrapper", + "hivemall.xgboost.regression.XGBoostRegressionUDTF", "train_xgboost_regr", setMixServs(toHivemallFeatures(exprs)), Seq("model_id", "pred_model") @@ -540,7 +540,7 @@ final class HivemallOps(df: DataFrame) extends Logging { def train_xgboost_classifier(exprs: Column*): DataFrame = withTypedPlan { planHiveGenericUDTF( df, - "hivemall.xgboost.classification.XGBoostBinaryClassifierUDTFWrapper", + "hivemall.xgboost.classification.XGBoostBinaryClassifierUDTF", "train_xgboost_classifier", setMixServs(toHivemallFeatures(exprs)), Seq("model_id", "pred_model") @@ -557,7 +557,7 @@ final class HivemallOps(df: DataFrame) extends Logging { def train_xgboost_multiclass_classifier(exprs: Column*): DataFrame = withTypedPlan { planHiveGenericUDTF( df, - "hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTFWrapper", + "hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTF", "train_xgboost_multiclass_classifier", setMixServs(toHivemallFeatures(exprs)), Seq("model_id", "pred_model") http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java b/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java index e05755e..a175dd2 100644 --- a/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java +++ b/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java @@ -140,7 +140,7 @@ public abstract class XGBoostPredictUDTF extends UDTFWithOptions { try { return XGBoost.loadModel(new ByteArrayInputStream(input)); } catch (Exception e) { - throw new HiveException(e.getMessage()); + throw new HiveException(e); } } @@ -151,7 +151,7 @@ public abstract class XGBoostPredictUDTF extends UDTFWithOptions { final float[][] predicted = model.predict(testData); forwardPredicted(buf, predicted); } catch (Exception e) { - throw new HiveException(e.getMessage()); + throw new HiveException(e); } buf.clear(); } @@ -160,14 +160,18 @@ public abstract class XGBoostPredictUDTF extends UDTFWithOptions { public void process(Object[] args) throws HiveException { if (args[1] != null) { final String rowId = PrimitiveObjectInspectorUtils.getString(args[0], rowIdOI); - final List<String> features = (List<String>) featureListOI.getList(args[1]); + final List<?> features = (List<?>) featureListOI.getList(args[1]); + final String[] fv = new String[features.size()]; + for (int i = 0; i < features.size(); i++) { + fv[i] = (String) featureElemOI.getPrimitiveJavaObject(features.get(i)); + } final String modelId = PrimitiveObjectInspectorUtils.getString(args[2], modelIdOI); if (!mapToModel.containsKey(modelId)) { final byte[] predModel = PrimitiveObjectInspectorUtils.getBinary(args[3], modelOI) .getBytes(); mapToModel.put(modelId, initXgBooster(predModel)); } - final LabeledPoint point = XGBoostUtils.parseFeatures(0.f, features); + final LabeledPoint point = XGBoostUtils.parseFeatures(0.f, fv); if (point != null) { if (!rowBuffer.containsKey(modelId)) { rowBuffer.put(modelId, new ArrayList()); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java b/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java index b57925a..059cb1c 100644 --- a/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java +++ b/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java @@ -21,6 +21,7 @@ package hivemall.xgboost; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.util.*; +import javax.annotation.Nonnull; import ml.dmlc.xgboost4j.LabeledPoint; import ml.dmlc.xgboost4j.java.Booster; @@ -232,8 +233,8 @@ public abstract class XGBoostUDTF extends UDTFWithOptions { // Try to create a `Booster` instance to check if given XGBoost options // are valid, or not. createXGBooster(params, featuresList); - } catch (XGBoostError e) { - throw new UDFArgumentException(e.getMessage()); + } catch (Exception e) { + throw new UDFArgumentException(e); } return cl; @@ -264,49 +265,38 @@ public abstract class XGBoostUDTF extends UDTFWithOptions { /** It `target` has valid input range, it overrides this */ public void checkTargetValue(double target) throws HiveException {} - @Override public void process(Object[] args) throws HiveException { if (args[0] != null) { // TODO: Need to support dense inputs - final List<String> features = (List<String>) featureListOI.getList(args[0]); + final List<?> features = (List<?>) featureListOI.getList(args[0]); + final String[] fv = new String[features.size()]; + for (int i = 0; i < features.size(); i++) { + fv[i] = (String) featureElemOI.getPrimitiveJavaObject(features.get(i)); + } double target = PrimitiveObjectInspectorUtils.getDouble(args[1], this.targetOI); checkTargetValue(target); - final LabeledPoint point = XGBoostUtils.parseFeatures(target, features); + final LabeledPoint point = XGBoostUtils.parseFeatures(target, fv); if (point != null) { this.featuresList.add(point); } } } - /** - * Need to override this for a Spark wrapper because `MapredContext` does not work in there. - */ - protected String generateUniqueModelId() { - return "xgbmodel-" + String.valueOf(HadoopUtils.getTaskId()); + private String generateUniqueModelId() { + return "xgbmodel-" + HadoopUtils.getUniqueTaskIdString(); } - private static Booster createXGBooster(final Map<String, Object> params, - final List<LabeledPoint> input) throws XGBoostError { - try { - Class<?>[] args = {Map.class, DMatrix[].class}; - Constructor<Booster> ctor; - ctor = Booster.class.getDeclaredConstructor(args); - ctor.setAccessible(true); - return ctor.newInstance(new Object[] {params, - new DMatrix[] {new DMatrix(input.iterator(), "")}}); - } catch (InstantiationException e) { - // Catch java reflection error as fast as possible - e.printStackTrace(); - } catch (IllegalAccessException e) { - e.printStackTrace(); - } catch (InvocationTargetException e) { - e.printStackTrace(); - } catch (NoSuchMethodException e) { - e.printStackTrace(); - } - // No one reach here - return null; + @Nonnull + private static Booster createXGBooster( + final Map<String, Object> params, final List<LabeledPoint> input) + throws NoSuchMethodException, XGBoostError, IllegalAccessException, + InvocationTargetException, InstantiationException { + Class<?>[] args = {Map.class, DMatrix[].class}; + Constructor<Booster> ctor = Booster.class.getDeclaredConstructor(args); + ctor.setAccessible(true); + return ctor.newInstance(new Object[] {params, + new DMatrix[] {new DMatrix(input.iterator(), "")}}); } @Override @@ -326,7 +316,7 @@ public abstract class XGBoostUDTF extends UDTFWithOptions { logger.info("model_id:" + modelId.toString() + " size:" + predModel.length); forward(new Object[] {modelId, predModel}); } catch (Exception e) { - throw new HiveException(e.getMessage()); + throw new HiveException(e); } } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java ---------------------------------------------------------------------- diff --git a/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java b/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java index d0769f4..632d2fe 100644 --- a/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java +++ b/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java @@ -27,18 +27,18 @@ public final class XGBoostUtils { private XGBoostUtils() {} /** Transform List<String> inputs into a XGBoost input format */ - public static LabeledPoint parseFeatures(double target, List<String> features) { - final int size = features.size(); + public static LabeledPoint parseFeatures(double target, String[] features) { + final int size = features.length; if (size == 0) { return null; } final int[] indices = new int[size]; final float[] values = new float[size]; for (int i = 0; i < size; i++) { - if (features.get(i) == null) { + if (features[i] == null) { continue; } - final String str = features.get(i); + final String str = features[i]; final int pos = str.indexOf(':'); if (pos >= 1) { indices[i] = Integer.parseInt(str.substring(0, pos));
