Repository: incubator-hivemall Updated Branches: refs/heads/master 30593b14b -> 47d1100c1
[HIVEMALL-218] Fixed train_lda NPE where input row is null ## What changes were proposed in this pull request? Fixed NegativeArraySizeException where input is NULL of `train_lda` ## What type of PR is it? Bug Fix ## What is the Jira issue? https://issues.apache.org/jira/browse/HIVEMALL-218 ## How was this patch tested? manual tests ## Checklist - [x] Did you apply source code formatter, i.e., `./bin/format_code.sh`, for your commit? - [x] Did you run system tests on Hive (or Spark)? Author: Makoto Yui <[email protected]> Closes #164 from myui/HIVEMALL-218. Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/47d1100c Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/47d1100c Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/47d1100c Branch: refs/heads/master Commit: 47d1100c1fab6796f09f0998624b3a445869f1d4 Parents: 30593b1 Author: Makoto Yui <[email protected]> Authored: Fri Sep 7 19:19:35 2018 +0900 Committer: Makoto Yui <[email protected]> Committed: Fri Sep 7 19:19:35 2018 +0900 ---------------------------------------------------------------------- .../ProbabilisticTopicModelBaseUDTF.java | 23 +++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/47d1100c/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java index 33d940d..23a021d 100644 --- a/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java +++ b/core/src/main/java/hivemall/topicmodel/ProbabilisticTopicModelBaseUDTF.java @@ -57,6 +57,8 @@ import org.apache.hadoop.io.Text; import org.apache.hadoop.mapred.Counters; import org.apache.hadoop.mapred.Reporter; +import com.google.common.base.Preconditions; + public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions { private static final Log logger = LogFactory.getLog(ProbabilisticTopicModelBaseUDTF.class); @@ -159,11 +161,17 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions { this.model = createModel(); } - final int length = wordCountsOI.getListLength(args[0]); + Preconditions.checkArgument(args.length >= 1); + Object arg0 = args[0]; + if (arg0 == null) { + return; + } + + final int length = wordCountsOI.getListLength(arg0); final String[] wordCounts = new String[length]; int j = 0; for (int i = 0; i < length; i++) { - Object o = wordCountsOI.getListElement(args[0], i); + Object o = wordCountsOI.getListElement(arg0, i); if (o == null) { throw new HiveException("Given feature vector contains invalid null elements"); } @@ -268,6 +276,10 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions { @Override public void close() throws HiveException { + if (model.getDocCount() == 0L) { + this.model = null; + throw new HiveException("No training exmples to learn. Please revise input data."); + } finalizeTraining(); forwardModel(); this.model = null; @@ -275,10 +287,6 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions { @VisibleForTesting void finalizeTraining() throws HiveException { - if (model.getDocCount() == 0L) { - this.model = null; - return; - } if (miniBatchCount > 0) { // update for remaining samples model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); } @@ -462,6 +470,9 @@ public abstract class ProbabilisticTopicModelBaseUDTF extends UDTFWithOptions { topicIdx.set(k); final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k); + if (topicWords == null) { + continue; + } for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { score.set(e.getKey().floatValue()); for (String v : e.getValue()) {
