http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/driver/MahoutDriver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/driver/MahoutDriver.java b/community/mahout-mr/src/main/java/org/apache/mahout/driver/MahoutDriver.java new file mode 100644 index 0000000..5c5b8a4 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/driver/MahoutDriver.java @@ -0,0 +1,244 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.driver; + +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +import com.google.common.io.Closeables; +import org.apache.hadoop.util.ProgramDriver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * General-purpose driver class for Mahout programs. Utilizes org.apache.hadoop.util.ProgramDriver to run + * main methods of other classes, but first loads up default properties from a properties file. + * <p/> + * To run locally: + * + * <pre>$MAHOUT_HOME/bin/mahout run shortJobName [over-ride ops]</pre> + * <p/> + * Works like this: by default, the file "driver.classes.props" is loaded from the classpath, which + * defines a mapping between short names like "vectordump" and fully qualified class names. + * The format of driver.classes.props is like so: + * <p/> + * + * <pre>fully.qualified.class.name = shortJobName : descriptive string</pre> + * <p/> + * The default properties to be applied to the program run is pulled out of, by default, "<shortJobName>.props" + * (also off of the classpath). + * <p/> + * The format of the default properties files is as follows: + * <pre> + i|input = /path/to/my/input + o|output = /path/to/my/output + m|jarFile = /path/to/jarFile + # etc - each line is shortArg|longArg = value + </pre> + * + * The next argument to the Driver is supposed to be the short name of the class to be run (as defined in the + * driver.classes.props file). + * <p/> + * Then the class which will be run will have it's main called with + * + * <pre>main(new String[] { "--input", "/path/to/my/input", "--output", "/path/to/my/output" });</pre> + * + * After all the "default" properties are loaded from the file, any further command-line arguments are taken in, + * and over-ride the defaults. + * <p/> + * So if your driver.classes.props looks like so: + * + * <pre>org.apache.mahout.utils.vectors.VectorDumper = vecDump : dump vectors from a sequence file</pre> + * + * and you have a file core/src/main/resources/vecDump.props which looks like + * <pre> + o|output = /tmp/vectorOut + s|seqFile = /my/vector/sequenceFile + </pre> + * + * And you execute the command-line: + * + * <pre>$MAHOUT_HOME/bin/mahout run vecDump -s /my/otherVector/sequenceFile</pre> + * + * Then org.apache.mahout.utils.vectors.VectorDumper.main() will be called with arguments: + * <pre>{"--output", "/tmp/vectorOut", "-s", "/my/otherVector/sequenceFile"}</pre> + */ +public final class MahoutDriver { + + private static final Logger log = LoggerFactory.getLogger(MahoutDriver.class); + + private MahoutDriver() { + } + + public static void main(String[] args) throws Throwable { + + Properties mainClasses = loadProperties("driver.classes.props"); + if (mainClasses == null) { + mainClasses = loadProperties("driver.classes.default.props"); + } + if (mainClasses == null) { + throw new IOException("Can't load any properties file?"); + } + + boolean foundShortName = false; + ProgramDriver programDriver = new ProgramDriver(); + for (Object key : mainClasses.keySet()) { + String keyString = (String) key; + if (args.length > 0 && shortName(mainClasses.getProperty(keyString)).equals(args[0])) { + foundShortName = true; + } + if (args.length > 0 && keyString.equalsIgnoreCase(args[0]) && isDeprecated(mainClasses, keyString)) { + log.error(desc(mainClasses.getProperty(keyString))); + return; + } + if (isDeprecated(mainClasses, keyString)) { + continue; + } + addClass(programDriver, keyString, mainClasses.getProperty(keyString)); + } + + if (args.length < 1 || args[0] == null || "-h".equals(args[0]) || "--help".equals(args[0])) { + programDriver.driver(args); + return; + } + + String progName = args[0]; + if (!foundShortName) { + addClass(programDriver, progName, progName); + } + shift(args); + + Properties mainProps = loadProperties(progName + ".props"); + if (mainProps == null) { + log.warn("No {}.props found on classpath, will use command-line arguments only", progName); + mainProps = new Properties(); + } + + Map<String,String[]> argMap = new HashMap<>(); + int i = 0; + while (i < args.length && args[i] != null) { + List<String> argValues = new ArrayList<>(); + String arg = args[i]; + i++; + if (arg.startsWith("-D")) { // '-Dkey=value' or '-Dkey=value1,value2,etc' case + String[] argSplit = arg.split("="); + arg = argSplit[0]; + if (argSplit.length == 2) { + argValues.add(argSplit[1]); + } + } else { // '-key [values]' or '--key [values]' case. + while (i < args.length && args[i] != null) { + if (args[i].startsWith("-")) { + break; + } + argValues.add(args[i]); + i++; + } + } + argMap.put(arg, argValues.toArray(new String[argValues.size()])); + } + + // Add properties from the .props file that are not overridden on the command line + for (String key : mainProps.stringPropertyNames()) { + String[] argNamePair = key.split("\\|"); + String shortArg = '-' + argNamePair[0].trim(); + String longArg = argNamePair.length < 2 ? null : "--" + argNamePair[1].trim(); + if (!argMap.containsKey(shortArg) && (longArg == null || !argMap.containsKey(longArg))) { + argMap.put(longArg, new String[] {mainProps.getProperty(key)}); + } + } + + // Now add command-line args + List<String> argsList = new ArrayList<>(); + argsList.add(progName); + for (Map.Entry<String,String[]> entry : argMap.entrySet()) { + String arg = entry.getKey(); + if (arg.startsWith("-D")) { // arg is -Dkey - if value for this !isEmpty(), then arg -> -Dkey + "=" + value + String[] argValues = entry.getValue(); + if (argValues.length > 0 && !argValues[0].trim().isEmpty()) { + arg += '=' + argValues[0].trim(); + } + argsList.add(1, arg); + } else { + argsList.add(arg); + for (String argValue : Arrays.asList(argMap.get(arg))) { + if (!argValue.isEmpty()) { + argsList.add(argValue); + } + } + } + } + + long start = System.currentTimeMillis(); + + programDriver.driver(argsList.toArray(new String[argsList.size()])); + + if (log.isInfoEnabled()) { + log.info("Program took {} ms (Minutes: {})", System.currentTimeMillis() - start, + (System.currentTimeMillis() - start) / 60000.0); + } + } + + private static boolean isDeprecated(Properties mainClasses, String keyString) { + return "deprecated".equalsIgnoreCase(shortName(mainClasses.getProperty(keyString))); + } + + private static Properties loadProperties(String resource) throws IOException { + InputStream propsStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(resource); + if (propsStream != null) { + try { + Properties properties = new Properties(); + properties.load(propsStream); + return properties; + } finally { + Closeables.close(propsStream, true); + } + } + return null; + } + + private static String[] shift(String[] args) { + System.arraycopy(args, 1, args, 0, args.length - 1); + args[args.length - 1] = null; + return args; + } + + private static String shortName(String valueString) { + return valueString.contains(":") ? valueString.substring(0, valueString.indexOf(':')).trim() : valueString; + } + + private static String desc(String valueString) { + return valueString.contains(":") ? valueString.substring(valueString.indexOf(':')).trim() : valueString; + } + + private static void addClass(ProgramDriver driver, String classString, String descString) { + try { + Class<?> clazz = Class.forName(classString); + driver.addClass(shortName(descString), clazz, desc(descString)); + } catch (Throwable t) { + log.warn("Unable to add class: {}", classString, t); + } + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java b/community/mahout-mr/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java new file mode 100644 index 0000000..4b2eea1 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java @@ -0,0 +1,229 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.ep; + +import java.io.Closeable; +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Lists; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.sgd.PolymorphicWritable; + +/** + * Allows evolutionary optimization where the state function can't be easily + * packaged for the optimizer to execute. A good example of this is with + * on-line learning where optimizing the learning parameters is desirable. + * We would like to pass training examples to the learning algorithms, but + * we definitely want to do the training in multiple threads and then after + * several training steps, we want to do a selection and mutation step. + * + * In such a case, it is highly desirable to leave most of the control flow + * in the hands of our caller. As such, this class provides three functions, + * <ul> + * <li> Storage of the evolutionary state. The state variables have payloads + * which can be anything that implements Payload. + * <li> Threaded execution of a single operation on each of the members of the + * population being evolved. In the on-line learning example, this is used for + * training all of the classifiers in the population. + * <li> Propagating mutations of the most successful members of the population. + * This propagation involves copying the state and the payload and then updating + * the payload after mutation of the evolutionary state. + * </ul> + * + * The State class that we use for storing the state of each member of the + * population also provides parameter mapping. Check out Mapping and State + * for more info. + * + * @see Mapping + * @see Payload + * @see State + * + * @param <T> The payload class. + */ +public class EvolutionaryProcess<T extends Payload<U>, U> implements Writable, Closeable { + // used to execute operations on the population in thread parallel. + private ExecutorService pool; + + // threadCount is serialized so that we can reconstruct the thread pool + private int threadCount; + + // list of members of the population + private List<State<T, U>> population; + + // how big should the population be. If this is changed, it will take effect + // the next time the population is mutated. + + private int populationSize; + + public EvolutionaryProcess() { + population = new ArrayList<>(); + } + + /** + * Creates an evolutionary optimization framework with specified threadiness, + * population size and initial state. + * @param threadCount How many threads to use in parallelDo + * @param populationSize How large a population to use + * @param seed An initial population member + */ + public EvolutionaryProcess(int threadCount, int populationSize, State<T, U> seed) { + this.populationSize = populationSize; + setThreadCount(threadCount); + initializePopulation(populationSize, seed); + } + + private void initializePopulation(int populationSize, State<T, U> seed) { + population = Lists.newArrayList(seed); + for (int i = 0; i < populationSize; i++) { + population.add(seed.mutate()); + } + } + + public void add(State<T, U> value) { + population.add(value); + } + + /** + * Nuke all but a few of the current population and then repopulate with + * variants of the survivors. + * @param survivors How many survivors we want to keep. + */ + public void mutatePopulation(int survivors) { + // largest value first, oldest first in case of ties + Collections.sort(population); + + // we copy here to avoid concurrent modification + List<State<T, U>> parents = new ArrayList<>(population.subList(0, survivors)); + population.subList(survivors, population.size()).clear(); + + // fill out the population with offspring from the survivors + int i = 0; + while (population.size() < populationSize) { + population.add(parents.get(i % survivors).mutate()); + i++; + } + } + + /** + * Execute an operation on all of the members of the population with many threads. The + * return value is taken as the current fitness of the corresponding member. + * @param fn What to do on each member. Gets payload and the mapped parameters as args. + * @return The member of the population with the best fitness. + * @throws InterruptedException Shouldn't happen. + * @throws ExecutionException If fn throws an exception, that exception will be collected + * and rethrown nested in an ExecutionException. + */ + public State<T, U> parallelDo(final Function<Payload<U>> fn) throws InterruptedException, ExecutionException { + Collection<Callable<State<T, U>>> tasks = new ArrayList<>(); + for (final State<T, U> state : population) { + tasks.add(new Callable<State<T, U>>() { + @Override + public State<T, U> call() { + double v = fn.apply(state.getPayload(), state.getMappedParams()); + state.setValue(v); + return state; + } + }); + } + + List<Future<State<T, U>>> r = pool.invokeAll(tasks); + + // zip through the results and find the best one + double max = Double.NEGATIVE_INFINITY; + State<T, U> best = null; + for (Future<State<T, U>> future : r) { + State<T, U> s = future.get(); + double value = s.getValue(); + if (!Double.isNaN(value) && value >= max) { + max = value; + best = s; + } + } + if (best == null) { + best = r.get(0).get(); + } + + return best; + } + + public void setThreadCount(int threadCount) { + this.threadCount = threadCount; + pool = Executors.newFixedThreadPool(threadCount); + } + + public int getThreadCount() { + return threadCount; + } + + public int getPopulationSize() { + return populationSize; + } + + public List<State<T, U>> getPopulation() { + return population; + } + + @Override + public void close() { + List<Runnable> remainingTasks = pool.shutdownNow(); + try { + pool.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + throw new IllegalStateException("Had to forcefully shut down " + remainingTasks.size() + " tasks"); + } + if (!remainingTasks.isEmpty()) { + throw new IllegalStateException("Had to forcefully shut down " + remainingTasks.size() + " tasks"); + } + } + + public interface Function<T> { + double apply(T payload, double[] params); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(threadCount); + out.writeInt(population.size()); + for (State<T, U> state : population) { + PolymorphicWritable.write(out, state); + } + } + + @Override + public void readFields(DataInput input) throws IOException { + setThreadCount(input.readInt()); + int n = input.readInt(); + population = new ArrayList<>(); + for (int i = 0; i < n; i++) { + State<T, U> state = (State<T, U>) PolymorphicWritable.read(input, State.class); + population.add(state); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/ep/Mapping.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/ep/Mapping.java b/community/mahout-mr/src/main/java/org/apache/mahout/ep/Mapping.java new file mode 100644 index 0000000..41a8942 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/ep/Mapping.java @@ -0,0 +1,206 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.ep; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.sgd.PolymorphicWritable; +import org.apache.mahout.math.function.DoubleFunction; + +/** + * Provides coordinate tranformations so that evolution can proceed on the entire space of + * reals but have the output limited and squished in convenient (and safe) ways. + */ +public abstract class Mapping extends DoubleFunction implements Writable { + + private Mapping() { + } + + public static final class SoftLimit extends Mapping { + private double min; + private double max; + private double scale; + + public SoftLimit() { + } + + private SoftLimit(double min, double max, double scale) { + this.min = min; + this.max = max; + this.scale = scale; + } + + @Override + public double apply(double v) { + return min + (max - min) * 1 / (1 + Math.exp(-v * scale)); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeDouble(min); + out.writeDouble(max); + out.writeDouble(scale); + } + + @Override + public void readFields(DataInput in) throws IOException { + min = in.readDouble(); + max = in.readDouble(); + scale = in.readDouble(); + } + } + + public static final class LogLimit extends Mapping { + private Mapping wrapped; + + public LogLimit() { + } + + private LogLimit(double low, double high) { + wrapped = softLimit(Math.log(low), Math.log(high)); + } + + @Override + public double apply(double v) { + return Math.exp(wrapped.apply(v)); + } + + @Override + public void write(DataOutput dataOutput) throws IOException { + PolymorphicWritable.write(dataOutput, wrapped); + } + + @Override + public void readFields(DataInput in) throws IOException { + wrapped = PolymorphicWritable.read(in, Mapping.class); + } + } + + public static final class Exponential extends Mapping { + private double scale; + + public Exponential() { + } + + private Exponential(double scale) { + this.scale = scale; + } + + @Override + public double apply(double v) { + return Math.exp(v * scale); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeDouble(scale); + } + + @Override + public void readFields(DataInput in) throws IOException { + scale = in.readDouble(); + } + } + + public static final class Identity extends Mapping { + @Override + public double apply(double v) { + return v; + } + + @Override + public void write(DataOutput dataOutput) { + // stateless + } + + @Override + public void readFields(DataInput dataInput) { + // stateless + } + } + + /** + * Maps input to the open interval (min, max) with 0 going to the mean of min and + * max. When scale is large, a larger proportion of values are mapped to points + * near the boundaries. When scale is small, a larger proportion of values are mapped to + * points well within the boundaries. + * @param min The largest lower bound on values to be returned. + * @param max The least upper bound on values to be returned. + * @param scale Defines how sharp the boundaries are. + * @return A mapping that satisfies the desired constraint. + */ + public static Mapping softLimit(double min, double max, double scale) { + return new SoftLimit(min, max, scale); + } + + /** + * Maps input to the open interval (min, max) with 0 going to the mean of min and + * max. When scale is large, a larger proportion of values are mapped to points + * near the boundaries. + * @see #softLimit(double, double, double) + * @param min The largest lower bound on values to be returned. + * @param max The least upper bound on values to be returned. + * @return A mapping that satisfies the desired constraint. + */ + public static Mapping softLimit(double min, double max) { + return softLimit(min, max, 1); + } + + /** + * Maps input to positive values in the open interval (min, max) with + * 0 going to the geometric mean. Near the geometric mean, values are + * distributed roughly geometrically. + * @param low The largest lower bound for output results. Must be >0. + * @param high The least upper bound for output results. Must be >0. + * @return A mapped value. + */ + public static Mapping logLimit(double low, double high) { + Preconditions.checkArgument(low > 0, "Lower bound for log limit must be > 0 but was %f", low); + Preconditions.checkArgument(high > 0, "Upper bound for log limit must be > 0 but was %f", high); + return new LogLimit(low, high); + } + + /** + * Maps results to positive values. + * @return A positive value. + */ + public static Mapping exponential() { + return exponential(1); + } + + /** + * Maps results to positive values. + * @param scale If large, then large values are more likely. + * @return A positive value. + */ + public static Mapping exponential(double scale) { + return new Exponential(scale); + } + + /** + * Maps results to themselves. + * @return The original value. + */ + public static Mapping identity() { + return new Identity(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/ep/Payload.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/ep/Payload.java b/community/mahout-mr/src/main/java/org/apache/mahout/ep/Payload.java new file mode 100644 index 0000000..920237d --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/ep/Payload.java @@ -0,0 +1,36 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.ep; + +import org.apache.hadoop.io.Writable; + +/** + * Payloads for evolutionary state must be copyable and updatable. The copy should be a deep copy + * unless some aspect of the state is sharable or immutable. + * <p/> + * During mutation, a copy is first made and then after the parameters in the State structure are + * suitably modified, update is called with the scaled versions of the parameters. + * + * @param <T> + * @see State + */ +public interface Payload<T> extends Writable { + Payload<T> copy(); + + void update(double[] params); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/ep/State.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/ep/State.java b/community/mahout-mr/src/main/java/org/apache/mahout/ep/State.java new file mode 100644 index 0000000..7a0fb5e --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/ep/State.java @@ -0,0 +1,302 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.ep; + +import com.google.common.collect.Lists; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.sgd.PolymorphicWritable; +import org.apache.mahout.common.RandomUtils; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Locale; +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Records evolutionary state and provides a mutation operation for recorded-step meta-mutation. + * + * You provide the payload, this class provides the mutation operations. During mutation, + * the payload is copied and after the state variables are changed, they are passed to the + * payload. + * + * Parameters are internally mutated in a state space that spans all of R^n, but parameters + * passed to the payload are transformed as specified by a call to setMap(). The default + * mapping is the identity map, but uniform-ish or exponential-ish coverage of a range are + * also supported. + * + * More information on the underlying algorithm can be found in the following paper + * + * http://arxiv.org/abs/0803.3838 + * + * @see Mapping + */ +public class State<T extends Payload<U>, U> implements Comparable<State<T, U>>, Writable { + + // object count is kept to break ties in comparison. + private static final AtomicInteger OBJECT_COUNT = new AtomicInteger(); + + private int id = OBJECT_COUNT.getAndIncrement(); + private Random gen = RandomUtils.getRandom(); + // current state + private double[] params; + // mappers to transform state + private Mapping[] maps; + // omni-directional mutation + private double omni; + // directional mutation + private double[] step; + // current fitness value + private double value; + private T payload; + + public State() { + } + + /** + * Invent a new state with no momentum (yet). + */ + public State(double[] x0, double omni) { + params = Arrays.copyOf(x0, x0.length); + this.omni = omni; + step = new double[params.length]; + maps = new Mapping[params.length]; + } + + /** + * Deep copies a state, useful in mutation. + */ + public State<T, U> copy() { + State<T, U> r = new State<>(); + r.params = Arrays.copyOf(this.params, this.params.length); + r.omni = this.omni; + r.step = Arrays.copyOf(this.step, this.step.length); + r.maps = Arrays.copyOf(this.maps, this.maps.length); + if (this.payload != null) { + r.payload = (T) this.payload.copy(); + } + r.gen = this.gen; + return r; + } + + /** + * Clones this state with a random change in position. Copies the payload and + * lets it know about the change. + * + * @return A new state. + */ + public State<T, U> mutate() { + double sum = 0; + for (double v : step) { + sum += v * v; + } + sum = Math.sqrt(sum); + double lambda = 1 + gen.nextGaussian(); + + State<T, U> r = this.copy(); + double magnitude = 0.9 * omni + sum / 10; + r.omni = magnitude * -Math.log1p(-gen.nextDouble()); + for (int i = 0; i < step.length; i++) { + r.step[i] = lambda * step[i] + r.omni * gen.nextGaussian(); + r.params[i] += r.step[i]; + } + if (this.payload != null) { + r.payload.update(r.getMappedParams()); + } + return r; + } + + /** + * Defines the transformation for a parameter. + * @param i Which parameter's mapping to define. + * @param m The mapping to use. + * @see org.apache.mahout.ep.Mapping + */ + public void setMap(int i, Mapping m) { + maps[i] = m; + } + + /** + * Returns a transformed parameter. + * @param i The parameter to return. + * @return The value of the parameter. + */ + public double get(int i) { + Mapping m = maps[i]; + return m == null ? params[i] : m.apply(params[i]); + } + + public int getId() { + return id; + } + + public double[] getParams() { + return params; + } + + public Mapping[] getMaps() { + return maps; + } + + /** + * Returns all the parameters in mapped form. + * @return An array of parameters. + */ + public double[] getMappedParams() { + double[] r = Arrays.copyOf(params, params.length); + for (int i = 0; i < params.length; i++) { + r[i] = get(i); + } + return r; + } + + public double getOmni() { + return omni; + } + + public double[] getStep() { + return step; + } + + public T getPayload() { + return payload; + } + + public double getValue() { + return value; + } + + public void setOmni(double omni) { + this.omni = omni; + } + + public void setId(int id) { + this.id = id; + } + + public void setStep(double[] step) { + this.step = step; + } + + public void setMaps(Mapping[] maps) { + this.maps = maps; + } + + public void setMaps(Iterable<Mapping> maps) { + Collection<Mapping> list = Lists.newArrayList(maps); + this.maps = list.toArray(new Mapping[list.size()]); + } + + public void setValue(double v) { + value = v; + } + + public void setPayload(T payload) { + this.payload = payload; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof State)) { + return false; + } + State<?,?> other = (State<?,?>) o; + return id == other.id && value == other.value; + } + + @Override + public int hashCode() { + return RandomUtils.hashDouble(value) ^ id; + } + + /** + * Natural order is to sort in descending order of score. Creation order is used as a + * tie-breaker. + * + * @param other The state to compare with. + * @return -1, 0, 1 if the other state is better, identical or worse than this one. + */ + @Override + public int compareTo(State<T, U> other) { + int r = Double.compare(other.value, this.value); + if (r != 0) { + return r; + } + if (this.id < other.id) { + return -1; + } + if (this.id > other.id) { + return 1; + } + return 0; + } + + @Override + public String toString() { + double sum = 0; + for (double v : step) { + sum += v * v; + } + return String.format(Locale.ENGLISH, "<S/%s %.3f %.3f>", payload, omni + Math.sqrt(sum), value); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(id); + out.writeInt(params.length); + for (double v : params) { + out.writeDouble(v); + } + for (Mapping map : maps) { + PolymorphicWritable.write(out, map); + } + + out.writeDouble(omni); + for (double v : step) { + out.writeDouble(v); + } + + out.writeDouble(value); + PolymorphicWritable.write(out, payload); + } + + @Override + public void readFields(DataInput input) throws IOException { + id = input.readInt(); + int n = input.readInt(); + params = new double[n]; + for (int i = 0; i < n; i++) { + params[i] = input.readDouble(); + } + + maps = new Mapping[n]; + for (int i = 0; i < n; i++) { + maps[i] = PolymorphicWritable.read(input, Mapping.class); + } + omni = input.readDouble(); + step = new double[n]; + for (int i = 0; i < n; i++) { + step[i] = input.readDouble(); + } + value = input.readDouble(); + payload = (T) PolymorphicWritable.read(input, Payload.class); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/ep/package-info.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/ep/package-info.java b/community/mahout-mr/src/main/java/org/apache/mahout/ep/package-info.java new file mode 100644 index 0000000..4afe677 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/ep/package-info.java @@ -0,0 +1,26 @@ +/** + * <p>Provides basic evolutionary optimization using <a href="http://arxiv.org/abs/0803.3838">recorded-step</a> + * mutation.</p> + * + * <p>With this style of optimization, we can optimize a function {@code f: R^n -> R} by stochastic + * hill-climbing with some of the benefits of conjugate gradient style history encoded in the mutation function. + * This mutation function will adapt to allow weakly directed search rather than using the somewhat more + * conventional symmetric Gaussian.</p> + * + * <p>With recorded-step mutation, the meta-mutation parameters are all auto-encoded in the current state of each point. + * This avoids the classic problem of having more mutation rate parameters than are in the original state and then + * requiring even more parameters to describe the meta-mutation rate. Instead, we store the previous point and one + * omni-directional mutation component. Mutation is performed by first mutating along the line formed by the previous + * and current points and then adding a scaled symmetric Gaussian. The magnitude of the omni-directional mutation is + * then mutated using itself as a scale.</p> + * + * <p>Because it is convenient to not restrict the parameter space, this package also provides convenient parameter + * mapping methods. These mapping methods map the set of reals to a finite open interval (a,b) in such a way that + * {@code lim_{x->-\inf} f(x) = a} and {@code lim_{x->\inf} f(x) = b}. The linear mapping is defined so that + * {@code f(0) = (a+b)/2} and the exponential mapping requires that a and b are both positive and has + * {@code f(0) = sqrt(ab)}. The linear mapping is useful for values that must stay roughly within a range but + * which are roughly uniform within the center of that range. The exponential + * mapping is useful for values that must stay within a range but whose distribution is roughly exponential near + * geometric mean of the end-points. An identity mapping is also supplied.</p> + */ +package org.apache.mahout.ep; http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java new file mode 100644 index 0000000..6618a1a --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.math; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; + +import java.io.IOException; + +public final class DistributedRowMatrixWriter { + + private DistributedRowMatrixWriter() { + } + + public static void write(Path outputDir, Configuration conf, Iterable<MatrixSlice> matrix) throws IOException { + FileSystem fs = outputDir.getFileSystem(conf); + SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, outputDir, + IntWritable.class, VectorWritable.class); + IntWritable topic = new IntWritable(); + VectorWritable vector = new VectorWritable(); + for (MatrixSlice slice : matrix) { + topic.set(slice.index()); + vector.set(slice.vector()); + writer.append(topic, vector); + } + writer.close(); + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/MatrixUtils.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/MatrixUtils.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/MatrixUtils.java new file mode 100644 index 0000000..f9ca52e --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/MatrixUtils.java @@ -0,0 +1,114 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.math; + +import com.google.common.collect.Lists; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.map.OpenObjectIntHashMap; + +import java.io.IOException; +import java.util.List; + +public final class MatrixUtils { + + private MatrixUtils() { + } + + public static void write(Path outputDir, Configuration conf, VectorIterable matrix) + throws IOException { + FileSystem fs = outputDir.getFileSystem(conf); + fs.delete(outputDir, true); + SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, outputDir, + IntWritable.class, VectorWritable.class); + IntWritable topic = new IntWritable(); + VectorWritable vector = new VectorWritable(); + for (MatrixSlice slice : matrix) { + topic.set(slice.index()); + vector.set(slice.vector()); + writer.append(topic, vector); + } + writer.close(); + } + + public static Matrix read(Configuration conf, Path... modelPaths) throws IOException { + int numRows = -1; + int numCols = -1; + boolean sparse = false; + List<Pair<Integer, Vector>> rows = Lists.newArrayList(); + for (Path modelPath : modelPaths) { + for (Pair<IntWritable, VectorWritable> row + : new SequenceFileIterable<IntWritable, VectorWritable>(modelPath, true, conf)) { + rows.add(Pair.of(row.getFirst().get(), row.getSecond().get())); + numRows = Math.max(numRows, row.getFirst().get()); + sparse = !row.getSecond().get().isDense(); + if (numCols < 0) { + numCols = row.getSecond().get().size(); + } + } + } + if (rows.isEmpty()) { + throw new IOException(Arrays.toString(modelPaths) + " have no vectors in it"); + } + numRows++; + Vector[] arrayOfRows = new Vector[numRows]; + for (Pair<Integer, Vector> pair : rows) { + arrayOfRows[pair.getFirst()] = pair.getSecond(); + } + Matrix matrix; + if (sparse) { + matrix = new SparseRowMatrix(numRows, numCols, arrayOfRows); + } else { + matrix = new DenseMatrix(numRows, numCols); + for (int i = 0; i < numRows; i++) { + matrix.assignRow(i, arrayOfRows[i]); + } + } + return matrix; + } + + public static OpenObjectIntHashMap<String> readDictionary(Configuration conf, Path... dictPath) { + OpenObjectIntHashMap<String> dictionary = new OpenObjectIntHashMap<>(); + for (Path dictionaryFile : dictPath) { + for (Pair<Writable, IntWritable> record + : new SequenceFileIterable<Writable, IntWritable>(dictionaryFile, true, conf)) { + dictionary.put(record.getFirst().toString(), record.getSecond().get()); + } + } + return dictionary; + } + + public static String[] invertDictionary(OpenObjectIntHashMap<String> termIdMap) { + int maxTermId = -1; + for (String term : termIdMap.keys()) { + maxTermId = Math.max(maxTermId, termIdMap.get(term)); + } + maxTermId++; + String[] dictionary = new String[maxTermId]; + for (String term : termIdMap.keys()) { + dictionary[termIdMap.get(term)] = term; + } + return dictionary; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/MultiLabelVectorWritable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/MultiLabelVectorWritable.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/MultiLabelVectorWritable.java new file mode 100644 index 0000000..0c45c9a --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/MultiLabelVectorWritable.java @@ -0,0 +1,88 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math; + +import org.apache.hadoop.io.Writable; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +/** + * Writable to handle serialization of a vector and a variable list of + * associated label indexes. + */ +public final class MultiLabelVectorWritable implements Writable { + + private final VectorWritable vectorWritable = new VectorWritable(); + private int[] labels; + + public MultiLabelVectorWritable() { + } + + public MultiLabelVectorWritable(Vector vector, int[] labels) { + this.vectorWritable.set(vector); + this.labels = labels; + } + + public Vector getVector() { + return vectorWritable.get(); + } + + public void setVector(Vector vector) { + vectorWritable.set(vector); + } + + public void setLabels(int[] labels) { + this.labels = labels; + } + + public int[] getLabels() { + return labels; + } + + @Override + public void readFields(DataInput in) throws IOException { + vectorWritable.readFields(in); + int labelSize = in.readInt(); + labels = new int[labelSize]; + for (int i = 0; i < labelSize; i++) { + labels[i] = in.readInt(); + } + } + + @Override + public void write(DataOutput out) throws IOException { + vectorWritable.write(out); + out.writeInt(labels.length); + for (int label : labels) { + out.writeInt(label); + } + } + + public static MultiLabelVectorWritable read(DataInput in) throws IOException { + MultiLabelVectorWritable writable = new MultiLabelVectorWritable(); + writable.readFields(in); + return writable; + } + + public static void write(DataOutput out, SequentialAccessSparseVector ssv, int[] labels) throws IOException { + new MultiLabelVectorWritable(ssv, labels).write(out); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java new file mode 100644 index 0000000..dbe1f8b --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/als/AlternatingLeastSquaresSolver.java @@ -0,0 +1,116 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.als; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.QRDecomposition; +import org.apache.mahout.math.Vector; + +/** + * See + * <a href="http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf"> + * this paper.</a> + */ +public final class AlternatingLeastSquaresSolver { + + private AlternatingLeastSquaresSolver() {} + + //TODO make feature vectors a simple array + public static Vector solve(Iterable<Vector> featureVectors, Vector ratingVector, double lambda, int numFeatures) { + + Preconditions.checkNotNull(featureVectors, "Feature Vectors cannot be null"); + Preconditions.checkArgument(!Iterables.isEmpty(featureVectors)); + Preconditions.checkNotNull(ratingVector, "Rating Vector cannot be null"); + Preconditions.checkArgument(ratingVector.getNumNondefaultElements() > 0, "Rating Vector cannot be empty"); + Preconditions.checkArgument(Iterables.size(featureVectors) == ratingVector.getNumNondefaultElements()); + + int nui = ratingVector.getNumNondefaultElements(); + + Matrix MiIi = createMiIi(featureVectors, numFeatures); + Matrix RiIiMaybeTransposed = createRiIiMaybeTransposed(ratingVector); + + /* compute Ai = MiIi * t(MiIi) + lambda * nui * E */ + Matrix Ai = miTimesMiTransposePlusLambdaTimesNuiTimesE(MiIi, lambda, nui); + /* compute Vi = MiIi * t(R(i,Ii)) */ + Matrix Vi = MiIi.times(RiIiMaybeTransposed); + /* compute Ai * ui = Vi */ + return solve(Ai, Vi); + } + + private static Vector solve(Matrix Ai, Matrix Vi) { + return new QRDecomposition(Ai).solve(Vi).viewColumn(0); + } + + static Matrix addLambdaTimesNuiTimesE(Matrix matrix, double lambda, int nui) { + Preconditions.checkArgument(matrix.numCols() == matrix.numRows(), "Must be a Square Matrix"); + double lambdaTimesNui = lambda * nui; + int numCols = matrix.numCols(); + for (int n = 0; n < numCols; n++) { + matrix.setQuick(n, n, matrix.getQuick(n, n) + lambdaTimesNui); + } + return matrix; + } + + private static Matrix miTimesMiTransposePlusLambdaTimesNuiTimesE(Matrix MiIi, double lambda, int nui) { + + double lambdaTimesNui = lambda * nui; + int rows = MiIi.numRows(); + + double[][] result = new double[rows][rows]; + + for (int i = 0; i < rows; i++) { + for (int j = i; j < rows; j++) { + double dot = MiIi.viewRow(i).dot(MiIi.viewRow(j)); + if (i != j) { + result[i][j] = dot; + result[j][i] = dot; + } else { + result[i][i] = dot + lambdaTimesNui; + } + } + } + return new DenseMatrix(result, true); + } + + + static Matrix createMiIi(Iterable<Vector> featureVectors, int numFeatures) { + double[][] MiIi = new double[numFeatures][Iterables.size(featureVectors)]; + int n = 0; + for (Vector featureVector : featureVectors) { + for (int m = 0; m < numFeatures; m++) { + MiIi[m][n] = featureVector.getQuick(m); + } + n++; + } + return new DenseMatrix(MiIi, true); + } + + static Matrix createRiIiMaybeTransposed(Vector ratingVector) { + Preconditions.checkArgument(ratingVector.isSequentialAccess(), "Ratings should be iterable in Index or Sequential Order"); + + double[][] RiIiMaybeTransposed = new double[ratingVector.getNumNondefaultElements()][1]; + int index = 0; + for (Vector.Element elem : ratingVector.nonZeroes()) { + RiIiMaybeTransposed[index++][0] = elem.get(); + } + return new DenseMatrix(RiIiMaybeTransposed, true); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java new file mode 100644 index 0000000..5d77898 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/als/ImplicitFeedbackAlternatingLeastSquaresSolver.java @@ -0,0 +1,171 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.als; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.QRDecomposition; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.list.IntArrayList; +import org.apache.mahout.math.map.OpenIntObjectHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.Preconditions; + +/** see <a href="http://research.yahoo.com/pub/2433">Collaborative Filtering for Implicit Feedback Datasets</a> */ +public class ImplicitFeedbackAlternatingLeastSquaresSolver { + + private final int numFeatures; + private final double alpha; + private final double lambda; + private final int numTrainingThreads; + + private final OpenIntObjectHashMap<Vector> Y; + private final Matrix YtransposeY; + + private static final Logger log = LoggerFactory.getLogger(ImplicitFeedbackAlternatingLeastSquaresSolver.class); + + public ImplicitFeedbackAlternatingLeastSquaresSolver(int numFeatures, double lambda, double alpha, + OpenIntObjectHashMap<Vector> Y, int numTrainingThreads) { + this.numFeatures = numFeatures; + this.lambda = lambda; + this.alpha = alpha; + this.Y = Y; + this.numTrainingThreads = numTrainingThreads; + YtransposeY = getYtransposeY(Y); + } + + public Vector solve(Vector ratings) { + return solve(YtransposeY.plus(getYtransponseCuMinusIYPlusLambdaI(ratings)), getYtransponseCuPu(ratings)); + } + + private static Vector solve(Matrix A, Matrix y) { + return new QRDecomposition(A).solve(y).viewColumn(0); + } + + double confidence(double rating) { + return 1 + alpha * rating; + } + + /* Y' Y */ + public Matrix getYtransposeY(final OpenIntObjectHashMap<Vector> Y) { + + ExecutorService queue = Executors.newFixedThreadPool(numTrainingThreads); + if (log.isInfoEnabled()) { + log.info("Starting the computation of Y'Y"); + } + long startTime = System.nanoTime(); + final IntArrayList indexes = Y.keys(); + final int numIndexes = indexes.size(); + + final double[][] YtY = new double[numFeatures][numFeatures]; + + // Compute Y'Y by dot products between the 'columns' of Y + for (int i = 0; i < numFeatures; i++) { + for (int j = i; j < numFeatures; j++) { + + final int ii = i; + final int jj = j; + queue.execute(new Runnable() { + @Override + public void run() { + double dot = 0; + for (int k = 0; k < numIndexes; k++) { + Vector row = Y.get(indexes.getQuick(k)); + dot += row.getQuick(ii) * row.getQuick(jj); + } + YtY[ii][jj] = dot; + if (ii != jj) { + YtY[jj][ii] = dot; + } + } + }); + + } + } + queue.shutdown(); + try { + queue.awaitTermination(1, TimeUnit.DAYS); + } catch (InterruptedException e) { + log.error("Error during Y'Y queue shutdown", e); + throw new RuntimeException("Error during Y'Y queue shutdown"); + } + if (log.isInfoEnabled()) { + log.info("Computed Y'Y in " + (System.nanoTime() - startTime) / 1000000.0 + " ms" ); + } + return new DenseMatrix(YtY, true); + } + + /** Y' (Cu - I) Y + λ I */ + private Matrix getYtransponseCuMinusIYPlusLambdaI(Vector userRatings) { + Preconditions.checkArgument(userRatings.isSequentialAccess(), "need sequential access to ratings!"); + + /* (Cu -I) Y */ + OpenIntObjectHashMap<Vector> CuMinusIY = new OpenIntObjectHashMap<>(userRatings.getNumNondefaultElements()); + for (Element e : userRatings.nonZeroes()) { + CuMinusIY.put(e.index(), Y.get(e.index()).times(confidence(e.get()) - 1)); + } + + Matrix YtransponseCuMinusIY = new DenseMatrix(numFeatures, numFeatures); + + /* Y' (Cu -I) Y by outer products */ + for (Element e : userRatings.nonZeroes()) { + for (Element feature : Y.get(e.index()).all()) { + Vector partial = CuMinusIY.get(e.index()).times(feature.get()); + YtransponseCuMinusIY.viewRow(feature.index()).assign(partial, Functions.PLUS); + } + } + + /* Y' (Cu - I) Y + λ I add lambda on the diagonal */ + for (int feature = 0; feature < numFeatures; feature++) { + YtransponseCuMinusIY.setQuick(feature, feature, YtransponseCuMinusIY.getQuick(feature, feature) + lambda); + } + + return YtransponseCuMinusIY; + } + + /** Y' Cu p(u) */ + private Matrix getYtransponseCuPu(Vector userRatings) { + Preconditions.checkArgument(userRatings.isSequentialAccess(), "need sequential access to ratings!"); + + Vector YtransponseCuPu = new DenseVector(numFeatures); + + for (Element e : userRatings.nonZeroes()) { + YtransponseCuPu.assign(Y.get(e.index()).times(confidence(e.get())), Functions.PLUS); + } + + return columnVectorAsMatrix(YtransponseCuPu); + } + + private Matrix columnVectorAsMatrix(Vector v) { + double[][] matrix = new double[numFeatures][1]; + for (Element e : v.all()) { + matrix[e.index()][0] = e.get(); + } + return new DenseMatrix(matrix, true); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/AsyncEigenVerifier.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/AsyncEigenVerifier.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/AsyncEigenVerifier.java new file mode 100644 index 0000000..0233848 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/AsyncEigenVerifier.java @@ -0,0 +1,80 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.decomposer; + +import java.io.Closeable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorIterable; + +public class AsyncEigenVerifier extends SimpleEigenVerifier implements Closeable { + + private final ExecutorService threadPool; + private EigenStatus status; + private boolean finished; + private boolean started; + + public AsyncEigenVerifier() { + threadPool = Executors.newFixedThreadPool(1); + status = new EigenStatus(-1, 0); + } + + @Override + public synchronized EigenStatus verify(VectorIterable corpus, Vector vector) { + if (!finished && !started) { // not yet started or finished, so start! + status = new EigenStatus(-1, 0); + Vector vectorCopy = vector.clone(); + threadPool.execute(new VerifierRunnable(corpus, vectorCopy)); + started = true; + } + if (finished) { + finished = false; + } + return status; + } + + @Override + public void close() { + this.threadPool.shutdownNow(); + } + protected EigenStatus innerVerify(VectorIterable corpus, Vector vector) { + return super.verify(corpus, vector); + } + + private class VerifierRunnable implements Runnable { + private final VectorIterable corpus; + private final Vector vector; + + protected VerifierRunnable(VectorIterable corpus, Vector vector) { + this.corpus = corpus; + this.vector = vector; + } + + @Override + public void run() { + EigenStatus status = innerVerify(corpus, vector); + synchronized (AsyncEigenVerifier.this) { + AsyncEigenVerifier.this.status = status; + finished = true; + started = false; + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/EigenStatus.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/EigenStatus.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/EigenStatus.java new file mode 100644 index 0000000..a284f50 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/EigenStatus.java @@ -0,0 +1,50 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.decomposer; + +public class EigenStatus { + private final double eigenValue; + private final double cosAngle; + private volatile Boolean inProgress; + + public EigenStatus(double eigenValue, double cosAngle) { + this(eigenValue, cosAngle, true); + } + + public EigenStatus(double eigenValue, double cosAngle, boolean inProgress) { + this.eigenValue = eigenValue; + this.cosAngle = cosAngle; + this.inProgress = inProgress; + } + + public double getCosAngle() { + return cosAngle; + } + + public double getEigenValue() { + return eigenValue; + } + + public boolean inProgress() { + return inProgress; + } + + void setInProgress(boolean status) { + inProgress = status; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/SimpleEigenVerifier.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/SimpleEigenVerifier.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/SimpleEigenVerifier.java new file mode 100644 index 0000000..71aaa30 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/SimpleEigenVerifier.java @@ -0,0 +1,41 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.math.decomposer; + +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorIterable; + +public class SimpleEigenVerifier implements SingularVectorVerifier { + + @Override + public EigenStatus verify(VectorIterable corpus, Vector vector) { + Vector resultantVector = corpus.timesSquared(vector); + double newNorm = resultantVector.norm(2); + double oldNorm = vector.norm(2); + double eigenValue; + double cosAngle; + if (newNorm > 0 && oldNorm > 0) { + eigenValue = newNorm / oldNorm; + cosAngle = resultantVector.dot(vector) / newNorm * oldNorm; + } else { + eigenValue = 1.0; + cosAngle = 0.0; + } + return new EigenStatus(eigenValue, cosAngle, false); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/SingularVectorVerifier.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/SingularVectorVerifier.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/SingularVectorVerifier.java new file mode 100644 index 0000000..a9a7af8 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/SingularVectorVerifier.java @@ -0,0 +1,25 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.decomposer; + +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorIterable; + +public interface SingularVectorVerifier { + EigenStatus verify(VectorIterable eigenMatrix, Vector vector); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/hebbian/EigenUpdater.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/hebbian/EigenUpdater.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/hebbian/EigenUpdater.java new file mode 100644 index 0000000..ac9cc41 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/hebbian/EigenUpdater.java @@ -0,0 +1,25 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.decomposer.hebbian; + +import org.apache.mahout.math.Vector; + + +public interface EigenUpdater { + void update(Vector pseudoEigen, Vector trainingVector, TrainingState currentState); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianSolver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianSolver.java b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianSolver.java new file mode 100644 index 0000000..5b5cc9b --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/math/decomposer/hebbian/HebbianSolver.java @@ -0,0 +1,342 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.decomposer.hebbian; + +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; +import java.util.Random; + +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.decomposer.AsyncEigenVerifier; +import org.apache.mahout.math.decomposer.EigenStatus; +import org.apache.mahout.math.decomposer.SingularVectorVerifier; +import org.apache.mahout.math.function.PlusMult; +import org.apache.mahout.math.function.TimesFunction; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The Hebbian solver is an iterative, sparse, singular value decomposition solver, based on the paper + * <a href="http://www.dcs.shef.ac.uk/~genevieve/gorrell_webb.pdf">Generalized Hebbian Algorithm for + * Latent Semantic Analysis</a> (2005) by Genevieve Gorrell and Brandyn Webb (a.k.a. Simon Funk). + * TODO: more description here! For now: read the inline comments, and the comments for the constructors. + */ +public class HebbianSolver { + + private static final Logger log = LoggerFactory.getLogger(HebbianSolver.class); + private static final boolean DEBUG = false; + + private final EigenUpdater updater; + private final SingularVectorVerifier verifier; + private final double convergenceTarget; + private final int maxPassesPerEigen; + private final Random rng = RandomUtils.getRandom(); + + private int numPasses = 0; + + /** + * Creates a new HebbianSolver + * + * @param updater + * {@link EigenUpdater} used to do the actual work of iteratively updating the current "best guess" + * singular vector one data-point presentation at a time. + * @param verifier + * {@link SingularVectorVerifier } an object which perpetually tries to check how close to + * convergence the current singular vector is (typically is a + * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } which does this + * in the background in another thread, while the main thread continues to converge) + * @param convergenceTarget a small "epsilon" value which tells the solver how small you want the cosine of the + * angle between a proposed eigenvector and that same vector after being multiplied by the (square of the) input + * corpus + * @param maxPassesPerEigen a cutoff which tells the solver after how many times of checking for convergence (done + * by the verifier) should the solver stop trying, even if it has not reached the convergenceTarget. + */ + public HebbianSolver(EigenUpdater updater, + SingularVectorVerifier verifier, + double convergenceTarget, + int maxPassesPerEigen) { + this.updater = updater; + this.verifier = verifier; + this.convergenceTarget = convergenceTarget; + this.maxPassesPerEigen = maxPassesPerEigen; + } + + /** + * Creates a new HebbianSolver with maxPassesPerEigen = Integer.MAX_VALUE (i.e. keep on iterating until + * convergenceTarget is reached). <b>Not recommended</b> unless only looking for + * the first few (5, maybe 10?) singular + * vectors, as small errors which compound early on quickly put a minimum error on subsequent vectors. + * + * @param updater {@link EigenUpdater} used to do the actual work of iteratively updating the current "best guess" + * singular vector one data-point presentation at a time. + * @param verifier {@link org.apache.mahout.math.decomposer.SingularVectorVerifier } + * an object which perpetually tries to check how close to + * convergence the current singular vector is (typically is a + * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } which does this + * in the background in another thread, while the main thread continues to converge) + * @param convergenceTarget a small "epsilon" value which tells the solver how small you want the cosine of the + * angle between a proposed eigenvector and that same vector after being multiplied by the (square of the) input + * corpus + */ + public HebbianSolver(EigenUpdater updater, + SingularVectorVerifier verifier, + double convergenceTarget) { + this(updater, + verifier, + convergenceTarget, + Integer.MAX_VALUE); + } + + /** + * <b>This is the recommended constructor to use if you're not sure</b> + * Creates a new HebbianSolver with the default {@link HebbianUpdater } to do the updating work, and the default + * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } to check for convergence in a + * (single) background thread. + * + * @param convergenceTarget a small "epsilon" value which tells the solver how small you want the cosine of the + * angle between a proposed eigenvector and that same vector after being multiplied by the (square of the) input + * corpus + * @param maxPassesPerEigen a cutoff which tells the solver after how many times of checking for convergence (done + * by the verifier) should the solver stop trying, even if it has not reached the convergenceTarget. + */ + public HebbianSolver(double convergenceTarget, int maxPassesPerEigen) { + this(new HebbianUpdater(), + new AsyncEigenVerifier(), + convergenceTarget, + maxPassesPerEigen); + } + + /** + * Creates a new HebbianSolver with the default {@link HebbianUpdater } to do the updating work, and the default + * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } to check for convergence in a (single) + * background thread, with + * maxPassesPerEigen set to Integer.MAX_VALUE. <b>Not recommended</b> unless only looking + * for the first few (5, maybe 10?) singular + * vectors, as small errors which compound early on quickly put a minimum error on subsequent vectors. + * + * @param convergenceTarget a small "epsilon" value which tells the solver how small you want the cosine of the + * angle between a proposed eigenvector and that same vector after being multiplied by the (square of the) input + * corpus + */ + public HebbianSolver(double convergenceTarget) { + this(convergenceTarget, Integer.MAX_VALUE); + } + + /** + * Creates a new HebbianSolver with the default {@link HebbianUpdater } to do the updating work, and the default + * {@link org.apache.mahout.math.decomposer.AsyncEigenVerifier } to check for convergence in a (single) + * background thread, with + * convergenceTarget set to 0, which means that the solver will not really care about convergence as a loop-exiting + * criterion (but will be checking for convergence anyways, so it will be logged and singular values will be + * saved). + * + * @param numPassesPerEigen the exact number of times the verifier will check convergence status in the background + * before the solver will move on to the next eigen-vector. + */ + public HebbianSolver(int numPassesPerEigen) { + this(0.0, numPassesPerEigen); + } + + /** + * Primary singular vector solving method. + * + * @param corpus input matrix to find singular vectors of. Needs not be symmetric, should probably be sparse (in + * fact the input vectors are not mutated, and accessed only via dot-products and sums, so they should be + * {@link org.apache.mahout.math.SequentialAccessSparseVector } + * @param desiredRank the number of singular vectors to find (in roughly decreasing order by singular value) + * @return the final {@link TrainingState } of the solver, after desiredRank singular vectors (and approximate + * singular values) have been found. + */ + public TrainingState solve(Matrix corpus, + int desiredRank) { + int cols = corpus.numCols(); + Matrix eigens = new DenseMatrix(desiredRank, cols); + List<Double> eigenValues = new ArrayList<>(); + log.info("Finding {} singular vectors of matrix with {} rows, via Hebbian", desiredRank, corpus.numRows()); + /* + * The corpusProjections matrix is a running cache of the residual projection of each corpus vector against all + * of the previously found singular vectors. Without this, if multiple passes over the data is made (per + * singular vector), recalculating these projections eventually dominates the computational complexity of the + * solver. + */ + Matrix corpusProjections = new DenseMatrix(corpus.numRows(), desiredRank); + TrainingState state = new TrainingState(eigens, corpusProjections); + for (int i = 0; i < desiredRank; i++) { + Vector currentEigen = new DenseVector(cols); + Vector previousEigen = null; + while (hasNotConverged(currentEigen, corpus, state)) { + int randomStartingIndex = getRandomStartingIndex(corpus, eigens); + Vector initialTrainingVector = corpus.viewRow(randomStartingIndex); + state.setTrainingIndex(randomStartingIndex); + updater.update(currentEigen, initialTrainingVector, state); + for (int corpusRow = 0; corpusRow < corpus.numRows(); corpusRow++) { + state.setTrainingIndex(corpusRow); + if (corpusRow != randomStartingIndex) { + updater.update(currentEigen, corpus.viewRow(corpusRow), state); + } + } + state.setFirstPass(false); + if (DEBUG) { + if (previousEigen == null) { + previousEigen = currentEigen.clone(); + } else { + double dot = currentEigen.dot(previousEigen); + if (dot > 0.0) { + dot /= currentEigen.norm(2) * previousEigen.norm(2); + } + // log.info("Current pass * previous pass = {}", dot); + } + } + } + // converged! + double eigenValue = state.getStatusProgress().get(state.getStatusProgress().size() - 1).getEigenValue(); + // it's actually more efficient to do this to normalize than to call currentEigen = currentEigen.normalize(), + // because the latter does a clone, which isn't necessary here. + currentEigen.assign(new TimesFunction(), 1 / currentEigen.norm(2)); + eigens.assignRow(i, currentEigen); + eigenValues.add(eigenValue); + state.setCurrentEigenValues(eigenValues); + log.info("Found eigenvector {}, eigenvalue: {}", i, eigenValue); + + /** + * TODO: Persist intermediate output! + */ + state.setFirstPass(true); + state.setNumEigensProcessed(state.getNumEigensProcessed() + 1); + state.setActivationDenominatorSquared(0); + state.setActivationNumerator(0); + state.getStatusProgress().clear(); + numPasses = 0; + } + return state; + } + + /** + * You have to start somewhere... + * TODO: start instead wherever you find a vector with maximum residual length after subtracting off the projection + * TODO: onto all previous eigenvectors. + * + * @param corpus the corpus matrix + * @param eigens not currently used, but should be (see above TODO) + * @return the index into the corpus where the "starting seed" input vector lies. + */ + private int getRandomStartingIndex(Matrix corpus, Matrix eigens) { + int index; + Vector v; + do { + double r = rng.nextDouble(); + index = (int) (r * corpus.numRows()); + v = corpus.viewRow(index); + } while (v == null || v.norm(2) == 0 || v.getNumNondefaultElements() < 5); + return index; + } + + /** + * Uses the {@link SingularVectorVerifier } to check for convergence + * + * @param currentPseudoEigen the purported singular vector whose convergence is being checked + * @param corpus the corpus to check against + * @param state contains the previous eigens, various other solving state {@link TrainingState} + * @return true if <em>either</em> we have converged, <em>or</em> maxPassesPerEigen has been exceeded. + */ + protected boolean hasNotConverged(Vector currentPseudoEigen, + Matrix corpus, + TrainingState state) { + numPasses++; + if (state.isFirstPass()) { + log.info("First pass through the corpus, no need to check convergence..."); + return true; + } + Matrix previousEigens = state.getCurrentEigens(); + log.info("Have made {} passes through the corpus, checking convergence...", numPasses); + /* + * Step 1: orthogonalize currentPseudoEigen by subtracting off eigen(i) * helper.get(i) + * Step 2: zero-out the helper vector because it has already helped. + */ + for (int i = 0; i < state.getNumEigensProcessed(); i++) { + Vector previousEigen = previousEigens.viewRow(i); + currentPseudoEigen.assign(previousEigen, new PlusMult(-state.getHelperVector().get(i))); + state.getHelperVector().set(i, 0); + } + if (currentPseudoEigen.norm(2) > 0) { + for (int i = 0; i < state.getNumEigensProcessed(); i++) { + Vector previousEigen = previousEigens.viewRow(i); + log.info("dot with previous: {}", previousEigen.dot(currentPseudoEigen) / currentPseudoEigen.norm(2)); + } + } + /* + * Step 3: verify how eigen-like the prospective eigen is. This is potentially asynchronous. + */ + EigenStatus status = verify(corpus, currentPseudoEigen); + if (status.inProgress()) { + log.info("Verifier not finished, making another pass..."); + } else { + log.info("Has 1 - cosAngle: {}, convergence target is: {}", 1.0 - status.getCosAngle(), convergenceTarget); + state.getStatusProgress().add(status); + } + return + state.getStatusProgress().size() <= maxPassesPerEigen + && 1.0 - status.getCosAngle() > convergenceTarget; + } + + protected EigenStatus verify(Matrix corpus, Vector currentPseudoEigen) { + return verifier.verify(corpus, currentPseudoEigen); + } + + public static void main(String[] args) { + Properties props = new Properties(); + String propertiesFile = args.length > 0 ? args[0] : "config/solver.properties"; + // props.load(new FileInputStream(propertiesFile)); + + String corpusDir = props.getProperty("solver.input.dir"); + String outputDir = props.getProperty("solver.output.dir"); + if (corpusDir == null || corpusDir.isEmpty() || outputDir == null || outputDir.isEmpty()) { + log.error("{} must contain values for solver.input.dir and solver.output.dir", propertiesFile); + return; + } + //int inBufferSize = Integer.parseInt(props.getProperty("solver.input.bufferSize")); + int rank = Integer.parseInt(props.getProperty("solver.output.desiredRank")); + double convergence = Double.parseDouble(props.getProperty("solver.convergence")); + int maxPasses = Integer.parseInt(props.getProperty("solver.maxPasses")); + //int numThreads = Integer.parseInt(props.getProperty("solver.verifier.numThreads")); + + HebbianUpdater updater = new HebbianUpdater(); + SingularVectorVerifier verifier = new AsyncEigenVerifier(); + HebbianSolver solver = new HebbianSolver(updater, verifier, convergence, maxPasses); + Matrix corpus = null; + /* + if (numThreads <= 1) { + // corpus = new DiskBufferedDoubleMatrix(new File(corpusDir), inBufferSize); + } else { + // corpus = new ParallelMultiplyingDiskBufferedDoubleMatrix(new File(corpusDir), inBufferSize, numThreads); + } + */ + long now = System.currentTimeMillis(); + TrainingState finalState = solver.solve(corpus, rank); + long time = (System.currentTimeMillis() - now) / 1000; + log.info("Solved {} eigenVectors in {} seconds. Persisted to {}", + finalState.getCurrentEigens().rowSize(), time, outputDir); + } + + +}
