http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java new file mode 100644 index 0000000..0e7ee96 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java @@ -0,0 +1,417 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.commandline; + +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure; +import org.apache.mahout.clustering.kernel.TriangularKernelProfile; + + +public final class DefaultOptionCreator { + + public static final String CLUSTERING_OPTION = "clustering"; + + public static final String CLUSTERS_IN_OPTION = "clusters"; + + public static final String CONVERGENCE_DELTA_OPTION = "convergenceDelta"; + + public static final String DISTANCE_MEASURE_OPTION = "distanceMeasure"; + + public static final String EMIT_MOST_LIKELY_OPTION = "emitMostLikely"; + + public static final String INPUT_OPTION = "input"; + + public static final String MAX_ITERATIONS_OPTION = "maxIter"; + + public static final String MAX_REDUCERS_OPTION = "maxRed"; + + public static final String METHOD_OPTION = "method"; + + public static final String NUM_CLUSTERS_OPTION = "numClusters"; + + public static final String OUTPUT_OPTION = "output"; + + public static final String OVERWRITE_OPTION = "overwrite"; + + public static final String T1_OPTION = "t1"; + + public static final String T2_OPTION = "t2"; + + public static final String T3_OPTION = "t3"; + + public static final String T4_OPTION = "t4"; + + public static final String OUTLIER_THRESHOLD = "outlierThreshold"; + + public static final String CLUSTER_FILTER_OPTION = "clusterFilter"; + + public static final String THRESHOLD_OPTION = "threshold"; + + public static final String SEQUENTIAL_METHOD = "sequential"; + + public static final String MAPREDUCE_METHOD = "mapreduce"; + + public static final String KERNEL_PROFILE_OPTION = "kernelProfile"; + + public static final String ANALYZER_NAME_OPTION = "analyzerName"; + + public static final String RANDOM_SEED = "randomSeed"; + + private DefaultOptionCreator() {} + + /** + * Returns a default command line option for help. Used by all clustering jobs + * and many others + * */ + public static Option helpOption() { + return new DefaultOptionBuilder().withLongName("help") + .withDescription("Print out help").withShortName("h").create(); + } + + /** + * Returns a default command line option for input directory specification. + * Used by all clustering jobs plus others + */ + public static DefaultOptionBuilder inputOption() { + return new DefaultOptionBuilder() + .withLongName(INPUT_OPTION) + .withRequired(false) + .withShortName("i") + .withArgument( + new ArgumentBuilder().withName(INPUT_OPTION).withMinimum(1) + .withMaximum(1).create()) + .withDescription("Path to job input directory."); + } + + /** + * Returns a default command line option for clusters input directory + * specification. Used by FuzzyKmeans, Kmeans + */ + public static DefaultOptionBuilder clustersInOption() { + return new DefaultOptionBuilder() + .withLongName(CLUSTERS_IN_OPTION) + .withRequired(true) + .withArgument( + new ArgumentBuilder().withName(CLUSTERS_IN_OPTION).withMinimum(1) + .withMaximum(1).create()) + .withDescription( + "The path to the initial clusters directory. Must be a SequenceFile of some type of Cluster") + .withShortName("c"); + } + + /** + * Returns a default command line option for output directory specification. + * Used by all clustering jobs plus others + */ + public static DefaultOptionBuilder outputOption() { + return new DefaultOptionBuilder() + .withLongName(OUTPUT_OPTION) + .withRequired(false) + .withShortName("o") + .withArgument( + new ArgumentBuilder().withName(OUTPUT_OPTION).withMinimum(1) + .withMaximum(1).create()) + .withDescription("The directory pathname for output."); + } + + /** + * Returns a default command line option for output directory overwriting. + * Used by all clustering jobs + */ + public static DefaultOptionBuilder overwriteOption() { + return new DefaultOptionBuilder() + .withLongName(OVERWRITE_OPTION) + .withRequired(false) + .withDescription( + "If present, overwrite the output directory before running job") + .withShortName("ow"); + } + + /** + * Returns a default command line option for specification of distance measure + * class to use. Used by Canopy, FuzzyKmeans, Kmeans, MeanShift + */ + public static DefaultOptionBuilder distanceMeasureOption() { + return new DefaultOptionBuilder() + .withLongName(DISTANCE_MEASURE_OPTION) + .withRequired(false) + .withShortName("dm") + .withArgument( + new ArgumentBuilder().withName(DISTANCE_MEASURE_OPTION) + .withDefault(SquaredEuclideanDistanceMeasure.class.getName()) + .withMinimum(1).withMaximum(1).create()) + .withDescription( + "The classname of the DistanceMeasure. Default is SquaredEuclidean"); + } + + /** + * Returns a default command line option for specification of sequential or + * parallel operation. Used by Canopy, FuzzyKmeans, Kmeans, MeanShift, + * Dirichlet + */ + public static DefaultOptionBuilder methodOption() { + return new DefaultOptionBuilder() + .withLongName(METHOD_OPTION) + .withRequired(false) + .withShortName("xm") + .withArgument( + new ArgumentBuilder().withName(METHOD_OPTION) + .withDefault(MAPREDUCE_METHOD).withMinimum(1).withMaximum(1) + .create()) + .withDescription( + "The execution method to use: sequential or mapreduce. Default is mapreduce"); + } + + /** + * Returns a default command line option for specification of T1. Used by + * Canopy, MeanShift + */ + public static DefaultOptionBuilder t1Option() { + return new DefaultOptionBuilder() + .withLongName(T1_OPTION) + .withRequired(true) + .withArgument( + new ArgumentBuilder().withName(T1_OPTION).withMinimum(1) + .withMaximum(1).create()).withDescription("T1 threshold value") + .withShortName(T1_OPTION); + } + + /** + * Returns a default command line option for specification of T2. Used by + * Canopy, MeanShift + */ + public static DefaultOptionBuilder t2Option() { + return new DefaultOptionBuilder() + .withLongName(T2_OPTION) + .withRequired(true) + .withArgument( + new ArgumentBuilder().withName(T2_OPTION).withMinimum(1) + .withMaximum(1).create()).withDescription("T2 threshold value") + .withShortName(T2_OPTION); + } + + /** + * Returns a default command line option for specification of T3 (Reducer T1). + * Used by Canopy + */ + public static DefaultOptionBuilder t3Option() { + return new DefaultOptionBuilder() + .withLongName(T3_OPTION) + .withRequired(false) + .withArgument( + new ArgumentBuilder().withName(T3_OPTION).withMinimum(1) + .withMaximum(1).create()) + .withDescription("T3 (Reducer T1) threshold value") + .withShortName(T3_OPTION); + } + + /** + * Returns a default command line option for specification of T4 (Reducer T2). + * Used by Canopy + */ + public static DefaultOptionBuilder t4Option() { + return new DefaultOptionBuilder() + .withLongName(T4_OPTION) + .withRequired(false) + .withArgument( + new ArgumentBuilder().withName(T4_OPTION).withMinimum(1) + .withMaximum(1).create()) + .withDescription("T4 (Reducer T2) threshold value") + .withShortName(T4_OPTION); + } + + /** + * @return a DefaultOptionBuilder for the clusterFilter option + */ + public static DefaultOptionBuilder clusterFilterOption() { + return new DefaultOptionBuilder() + .withLongName(CLUSTER_FILTER_OPTION) + .withShortName("cf") + .withRequired(false) + .withArgument( + new ArgumentBuilder().withName(CLUSTER_FILTER_OPTION).withMinimum(1) + .withMaximum(1).create()) + .withDescription("Cluster filter suppresses small canopies from mapper") + .withShortName(CLUSTER_FILTER_OPTION); + } + + /** + * Returns a default command line option for specification of max number of + * iterations. Used by Dirichlet, FuzzyKmeans, Kmeans, LDA + */ + public static DefaultOptionBuilder maxIterationsOption() { + // default value used by LDA which overrides withRequired(false) + return new DefaultOptionBuilder() + .withLongName(MAX_ITERATIONS_OPTION) + .withRequired(true) + .withShortName("x") + .withArgument( + new ArgumentBuilder().withName(MAX_ITERATIONS_OPTION) + .withDefault("-1").withMinimum(1).withMaximum(1).create()) + .withDescription("The maximum number of iterations."); + } + + /** + * Returns a default command line option for specification of numbers of + * clusters to create. Used by Dirichlet, FuzzyKmeans, Kmeans + */ + public static DefaultOptionBuilder numClustersOption() { + return new DefaultOptionBuilder() + .withLongName(NUM_CLUSTERS_OPTION) + .withRequired(false) + .withArgument( + new ArgumentBuilder().withName("k").withMinimum(1).withMaximum(1) + .create()).withDescription("The number of clusters to create") + .withShortName("k"); + } + + public static DefaultOptionBuilder useSetRandomSeedOption() { + return new DefaultOptionBuilder() + .withLongName(RANDOM_SEED) + .withRequired(false) + .withArgument(new ArgumentBuilder().withName(RANDOM_SEED).create()) + .withDescription("Seed to initaize Random Number Generator with") + .withShortName("rs"); + } + + /** + * Returns a default command line option for convergence delta specification. + * Used by FuzzyKmeans, Kmeans, MeanShift + */ + public static DefaultOptionBuilder convergenceOption() { + return new DefaultOptionBuilder() + .withLongName(CONVERGENCE_DELTA_OPTION) + .withRequired(false) + .withShortName("cd") + .withArgument( + new ArgumentBuilder().withName(CONVERGENCE_DELTA_OPTION) + .withDefault("0.5").withMinimum(1).withMaximum(1).create()) + .withDescription("The convergence delta value. Default is 0.5"); + } + + /** + * Returns a default command line option for specifying the max number of + * reducers. Used by Dirichlet, FuzzyKmeans, Kmeans and LDA + * + * @deprecated + */ + @Deprecated + public static DefaultOptionBuilder numReducersOption() { + return new DefaultOptionBuilder() + .withLongName(MAX_REDUCERS_OPTION) + .withRequired(false) + .withShortName("r") + .withArgument( + new ArgumentBuilder().withName(MAX_REDUCERS_OPTION) + .withDefault("2").withMinimum(1).withMaximum(1).create()) + .withDescription("The number of reduce tasks. Defaults to 2"); + } + + /** + * Returns a default command line option for clustering specification. Used by + * all clustering except LDA + */ + public static DefaultOptionBuilder clusteringOption() { + return new DefaultOptionBuilder() + .withLongName(CLUSTERING_OPTION) + .withRequired(false) + .withDescription( + "If present, run clustering after the iterations have taken place") + .withShortName("cl"); + } + + /** + * Returns a default command line option for specifying a Lucene analyzer class + * @return {@link DefaultOptionBuilder} + */ + public static DefaultOptionBuilder analyzerOption() { + return new DefaultOptionBuilder() + .withLongName(ANALYZER_NAME_OPTION) + .withRequired(false) + .withDescription("If present, the name of a Lucene analyzer class to use") + .withArgument(new ArgumentBuilder().withName(ANALYZER_NAME_OPTION).withDefault(StandardAnalyzer.class.getName()) + .withMinimum(1).withMaximum(1).create()) + .withShortName("an"); + } + + + /** + * Returns a default command line option for specifying the emitMostLikely + * flag. Used by Dirichlet and FuzzyKmeans + */ + public static DefaultOptionBuilder emitMostLikelyOption() { + return new DefaultOptionBuilder() + .withLongName(EMIT_MOST_LIKELY_OPTION) + .withRequired(false) + .withShortName("e") + .withArgument( + new ArgumentBuilder().withName(EMIT_MOST_LIKELY_OPTION) + .withDefault("true").withMinimum(1).withMaximum(1).create()) + .withDescription( + "True if clustering should emit the most likely point only, " + + "false for threshold clustering. Default is true"); + } + + /** + * Returns a default command line option for specifying the clustering + * threshold value. Used by Dirichlet and FuzzyKmeans + */ + public static DefaultOptionBuilder thresholdOption() { + return new DefaultOptionBuilder() + .withLongName(THRESHOLD_OPTION) + .withRequired(false) + .withShortName("t") + .withArgument( + new ArgumentBuilder().withName(THRESHOLD_OPTION).withDefault("0") + .withMinimum(1).withMaximum(1).create()) + .withDescription( + "The pdf threshold used for cluster determination. Default is 0"); + } + + public static DefaultOptionBuilder kernelProfileOption() { + return new DefaultOptionBuilder() + .withLongName(KERNEL_PROFILE_OPTION) + .withRequired(false) + .withShortName("kp") + .withArgument( + new ArgumentBuilder() + .withName(KERNEL_PROFILE_OPTION) + .withDefault(TriangularKernelProfile.class.getName()) + .withMinimum(1).withMaximum(1).create()) + .withDescription( + "The classname of the IKernelProfile. Default is TriangularKernelProfile"); + } + + /** + * Returns a default command line option for specification of OUTLIER THRESHOLD value. Used for + * Cluster Classification. + */ + public static DefaultOptionBuilder outlierThresholdOption() { + return new DefaultOptionBuilder() + .withLongName(OUTLIER_THRESHOLD) + .withRequired(false) + .withArgument( + new ArgumentBuilder().withName(OUTLIER_THRESHOLD).withMinimum(1) + .withMaximum(1).create()).withDescription("Outlier threshold value") + .withShortName(OUTLIER_THRESHOLD); + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/ChebyshevDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/ChebyshevDistanceMeasure.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/ChebyshevDistanceMeasure.java new file mode 100644 index 0000000..61aa9a5 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/ChebyshevDistanceMeasure.java @@ -0,0 +1,63 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import java.util.Collection; +import java.util.Collections; + +import org.apache.hadoop.conf.Configuration; +import org.apache.mahout.common.parameters.Parameter; +import org.apache.mahout.math.CardinalityException; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; + +/** + * This class implements a "Chebyshev distance" metric by finding the maximum difference + * between each coordinate. Also 'chessboard distance' due to the moves a king can make. + */ +public class ChebyshevDistanceMeasure implements DistanceMeasure { + + @Override + public void configure(Configuration job) { + // nothing to do + } + + @Override + public Collection<Parameter<?>> getParameters() { + return Collections.emptyList(); + } + + @Override + public void createParameters(String prefix, Configuration jobConf) { + // nothing to do + } + + @Override + public double distance(Vector v1, Vector v2) { + if (v1.size() != v2.size()) { + throw new CardinalityException(v1.size(), v2.size()); + } + return v1.aggregate(v2, Functions.MAX_ABS, Functions.MINUS); + } + + @Override + public double distance(double centroidLengthSquare, Vector centroid, Vector v) { + return distance(centroid, v); // TODO + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java new file mode 100644 index 0000000..37265eb --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java @@ -0,0 +1,119 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import java.util.Collection; +import java.util.Collections; + +import org.apache.hadoop.conf.Configuration; +import org.apache.mahout.common.parameters.Parameter; +import org.apache.mahout.math.CardinalityException; +import org.apache.mahout.math.Vector; + +/** + * This class implements a cosine distance metric by dividing the dot product of two vectors by the product of their + * lengths. That gives the cosine of the angle between the two vectors. To convert this to a usable distance, + * 1-cos(angle) is what is actually returned. + */ +public class CosineDistanceMeasure implements DistanceMeasure { + + @Override + public void configure(Configuration job) { + // nothing to do + } + + @Override + public Collection<Parameter<?>> getParameters() { + return Collections.emptyList(); + } + + @Override + public void createParameters(String prefix, Configuration jobConf) { + // nothing to do + } + + public static double distance(double[] p1, double[] p2) { + double dotProduct = 0.0; + double lengthSquaredp1 = 0.0; + double lengthSquaredp2 = 0.0; + for (int i = 0; i < p1.length; i++) { + lengthSquaredp1 += p1[i] * p1[i]; + lengthSquaredp2 += p2[i] * p2[i]; + dotProduct += p1[i] * p2[i]; + } + double denominator = Math.sqrt(lengthSquaredp1) * Math.sqrt(lengthSquaredp2); + + // correct for floating-point rounding errors + if (denominator < dotProduct) { + denominator = dotProduct; + } + + // correct for zero-vector corner case + if (denominator == 0 && dotProduct == 0) { + return 0; + } + + return 1.0 - dotProduct / denominator; + } + + @Override + public double distance(Vector v1, Vector v2) { + if (v1.size() != v2.size()) { + throw new CardinalityException(v1.size(), v2.size()); + } + double lengthSquaredv1 = v1.getLengthSquared(); + double lengthSquaredv2 = v2.getLengthSquared(); + + double dotProduct = v2.dot(v1); + double denominator = Math.sqrt(lengthSquaredv1) * Math.sqrt(lengthSquaredv2); + + // correct for floating-point rounding errors + if (denominator < dotProduct) { + denominator = dotProduct; + } + + // correct for zero-vector corner case + if (denominator == 0 && dotProduct == 0) { + return 0; + } + + return 1.0 - dotProduct / denominator; + } + + @Override + public double distance(double centroidLengthSquare, Vector centroid, Vector v) { + + double lengthSquaredv = v.getLengthSquared(); + + double dotProduct = v.dot(centroid); + double denominator = Math.sqrt(centroidLengthSquare) * Math.sqrt(lengthSquaredv); + + // correct for floating-point rounding errors + if (denominator < dotProduct) { + denominator = dotProduct; + } + + // correct for zero-vector corner case + if (denominator == 0 && dotProduct == 0) { + return 0; + } + + return 1.0 - dotProduct / denominator; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/DistanceMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/DistanceMeasure.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/DistanceMeasure.java new file mode 100644 index 0000000..696e79c --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/DistanceMeasure.java @@ -0,0 +1,48 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import org.apache.mahout.common.parameters.Parametered; +import org.apache.mahout.math.Vector; + +/** This interface is used for objects which can determine a distance metric between two points */ +public interface DistanceMeasure extends Parametered { + + /** + * Returns the distance metric applied to the arguments + * + * @param v1 + * a Vector defining a multidimensional point in some feature space + * @param v2 + * a Vector defining a multidimensional point in some feature space + * @return a scalar doubles of the distance + */ + double distance(Vector v1, Vector v2); + + /** + * Optimized version of distance metric for sparse vectors. This distance computation requires operations + * proportional to the number of non-zero elements in the vector instead of the cardinality of the vector. + * + * @param centroidLengthSquare + * Square of the length of centroid + * @param centroid + * Centroid vector + */ + double distance(double centroidLengthSquare, Vector centroid, Vector v); + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/EuclideanDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/EuclideanDistanceMeasure.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/EuclideanDistanceMeasure.java new file mode 100644 index 0000000..665678d --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/EuclideanDistanceMeasure.java @@ -0,0 +1,41 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import org.apache.mahout.math.Vector; + +/** + * This class implements a Euclidean distance metric by summing the square root of the squared differences + * between each coordinate. + * <p/> + * If you don't care about the true distance and only need the values for comparison, then the base class, + * {@link SquaredEuclideanDistanceMeasure}, will be faster since it doesn't do the actual square root of the + * squared differences. + */ +public class EuclideanDistanceMeasure extends SquaredEuclideanDistanceMeasure { + + @Override + public double distance(Vector v1, Vector v2) { + return Math.sqrt(super.distance(v1, v2)); + } + + @Override + public double distance(double centroidLengthSquare, Vector centroid, Vector v) { + return Math.sqrt(super.distance(centroidLengthSquare, centroid, v)); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/MahalanobisDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/MahalanobisDistanceMeasure.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/MahalanobisDistanceMeasure.java new file mode 100644 index 0000000..17ee714 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/MahalanobisDistanceMeasure.java @@ -0,0 +1,197 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import java.io.DataInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.common.ClassUtils; +import org.apache.mahout.common.parameters.ClassParameter; +import org.apache.mahout.common.parameters.Parameter; +import org.apache.mahout.common.parameters.PathParameter; +import org.apache.mahout.math.Algebra; +import org.apache.mahout.math.CardinalityException; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixWritable; +import org.apache.mahout.math.SingularValueDecomposition; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +//See http://en.wikipedia.org/wiki/Mahalanobis_distance for details +public class MahalanobisDistanceMeasure implements DistanceMeasure { + + private Matrix inverseCovarianceMatrix; + private Vector meanVector; + + private ClassParameter vectorClass; + private ClassParameter matrixClass; + private List<Parameter<?>> parameters; + private Parameter<Path> inverseCovarianceFile; + private Parameter<Path> meanVectorFile; + + /*public MahalanobisDistanceMeasure(Vector meanVector,Matrix inputMatrix, boolean inversionNeeded) + { + this.meanVector=meanVector; + if (inversionNeeded) + setCovarianceMatrix(inputMatrix); + else + setInverseCovarianceMatrix(inputMatrix); + }*/ + + @Override + public void configure(Configuration jobConf) { + if (parameters == null) { + ParameteredGeneralizations.configureParameters(this, jobConf); + } + try { + if (inverseCovarianceFile.get() != null) { + FileSystem fs = FileSystem.get(inverseCovarianceFile.get().toUri(), jobConf); + MatrixWritable inverseCovarianceMatrix = + ClassUtils.instantiateAs((Class<? extends MatrixWritable>) matrixClass.get(), MatrixWritable.class); + if (!fs.exists(inverseCovarianceFile.get())) { + throw new FileNotFoundException(inverseCovarianceFile.get().toString()); + } + try (DataInputStream in = fs.open(inverseCovarianceFile.get())){ + inverseCovarianceMatrix.readFields(in); + } + this.inverseCovarianceMatrix = inverseCovarianceMatrix.get(); + Preconditions.checkArgument(this.inverseCovarianceMatrix != null, "inverseCovarianceMatrix not initialized"); + } + + if (meanVectorFile.get() != null) { + FileSystem fs = FileSystem.get(meanVectorFile.get().toUri(), jobConf); + VectorWritable meanVector = + ClassUtils.instantiateAs((Class<? extends VectorWritable>) vectorClass.get(), VectorWritable.class); + if (!fs.exists(meanVectorFile.get())) { + throw new FileNotFoundException(meanVectorFile.get().toString()); + } + try (DataInputStream in = fs.open(meanVectorFile.get())){ + meanVector.readFields(in); + } + this.meanVector = meanVector.get(); + Preconditions.checkArgument(this.meanVector != null, "meanVector not initialized"); + } + + } catch (IOException e) { + throw new IllegalStateException(e); + } + } + + @Override + public Collection<Parameter<?>> getParameters() { + return parameters; + } + + @Override + public void createParameters(String prefix, Configuration jobConf) { + parameters = new ArrayList<>(); + inverseCovarianceFile = new PathParameter(prefix, "inverseCovarianceFile", jobConf, null, + "Path on DFS to a file containing the inverse covariance matrix."); + parameters.add(inverseCovarianceFile); + + matrixClass = new ClassParameter(prefix, "maxtrixClass", jobConf, DenseMatrix.class, + "Class<Matix> file specified in parameter inverseCovarianceFile has been serialized with."); + parameters.add(matrixClass); + + meanVectorFile = new PathParameter(prefix, "meanVectorFile", jobConf, null, + "Path on DFS to a file containing the mean Vector."); + parameters.add(meanVectorFile); + + vectorClass = new ClassParameter(prefix, "vectorClass", jobConf, DenseVector.class, + "Class file specified in parameter meanVectorFile has been serialized with."); + parameters.add(vectorClass); + } + + /** + * @param v The vector to compute the distance to + * @return Mahalanobis distance of a multivariate vector + */ + public double distance(Vector v) { + return Math.sqrt(v.minus(meanVector).dot(Algebra.mult(inverseCovarianceMatrix, v.minus(meanVector)))); + } + + @Override + public double distance(Vector v1, Vector v2) { + if (v1.size() != v2.size()) { + throw new CardinalityException(v1.size(), v2.size()); + } + return Math.sqrt(v1.minus(v2).dot(Algebra.mult(inverseCovarianceMatrix, v1.minus(v2)))); + } + + @Override + public double distance(double centroidLengthSquare, Vector centroid, Vector v) { + return distance(centroid, v); // TODO + } + + public void setInverseCovarianceMatrix(Matrix inverseCovarianceMatrix) { + Preconditions.checkArgument(inverseCovarianceMatrix != null, "inverseCovarianceMatrix not initialized"); + this.inverseCovarianceMatrix = inverseCovarianceMatrix; + } + + + /** + * Computes the inverse covariance from the input covariance matrix given in input. + * + * @param m A covariance matrix. + * @throws IllegalArgumentException if <tt>eigen values equal to 0 found</tt>. + */ + public void setCovarianceMatrix(Matrix m) { + if (m.numRows() != m.numCols()) { + throw new CardinalityException(m.numRows(), m.numCols()); + } + // See http://www.mlahanas.de/Math/svd.htm for details, + // which specifically details the case of covariance matrix inversion + // Complexity: O(min(nm2,mn2)) + SingularValueDecomposition svd = new SingularValueDecomposition(m); + Matrix sInv = svd.getS(); + // Inverse Diagonal Elems + for (int i = 0; i < sInv.numRows(); i++) { + double diagElem = sInv.get(i, i); + if (diagElem > 0.0) { + sInv.set(i, i, 1 / diagElem); + } else { + throw new IllegalStateException("Eigen Value equals to 0 found."); + } + } + inverseCovarianceMatrix = svd.getU().times(sInv.times(svd.getU().transpose())); + Preconditions.checkArgument(inverseCovarianceMatrix != null, "inverseCovarianceMatrix not initialized"); + } + + public Matrix getInverseCovarianceMatrix() { + return inverseCovarianceMatrix; + } + + public void setMeanVector(Vector meanVector) { + Preconditions.checkArgument(meanVector != null, "meanVector not initialized"); + this.meanVector = meanVector; + } + + public Vector getMeanVector() { + return meanVector; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/ManhattanDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/ManhattanDistanceMeasure.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/ManhattanDistanceMeasure.java new file mode 100644 index 0000000..5c32fcf --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/ManhattanDistanceMeasure.java @@ -0,0 +1,70 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import java.util.Collection; +import java.util.Collections; + +import org.apache.hadoop.conf.Configuration; +import org.apache.mahout.common.parameters.Parameter; +import org.apache.mahout.math.CardinalityException; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; + +/** + * This class implements a "manhattan distance" metric by summing the absolute values of the difference + * between each coordinate + */ +public class ManhattanDistanceMeasure implements DistanceMeasure { + + public static double distance(double[] p1, double[] p2) { + double result = 0.0; + for (int i = 0; i < p1.length; i++) { + result += Math.abs(p2[i] - p1[i]); + } + return result; + } + + @Override + public void configure(Configuration job) { + // nothing to do + } + + @Override + public Collection<Parameter<?>> getParameters() { + return Collections.emptyList(); + } + + @Override + public void createParameters(String prefix, Configuration jobConf) { + // nothing to do + } + + @Override + public double distance(Vector v1, Vector v2) { + if (v1.size() != v2.size()) { + throw new CardinalityException(v1.size(), v2.size()); + } + return v1.aggregate(v2, Functions.PLUS, Functions.MINUS_ABS); + } + + @Override + public double distance(double centroidLengthSquare, Vector centroid, Vector v) { + return distance(centroid, v); // TODO + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/MinkowskiDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/MinkowskiDistanceMeasure.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/MinkowskiDistanceMeasure.java new file mode 100644 index 0000000..c3a48cb --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/MinkowskiDistanceMeasure.java @@ -0,0 +1,93 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.mahout.common.parameters.DoubleParameter; +import org.apache.mahout.common.parameters.Parameter; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; + +/** + * Implement Minkowski distance, a real-valued generalization of the + * integral L(n) distances: Manhattan = L1, Euclidean = L2. + * For high numbers of dimensions, very high exponents give more useful distances. + * + * Note: Math.pow is clever about integer-valued doubles. + **/ +public class MinkowskiDistanceMeasure implements DistanceMeasure { + + private static final double EXPONENT = 3.0; + + private List<Parameter<?>> parameters; + private double exponent = EXPONENT; + + public MinkowskiDistanceMeasure() { + } + + public MinkowskiDistanceMeasure(double exponent) { + this.exponent = exponent; + } + + @Override + public void createParameters(String prefix, Configuration conf) { + parameters = new ArrayList<>(); + Parameter<?> param = + new DoubleParameter(prefix, "exponent", conf, EXPONENT, "Exponent for Fractional Lagrange distance"); + parameters.add(param); + } + + @Override + public Collection<Parameter<?>> getParameters() { + return parameters; + } + + @Override + public void configure(Configuration jobConf) { + if (parameters == null) { + ParameteredGeneralizations.configureParameters(this, jobConf); + } + } + + public double getExponent() { + return exponent; + } + + public void setExponent(double exponent) { + this.exponent = exponent; + } + + /** + * Math.pow is clever about integer-valued doubles + */ + @Override + public double distance(Vector v1, Vector v2) { + return Math.pow(v1.aggregate(v2, Functions.PLUS, Functions.minusAbsPow(exponent)), 1.0 / exponent); + } + + // TODO: how? + @Override + public double distance(double centroidLengthSquare, Vector centroid, Vector v) { + return distance(centroid, v); // TODO - can this use centroidLengthSquare somehow? + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/SquaredEuclideanDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/SquaredEuclideanDistanceMeasure.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/SquaredEuclideanDistanceMeasure.java new file mode 100644 index 0000000..66da121 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/SquaredEuclideanDistanceMeasure.java @@ -0,0 +1,59 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import java.util.Collection; +import java.util.Collections; + +import org.apache.hadoop.conf.Configuration; +import org.apache.mahout.common.parameters.Parameter; +import org.apache.mahout.math.Vector; + +/** + * Like {@link EuclideanDistanceMeasure} but it does not take the square root. + * <p/> + * Thus, it is not actually the Euclidean Distance, but it is saves on computation when you only need the + * distance for comparison and don't care about the actual value as a distance. + */ +public class SquaredEuclideanDistanceMeasure implements DistanceMeasure { + + @Override + public void configure(Configuration job) { + // nothing to do + } + + @Override + public Collection<Parameter<?>> getParameters() { + return Collections.emptyList(); + } + + @Override + public void createParameters(String prefix, Configuration jobConf) { + // nothing to do + } + + @Override + public double distance(Vector v1, Vector v2) { + return v2.getDistanceSquared(v1); + } + + @Override + public double distance(double centroidLengthSquare, Vector centroid, Vector v) { + return centroidLengthSquare - 2 * v.dot(centroid) + v.getLengthSquared(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/TanimotoDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/TanimotoDistanceMeasure.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/TanimotoDistanceMeasure.java new file mode 100644 index 0000000..cfeb119 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/TanimotoDistanceMeasure.java @@ -0,0 +1,69 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; + +/** + * Tanimoto coefficient implementation. + * + * http://en.wikipedia.org/wiki/Jaccard_index + */ +public class TanimotoDistanceMeasure extends WeightedDistanceMeasure { + + /** + * Calculates the distance between two vectors. + * + * The coefficient (a measure of similarity) is: T(a, b) = a.b / (|a|^2 + |b|^2 - a.b) + * + * The distance d(a,b) = 1 - T(a,b) + * + * @return 0 for perfect match, > 0 for greater distance + */ + @Override + public double distance(Vector a, Vector b) { + double ab; + double denominator; + if (getWeights() != null) { + ab = a.times(b).aggregate(getWeights(), Functions.PLUS, Functions.MULT); + denominator = a.aggregate(getWeights(), Functions.PLUS, Functions.MULT_SQUARE_LEFT) + + b.aggregate(getWeights(), Functions.PLUS, Functions.MULT_SQUARE_LEFT) + - ab; + } else { + ab = b.dot(a); // b is SequentialAccess + denominator = a.getLengthSquared() + b.getLengthSquared() - ab; + } + + if (denominator < ab) { // correct for fp round-off: distance >= 0 + denominator = ab; + } + if (denominator > 0) { + // denominator == 0 only when dot(a,a) == dot(b,b) == dot(a,b) == 0 + return 1.0 - ab / denominator; + } else { + return 0.0; + } + } + + @Override + public double distance(double centroidLengthSquare, Vector centroid, Vector v) { + return distance(centroid, v); // TODO + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/WeightedDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/WeightedDistanceMeasure.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/WeightedDistanceMeasure.java new file mode 100644 index 0000000..1acbe86 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/WeightedDistanceMeasure.java @@ -0,0 +1,93 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import java.io.DataInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.common.ClassUtils; +import org.apache.mahout.common.parameters.ClassParameter; +import org.apache.mahout.common.parameters.Parameter; +import org.apache.mahout.common.parameters.PathParameter; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +/** Abstract implementation of DistanceMeasure with support for weights. */ +public abstract class WeightedDistanceMeasure implements DistanceMeasure { + + private List<Parameter<?>> parameters; + private Parameter<Path> weightsFile; + private ClassParameter vectorClass; + private Vector weights; + + @Override + public void createParameters(String prefix, Configuration jobConf) { + parameters = new ArrayList<>(); + weightsFile = new PathParameter(prefix, "weightsFile", jobConf, null, + "Path on DFS to a file containing the weights."); + parameters.add(weightsFile); + vectorClass = new ClassParameter(prefix, "vectorClass", jobConf, DenseVector.class, + "Class<Vector> file specified in parameter weightsFile has been serialized with."); + parameters.add(vectorClass); + } + + @Override + public Collection<Parameter<?>> getParameters() { + return parameters; + } + + @Override + public void configure(Configuration jobConf) { + if (parameters == null) { + ParameteredGeneralizations.configureParameters(this, jobConf); + } + try { + if (weightsFile.get() != null) { + FileSystem fs = FileSystem.get(weightsFile.get().toUri(), jobConf); + VectorWritable weights = + ClassUtils.instantiateAs((Class<? extends VectorWritable>) vectorClass.get(), VectorWritable.class); + if (!fs.exists(weightsFile.get())) { + throw new FileNotFoundException(weightsFile.get().toString()); + } + try (DataInputStream in = fs.open(weightsFile.get())){ + weights.readFields(in); + } + this.weights = weights.get(); + } + } catch (IOException e) { + throw new IllegalStateException(e); + } + } + + public Vector getWeights() { + return weights; + } + + public void setWeights(Vector weights) { + this.weights = weights; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/WeightedEuclideanDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/WeightedEuclideanDistanceMeasure.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/WeightedEuclideanDistanceMeasure.java new file mode 100644 index 0000000..4c78d9f --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/WeightedEuclideanDistanceMeasure.java @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; + +/** + * This class implements a Euclidean distance metric by summing the square root of the squared differences + * between each coordinate, optionally adding weights. + */ +public class WeightedEuclideanDistanceMeasure extends WeightedDistanceMeasure { + + @Override + public double distance(Vector p1, Vector p2) { + double result = 0; + Vector res = p2.minus(p1); + Vector theWeights = getWeights(); + if (theWeights == null) { + for (Element elt : res.nonZeroes()) { + result += elt.get() * elt.get(); + } + } else { + for (Element elt : res.nonZeroes()) { + result += elt.get() * elt.get() * theWeights.get(elt.index()); + } + } + return Math.sqrt(result); + } + + @Override + public double distance(double centroidLengthSquare, Vector centroid, Vector v) { + return distance(centroid, v); // TODO + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/WeightedManhattanDistanceMeasure.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/WeightedManhattanDistanceMeasure.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/WeightedManhattanDistanceMeasure.java new file mode 100644 index 0000000..2c280e2 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/distance/WeightedManhattanDistanceMeasure.java @@ -0,0 +1,53 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.distance; + +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; + +/** + * This class implements a "Manhattan distance" metric by summing the absolute values of the difference + * between each coordinate, optionally with weights. + */ +public class WeightedManhattanDistanceMeasure extends WeightedDistanceMeasure { + + @Override + public double distance(Vector p1, Vector p2) { + double result = 0; + + Vector res = p2.minus(p1); + if (getWeights() == null) { + for (Element elt : res.nonZeroes()) { + result += Math.abs(elt.get()); + } + + } else { + for (Element elt : res.nonZeroes()) { + result += Math.abs(elt.get() * getWeights().get(elt.index())); + } + } + + return result; + } + + @Override + public double distance(double centroidLengthSquare, Vector centroid, Vector v) { + return distance(centroid, v); // TODO + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/CopyConstructorIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/CopyConstructorIterator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/CopyConstructorIterator.java new file mode 100644 index 0000000..73cc821 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/CopyConstructorIterator.java @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.util.Iterator; + +import com.google.common.base.Function; +import com.google.common.collect.ForwardingIterator; +import com.google.common.collect.Iterators; + +/** + * An iterator that copies the values in an underlying iterator by finding an appropriate copy constructor. + */ +public final class CopyConstructorIterator<T> extends ForwardingIterator<T> { + + private final Iterator<T> delegate; + private Constructor<T> constructor; + + public CopyConstructorIterator(Iterator<? extends T> copyFrom) { + this.delegate = Iterators.transform( + copyFrom, + new Function<T,T>() { + @Override + public T apply(T from) { + if (constructor == null) { + Class<T> elementClass = (Class<T>) from.getClass(); + try { + constructor = elementClass.getConstructor(elementClass); + } catch (NoSuchMethodException e) { + throw new IllegalStateException(e); + } + } + try { + return constructor.newInstance(from); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + throw new IllegalStateException(e); + } + } + }); + } + + @Override + protected Iterator<T> delegate() { + return delegate; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/CountingIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/CountingIterator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/CountingIterator.java new file mode 100644 index 0000000..658c1f1 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/CountingIterator.java @@ -0,0 +1,43 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import com.google.common.collect.AbstractIterator; + +/** + * Iterates over the integers from 0 through {@code to-1}. + */ +public final class CountingIterator extends AbstractIterator<Integer> { + + private int count; + private final int to; + + public CountingIterator(int to) { + this.to = to; + } + + @Override + protected Integer computeNext() { + if (count < to) { + return count++; + } else { + return endOfData(); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterable.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterable.java new file mode 100644 index 0000000..cfc18d6 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterable.java @@ -0,0 +1,88 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; +import java.util.Iterator; + +import com.google.common.base.Charsets; + +/** + * Iterable representing the lines of a text file. It can produce an {@link Iterator} over those lines. This + * assumes the text file's lines are delimited in a manner consistent with how {@link java.io.BufferedReader} + * defines lines. + * + */ +public final class FileLineIterable implements Iterable<String> { + + private final InputStream is; + private final Charset encoding; + private final boolean skipFirstLine; + private final String origFilename; + + /** Creates a {@link FileLineIterable} over a given file, assuming a UTF-8 encoding. */ + public FileLineIterable(File file) throws IOException { + this(file, Charsets.UTF_8, false); + } + + /** Creates a {@link FileLineIterable} over a given file, assuming a UTF-8 encoding. */ + public FileLineIterable(File file, boolean skipFirstLine) throws IOException { + this(file, Charsets.UTF_8, skipFirstLine); + } + + /** Creates a {@link FileLineIterable} over a given file, using the given encoding. */ + public FileLineIterable(File file, Charset encoding, boolean skipFirstLine) throws IOException { + this(FileLineIterator.getFileInputStream(file), encoding, skipFirstLine); + } + + public FileLineIterable(InputStream is) { + this(is, Charsets.UTF_8, false); + } + + public FileLineIterable(InputStream is, boolean skipFirstLine) { + this(is, Charsets.UTF_8, skipFirstLine); + } + + public FileLineIterable(InputStream is, Charset encoding, boolean skipFirstLine) { + this.is = is; + this.encoding = encoding; + this.skipFirstLine = skipFirstLine; + this.origFilename = ""; + } + + public FileLineIterable(InputStream is, Charset encoding, boolean skipFirstLine, String filename) { + this.is = is; + this.encoding = encoding; + this.skipFirstLine = skipFirstLine; + this.origFilename = filename; + } + + + @Override + public Iterator<String> iterator() { + try { + return new FileLineIterator(is, encoding, skipFirstLine, this.origFilename); + } catch (IOException ioe) { + throw new IllegalStateException(ioe); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterator.java new file mode 100644 index 0000000..b7cc51e --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterator.java @@ -0,0 +1,167 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.io.BufferedReader; +import java.io.Closeable; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.Charset; +import java.util.zip.GZIPInputStream; +import java.util.zip.ZipInputStream; + +import com.google.common.base.Charsets; +import com.google.common.collect.AbstractIterator; +import com.google.common.io.Closeables; +import com.google.common.io.Files; +import org.apache.mahout.cf.taste.impl.common.SkippingIterator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Iterates over the lines of a text file. This assumes the text file's lines are delimited in a manner + * consistent with how {@link BufferedReader} defines lines. + * <p/> + * This class will uncompress files that end in .zip or .gz accordingly, too. + */ +public final class FileLineIterator extends AbstractIterator<String> implements SkippingIterator<String>, Closeable { + + private final BufferedReader reader; + + private static final Logger log = LoggerFactory.getLogger(FileLineIterator.class); + + /** + * Creates a {@link FileLineIterator} over a given file, assuming a UTF-8 encoding. + * + * @throws java.io.FileNotFoundException if the file does not exist + * @throws IOException + * if the file cannot be read + */ + public FileLineIterator(File file) throws IOException { + this(file, Charsets.UTF_8, false); + } + + /** + * Creates a {@link FileLineIterator} over a given file, assuming a UTF-8 encoding. + * + * @throws java.io.FileNotFoundException if the file does not exist + * @throws IOException if the file cannot be read + */ + public FileLineIterator(File file, boolean skipFirstLine) throws IOException { + this(file, Charsets.UTF_8, skipFirstLine); + } + + /** + * Creates a {@link FileLineIterator} over a given file, using the given encoding. + * + * @throws java.io.FileNotFoundException if the file does not exist + * @throws IOException if the file cannot be read + */ + public FileLineIterator(File file, Charset encoding, boolean skipFirstLine) throws IOException { + this(getFileInputStream(file), encoding, skipFirstLine); + } + + public FileLineIterator(InputStream is) throws IOException { + this(is, Charsets.UTF_8, false); + } + + public FileLineIterator(InputStream is, boolean skipFirstLine) throws IOException { + this(is, Charsets.UTF_8, skipFirstLine); + } + + public FileLineIterator(InputStream is, Charset encoding, boolean skipFirstLine) throws IOException { + reader = new BufferedReader(new InputStreamReader(is, encoding)); + if (skipFirstLine) { + reader.readLine(); + } + } + + public FileLineIterator(InputStream is, Charset encoding, boolean skipFirstLine, String filename) + throws IOException { + InputStream compressedInputStream; + + if ("gz".equalsIgnoreCase(Files.getFileExtension(filename.toLowerCase()))) { + compressedInputStream = new GZIPInputStream(is); + } else if ("zip".equalsIgnoreCase(Files.getFileExtension(filename.toLowerCase()))) { + compressedInputStream = new ZipInputStream(is); + } else { + compressedInputStream = is; + } + + reader = new BufferedReader(new InputStreamReader(compressedInputStream, encoding)); + if (skipFirstLine) { + reader.readLine(); + } + } + + static InputStream getFileInputStream(File file) throws IOException { + InputStream is = new FileInputStream(file); + String name = file.getName(); + if ("gz".equalsIgnoreCase(Files.getFileExtension(name.toLowerCase()))) { + return new GZIPInputStream(is); + } else if ("zip".equalsIgnoreCase(Files.getFileExtension(name.toLowerCase()))) { + return new ZipInputStream(is); + } else { + return is; + } + } + + @Override + protected String computeNext() { + String line; + try { + line = reader.readLine(); + } catch (IOException ioe) { + try { + close(); + } catch (IOException e) { + log.error(e.getMessage(), e); + } + throw new IllegalStateException(ioe); + } + return line == null ? endOfData() : line; + } + + + @Override + public void skip(int n) { + try { + for (int i = 0; i < n; i++) { + if (reader.readLine() == null) { + break; + } + } + } catch (IOException ioe) { + try { + close(); + } catch (IOException e) { + throw new IllegalStateException(e); + } + } + } + + @Override + public void close() throws IOException { + endOfData(); + Closeables.close(reader, true); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/FixedSizeSamplingIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/FixedSizeSamplingIterator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/FixedSizeSamplingIterator.java new file mode 100644 index 0000000..1905654 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/FixedSizeSamplingIterator.java @@ -0,0 +1,59 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.util.Iterator; +import java.util.List; +import java.util.Random; + +import com.google.common.collect.ForwardingIterator; +import com.google.common.collect.Lists; +import org.apache.mahout.common.RandomUtils; + +/** + * Sample a fixed number of elements from an Iterator. The results can appear in any order. + */ +public final class FixedSizeSamplingIterator<T> extends ForwardingIterator<T> { + + private final Iterator<T> delegate; + + public FixedSizeSamplingIterator(int size, Iterator<T> source) { + List<T> buf = Lists.newArrayListWithCapacity(size); + int sofar = 0; + Random random = RandomUtils.getRandom(); + while (source.hasNext()) { + T v = source.next(); + sofar++; + if (buf.size() < size) { + buf.add(v); + } else { + int position = random.nextInt(sofar); + if (position < buf.size()) { + buf.set(position, v); + } + } + } + delegate = buf.iterator(); + } + + @Override + protected Iterator<T> delegate() { + return delegate; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterable.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterable.java new file mode 100644 index 0000000..425b44b --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterable.java @@ -0,0 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.util.Iterator; + +/** + * Wraps an {@link Iterable} whose {@link Iterable#iterator()} returns only some subset of the elements that + * it would, as determined by a iterator rate parameter. + */ +public final class SamplingIterable<T> implements Iterable<T> { + + private final Iterable<? extends T> delegate; + private final double samplingRate; + + public SamplingIterable(Iterable<? extends T> delegate, double samplingRate) { + this.delegate = delegate; + this.samplingRate = samplingRate; + } + + @Override + public Iterator<T> iterator() { + return new SamplingIterator<>(delegate.iterator(), samplingRate); + } + + public static <T> Iterable<T> maybeWrapIterable(Iterable<T> delegate, double samplingRate) { + return samplingRate >= 1.0 ? delegate : new SamplingIterable<>(delegate, samplingRate); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java new file mode 100644 index 0000000..2ba46fd --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java @@ -0,0 +1,73 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.util.Iterator; + +import com.google.common.base.Preconditions; +import com.google.common.collect.AbstractIterator; +import org.apache.commons.math3.distribution.PascalDistribution; +import org.apache.mahout.cf.taste.impl.common.SkippingIterator; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.RandomWrapper; + +/** + * Wraps an {@link Iterator} and returns only some subset of the elements that it would, as determined by a + * iterator rate parameter. + */ +public final class SamplingIterator<T> extends AbstractIterator<T> { + + private final PascalDistribution geometricDistribution; + private final Iterator<? extends T> delegate; + + public SamplingIterator(Iterator<? extends T> delegate, double samplingRate) { + this(RandomUtils.getRandom(), delegate, samplingRate); + } + + public SamplingIterator(RandomWrapper random, Iterator<? extends T> delegate, double samplingRate) { + Preconditions.checkNotNull(delegate); + Preconditions.checkArgument(samplingRate > 0.0 && samplingRate <= 1.0, + "Must be: 0.0 < samplingRate <= 1.0. But samplingRate = " + samplingRate); + // Geometric distribution is special case of negative binomial (aka Pascal) with r=1: + geometricDistribution = new PascalDistribution(random.getRandomGenerator(), 1, samplingRate); + this.delegate = delegate; + } + + @Override + protected T computeNext() { + int toSkip = geometricDistribution.sample(); + if (delegate instanceof SkippingIterator<?>) { + SkippingIterator<? extends T> skippingDelegate = (SkippingIterator<? extends T>) delegate; + skippingDelegate.skip(toSkip); + if (skippingDelegate.hasNext()) { + return skippingDelegate.next(); + } + } else { + for (int i = 0; i < toSkip && delegate.hasNext(); i++) { + delegate.next(); + } + if (delegate.hasNext()) { + return delegate.next(); + } + } + return endOfData(); + } + + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/StableFixedSizeSamplingIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/StableFixedSizeSamplingIterator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/StableFixedSizeSamplingIterator.java new file mode 100644 index 0000000..c4ddf7b --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/StableFixedSizeSamplingIterator.java @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Random; + +import com.google.common.base.Function; +import com.google.common.collect.ForwardingIterator; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.RandomUtils; + +/** + * Sample a fixed number of elements from an Iterator. The results will appear in the original order at some + * cost in time and memory relative to a FixedSizeSampler. + */ +public class StableFixedSizeSamplingIterator<T> extends ForwardingIterator<T> { + + private final Iterator<T> delegate; + + public StableFixedSizeSamplingIterator(int size, Iterator<T> source) { + List<Pair<Integer,T>> buf = Lists.newArrayListWithCapacity(size); + int sofar = 0; + Random random = RandomUtils.getRandom(); + while (source.hasNext()) { + T v = source.next(); + sofar++; + if (buf.size() < size) { + buf.add(new Pair<>(sofar, v)); + } else { + int position = random.nextInt(sofar); + if (position < buf.size()) { + buf.set(position, new Pair<>(sofar, v)); + } + } + } + + Collections.sort(buf); + delegate = Iterators.transform(buf.iterator(), + new Function<Pair<Integer,T>,T>() { + @Override + public T apply(Pair<Integer,T> from) { + return from.getSecond(); + } + }); + } + + @Override + protected Iterator<T> delegate() { + return delegate; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/StringRecordIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/StringRecordIterator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/StringRecordIterator.java new file mode 100644 index 0000000..73b841e --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/common/iterator/StringRecordIterator.java @@ -0,0 +1,55 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common.iterator; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.regex.Pattern; + +import com.google.common.base.Function; +import com.google.common.collect.ForwardingIterator; +import com.google.common.collect.Iterators; +import org.apache.mahout.common.Pair; + +public class StringRecordIterator extends ForwardingIterator<Pair<List<String>,Long>> { + + private static final Long ONE = 1L; + + private final Pattern splitter; + private final Iterator<Pair<List<String>,Long>> delegate; + + public StringRecordIterator(Iterable<String> stringIterator, String pattern) { + this.splitter = Pattern.compile(pattern); + delegate = Iterators.transform( + stringIterator.iterator(), + new Function<String,Pair<List<String>,Long>>() { + @Override + public Pair<List<String>,Long> apply(String from) { + String[] items = splitter.split(from); + return new Pair<>(Arrays.asList(items), ONE); + } + }); + } + + @Override + protected Iterator<Pair<List<String>,Long>> delegate() { + return delegate; + } + +}
