http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/common/PartialVectorMerger.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/common/PartialVectorMerger.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/common/PartialVectorMerger.java new file mode 100644 index 0000000..287a813 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/common/PartialVectorMerger.java @@ -0,0 +1,144 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.common; + +import java.io.IOException; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.math.VectorWritable; + +/** + * This class groups a set of input vectors. The Sequence file input should have a + * {@link org.apache.hadoop.io.WritableComparable} + * key containing document id and a {@link VectorWritable} value containing the term frequency vector. This + * class also does normalization of the vector. + * + */ +public final class PartialVectorMerger { + + public static final float NO_NORMALIZING = -1.0f; + + public static final String NORMALIZATION_POWER = "normalization.power"; + + public static final String DIMENSION = "vector.dimension"; + + public static final String SEQUENTIAL_ACCESS = "vector.sequentialAccess"; + + public static final String NAMED_VECTOR = "vector.named"; + + public static final String LOG_NORMALIZE = "vector.lognormalize"; + + /** + * Cannot be initialized. Use the static functions + */ + private PartialVectorMerger() { + + } + + /** + * Merge all the partial {@link org.apache.mahout.math.RandomAccessSparseVector}s into the complete Document + * {@link org.apache.mahout.math.RandomAccessSparseVector} + * + * @param partialVectorPaths + * input directory of the vectors in {@link org.apache.hadoop.io.SequenceFile} format + * @param output + * output directory were the partial vectors have to be created + * @param baseConf + * job configuration + * @param normPower + * The normalization value. Must be greater than or equal to 0 or equal to {@link #NO_NORMALIZING} + * @param dimension cardinality of the vectors + * @param sequentialAccess + * output vectors should be optimized for sequential access + * @param namedVector + * output vectors should be named, retaining key (doc id) as a label + * @param numReducers + * The number of reducers to spawn + */ + public static void mergePartialVectors(Iterable<Path> partialVectorPaths, + Path output, + Configuration baseConf, + float normPower, + boolean logNormalize, + int dimension, + boolean sequentialAccess, + boolean namedVector, + int numReducers) + throws IOException, InterruptedException, ClassNotFoundException { + Preconditions.checkArgument(normPower == NO_NORMALIZING || normPower >= 0, + "If specified normPower must be nonnegative", normPower); + Preconditions.checkArgument(normPower == NO_NORMALIZING + || (normPower > 1 && !Double.isInfinite(normPower)) + || !logNormalize, + "normPower must be > 1 and not infinite if log normalization is chosen", normPower); + + Configuration conf = new Configuration(baseConf); + // this conf parameter needs to be set enable serialisation of conf values + conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization," + + "org.apache.hadoop.io.serializer.WritableSerialization"); + conf.setBoolean(SEQUENTIAL_ACCESS, sequentialAccess); + conf.setBoolean(NAMED_VECTOR, namedVector); + conf.setInt(DIMENSION, dimension); + conf.setFloat(NORMALIZATION_POWER, normPower); + conf.setBoolean(LOG_NORMALIZE, logNormalize); + + Job job = new Job(conf); + job.setJobName("PartialVectorMerger::MergePartialVectors"); + job.setJarByClass(PartialVectorMerger.class); + + job.setOutputKeyClass(Text.class); + job.setOutputValueClass(VectorWritable.class); + + FileInputFormat.setInputPaths(job, getCommaSeparatedPaths(partialVectorPaths)); + + FileOutputFormat.setOutputPath(job, output); + + job.setMapperClass(Mapper.class); + job.setInputFormatClass(SequenceFileInputFormat.class); + job.setReducerClass(PartialVectorMergeReducer.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + job.setNumReduceTasks(numReducers); + + HadoopUtil.delete(conf, output); + + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + } + + private static String getCommaSeparatedPaths(Iterable<Path> paths) { + StringBuilder commaSeparatedPaths = new StringBuilder(100); + String sep = ""; + for (Path path : paths) { + commaSeparatedPaths.append(sep).append(path.toString()); + sep = ","; + } + return commaSeparatedPaths.toString(); + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/document/SequenceFileTokenizerMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/document/SequenceFileTokenizerMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/document/SequenceFileTokenizerMapper.java new file mode 100644 index 0000000..690e0e5 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/document/SequenceFileTokenizerMapper.java @@ -0,0 +1,70 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.document; + +import java.io.IOException; +import java.io.StringReader; + +import com.google.common.io.Closeables; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.standard.StandardAnalyzer; + +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.mahout.common.StringTuple; +import org.apache.mahout.common.lucene.AnalyzerUtils; +import org.apache.mahout.vectorizer.DocumentProcessor; + +/** + * Tokenizes a text document and outputs tokens in a StringTuple + */ +public class SequenceFileTokenizerMapper extends Mapper<Text, Text, Text, StringTuple> { + + private Analyzer analyzer; + + @Override + protected void map(Text key, Text value, Context context) throws IOException, InterruptedException { + TokenStream stream = analyzer.tokenStream(key.toString(), new StringReader(value.toString())); + CharTermAttribute termAtt = stream.addAttribute(CharTermAttribute.class); + stream.reset(); + StringTuple document = new StringTuple(); + while (stream.incrementToken()) { + if (termAtt.length() > 0) { + document.add(new String(termAtt.buffer(), 0, termAtt.length())); + } + } + stream.end(); + Closeables.close(stream, true); + context.write(key, document); + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + + String analyzerClassName = context.getConfiguration().get(DocumentProcessor.ANALYZER_CLASS, + StandardAnalyzer.class.getName()); + try { + analyzer = AnalyzerUtils.createAnalyzer(analyzerClassName); + } catch (ClassNotFoundException e) { + throw new IOException("Unable to create analyzer: " + analyzerClassName, e); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/AdaptiveWordValueEncoder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/AdaptiveWordValueEncoder.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/AdaptiveWordValueEncoder.java new file mode 100644 index 0000000..04b718e --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/AdaptiveWordValueEncoder.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.encoders; + +import com.google.common.base.Charsets; +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Multiset; +import org.apache.mahout.math.Vector; + +/** + * Encodes words into vectors much as does WordValueEncoder while maintaining + * an adaptive dictionary of values seen so far. This allows weighting of terms + * without a pre-scan of all of the data. + */ +public class AdaptiveWordValueEncoder extends WordValueEncoder { + + private final Multiset<String> dictionary; + + public AdaptiveWordValueEncoder(String name) { + super(name); + dictionary = HashMultiset.create(); + } + + /** + * Adds a value to a vector. + * + * @param originalForm The original form of the value as a string. + * @param data The vector to which the value should be added. + */ + @Override + public void addToVector(String originalForm, double weight, Vector data) { + dictionary.add(originalForm); + super.addToVector(originalForm, weight, data); + } + + @Override + protected double getWeight(byte[] originalForm, double w) { + return w * weight(originalForm); + } + + @Override + protected double weight(byte[] originalForm) { + // the counts here are adjusted so that every observed value has an extra 0.5 count + // as does a hypothetical unobserved value. This smooths our estimates a bit and + // allows the first word seen to have a non-zero weight of -log(1.5 / 2) + double thisWord = dictionary.count(new String(originalForm, Charsets.UTF_8)) + 0.5; + double allWords = dictionary.size() + dictionary.elementSet().size() * 0.5 + 0.5; + return -Math.log(thisWord / allWords); + } + + public Multiset<String> getDictionary() { + return dictionary; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingContinuousValueEncoder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingContinuousValueEncoder.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingContinuousValueEncoder.java new file mode 100644 index 0000000..0b350c6 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingContinuousValueEncoder.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.encoders; + +import java.util.Arrays; + +import com.google.common.base.Preconditions; +import org.apache.mahout.math.map.OpenIntIntHashMap; + +public class CachingContinuousValueEncoder extends ContinuousValueEncoder { + private final int dataSize; + private OpenIntIntHashMap[] caches; + + public CachingContinuousValueEncoder(String name, int dataSize) { + super(name); + this.dataSize = dataSize; + initCaches(); + } + + private void initCaches() { + this.caches = new OpenIntIntHashMap[getProbes()]; + for (int probe = 0; probe < getProbes(); probe++) { + caches[probe] = new OpenIntIntHashMap(); + } + } + + OpenIntIntHashMap[] getCaches() { + return caches; + } + + @Override + public void setProbes(int probes) { + super.setProbes(probes); + initCaches(); + } + + @Override + protected int hashForProbe(byte[] originalForm, int dataSize, String name, int probe) { + Preconditions.checkArgument(dataSize == this.dataSize, + "dataSize argument [" + dataSize + "] does not match expected dataSize [" + this.dataSize + ']'); + int originalHashcode = Arrays.hashCode(originalForm); + if (caches[probe].containsKey(originalHashcode)) { + return caches[probe].get(originalHashcode); + } + int hash = super.hashForProbe(originalForm, dataSize, name, probe); + caches[probe].put(originalHashcode, hash); + return hash; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingStaticWordValueEncoder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingStaticWordValueEncoder.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingStaticWordValueEncoder.java new file mode 100644 index 0000000..258ff84 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingStaticWordValueEncoder.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.encoders; + +import java.util.Arrays; + +import com.google.common.base.Preconditions; +import org.apache.mahout.math.map.OpenIntIntHashMap; + +public class CachingStaticWordValueEncoder extends StaticWordValueEncoder { + + private final int dataSize; + private OpenIntIntHashMap[] caches; + + public CachingStaticWordValueEncoder(String name, int dataSize) { + super(name); + this.dataSize = dataSize; + initCaches(); + } + + private void initCaches() { + caches = new OpenIntIntHashMap[getProbes()]; + for (int probe = 0; probe < getProbes(); probe++) { + caches[probe] = new OpenIntIntHashMap(); + } + } + + OpenIntIntHashMap[] getCaches() { + return caches; + } + + @Override + public void setProbes(int probes) { + super.setProbes(probes); + initCaches(); + } + + @Override + protected int hashForProbe(byte[] originalForm, int dataSize, String name, int probe) { + Preconditions.checkArgument(dataSize == this.dataSize, + "dataSize argument [" + dataSize + "] does not match expected dataSize [" + this.dataSize + ']'); + int originalHashcode = Arrays.hashCode(originalForm); + if (caches[probe].containsKey(originalHashcode)) { + return caches[probe].get(originalHashcode); + } + int hash = super.hashForProbe(originalForm, dataSize, name, probe); + caches[probe].put(originalHashcode, hash); + return hash; + } +} + http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingTextValueEncoder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingTextValueEncoder.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingTextValueEncoder.java new file mode 100644 index 0000000..b109818 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingTextValueEncoder.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.encoders; + +public class CachingTextValueEncoder extends TextValueEncoder { + public CachingTextValueEncoder(String name, int dataSize) { + super(name); + setWordEncoder(new CachingStaticWordValueEncoder(name, dataSize)); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingValueEncoder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingValueEncoder.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingValueEncoder.java new file mode 100644 index 0000000..08d3d3e --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingValueEncoder.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.encoders; + +import org.apache.mahout.math.MurmurHash; + +/** + * Provides basic hashing semantics for encoders where the probe locations + * depend only on the name of the variable. + */ +public abstract class CachingValueEncoder extends FeatureVectorEncoder { + private int[] cachedProbes; + + protected CachingValueEncoder(String name, int seed) { + super(name); + cacheProbeLocations(seed); + } + + /** + * Sets the number of locations in the feature vector that a value should be in. + * This causes the cached probe locations to be recomputed. + * + * @param probes Number of locations to increment. + */ + @Override + public void setProbes(int probes) { + super.setProbes(probes); + cacheProbeLocations(getSeed()); + } + + protected abstract int getSeed(); + + private void cacheProbeLocations(int seed) { + cachedProbes = new int[getProbes()]; + for (int i = 0; i < getProbes(); i++) { + // note that the modulo operation is deferred + cachedProbes[i] = (int) MurmurHash.hash64A(bytesForString(getName()), seed + i); + } + } + + @Override + protected int hashForProbe(byte[] originalForm, int dataSize, String name, int probe) { + int h = cachedProbes[probe] % dataSize; + if (h < 0) { + h += dataSize; + } + return h; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/ConstantValueEncoder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/ConstantValueEncoder.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/ConstantValueEncoder.java new file mode 100644 index 0000000..d7dd9f6 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/ConstantValueEncoder.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.encoders; + +import org.apache.mahout.math.Vector; + +/** + * An encoder that does the standard thing for a virtual bias term. + */ +public class ConstantValueEncoder extends CachingValueEncoder { + public ConstantValueEncoder(String name) { + super(name, 0); + } + + @Override + public void addToVector(byte[] originalForm, double weight, Vector data) { + int probes = getProbes(); + String name = getName(); + for (int i = 0; i < probes; i++) { + int n = hashForProbe(originalForm, data.size(), name, i); + if (isTraceEnabled()) { + trace((String) null, n); + } + data.set(n, data.get(n) + getWeight(originalForm,weight)); + } + } + + @Override + protected double getWeight(byte[] originalForm, double w) { + return w; + } + + @Override + public String asString(String originalForm) { + return getName(); + } + + @Override + protected int getSeed() { + return 0; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/ContinuousValueEncoder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/ContinuousValueEncoder.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/ContinuousValueEncoder.java new file mode 100644 index 0000000..14382a5 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/ContinuousValueEncoder.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.encoders; + +import com.google.common.base.Charsets; +import org.apache.mahout.math.Vector; + +/** + * Continuous values are stored in fixed randomized location in the feature vector. + */ +public class ContinuousValueEncoder extends CachingValueEncoder { + + public ContinuousValueEncoder(String name) { + super(name, CONTINUOUS_VALUE_HASH_SEED); + } + + /** + * Adds a value to a vector. + * + * @param originalForm The original form of the value as a string. + * @param data The vector to which the value should be added. + */ + @Override + public void addToVector(byte[] originalForm, double weight, Vector data) { + int probes = getProbes(); + String name = getName(); + for (int i = 0; i < probes; i++) { + int n = hashForProbe(originalForm, data.size(), name, i); + if (isTraceEnabled()) { + trace((String) null, n); + } + data.set(n, data.get(n) + getWeight(originalForm,weight)); + } + } + + @Override + protected double getWeight(byte[] originalForm, double w) { + if (originalForm == null) { + return w; + } + return w * Double.parseDouble(new String(originalForm, Charsets.UTF_8)); + } + + /** + * Converts a value into a form that would help a human understand the internals of how the value + * is being interpreted. For text-like things, this is likely to be a list of the terms found with + * associated weights (if any). + * + * @param originalForm The original form of the value as a string. + * @return A string that a human can read. + */ + @Override + public String asString(String originalForm) { + return getName() + ':' + originalForm; + } + + @Override + protected int getSeed() { + return CONTINUOUS_VALUE_HASH_SEED; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/Dictionary.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/Dictionary.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/Dictionary.java new file mode 100644 index 0000000..60c89f7 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/Dictionary.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.encoders; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** +* Assigns integer codes to strings as they appear. +*/ +public class Dictionary { + private final Map<String, Integer> dict = new LinkedHashMap<>(); + + public int intern(String s) { + if (!dict.containsKey(s)) { + dict.put(s, dict.size()); + } + return dict.get(s); + } + + public List<String> values() { + // order of keySet is guaranteed to be insertion order + return new ArrayList<>(dict.keySet()); + } + + public int size() { + return dict.size(); + } + + public static Dictionary fromList(Iterable<String> values) { + Dictionary dict = new Dictionary(); + for (String value : values) { + dict.intern(value); + } + return dict; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/FeatureVectorEncoder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/FeatureVectorEncoder.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/FeatureVectorEncoder.java new file mode 100644 index 0000000..96498d7 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/FeatureVectorEncoder.java @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.encoders; + +import com.google.common.base.Charsets; +import com.google.common.collect.Sets; +import org.apache.mahout.math.MurmurHash; +import org.apache.mahout.math.Vector; + +import java.util.Collections; +import java.util.Map; +import java.util.Set; + +/** + * General interface for objects that record features into a feature vector. + * <p/> + * By convention, sub-classes should provide a constructor that accepts just a field name as well as + * setters to customize properties of the conversion such as adding tokenizers or a weight + * dictionary. + */ +public abstract class FeatureVectorEncoder { + protected static final int CONTINUOUS_VALUE_HASH_SEED = 1; + protected static final int WORD_LIKE_VALUE_HASH_SEED = 100; + + private static final byte[] EMPTY_ARRAY = new byte[0]; + + private final String name; + private int probes; + + private Map<String, Set<Integer>> traceDictionary; + + protected FeatureVectorEncoder(String name) { + this(name, 1); + } + + protected FeatureVectorEncoder(String name, int probes) { + this.name = name; + this.probes = probes; + } + + /** + * Adds a value expressed in string form to a vector. + * + * @param originalForm The original form of the value as a string. + * @param data The vector to which the value should be added. + */ + public void addToVector(String originalForm, Vector data) { + addToVector(originalForm, 1.0, data); + } + + /** + * Adds a value expressed in byte array form to a vector. + * + * @param originalForm The original form of the value as a byte array. + * @param data The vector to which the value should be added. + */ + public void addToVector(byte[] originalForm, Vector data) { + addToVector(originalForm, 1.0, data); + } + + /** + * Adds a weighted value expressed in string form to a vector. In some cases it is convenient to + * use this method to encode continuous values using the weight as the value. In such cases, the + * string value should typically be set to null. + * + * @param originalForm The original form of the value as a string. + * @param weight The weight to be applied to this feature. + * @param data The vector to which the value should be added. + */ + public void addToVector(String originalForm, double weight, Vector data) { + addToVector(bytesForString(originalForm), weight, data); + } + + public abstract void addToVector(byte[] originalForm, double weight, Vector data); + + /** + * Provides the unique hash for a particular probe. For all encoders except text, this + * is all that is needed and the default implementation of hashesForProbe will do the right + * thing. For text and similar values, hashesForProbe should be over-ridden and this method + * should not be used. + * + * @param originalForm The original byte array value + * @param dataSize The length of the vector being encoded + * @param name The name of the variable being encoded + * @param probe The probe number + * @return The hash of the current probe + */ + protected abstract int hashForProbe(byte[] originalForm, int dataSize, String name, int probe); + + /** + * Returns all of the hashes for this probe. For most encoders, this is a singleton, but + * for text, many hashes are returned, one for each word (unique or not). Most implementations + * should only implement hashForProbe for simplicity. + * + * @param originalForm The original byte array value. + * @param dataSize The length of the vector being encoded + * @param name The name of the variable being encoded + * @param probe The probe number + * @return an Iterable of the hashes + */ + protected Iterable<Integer> hashesForProbe(byte[] originalForm, int dataSize, String name, int probe) { + return Collections.singletonList(hashForProbe(originalForm, dataSize, name, probe)); + } + + protected double getWeight(byte[] originalForm, double w) { + return 1.0; + } + + // ******* Utility functions used by most implementations + + /** + * Hash a string and an integer into the range [0..numFeatures-1]. + * + * @param term The string. + * @param probe An integer that modifies the resulting hash. + * @param numFeatures The range into which the resulting hash must fit. + * @return An integer in the range [0..numFeatures-1] that has good spread for small changes in + * term and probe. + */ + protected int hash(String term, int probe, int numFeatures) { + long r = MurmurHash.hash64A(bytesForString(term), probe) % numFeatures; + if (r < 0) { + r += numFeatures; + } + return (int) r; + } + + /** + * Hash a byte array and an integer into the range [0..numFeatures-1]. + * + * @param term The bytes. + * @param probe An integer that modifies the resulting hash. + * @param numFeatures The range into which the resulting hash must fit. + * @return An integer in the range [0..numFeatures-1] that has good spread for small changes in + * term and probe. + */ + protected static int hash(byte[] term, int probe, int numFeatures) { + long r = MurmurHash.hash64A(term, probe) % numFeatures; + if (r < 0) { + r += numFeatures; + } + return (int) r; + } + + /** + * Hash two strings and an integer into the range [0..numFeatures-1]. + * + * @param term1 The first string. + * @param term2 The second string. + * @param probe An integer that modifies the resulting hash. + * @param numFeatures The range into which the resulting hash must fit. + * @return An integer in the range [0..numFeatures-1] that has good spread for small changes in + * term and probe. + */ + protected static int hash(String term1, String term2, int probe, int numFeatures) { + long r = MurmurHash.hash64A(bytesForString(term1), probe); + r = MurmurHash.hash64A(bytesForString(term2), (int) r) % numFeatures; + if (r < 0) { + r += numFeatures; + } + return (int) r; + } + + /** + * Hash two byte arrays and an integer into the range [0..numFeatures-1]. + * + * @param term1 The first string. + * @param term2 The second string. + * @param probe An integer that modifies the resulting hash. + * @param numFeatures The range into which the resulting hash must fit. + * @return An integer in the range [0..numFeatures-1] that has good spread for small changes in + * term and probe. + */ + protected int hash(byte[] term1, byte[] term2, int probe, int numFeatures) { + long r = MurmurHash.hash64A(term1, probe); + r = MurmurHash.hash64A(term2, (int) r) % numFeatures; + if (r < 0) { + r += numFeatures; + } + return (int) r; + } + + /** + * Hash four strings and an integer into the range [0..numFeatures-1]. + * + * @param term1 The first string. + * @param term2 The second string. + * @param term3 The third string + * @param term4 And the fourth. + * @param probe An integer that modifies the resulting hash. + * @param numFeatures The range into which the resulting hash must fit. + * @return An integer in the range [0..numFeatures-1] that has good spread for small changes in + * term and probe. + */ + protected int hash(String term1, String term2, String term3, String term4, int probe, int numFeatures) { + long r = MurmurHash.hash64A(bytesForString(term1), probe); + r = MurmurHash.hash64A(bytesForString(term2), (int) r) % numFeatures; + r = MurmurHash.hash64A(bytesForString(term3), (int) r) % numFeatures; + r = MurmurHash.hash64A(bytesForString(term4), (int) r) % numFeatures; + if (r < 0) { + r += numFeatures; + } + return (int) r; + } + + /** + * Converts a value into a form that would help a human understand the internals of how the value + * is being interpreted. For text-like things, this is likely to be a list of the terms found + * with associated weights (if any). + * + * @param originalForm The original form of the value as a string. + * @return A string that a human can read. + */ + public abstract String asString(String originalForm); + + public int getProbes() { + return probes; + } + + /** + * Sets the number of locations in the feature vector that a value should be in. + * + * @param probes Number of locations to increment. + */ + public void setProbes(int probes) { + this.probes = probes; + } + + public String getName() { + return name; + } + + protected boolean isTraceEnabled() { + return traceDictionary != null; + } + + protected void trace(String subName, int n) { + if (traceDictionary != null) { + String key = name; + if (subName != null) { + key = name + '=' + subName; + } + Set<Integer> trace = traceDictionary.get(key); + if (trace == null) { + trace = Sets.newHashSet(n); + traceDictionary.put(key, trace); + } else { + trace.add(n); + } + } + } + + protected void trace(byte[] subName, int n) { + trace(new String(subName, Charsets.UTF_8), n); + } + + public void setTraceDictionary(Map<String, Set<Integer>> traceDictionary) { + this.traceDictionary = traceDictionary; + } + + protected static byte[] bytesForString(String x) { + return x == null ? EMPTY_ARRAY : x.getBytes(Charsets.UTF_8); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/InteractionValueEncoder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/InteractionValueEncoder.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/InteractionValueEncoder.java new file mode 100644 index 0000000..0be8823 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/InteractionValueEncoder.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.encoders; + +import java.util.Locale; + +import org.apache.mahout.math.Vector; + +import com.google.common.base.Charsets; + +public class InteractionValueEncoder extends FeatureVectorEncoder { + private final FeatureVectorEncoder firstEncoder; + private final FeatureVectorEncoder secondEncoder; + + public InteractionValueEncoder(String name, FeatureVectorEncoder encoderOne, FeatureVectorEncoder encoderTwo) { + super(name, 2); + firstEncoder = encoderOne; + secondEncoder = encoderTwo; + } + + /** + * Adds a value to a vector. + * + * @param originalForm The original form of the first value as a string. + * @param data The vector to which the value should be added. + */ + @Override + public void addToVector(String originalForm, double w, Vector data) { + throw new UnsupportedOperationException("addToVector is not supported for InteractionVectorEncoder"); + } + + /** + * Adds a value to a vector. (Unsupported) + * + * @param originalForm The original form of the first value as a byte array. + * @param data The vector to which the value should be added. + */ + @Override + public void addToVector(byte[] originalForm, double w, Vector data) { + throw new UnsupportedOperationException("addToVector is not supported for InteractionVectorEncoder"); + } + + /** + * Adds a value to a vector. + * + * @param original1 The original form of the first value as a string. + * @param original2 The original form of the second value as a string. + * @param weight How much to weight this interaction + * @param data The vector to which the value should be added. + */ + public void addInteractionToVector(String original1, String original2, double weight, Vector data) { + byte[] originalForm1 = bytesForString(original1); + byte[] originalForm2 = bytesForString(original2); + addInteractionToVector(originalForm1, originalForm2, weight, data); + } + + /** + * Adds a value to a vector. + * + * @param originalForm1 The original form of the first value as a byte array. + * @param originalForm2 The original form of the second value as a byte array. + * @param weight How much to weight this interaction + * @param data The vector to which the value should be added. + */ + public void addInteractionToVector(byte[] originalForm1, byte[] originalForm2, double weight, Vector data) { + String name = getName(); + double w = getWeight(originalForm1, originalForm2, weight); + for (int i = 0; i < probes(); i++) { + Iterable<Integer> jValues = + secondEncoder.hashesForProbe(originalForm2, data.size(), name, i % secondEncoder.getProbes()); + for (Integer k : firstEncoder.hashesForProbe(originalForm1, data.size(), name, i % firstEncoder.getProbes())) { + for (Integer j : jValues) { + int n = (k + j) % data.size(); + if (isTraceEnabled()) { + trace(String.format("%s:%s", new String(originalForm1, Charsets.UTF_8), new String(originalForm2, + Charsets.UTF_8)), n); + } + data.set(n, data.get(n) + w); + } + } + } + } + + private int probes() { + return getProbes(); + } + + protected double getWeight(byte[] originalForm1, byte[] originalForm2, double w) { + return firstEncoder.getWeight(originalForm1, 1.0) * secondEncoder.getWeight(originalForm2, 1.0) * w; + } + + /** + * Converts a value into a form that would help a human understand the internals of how the value + * is being interpreted. For text-like things, this is likely to be a list of the terms found with + * associated weights (if any). + * + * @param originalForm The original form of the value as a string. + * @return A string that a human can read. + */ + @Override + public String asString(String originalForm) { + return String.format(Locale.ENGLISH, "%s:%s", getName(), originalForm); + } + + @Override + protected int hashForProbe(byte[] originalForm, int dataSize, String name, int probe) { + return hash(name, probe, dataSize); + } +} + + http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/LuceneTextValueEncoder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/LuceneTextValueEncoder.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/LuceneTextValueEncoder.java new file mode 100644 index 0000000..e3e133c --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/LuceneTextValueEncoder.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.encoders; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.mahout.common.lucene.TokenStreamIterator; + +import java.io.IOException; +import java.io.Reader; +import java.nio.CharBuffer; +import java.util.Iterator; + +/** + * Encodes text using a lucene style tokenizer. + * + * @see TextValueEncoder + */ +public class LuceneTextValueEncoder extends TextValueEncoder { + private Analyzer analyzer; + + public LuceneTextValueEncoder(String name) { + super(name); + } + + public void setAnalyzer(Analyzer analyzer) { + this.analyzer = analyzer; + } + + /** + * Tokenizes a string using the simplest method. This should be over-ridden for more subtle + * tokenization. + */ + @Override + protected Iterable<String> tokenize(CharSequence originalForm) { + TokenStream ts = analyzer.tokenStream(getName(), new CharSequenceReader(originalForm)); + ts.addAttribute(CharTermAttribute.class); + return new LuceneTokenIterable(ts, false); + } + + private static final class CharSequenceReader extends Reader { + private final CharBuffer buf; + + /** + * Creates a new character-stream reader whose critical sections will synchronize on the reader + * itself. + */ + private CharSequenceReader(CharSequence input) { + int n = input.length(); + buf = CharBuffer.allocate(n); + for (int i = 0; i < n; i++) { + buf.put(input.charAt(i)); + } + buf.rewind(); + } + + /** + * Reads characters into a portion of an array. This method will block until some input is + * available, an I/O error occurs, or the end of the stream is reached. + * + * @param cbuf Destination buffer + * @param off Offset at which to start storing characters + * @param len Maximum number of characters to read + * @return The number of characters read, or -1 if the end of the stream has been reached + */ + @Override + public int read(char[] cbuf, int off, int len) { + int toRead = Math.min(len, buf.remaining()); + if (toRead > 0) { + buf.get(cbuf, off, toRead); + return toRead; + } else { + return -1; + } + } + + @Override + public void close() { + // do nothing + } + } + + private static final class LuceneTokenIterable implements Iterable<String> { + private boolean firstTime = true; + private final TokenStream tokenStream; + + private LuceneTokenIterable(TokenStream ts, boolean firstTime) { + this.tokenStream = ts; + this.firstTime = firstTime; + } + + /** + * Returns an iterator over a set of elements of type T. + * + * @return an Iterator. + */ + @Override + public Iterator<String> iterator() { + if (firstTime) { + firstTime = false; + } else { + try { + tokenStream.reset(); + } catch (IOException e) { + throw new IllegalStateException("This token stream can't be reset"); + } + } + + return new TokenStreamIterator(tokenStream); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/StaticWordValueEncoder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/StaticWordValueEncoder.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/StaticWordValueEncoder.java new file mode 100644 index 0000000..6f67ef4 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/StaticWordValueEncoder.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.encoders; + +import com.google.common.base.Charsets; + +import java.util.Collections; +import java.util.Map; + +/** + * Encodes a categorical values with an unbounded vocabulary. Values are encoding by incrementing a + * few locations in the output vector with a weight that is either defaulted to 1 or that is looked + * up in a weight dictionary. By default, only one probe is used which should be fine but could + * cause a decrease in the speed of learning because more features will be non-zero. If a large + * feature vector is used so that the probability of feature collisions is suitably small, then this + * can be decreased to 1. If a very small feature vector is used, the number of probes should + * probably be increased to 3. + */ +public class StaticWordValueEncoder extends WordValueEncoder { + private Map<String, Double> dictionary; + private double missingValueWeight = 1; + private final byte[] nameBytes; + + public StaticWordValueEncoder(String name) { + super(name); + nameBytes = bytesForString(name); + } + + @Override + protected int hashForProbe(byte[] originalForm, int dataSize, String name, int probe) { + return hash(nameBytes, originalForm, WORD_LIKE_VALUE_HASH_SEED + probe, dataSize); + } + + /** + * Sets the weighting dictionary to be used by this encoder. Also sets the missing value weight + * to be half the smallest weight in the dictionary. + * + * @param dictionary The dictionary to use to look up weights. + */ + public void setDictionary(Map<String, Double> dictionary) { + this.dictionary = dictionary; + setMissingValueWeight(Collections.min(dictionary.values()) / 2); + } + + /** + * Sets the weight that is to be used for values that do not appear in the dictionary. + * + * @param missingValueWeight The default weight for missing values. + */ + public void setMissingValueWeight(double missingValueWeight) { + this.missingValueWeight = missingValueWeight; + } + + @Override + protected double weight(byte[] originalForm) { + double weight = missingValueWeight; + if (dictionary != null) { + String s = new String(originalForm, Charsets.UTF_8); + if (dictionary.containsKey(s)) { + weight = dictionary.get(s); + } + } + return weight; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/TextValueEncoder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/TextValueEncoder.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/TextValueEncoder.java new file mode 100644 index 0000000..87de095 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/TextValueEncoder.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.encoders; + +import com.google.common.base.Charsets; +import com.google.common.base.Splitter; +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Lists; +import com.google.common.collect.Multiset; +import org.apache.mahout.math.Vector; + +import java.util.Collection; +import java.util.regex.Pattern; + +/** + * Encodes text that is tokenized on non-alphanum separators. Each word is encoded using a + * settable encoder which is by default an StaticWordValueEncoder which gives all + * words the same weight. + * @see LuceneTextValueEncoder + */ +public class TextValueEncoder extends FeatureVectorEncoder { + + private static final double LOG_2 = Math.log(2.0); + + private static final Splitter ON_NON_WORD = Splitter.on(Pattern.compile("\\W+")).omitEmptyStrings(); + + private FeatureVectorEncoder wordEncoder; + private final Multiset<String> counts; + + public TextValueEncoder(String name) { + super(name, 2); + wordEncoder = new StaticWordValueEncoder(name); + counts = HashMultiset.create(); + } + + /** + * Adds a value to a vector after tokenizing it by splitting on non-alphanum characters. + * + * @param originalForm The original form of the value as a string. + * @param data The vector to which the value should be added. + */ + @Override + public void addToVector(byte[] originalForm, double weight, Vector data) { + addText(originalForm); + flush(weight, data); + } + + /** + * Adds text to the internal word counter, but delays converting it to vector + * form until flush is called. + * @param originalForm The original text encoded as UTF-8 + */ + public void addText(byte[] originalForm) { + addText(new String(originalForm, Charsets.UTF_8)); + } + + /** + * Adds text to the internal word counter, but delays converting it to vector + * form until flush is called. + * @param text The original text encoded as UTF-8 + */ + public void addText(CharSequence text) { + for (String word : tokenize(text)) { + counts.add(word); + } + } + + /** + * Adds all of the tokens that we counted up to a vector. + */ + public void flush(double weight, Vector data) { + for (String word : counts.elementSet()) { + // weight words by log_2(tf) times whatever other weight we are given + wordEncoder.addToVector(word, weight * Math.log1p(counts.count(word)) / LOG_2, data); + } + counts.clear(); + } + + @Override + protected int hashForProbe(byte[] originalForm, int dataSize, String name, int probe) { + return 0; + } + + @Override + protected Iterable<Integer> hashesForProbe(byte[] originalForm, int dataSize, String name, int probe) { + Collection<Integer> hashes = Lists.newArrayList(); + for (String word : tokenize(new String(originalForm, Charsets.UTF_8))) { + hashes.add(hashForProbe(bytesForString(word), dataSize, name, probe)); + } + return hashes; + } + + /** + * Tokenizes a string using the simplest method. This should be over-ridden for more subtle + * tokenization. + * @see LuceneTextValueEncoder + */ + protected Iterable<String> tokenize(CharSequence originalForm) { + return ON_NON_WORD.split(originalForm); + } + + /** + * Converts a value into a form that would help a human understand the internals of how the value + * is being interpreted. For text-like things, this is likely to be a list of the terms found with + * associated weights (if any). + * + * @param originalForm The original form of the value as a string. + * @return A string that a human can read. + */ + @Override + public String asString(String originalForm) { + StringBuilder r = new StringBuilder(); + r.append('['); + for (String word : tokenize(originalForm)) { + if (r.length() > 1) { + r.append(", "); + } + r.append(wordEncoder.asString(word)); + } + r.append(']'); + return r.toString(); + } + + public final void setWordEncoder(FeatureVectorEncoder wordEncoder) { + this.wordEncoder = wordEncoder; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/WordValueEncoder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/WordValueEncoder.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/WordValueEncoder.java new file mode 100644 index 0000000..2b9dc23 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/encoders/WordValueEncoder.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.encoders; + +import org.apache.mahout.math.Vector; + +import java.util.Locale; + +/** + * Encodes words as sparse vector updates to a Vector. Weighting is defined by a + * sub-class. + */ +public abstract class WordValueEncoder extends FeatureVectorEncoder { + private final byte[] nameBytes; + + protected WordValueEncoder(String name) { + super(name, 2); + nameBytes = bytesForString(name); + } + + /** + * Adds a value to a vector. + * + * @param originalForm The original form of the value as a string. + * @param data The vector to which the value should be added. + */ + @Override + public void addToVector(byte[] originalForm, double w, Vector data) { + int probes = getProbes(); + String name = getName(); + double weight = getWeight(originalForm, w); + for (int i = 0; i < probes; i++) { + int n = hashForProbe(originalForm, data.size(), name, i); + if (isTraceEnabled()) { + trace(originalForm, n); + } + data.set(n, data.get(n) + weight); + } + } + + + @Override + protected double getWeight(byte[] originalForm, double w) { + return w * weight(originalForm); + } + + @Override + protected int hashForProbe(byte[] originalForm, int dataSize, String name, int probe) { + return hash(nameBytes, originalForm, WORD_LIKE_VALUE_HASH_SEED + probe, dataSize); + } + + /** + * Converts a value into a form that would help a human understand the internals of how the value + * is being interpreted. For text-like things, this is likely to be a list of the terms found with + * associated weights (if any). + * + * @param originalForm The original form of the value as a string. + * @return A string that a human can read. + */ + @Override + public String asString(String originalForm) { + return String.format(Locale.ENGLISH, "%s:%s:%.4f", getName(), originalForm, weight(bytesForString(originalForm))); + } + + protected abstract double weight(byte[] originalForm); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/pruner/PrunedPartialVectorMergeReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/pruner/PrunedPartialVectorMergeReducer.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/pruner/PrunedPartialVectorMergeReducer.java new file mode 100644 index 0000000..9f14249 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/pruner/PrunedPartialVectorMergeReducer.java @@ -0,0 +1,65 @@ +package org.apache.mahout.vectorizer.pruner; +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.vectorizer.common.PartialVectorMerger; + +import java.io.IOException; + +public class PrunedPartialVectorMergeReducer extends + Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> { + + private double normPower; + + private boolean logNormalize; + + @Override + protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Context context) throws IOException, + InterruptedException { + + Vector vector = null; + for (VectorWritable value : values) { + if (vector == null) { + vector = value.get().clone(); + continue; + } + //value.get().addTo(vector); + vector.assign(value.get(), Functions.PLUS); + } + + if (vector != null && normPower != PartialVectorMerger.NO_NORMALIZING) { + vector = logNormalize ? vector.logNormalize(normPower) : vector.normalize(normPower); + } + + VectorWritable vectorWritable = new VectorWritable(vector); + context.write(key, vectorWritable); + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + Configuration conf = context.getConfiguration(); + normPower = conf.getFloat(PartialVectorMerger.NORMALIZATION_POWER, PartialVectorMerger.NO_NORMALIZING); + logNormalize = conf.getBoolean(PartialVectorMerger.LOG_NORMALIZE, false); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/pruner/WordsPrunerReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/pruner/WordsPrunerReducer.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/pruner/WordsPrunerReducer.java new file mode 100644 index 0000000..e0da4fe --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/pruner/WordsPrunerReducer.java @@ -0,0 +1,86 @@ +package org.apache.mahout.vectorizer.pruner; +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.map.OpenIntLongHashMap; +import org.apache.mahout.vectorizer.HighDFWordsPruner; + +import java.io.IOException; +import java.util.Iterator; + +public class WordsPrunerReducer extends + Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> { + + private final OpenIntLongHashMap dictionary = new OpenIntLongHashMap(); + private long maxDf = Long.MAX_VALUE; + private long minDf = -1; + + @Override + protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Context context) + throws IOException, InterruptedException { + Iterator<VectorWritable> it = values.iterator(); + if (!it.hasNext()) { + return; + } + Vector value = it.next().get(); + Vector vector = value.clone(); + if (maxDf != Long.MAX_VALUE || minDf > -1) { + for (Vector.Element e : value.nonZeroes()) { + if (!dictionary.containsKey(e.index())) { + vector.setQuick(e.index(), 0.0); + continue; + } + long df = dictionary.get(e.index()); + if (df > maxDf || df < minDf) { + vector.setQuick(e.index(), 0.0); + } + } + } + + VectorWritable vectorWritable = new VectorWritable(vector); + context.write(key, vectorWritable); + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + Configuration conf = context.getConfiguration(); + //Path[] localFiles = HadoopUtil.getCachedFiles(conf); + + maxDf = conf.getLong(HighDFWordsPruner.MAX_DF, Long.MAX_VALUE); + minDf = conf.getLong(HighDFWordsPruner.MIN_DF, -1); + + Path dictionaryFile = HadoopUtil.getSingleCachedFile(conf); + + // key is feature, value is the document frequency + for (Pair<IntWritable, LongWritable> record + : new SequenceFileIterable<IntWritable, LongWritable>(dictionaryFile, true, conf)) { + dictionary.put(record.getFirst().get(), record.getSecond().get()); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TFPartialVectorReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TFPartialVectorReducer.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TFPartialVectorReducer.java new file mode 100644 index 0000000..1496c90 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TFPartialVectorReducer.java @@ -0,0 +1,139 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.term; + +import com.google.common.collect.Lists; +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.lucene.analysis.shingle.ShingleFilter; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.StringTuple; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.common.lucene.IteratorTokenStream; +import org.apache.mahout.math.NamedVector; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.map.OpenObjectIntHashMap; +import org.apache.mahout.vectorizer.DictionaryVectorizer; +import org.apache.mahout.vectorizer.common.PartialVectorMerger; + +import java.io.IOException; +import java.net.URI; +import java.util.Iterator; +import java.util.List; + +/** + * Converts a document in to a sparse vector + */ +public class TFPartialVectorReducer extends Reducer<Text, StringTuple, Text, VectorWritable> { + + private final OpenObjectIntHashMap<String> dictionary = new OpenObjectIntHashMap<>(); + + private int dimension; + private boolean sequentialAccess; + private boolean namedVector; + private int maxNGramSize = 1; + + @Override + protected void reduce(Text key, Iterable<StringTuple> values, Context context) + throws IOException, InterruptedException { + Iterator<StringTuple> it = values.iterator(); + + if (!it.hasNext()) { + return; + } + + List<String> value = Lists.newArrayList(); + + while (it.hasNext()) { + value.addAll(it.next().getEntries()); + } + + Vector vector = new RandomAccessSparseVector(dimension, value.size()); // guess at initial size + + if (maxNGramSize >= 2) { + ShingleFilter sf = new ShingleFilter(new IteratorTokenStream(value.iterator()), maxNGramSize); + sf.reset(); + try { + do { + String term = sf.getAttribute(CharTermAttribute.class).toString(); + if (!term.isEmpty() && dictionary.containsKey(term)) { // ngram + int termId = dictionary.get(term); + vector.setQuick(termId, vector.getQuick(termId) + 1); + } + } while (sf.incrementToken()); + + sf.end(); + } finally { + Closeables.close(sf, true); + } + } else { + for (String term : value) { + if (!term.isEmpty() && dictionary.containsKey(term)) { // unigram + int termId = dictionary.get(term); + vector.setQuick(termId, vector.getQuick(termId) + 1); + } + } + } + if (sequentialAccess) { + vector = new SequentialAccessSparseVector(vector); + } + + if (namedVector) { + vector = new NamedVector(vector, key.toString()); + } + + // if the vector has no nonZero entries (nothing in the dictionary), let's not waste space sending it to disk. + if (vector.getNumNondefaultElements() > 0) { + VectorWritable vectorWritable = new VectorWritable(vector); + context.write(key, vectorWritable); + } else { + context.getCounter("TFPartialVectorReducer", "emptyVectorCount").increment(1); + } + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + Configuration conf = context.getConfiguration(); + + dimension = conf.getInt(PartialVectorMerger.DIMENSION, Integer.MAX_VALUE); + sequentialAccess = conf.getBoolean(PartialVectorMerger.SEQUENTIAL_ACCESS, false); + namedVector = conf.getBoolean(PartialVectorMerger.NAMED_VECTOR, false); + maxNGramSize = conf.getInt(DictionaryVectorizer.MAX_NGRAMS, maxNGramSize); + + URI[] localFiles = DistributedCache.getCacheFiles(conf); + Path dictionaryFile = HadoopUtil.findInCacheByPartOfFilename(DictionaryVectorizer.DICTIONARY_FILE, localFiles); + // key is word value is id + for (Pair<Writable, IntWritable> record + : new SequenceFileIterable<Writable, IntWritable>(dictionaryFile, true, conf)) { + dictionary.put(record.getFirst().toString(), record.getSecond().get()); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountCombiner.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountCombiner.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountCombiner.java new file mode 100644 index 0000000..4c63333 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountCombiner.java @@ -0,0 +1,41 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.term; + +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Reducer; + +import java.io.IOException; + +/** + * @see TermCountReducer + */ +public class TermCountCombiner extends Reducer<Text, LongWritable, Text, LongWritable> { + + @Override + protected void reduce(Text key, Iterable<LongWritable> values, Context context) + throws IOException, InterruptedException { + long sum = 0; + for (LongWritable value : values) { + sum += value.get(); + } + context.write(key, new LongWritable(sum)); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountMapper.java new file mode 100644 index 0000000..9af3d57 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountMapper.java @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.term; + +import java.io.IOException; + +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.common.StringTuple; +import org.apache.mahout.math.function.ObjectLongProcedure; +import org.apache.mahout.math.map.OpenObjectLongHashMap; + +/** + * TextVectorizer Term Count Mapper. Tokenizes a text document and outputs the count of the words + */ +public class TermCountMapper extends Mapper<Text, StringTuple, Text, LongWritable> { + + @Override + protected void map(Text key, StringTuple value, final Context context) throws IOException, InterruptedException { + OpenObjectLongHashMap<String> wordCount = new OpenObjectLongHashMap<>(); + for (String word : value.getEntries()) { + if (wordCount.containsKey(word)) { + wordCount.put(word, wordCount.get(word) + 1); + } else { + wordCount.put(word, 1); + } + } + wordCount.forEachPair(new ObjectLongProcedure<String>() { + @Override + public boolean apply(String first, long second) { + try { + context.write(new Text(first), new LongWritable(second)); + } catch (IOException e) { + context.getCounter("Exception", "Output IO Exception").increment(1); + } catch (InterruptedException e) { + context.getCounter("Exception", "Interrupted Exception").increment(1); + } + return true; + } + }); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountReducer.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountReducer.java new file mode 100644 index 0000000..388bfc2 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountReducer.java @@ -0,0 +1,55 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.term; + +import java.io.IOException; + +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.vectorizer.DictionaryVectorizer; + +/** + * This accumulates all the words and the weights and sums them up. + * + * @see TermCountCombiner + */ +public class TermCountReducer extends Reducer<Text, LongWritable, Text, LongWritable> { + + private int minSupport; + + @Override + protected void reduce(Text key, Iterable<LongWritable> values, Context context) + throws IOException, InterruptedException { + long sum = 0; + for (LongWritable value : values) { + sum += value.get(); + } + if (sum >= minSupport) { + context.write(key, new LongWritable(sum)); + } + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + minSupport = context.getConfiguration().getInt(DictionaryVectorizer.MIN_SUPPORT, + DictionaryVectorizer.DEFAULT_MIN_SUPPORT); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermDocumentCountMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermDocumentCountMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermDocumentCountMapper.java new file mode 100644 index 0000000..30828bf --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/term/TermDocumentCountMapper.java @@ -0,0 +1,50 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.term; + +import java.io.IOException; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +/** + * TextVectorizer Document Frequency Count Mapper. Outputs 1 for each feature + */ +public class TermDocumentCountMapper extends Mapper<WritableComparable<?>, VectorWritable, IntWritable, LongWritable> { + + private static final LongWritable ONE = new LongWritable(1); + + private static final IntWritable TOTAL_COUNT = new IntWritable(-1); + + private final IntWritable out = new IntWritable(); + + @Override + protected void map(WritableComparable<?> key, VectorWritable value, Context context) + throws IOException, InterruptedException { + Vector vector = value.get(); + for (Vector.Element e : vector.nonZeroes()) { + out.set(e.index()); + context.write(out, ONE); + } + context.write(TOTAL_COUNT, ONE); + } +}
