Close #71: [HIVEMALL-74] Implement pLSA
Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/f2bf3a72 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/f2bf3a72 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/f2bf3a72 Branch: refs/heads/master Commit: f2bf3a72b2f8deb0835feed649369c885a23053c Parents: bffd2c7 Author: Takuya Kitazawa <[email protected]> Authored: Thu Apr 27 22:44:44 2017 +0900 Committer: myui <[email protected]> Committed: Thu Apr 27 22:44:44 2017 +0900 ---------------------------------------------------------------------- .../topicmodel/IncrementalPLSAModel.java | 316 +++++++++++ .../hivemall/topicmodel/LDAPredictUDAF.java | 24 +- .../main/java/hivemall/topicmodel/LDAUDTF.java | 22 +- .../hivemall/topicmodel/PLSAPredictUDAF.java | 480 +++++++++++++++++ .../main/java/hivemall/topicmodel/PLSAUDTF.java | 535 +++++++++++++++++++ .../java/hivemall/utils/lang/ArrayUtils.java | 10 + .../java/hivemall/utils/math/MathUtils.java | 12 + .../topicmodel/IncrementalPLSAModelTest.java | 291 ++++++++++ .../hivemall/topicmodel/LDAPredictUDAFTest.java | 4 +- .../java/hivemall/topicmodel/LDAUDTFTest.java | 2 +- .../topicmodel/PLSAPredictUDAFTest.java | 217 ++++++++ .../java/hivemall/topicmodel/PLSAUDTFTest.java | 106 ++++ docs/gitbook/SUMMARY.md | 1 + docs/gitbook/clustering/plsa.md | 154 ++++++ resources/ddl/define-all-as-permanent.hive | 6 + resources/ddl/define-all.hive | 6 + resources/ddl/define-all.spark | 6 + resources/ddl/define-udfs.td.hql | 2 + 18 files changed, 2168 insertions(+), 26 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java new file mode 100644 index 0000000..745e510 --- /dev/null +++ b/core/src/main/java/hivemall/topicmodel/IncrementalPLSAModel.java @@ -0,0 +1,316 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.topicmodel; + +import static hivemall.utils.lang.ArrayUtils.newRandomFloatArray; +import static hivemall.utils.math.MathUtils.l1normalize; +import hivemall.model.FeatureValue; +import hivemall.utils.math.MathUtils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.SortedMap; +import java.util.TreeMap; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +public final class IncrementalPLSAModel { + + // --------------------------------- + // HyperParameters + + // number of topics + private final int _K; + + // control how much P(w|z) update is affected by the last value + private final float _alpha; + + // check convergence of P(w|z) for a document + private final double _delta; + + // --------------------------------- + + // random number generator + @Nonnull + private final Random _rnd; + + // optimized in the E step + private List<Map<String, float[]>> _p_dwz; // P(z|d,w) probability of topics for each document-word (i.e., instance-feature) pair + + // optimized in the M step + @Nonnull + private List<float[]> _p_dz; // P(z|d) probability of topics for documents + private Map<String, float[]> _p_zw; // P(w|z) probability of words for each topic + + @Nonnull + private final List<Map<String, Float>> _miniBatchDocs; + private int _miniBatchSize; + + public IncrementalPLSAModel(int K, float alpha, double delta) { + this._K = K; + this._alpha = alpha; + this._delta = delta; + + this._rnd = new Random(1001); + + this._p_zw = new HashMap<String, float[]>(); + + this._miniBatchDocs = new ArrayList<Map<String, Float>>(); + } + + public void train(@Nonnull final String[][] miniBatch) { + initMiniBatch(miniBatch, _miniBatchDocs); + + this._miniBatchSize = _miniBatchDocs.size(); + + initParams(); + + final List<float[]> pPrev_dz = new ArrayList<float[]>(); + + for (int d = 0; d < _miniBatchSize; d++) { + do { + pPrev_dz.clear(); + pPrev_dz.addAll(_p_dz); + + // Expectation + eStep(d); + + // Maximization + mStep(d); + } while (!isPdzConverged(d, pPrev_dz, _p_dz)); // until get stable value of P(z|d) + } + } + + private static void initMiniBatch(@Nonnull final String[][] miniBatch, + @Nonnull final List<Map<String, Float>> docs) { + docs.clear(); + + final FeatureValue probe = new FeatureValue(); + + // parse document + for (final String[] e : miniBatch) { + if (e == null || e.length == 0) { + continue; + } + + final Map<String, Float> doc = new HashMap<String, Float>(); + + // parse features + for (String fv : e) { + if (fv == null) { + continue; + } + FeatureValue.parseFeatureAsString(fv, probe); + String word = probe.getFeatureAsString(); + float value = probe.getValueAsFloat(); + doc.put(word, Float.valueOf(value)); + } + + docs.add(doc); + } + } + + private void initParams() { + final List<float[]> p_dz = new ArrayList<float[]>(); + final List<Map<String, float[]>> p_dwz = new ArrayList<Map<String, float[]>>(); + + for (int d = 0; d < _miniBatchSize; d++) { + // init P(z|d) + float[] p_dz_d = l1normalize(newRandomFloatArray(_K, _rnd)); + p_dz.add(p_dz_d); + + final Map<String, float[]> p_dwz_d = new HashMap<String, float[]>(); + p_dwz.add(p_dwz_d); + + for (final String w : _miniBatchDocs.get(d).keySet()) { + // init P(z|d,w) + float[] p_dwz_dw = l1normalize(newRandomFloatArray(_K, _rnd)); + p_dwz_d.put(w, p_dwz_dw); + + // insert new labels to P(w|z) + if (!_p_zw.containsKey(w)) { + _p_zw.put(w, newRandomFloatArray(_K, _rnd)); + } + } + } + + // ensure \sum_w P(w|z) = 1 + final double[] sums = new double[_K]; + for (float[] p_zw_w : _p_zw.values()) { + MathUtils.add(p_zw_w, sums, _K); + } + for (float[] p_zw_w : _p_zw.values()) { + for (int z = 0; z < _K; z++) { + p_zw_w[z] /= sums[z]; + } + } + + this._p_dz = p_dz; + this._p_dwz = p_dwz; + } + + private void eStep(@Nonnegative final int d) { + final Map<String, float[]> p_dwz_d = _p_dwz.get(d); + final float[] p_dz_d = _p_dz.get(d); + + // update P(z|d,w) = P(z|d) * P(w|z) + for (final String w : _miniBatchDocs.get(d).keySet()) { + final float[] p_dwz_dw = p_dwz_d.get(w); + final float[] p_zw_w = _p_zw.get(w); + for (int z = 0; z < _K; z++) { + p_dwz_dw[z] = p_dz_d[z] * p_zw_w[z]; + } + l1normalize(p_dwz_dw); + } + } + + private void mStep(@Nonnegative final int d) { + final Map<String, Float> doc = _miniBatchDocs.get(d); + final Map<String, float[]> p_dwz_d = _p_dwz.get(d); + + // update P(z|d) = n(d,w) * P(z|d,w) + final float[] p_dz_d = _p_dz.get(d); + Arrays.fill(p_dz_d, 0.f); // zero-fill w/ keeping pointer to _p_dz.get(d) + for (Map.Entry<String, Float> e : doc.entrySet()) { + final float[] p_dwz_dw = p_dwz_d.get(e.getKey()); + final float n = e.getValue().floatValue(); + for (int z = 0; z < _K; z++) { + p_dz_d[z] += n * p_dwz_dw[z]; + } + } + l1normalize(p_dz_d); + + // update P(w|z) = n(d,w) * P(z|d,w) + alpha * P(w|z)^(n-1) + final double[] sums = new double[_K]; + for (Map.Entry<String, float[]> e : _p_zw.entrySet()) { + String w = e.getKey(); + final float[] p_zw_w = e.getValue(); + + Float w_value = doc.get(w); + if (w_value != null) { // all words in the document + final float n = w_value.floatValue(); + final float[] p_dwz_dw = p_dwz_d.get(w); + + for (int z = 0; z < _K; z++) { + p_zw_w[z] = n * p_dwz_dw[z] + _alpha * p_zw_w[z]; + } + } + + MathUtils.add(p_zw_w, sums, _K); + } + // normalize to ensure \sum_w P(w|z) = 1 + for (float[] p_zw_w : _p_zw.values()) { + for (int z = 0; z < _K; z++) { + p_zw_w[z] /= sums[z]; + } + } + } + + private boolean isPdzConverged(@Nonnegative final int d, @Nonnull final List<float[]> pPrev_dz, + @Nonnull final List<float[]> p_dz) { + final float[] pPrev_dz_d = pPrev_dz.get(d); + final float[] p_dz_d = p_dz.get(d); + + double diff = 0.d; + for (int z = 0; z < _K; z++) { + diff += Math.abs(pPrev_dz_d[z] - p_dz_d[z]); + } + return (diff / _K) < _delta; + } + + public float computePerplexity() { + double numer = 0.d; + double denom = 0.d; + + for (int d = 0; d < _miniBatchSize; d++) { + final float[] p_dz_d = _p_dz.get(d); + for (Map.Entry<String, Float> e : _miniBatchDocs.get(d).entrySet()) { + String w = e.getKey(); + float w_value = e.getValue().floatValue(); + + final float[] p_zw_w = _p_zw.get(w); + double p_dw = 0.d; + for (int z = 0; z < _K; z++) { + p_dw += (double) p_zw_w[z] * p_dz_d[z]; + } + + numer += w_value * Math.log(p_dw); + denom += w_value; + } + } + + return (float) Math.exp(-1.d * (numer / denom)); + } + + @Nonnull + public SortedMap<Float, List<String>> getTopicWords(@Nonnegative final int z) { + final SortedMap<Float, List<String>> res = new TreeMap<Float, List<String>>( + Collections.reverseOrder()); + + for (Map.Entry<String, float[]> e : _p_zw.entrySet()) { + final String w = e.getKey(); + final float prob = e.getValue()[z]; + + List<String> words = res.get(prob); + if (words == null) { + words = new ArrayList<String>(); + res.put(prob, words); + } + words.add(w); + } + + return res; + } + + @Nonnull + public float[] getTopicDistribution(@Nonnull final String[] doc) { + train(new String[][] {doc}); + return _p_dz.get(0); + } + + public float getProbability(@Nonnull final String w, @Nonnegative final int z) { + return _p_zw.get(w)[z]; + } + + public void setProbability(@Nonnull final String w, @Nonnegative final int z, final float prob) { + float[] prob_label = _p_zw.get(w); + if (prob_label == null) { + prob_label = newRandomFloatArray(_K, _rnd); + _p_zw.put(w, prob_label); + } + prob_label[z] = prob; + + // ensure \sum_w P(w|z) = 1 + final double[] sums = new double[_K]; + for (float[] p_zw_w : _p_zw.values()) { + MathUtils.add(p_zw_w, sums, _K); + } + for (float[] p_zw_w : _p_zw.values()) { + for (int zi = 0; zi < _K; zi++) { + p_zw_w[zi] /= sums[zi]; + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java index 811af2e..a4076b6 100644 --- a/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java +++ b/core/src/main/java/hivemall/topicmodel/LDAPredictUDAF.java @@ -112,7 +112,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver { private PrimitiveObjectInspector lambdaOI; // Hyperparameters - private int topic; + private int topics; private float alpha; private double delta; @@ -134,7 +134,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver { protected Options getOptions() { Options opts = new Options(); - opts.addOption("k", "topic", true, "The number of topics [required]"); + opts.addOption("k", "topics", true, "The number of topics [required]"); opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]"); opts.addOption("delta", true, "Check convergence in the expectation step [default: 1E-5]"); @@ -176,19 +176,19 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver { CommandLine cl = null; if (argOIs.length != 5) { - throw new UDFArgumentException("At least 1 option `-topic` MUST be specified"); + throw new UDFArgumentException("At least 1 option `-topics` MUST be specified"); } String rawArgs = HiveUtils.getConstString(argOIs[4]); cl = parseOptions(rawArgs); - this.topic = Primitives.parseInt(cl.getOptionValue("topic"), 0); - if (topic < 1) { + this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 0); + if (topics < 1) { throw new UDFArgumentException( - "A positive integer MUST be set to an option `-topic`: " + topic); + "A positive integer MUST be set to an option `-topics`: " + topics); } - this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topic); + this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topics); this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 1E-5d); return cl; @@ -211,7 +211,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver { this.internalMergeOI = soi; this.wcListField = soi.getStructFieldRef("wcList"); this.lambdaMapField = soi.getStructFieldRef("lambdaMap"); - this.topicOptionField = soi.getStructFieldRef("topic"); + this.topicOptionField = soi.getStructFieldRef("topics"); this.alphaOptionField = soi.getStructFieldRef("alpha"); this.deltaOptionField = soi.getStructFieldRef("delta"); this.wcListElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; @@ -253,7 +253,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver { PrimitiveObjectInspectorFactory.javaStringObjectInspector, ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaFloatObjectInspector))); - fieldNames.add("topic"); + fieldNames.add("topics"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); fieldNames.add("alpha"); @@ -278,7 +278,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver { throws HiveException { OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg; myAggr.reset(); - myAggr.setOptions(topic, alpha, delta); + myAggr.setOptions(topics, alpha, delta); } @Override @@ -359,7 +359,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver { // restore options from partial result Object topicObj = internalMergeOI.getStructFieldData(partial, topicOptionField); - this.topic = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicObj); + this.topics = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicObj); Object alphaObj = internalMergeOI.getStructFieldData(partial, alphaOptionField); this.alpha = PrimitiveObjectInspectorFactory.writableFloatObjectInspector.get(alphaObj); @@ -368,7 +368,7 @@ public final class LDAPredictUDAF extends AbstractGenericUDAFResolver { this.delta = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(deltaObj); OnlineLDAPredictAggregationBuffer myAggr = (OnlineLDAPredictAggregationBuffer) agg; - myAggr.setOptions(topic, alpha, delta); + myAggr.setOptions(topics, alpha, delta); myAggr.merge(wcList, lambdaMap); } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/topicmodel/LDAUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java index 9aa15e2..1e28a30 100644 --- a/core/src/main/java/hivemall/topicmodel/LDAUDTF.java +++ b/core/src/main/java/hivemall/topicmodel/LDAUDTF.java @@ -63,7 +63,7 @@ public class LDAUDTF extends UDTFWithOptions { private static final Log logger = LogFactory.getLog(LDAUDTF.class); // Options - protected int topic; + protected int topics; protected float alpha; protected float eta; protected long numDocs; @@ -93,9 +93,9 @@ public class LDAUDTF extends UDTFWithOptions { protected ByteBuffer inputBuf; public LDAUDTF() { - this.topic = 10; - this.alpha = 1.f / topic; - this.eta = 1.f / topic; + this.topics = 10; + this.alpha = 1.f / topics; + this.eta = 1.f / topics; this.numDocs = -1L; this.tau0 = 64.d; this.kappa = 0.7; @@ -108,7 +108,7 @@ public class LDAUDTF extends UDTFWithOptions { @Override protected Options getOptions() { Options opts = new Options(); - opts.addOption("k", "topic", true, "The number of topics [default: 10]"); + opts.addOption("k", "topics", true, "The number of topics [default: 10]"); opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]"); opts.addOption("eta", true, "The hyperparameter for beta [default: 1/k]"); opts.addOption("d", "num_docs", true, "The total number of documents [default: auto]"); @@ -131,9 +131,9 @@ public class LDAUDTF extends UDTFWithOptions { if (argOIs.length >= 2) { String rawArgs = HiveUtils.getConstString(argOIs[1]); cl = parseOptions(rawArgs); - this.topic = Primitives.parseInt(cl.getOptionValue("topic"), 10); - this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topic); - this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / topic); + this.topics = Primitives.parseInt(cl.getOptionValue("topics"), 10); + this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topics); + this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / topics); this.numDocs = Primitives.parseLong(cl.getOptionValue("num_docs"), -1L); this.tau0 = Primitives.parseDouble(cl.getOptionValue("tau0"), 64.d); if (tau0 <= 0.d) { @@ -187,7 +187,7 @@ public class LDAUDTF extends UDTFWithOptions { } protected void initModel() { - this.model = new OnlineLDAModel(topic, alpha, eta, numDocs, tau0, kappa, delta); + this.model = new OnlineLDAModel(topics, alpha, eta, numDocs, tau0, kappa, delta); } @Override @@ -527,7 +527,7 @@ public class LDAUDTF extends UDTFWithOptions { forwardObjs[1] = word; forwardObjs[2] = score; - for (int k = 0; k < topic; k++) { + for (int k = 0; k < topics; k++) { topicIdx.set(k); final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k); @@ -541,7 +541,7 @@ public class LDAUDTF extends UDTFWithOptions { } } - logger.info("Forwarded topic words each of " + topic + " topics"); + logger.info("Forwarded topic words each of " + topics + " topics"); } /* http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java new file mode 100644 index 0000000..c0b60fc --- /dev/null +++ b/core/src/main/java/hivemall/topicmodel/PLSAPredictUDAF.java @@ -0,0 +1,480 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.topicmodel; + +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.lang.CommandLineUtils; +import hivemall.utils.lang.Primitives; + +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.HelpFormatter; +import org.apache.commons.cli.Options; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException; +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.parse.SemanticException; +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator; +import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.StructField; +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; + +@Description(name = "plsa_predict", + value = "_FUNC_(string word, float value, int label, float prob[, const string options])" + + " - Returns a list which consists of <int label, float prob>") +public final class PLSAPredictUDAF extends AbstractGenericUDAFResolver { + + @Override + public Evaluator getEvaluator(TypeInfo[] typeInfo) throws SemanticException { + if (typeInfo.length != 4 && typeInfo.length != 5) { + throw new UDFArgumentLengthException( + "Expected argument length is 4 or 5 but given argument length was " + + typeInfo.length); + } + + if (!HiveUtils.isStringTypeInfo(typeInfo[0])) { + throw new UDFArgumentTypeException(0, + "String type is expected for the first argument word: " + typeInfo[0].getTypeName()); + } + if (!HiveUtils.isNumberTypeInfo(typeInfo[1])) { + throw new UDFArgumentTypeException(1, + "Number type is expected for the second argument value: " + + typeInfo[1].getTypeName()); + } + if (!HiveUtils.isIntegerTypeInfo(typeInfo[2])) { + throw new UDFArgumentTypeException(2, + "Integer type is expected for the third argument label: " + + typeInfo[2].getTypeName()); + } + if (!HiveUtils.isNumberTypeInfo(typeInfo[3])) { + throw new UDFArgumentTypeException(3, + "Number type is expected for the forth argument prob: " + typeInfo[3].getTypeName()); + } + + if (typeInfo.length == 5) { + if (!HiveUtils.isStringTypeInfo(typeInfo[4])) { + throw new UDFArgumentTypeException(4, + "String type is expected for the fifth argument prob: " + + typeInfo[4].getTypeName()); + } + } + + return new Evaluator(); + } + + public static class Evaluator extends GenericUDAFEvaluator { + + // input OI + private PrimitiveObjectInspector wordOI; + private PrimitiveObjectInspector valueOI; + private PrimitiveObjectInspector labelOI; + private PrimitiveObjectInspector probOI; + + // Hyperparameters + private int topics; + private float alpha; + private double delta; + + // merge OI + private StructObjectInspector internalMergeOI; + private StructField wcListField; + private StructField probMapField; + private StructField topicOptionField; + private StructField alphaOptionField; + private StructField deltaOptionField; + private PrimitiveObjectInspector wcListElemOI; + private StandardListObjectInspector wcListOI; + private StandardMapObjectInspector probMapOI; + private PrimitiveObjectInspector probMapKeyOI; + private StandardListObjectInspector probMapValueOI; + private PrimitiveObjectInspector probMapValueElemOI; + + public Evaluator() {} + + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("k", "topics", true, "The number of topics [default: 10]"); + opts.addOption("alpha", true, "The hyperparameter for P(w|z) update [default: 0.5]"); + opts.addOption("delta", true, + "Check convergence in the expectation step [default: 1E-5]"); + return opts; + } + + @Nonnull + protected final CommandLine parseOptions(String optionValue) throws UDFArgumentException { + String[] args = optionValue.split("\\s+"); + Options opts = getOptions(); + opts.addOption("help", false, "Show function help"); + CommandLine cl = CommandLineUtils.parseOptions(args, opts); + + if (cl.hasOption("help")) { + Description funcDesc = getClass().getAnnotation(Description.class); + final String cmdLineSyntax; + if (funcDesc == null) { + cmdLineSyntax = getClass().getSimpleName(); + } else { + String funcName = funcDesc.name(); + cmdLineSyntax = funcName == null ? getClass().getSimpleName() + : funcDesc.value().replace("_FUNC_", funcDesc.name()); + } + StringWriter sw = new StringWriter(); + sw.write('\n'); + PrintWriter pw = new PrintWriter(sw); + HelpFormatter formatter = new HelpFormatter(); + formatter.printHelp(pw, HelpFormatter.DEFAULT_WIDTH, cmdLineSyntax, null, opts, + HelpFormatter.DEFAULT_LEFT_PAD, HelpFormatter.DEFAULT_DESC_PAD, null, true); + pw.flush(); + String helpMsg = sw.toString(); + throw new UDFArgumentException(helpMsg); + } + + return cl; + } + + @Nullable + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length != 5) { + return null; + } + + String rawArgs = HiveUtils.getConstString(argOIs[4]); + CommandLine cl = parseOptions(rawArgs); + + this.topics = Primitives.parseInt(cl.getOptionValue("topics"), PLSAUDTF.DEFAULT_TOPICS); + if (topics < 1) { + throw new UDFArgumentException( + "A positive integer MUST be set to an option `-topics`: " + topics); + } + + this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), PLSAUDTF.DEFAULT_ALPHA); + this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), PLSAUDTF.DEFAULT_DELTA); + + return cl; + } + + @Override + public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException { + assert (parameters.length == 4 || parameters.length == 5); + super.init(mode, parameters); + + // initialize input + if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data + processOptions(parameters); + this.wordOI = HiveUtils.asStringOI(parameters[0]); + this.valueOI = HiveUtils.asDoubleCompatibleOI(parameters[1]); + this.labelOI = HiveUtils.asIntegerOI(parameters[2]); + this.probOI = HiveUtils.asDoubleCompatibleOI(parameters[3]); + } else {// from partial aggregation + StructObjectInspector soi = (StructObjectInspector) parameters[0]; + this.internalMergeOI = soi; + this.wcListField = soi.getStructFieldRef("wcList"); + this.probMapField = soi.getStructFieldRef("probMap"); + this.topicOptionField = soi.getStructFieldRef("topics"); + this.alphaOptionField = soi.getStructFieldRef("alpha"); + this.deltaOptionField = soi.getStructFieldRef("delta"); + this.wcListElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + this.wcListOI = ObjectInspectorFactory.getStandardListObjectInspector(wcListElemOI); + this.probMapKeyOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + this.probMapValueElemOI = PrimitiveObjectInspectorFactory.javaStringObjectInspector; + this.probMapValueOI = ObjectInspectorFactory.getStandardListObjectInspector(probMapValueElemOI); + this.probMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(probMapKeyOI, + probMapValueOI); + } + + // initialize output + final ObjectInspector outputOI; + if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial + outputOI = internalMergeOI(); + } else { + final ArrayList<String> fieldNames = new ArrayList<String>(); + final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + fieldNames.add("label"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("probability"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + outputOI = ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorFactory.getStandardStructObjectInspector( + fieldNames, fieldOIs)); + } + return outputOI; + } + + private static StructObjectInspector internalMergeOI() { + ArrayList<String> fieldNames = new ArrayList<String>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + + fieldNames.add("wcList"); + fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector)); + + fieldNames.add("probMap"); + fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector( + PrimitiveObjectInspectorFactory.javaStringObjectInspector, + ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaFloatObjectInspector))); + + fieldNames.add("topics"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + + fieldNames.add("alpha"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + fieldNames.add("delta"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + @SuppressWarnings("deprecation") + @Override + public AggregationBuffer getNewAggregationBuffer() throws HiveException { + AggregationBuffer myAggr = new PLSAPredictAggregationBuffer(); + reset(myAggr); + return myAggr; + } + + @Override + public void reset(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg; + myAggr.reset(); + myAggr.setOptions(topics, alpha, delta); + } + + @Override + public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, + Object[] parameters) throws HiveException { + PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg; + + if (parameters[0] == null || parameters[1] == null || parameters[2] == null + || parameters[3] == null) { + return; + } + + String word = PrimitiveObjectInspectorUtils.getString(parameters[0], wordOI); + float value = PrimitiveObjectInspectorUtils.getFloat(parameters[1], valueOI); + int label = PrimitiveObjectInspectorUtils.getInt(parameters[2], labelOI); + float prob = PrimitiveObjectInspectorUtils.getFloat(parameters[3], probOI); + + myAggr.iterate(word, value, label, prob); + } + + @Override + public Object terminatePartial(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg; + if (myAggr.wcList.size() == 0) { + return null; + } + + Object[] partialResult = new Object[5]; + partialResult[0] = myAggr.wcList; + partialResult[1] = myAggr.probMap; + partialResult[2] = new IntWritable(myAggr.topics); + partialResult[3] = new FloatWritable(myAggr.alpha); + partialResult[4] = new DoubleWritable(myAggr.delta); + + return partialResult; + } + + @Override + public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object partial) + throws HiveException { + if (partial == null) { + return; + } + + Object wcListObj = internalMergeOI.getStructFieldData(partial, wcListField); + + List<?> wcListRaw = wcListOI.getList(HiveUtils.castLazyBinaryObject(wcListObj)); + + // fix list elements to Java String objects + int wcListSize = wcListRaw.size(); + List<String> wcList = new ArrayList<String>(); + for (int i = 0; i < wcListSize; i++) { + wcList.add(PrimitiveObjectInspectorUtils.getString(wcListRaw.get(i), wcListElemOI)); + } + + Object probMapObj = internalMergeOI.getStructFieldData(partial, probMapField); + Map<?, ?> probMapRaw = probMapOI.getMap(HiveUtils.castLazyBinaryObject(probMapObj)); + + Map<String, List<Float>> probMap = new HashMap<String, List<Float>>(); + for (Map.Entry<?, ?> e : probMapRaw.entrySet()) { + // fix map keys to Java String objects + String word = PrimitiveObjectInspectorUtils.getString(e.getKey(), probMapKeyOI); + + Object probMapValueObj = e.getValue(); + List<?> probMapValueRaw = probMapValueOI.getList(HiveUtils.castLazyBinaryObject(probMapValueObj)); + + // fix map values to lists of Java Float objects + int probMapValueSize = probMapValueRaw.size(); + List<Float> prob_word = new ArrayList<Float>(); + for (int i = 0; i < probMapValueSize; i++) { + prob_word.add(HiveUtils.getFloat(probMapValueRaw.get(i), probMapValueElemOI)); + } + + probMap.put(word, prob_word); + } + + // restore options from partial result + Object topicObj = internalMergeOI.getStructFieldData(partial, topicOptionField); + this.topics = PrimitiveObjectInspectorFactory.writableIntObjectInspector.get(topicObj); + + Object alphaObj = internalMergeOI.getStructFieldData(partial, alphaOptionField); + this.alpha = PrimitiveObjectInspectorFactory.writableFloatObjectInspector.get(alphaObj); + + Object deltaObj = internalMergeOI.getStructFieldData(partial, deltaOptionField); + this.delta = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(deltaObj); + + PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg; + myAggr.setOptions(topics, alpha, delta); + myAggr.merge(wcList, probMap); + } + + @Override + public Object terminate(@SuppressWarnings("deprecation") AggregationBuffer agg) + throws HiveException { + PLSAPredictAggregationBuffer myAggr = (PLSAPredictAggregationBuffer) agg; + float[] topicDistr = myAggr.get(); + + SortedMap<Float, Integer> sortedDistr = new TreeMap<Float, Integer>( + Collections.reverseOrder()); + for (int i = 0; i < topicDistr.length; i++) { + sortedDistr.put(topicDistr[i], i); + } + + List<Object[]> result = new ArrayList<Object[]>(); + for (Map.Entry<Float, Integer> e : sortedDistr.entrySet()) { + Object[] struct = new Object[2]; + struct[0] = new IntWritable(e.getValue().intValue()); // label + struct[1] = new FloatWritable(e.getKey().floatValue()); // probability + result.add(struct); + } + return result; + } + + } + + public static class PLSAPredictAggregationBuffer extends + GenericUDAFEvaluator.AbstractAggregationBuffer { + + private List<String> wcList; + private Map<String, List<Float>> probMap; + + private int topics; + private float alpha; + private double delta; + + PLSAPredictAggregationBuffer() { + super(); + } + + void setOptions(int topics, float alpha, double delta) { + this.topics = topics; + this.alpha = alpha; + this.delta = delta; + } + + void reset() { + this.wcList = new ArrayList<String>(); + this.probMap = new HashMap<String, List<Float>>(); + } + + void iterate(@Nonnull final String word, final float value, final int label, + final float prob) { + wcList.add(word + ":" + value); + + // for an unforeseen word, initialize its probs w/ -1s + List<Float> prob_word = probMap.get(word); + + if (prob_word == null) { + prob_word = new ArrayList<Float>(Collections.nCopies(topics, -1.f)); + probMap.put(word, prob_word); + } + + // set the given prob value + prob_word.set(label, Float.valueOf(prob)); + } + + void merge(@Nonnull final List<String> o_wcList, + @Nonnull final Map<String, List<Float>> o_probMap) { + wcList.addAll(o_wcList); + + for (Map.Entry<String, List<Float>> e : o_probMap.entrySet()) { + String o_word = e.getKey(); + List<Float> o_prob_word = e.getValue(); + + final List<Float> prob_word = probMap.get(o_word); + if (prob_word == null) {// for a partially observed word + probMap.put(o_word, o_prob_word); + } else { // for an unforeseen word + for (int k = 0; k < topics; k++) { + final float prob_k = o_prob_word.get(k).floatValue(); + if (prob_k != -1.f) { // not default value + prob_word.set(k, prob_k); // set the partial prob value + } + } + probMap.put(o_word, prob_word); + } + } + } + + float[] get() { + final IncrementalPLSAModel model = new IncrementalPLSAModel(topics, alpha, delta); + + for (Map.Entry<String, List<Float>> e : probMap.entrySet()) { + final String word = e.getKey(); + final List<Float> prob_word = e.getValue(); + for (int k = 0; k < topics; k++) { + final float prob_k = prob_word.get(k).floatValue(); + if (prob_k != -1.f) { + model.setProbability(word, k, prob_k); + } + } + } + + String[] wcArray = wcList.toArray(new String[wcList.size()]); + return model.getTopicDistribution(wcArray); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java new file mode 100644 index 0000000..2616133 --- /dev/null +++ b/core/src/main/java/hivemall/topicmodel/PLSAUDTF.java @@ -0,0 +1,535 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.topicmodel; + +import hivemall.UDTFWithOptions; +import hivemall.annotations.VisibleForTesting; +import hivemall.utils.hadoop.HiveUtils; +import hivemall.utils.io.FileUtils; +import hivemall.utils.io.NioStatefullSegment; +import hivemall.utils.lang.NumberUtils; +import hivemall.utils.lang.Primitives; +import hivemall.utils.lang.SizeOf; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.SortedMap; + +import javax.annotation.Nonnegative; +import javax.annotation.Nonnull; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.Options; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.serde2.objectinspector.*; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; +import org.apache.hadoop.io.FloatWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapred.Counters; +import org.apache.hadoop.mapred.Reporter; + +@Description(name = "train_plsa", value = "_FUNC_(array<string> words[, const string options])" + + " - Returns a relation consists of <int topic, string word, float score>") +public class PLSAUDTF extends UDTFWithOptions { + private static final Log logger = LogFactory.getLog(PLSAUDTF.class); + + public static final int DEFAULT_TOPICS = 10; + public static final float DEFAULT_ALPHA = 0.5f; + public static final double DEFAULT_DELTA = 1E-3d; + + // Options + protected int topics; + protected float alpha; + protected int iterations; + protected double delta; + protected double eps; + protected int miniBatchSize; + + // number of proceeded training samples + protected long count; + + protected String[][] miniBatch; + protected int miniBatchCount; + + protected transient IncrementalPLSAModel model; + + protected ListObjectInspector wordCountsOI; + + // for iterations + protected NioStatefullSegment fileIO; + protected ByteBuffer inputBuf; + + public PLSAUDTF() { + this.topics = DEFAULT_TOPICS; + this.alpha = DEFAULT_ALPHA; + this.iterations = 10; + this.delta = DEFAULT_DELTA; + this.eps = 1E-1d; + this.miniBatchSize = 128; + } + + @Override + protected Options getOptions() { + Options opts = new Options(); + opts.addOption("k", "topics", true, "The number of topics [default: 10]"); + opts.addOption("alpha", true, "The hyperparameter for P(w|z) update [default: 0.5]"); + opts.addOption("iter", "iterations", true, "The maximum number of iterations [default: 10]"); + opts.addOption("delta", true, "Check convergence in the expectation step [default: 1E-3]"); + opts.addOption("eps", "epsilon", true, + "Check convergence based on the difference of perplexity [default: 1E-1]"); + opts.addOption("s", "mini_batch_size", true, + "Repeat model updating per mini-batch [default: 128]"); + return opts; + } + + @Override + protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException { + CommandLine cl = null; + + if (argOIs.length >= 2) { + String rawArgs = HiveUtils.getConstString(argOIs[1]); + cl = parseOptions(rawArgs); + this.topics = Primitives.parseInt(cl.getOptionValue("topics"), DEFAULT_TOPICS); + this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), DEFAULT_ALPHA); + this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 10); + if (iterations < 1) { + throw new UDFArgumentException( + "'-iterations' must be greater than or equals to 1: " + iterations); + } + this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), DEFAULT_DELTA); + this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d); + this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"), 128); + } + + return cl; + } + + @Override + public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException { + if (argOIs.length < 1) { + throw new UDFArgumentException( + "_FUNC_ takes 1 arguments: array<string> words [, const string options]"); + } + + this.wordCountsOI = HiveUtils.asListOI(argOIs[0]); + HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector()); + + processOptions(argOIs); + + this.model = null; + this.count = 0L; + this.miniBatch = new String[miniBatchSize][]; + this.miniBatchCount = 0; + + ArrayList<String> fieldNames = new ArrayList<String>(); + ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(); + fieldNames.add("topic"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); + fieldNames.add("word"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector); + fieldNames.add("score"); + fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector); + + return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); + } + + protected void initModel() { + this.model = new IncrementalPLSAModel(topics, alpha, delta); + } + + @Override + public void process(Object[] args) throws HiveException { + if (model == null) { + initModel(); + } + + int length = wordCountsOI.getListLength(args[0]); + String[] wordCounts = new String[length]; + int j = 0; + for (int i = 0; i < length; i++) { + Object o = wordCountsOI.getListElement(args[0], i); + if (o == null) { + throw new HiveException("Given feature vector contains invalid elements"); + } + String s = o.toString(); + wordCounts[j] = s; + j++; + } + if (j == 0) {// avoid empty documents + return; + } + + count++; + + recordTrainSampleToTempFile(wordCounts); + + miniBatch[miniBatchCount] = wordCounts; + miniBatchCount++; + + if (miniBatchCount == miniBatchSize) { + model.train(miniBatch); + Arrays.fill(miniBatch, null); // clear + miniBatchCount = 0; + } + } + + protected void recordTrainSampleToTempFile(@Nonnull final String[] wordCounts) + throws HiveException { + if (iterations == 1) { + return; + } + + ByteBuffer buf = inputBuf; + NioStatefullSegment dst = fileIO; + + if (buf == null) { + final File file; + try { + file = File.createTempFile("hivemall_plsa", ".sgmt"); + file.deleteOnExit(); + if (!file.canWrite()) { + throw new UDFArgumentException("Cannot write a temporary file: " + + file.getAbsolutePath()); + } + logger.info("Record training samples to a file: " + file.getAbsolutePath()); + } catch (IOException ioe) { + throw new UDFArgumentException(ioe); + } catch (Throwable e) { + throw new UDFArgumentException(e); + } + this.inputBuf = buf = ByteBuffer.allocateDirect(1024 * 1024); // 1 MB + this.fileIO = dst = new NioStatefullSegment(file, false); + } + + int wcLength = 0; + for (String wc : wordCounts) { + if (wc == null) { + continue; + } + wcLength += wc.getBytes().length; + } + // recordBytes, wordCounts length, wc1 length, wc1 string, wc2 length, wc2 string, ... + int recordBytes = (Integer.SIZE * 2 + Integer.SIZE * wcLength) / 8 + wcLength; + int remain = buf.remaining(); + if (remain < recordBytes) { + writeBuffer(buf, dst); + } + + buf.putInt(recordBytes); + buf.putInt(wordCounts.length); + for (String wc : wordCounts) { + if (wc == null) { + continue; + } + buf.putInt(wc.length()); + buf.put(wc.getBytes()); + } + } + + private static void writeBuffer(@Nonnull ByteBuffer srcBuf, @Nonnull NioStatefullSegment dst) + throws HiveException { + srcBuf.flip(); + try { + dst.write(srcBuf); + } catch (IOException e) { + throw new HiveException("Exception causes while writing a buffer to file", e); + } + srcBuf.clear(); + } + + @Override + public void close() throws HiveException { + if (count == 0) { + this.model = null; + return; + } + if (miniBatchCount > 0) { // update for remaining samples + model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); + } + if (iterations > 1) { + runIterativeTraining(iterations); + } + forwardModel(); + this.model = null; + } + + protected final void runIterativeTraining(@Nonnegative final int iterations) + throws HiveException { + final ByteBuffer buf = this.inputBuf; + final NioStatefullSegment dst = this.fileIO; + assert (buf != null); + assert (dst != null); + final long numTrainingExamples = count; + + final Reporter reporter = getReporter(); + final Counters.Counter iterCounter = (reporter == null) ? null : reporter.getCounter( + "hivemall.plsa.IncrementalPLSA$Counter", "iteration"); + + try { + if (dst.getPosition() == 0L) {// run iterations w/o temporary file + if (buf.position() == 0) { + return; // no training example + } + buf.flip(); + + int iter = 2; + float perplexityPrev = Float.MAX_VALUE; + float perplexity; + int numTrain; + for (; iter <= iterations; iter++) { + perplexity = 0.f; + numTrain = 0; + + reportProgress(reporter); + setCounterValue(iterCounter, iter); + + Arrays.fill(miniBatch, null); // clear + miniBatchCount = 0; + + while (buf.remaining() > 0) { + int recordBytes = buf.getInt(); + assert (recordBytes > 0) : recordBytes; + int wcLength = buf.getInt(); + final String[] wordCounts = new String[wcLength]; + for (int j = 0; j < wcLength; j++) { + int len = buf.getInt(); + byte[] bytes = new byte[len]; + buf.get(bytes); + wordCounts[j] = new String(bytes); + } + + miniBatch[miniBatchCount] = wordCounts; + miniBatchCount++; + + if (miniBatchCount == miniBatchSize) { + model.train(miniBatch); + perplexity += model.computePerplexity(); + numTrain++; + + Arrays.fill(miniBatch, null); // clear + miniBatchCount = 0; + } + } + buf.rewind(); + + // update for remaining samples + if (miniBatchCount > 0) { // update for remaining samples + model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); + perplexity += model.computePerplexity(); + numTrain++; + } + + logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain); + perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches + if (Math.abs(perplexityPrev - perplexity) < eps) { + break; + } + perplexityPrev = perplexity; + } + logger.info("Performed " + + Math.min(iter, iterations) + + " iterations of " + + NumberUtils.formatNumber(numTrainingExamples) + + " training examples on memory (thus " + + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations)) + + " training updates in total) "); + } else {// read training examples in the temporary file and invoke train for each example + + // write training examples in buffer to a temporary file + if (buf.remaining() > 0) { + writeBuffer(buf, dst); + } + try { + dst.flush(); + } catch (IOException e) { + throw new HiveException("Failed to flush a file: " + + dst.getFile().getAbsolutePath(), e); + } + if (logger.isInfoEnabled()) { + File tmpFile = dst.getFile(); + logger.info("Wrote " + numTrainingExamples + + " records to a temporary file for iterative training: " + + tmpFile.getAbsolutePath() + " (" + FileUtils.prettyFileSize(tmpFile) + + ")"); + } + + // run iterations + int iter = 2; + float perplexityPrev = Float.MAX_VALUE; + float perplexity; + int numTrain; + for (; iter <= iterations; iter++) { + perplexity = 0.f; + numTrain = 0; + + Arrays.fill(miniBatch, null); // clear + miniBatchCount = 0; + + setCounterValue(iterCounter, iter); + + buf.clear(); + dst.resetPosition(); + while (true) { + reportProgress(reporter); + // TODO prefetch + // writes training examples to a buffer in the temporary file + final int bytesRead; + try { + bytesRead = dst.read(buf); + } catch (IOException e) { + throw new HiveException("Failed to read a file: " + + dst.getFile().getAbsolutePath(), e); + } + if (bytesRead == 0) { // reached file EOF + break; + } + assert (bytesRead > 0) : bytesRead; + + // reads training examples from a buffer + buf.flip(); + int remain = buf.remaining(); + if (remain < SizeOf.INT) { + throw new HiveException("Illegal file format was detected"); + } + while (remain >= SizeOf.INT) { + int pos = buf.position(); + int recordBytes = buf.getInt(); + remain -= SizeOf.INT; + if (remain < recordBytes) { + buf.position(pos); + break; + } + + int wcLength = buf.getInt(); + final String[] wordCounts = new String[wcLength]; + for (int j = 0; j < wcLength; j++) { + int len = buf.getInt(); + byte[] bytes = new byte[len]; + buf.get(bytes); + wordCounts[j] = new String(bytes); + } + + miniBatch[miniBatchCount] = wordCounts; + miniBatchCount++; + + if (miniBatchCount == miniBatchSize) { + model.train(miniBatch); + perplexity += model.computePerplexity(); + numTrain++; + + Arrays.fill(miniBatch, null); // clear + miniBatchCount = 0; + } + + remain -= recordBytes; + } + buf.compact(); + } + + // update for remaining samples + if (miniBatchCount > 0) { // update for remaining samples + model.train(Arrays.copyOfRange(miniBatch, 0, miniBatchCount)); + perplexity += model.computePerplexity(); + numTrain++; + } + + logger.info("Perplexity: " + perplexity + ", Num train: " + numTrain); + perplexity /= numTrain; // mean perplexity over `numTrain` mini-batches + if (Math.abs(perplexityPrev - perplexity) < eps) { + break; + } + perplexityPrev = perplexity; + } + logger.info("Performed " + + Math.min(iter, iterations) + + " iterations of " + + NumberUtils.formatNumber(numTrainingExamples) + + " training examples on a secondary storage (thus " + + NumberUtils.formatNumber(numTrainingExamples * Math.min(iter, iterations)) + + " training updates in total)"); + } + } finally { + // delete the temporary file and release resources + try { + dst.close(true); + } catch (IOException e) { + throw new HiveException("Failed to close a file: " + + dst.getFile().getAbsolutePath(), e); + } + this.inputBuf = null; + this.fileIO = null; + } + } + + protected void forwardModel() throws HiveException { + final IntWritable topicIdx = new IntWritable(); + final Text word = new Text(); + final FloatWritable score = new FloatWritable(); + + final Object[] forwardObjs = new Object[3]; + forwardObjs[0] = topicIdx; + forwardObjs[1] = word; + forwardObjs[2] = score; + + for (int k = 0; k < topics; k++) { + topicIdx.set(k); + + final SortedMap<Float, List<String>> topicWords = model.getTopicWords(k); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + score.set(e.getKey()); + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + word.set(words.get(i)); + forward(forwardObjs); + } + } + } + + logger.info("Forwarded topic words each of " + topics + " topics"); + } + + /* + * For testing: + */ + + @VisibleForTesting + double getProbability(String label, int k) { + return model.getProbability(label, k); + } + + @VisibleForTesting + SortedMap<Float, List<String>> getTopicWords(int k) { + return model.getTopicWords(k); + } + + @VisibleForTesting + float[] getTopicDistribution(@Nonnull String[] doc) { + return model.getTopicDistribution(doc); + } +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/utils/lang/ArrayUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java index c20c363..4177d70 100644 --- a/core/src/main/java/hivemall/utils/lang/ArrayUtils.java +++ b/core/src/main/java/hivemall/utils/lang/ArrayUtils.java @@ -735,4 +735,14 @@ public final class ArrayUtils { return ret; } + @Nonnull + public static float[] newRandomFloatArray(@Nonnegative final int size, + @Nonnull final Random rnd) { + final float[] ret = new float[size]; + for (int i = 0; i < size; i++) { + ret[i] = rnd.nextFloat(); + } + return ret; + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/main/java/hivemall/utils/math/MathUtils.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java index 061b75d..8ffb89c 100644 --- a/core/src/main/java/hivemall/utils/math/MathUtils.java +++ b/core/src/main/java/hivemall/utils/math/MathUtils.java @@ -408,4 +408,16 @@ public final class MathUtils { return Math.log(logsumexp) + max; } + @Nonnull + public static float[] l1normalize(@Nonnull final float[] arr) { + double sum = 0.d; + for (int i = 0; i < arr.length; i++) { + sum += Math.abs(arr[i]); + } + for (int i = 0; i < arr.length; i++) { + arr[i] /= sum; + } + return arr; + } + } http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java b/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java new file mode 100644 index 0000000..db34a38 --- /dev/null +++ b/core/src/test/java/hivemall/topicmodel/IncrementalPLSAModelTest.java @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package hivemall.topicmodel; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.SortedMap; +import java.util.Set; +import java.util.HashSet; +import java.util.Arrays; +import java.util.StringTokenizer; +import java.util.zip.GZIPInputStream; + +import hivemall.classifier.KernelExpansionPassiveAggressiveUDTFTest; + +import org.junit.Assert; +import org.junit.Test; + +import javax.annotation.Nonnull; + +public class IncrementalPLSAModelTest { + private static final boolean DEBUG = false; + + @Test + public void testOnline() { + int K = 2; + int it = 0; + int maxIter = 1024; + float perplexityPrev; + float perplexity = Float.MAX_VALUE; + + IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.f, 1E-5d); + + String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"}; + String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2", + "oranges:1"}; + + do { + perplexityPrev = perplexity; + perplexity = 0.f; + + // online (i.e., one-by-one) updating + model.train(new String[][] {doc1}); + perplexity += model.computePerplexity(); + + model.train(new String[][] {doc2}); + perplexity += model.computePerplexity(); + + perplexity /= 2.f; // mean perplexity for the 2 docs + + it++; + println("Iteration " + it + ": mean perplexity = " + perplexity); + } while (it < maxIter && Math.abs(perplexityPrev - perplexity) >= 1E-4f); + + SortedMap<Float, List<String>> topicWords; + + println("Topic 0:"); + println("========"); + topicWords = model.getTopicWords(0); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + println(e.getKey() + " " + words.get(i)); + } + } + println("========"); + + println("Topic 1:"); + println("========"); + topicWords = model.getTopicWords(1); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + println(e.getKey() + " " + words.get(i)); + } + } + println("========"); + + + int k1, k2; + float[] topicDistr = model.getTopicDistribution(doc1); + if (topicDistr[0] > topicDistr[1]) { + // topic 0 MUST represent doc#1 + k1 = 0; + k2 = 1; + } else { + k1 = 1; + k2 = 0; + } + Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), " + + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic", + model.getProbability("vegetables", k1) > model.getProbability("flu", k1)); + Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), " + + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic", + model.getProbability("avocados", k2) > model.getProbability("healthy", k2)); + } + + @Test + public void testMiniBatch() { + int K = 2; + int it = 0; + int maxIter = 2048; + float perplexityPrev; + float perplexity = Float.MAX_VALUE; + + IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.f, 1E-5d); + + String[] doc1 = new String[] {"fruits:1", "healthy:1", "vegetables:1"}; + String[] doc2 = new String[] {"apples:1", "avocados:1", "colds:1", "flu:1", "like:2", + "oranges:1"}; + + do { + perplexityPrev = perplexity; + + model.train(new String[][] {doc1, doc2}); + perplexity = model.computePerplexity(); + + it++; + println("Iteration " + it + ": perplexity = " + perplexity); + } while (it < maxIter && Math.abs(perplexityPrev - perplexity) >= 1E-4f); + + SortedMap<Float, List<String>> topicWords; + + println("Topic 0:"); + println("========"); + topicWords = model.getTopicWords(0); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + println(e.getKey() + " " + words.get(i)); + } + } + println("========"); + + println("Topic 1:"); + println("========"); + topicWords = model.getTopicWords(1); + for (Map.Entry<Float, List<String>> e : topicWords.entrySet()) { + List<String> words = e.getValue(); + for (int i = 0; i < words.size(); i++) { + println(e.getKey() + " " + words.get(i)); + } + } + println("========"); + + + int k1, k2; + float[] topicDistr = model.getTopicDistribution(doc1); + if (topicDistr[0] > topicDistr[1]) { + // topic 0 MUST represent doc#1 + k1 = 0; + k2 = 1; + } else { + k1 = 1; + k2 = 0; + } + Assert.assertTrue("doc1 is in topic " + k1 + " (" + (topicDistr[k1] * 100) + "%), " + + "and `vegetables` SHOULD be more suitable topic word than `flu` in the topic", + model.getProbability("vegetables", k1) > model.getProbability("flu", k1)); + Assert.assertTrue("doc2 is in topic " + k2 + " (" + (topicDistr[k2] * 100) + "%), " + + "and `avocados` SHOULD be more suitable topic word than `healthy` in the topic", + model.getProbability("avocados", k2) > model.getProbability("healthy", k2)); + } + + @Test + public void testNews20() throws IOException { + int K = 20; + int miniBatchSize = 2; + + int cnt, it; + int maxIter = 64; + + IncrementalPLSAModel model = new IncrementalPLSAModel(K, 0.8f, 1E-5d); + + BufferedReader news20 = readFile("news20-multiclass.gz"); + + String[][] docs = new String[K][]; + + String line = news20.readLine(); + List<String> doc = new ArrayList<String>(); + + cnt = 0; + while (line != null) { + StringTokenizer tokens = new StringTokenizer(line, " "); + + int k = Integer.parseInt(tokens.nextToken()) - 1; + + while (tokens.hasMoreTokens()) { + doc.add(tokens.nextToken()); + } + + // store first document in each of K classes + if (docs[k] == null) { + docs[k] = doc.toArray(new String[doc.size()]); + cnt++; + } + + if (cnt == K) { + break; + } + + doc.clear(); + line = news20.readLine(); + } + println("Stored " + cnt + " docs. Start training w/ mini-batch size: " + miniBatchSize); + + float perplexityPrev; + float perplexity = Float.MAX_VALUE; + + it = 0; + do { + perplexityPrev = perplexity; + perplexity = 0.f; + + int head = 0; + cnt = 0; + while (head < K) { + int tail = head + miniBatchSize; + model.train(Arrays.copyOfRange(docs, head, tail)); + perplexity += model.computePerplexity(); + head = tail; + cnt++; + println("Processed mini-batch#" + cnt); + } + + perplexity /= cnt; + + it++; + println("Iteration " + it + ": mean perplexity = " + perplexity); + } while (it < maxIter && Math.abs(perplexityPrev - perplexity) >= 1E-3f); + + Set<Integer> topics = new HashSet<Integer>(); + for (int k = 0; k < K; k++) { + topics.add(findMaxTopic(model.getTopicDistribution(docs[k]))); + } + + int n = topics.size(); + println("# of unique topics: " + n); + Assert.assertTrue("At least 15 documents SHOULD be classified to different topics, " + + "but there are only " + n + " unique topics.", n >= 15); + } + + private static void println(String msg) { + if (DEBUG) { + System.out.println(msg); + } + } + + @Nonnull + private static BufferedReader readFile(@Nonnull String fileName) throws IOException { + // use data stored for KPA UDTF test + InputStream is = KernelExpansionPassiveAggressiveUDTFTest.class.getResourceAsStream(fileName); + if (fileName.endsWith(".gz")) { + is = new GZIPInputStream(is); + } + return new BufferedReader(new InputStreamReader(is)); + } + + @Nonnull + private static int findMaxTopic(@Nonnull float[] topicDistr) { + int maxIdx = 0; + for (int i = 1; i < topicDistr.length; i++) { + if (topicDistr[maxIdx] < topicDistr[i]) { + maxIdx = i; + } + } + return maxIdx; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java index a23d917..2c08560 100644 --- a/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java +++ b/core/src/test/java/hivemall/topicmodel/LDAPredictUDAFTest.java @@ -100,7 +100,7 @@ public class LDAPredictUDAFTest { PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector( PrimitiveObjectInspector.PrimitiveCategory.FLOAT), ObjectInspectorUtils.getConstantObjectInspector( - PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topic 2")}; + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2")}; evaluator = udaf.getEvaluator(new SimpleGenericUDAFParameterInfo(inputOIs, false, false)); @@ -117,7 +117,7 @@ public class LDAPredictUDAFTest { ObjectInspectorFactory.getStandardListObjectInspector( PrimitiveObjectInspectorFactory.javaFloatObjectInspector))); - fieldNames.add("topic"); + fieldNames.add("topics"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector); fieldNames.add("alpha"); http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f2bf3a72/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java index d1e3f81..a5881d4 100644 --- a/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java +++ b/core/src/test/java/hivemall/topicmodel/LDAUDTFTest.java @@ -42,7 +42,7 @@ public class LDAUDTFTest { ObjectInspector[] argOIs = new ObjectInspector[] { ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector), ObjectInspectorUtils.getConstantObjectInspector( - PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topic 2 -num_docs 2 -s 1")}; + PrimitiveObjectInspectorFactory.javaStringObjectInspector, "-topics 2 -num_docs 2 -s 1")}; udtf.initialize(argOIs);
