http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/ProjectionSearch.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/ProjectionSearch.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/ProjectionSearch.java new file mode 100644 index 0000000..61a9f56 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/ProjectionSearch.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.neighborhood; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Set; + +import com.google.common.base.Preconditions; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.BoundType; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import com.google.common.collect.TreeMultiset; +import org.apache.mahout.math.random.RandomProjector; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.random.WeightedThing; + +/** + * Does approximate nearest neighbor dudes search by projecting the data. + */ +public class ProjectionSearch extends UpdatableSearcher { + + /** + * A lists of tree sets containing the scalar projections of each vector. + * The elements in a TreeMultiset are WeightedThing<Integer>, where the weight is the scalar + * projection of the vector at the index pointed to by the Integer from the referenceVectors list + * on the basis vector whose index is the same as the index of the TreeSet in the List. + */ + private List<TreeMultiset<WeightedThing<Vector>>> scalarProjections; + + /** + * The list of random normalized projection vectors forming a basis. + * The TreeSet of scalar projections at index i in scalarProjections corresponds to the vector + * at index i from basisVectors. + */ + private Matrix basisMatrix; + + /** + * The number of elements to consider on both sides in the ball around the vector found by the + * search in a TreeSet from scalarProjections. + */ + private final int searchSize; + + private final int numProjections; + private boolean initialized = false; + + private void initialize(int numDimensions) { + if (initialized) { + return; + } + initialized = true; + basisMatrix = RandomProjector.generateBasisNormal(numProjections, numDimensions); + scalarProjections = Lists.newArrayList(); + for (int i = 0; i < numProjections; ++i) { + scalarProjections.add(TreeMultiset.<WeightedThing<Vector>>create()); + } + } + + public ProjectionSearch(DistanceMeasure distanceMeasure, int numProjections, int searchSize) { + super(distanceMeasure); + Preconditions.checkArgument(numProjections > 0 && numProjections < 100, + "Unreasonable value for number of projections. Must be: 0 < numProjections < 100"); + + this.searchSize = searchSize; + this.numProjections = numProjections; + } + + /** + * Adds a WeightedVector into the set of projections for later searching. + * @param vector The WeightedVector to add. + */ + @Override + public void add(Vector vector) { + initialize(vector.size()); + Vector projection = basisMatrix.times(vector); + // Add the the new vector and the projected distance to each set separately. + int i = 0; + for (TreeMultiset<WeightedThing<Vector>> s : scalarProjections) { + s.add(new WeightedThing<>(vector, projection.get(i++))); + } + int numVectors = scalarProjections.get(0).size(); + for (TreeMultiset<WeightedThing<Vector>> s : scalarProjections) { + Preconditions.checkArgument(s.size() == numVectors, "Number of vectors in projection sets " + + "differ"); + double firstWeight = s.firstEntry().getElement().getWeight(); + for (WeightedThing<Vector> w : s) { + Preconditions.checkArgument(firstWeight <= w.getWeight(), "Weights not in non-decreasing " + + "order"); + firstWeight = w.getWeight(); + } + } + } + + /** + * Returns the number of scalarProjections that we can search + * @return The number of scalarProjections added to the search so far. + */ + @Override + public int size() { + if (scalarProjections == null) { + return 0; + } + return scalarProjections.get(0).size(); + } + + /** + * Searches for the query vector returning the closest limit referenceVectors. + * + * @param query the vector to search for. + * @param limit the number of results to return. + * @return a list of Vectors wrapped in WeightedThings where the "thing"'s weight is the + * distance. + */ + @Override + public List<WeightedThing<Vector>> search(Vector query, int limit) { + Set<Vector> candidates = Sets.newHashSet(); + + Iterator<? extends Vector> projections = basisMatrix.iterator(); + for (TreeMultiset<WeightedThing<Vector>> v : scalarProjections) { + Vector basisVector = projections.next(); + WeightedThing<Vector> projectedQuery = new WeightedThing<>(query, + query.dot(basisVector)); + for (WeightedThing<Vector> candidate : Iterables.concat( + Iterables.limit(v.tailMultiset(projectedQuery, BoundType.CLOSED), searchSize), + Iterables.limit(v.headMultiset(projectedQuery, BoundType.OPEN).descendingMultiset(), searchSize))) { + candidates.add(candidate.getValue()); + } + } + + // If searchSize * scalarProjections.size() is small enough not to cause much memory pressure, + // this is probably just as fast as a priority queue here. + List<WeightedThing<Vector>> top = Lists.newArrayList(); + for (Vector candidate : candidates) { + top.add(new WeightedThing<>(candidate, distanceMeasure.distance(query, candidate))); + } + Collections.sort(top); + return top.subList(0, Math.min(limit, top.size())); + } + + /** + * Returns the closest vector to the query. + * When only one the nearest vector is needed, use this method, NOT search(query, limit) because + * it's faster (less overhead). + * + * @param query the vector to search for + * @param differentThanQuery if true, returns the closest vector different than the query (this + * only matters if the query is among the searched vectors), otherwise, + * returns the closest vector to the query (even the same vector). + * @return the weighted vector closest to the query + */ + @Override + public WeightedThing<Vector> searchFirst(Vector query, boolean differentThanQuery) { + double bestDistance = Double.POSITIVE_INFINITY; + Vector bestVector = null; + + Iterator<? extends Vector> projections = basisMatrix.iterator(); + for (TreeMultiset<WeightedThing<Vector>> v : scalarProjections) { + Vector basisVector = projections.next(); + WeightedThing<Vector> projectedQuery = new WeightedThing<>(query, query.dot(basisVector)); + for (WeightedThing<Vector> candidate : Iterables.concat( + Iterables.limit(v.tailMultiset(projectedQuery, BoundType.CLOSED), searchSize), + Iterables.limit(v.headMultiset(projectedQuery, BoundType.OPEN).descendingMultiset(), searchSize))) { + double distance = distanceMeasure.distance(query, candidate.getValue()); + if (distance < bestDistance && (!differentThanQuery || !candidate.getValue().equals(query))) { + bestDistance = distance; + bestVector = candidate.getValue(); + } + } + } + + return new WeightedThing<>(bestVector, bestDistance); + } + + @Override + public Iterator<Vector> iterator() { + return new AbstractIterator<Vector>() { + private final Iterator<WeightedThing<Vector>> projected = scalarProjections.get(0).iterator(); + @Override + protected Vector computeNext() { + if (!projected.hasNext()) { + return endOfData(); + } + return projected.next().getValue(); + } + }; + } + + @Override + public boolean remove(Vector vector, double epsilon) { + WeightedThing<Vector> toRemove = searchFirst(vector, false); + if (toRemove.getWeight() < epsilon) { + Iterator<? extends Vector> basisVectors = basisMatrix.iterator(); + for (TreeMultiset<WeightedThing<Vector>> projection : scalarProjections) { + if (!projection.remove(new WeightedThing<>(vector, vector.dot(basisVectors.next())))) { + throw new RuntimeException("Internal inconsistency in ProjectionSearch"); + } + } + return true; + } else { + return false; + } + } + + @Override + public void clear() { + if (scalarProjections == null) { + return; + } + for (TreeMultiset<WeightedThing<Vector>> set : scalarProjections) { + set.clear(); + } + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/Searcher.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/Searcher.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/Searcher.java new file mode 100644 index 0000000..dd387b5 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/Searcher.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.neighborhood; + +import java.util.List; + +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import org.apache.lucene.util.PriorityQueue; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.MatrixSlice; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.WeightedVector; +import org.apache.mahout.math.random.WeightedThing; + +/** + * Describes how to search a bunch of vectors. + * The vectors can be of any type (weighted, sparse, ...) but only the values of the vector matter + * when searching (weights, indices, ...) will not. + * + * When iterating through a Searcher, the Vectors added to it are returned. + */ +public abstract class Searcher implements Iterable<Vector> { + protected DistanceMeasure distanceMeasure; + + protected Searcher(DistanceMeasure distanceMeasure) { + this.distanceMeasure = distanceMeasure; + } + + public DistanceMeasure getDistanceMeasure() { + return distanceMeasure; + } + + /** + * Add a new Vector to the Searcher that will be checked when getting + * the nearest neighbors. + * + * The vector IS NOT CLONED. Do not modify the vector externally otherwise the internal + * Searcher data structures could be invalidated. + */ + public abstract void add(Vector vector); + + /** + * Returns the number of WeightedVectors being searched for nearest neighbors. + */ + public abstract int size(); + + /** + * When querying the Searcher for the closest vectors, a list of WeightedThing<Vector>s is + * returned. The value of the WeightedThing is the neighbor and the weight is the + * the distance (calculated by some metric - see a concrete implementation) between the query + * and neighbor. + * The actual type of vector in the pair is the same as the vector added to the Searcher. + * @param query the vector to search for + * @param limit the number of results to return + * @return the list of weighted vectors closest to the query + */ + public abstract List<WeightedThing<Vector>> search(Vector query, int limit); + + public List<List<WeightedThing<Vector>>> search(Iterable<? extends Vector> queries, int limit) { + List<List<WeightedThing<Vector>>> results = Lists.newArrayListWithExpectedSize(Iterables.size(queries)); + for (Vector query : queries) { + results.add(search(query, limit)); + } + return results; + } + + /** + * Returns the closest vector to the query. + * When only one the nearest vector is needed, use this method, NOT search(query, limit) because + * it's faster (less overhead). + * + * @param query the vector to search for + * @param differentThanQuery if true, returns the closest vector different than the query (this + * only matters if the query is among the searched vectors), otherwise, + * returns the closest vector to the query (even the same vector). + * @return the weighted vector closest to the query + */ + public abstract WeightedThing<Vector> searchFirst(Vector query, boolean differentThanQuery); + + public List<WeightedThing<Vector>> searchFirst(Iterable<? extends Vector> queries, boolean differentThanQuery) { + List<WeightedThing<Vector>> results = Lists.newArrayListWithExpectedSize(Iterables.size(queries)); + for (Vector query : queries) { + results.add(searchFirst(query, differentThanQuery)); + } + return results; + } + + /** + * Adds all the data elements in the Searcher. + * + * @param data an iterable of WeightedVectors to add. + */ + public void addAll(Iterable<? extends Vector> data) { + for (Vector vector : data) { + add(vector); + } + } + + /** + * Adds all the data elements in the Searcher. + * + * @param data an iterable of MatrixSlices to add. + */ + public void addAllMatrixSlices(Iterable<MatrixSlice> data) { + for (MatrixSlice slice : data) { + add(slice.vector()); + } + } + + public void addAllMatrixSlicesAsWeightedVectors(Iterable<MatrixSlice> data) { + for (MatrixSlice slice : data) { + add(new WeightedVector(slice.vector(), 1, slice.index())); + } + } + + public boolean remove(Vector v, double epsilon) { + throw new UnsupportedOperationException("Can't remove a vector from a " + + this.getClass().getName()); + } + + public void clear() { + throw new UnsupportedOperationException("Can't remove vectors from a " + + this.getClass().getName()); + } + + /** + * Returns a bounded size priority queue, in reverse order that keeps track of the best nearest neighbor vectors. + * @param limit maximum size of the heap. + * @return the priority queue. + */ + public static PriorityQueue<WeightedThing<Vector>> getCandidateQueue(int limit) { + return new PriorityQueue<WeightedThing<Vector>>(limit) { + @Override + protected boolean lessThan(WeightedThing<Vector> a, WeightedThing<Vector> b) { + return a.getWeight() > b.getWeight(); + } + }; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/UpdatableSearcher.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/UpdatableSearcher.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/UpdatableSearcher.java new file mode 100644 index 0000000..68365c7 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/UpdatableSearcher.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.neighborhood; + +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Vector; + +/** + * Describes how we search vectors. A class should extend UpdatableSearch only if it can handle a remove function. + */ +public abstract class UpdatableSearcher extends Searcher { + + protected UpdatableSearcher(DistanceMeasure distanceMeasure) { + super(distanceMeasure); + } + + @Override + public abstract boolean remove(Vector v, double epsilon); + + @Override + public abstract void clear(); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/random/RandomProjector.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/random/RandomProjector.java b/mr/src/main/java/org/apache/mahout/math/random/RandomProjector.java new file mode 100644 index 0000000..79fe4b6 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/random/RandomProjector.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.random; + +import java.util.List; + +import com.google.common.collect.Lists; +import org.apache.commons.lang.math.RandomUtils; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixSlice; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.DoubleFunction; + +public final class RandomProjector { + private RandomProjector() { + } + + /** + * Generates a basis matrix of size projectedVectorSize x vectorSize. Multiplying a a vector by + * this matrix results in the projected vector. + * + * The rows of the matrix are sampled from a multi normal distribution. + * + * @param projectedVectorSize final projected size of a vector (number of projection vectors) + * @param vectorSize initial vector size + * @return a projection matrix + */ + public static Matrix generateBasisNormal(int projectedVectorSize, int vectorSize) { + Matrix basisMatrix = new DenseMatrix(projectedVectorSize, vectorSize); + basisMatrix.assign(new Normal()); + for (MatrixSlice row : basisMatrix) { + row.vector().assign(row.normalize()); + } + return basisMatrix; + } + + /** + * Generates a basis matrix of size projectedVectorSize x vectorSize. Multiplying a a vector by + * this matrix results in the projected vector. + * + * The rows of a matrix are sample from a distribution where: + * - +1 has probability 1/2, + * - -1 has probability 1/2 + * + * See Achlioptas, D. (2003). Database-friendly random projections: Johnson-Lindenstrauss with binary coins. + * Journal of Computer and System Sciences, 66(4), 671â687. doi:10.1016/S0022-0000(03)00025-4 + * + * @param projectedVectorSize final projected size of a vector (number of projection vectors) + * @param vectorSize initial vector size + * @return a projection matrix + */ + public static Matrix generateBasisPlusMinusOne(int projectedVectorSize, int vectorSize) { + Matrix basisMatrix = new DenseMatrix(projectedVectorSize, vectorSize); + for (int i = 0; i < projectedVectorSize; ++i) { + for (int j = 0; j < vectorSize; ++j) { + basisMatrix.set(i, j, RandomUtils.nextInt(2) == 0 ? 1 : -1); + } + } + for (MatrixSlice row : basisMatrix) { + row.vector().assign(row.normalize()); + } + return basisMatrix; + } + + /** + * Generates a basis matrix of size projectedVectorSize x vectorSize. Multiplying a a vector by + * this matrix results in the projected vector. + * + * The rows of a matrix are sample from a distribution where: + * - 0 has probability 2/3, + * - +1 has probability 1/6, + * - -1 has probability 1/6 + * + * See Achlioptas, D. (2003). Database-friendly random projections: Johnson-Lindenstrauss with binary coins. + * Journal of Computer and System Sciences, 66(4), 671â687. doi:10.1016/S0022-0000(03)00025-4 + * + * @param projectedVectorSize final projected size of a vector (number of projection vectors) + * @param vectorSize initial vector size + * @return a projection matrix + */ + public static Matrix generateBasisZeroPlusMinusOne(int projectedVectorSize, int vectorSize) { + Matrix basisMatrix = new DenseMatrix(projectedVectorSize, vectorSize); + Multinomial<Double> choice = new Multinomial<>(); + choice.add(0.0, 2 / 3.0); + choice.add(Math.sqrt(3.0), 1 / 6.0); + choice.add(-Math.sqrt(3.0), 1 / 6.0); + for (int i = 0; i < projectedVectorSize; ++i) { + for (int j = 0; j < vectorSize; ++j) { + basisMatrix.set(i, j, choice.sample()); + } + } + for (MatrixSlice row : basisMatrix) { + row.vector().assign(row.normalize()); + } + return basisMatrix; + } + + /** + * Generates a list of projectedVectorSize vectors, each of size vectorSize. This looks like a + * matrix of size (projectedVectorSize, vectorSize). + * @param projectedVectorSize final projected size of a vector (number of projection vectors) + * @param vectorSize initial vector size + * @return a list of projection vectors + */ + public static List<Vector> generateVectorBasis(int projectedVectorSize, int vectorSize) { + DoubleFunction random = new Normal(); + List<Vector> basisVectors = Lists.newArrayList(); + for (int i = 0; i < projectedVectorSize; ++i) { + Vector basisVector = new DenseVector(vectorSize); + basisVector.assign(random); + basisVector.normalize(); + basisVectors.add(basisVector); + } + return basisVectors; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/ssvd/SequentialOutOfCoreSvd.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/ssvd/SequentialOutOfCoreSvd.java b/mr/src/main/java/org/apache/mahout/math/ssvd/SequentialOutOfCoreSvd.java new file mode 100644 index 0000000..f7724f7 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/ssvd/SequentialOutOfCoreSvd.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.ssvd; + +import org.apache.mahout.math.CholeskyDecomposition; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixWritable; +import org.apache.mahout.math.RandomTrinaryMatrix; +import org.apache.mahout.math.SingularValueDecomposition; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; + +/** + * Sequential block-oriented out of core SVD algorithm. + * <p/> + * The basic algorithm (in-core version) is that we do a random projects, get a basis of that and + * then re-project the original matrix using that basis. This re-projected matrix allows us to get + * an approximate SVD of the original matrix. + * <p/> + * The input to this program is a list of files that contain the sub-matrices A_i. The result is a + * vector of singular values and optionally files that contain the left and right singular vectors. + * <p/> + * Mathematically, to decompose A, we do this: + * <p/> + * Y = A * \Omega + * <p/> + * Q R = Y + * <p/> + * B = Q" A + * <p/> + * U D V' = B + * <p/> + * (Q U) D V' \approx A + * <p/> + * To do this out of core, we break A into blocks each with the same number of rows. This gives a + * block-wise version of Y. As we are computing Y, we can also accumulate Y' Y and when done, we + * can use a Cholesky decomposition to do the QR decomposition of Y in a latent form. That gives us + * B in block-wise form and we can do the same trick to get an LQ of B. The L part can be + * decomposed in memory. Then we can recombine to get the final decomposition. + * <p/> + * The details go like this. Start with a block form of A. + * <p/> + * Y_i = A_i * \Omega + * <p/> + * Instead of doing a QR decomposition of Y, we do a Cholesky decomposition of Y' Y. This is a + * small in-memory operation. Q is large and dense and won't fit in memory. + * <p/> + * R' R = \sum_i Y_i' Y_i + * <p/> + * For reference, R is all we need to compute explicitly. Q will be computed on the fly when + * needed. + * <p/> + * Q = Y R^-1 + * <p/> + * B = Q" A = \sum_i (A \Omega R^-1)' A_i + * <p/> + * As B is generated, it needs to be segmented in row-wise blocks since it is wide but not tall. + * This storage requires something like a map-reduce to accumulate the partial sums. In this code, + * we do this by re-reading previously computed chunks and augmenting them. + * <p/> + * While the pieces of B are being computed, we can accumulate B B' in preparation for a second + * Cholesky decomposition + * <p/> + * L L' = B B' = sum B_j B_j' + * <p/> + * Again, this is an LQ decomposition of BB', but we don't compute the Q part explicitly. L will be + * small and thus tractable. + * <p/> + * Finally, we do the actual SVD decomposition. + * <p/> + * U_0 D V_0' = L + * <p/> + * D contains the singular values of A. The left and right singular values can be reconstructed + * using Y and B. Note that both of these reconstructions can be done with single passes through + * the blocked forms of Y and B. + * <p/> + * U = A \Omega R^{-1} U_0 + * <p/> + * V = B' L'^{-1} V_0 + */ +public class SequentialOutOfCoreSvd { + + private final CholeskyDecomposition l2; + private final SingularValueDecomposition svd; + private final CholeskyDecomposition r2; + private final int columnsPerSlice; + private final int seed; + private final int dim; + + public SequentialOutOfCoreSvd(Iterable<File> partsOfA, File tmpDir, int internalDimension, int columnsPerSlice) + throws IOException { + this.columnsPerSlice = columnsPerSlice; + this.dim = internalDimension; + + seed = 1; + Matrix y2 = null; + + // step 1, compute R as in R'R = Y'Y where Y = A \Omega + for (File file : partsOfA) { + MatrixWritable m = new MatrixWritable(); + try (DataInputStream in = new DataInputStream(new FileInputStream(file))) { + m.readFields(in); + } + + Matrix aI = m.get(); + Matrix omega = new RandomTrinaryMatrix(seed, aI.columnSize(), internalDimension, false); + Matrix y = aI.times(omega); + + if (y2 == null) { + y2 = y.transpose().times(y); + } else { + y2.assign(y.transpose().times(y), Functions.PLUS); + } + } + r2 = new CholeskyDecomposition(y2); + + // step 2, compute B + int ncols = 0; + for (File file : partsOfA) { + MatrixWritable m = new MatrixWritable(); + try (DataInputStream in = new DataInputStream(new FileInputStream(file))) { + m.readFields(in); + } + Matrix aI = m.get(); + ncols = Math.max(ncols, aI.columnSize()); + + Matrix omega = new RandomTrinaryMatrix(seed, aI.numCols(), internalDimension, false); + for (int j = 0; j < aI.numCols(); j += columnsPerSlice) { + Matrix yI = aI.times(omega); + Matrix aIJ = aI.viewPart(0, aI.rowSize(), j, Math.min(columnsPerSlice, aI.columnSize() - j)); + Matrix bIJ = r2.solveRight(yI).transpose().times(aIJ); + addToSavedCopy(bFile(tmpDir, j), bIJ); + } + } + + // step 3, compute BB', L and SVD(L) + Matrix b2 = new DenseMatrix(internalDimension, internalDimension); + MatrixWritable bTmp = new MatrixWritable(); + for (int j = 0; j < ncols; j += columnsPerSlice) { + if (bFile(tmpDir, j).exists()) { + try (DataInputStream in = new DataInputStream(new FileInputStream(bFile(tmpDir, j)))) { + bTmp.readFields(in); + } + + b2.assign(bTmp.get().times(bTmp.get().transpose()), Functions.PLUS); + } + } + l2 = new CholeskyDecomposition(b2); + svd = new SingularValueDecomposition(l2.getL()); + } + + public void computeV(File tmpDir, int ncols) throws IOException { + // step 5, compute pieces of V + for (int j = 0; j < ncols; j += columnsPerSlice) { + File bPath = bFile(tmpDir, j); + if (bPath.exists()) { + MatrixWritable m = new MatrixWritable(); + try (DataInputStream in = new DataInputStream(new FileInputStream(bPath))) { + m.readFields(in); + } + m.set(l2.solveRight(m.get().transpose()).times(svd.getV())); + try (DataOutputStream out = new DataOutputStream(new FileOutputStream( + new File(tmpDir, String.format("V-%s", bPath.getName().replaceAll(".*-", "")))))) { + m.write(out); + } + } + } + } + + public void computeU(Iterable<File> partsOfA, File tmpDir) throws IOException { + // step 4, compute pieces of U + for (File file : partsOfA) { + MatrixWritable m = new MatrixWritable(); + m.readFields(new DataInputStream(new FileInputStream(file))); + Matrix aI = m.get(); + + Matrix y = aI.times(new RandomTrinaryMatrix(seed, aI.numCols(), dim, false)); + Matrix uI = r2.solveRight(y).times(svd.getU()); + m.set(uI); + try (DataOutputStream out = new DataOutputStream(new FileOutputStream( + new File(tmpDir, String.format("U-%s", file.getName().replaceAll(".*-", "")))))) { + m.write(out); + } + } + } + + private static void addToSavedCopy(File file, Matrix matrix) throws IOException { + MatrixWritable mw = new MatrixWritable(); + if (file.exists()) { + try (DataInputStream in = new DataInputStream(new FileInputStream(file))) { + mw.readFields(in); + } + mw.get().assign(matrix, Functions.PLUS); + } else { + mw.set(matrix); + } + try (DataOutputStream out = new DataOutputStream(new FileOutputStream(file))) { + mw.write(out); + } + } + + private static File bFile(File tmpDir, int j) { + return new File(tmpDir, String.format("B-%09d", j)); + } + + public Vector getSingularValues() { + return new DenseVector(svd.getSingularValues()); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/stats/GlobalOnlineAuc.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/stats/GlobalOnlineAuc.java b/mr/src/main/java/org/apache/mahout/math/stats/GlobalOnlineAuc.java new file mode 100644 index 0000000..4485bbe --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/stats/GlobalOnlineAuc.java @@ -0,0 +1,168 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.stats; + +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixWritable; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Random; + +/** + * Computes a running estimate of AUC (see http://en.wikipedia.org/wiki/Receiver_operating_characteristic). + * <p/> + * Since AUC is normally a global property of labeled scores, it is almost always computed in a + * batch fashion. The probabilistic definition (the probability that a random element of one set + * has a higher score than a random element of another set) gives us a way to estimate this + * on-line. + * + * @see GroupedOnlineAuc + */ +public class GlobalOnlineAuc implements OnlineAuc { + enum ReplacementPolicy { + FIFO, FAIR, RANDOM + } + + // increasing this to 100 causes very small improvements in accuracy. Decreasing it to 2 + // causes substantial degradation for the FAIR and RANDOM policies, but almost no change + // for the FIFO policy + public static final int HISTORY = 10; + + // defines the exponential averaging window for results + private int windowSize = Integer.MAX_VALUE; + + // FIFO has distinctly the best properties as a policy. See OnlineAucTest for details + private ReplacementPolicy policy = ReplacementPolicy.FIFO; + private final Random random = RandomUtils.getRandom(); + private Matrix scores; + private Vector averages; + private Vector samples; + + public GlobalOnlineAuc() { + int numCategories = 2; + scores = new DenseMatrix(numCategories, HISTORY); + scores.assign(Double.NaN); + averages = new DenseVector(numCategories); + averages.assign(0.5); + samples = new DenseVector(numCategories); + } + + @Override + public double addSample(int category, String groupKey, double score) { + return addSample(category, score); + } + + @Override + public double addSample(int category, double score) { + int n = (int) samples.get(category); + if (n < HISTORY) { + scores.set(category, n, score); + } else { + switch (policy) { + case FIFO: + scores.set(category, n % HISTORY, score); + break; + case FAIR: + int j1 = random.nextInt(n + 1); + if (j1 < HISTORY) { + scores.set(category, j1, score); + } + break; + case RANDOM: + int j2 = random.nextInt(HISTORY); + scores.set(category, j2, score); + break; + default: + throw new IllegalStateException("Unknown policy: " + policy); + } + } + + samples.set(category, n + 1); + + if (samples.minValue() >= 1) { + // compare to previous scores for other category + Vector row = scores.viewRow(1 - category); + double m = 0.0; + double count = 0.0; + for (Vector.Element element : row.all()) { + double v = element.get(); + if (Double.isNaN(v)) { + continue; + } + count++; + if (score > v) { + m++; + // } else if (score < v) { + // m += 0 + } else if (score == v) { + m += 0.5; + } + } + averages.set(category, averages.get(category) + + (m / count - averages.get(category)) / Math.min(windowSize, samples.get(category))); + } + return auc(); + } + + @Override + public double auc() { + // return an unweighted average of all averages. + return (1 - averages.get(0) + averages.get(1)) / 2; + } + + public double value() { + return auc(); + } + + @Override + public void setPolicy(ReplacementPolicy policy) { + this.policy = policy; + } + + @Override + public void setWindowSize(int windowSize) { + this.windowSize = windowSize; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(windowSize); + out.writeInt(policy.ordinal()); + MatrixWritable.writeMatrix(out, scores); + VectorWritable.writeVector(out, averages); + VectorWritable.writeVector(out, samples); + } + + @Override + public void readFields(DataInput in) throws IOException { + windowSize = in.readInt(); + policy = ReplacementPolicy.values()[in.readInt()]; + + scores = MatrixWritable.readMatrix(in); + averages = VectorWritable.readVector(in); + samples = VectorWritable.readVector(in); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/stats/GroupedOnlineAuc.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/stats/GroupedOnlineAuc.java b/mr/src/main/java/org/apache/mahout/math/stats/GroupedOnlineAuc.java new file mode 100644 index 0000000..3fa1b79 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/stats/GroupedOnlineAuc.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.stats; + +import com.google.common.collect.Maps; +import org.apache.mahout.classifier.sgd.PolymorphicWritable; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Map; + +/** + * Implements a variant on AUC where the result returned is an average of several AUC measurements + * made on sub-groups of the overall data. Controlling for the grouping factor allows the effects + * of the grouping factor on the model to be ignored. This is useful, for instance, when using a + * classifier as a click prediction engine. In that case you want AUC to refer only to the ranking + * of items for a particular user, not to the discrimination of users from each other. Grouping by + * user (or user cluster) helps avoid optimizing for the wrong quality. + */ +public class GroupedOnlineAuc implements OnlineAuc { + private final Map<String, OnlineAuc> map = Maps.newHashMap(); + private GlobalOnlineAuc.ReplacementPolicy policy; + private int windowSize; + + @Override + public double addSample(int category, String groupKey, double score) { + if (groupKey == null) { + addSample(category, score); + } + + OnlineAuc group = map.get(groupKey); + if (group == null) { + group = new GlobalOnlineAuc(); + if (policy != null) { + group.setPolicy(policy); + } + if (windowSize > 0) { + group.setWindowSize(windowSize); + } + map.put(groupKey, group); + } + return group.addSample(category, score); + } + + @Override + public double addSample(int category, double score) { + throw new UnsupportedOperationException("Can't add to " + this.getClass() + " without group key"); + } + + @Override + public double auc() { + double sum = 0; + for (OnlineAuc auc : map.values()) { + sum += auc.auc(); + } + return sum / map.size(); + } + + @Override + public void setPolicy(GlobalOnlineAuc.ReplacementPolicy policy) { + this.policy = policy; + for (OnlineAuc auc : map.values()) { + auc.setPolicy(policy); + } + } + + @Override + public void setWindowSize(int windowSize) { + this.windowSize = windowSize; + for (OnlineAuc auc : map.values()) { + auc.setWindowSize(windowSize); + } + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(map.size()); + for (Map.Entry<String,OnlineAuc> entry : map.entrySet()) { + out.writeUTF(entry.getKey()); + PolymorphicWritable.write(out, entry.getValue()); + } + out.writeInt(policy.ordinal()); + out.writeInt(windowSize); + } + + @Override + public void readFields(DataInput in) throws IOException { + int n = in.readInt(); + map.clear(); + for (int i = 0; i < n; i++) { + String key = in.readUTF(); + map.put(key, PolymorphicWritable.read(in, OnlineAuc.class)); + } + policy = GlobalOnlineAuc.ReplacementPolicy.values()[in.readInt()]; + windowSize = in.readInt(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java b/mr/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java new file mode 100644 index 0000000..d21ae6b --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.stats; + +import org.apache.hadoop.io.Writable; + +/** + * Describes the generic outline of how to compute AUC. Currently there are two + * implementations of this, one for computing a global estimate of AUC and the other + * for computing average grouped AUC. Grouped AUC is useful when misusing a classifier + * as a recommendation system. + */ +public interface OnlineAuc extends Writable { + double addSample(int category, String groupKey, double score); + + double addSample(int category, double score); + + double auc(); + + void setPolicy(GlobalOnlineAuc.ReplacementPolicy policy); + + void setWindowSize(int windowSize); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/stats/Sampler.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/math/stats/Sampler.java b/mr/src/main/java/org/apache/mahout/math/stats/Sampler.java new file mode 100644 index 0000000..4b9e8a9 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/math/stats/Sampler.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.stats; + +import com.google.common.base.Preconditions; +import org.apache.mahout.math.Vector; + +import java.util.Arrays; +import java.util.Random; + +/** + * Discrete distribution sampler: + * + * Samples from a given discrete distribution: you provide a source of randomness and a Vector + * (cardinality N) which describes a distribution over [0,N), and calls to sample() sample + * from 0 to N using this distribution + */ +public class Sampler { + + private final Random random; + private final double[] sampler; + + public Sampler(Random random) { + this.random = random; + sampler = null; + } + + public Sampler(Random random, double[] sampler) { + this.random = random; + this.sampler = sampler; + } + + public Sampler(Random random, Vector distribution) { + this.random = random; + this.sampler = samplerFor(distribution); + } + + public int sample(Vector distribution) { + return sample(samplerFor(distribution)); + } + + public int sample() { + Preconditions.checkNotNull(sampler, + "Sampler must have been constructed with a distribution, or else sample(Vector) should be used to sample"); + return sample(sampler); + } + + private static double[] samplerFor(Vector vectorDistribution) { + int size = vectorDistribution.size(); + double[] partition = new double[size]; + double norm = vectorDistribution.norm(1); + double sum = 0; + for (int i = 0; i < size; i++) { + sum += vectorDistribution.get(i) / norm; + partition[i] = sum; + } + return partition; + } + + private int sample(double[] sampler) { + int index = Arrays.binarySearch(sampler, random.nextDouble()); + return index < 0 ? -(index + 1) : index; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/DictionaryVectorizer.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/DictionaryVectorizer.java b/mr/src/main/java/org/apache/mahout/vectorizer/DictionaryVectorizer.java new file mode 100644 index 0000000..8a1f8f8 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/vectorizer/DictionaryVectorizer.java @@ -0,0 +1,416 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.StringTuple; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.vectorizer.collocations.llr.CollocDriver; +import org.apache.mahout.vectorizer.collocations.llr.LLRReducer; +import org.apache.mahout.vectorizer.common.PartialVectorMerger; +import org.apache.mahout.vectorizer.term.TFPartialVectorReducer; +import org.apache.mahout.vectorizer.term.TermCountCombiner; +import org.apache.mahout.vectorizer.term.TermCountMapper; +import org.apache.mahout.vectorizer.term.TermCountReducer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; + +/** + * This class converts a set of input documents in the sequence file format to vectors. The Sequence file + * input should have a {@link Text} key containing the unique document identifier and a {@link StringTuple} + * value containing the tokenized document. You may use {@link DocumentProcessor} to tokenize the document. + * This is a dictionary based Vectorizer. + */ +public final class DictionaryVectorizer extends AbstractJob implements Vectorizer { + private static final Logger log = LoggerFactory.getLogger(DictionaryVectorizer.class); + + public static final String DOCUMENT_VECTOR_OUTPUT_FOLDER = "tf-vectors"; + public static final String MIN_SUPPORT = "min.support"; + public static final String MAX_NGRAMS = "max.ngrams"; + public static final int DEFAULT_MIN_SUPPORT = 2; + public static final String DICTIONARY_FILE = "dictionary.file-"; + + private static final int MAX_CHUNKSIZE = 10000; + private static final int MIN_CHUNKSIZE = 100; + private static final String OUTPUT_FILES_PATTERN = "part-*"; + // 4 byte overhead for each entry in the OpenObjectIntHashMap + private static final int DICTIONARY_BYTE_OVERHEAD = 4; + private static final String VECTOR_OUTPUT_FOLDER = "partial-vectors-"; + private static final String DICTIONARY_JOB_FOLDER = "wordcount"; + + /** + * Cannot be initialized. Use the static functions + */ + private DictionaryVectorizer() { + } + + //TODO: move more of SparseVectorsFromSequenceFile in here, and then fold SparseVectorsFrom with + // EncodedVectorsFrom to have one framework. + + @Override + public void createVectors(Path input, Path output, VectorizerConfig config) + throws IOException, ClassNotFoundException, InterruptedException { + createTermFrequencyVectors(input, + output, + config.getTfDirName(), + config.getConf(), + config.getMinSupport(), + config.getMaxNGramSize(), + config.getMinLLRValue(), + config.getNormPower(), + config.isLogNormalize(), + config.getNumReducers(), + config.getChunkSizeInMegabytes(), + config.isSequentialAccess(), + config.isNamedVectors()); + } + + /** + * Create Term Frequency (Tf) Vectors from the input set of documents in {@link SequenceFile} format. This + * tries to fix the maximum memory used by the feature chunk per node thereby splitting the process across + * multiple map/reduces. + * + * @param input + * input directory of the documents in {@link SequenceFile} format + * @param output + * output directory where {@link org.apache.mahout.math.RandomAccessSparseVector}'s of the document + * are generated + * @param tfVectorsFolderName + * The name of the folder in which the final output vectors will be stored + * @param baseConf + * job configuration + * @param normPower + * L_p norm to be computed + * @param logNormalize + * whether to use log normalization + * @param minSupport + * the minimum frequency of the feature in the entire corpus to be considered for inclusion in the + * sparse vector + * @param maxNGramSize + * 1 = unigram, 2 = unigram and bigram, 3 = unigram, bigram and trigram + * @param minLLRValue + * minValue of log likelihood ratio to used to prune ngrams + * @param chunkSizeInMegabytes + * the size in MB of the feature => id chunk to be kept in memory at each node during Map/Reduce + * stage. Its recommended you calculated this based on the number of cores and the free memory + * available to you per node. Say, you have 2 cores and around 1GB extra memory to spare we + * recommend you use a split size of around 400-500MB so that two simultaneous reducers can create + * partial vectors without thrashing the system due to increased swapping + */ + public static void createTermFrequencyVectors(Path input, + Path output, + String tfVectorsFolderName, + Configuration baseConf, + int minSupport, + int maxNGramSize, + float minLLRValue, + float normPower, + boolean logNormalize, + int numReducers, + int chunkSizeInMegabytes, + boolean sequentialAccess, + boolean namedVectors) + throws IOException, InterruptedException, ClassNotFoundException { + Preconditions.checkArgument(normPower == PartialVectorMerger.NO_NORMALIZING || normPower >= 0, + "If specified normPower must be nonnegative", normPower); + Preconditions.checkArgument(normPower == PartialVectorMerger.NO_NORMALIZING + || (normPower > 1 && !Double.isInfinite(normPower)) + || !logNormalize, + "normPower must be > 1 and not infinite if log normalization is chosen", normPower); + if (chunkSizeInMegabytes < MIN_CHUNKSIZE) { + chunkSizeInMegabytes = MIN_CHUNKSIZE; + } else if (chunkSizeInMegabytes > MAX_CHUNKSIZE) { // 10GB + chunkSizeInMegabytes = MAX_CHUNKSIZE; + } + if (minSupport < 0) { + minSupport = DEFAULT_MIN_SUPPORT; + } + + Path dictionaryJobPath = new Path(output, DICTIONARY_JOB_FOLDER); + + log.info("Creating dictionary from {} and saving at {}", input, dictionaryJobPath); + + int[] maxTermDimension = new int[1]; + List<Path> dictionaryChunks; + if (maxNGramSize == 1) { + startWordCounting(input, dictionaryJobPath, baseConf, minSupport); + dictionaryChunks = + createDictionaryChunks(dictionaryJobPath, output, baseConf, chunkSizeInMegabytes, maxTermDimension); + } else { + CollocDriver.generateAllGrams(input, dictionaryJobPath, baseConf, maxNGramSize, + minSupport, minLLRValue, numReducers); + dictionaryChunks = + createDictionaryChunks(new Path(new Path(output, DICTIONARY_JOB_FOLDER), + CollocDriver.NGRAM_OUTPUT_DIRECTORY), + output, + baseConf, + chunkSizeInMegabytes, + maxTermDimension); + } + + int partialVectorIndex = 0; + Collection<Path> partialVectorPaths = Lists.newArrayList(); + for (Path dictionaryChunk : dictionaryChunks) { + Path partialVectorOutputPath = new Path(output, VECTOR_OUTPUT_FOLDER + partialVectorIndex++); + partialVectorPaths.add(partialVectorOutputPath); + makePartialVectors(input, baseConf, maxNGramSize, dictionaryChunk, partialVectorOutputPath, + maxTermDimension[0], sequentialAccess, namedVectors, numReducers); + } + + Configuration conf = new Configuration(baseConf); + + Path outputDir = new Path(output, tfVectorsFolderName); + PartialVectorMerger.mergePartialVectors(partialVectorPaths, outputDir, conf, normPower, logNormalize, + maxTermDimension[0], sequentialAccess, namedVectors, numReducers); + HadoopUtil.delete(conf, partialVectorPaths); + } + + /** + * Read the feature frequency List which is built at the end of the Word Count Job and assign ids to them. + * This will use constant memory and will run at the speed of your disk read + */ + private static List<Path> createDictionaryChunks(Path wordCountPath, + Path dictionaryPathBase, + Configuration baseConf, + int chunkSizeInMegabytes, + int[] maxTermDimension) throws IOException { + List<Path> chunkPaths = Lists.newArrayList(); + + Configuration conf = new Configuration(baseConf); + + FileSystem fs = FileSystem.get(wordCountPath.toUri(), conf); + + long chunkSizeLimit = chunkSizeInMegabytes * 1024L * 1024L; + int chunkIndex = 0; + Path chunkPath = new Path(dictionaryPathBase, DICTIONARY_FILE + chunkIndex); + chunkPaths.add(chunkPath); + + SequenceFile.Writer dictWriter = new SequenceFile.Writer(fs, conf, chunkPath, Text.class, IntWritable.class); + + try { + long currentChunkSize = 0; + Path filesPattern = new Path(wordCountPath, OUTPUT_FILES_PATTERN); + int i = 0; + for (Pair<Writable,Writable> record + : new SequenceFileDirIterable<>(filesPattern, PathType.GLOB, null, null, true, conf)) { + if (currentChunkSize > chunkSizeLimit) { + Closeables.close(dictWriter, false); + chunkIndex++; + + chunkPath = new Path(dictionaryPathBase, DICTIONARY_FILE + chunkIndex); + chunkPaths.add(chunkPath); + + dictWriter = new SequenceFile.Writer(fs, conf, chunkPath, Text.class, IntWritable.class); + currentChunkSize = 0; + } + + Writable key = record.getFirst(); + int fieldSize = DICTIONARY_BYTE_OVERHEAD + key.toString().length() * 2 + Integer.SIZE / 8; + currentChunkSize += fieldSize; + dictWriter.append(key, new IntWritable(i++)); + } + maxTermDimension[0] = i; + } finally { + Closeables.close(dictWriter, false); + } + + return chunkPaths; + } + + /** + * Create a partial vector using a chunk of features from the input documents. The input documents has to be + * in the {@link SequenceFile} format + * + * @param input + * input directory of the documents in {@link SequenceFile} format + * @param baseConf + * job configuration + * @param maxNGramSize + * maximum size of ngrams to generate + * @param dictionaryFilePath + * location of the chunk of features and the id's + * @param output + * output directory were the partial vectors have to be created + * @param dimension + * @param sequentialAccess + * output vectors should be optimized for sequential access + * @param namedVectors + * output vectors should be named, retaining key (doc id) as a label + * @param numReducers + * the desired number of reducer tasks + */ + private static void makePartialVectors(Path input, + Configuration baseConf, + int maxNGramSize, + Path dictionaryFilePath, + Path output, + int dimension, + boolean sequentialAccess, + boolean namedVectors, + int numReducers) + throws IOException, InterruptedException, ClassNotFoundException { + + Configuration conf = new Configuration(baseConf); + // this conf parameter needs to be set enable serialisation of conf values + conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization," + + "org.apache.hadoop.io.serializer.WritableSerialization"); + conf.setInt(PartialVectorMerger.DIMENSION, dimension); + conf.setBoolean(PartialVectorMerger.SEQUENTIAL_ACCESS, sequentialAccess); + conf.setBoolean(PartialVectorMerger.NAMED_VECTOR, namedVectors); + conf.setInt(MAX_NGRAMS, maxNGramSize); + DistributedCache.addCacheFile(dictionaryFilePath.toUri(), conf); + + Job job = new Job(conf); + job.setJobName("DictionaryVectorizer::MakePartialVectors: input-folder: " + input + + ", dictionary-file: " + dictionaryFilePath); + job.setJarByClass(DictionaryVectorizer.class); + + job.setMapOutputKeyClass(Text.class); + job.setMapOutputValueClass(StringTuple.class); + job.setOutputKeyClass(Text.class); + job.setOutputValueClass(VectorWritable.class); + FileInputFormat.setInputPaths(job, input); + + FileOutputFormat.setOutputPath(job, output); + + job.setMapperClass(Mapper.class); + job.setInputFormatClass(SequenceFileInputFormat.class); + job.setReducerClass(TFPartialVectorReducer.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + job.setNumReduceTasks(numReducers); + + HadoopUtil.delete(conf, output); + + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + } + + /** + * Count the frequencies of words in parallel using Map/Reduce. The input documents have to be in + * {@link SequenceFile} format + */ + private static void startWordCounting(Path input, Path output, Configuration baseConf, int minSupport) + throws IOException, InterruptedException, ClassNotFoundException { + + Configuration conf = new Configuration(baseConf); + // this conf parameter needs to be set enable serialisation of conf values + conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization," + + "org.apache.hadoop.io.serializer.WritableSerialization"); + conf.setInt(MIN_SUPPORT, minSupport); + + Job job = new Job(conf); + + job.setJobName("DictionaryVectorizer::WordCount: input-folder: " + input); + job.setJarByClass(DictionaryVectorizer.class); + + job.setOutputKeyClass(Text.class); + job.setOutputValueClass(LongWritable.class); + + FileInputFormat.setInputPaths(job, input); + FileOutputFormat.setOutputPath(job, output); + + job.setMapperClass(TermCountMapper.class); + + job.setInputFormatClass(SequenceFileInputFormat.class); + job.setCombinerClass(TermCountCombiner.class); + job.setReducerClass(TermCountReducer.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + + HadoopUtil.delete(conf, output); + + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + } + + @Override + public int run(String[] args) throws Exception { + addInputOption(); + addOutputOption(); + addOption("tfDirName", "tf", "The folder to store the TF calculations", "tfDirName"); + addOption("minSupport", "s", "(Optional) Minimum Support. Default Value: 2", "2"); + addOption("maxNGramSize", "ng", "(Optional) The maximum size of ngrams to create" + + " (2 = bigrams, 3 = trigrams, etc) Default Value:1"); + addOption("minLLR", "ml", "(Optional)The minimum Log Likelihood Ratio(Float) Default is " + + LLRReducer.DEFAULT_MIN_LLR); + addOption("norm", "n", "The norm to use, expressed as either a float or \"INF\" " + + "if you want to use the Infinite norm. " + + "Must be greater or equal to 0. The default is not to normalize"); + addOption("logNormalize", "lnorm", "(Optional) Whether output vectors should be logNormalize. " + + "If set true else false", "false"); + addOption(DefaultOptionCreator.numReducersOption().create()); + addOption("chunkSize", "chunk", "The chunkSize in MegaBytes. 100-10000 MB", "100"); + addOption(DefaultOptionCreator.methodOption().create()); + addOption("namedVector", "nv", "(Optional) Whether output vectors should be NamedVectors. " + + "If set true else false", "false"); + if (parseArguments(args) == null) { + return -1; + } + String tfDirName = getOption("tfDirName", "tfDir"); + int minSupport = getInt("minSupport", 2); + int maxNGramSize = getInt("maxNGramSize", 1); + float minLLRValue = getFloat("minLLR", LLRReducer.DEFAULT_MIN_LLR); + float normPower = getFloat("norm", PartialVectorMerger.NO_NORMALIZING); + boolean logNormalize = hasOption("logNormalize"); + int numReducers = getInt(DefaultOptionCreator.MAX_REDUCERS_OPTION); + int chunkSizeInMegs = getInt("chunkSize", 100); + boolean sequential = hasOption("sequential"); + boolean namedVecs = hasOption("namedVectors"); + //TODO: add support for other paths + createTermFrequencyVectors(getInputPath(), getOutputPath(), + tfDirName, + getConf(), minSupport, maxNGramSize, minLLRValue, + normPower, logNormalize, numReducers, chunkSizeInMegs, sequential, namedVecs); + return 0; + } + + public static void main(String[] args) throws Exception { + ToolRunner.run(new DictionaryVectorizer(), args); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/DocumentProcessor.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/DocumentProcessor.java b/mr/src/main/java/org/apache/mahout/vectorizer/DocumentProcessor.java new file mode 100644 index 0000000..2c3c236 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/vectorizer/DocumentProcessor.java @@ -0,0 +1,99 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer; + +import java.io.IOException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.lucene.analysis.Analyzer; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.StringTuple; +import org.apache.mahout.vectorizer.document.SequenceFileTokenizerMapper; + +/** + * This class converts a set of input documents in the sequence file format of {@link StringTuple}s.The + * {@link org.apache.hadoop.io.SequenceFile} input should have a {@link Text} key + * containing the unique document identifier and a + * {@link Text} value containing the whole document. The document should be stored in UTF-8 encoding which is + * recognizable by hadoop. It uses the given {@link Analyzer} to process the document into + * {@link org.apache.lucene.analysis.Token}s. + * + */ +public final class DocumentProcessor { + + public static final String TOKENIZED_DOCUMENT_OUTPUT_FOLDER = "tokenized-documents"; + public static final String ANALYZER_CLASS = "analyzer.class"; + + /** + * Cannot be initialized. Use the static functions + */ + private DocumentProcessor() { + + } + + /** + * Convert the input documents into token array using the {@link StringTuple} The input documents has to be + * in the {@link org.apache.hadoop.io.SequenceFile} format + * + * @param input + * input directory of the documents in {@link org.apache.hadoop.io.SequenceFile} format + * @param output + * output directory were the {@link StringTuple} token array of each document has to be created + * @param analyzerClass + * The Lucene {@link Analyzer} for tokenizing the UTF-8 text + */ + public static void tokenizeDocuments(Path input, + Class<? extends Analyzer> analyzerClass, + Path output, + Configuration baseConf) + throws IOException, InterruptedException, ClassNotFoundException { + Configuration conf = new Configuration(baseConf); + // this conf parameter needs to be set enable serialisation of conf values + conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization," + + "org.apache.hadoop.io.serializer.WritableSerialization"); + conf.set(ANALYZER_CLASS, analyzerClass.getName()); + + Job job = new Job(conf); + job.setJobName("DocumentProcessor::DocumentTokenizer: input-folder: " + input); + job.setJarByClass(DocumentProcessor.class); + + job.setOutputKeyClass(Text.class); + job.setOutputValueClass(StringTuple.class); + FileInputFormat.setInputPaths(job, input); + FileOutputFormat.setOutputPath(job, output); + + job.setMapperClass(SequenceFileTokenizerMapper.class); + job.setInputFormatClass(SequenceFileInputFormat.class); + job.setNumReduceTasks(0); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + HadoopUtil.delete(conf, output); + + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/EncodedVectorsFromSequenceFiles.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/EncodedVectorsFromSequenceFiles.java b/mr/src/main/java/org/apache/mahout/vectorizer/EncodedVectorsFromSequenceFiles.java new file mode 100644 index 0000000..1cf7ad7 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/vectorizer/EncodedVectorsFromSequenceFiles.java @@ -0,0 +1,104 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.util.ToolRunner; +import org.apache.lucene.analysis.Analyzer; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.ClassUtils; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder; +import org.apache.mahout.vectorizer.encoders.LuceneTextValueEncoder; + +/** + * Converts a given set of sequence files into SparseVectors + */ +public final class EncodedVectorsFromSequenceFiles extends AbstractJob { + + public static void main(String[] args) throws Exception { + ToolRunner.run(new Configuration(), new EncodedVectorsFromSequenceFiles(), args); + } + + @Override + public int run(String[] args) throws Exception { + addInputOption(); + addOutputOption(); + addOption(DefaultOptionCreator.analyzerOption().create()); + addOption(buildOption("sequentialAccessVector", "seq", + "(Optional) Whether output vectors should be SequentialAccessVectors. " + + "If set true else false", + false, false, null)); + addOption(buildOption("namedVector", "nv", + "Create named vectors using the key. False by default", false, false, null)); + addOption("cardinality", "c", + "The cardinality to use for creating the vectors. Default is 5000", "5000"); + addOption("encoderFieldName", "en", + "The name of the encoder to be passed to the FeatureVectorEncoder constructor. Default is text. " + + "Note this is not the class name of a FeatureValueEncoder, but is instead the construction " + + "argument.", + "text"); + addOption("encoderClass", "ec", + "The class name of the encoder to be used. Default is " + LuceneTextValueEncoder.class.getName(), + LuceneTextValueEncoder.class.getName()); + addOption(DefaultOptionCreator.overwriteOption().create()); + if (parseArguments(args) == null) { + return -1; + } + + Path input = getInputPath(); + Path output = getOutputPath(); + + if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { + HadoopUtil.delete(getConf(), output); + } + + Class<? extends Analyzer> analyzerClass = getAnalyzerClassFromOption(); + + Configuration conf = getConf(); + + boolean sequentialAccessOutput = hasOption("sequentialAccessVector"); + + boolean namedVectors = hasOption("namedVector"); + int cardinality = 5000; + if (hasOption("cardinality")) { + cardinality = Integer.parseInt(getOption("cardinality")); + } + String encoderName = "text"; + if (hasOption("encoderFieldName")) { + encoderName = getOption("encoderFieldName"); + } + String encoderClass = LuceneTextValueEncoder.class.getName(); + if (hasOption("encoderClass")) { + encoderClass = getOption("encoderClass"); + ClassUtils.instantiateAs(encoderClass, FeatureVectorEncoder.class, new Class[] { String.class }, + new Object[] { encoderName }); //try instantiating it + } + + SimpleTextEncodingVectorizer vectorizer = new SimpleTextEncodingVectorizer(); + VectorizerConfig config = new VectorizerConfig(conf, analyzerClass.getName(), encoderClass, encoderName, + sequentialAccessOutput, namedVectors, cardinality); + + vectorizer.createVectors(input, output, config); + + return 0; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/EncodingMapper.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/EncodingMapper.java b/mr/src/main/java/org/apache/mahout/vectorizer/EncodingMapper.java new file mode 100644 index 0000000..63ccea4 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/vectorizer/EncodingMapper.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.vectorizer; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.apache.mahout.common.ClassUtils; +import org.apache.mahout.common.lucene.AnalyzerUtils; +import org.apache.mahout.math.NamedVector; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder; +import org.apache.mahout.vectorizer.encoders.LuceneTextValueEncoder; + +import java.io.IOException; + +/** + * The Mapper that does the work of encoding text + */ +public class EncodingMapper extends Mapper<Text, Text, Text, VectorWritable> { + + public static final String USE_NAMED_VECTORS = "namedVectors"; + public static final String USE_SEQUENTIAL = "sequential"; + public static final String ANALYZER_NAME = "analyzer"; + public static final String ENCODER_FIELD_NAME = "encoderFieldName"; + public static final String ENCODER_CLASS = "encoderClass"; + public static final String CARDINALITY = "cardinality"; + private boolean sequentialVectors; + private boolean namedVectors; + private FeatureVectorEncoder encoder; + private int cardinality; + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + Configuration conf = context.getConfiguration(); + sequentialVectors = conf.getBoolean(USE_SEQUENTIAL, false); + namedVectors = conf.getBoolean(USE_NAMED_VECTORS, false); + String analyzerName = conf.get(ANALYZER_NAME, StandardAnalyzer.class.getName()); + Analyzer analyzer; + try { + analyzer = AnalyzerUtils.createAnalyzer(analyzerName); + } catch (ClassNotFoundException e) { + //TODO: hmmm, don't like this approach + throw new IOException("Unable to create Analyzer for name: " + analyzerName, e); + } + + String encoderName = conf.get(ENCODER_FIELD_NAME, "text"); + cardinality = conf.getInt(CARDINALITY, 5000); + String encClass = conf.get(ENCODER_CLASS); + encoder = ClassUtils.instantiateAs(encClass, + FeatureVectorEncoder.class, + new Class[]{String.class}, + new Object[]{encoderName}); + if (encoder instanceof LuceneTextValueEncoder) { + ((LuceneTextValueEncoder) encoder).setAnalyzer(analyzer); + } + } + + @Override + protected void map(Text key, Text value, Context context) throws IOException, InterruptedException { + Vector vector; + if (sequentialVectors) { + vector = new SequentialAccessSparseVector(cardinality); + } else { + vector = new RandomAccessSparseVector(cardinality); + } + if (namedVectors) { + vector = new NamedVector(vector, key.toString()); + } + encoder.addToVector(value.toString(), vector); + context.write(new Text(key.toString()), new VectorWritable(vector)); + } +}
