http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileValueIterable.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileValueIterable.java b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileValueIterable.java new file mode 100644 index 0000000..d2fdf8d --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileValueIterable.java @@ -0,0 +1,67 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator.sequencefile; + +import java.io.IOException; +import java.util.Iterator; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Writable; + +/** + * <p>{@link Iterable} counterpart to {@link SequenceFileValueIterator}.</p> + */ +public final class SequenceFileValueIterable<V extends Writable> implements Iterable<V> { + + private final Path path; + private final boolean reuseKeyValueInstances; + private final Configuration conf; + + /** + * Like {@link #SequenceFileValueIterable(Path, boolean, Configuration)} but instances are not reused + * by default. + * + * @param path file to iterate over + */ + public SequenceFileValueIterable(Path path, Configuration conf) { + this(path, false, conf); + } + + /** + * @param path file to iterate over + * @param reuseKeyValueInstances if true, reuses instances of the value object instead of creating a new + * one for each read from the file + */ + public SequenceFileValueIterable(Path path, boolean reuseKeyValueInstances, Configuration conf) { + this.path = path; + this.reuseKeyValueInstances = reuseKeyValueInstances; + this.conf = conf; + } + + @Override + public Iterator<V> iterator() { + try { + return new SequenceFileValueIterator<>(path, reuseKeyValueInstances, conf); + } catch (IOException ioe) { + throw new IllegalStateException(path.toString(), ioe); + } + } + +} +
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileValueIterator.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileValueIterator.java b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileValueIterator.java new file mode 100644 index 0000000..49d64c7 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileValueIterator.java @@ -0,0 +1,97 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator.sequencefile; + +import java.io.Closeable; +import java.io.IOException; + +import com.google.common.collect.AbstractIterator; +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.util.ReflectionUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * <p>{@link java.util.Iterator} over a {@link SequenceFile}'s values only.</p> + */ +public final class SequenceFileValueIterator<V extends Writable> extends AbstractIterator<V> implements Closeable { + + private final SequenceFile.Reader reader; + private final Configuration conf; + private final Class<V> valueClass; + private final Writable key; + private V value; + private final boolean reuseKeyValueInstances; + + private static final Logger log = LoggerFactory.getLogger(SequenceFileValueIterator.class); + + /** + * @throws IOException if path can't be read, or its key or value class can't be instantiated + */ + + public SequenceFileValueIterator(Path path, boolean reuseKeyValueInstances, Configuration conf) throws IOException { + value = null; + FileSystem fs = path.getFileSystem(conf); + path = path.makeQualified(path.toUri(), path); + reader = new SequenceFile.Reader(fs, path, conf); + this.conf = conf; + Class<? extends Writable> keyClass = (Class<? extends Writable>) reader.getKeyClass(); + key = ReflectionUtils.newInstance(keyClass, conf); + valueClass = (Class<V>) reader.getValueClass(); + this.reuseKeyValueInstances = reuseKeyValueInstances; + } + + public Class<V> getValueClass() { + return valueClass; + } + + @Override + public void close() throws IOException { + value = null; + Closeables.close(reader, true); + endOfData(); + } + + @Override + protected V computeNext() { + if (!reuseKeyValueInstances || value == null) { + value = ReflectionUtils.newInstance(valueClass, conf); + } + try { + boolean available = reader.next(key, value); + if (!available) { + close(); + return null; + } + return value; + } catch (IOException ioe) { + try { + close(); + } catch (IOException e) { + log.error(e.getMessage(), e); + } + throw new IllegalStateException(ioe); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/lucene/AnalyzerUtils.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/common/lucene/AnalyzerUtils.java b/mr/src/main/java/org/apache/mahout/common/lucene/AnalyzerUtils.java new file mode 100644 index 0000000..37ca383 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/common/lucene/AnalyzerUtils.java @@ -0,0 +1,61 @@ +package org.apache.mahout.common.lucene; +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.util.Version; +import org.apache.mahout.common.ClassUtils; + +public final class AnalyzerUtils { + + private AnalyzerUtils() {} + + /** + * Create an Analyzer using the latest {@link org.apache.lucene.util.Version}. Note, if you need to pass in + * parameters to your constructor, you will need to wrap it in an implementation that does not take any arguments + * @param analyzerClassName - Lucene Analyzer Name + * @return {@link Analyzer} + * @throws ClassNotFoundException - {@link ClassNotFoundException} + */ + public static Analyzer createAnalyzer(String analyzerClassName) throws ClassNotFoundException { + return createAnalyzer(analyzerClassName, Version.LUCENE_46); + } + + public static Analyzer createAnalyzer(String analyzerClassName, Version version) throws ClassNotFoundException { + Class<? extends Analyzer> analyzerClass = Class.forName(analyzerClassName).asSubclass(Analyzer.class); + return createAnalyzer(analyzerClass, version); + } + + /** + * Create an Analyzer using the latest {@link org.apache.lucene.util.Version}. Note, if you need to pass in + * parameters to your constructor, you will need to wrap it in an implementation that does not take any arguments + * @param analyzerClass The Analyzer Class to instantiate + * @return {@link Analyzer} + */ + public static Analyzer createAnalyzer(Class<? extends Analyzer> analyzerClass) { + return createAnalyzer(analyzerClass, Version.LUCENE_46); + } + + public static Analyzer createAnalyzer(Class<? extends Analyzer> analyzerClass, Version version) { + try { + return ClassUtils.instantiateAs(analyzerClass, Analyzer.class, + new Class<?>[] { Version.class }, new Object[] { version }); + } catch (IllegalStateException e) { + return ClassUtils.instantiateAs(analyzerClass, Analyzer.class); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/lucene/IteratorTokenStream.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/common/lucene/IteratorTokenStream.java b/mr/src/main/java/org/apache/mahout/common/lucene/IteratorTokenStream.java new file mode 100644 index 0000000..5facad8 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/common/lucene/IteratorTokenStream.java @@ -0,0 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.lucene; + +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; + +import java.util.Iterator; + +/** Used to emit tokens from an input string array in the style of TokenStream */ +public final class IteratorTokenStream extends TokenStream { + private final CharTermAttribute termAtt; + private final Iterator<String> iterator; + + public IteratorTokenStream(Iterator<String> iterator) { + this.iterator = iterator; + this.termAtt = addAttribute(CharTermAttribute.class); + } + + @Override + public boolean incrementToken() { + if (iterator.hasNext()) { + clearAttributes(); + termAtt.append(iterator.next()); + return true; + } else { + return false; + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/lucene/TokenStreamIterator.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/common/lucene/TokenStreamIterator.java b/mr/src/main/java/org/apache/mahout/common/lucene/TokenStreamIterator.java new file mode 100644 index 0000000..af60d8b --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/common/lucene/TokenStreamIterator.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.lucene; + +import com.google.common.collect.AbstractIterator; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; + +import java.io.IOException; + +/** + * Provide an Iterator for the tokens in a TokenStream. + * + * Note, it is the responsibility of the instantiating class to properly consume the + * {@link org.apache.lucene.analysis.TokenStream}. See the Lucene {@link org.apache.lucene.analysis.TokenStream} + * documentation for more information. + */ +//TODO: consider using the char/byte arrays instead of strings, esp. when we upgrade to Lucene 4.0 +public final class TokenStreamIterator extends AbstractIterator<String> { + + private final TokenStream tokenStream; + + public TokenStreamIterator(TokenStream tokenStream) { + this.tokenStream = tokenStream; + } + + @Override + protected String computeNext() { + try { + if (tokenStream.incrementToken()) { + return tokenStream.getAttribute(CharTermAttribute.class).toString(); + } else { + tokenStream.end(); + tokenStream.close(); + return endOfData(); + } + } catch (IOException e) { + throw new IllegalStateException("IO error while tokenizing", e); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/mapreduce/MergeVectorsCombiner.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/common/mapreduce/MergeVectorsCombiner.java b/mr/src/main/java/org/apache/mahout/common/mapreduce/MergeVectorsCombiner.java new file mode 100644 index 0000000..8e0385d --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/common/mapreduce/MergeVectorsCombiner.java @@ -0,0 +1,34 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.mapreduce; + +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.math.VectorWritable; + +import java.io.IOException; + +public class MergeVectorsCombiner + extends Reducer<WritableComparable<?>,VectorWritable,WritableComparable<?>,VectorWritable> { + + @Override + public void reduce(WritableComparable<?> key, Iterable<VectorWritable> vectors, Context ctx) + throws IOException, InterruptedException { + ctx.write(key, VectorWritable.merge(vectors.iterator())); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/mapreduce/MergeVectorsReducer.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/common/mapreduce/MergeVectorsReducer.java b/mr/src/main/java/org/apache/mahout/common/mapreduce/MergeVectorsReducer.java new file mode 100644 index 0000000..b8d5dea --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/common/mapreduce/MergeVectorsReducer.java @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.mapreduce; + +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import java.io.IOException; + +public class MergeVectorsReducer extends + Reducer<WritableComparable<?>,VectorWritable,WritableComparable<?>,VectorWritable> { + + private final VectorWritable result = new VectorWritable(); + + @Override + public void reduce(WritableComparable<?> key, Iterable<VectorWritable> vectors, Context ctx) + throws IOException, InterruptedException { + Vector merged = VectorWritable.merge(vectors.iterator()).get(); + result.set(new SequentialAccessSparseVector(merged)); + ctx.write(key, result); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/mapreduce/TransposeMapper.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/common/mapreduce/TransposeMapper.java b/mr/src/main/java/org/apache/mahout/common/mapreduce/TransposeMapper.java new file mode 100644 index 0000000..c6c3f05 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/common/mapreduce/TransposeMapper.java @@ -0,0 +1,49 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.mapreduce; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import java.io.IOException; + +public class TransposeMapper extends Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable> { + + public static final String NEW_NUM_COLS_PARAM = TransposeMapper.class.getName() + ".newNumCols"; + + private int newNumCols; + + @Override + protected void setup(Context ctx) throws IOException, InterruptedException { + newNumCols = ctx.getConfiguration().getInt(NEW_NUM_COLS_PARAM, Integer.MAX_VALUE); + } + + @Override + protected void map(IntWritable r, VectorWritable v, Context ctx) throws IOException, InterruptedException { + int row = r.get(); + for (Vector.Element e : v.get().nonZeroes()) { + RandomAccessSparseVector tmp = new RandomAccessSparseVector(newNumCols, 1); + tmp.setQuick(row, e.get()); + r.set(e.index()); + ctx.write(r, new VectorWritable(tmp)); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/mapreduce/VectorSumCombiner.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/common/mapreduce/VectorSumCombiner.java b/mr/src/main/java/org/apache/mahout/common/mapreduce/VectorSumCombiner.java new file mode 100644 index 0000000..1d93386 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/common/mapreduce/VectorSumCombiner.java @@ -0,0 +1,38 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.mapreduce; + +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.hadoop.similarity.cooccurrence.Vectors; + +import java.io.IOException; + +public class VectorSumCombiner + extends Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> { + + private final VectorWritable result = new VectorWritable(); + + @Override + protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Context ctx) + throws IOException, InterruptedException { + result.set(Vectors.sum(values.iterator())); + ctx.write(key, result); + } + } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/mapreduce/VectorSumReducer.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/common/mapreduce/VectorSumReducer.java b/mr/src/main/java/org/apache/mahout/common/mapreduce/VectorSumReducer.java new file mode 100644 index 0000000..97d3805 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/common/mapreduce/VectorSumReducer.java @@ -0,0 +1,35 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.mapreduce; + +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.hadoop.similarity.cooccurrence.Vectors; + +import java.io.IOException; + +public class VectorSumReducer + extends Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> { + + @Override + protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Context ctx) + throws IOException, InterruptedException { + ctx.write(key, new VectorWritable(Vectors.sum(values.iterator()))); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/nlp/NGrams.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/common/nlp/NGrams.java b/mr/src/main/java/org/apache/mahout/common/nlp/NGrams.java new file mode 100644 index 0000000..7adadc1 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/common/nlp/NGrams.java @@ -0,0 +1,94 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.nlp; + +import com.google.common.base.CharMatcher; +import com.google.common.base.Splitter; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +public class NGrams { + + private static final Splitter SPACE_TAB = Splitter.on(CharMatcher.anyOf(" \t")); + + private final String line; + private final int gramSize; + + public NGrams(String line, int gramSize) { + this.line = line; + this.gramSize = gramSize; + } + + public Map<String,List<String>> generateNGrams() { + Map<String,List<String>> returnDocument = Maps.newHashMap(); + + Iterator<String> tokenizer = SPACE_TAB.split(line).iterator(); + List<String> tokens = Lists.newArrayList(); + String labelName = tokenizer.next(); + List<String> previousN1Grams = Lists.newArrayList(); + while (tokenizer.hasNext()) { + + String nextToken = tokenizer.next(); + if (previousN1Grams.size() == gramSize) { + previousN1Grams.remove(0); + } + + previousN1Grams.add(nextToken); + + StringBuilder gramBuilder = new StringBuilder(); + + for (String gram : previousN1Grams) { + gramBuilder.append(gram); + String token = gramBuilder.toString(); + tokens.add(token); + gramBuilder.append(' '); + } + } + returnDocument.put(labelName, tokens); + return returnDocument; + } + + public List<String> generateNGramsWithoutLabel() { + + List<String> tokens = Lists.newArrayList(); + List<String> previousN1Grams = Lists.newArrayList(); + for (String nextToken : SPACE_TAB.split(line)) { + + if (previousN1Grams.size() == gramSize) { + previousN1Grams.remove(0); + } + + previousN1Grams.add(nextToken); + + StringBuilder gramBuilder = new StringBuilder(); + + for (String gram : previousN1Grams) { + gramBuilder.append(gram); + String token = gramBuilder.toString(); + tokens.add(token); + gramBuilder.append(' '); + } + } + + return tokens; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/parameters/AbstractParameter.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/common/parameters/AbstractParameter.java b/mr/src/main/java/org/apache/mahout/common/parameters/AbstractParameter.java new file mode 100644 index 0000000..f0a7aa8 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/common/parameters/AbstractParameter.java @@ -0,0 +1,120 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.parameters; + +import java.util.Collection; +import java.util.Collections; + +import org.apache.hadoop.conf.Configuration; + +public abstract class AbstractParameter<T> implements Parameter<T> { + + private T value; + private final String prefix; + private final String name; + private final String description; + private final Class<T> type; + private final String defaultValue; + + protected AbstractParameter(Class<T> type, + String prefix, + String name, + Configuration jobConf, + T defaultValue, + String description) { + this.type = type; + this.name = name; + this.description = description; + + this.value = defaultValue; + this.defaultValue = getStringValue(); + + this.prefix = prefix; + String jobConfValue = jobConf.get(prefix + name); + if (jobConfValue != null) { + setStringValue(jobConfValue); + } + + } + + @Override + public void configure(Configuration jobConf) { + // nothing to do + } + + @Override + public void createParameters(String prefix, Configuration jobConf) { } + + @Override + public String getStringValue() { + if (value == null) { + return null; + } + return value.toString(); + } + + @Override + public Collection<Parameter<?>> getParameters() { + return Collections.emptyList(); + } + + @Override + public String prefix() { + return prefix; + } + + @Override + public String name() { + return name; + } + + @Override + public String description() { + return description; + } + + @Override + public Class<T> type() { + return type; + } + + @Override + public String defaultValue() { + return defaultValue; + } + + @Override + public T get() { + return value; + } + + @Override + public void set(T value) { + this.value = value; + } + + @Override + public String toString() { + if (value != null) { + return value.toString(); + } else { + return super.toString(); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/parameters/ClassParameter.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/common/parameters/ClassParameter.java b/mr/src/main/java/org/apache/mahout/common/parameters/ClassParameter.java new file mode 100644 index 0000000..1d1c0bb --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/common/parameters/ClassParameter.java @@ -0,0 +1,44 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.parameters; + +import org.apache.hadoop.conf.Configuration; + +public class ClassParameter extends AbstractParameter<Class> { + + public ClassParameter(String prefix, String name, Configuration jobConf, Class<?> defaultValue, String description) { + super(Class.class, prefix, name, jobConf, defaultValue, description); + } + + @Override + public void setStringValue(String stringValue) { + try { + set(Class.forName(stringValue)); + } catch (ClassNotFoundException e) { + throw new IllegalStateException(e); + } + } + + @Override + public String getStringValue() { + if (get() == null) { + return null; + } + return get().getName(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/parameters/DoubleParameter.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/common/parameters/DoubleParameter.java b/mr/src/main/java/org/apache/mahout/common/parameters/DoubleParameter.java new file mode 100644 index 0000000..cb3efcf --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/common/parameters/DoubleParameter.java @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.parameters; + +import org.apache.hadoop.conf.Configuration; + +public class DoubleParameter extends AbstractParameter<Double> { + + public DoubleParameter(String prefix, String name, Configuration conf, double defaultValue, String description) { + super(Double.class, prefix, name, conf, defaultValue, description); + } + + @Override + public void setStringValue(String stringValue) { + set(Double.valueOf(stringValue)); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/parameters/Parameter.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/common/parameters/Parameter.java b/mr/src/main/java/org/apache/mahout/common/parameters/Parameter.java new file mode 100644 index 0000000..292fa27 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/common/parameters/Parameter.java @@ -0,0 +1,62 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.parameters; + +/** + * An accessor to a parameters in the job. + * + * This is a composite entity that can it self contain more parameters. Say the parameters describes what + * DistanceMeasure class to use, once set this parameters would also produce the parameters available in that + * DistanceMeasure implementation. + */ +public interface Parameter<T> extends Parametered { + /** @return job configuration setting key prefix, e.g. 'org.apache.mahout.util.WeightedDistanceMeasure.' */ + String prefix(); + + /** @return configuration parameters name, e.g. 'weightsFile' */ + String name(); + + /** @return human readable description of parameters */ + String description(); + + /** @return value class type */ + Class<T> type(); + + /** + * @param stringValue + * value string representation + */ + void setStringValue(String stringValue); + + /** + * @return value string representation of current value + */ + String getStringValue(); + + /** + * @param value + * new parameters value + */ + void set(T value); + + /** @return current parameters value */ + T get(); + + /** @return value used if not set by consumer */ + String defaultValue(); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/parameters/Parametered.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/common/parameters/Parametered.java b/mr/src/main/java/org/apache/mahout/common/parameters/Parametered.java new file mode 100644 index 0000000..96c9457 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/common/parameters/Parametered.java @@ -0,0 +1,206 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.parameters; + +import java.util.Collection; + +import org.apache.hadoop.conf.Configuration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Meta information and accessors for configuring a job. */ +public interface Parametered { + + Logger log = LoggerFactory.getLogger(Parametered.class); + + Collection<Parameter<?>> getParameters(); + + /** + * EXPERT: consumers should never have to call this method. It would be friendly visible to + * {@link ParameteredGeneralizations} if java supported it. Calling this method should create a new list of + * parameters and is called + * + * @param prefix + * ends with a dot if not empty. + * @param jobConf + * configuration used for retrieving values + * @see ParameteredGeneralizations#configureParameters(String,Parametered,Configuration) + * invoking method + * @see ParameteredGeneralizations#configureParametersRecursively(Parametered,String,Configuration) + * invoking method + */ + void createParameters(String prefix, Configuration jobConf); + + void configure(Configuration config); + + /** "multiple inheritance" */ + final class ParameteredGeneralizations { + private ParameteredGeneralizations() { } + + public static void configureParameters(Parametered parametered, Configuration jobConf) { + configureParameters(parametered.getClass().getSimpleName() + '.', + parametered, jobConf); + + } + + /** + * Calls + * {@link Parametered#createParameters(String,org.apache.hadoop.conf.Configuration)} + * on parameter parmetered, and then recur down its composite tree to invoke + * {@link Parametered#createParameters(String,org.apache.hadoop.conf.Configuration)} + * and {@link Parametered#configure(org.apache.hadoop.conf.Configuration)} on + * each composite part. + * + * @param prefix + * ends with a dot if not empty. + * @param parametered + * instance to be configured + * @param jobConf + * configuration used for retrieving values + */ + public static void configureParameters(String prefix, Parametered parametered, Configuration jobConf) { + parametered.createParameters(prefix, jobConf); + configureParametersRecursively(parametered, prefix, jobConf); + } + + private static void configureParametersRecursively(Parametered parametered, String prefix, Configuration jobConf) { + for (Parameter<?> parameter : parametered.getParameters()) { + if (log.isDebugEnabled()) { + log.debug("Configuring {}{}", prefix, parameter.name()); + } + String name = prefix + parameter.name() + '.'; + parameter.createParameters(name, jobConf); + parameter.configure(jobConf); + if (!parameter.getParameters().isEmpty()) { + configureParametersRecursively(parameter, name, jobConf); + } + } + } + + public static String help(Parametered parametered) { + return new Help(parametered).toString(); + } + + public static String conf(Parametered parametered) { + return new Conf(parametered).toString(); + } + + private static final class Help { + static final int NAME_DESC_DISTANCE = 8; + + private final StringBuilder sb; + private int longestName; + private int numChars = 100; // a few extra just to be sure + + private Help(Parametered parametered) { + recurseCount(parametered); + numChars += (longestName + NAME_DESC_DISTANCE) * parametered.getParameters().size(); + sb = new StringBuilder(numChars); + recurseWrite(parametered); + } + + @Override + public String toString() { + return sb.toString(); + } + + private void recurseCount(Parametered parametered) { + for (Parameter<?> parameter : parametered.getParameters()) { + int parameterNameLength = parameter.name().length(); + if (parameterNameLength > longestName) { + longestName = parameterNameLength; + } + recurseCount(parameter); + numChars += parameter.description().length(); + } + } + + private void recurseWrite(Parametered parametered) { + for (Parameter<?> parameter : parametered.getParameters()) { + sb.append(parameter.prefix()); + sb.append(parameter.name()); + int max = longestName - parameter.name().length() - parameter.prefix().length() + + NAME_DESC_DISTANCE; + for (int i = 0; i < max; i++) { + sb.append(' '); + } + sb.append(parameter.description()); + if (parameter.defaultValue() != null) { + sb.append(" (default value '"); + sb.append(parameter.defaultValue()); + sb.append("')"); + } + sb.append('\n'); + recurseWrite(parameter); + } + } + } + + private static final class Conf { + private final StringBuilder sb; + private int longestName; + private int numChars = 100; // a few extra just to be sure + + private Conf(Parametered parametered) { + recurseCount(parametered); + sb = new StringBuilder(numChars); + recurseWrite(parametered); + } + + @Override + public String toString() { + return sb.toString(); + } + + private void recurseCount(Parametered parametered) { + for (Parameter<?> parameter : parametered.getParameters()) { + int parameterNameLength = parameter.prefix().length() + parameter.name().length(); + if (parameterNameLength > longestName) { + longestName = parameterNameLength; + } + + numChars += parameterNameLength; + numChars += 5; // # $0\n$1 = $2\n\n + numChars += parameter.description().length(); + if (parameter.getStringValue() != null) { + numChars += parameter.getStringValue().length(); + } + + recurseCount(parameter); + } + } + + private void recurseWrite(Parametered parametered) { + for (Parameter<?> parameter : parametered.getParameters()) { + sb.append("# "); + sb.append(parameter.description()); + sb.append('\n'); + sb.append(parameter.prefix()); + sb.append(parameter.name()); + sb.append(" = "); + if (parameter.getStringValue() != null) { + sb.append(parameter.getStringValue()); + } + sb.append('\n'); + sb.append('\n'); + recurseWrite(parameter); + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/parameters/PathParameter.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/common/parameters/PathParameter.java b/mr/src/main/java/org/apache/mahout/common/parameters/PathParameter.java new file mode 100644 index 0000000..a617fe3 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/common/parameters/PathParameter.java @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.parameters; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; + +public class PathParameter extends AbstractParameter<Path> { + + public PathParameter(String prefix, String name, Configuration jobConf, Path defaultValue, String description) { + super(Path.class, prefix, name, jobConf, defaultValue, description); + } + + @Override + public void setStringValue(String stringValue) { + set(new Path(stringValue)); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/driver/MahoutDriver.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/driver/MahoutDriver.java b/mr/src/main/java/org/apache/mahout/driver/MahoutDriver.java new file mode 100644 index 0000000..1fd5506 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/driver/MahoutDriver.java @@ -0,0 +1,244 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.driver; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.io.Closeables; +import org.apache.hadoop.util.ProgramDriver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * General-purpose driver class for Mahout programs. Utilizes org.apache.hadoop.util.ProgramDriver to run + * main methods of other classes, but first loads up default properties from a properties file. + * <p/> + * To run locally: + * + * <pre>$MAHOUT_HOME/bin/mahout run shortJobName [over-ride ops]</pre> + * <p/> + * Works like this: by default, the file "driver.classes.props" is loaded from the classpath, which + * defines a mapping between short names like "vectordump" and fully qualified class names. + * The format of driver.classes.props is like so: + * <p/> + * + * <pre>fully.qualified.class.name = shortJobName : descriptive string</pre> + * <p/> + * The default properties to be applied to the program run is pulled out of, by default, "<shortJobName>.props" + * (also off of the classpath). + * <p/> + * The format of the default properties files is as follows: + * <pre> + i|input = /path/to/my/input + o|output = /path/to/my/output + m|jarFile = /path/to/jarFile + # etc - each line is shortArg|longArg = value + </pre> + * + * The next argument to the Driver is supposed to be the short name of the class to be run (as defined in the + * driver.classes.props file). + * <p/> + * Then the class which will be run will have it's main called with + * + * <pre>main(new String[] { "--input", "/path/to/my/input", "--output", "/path/to/my/output" });</pre> + * + * After all the "default" properties are loaded from the file, any further command-line arguments are taken in, + * and over-ride the defaults. + * <p/> + * So if your driver.classes.props looks like so: + * + * <pre>org.apache.mahout.utils.vectors.VectorDumper = vecDump : dump vectors from a sequence file</pre> + * + * and you have a file core/src/main/resources/vecDump.props which looks like + * <pre> + o|output = /tmp/vectorOut + s|seqFile = /my/vector/sequenceFile + </pre> + * + * And you execute the command-line: + * + * <pre>$MAHOUT_HOME/bin/mahout run vecDump -s /my/otherVector/sequenceFile</pre> + * + * Then org.apache.mahout.utils.vectors.VectorDumper.main() will be called with arguments: + * <pre>{"--output", "/tmp/vectorOut", "-s", "/my/otherVector/sequenceFile"}</pre> + */ +public final class MahoutDriver { + + private static final Logger log = LoggerFactory.getLogger(MahoutDriver.class); + + private MahoutDriver() { + } + + public static void main(String[] args) throws Throwable { + + Properties mainClasses = loadProperties("driver.classes.props"); + if (mainClasses == null) { + mainClasses = loadProperties("driver.classes.default.props"); + } + if (mainClasses == null) { + throw new IOException("Can't load any properties file?"); + } + + boolean foundShortName = false; + ProgramDriver programDriver = new ProgramDriver(); + for (Object key : mainClasses.keySet()) { + String keyString = (String) key; + if (args.length > 0 && shortName(mainClasses.getProperty(keyString)).equals(args[0])) { + foundShortName = true; + } + if (args.length > 0 && keyString.equalsIgnoreCase(args[0]) && isDeprecated(mainClasses, keyString)) { + log.error(desc(mainClasses.getProperty(keyString))); + return; + } + if (isDeprecated(mainClasses, keyString)) { + continue; + } + addClass(programDriver, keyString, mainClasses.getProperty(keyString)); + } + + if (args.length < 1 || args[0] == null || "-h".equals(args[0]) || "--help".equals(args[0])) { + programDriver.driver(args); + return; + } + + String progName = args[0]; + if (!foundShortName) { + addClass(programDriver, progName, progName); + } + shift(args); + + Properties mainProps = loadProperties(progName + ".props"); + if (mainProps == null) { + log.warn("No {}.props found on classpath, will use command-line arguments only", progName); + mainProps = new Properties(); + } + + Map<String,String[]> argMap = Maps.newHashMap(); + int i = 0; + while (i < args.length && args[i] != null) { + List<String> argValues = Lists.newArrayList(); + String arg = args[i]; + i++; + if (arg.startsWith("-D")) { // '-Dkey=value' or '-Dkey=value1,value2,etc' case + String[] argSplit = arg.split("="); + arg = argSplit[0]; + if (argSplit.length == 2) { + argValues.add(argSplit[1]); + } + } else { // '-key [values]' or '--key [values]' case. + while (i < args.length && args[i] != null) { + if (args[i].startsWith("-")) { + break; + } + argValues.add(args[i]); + i++; + } + } + argMap.put(arg, argValues.toArray(new String[argValues.size()])); + } + + // Add properties from the .props file that are not overridden on the command line + for (String key : mainProps.stringPropertyNames()) { + String[] argNamePair = key.split("\\|"); + String shortArg = '-' + argNamePair[0].trim(); + String longArg = argNamePair.length < 2 ? null : "--" + argNamePair[1].trim(); + if (!argMap.containsKey(shortArg) && (longArg == null || !argMap.containsKey(longArg))) { + argMap.put(longArg, new String[] {mainProps.getProperty(key)}); + } + } + + // Now add command-line args + List<String> argsList = Lists.newArrayList(); + argsList.add(progName); + for (Map.Entry<String,String[]> entry : argMap.entrySet()) { + String arg = entry.getKey(); + if (arg.startsWith("-D")) { // arg is -Dkey - if value for this !isEmpty(), then arg -> -Dkey + "=" + value + String[] argValues = entry.getValue(); + if (argValues.length > 0 && !argValues[0].trim().isEmpty()) { + arg += '=' + argValues[0].trim(); + } + argsList.add(1, arg); + } else { + argsList.add(arg); + for (String argValue : Arrays.asList(argMap.get(arg))) { + if (!argValue.isEmpty()) { + argsList.add(argValue); + } + } + } + } + + long start = System.currentTimeMillis(); + + programDriver.driver(argsList.toArray(new String[argsList.size()])); + + if (log.isInfoEnabled()) { + log.info("Program took {} ms (Minutes: {})", System.currentTimeMillis() - start, + (System.currentTimeMillis() - start) / 60000.0); + } + } + + private static boolean isDeprecated(Properties mainClasses, String keyString) { + return "deprecated".equalsIgnoreCase(shortName(mainClasses.getProperty(keyString))); + } + + private static Properties loadProperties(String resource) throws IOException { + InputStream propsStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(resource); + if (propsStream != null) { + try { + Properties properties = new Properties(); + properties.load(propsStream); + return properties; + } finally { + Closeables.close(propsStream, true); + } + } + return null; + } + + private static String[] shift(String[] args) { + System.arraycopy(args, 1, args, 0, args.length - 1); + args[args.length - 1] = null; + return args; + } + + private static String shortName(String valueString) { + return valueString.contains(":") ? valueString.substring(0, valueString.indexOf(':')).trim() : valueString; + } + + private static String desc(String valueString) { + return valueString.contains(":") ? valueString.substring(valueString.indexOf(':')).trim() : valueString; + } + + private static void addClass(ProgramDriver driver, String classString, String descString) { + try { + Class<?> clazz = Class.forName(classString); + driver.addClass(shortName(descString), clazz, desc(descString)); + } catch (Throwable t) { + log.warn("Unable to add class: {}", classString, t); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java b/mr/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java new file mode 100644 index 0000000..b744287 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java @@ -0,0 +1,228 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.ep; + +import com.google.common.collect.Lists; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.sgd.PolymorphicWritable; + +import java.io.Closeable; +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +/** + * Allows evolutionary optimization where the state function can't be easily + * packaged for the optimizer to execute. A good example of this is with + * on-line learning where optimizing the learning parameters is desirable. + * We would like to pass training examples to the learning algorithms, but + * we definitely want to do the training in multiple threads and then after + * several training steps, we want to do a selection and mutation step. + * + * In such a case, it is highly desirable to leave most of the control flow + * in the hands of our caller. As such, this class provides three functions, + * <ul> + * <li> Storage of the evolutionary state. The state variables have payloads + * which can be anything that implements Payload. + * <li> Threaded execution of a single operation on each of the members of the + * population being evolved. In the on-line learning example, this is used for + * training all of the classifiers in the population. + * <li> Propagating mutations of the most successful members of the population. + * This propagation involves copying the state and the payload and then updating + * the payload after mutation of the evolutionary state. + * </ul> + * + * The State class that we use for storing the state of each member of the + * population also provides parameter mapping. Check out Mapping and State + * for more info. + * + * @see Mapping + * @see Payload + * @see State + * + * @param <T> The payload class. + */ +public class EvolutionaryProcess<T extends Payload<U>, U> implements Writable, Closeable { + // used to execute operations on the population in thread parallel. + private ExecutorService pool; + + // threadCount is serialized so that we can reconstruct the thread pool + private int threadCount; + + // list of members of the population + private List<State<T, U>> population; + + // how big should the population be. If this is changed, it will take effect + // the next time the population is mutated. + + private int populationSize; + + public EvolutionaryProcess() { + population = Lists.newArrayList(); + } + + /** + * Creates an evolutionary optimization framework with specified threadiness, + * population size and initial state. + * @param threadCount How many threads to use in parallelDo + * @param populationSize How large a population to use + * @param seed An initial population member + */ + public EvolutionaryProcess(int threadCount, int populationSize, State<T, U> seed) { + this.populationSize = populationSize; + setThreadCount(threadCount); + initializePopulation(populationSize, seed); + } + + private void initializePopulation(int populationSize, State<T, U> seed) { + population = Lists.newArrayList(seed); + for (int i = 0; i < populationSize; i++) { + population.add(seed.mutate()); + } + } + + public void add(State<T, U> value) { + population.add(value); + } + + /** + * Nuke all but a few of the current population and then repopulate with + * variants of the survivors. + * @param survivors How many survivors we want to keep. + */ + public void mutatePopulation(int survivors) { + // largest value first, oldest first in case of ties + Collections.sort(population); + + // we copy here to avoid concurrent modification + List<State<T, U>> parents = Lists.newArrayList(population.subList(0, survivors)); + population.subList(survivors, population.size()).clear(); + + // fill out the population with offspring from the survivors + int i = 0; + while (population.size() < populationSize) { + population.add(parents.get(i % survivors).mutate()); + i++; + } + } + + /** + * Execute an operation on all of the members of the population with many threads. The + * return value is taken as the current fitness of the corresponding member. + * @param fn What to do on each member. Gets payload and the mapped parameters as args. + * @return The member of the population with the best fitness. + * @throws InterruptedException Shouldn't happen. + * @throws ExecutionException If fn throws an exception, that exception will be collected + * and rethrown nested in an ExecutionException. + */ + public State<T, U> parallelDo(final Function<Payload<U>> fn) throws InterruptedException, ExecutionException { + Collection<Callable<State<T, U>>> tasks = Lists.newArrayList(); + for (final State<T, U> state : population) { + tasks.add(new Callable<State<T, U>>() { + @Override + public State<T, U> call() { + double v = fn.apply(state.getPayload(), state.getMappedParams()); + state.setValue(v); + return state; + } + }); + } + + List<Future<State<T, U>>> r = pool.invokeAll(tasks); + + // zip through the results and find the best one + double max = Double.NEGATIVE_INFINITY; + State<T, U> best = null; + for (Future<State<T, U>> future : r) { + State<T, U> s = future.get(); + double value = s.getValue(); + if (!Double.isNaN(value) && value >= max) { + max = value; + best = s; + } + } + if (best == null) { + best = r.get(0).get(); + } + + return best; + } + + public void setThreadCount(int threadCount) { + this.threadCount = threadCount; + pool = Executors.newFixedThreadPool(threadCount); + } + + public int getThreadCount() { + return threadCount; + } + + public int getPopulationSize() { + return populationSize; + } + + public List<State<T, U>> getPopulation() { + return population; + } + + @Override + public void close() { + List<Runnable> remainingTasks = pool.shutdownNow(); + try { + pool.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + throw new IllegalStateException("Had to forcefully shut down " + remainingTasks.size() + " tasks"); + } + if (!remainingTasks.isEmpty()) { + throw new IllegalStateException("Had to forcefully shut down " + remainingTasks.size() + " tasks"); + } + } + + public interface Function<T> { + double apply(T payload, double[] params); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(threadCount); + out.writeInt(population.size()); + for (State<T, U> state : population) { + PolymorphicWritable.write(out, state); + } + } + + @Override + public void readFields(DataInput input) throws IOException { + setThreadCount(input.readInt()); + int n = input.readInt(); + population = Lists.newArrayList(); + for (int i = 0; i < n; i++) { + State<T, U> state = (State<T, U>) PolymorphicWritable.read(input, State.class); + population.add(state); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/ep/Mapping.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/ep/Mapping.java b/mr/src/main/java/org/apache/mahout/ep/Mapping.java new file mode 100644 index 0000000..41a8942 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/ep/Mapping.java @@ -0,0 +1,206 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.ep; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.sgd.PolymorphicWritable; +import org.apache.mahout.math.function.DoubleFunction; + +/** + * Provides coordinate tranformations so that evolution can proceed on the entire space of + * reals but have the output limited and squished in convenient (and safe) ways. + */ +public abstract class Mapping extends DoubleFunction implements Writable { + + private Mapping() { + } + + public static final class SoftLimit extends Mapping { + private double min; + private double max; + private double scale; + + public SoftLimit() { + } + + private SoftLimit(double min, double max, double scale) { + this.min = min; + this.max = max; + this.scale = scale; + } + + @Override + public double apply(double v) { + return min + (max - min) * 1 / (1 + Math.exp(-v * scale)); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeDouble(min); + out.writeDouble(max); + out.writeDouble(scale); + } + + @Override + public void readFields(DataInput in) throws IOException { + min = in.readDouble(); + max = in.readDouble(); + scale = in.readDouble(); + } + } + + public static final class LogLimit extends Mapping { + private Mapping wrapped; + + public LogLimit() { + } + + private LogLimit(double low, double high) { + wrapped = softLimit(Math.log(low), Math.log(high)); + } + + @Override + public double apply(double v) { + return Math.exp(wrapped.apply(v)); + } + + @Override + public void write(DataOutput dataOutput) throws IOException { + PolymorphicWritable.write(dataOutput, wrapped); + } + + @Override + public void readFields(DataInput in) throws IOException { + wrapped = PolymorphicWritable.read(in, Mapping.class); + } + } + + public static final class Exponential extends Mapping { + private double scale; + + public Exponential() { + } + + private Exponential(double scale) { + this.scale = scale; + } + + @Override + public double apply(double v) { + return Math.exp(v * scale); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeDouble(scale); + } + + @Override + public void readFields(DataInput in) throws IOException { + scale = in.readDouble(); + } + } + + public static final class Identity extends Mapping { + @Override + public double apply(double v) { + return v; + } + + @Override + public void write(DataOutput dataOutput) { + // stateless + } + + @Override + public void readFields(DataInput dataInput) { + // stateless + } + } + + /** + * Maps input to the open interval (min, max) with 0 going to the mean of min and + * max. When scale is large, a larger proportion of values are mapped to points + * near the boundaries. When scale is small, a larger proportion of values are mapped to + * points well within the boundaries. + * @param min The largest lower bound on values to be returned. + * @param max The least upper bound on values to be returned. + * @param scale Defines how sharp the boundaries are. + * @return A mapping that satisfies the desired constraint. + */ + public static Mapping softLimit(double min, double max, double scale) { + return new SoftLimit(min, max, scale); + } + + /** + * Maps input to the open interval (min, max) with 0 going to the mean of min and + * max. When scale is large, a larger proportion of values are mapped to points + * near the boundaries. + * @see #softLimit(double, double, double) + * @param min The largest lower bound on values to be returned. + * @param max The least upper bound on values to be returned. + * @return A mapping that satisfies the desired constraint. + */ + public static Mapping softLimit(double min, double max) { + return softLimit(min, max, 1); + } + + /** + * Maps input to positive values in the open interval (min, max) with + * 0 going to the geometric mean. Near the geometric mean, values are + * distributed roughly geometrically. + * @param low The largest lower bound for output results. Must be >0. + * @param high The least upper bound for output results. Must be >0. + * @return A mapped value. + */ + public static Mapping logLimit(double low, double high) { + Preconditions.checkArgument(low > 0, "Lower bound for log limit must be > 0 but was %f", low); + Preconditions.checkArgument(high > 0, "Upper bound for log limit must be > 0 but was %f", high); + return new LogLimit(low, high); + } + + /** + * Maps results to positive values. + * @return A positive value. + */ + public static Mapping exponential() { + return exponential(1); + } + + /** + * Maps results to positive values. + * @param scale If large, then large values are more likely. + * @return A positive value. + */ + public static Mapping exponential(double scale) { + return new Exponential(scale); + } + + /** + * Maps results to themselves. + * @return The original value. + */ + public static Mapping identity() { + return new Identity(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/ep/Payload.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/ep/Payload.java b/mr/src/main/java/org/apache/mahout/ep/Payload.java new file mode 100644 index 0000000..920237d --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/ep/Payload.java @@ -0,0 +1,36 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.ep; + +import org.apache.hadoop.io.Writable; + +/** + * Payloads for evolutionary state must be copyable and updatable. The copy should be a deep copy + * unless some aspect of the state is sharable or immutable. + * <p/> + * During mutation, a copy is first made and then after the parameters in the State structure are + * suitably modified, update is called with the scaled versions of the parameters. + * + * @param <T> + * @see State + */ +public interface Payload<T> extends Writable { + Payload<T> copy(); + + void update(double[] params); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/ep/State.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/ep/State.java b/mr/src/main/java/org/apache/mahout/ep/State.java new file mode 100644 index 0000000..7a0fb5e --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/ep/State.java @@ -0,0 +1,302 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.ep; + +import com.google.common.collect.Lists; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.sgd.PolymorphicWritable; +import org.apache.mahout.common.RandomUtils; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Locale; +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Records evolutionary state and provides a mutation operation for recorded-step meta-mutation. + * + * You provide the payload, this class provides the mutation operations. During mutation, + * the payload is copied and after the state variables are changed, they are passed to the + * payload. + * + * Parameters are internally mutated in a state space that spans all of R^n, but parameters + * passed to the payload are transformed as specified by a call to setMap(). The default + * mapping is the identity map, but uniform-ish or exponential-ish coverage of a range are + * also supported. + * + * More information on the underlying algorithm can be found in the following paper + * + * http://arxiv.org/abs/0803.3838 + * + * @see Mapping + */ +public class State<T extends Payload<U>, U> implements Comparable<State<T, U>>, Writable { + + // object count is kept to break ties in comparison. + private static final AtomicInteger OBJECT_COUNT = new AtomicInteger(); + + private int id = OBJECT_COUNT.getAndIncrement(); + private Random gen = RandomUtils.getRandom(); + // current state + private double[] params; + // mappers to transform state + private Mapping[] maps; + // omni-directional mutation + private double omni; + // directional mutation + private double[] step; + // current fitness value + private double value; + private T payload; + + public State() { + } + + /** + * Invent a new state with no momentum (yet). + */ + public State(double[] x0, double omni) { + params = Arrays.copyOf(x0, x0.length); + this.omni = omni; + step = new double[params.length]; + maps = new Mapping[params.length]; + } + + /** + * Deep copies a state, useful in mutation. + */ + public State<T, U> copy() { + State<T, U> r = new State<>(); + r.params = Arrays.copyOf(this.params, this.params.length); + r.omni = this.omni; + r.step = Arrays.copyOf(this.step, this.step.length); + r.maps = Arrays.copyOf(this.maps, this.maps.length); + if (this.payload != null) { + r.payload = (T) this.payload.copy(); + } + r.gen = this.gen; + return r; + } + + /** + * Clones this state with a random change in position. Copies the payload and + * lets it know about the change. + * + * @return A new state. + */ + public State<T, U> mutate() { + double sum = 0; + for (double v : step) { + sum += v * v; + } + sum = Math.sqrt(sum); + double lambda = 1 + gen.nextGaussian(); + + State<T, U> r = this.copy(); + double magnitude = 0.9 * omni + sum / 10; + r.omni = magnitude * -Math.log1p(-gen.nextDouble()); + for (int i = 0; i < step.length; i++) { + r.step[i] = lambda * step[i] + r.omni * gen.nextGaussian(); + r.params[i] += r.step[i]; + } + if (this.payload != null) { + r.payload.update(r.getMappedParams()); + } + return r; + } + + /** + * Defines the transformation for a parameter. + * @param i Which parameter's mapping to define. + * @param m The mapping to use. + * @see org.apache.mahout.ep.Mapping + */ + public void setMap(int i, Mapping m) { + maps[i] = m; + } + + /** + * Returns a transformed parameter. + * @param i The parameter to return. + * @return The value of the parameter. + */ + public double get(int i) { + Mapping m = maps[i]; + return m == null ? params[i] : m.apply(params[i]); + } + + public int getId() { + return id; + } + + public double[] getParams() { + return params; + } + + public Mapping[] getMaps() { + return maps; + } + + /** + * Returns all the parameters in mapped form. + * @return An array of parameters. + */ + public double[] getMappedParams() { + double[] r = Arrays.copyOf(params, params.length); + for (int i = 0; i < params.length; i++) { + r[i] = get(i); + } + return r; + } + + public double getOmni() { + return omni; + } + + public double[] getStep() { + return step; + } + + public T getPayload() { + return payload; + } + + public double getValue() { + return value; + } + + public void setOmni(double omni) { + this.omni = omni; + } + + public void setId(int id) { + this.id = id; + } + + public void setStep(double[] step) { + this.step = step; + } + + public void setMaps(Mapping[] maps) { + this.maps = maps; + } + + public void setMaps(Iterable<Mapping> maps) { + Collection<Mapping> list = Lists.newArrayList(maps); + this.maps = list.toArray(new Mapping[list.size()]); + } + + public void setValue(double v) { + value = v; + } + + public void setPayload(T payload) { + this.payload = payload; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof State)) { + return false; + } + State<?,?> other = (State<?,?>) o; + return id == other.id && value == other.value; + } + + @Override + public int hashCode() { + return RandomUtils.hashDouble(value) ^ id; + } + + /** + * Natural order is to sort in descending order of score. Creation order is used as a + * tie-breaker. + * + * @param other The state to compare with. + * @return -1, 0, 1 if the other state is better, identical or worse than this one. + */ + @Override + public int compareTo(State<T, U> other) { + int r = Double.compare(other.value, this.value); + if (r != 0) { + return r; + } + if (this.id < other.id) { + return -1; + } + if (this.id > other.id) { + return 1; + } + return 0; + } + + @Override + public String toString() { + double sum = 0; + for (double v : step) { + sum += v * v; + } + return String.format(Locale.ENGLISH, "<S/%s %.3f %.3f>", payload, omni + Math.sqrt(sum), value); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(id); + out.writeInt(params.length); + for (double v : params) { + out.writeDouble(v); + } + for (Mapping map : maps) { + PolymorphicWritable.write(out, map); + } + + out.writeDouble(omni); + for (double v : step) { + out.writeDouble(v); + } + + out.writeDouble(value); + PolymorphicWritable.write(out, payload); + } + + @Override + public void readFields(DataInput input) throws IOException { + id = input.readInt(); + int n = input.readInt(); + params = new double[n]; + for (int i = 0; i < n; i++) { + params[i] = input.readDouble(); + } + + maps = new Mapping[n]; + for (int i = 0; i < n; i++) { + maps[i] = PolymorphicWritable.read(input, Mapping.class); + } + omni = input.readDouble(); + step = new double[n]; + for (int i = 0; i < n; i++) { + step[i] = input.readDouble(); + } + value = input.readDouble(); + payload = (T) PolymorphicWritable.read(input, Payload.class); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/ep/package-info.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/ep/package-info.java b/mr/src/main/java/org/apache/mahout/ep/package-info.java new file mode 100644 index 0000000..4afe677 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/ep/package-info.java @@ -0,0 +1,26 @@ +/** + * <p>Provides basic evolutionary optimization using <a href="http://arxiv.org/abs/0803.3838">recorded-step</a> + * mutation.</p> + * + * <p>With this style of optimization, we can optimize a function {@code f: R^n -> R} by stochastic + * hill-climbing with some of the benefits of conjugate gradient style history encoded in the mutation function. + * This mutation function will adapt to allow weakly directed search rather than using the somewhat more + * conventional symmetric Gaussian.</p> + * + * <p>With recorded-step mutation, the meta-mutation parameters are all auto-encoded in the current state of each point. + * This avoids the classic problem of having more mutation rate parameters than are in the original state and then + * requiring even more parameters to describe the meta-mutation rate. Instead, we store the previous point and one + * omni-directional mutation component. Mutation is performed by first mutating along the line formed by the previous + * and current points and then adding a scaled symmetric Gaussian. The magnitude of the omni-directional mutation is + * then mutated using itself as a scale.</p> + * + * <p>Because it is convenient to not restrict the parameter space, this package also provides convenient parameter + * mapping methods. These mapping methods map the set of reals to a finite open interval (a,b) in such a way that + * {@code lim_{x->-\inf} f(x) = a} and {@code lim_{x->\inf} f(x) = b}. The linear mapping is defined so that + * {@code f(0) = (a+b)/2} and the exponential mapping requires that a and b are both positive and has + * {@code f(0) = sqrt(ab)}. The linear mapping is useful for values that must stay roughly within a range but + * which are roughly uniform within the center of that range. The exponential + * mapping is useful for values that must stay within a range but whose distribution is roughly exponential near + * geometric mean of the end-points. An identity mapping is also supplied.</p> + */ +package org.apache.mahout.ep; http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java b/mr/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java new file mode 100644 index 0000000..6618a1a --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.math; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; + +import java.io.IOException; + +public final class DistributedRowMatrixWriter { + + private DistributedRowMatrixWriter() { + } + + public static void write(Path outputDir, Configuration conf, Iterable<MatrixSlice> matrix) throws IOException { + FileSystem fs = outputDir.getFileSystem(conf); + SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, outputDir, + IntWritable.class, VectorWritable.class); + IntWritable topic = new IntWritable(); + VectorWritable vector = new VectorWritable(); + for (MatrixSlice slice : matrix) { + topic.set(slice.index()); + vector.set(slice.vector()); + writer.append(topic, vector); + } + writer.close(); + + } + +}
