Repository: incubator-hivemall
Updated Branches:
  refs/heads/master 55c858816 -> 7e96c8a99


Close #95: [HIVEMALL-119] Fix type cast issues in XGBoostUDTF


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/04372d49
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/04372d49
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/04372d49

Branch: refs/heads/master
Commit: 04372d490194d598bbc79ec1adba8b0918225c38
Parents: 55c8588
Author: Takeshi Yamamuro <[email protected]>
Authored: Fri Jul 14 23:27:56 2017 +0900
Committer: Makoto Yui <[email protected]>
Committed: Fri Jul 14 23:27:56 2017 +0900

----------------------------------------------------------------------
 .../XGBoostBinaryClassifierUDTFWrapper.java     | 47 -----------------
 .../XGBoostMulticlassClassifierUDTFWrapper.java | 47 -----------------
 .../XGBoostRegressionUDTFWrapper.java           | 47 -----------------
 .../org/apache/spark/sql/hive/HivemallOps.scala |  6 +--
 .../XGBoostBinaryClassifierUDTFWrapper.java     | 47 -----------------
 .../XGBoostMulticlassClassifierUDTFWrapper.java | 47 -----------------
 .../XGBoostRegressionUDTFWrapper.java           | 47 -----------------
 .../org/apache/spark/sql/hive/HivemallOps.scala |  6 +--
 .../hivemall/xgboost/XGBoostPredictUDTF.java    | 12 +++--
 .../main/java/hivemall/xgboost/XGBoostUDTF.java | 54 ++++++++------------
 .../java/hivemall/xgboost/XGBoostUtils.java     |  8 +--
 11 files changed, 40 insertions(+), 328 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/spark/spark-2.0/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.0/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java
 
b/spark/spark-2.0/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java
deleted file mode 100644
index 310d15e..0000000
--- 
a/spark/spark-2.0/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.xgboost.classification;
-
-import java.util.UUID;
-
-import org.apache.hadoop.hive.ql.exec.Description;
-
-/** An alternative implementation of 
[[hivemall.xgboost.classification.XGBoostBinaryClassifierUDTF]]. */
-@Description(
-    name = "train_xgboost_classifier",
-    value = "_FUNC_(string[] features, double target [, string options]) - 
Returns a relation consisting of <string model_id, array<byte> pred_model>"
-)
-public class XGBoostBinaryClassifierUDTFWrapper extends 
XGBoostBinaryClassifierUDTF {
-    private long sequence;
-    private long taskId;
-
-    public XGBoostBinaryClassifierUDTFWrapper() {
-        this.sequence = 0L;
-        this.taskId = Thread.currentThread().getId();
-    }
-
-    @Override
-    protected String generateUniqueModelId() {
-        sequence++;
-        /**
-         * TODO: Check if it is unique over all tasks in executors of Spark.
-         */
-        return "xgbmodel-" + taskId + "-" + UUID.randomUUID() + "-" + sequence;
-    }
-}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/spark/spark-2.0/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.0/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java
 
b/spark/spark-2.0/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java
deleted file mode 100644
index 81e6fe8..0000000
--- 
a/spark/spark-2.0/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.xgboost.classification;
-
-import java.util.UUID;
-
-import org.apache.hadoop.hive.ql.exec.Description;
-
-/** An alternative implementation of 
[[hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTFWrapper]]. */
-@Description(
-    name = "train_multiclass_xgboost_classifier",
-    value = "_FUNC_(string[] features, double target [, string options]) - 
Returns a relation consisting of <string model_id, array<byte> pred_model>"
-)
-public class XGBoostMulticlassClassifierUDTFWrapper extends 
XGBoostMulticlassClassifierUDTF {
-    private long sequence;
-    private long taskId;
-
-    public XGBoostMulticlassClassifierUDTFWrapper() {
-        this.sequence = 0L;
-        this.taskId = Thread.currentThread().getId();
-    }
-
-    @Override
-    protected String generateUniqueModelId() {
-        sequence++;
-        /**
-         * TODO: Check if it is unique over all tasks in executors of Spark.
-         */
-        return "xgbmodel-" + taskId + "-" + UUID.randomUUID() + "-" + sequence;
-    }
-}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/spark/spark-2.0/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.0/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java
 
b/spark/spark-2.0/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java
deleted file mode 100644
index b72e045..0000000
--- 
a/spark/spark-2.0/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.xgboost.regression;
-
-import java.util.UUID;
-
-import org.apache.hadoop.hive.ql.exec.Description;
-
-/** An alternative implementation of 
[[hivemall.xgboost.regression.XGBoostRegressionUDTF]]. */
-@Description(
-    name = "train_xgboost_regr",
-    value = "_FUNC_(string[] features, double target [, string options]) - 
Returns a relation consisting of <string model_id, array<byte> pred_model>"
-)
-public class XGBoostRegressionUDTFWrapper extends XGBoostRegressionUDTF {
-    private long sequence;
-    private long taskId;
-
-    public XGBoostRegressionUDTFWrapper() {
-        this.sequence = 0L;
-        this.taskId = Thread.currentThread().getId();
-    }
-
-    @Override
-    protected String generateUniqueModelId() {
-        sequence++;
-        /**
-         * TODO: Check if it is unique over all tasks in executors of Spark.
-         */
-        return "xgbmodel-" + taskId + "-" + UUID.randomUUID() + "-" + sequence;
-    }
-}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala 
b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
index 7b13892..b5299ef 100644
--- a/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
+++ b/spark/spark-2.0/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
@@ -519,7 +519,7 @@ final class HivemallOps(df: DataFrame) extends Logging {
   def train_xgboost_regr(exprs: Column*): DataFrame = withTypedPlan {
     planHiveGenericUDTF(
       df,
-      "hivemall.xgboost.regression.XGBoostRegressionUDTFWrapper",
+      "hivemall.xgboost.regression.XGBoostRegressionUDTF",
       "train_xgboost_regr",
       setMixServs(toHivemallFeatures(exprs)),
       Seq("model_id", "pred_model")
@@ -536,7 +536,7 @@ final class HivemallOps(df: DataFrame) extends Logging {
   def train_xgboost_classifier(exprs: Column*): DataFrame = withTypedPlan {
     planHiveGenericUDTF(
       df,
-      "hivemall.xgboost.classification.XGBoostBinaryClassifierUDTFWrapper",
+      "hivemall.xgboost.classification.XGBoostBinaryClassifierUDTF",
       "train_xgboost_classifier",
       setMixServs(toHivemallFeatures(exprs)),
       Seq("model_id", "pred_model")
@@ -553,7 +553,7 @@ final class HivemallOps(df: DataFrame) extends Logging {
   def train_xgboost_multiclass_classifier(exprs: Column*): DataFrame = 
withTypedPlan {
     planHiveGenericUDTF(
       df,
-      "hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTFWrapper",
+      "hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTF",
       "train_xgboost_multiclass_classifier",
       setMixServs(toHivemallFeatures(exprs)),
       Seq("model_id", "pred_model")

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java
 
b/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java
deleted file mode 100644
index 310d15e..0000000
--- 
a/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostBinaryClassifierUDTFWrapper.java
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.xgboost.classification;
-
-import java.util.UUID;
-
-import org.apache.hadoop.hive.ql.exec.Description;
-
-/** An alternative implementation of 
[[hivemall.xgboost.classification.XGBoostBinaryClassifierUDTF]]. */
-@Description(
-    name = "train_xgboost_classifier",
-    value = "_FUNC_(string[] features, double target [, string options]) - 
Returns a relation consisting of <string model_id, array<byte> pred_model>"
-)
-public class XGBoostBinaryClassifierUDTFWrapper extends 
XGBoostBinaryClassifierUDTF {
-    private long sequence;
-    private long taskId;
-
-    public XGBoostBinaryClassifierUDTFWrapper() {
-        this.sequence = 0L;
-        this.taskId = Thread.currentThread().getId();
-    }
-
-    @Override
-    protected String generateUniqueModelId() {
-        sequence++;
-        /**
-         * TODO: Check if it is unique over all tasks in executors of Spark.
-         */
-        return "xgbmodel-" + taskId + "-" + UUID.randomUUID() + "-" + sequence;
-    }
-}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java
 
b/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java
deleted file mode 100644
index 81e6fe8..0000000
--- 
a/spark/spark-2.1/src/main/java/hivemall/xgboost/classification/XGBoostMulticlassClassifierUDTFWrapper.java
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.xgboost.classification;
-
-import java.util.UUID;
-
-import org.apache.hadoop.hive.ql.exec.Description;
-
-/** An alternative implementation of 
[[hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTFWrapper]]. */
-@Description(
-    name = "train_multiclass_xgboost_classifier",
-    value = "_FUNC_(string[] features, double target [, string options]) - 
Returns a relation consisting of <string model_id, array<byte> pred_model>"
-)
-public class XGBoostMulticlassClassifierUDTFWrapper extends 
XGBoostMulticlassClassifierUDTF {
-    private long sequence;
-    private long taskId;
-
-    public XGBoostMulticlassClassifierUDTFWrapper() {
-        this.sequence = 0L;
-        this.taskId = Thread.currentThread().getId();
-    }
-
-    @Override
-    protected String generateUniqueModelId() {
-        sequence++;
-        /**
-         * TODO: Check if it is unique over all tasks in executors of Spark.
-         */
-        return "xgbmodel-" + taskId + "-" + UUID.randomUUID() + "-" + sequence;
-    }
-}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/spark/spark-2.1/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java
 
b/spark/spark-2.1/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java
deleted file mode 100644
index b72e045..0000000
--- 
a/spark/spark-2.1/src/main/java/hivemall/xgboost/regression/XGBoostRegressionUDTFWrapper.java
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package hivemall.xgboost.regression;
-
-import java.util.UUID;
-
-import org.apache.hadoop.hive.ql.exec.Description;
-
-/** An alternative implementation of 
[[hivemall.xgboost.regression.XGBoostRegressionUDTF]]. */
-@Description(
-    name = "train_xgboost_regr",
-    value = "_FUNC_(string[] features, double target [, string options]) - 
Returns a relation consisting of <string model_id, array<byte> pred_model>"
-)
-public class XGBoostRegressionUDTFWrapper extends XGBoostRegressionUDTF {
-    private long sequence;
-    private long taskId;
-
-    public XGBoostRegressionUDTFWrapper() {
-        this.sequence = 0L;
-        this.taskId = Thread.currentThread().getId();
-    }
-
-    @Override
-    protected String generateUniqueModelId() {
-        sequence++;
-        /**
-         * TODO: Check if it is unique over all tasks in executors of Spark.
-         */
-        return "xgbmodel-" + taskId + "-" + UUID.randomUUID() + "-" + sequence;
-    }
-}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala 
b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
index 83129a7..9350a81 100644
--- a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
+++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
@@ -523,7 +523,7 @@ final class HivemallOps(df: DataFrame) extends Logging {
   def train_xgboost_regr(exprs: Column*): DataFrame = withTypedPlan {
     planHiveGenericUDTF(
       df,
-      "hivemall.xgboost.regression.XGBoostRegressionUDTFWrapper",
+      "hivemall.xgboost.regression.XGBoostRegressionUDTF",
       "train_xgboost_regr",
       setMixServs(toHivemallFeatures(exprs)),
       Seq("model_id", "pred_model")
@@ -540,7 +540,7 @@ final class HivemallOps(df: DataFrame) extends Logging {
   def train_xgboost_classifier(exprs: Column*): DataFrame = withTypedPlan {
     planHiveGenericUDTF(
       df,
-      "hivemall.xgboost.classification.XGBoostBinaryClassifierUDTFWrapper",
+      "hivemall.xgboost.classification.XGBoostBinaryClassifierUDTF",
       "train_xgboost_classifier",
       setMixServs(toHivemallFeatures(exprs)),
       Seq("model_id", "pred_model")
@@ -557,7 +557,7 @@ final class HivemallOps(df: DataFrame) extends Logging {
   def train_xgboost_multiclass_classifier(exprs: Column*): DataFrame = 
withTypedPlan {
     planHiveGenericUDTF(
       df,
-      "hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTFWrapper",
+      "hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTF",
       "train_xgboost_multiclass_classifier",
       setMixServs(toHivemallFeatures(exprs)),
       Seq("model_id", "pred_model")

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java
----------------------------------------------------------------------
diff --git a/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java 
b/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java
index e05755e..a175dd2 100644
--- a/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java
+++ b/xgboost/src/main/java/hivemall/xgboost/XGBoostPredictUDTF.java
@@ -140,7 +140,7 @@ public abstract class XGBoostPredictUDTF extends 
UDTFWithOptions {
         try {
             return XGBoost.loadModel(new ByteArrayInputStream(input));
         } catch (Exception e) {
-            throw new HiveException(e.getMessage());
+            throw new HiveException(e);
         }
     }
 
@@ -151,7 +151,7 @@ public abstract class XGBoostPredictUDTF extends 
UDTFWithOptions {
             final float[][] predicted = model.predict(testData);
             forwardPredicted(buf, predicted);
         } catch (Exception e) {
-            throw new HiveException(e.getMessage());
+            throw new HiveException(e);
         }
         buf.clear();
     }
@@ -160,14 +160,18 @@ public abstract class XGBoostPredictUDTF extends 
UDTFWithOptions {
     public void process(Object[] args) throws HiveException {
         if (args[1] != null) {
             final String rowId = 
PrimitiveObjectInspectorUtils.getString(args[0], rowIdOI);
-            final List<String> features = (List<String>) 
featureListOI.getList(args[1]);
+            final List<?> features = (List<?>) featureListOI.getList(args[1]);
+            final String[] fv = new String[features.size()];
+            for (int i = 0; i < features.size(); i++) {
+                fv[i] = (String) 
featureElemOI.getPrimitiveJavaObject(features.get(i));
+            }
             final String modelId = 
PrimitiveObjectInspectorUtils.getString(args[2], modelIdOI);
             if (!mapToModel.containsKey(modelId)) {
                 final byte[] predModel = 
PrimitiveObjectInspectorUtils.getBinary(args[3], modelOI)
                                                                       
.getBytes();
                 mapToModel.put(modelId, initXgBooster(predModel));
             }
-            final LabeledPoint point = XGBoostUtils.parseFeatures(0.f, 
features);
+            final LabeledPoint point = XGBoostUtils.parseFeatures(0.f, fv);
             if (point != null) {
                 if (!rowBuffer.containsKey(modelId)) {
                     rowBuffer.put(modelId, new ArrayList());

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java
----------------------------------------------------------------------
diff --git a/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java 
b/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java
index b57925a..059cb1c 100644
--- a/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java
+++ b/xgboost/src/main/java/hivemall/xgboost/XGBoostUDTF.java
@@ -21,6 +21,7 @@ package hivemall.xgboost;
 import java.lang.reflect.Constructor;
 import java.lang.reflect.InvocationTargetException;
 import java.util.*;
+import javax.annotation.Nonnull;
 
 import ml.dmlc.xgboost4j.LabeledPoint;
 import ml.dmlc.xgboost4j.java.Booster;
@@ -232,8 +233,8 @@ public abstract class XGBoostUDTF extends UDTFWithOptions {
             // Try to create a `Booster` instance to check if given XGBoost 
options
             // are valid, or not.
             createXGBooster(params, featuresList);
-        } catch (XGBoostError e) {
-            throw new UDFArgumentException(e.getMessage());
+        } catch (Exception e) {
+            throw new UDFArgumentException(e);
         }
 
         return cl;
@@ -264,49 +265,38 @@ public abstract class XGBoostUDTF extends UDTFWithOptions 
{
     /** It `target` has valid input range, it overrides this */
     public void checkTargetValue(double target) throws HiveException {}
 
-
     @Override
     public void process(Object[] args) throws HiveException {
         if (args[0] != null) {
             // TODO: Need to support dense inputs
-            final List<String> features = (List<String>) 
featureListOI.getList(args[0]);
+            final List<?> features = (List<?>) featureListOI.getList(args[0]);
+            final String[] fv = new String[features.size()];
+            for (int i = 0; i < features.size(); i++) {
+                fv[i] = (String) 
featureElemOI.getPrimitiveJavaObject(features.get(i));
+            }
             double target = PrimitiveObjectInspectorUtils.getDouble(args[1], 
this.targetOI);
             checkTargetValue(target);
-            final LabeledPoint point = XGBoostUtils.parseFeatures(target, 
features);
+            final LabeledPoint point = XGBoostUtils.parseFeatures(target, fv);
             if (point != null) {
                 this.featuresList.add(point);
             }
         }
     }
 
-    /**
-     * Need to override this for a Spark wrapper because `MapredContext` does 
not work in there.
-     */
-    protected String generateUniqueModelId() {
-        return "xgbmodel-" + String.valueOf(HadoopUtils.getTaskId());
+    private String generateUniqueModelId() {
+        return "xgbmodel-" + HadoopUtils.getUniqueTaskIdString();
     }
 
-    private static Booster createXGBooster(final Map<String, Object> params,
-            final List<LabeledPoint> input) throws XGBoostError {
-        try {
-            Class<?>[] args = {Map.class, DMatrix[].class};
-            Constructor<Booster> ctor;
-            ctor = Booster.class.getDeclaredConstructor(args);
-            ctor.setAccessible(true);
-            return ctor.newInstance(new Object[] {params,
-                    new DMatrix[] {new DMatrix(input.iterator(), "")}});
-        } catch (InstantiationException e) {
-            // Catch java reflection error as fast as possible
-            e.printStackTrace();
-        } catch (IllegalAccessException e) {
-            e.printStackTrace();
-        } catch (InvocationTargetException e) {
-            e.printStackTrace();
-        } catch (NoSuchMethodException e) {
-            e.printStackTrace();
-        }
-        // No one reach here
-        return null;
+    @Nonnull
+    private static Booster createXGBooster(
+            final Map<String, Object> params, final List<LabeledPoint> input)
+            throws NoSuchMethodException, XGBoostError, IllegalAccessException,
+                InvocationTargetException, InstantiationException {
+        Class<?>[] args = {Map.class, DMatrix[].class};
+        Constructor<Booster> ctor = Booster.class.getDeclaredConstructor(args);
+        ctor.setAccessible(true);
+        return ctor.newInstance(new Object[] {params,
+                new DMatrix[] {new DMatrix(input.iterator(), "")}});
     }
 
     @Override
@@ -326,7 +316,7 @@ public abstract class XGBoostUDTF extends UDTFWithOptions {
             logger.info("model_id:" + modelId.toString() + " size:" + 
predModel.length);
             forward(new Object[] {modelId, predModel});
         } catch (Exception e) {
-            throw new HiveException(e.getMessage());
+            throw new HiveException(e);
         }
     }
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/04372d49/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java
----------------------------------------------------------------------
diff --git a/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java 
b/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java
index d0769f4..632d2fe 100644
--- a/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java
+++ b/xgboost/src/main/java/hivemall/xgboost/XGBoostUtils.java
@@ -27,18 +27,18 @@ public final class XGBoostUtils {
     private XGBoostUtils() {}
 
     /** Transform List<String> inputs into a XGBoost input format */
-    public static LabeledPoint parseFeatures(double target, List<String> 
features) {
-        final int size = features.size();
+    public static LabeledPoint parseFeatures(double target, String[] features) 
{
+        final int size = features.length;
         if (size == 0) {
             return null;
         }
         final int[] indices = new int[size];
         final float[] values = new float[size];
         for (int i = 0; i < size; i++) {
-            if (features.get(i) == null) {
+            if (features[i] == null) {
                 continue;
             }
-            final String str = features.get(i);
+            final String str = features[i];
             final int pos = str.indexOf(':');
             if (pos >= 1) {
                 indices[i] = Integer.parseInt(str.substring(0, pos));

Reply via email to