[REEF-118] Add Shimoga library for elastic group communication. Shimoga is REEF library for elastic group communication. It provides MPI-style operators like Broadcast and Reduce for inter-task messaging.
JIRA: [REEF-118](https://issues.apache.org/jira/browse/REEF-118) Pull Request: This closes #63 Project: http://git-wip-us.apache.org/repos/asf/incubator-reef/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-reef/commit/6c6ad336 Tree: http://git-wip-us.apache.org/repos/asf/incubator-reef/tree/6c6ad336 Diff: http://git-wip-us.apache.org/repos/asf/incubator-reef/diff/6c6ad336 Branch: refs/heads/master Commit: 6c6ad33674c6e61e44015e0632023e776b07536e Parents: 0911c08 Author: Sergiy Matusevych <[email protected]> Authored: Thu Feb 12 14:22:58 2015 -0800 Committer: Markus Weimer <[email protected]> Committed: Thu Mar 5 17:52:55 2015 -0800 ---------------------------------------------------------------------- lang/java/reef-examples/pom.xml | 24 - .../reef/examples/group/bgd/BGDClient.java | 134 +++++ .../reef/examples/group/bgd/BGDDriver.java | 376 ++++++++++++ .../reef/examples/group/bgd/BGDLocal.java | 53 ++ .../apache/reef/examples/group/bgd/BGDYarn.java | 52 ++ .../examples/group/bgd/ControlMessages.java | 30 + .../reef/examples/group/bgd/ExampleList.java | 72 +++ .../group/bgd/LineSearchReduceFunction.java | 51 ++ .../bgd/LossAndGradientReduceFunction.java | 55 ++ .../reef/examples/group/bgd/MasterTask.java | 246 ++++++++ .../reef/examples/group/bgd/SlaveTask.java | 204 +++++++ .../reef/examples/group/bgd/data/Example.java | 52 ++ .../examples/group/bgd/data/SparseExample.java | 68 +++ .../examples/group/bgd/data/parser/Parser.java | 32 + .../group/bgd/data/parser/SVMLightParser.java | 98 ++++ .../group/bgd/loss/LogisticLossFunction.java | 50 ++ .../examples/group/bgd/loss/LossFunction.java | 46 ++ .../bgd/loss/SquaredErrorLossFunction.java | 49 ++ .../bgd/loss/WeightedLogisticLossFunction.java | 74 +++ .../ControlMessageBroadcaster.java | 29 + .../DescentDirectionBroadcaster.java | 29 + .../LineSearchEvaluationsReducer.java | 29 + .../operatornames/LossAndGradientReducer.java | 29 + .../bgd/operatornames/MinEtaBroadcaster.java | 26 + .../ModelAndDescentDirectionBroadcaster.java | 29 + .../bgd/operatornames/ModelBroadcaster.java | 29 + .../group/bgd/operatornames/package-info.java | 23 + .../bgd/parameters/AllCommunicationGroup.java | 26 + .../bgd/parameters/BGDControlParameters.java | 126 ++++ .../group/bgd/parameters/BGDLossType.java | 61 ++ .../group/bgd/parameters/EnableRampup.java | 29 + .../reef/examples/group/bgd/parameters/Eps.java | 30 + .../reef/examples/group/bgd/parameters/Eta.java | 30 + .../group/bgd/parameters/EvaluatorMemory.java | 29 + .../examples/group/bgd/parameters/InputDir.java | 29 + .../group/bgd/parameters/Iterations.java | 29 + .../examples/group/bgd/parameters/Lambda.java | 29 + .../group/bgd/parameters/LossFunctionType.java | 30 + .../examples/group/bgd/parameters/MinParts.java | 29 + .../group/bgd/parameters/ModelDimensions.java | 30 + .../group/bgd/parameters/NumSplits.java | 30 + .../group/bgd/parameters/NumberOfReceivers.java | 30 + .../bgd/parameters/ProbabilityOfFailure.java | 30 + .../ProbabilityOfSuccesfulIteration.java | 30 + .../examples/group/bgd/parameters/Timeout.java | 28 + .../examples/group/bgd/utils/StepSizes.java | 59 ++ .../group/bgd/utils/SubConfiguration.java | 73 +++ .../group/broadcast/BroadcastDriver.java | 285 +++++++++ .../examples/group/broadcast/BroadcastREEF.java | 148 +++++ .../group/broadcast/ControlMessages.java | 26 + .../examples/group/broadcast/MasterTask.java | 97 ++++ .../ModelReceiveAckReduceFunction.java | 39 ++ .../examples/group/broadcast/SlaveTask.java | 76 +++ .../parameters/AllCommunicationGroup.java | 30 + .../parameters/ControlMessageBroadcaster.java | 26 + .../group/broadcast/parameters/Dimensions.java | 30 + .../parameters/FailureProbability.java | 30 + .../broadcast/parameters/ModelBroadcaster.java | 26 + .../parameters/ModelReceiveAckReducer.java | 26 + .../broadcast/parameters/NumberOfReceivers.java | 30 + .../utils/math/AbstractImmutableVector.java | 103 ++++ .../group/utils/math/AbstractVector.java | 61 ++ .../examples/group/utils/math/DenseVector.java | 112 ++++ .../group/utils/math/ImmutableVector.java | 78 +++ .../examples/group/utils/math/SparseVector.java | 57 ++ .../reef/examples/group/utils/math/Vector.java | 72 +++ .../examples/group/utils/math/VectorCodec.java | 70 +++ .../reef/examples/group/utils/math/Window.java | 76 +++ .../reef/examples/group/utils/timer/Timer.java | 58 ++ .../reef/examples/scheduler/Scheduler.java | 5 +- .../utils/wake/BlockingEventHandler.java | 2 +- .../utils/wake/LoggingEventHandler.java | 17 +- lang/java/reef-io/pom.xml | 5 + .../reef/io/network/group/api/GroupChanges.java | 31 + .../network/group/api/config/OperatorSpec.java | 38 ++ .../api/driver/CommunicationGroupDriver.java | 87 +++ .../group/api/driver/GroupCommDriver.java | 76 +++ .../api/driver/GroupCommServiceDriver.java | 59 ++ .../io/network/group/api/driver/TaskNode.java | 94 +++ .../group/api/driver/TaskNodeStatus.java | 81 +++ .../io/network/group/api/driver/Topology.java | 115 ++++ .../operators/AbstractGroupCommOperator.java | 44 ++ .../network/group/api/operators/AllGather.java | 50 ++ .../network/group/api/operators/AllReduce.java | 55 ++ .../network/group/api/operators/Broadcast.java | 60 ++ .../io/network/group/api/operators/Gather.java | 64 ++ .../group/api/operators/GroupCommOperator.java | 33 ++ .../io/network/group/api/operators/Reduce.java | 99 ++++ .../group/api/operators/ReduceScatter.java | 67 +++ .../io/network/group/api/operators/Scatter.java | 74 +++ .../group/api/operators/package-info.java | 48 ++ .../group/api/task/CommGroupNetworkHandler.java | 41 ++ .../api/task/CommunicationGroupClient.java | 97 ++++ .../task/CommunicationGroupServiceClient.java | 34 ++ .../network/group/api/task/GroupCommClient.java | 42 ++ .../group/api/task/GroupCommNetworkHandler.java | 38 ++ .../io/network/group/api/task/NodeStruct.java | 42 ++ .../group/api/task/OperatorTopology.java | 58 ++ .../group/api/task/OperatorTopologyStruct.java | 73 +++ .../network/group/impl/GroupChangesCodec.java | 71 +++ .../io/network/group/impl/GroupChangesImpl.java | 45 ++ .../group/impl/GroupCommunicationMessage.java | 167 ++++++ .../impl/GroupCommunicationMessageCodec.java | 111 ++++ .../impl/config/BroadcastOperatorSpec.java | 86 +++ .../group/impl/config/ReduceOperatorSpec.java | 107 ++++ .../parameters/CommunicationGroupName.java | 28 + .../group/impl/config/parameters/DataCodec.java | 29 + .../impl/config/parameters/OperatorName.java | 28 + .../config/parameters/ReduceFunctionParam.java | 29 + .../parameters/SerializedGroupConfigs.java | 30 + .../parameters/SerializedOperConfigs.java | 30 + .../impl/config/parameters/TaskVersion.java | 28 + .../config/parameters/TreeTopologyFanOut.java | 28 + .../driver/CommunicationGroupDriverImpl.java | 451 +++++++++++++++ .../group/impl/driver/CtrlMsgSender.java | 61 ++ .../group/impl/driver/ExceptionHandler.java | 56 ++ .../network/group/impl/driver/FlatTopology.java | 307 ++++++++++ .../group/impl/driver/GroupCommDriverImpl.java | 250 ++++++++ .../impl/driver/GroupCommMessageHandler.java | 55 ++ .../group/impl/driver/GroupCommService.java | 111 ++++ .../network/group/impl/driver/IndexedMsg.java | 71 +++ .../io/network/group/impl/driver/MsgKey.java | 90 +++ .../network/group/impl/driver/TaskNodeImpl.java | 476 +++++++++++++++ .../group/impl/driver/TaskNodeStatusImpl.java | 267 +++++++++ .../io/network/group/impl/driver/TaskState.java | 23 + .../driver/TopologyFailedEvaluatorHandler.java | 50 ++ .../impl/driver/TopologyFailedTaskHandler.java | 45 ++ .../impl/driver/TopologyMessageHandler.java | 44 ++ .../impl/driver/TopologyRunningTaskHandler.java | 44 ++ .../impl/driver/TopologyUpdateWaitHandler.java | 94 +++ .../network/group/impl/driver/TreeTopology.java | 345 +++++++++++ .../network/group/impl/driver/package-info.java | 116 ++++ .../group/impl/operators/BroadcastReceiver.java | 159 +++++ .../group/impl/operators/BroadcastSender.java | 141 +++++ .../group/impl/operators/ReduceReceiver.java | 155 +++++ .../group/impl/operators/ReduceSender.java | 161 ++++++ .../io/network/group/impl/operators/Sender.java | 59 ++ .../group/impl/task/ChildNodeStruct.java | 42 ++ .../impl/task/CommGroupNetworkHandlerImpl.java | 102 ++++ .../impl/task/CommunicationGroupClientImpl.java | 296 ++++++++++ .../group/impl/task/GroupCommClientImpl.java | 85 +++ .../impl/task/GroupCommNetworkHandlerImpl.java | 68 +++ .../io/network/group/impl/task/InitHandler.java | 54 ++ .../network/group/impl/task/NodeStructImpl.java | 98 ++++ .../group/impl/task/OperatorTopologyImpl.java | 466 +++++++++++++++ .../impl/task/OperatorTopologyStructImpl.java | 579 +++++++++++++++++++ .../group/impl/task/ParentNodeStruct.java | 45 ++ .../impl/utils/BroadcastingEventHandler.java | 44 ++ .../group/impl/utils/ConcurrentCountingMap.java | 134 +++++ .../network/group/impl/utils/CountingMap.java | 98 ++++ .../group/impl/utils/CountingSemaphore.java | 103 ++++ .../impl/utils/ResettingCountDownLatch.java | 57 ++ .../io/network/group/impl/utils/SetMap.java | 95 +++ .../reef/io/network/group/impl/utils/Utils.java | 80 +++ .../reef/io/network/group/package-info.java | 33 ++ .../reef/io/network/naming/NameServer.java | 20 +- .../reef/io/network/naming/NameServerImpl.java | 1 - .../org/apache/reef/io/network/util/Utils.java | 119 ++++ .../org/apache/reef/io/storage/ram/RamMap.java | 6 +- .../src/main/proto/group_comm_protocol.proto | 64 ++ .../GroupCommunicationMessageCodecTest.java | 72 +++ .../apache/reef/io/network/util/TestUtils.java | 60 ++ .../services/network/NetworkServiceTest.java | 26 +- 163 files changed, 13180 insertions(+), 76 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/pom.xml ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/pom.xml b/lang/java/reef-examples/pom.xml index f910a77..68c4693 100644 --- a/lang/java/reef-examples/pom.xml +++ b/lang/java/reef-examples/pom.xml @@ -214,30 +214,6 @@ under the License. </plugins> </build> </profile> - <profile> - <id>MatMult</id> - <build> - <defaultGoal>exec:exec</defaultGoal> - <plugins> - <plugin> - <groupId>org.codehaus.mojo</groupId> - <artifactId>exec-maven-plugin</artifactId> - <configuration> - <executable>java</executable> - <arguments> - <argument>-classpath</argument> - <classpath/> - <argument>-Djava.util.logging.config.class=org.apache.reef.util.logging.Config - </argument> - <argument>-Dcom.microsoft.reef.runtime.local.folder=${project.build.directory} - </argument> - <argument>org.apache.reef.examples.groupcomm.matmul.MatMultREEF</argument> - </arguments> - </configuration> - </plugin> - </plugins> - </build> - </profile> <profile> <id>RetainedEval</id> http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDClient.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDClient.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDClient.java new file mode 100644 index 0000000..84865e8 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDClient.java @@ -0,0 +1,134 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd; + +import org.apache.hadoop.mapred.TextInputFormat; +import org.apache.reef.client.DriverConfiguration; +import org.apache.reef.client.DriverLauncher; +import org.apache.reef.client.LauncherStatus; +import org.apache.reef.client.REEF; +import org.apache.reef.driver.evaluator.EvaluatorRequest; +import org.apache.reef.examples.group.bgd.parameters.*; +import org.apache.reef.io.data.loading.api.DataLoadingRequestBuilder; +import org.apache.reef.io.network.group.impl.config.parameters.TreeTopologyFanOut; +import org.apache.reef.io.network.group.impl.driver.GroupCommService; +import org.apache.reef.tang.Configuration; +import org.apache.reef.tang.Configurations; +import org.apache.reef.tang.JavaConfigurationBuilder; +import org.apache.reef.tang.Tang; +import org.apache.reef.tang.annotations.Parameter; +import org.apache.reef.tang.formats.CommandLine; +import org.apache.reef.util.EnvironmentUtils; + +import javax.inject.Inject; + +/** + * A client to submit BGD Jobs + */ +public class BGDClient { + private final String input; + private final int numSplits; + private final int memory; + + private final BGDControlParameters bgdControlParameters; + private final int fanOut; + + @Inject + public BGDClient(final @Parameter(InputDir.class) String input, + final @Parameter(NumSplits.class) int numSplits, + final @Parameter(EvaluatorMemory.class) int memory, + final @Parameter(TreeTopologyFanOut.class) int fanOut, + final BGDControlParameters bgdControlParameters) { + this.input = input; + this.fanOut = fanOut; + this.bgdControlParameters = bgdControlParameters; + this.numSplits = numSplits; + this.memory = memory; + } + + /** + * Runs BGD on the given runtime. + * + * @param runtimeConfiguration the runtime to run on. + * @param jobName the name of the job on the runtime. + * @return + */ + public void submit(final Configuration runtimeConfiguration, final String jobName) throws Exception { + final Configuration driverConfiguration = getDriverConfiguration(jobName); + Tang.Factory.getTang().newInjector(runtimeConfiguration).getInstance(REEF.class).submit(driverConfiguration); + } + + /** + * Runs BGD on the given runtime - with timeout. + * + * @param runtimeConfiguration the runtime to run on. + * @param jobName the name of the job on the runtime. + * @param timeout the time after which the job will be killed if not completed, in ms + * @return job completion status + */ + public LauncherStatus run(final Configuration runtimeConfiguration, + final String jobName, final int timeout) throws Exception { + final Configuration driverConfiguration = getDriverConfiguration(jobName); + return DriverLauncher.getLauncher(runtimeConfiguration).run(driverConfiguration, timeout); + } + + private final Configuration getDriverConfiguration(final String jobName) { + return Configurations.merge( + getDataLoadConfiguration(jobName), + GroupCommService.getConfiguration(fanOut), + this.bgdControlParameters.getConfiguration()); + } + + private Configuration getDataLoadConfiguration(final String jobName) { + final EvaluatorRequest computeRequest = EvaluatorRequest.newBuilder() + .setNumber(1) + .setMemory(memory) + .build(); + final Configuration dataLoadConfiguration = new DataLoadingRequestBuilder() + .setMemoryMB(memory) + .setInputFormatClass(TextInputFormat.class) + .setInputPath(input) + .setNumberOfDesiredSplits(numSplits) + .setComputeRequest(computeRequest) + .renewFailedEvaluators(false) + .setDriverConfigurationModule(EnvironmentUtils + .addClasspath(DriverConfiguration.CONF, DriverConfiguration.GLOBAL_LIBRARIES) + .set(DriverConfiguration.DRIVER_MEMORY, Integer.toString(memory)) + .set(DriverConfiguration.ON_CONTEXT_ACTIVE, BGDDriver.ContextActiveHandler.class) + .set(DriverConfiguration.ON_TASK_RUNNING, BGDDriver.TaskRunningHandler.class) + .set(DriverConfiguration.ON_TASK_FAILED, BGDDriver.TaskFailedHandler.class) + .set(DriverConfiguration.ON_TASK_COMPLETED, BGDDriver.TaskCompletedHandler.class) + .set(DriverConfiguration.DRIVER_IDENTIFIER, jobName)) + .build(); + return dataLoadConfiguration; + } + + public static final BGDClient fromCommandLine(final String[] args) throws Exception { + final JavaConfigurationBuilder configurationBuilder = Tang.Factory.getTang().newConfigurationBuilder(); + final CommandLine commandLine = new CommandLine(configurationBuilder) + .registerShortNameOfClass(InputDir.class) + .registerShortNameOfClass(Timeout.class) + .registerShortNameOfClass(EvaluatorMemory.class) + .registerShortNameOfClass(NumSplits.class) + .registerShortNameOfClass(TreeTopologyFanOut.class); + BGDControlParameters.registerShortNames(commandLine); + commandLine.processCommandLine(args); + return Tang.Factory.getTang().newInjector(configurationBuilder.build()).getInstance(BGDClient.class); + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDDriver.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDDriver.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDDriver.java new file mode 100644 index 0000000..2a80581 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDDriver.java @@ -0,0 +1,376 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd; + +import org.apache.reef.annotations.audience.DriverSide; +import org.apache.reef.driver.context.ActiveContext; +import org.apache.reef.driver.context.ServiceConfiguration; +import org.apache.reef.driver.task.CompletedTask; +import org.apache.reef.driver.task.FailedTask; +import org.apache.reef.driver.task.RunningTask; +import org.apache.reef.driver.task.TaskConfiguration; +import org.apache.reef.evaluator.context.parameters.ContextIdentifier; +import org.apache.reef.examples.group.bgd.data.parser.Parser; +import org.apache.reef.examples.group.bgd.data.parser.SVMLightParser; +import org.apache.reef.examples.group.bgd.loss.LossFunction; +import org.apache.reef.examples.group.bgd.operatornames.*; +import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup; +import org.apache.reef.examples.group.bgd.parameters.BGDControlParameters; +import org.apache.reef.examples.group.bgd.parameters.ModelDimensions; +import org.apache.reef.examples.group.bgd.parameters.ProbabilityOfFailure; +import org.apache.reef.io.data.loading.api.DataLoadingService; +import org.apache.reef.io.network.group.api.driver.CommunicationGroupDriver; +import org.apache.reef.io.network.group.api.driver.GroupCommDriver; +import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec; +import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec; +import org.apache.reef.io.serialization.Codec; +import org.apache.reef.io.serialization.SerializableCodec; +import org.apache.reef.poison.PoisonedConfiguration; +import org.apache.reef.tang.Configuration; +import org.apache.reef.tang.Configurations; +import org.apache.reef.tang.Tang; +import org.apache.reef.tang.annotations.Unit; +import org.apache.reef.tang.exceptions.InjectionException; +import org.apache.reef.tang.formats.ConfigurationSerializer; +import org.apache.reef.wake.EventHandler; + +import javax.inject.Inject; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.logging.Level; +import java.util.logging.Logger; + +@DriverSide +@Unit +public class BGDDriver { + + private static final Logger LOG = Logger.getLogger(BGDDriver.class.getName()); + + private static final Tang TANG = Tang.Factory.getTang(); + + private static final double STARTUP_FAILURE_PROB = 0.01; + + private final DataLoadingService dataLoadingService; + private final GroupCommDriver groupCommDriver; + private final ConfigurationSerializer confSerializer; + private final CommunicationGroupDriver communicationsGroup; + private final AtomicBoolean masterSubmitted = new AtomicBoolean(false); + private final AtomicInteger slaveIds = new AtomicInteger(0); + private final Map<String, RunningTask> runningTasks = new HashMap<>(); + private final AtomicBoolean jobComplete = new AtomicBoolean(false); + private final Codec<ArrayList<Double>> lossCodec = new SerializableCodec<>(); + private final BGDControlParameters bgdControlParameters; + + private String communicationsGroupMasterContextId; + + @Inject + public BGDDriver(final DataLoadingService dataLoadingService, + final GroupCommDriver groupCommDriver, + final ConfigurationSerializer confSerializer, + final BGDControlParameters bgdControlParameters) { + this.dataLoadingService = dataLoadingService; + this.groupCommDriver = groupCommDriver; + this.confSerializer = confSerializer; + this.bgdControlParameters = bgdControlParameters; + + final int minNumOfPartitions = + bgdControlParameters.isRampup() + ? bgdControlParameters.getMinParts() + : dataLoadingService.getNumberOfPartitions(); + + final int numParticipants = minNumOfPartitions + 1; + + this.communicationsGroup = this.groupCommDriver.newCommunicationGroup( + AllCommunicationGroup.class, // NAME + numParticipants); // Number of participants + + LOG.log(Level.INFO, + "Obtained entire communication group: start with {0} partitions", numParticipants); + + this.communicationsGroup + .addBroadcast(ControlMessageBroadcaster.class, + BroadcastOperatorSpec.newBuilder() + .setSenderId(MasterTask.TASK_ID) + .setDataCodecClass(SerializableCodec.class) + .build()) + .addBroadcast(ModelBroadcaster.class, + BroadcastOperatorSpec.newBuilder() + .setSenderId(MasterTask.TASK_ID) + .setDataCodecClass(SerializableCodec.class) + .build()) + .addReduce(LossAndGradientReducer.class, + ReduceOperatorSpec.newBuilder() + .setReceiverId(MasterTask.TASK_ID) + .setDataCodecClass(SerializableCodec.class) + .setReduceFunctionClass(LossAndGradientReduceFunction.class) + .build()) + .addBroadcast(ModelAndDescentDirectionBroadcaster.class, + BroadcastOperatorSpec.newBuilder() + .setSenderId(MasterTask.TASK_ID) + .setDataCodecClass(SerializableCodec.class) + .build()) + .addBroadcast(DescentDirectionBroadcaster.class, + BroadcastOperatorSpec.newBuilder() + .setSenderId(MasterTask.TASK_ID) + .setDataCodecClass(SerializableCodec.class) + .build()) + .addReduce(LineSearchEvaluationsReducer.class, + ReduceOperatorSpec.newBuilder() + .setReceiverId(MasterTask.TASK_ID) + .setDataCodecClass(SerializableCodec.class) + .setReduceFunctionClass(LineSearchReduceFunction.class) + .build()) + .addBroadcast(MinEtaBroadcaster.class, + BroadcastOperatorSpec.newBuilder() + .setSenderId(MasterTask.TASK_ID) + .setDataCodecClass(SerializableCodec.class) + .build()) + .finalise(); + + LOG.log(Level.INFO, "Added operators to communicationsGroup"); + } + + final class ContextActiveHandler implements EventHandler<ActiveContext> { + + @Override + public void onNext(final ActiveContext activeContext) { + LOG.log(Level.INFO, "Got active context: {0}", activeContext.getId()); + if (jobRunning(activeContext)) { + if (!groupCommDriver.isConfigured(activeContext)) { + // The Context is not configured with the group communications service let's do that. + submitGroupCommunicationsService(activeContext); + } else { + // The group communications service is already active on this context. We can submit the task. + submitTask(activeContext); + } + } + } + + /** + * @param activeContext a context to be configured with group communications. + */ + private void submitGroupCommunicationsService(final ActiveContext activeContext) { + final Configuration contextConf = groupCommDriver.getContextConfiguration(); + final String contextId = getContextId(contextConf); + final Configuration serviceConf; + if (!dataLoadingService.isDataLoadedContext(activeContext)) { + communicationsGroupMasterContextId = contextId; + serviceConf = groupCommDriver.getServiceConfiguration(); + } else { + final Configuration parsedDataServiceConf = ServiceConfiguration.CONF + .set(ServiceConfiguration.SERVICES, ExampleList.class) + .build(); + serviceConf = Tang.Factory.getTang() + .newConfigurationBuilder(groupCommDriver.getServiceConfiguration(), parsedDataServiceConf) + .bindImplementation(Parser.class, SVMLightParser.class) + .build(); + } + + LOG.log(Level.FINEST, "Submit GCContext conf: {0} and Service conf: {1}", new Object[]{ + confSerializer.toString(contextConf), confSerializer.toString(serviceConf)}); + + activeContext.submitContextAndService(contextConf, serviceConf); + } + + private void submitTask(final ActiveContext activeContext) { + + assert (groupCommDriver.isConfigured(activeContext)); + + final Configuration partialTaskConfiguration; + if (activeContext.getId().equals(communicationsGroupMasterContextId) && !masterTaskSubmitted()) { + partialTaskConfiguration = getMasterTaskConfiguration(); + LOG.info("Submitting MasterTask conf"); + } else { + partialTaskConfiguration = getSlaveTaskConfiguration(getSlaveId(activeContext)); + // partialTaskConfiguration = Configurations.merge( + // getSlaveTaskConfiguration(getSlaveId(activeContext)), + // getTaskPoisonConfiguration()); + LOG.info("Submitting SlaveTask conf"); + } + communicationsGroup.addTask(partialTaskConfiguration); + final Configuration taskConfiguration = groupCommDriver.getTaskConfiguration(partialTaskConfiguration); + LOG.log(Level.FINEST, "{0}", confSerializer.toString(taskConfiguration)); + activeContext.submitTask(taskConfiguration); + } + + private boolean jobRunning(final ActiveContext activeContext) { + synchronized (runningTasks) { + if (!jobComplete.get()) { + return true; + } else { + LOG.log(Level.INFO, "Job complete. Not submitting any task. Closing context {0}", activeContext); + activeContext.close(); + return false; + } + } + } + } + + final class TaskRunningHandler implements EventHandler<RunningTask> { + + @Override + public void onNext(final RunningTask runningTask) { + synchronized (runningTasks) { + if (!jobComplete.get()) { + LOG.log(Level.INFO, "Job has not completed yet. Adding to runningTasks: {0}", runningTask); + runningTasks.put(runningTask.getId(), runningTask); + } else { + LOG.log(Level.INFO, "Job complete. Closing context: {0}", runningTask.getActiveContext().getId()); + runningTask.getActiveContext().close(); + } + } + } + } + + final class TaskFailedHandler implements EventHandler<FailedTask> { + + @Override + public void onNext(final FailedTask failedTask) { + + final String failedTaskId = failedTask.getId(); + + LOG.log(Level.WARNING, "Got failed Task: " + failedTaskId); + + if (jobRunning(failedTaskId)) { + + final ActiveContext activeContext = failedTask.getActiveContext().get(); + final Configuration partialTaskConf = getSlaveTaskConfiguration(failedTaskId); + + // Do not add the task back: + // allCommGroup.addTask(partialTaskConf); + + final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf); + LOG.log(Level.FINEST, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf)); + + activeContext.submitTask(taskConf); + } + } + + private boolean jobRunning(final String failedTaskId) { + synchronized (runningTasks) { + if (!jobComplete.get()) { + return true; + } else { + final RunningTask rTask = runningTasks.remove(failedTaskId); + LOG.log(Level.INFO, "Job has completed. Not resubmitting"); + if (rTask != null) { + LOG.log(Level.INFO, "Closing activecontext"); + rTask.getActiveContext().close(); + } else { + LOG.log(Level.INFO, "Master must have closed my context"); + } + return false; + } + } + } + } + + final class TaskCompletedHandler implements EventHandler<CompletedTask> { + + @Override + public void onNext(final CompletedTask task) { + LOG.log(Level.INFO, "Got CompletedTask: {0}", task.getId()); + final byte[] retVal = task.get(); + if (retVal != null) { + final List<Double> losses = BGDDriver.this.lossCodec.decode(retVal); + for (final Double loss : losses) { + LOG.log(Level.INFO, "OUT: LOSS = {0}", loss); + } + } + synchronized (runningTasks) { + LOG.log(Level.INFO, "Acquired lock on runningTasks. Removing {0}", task.getId()); + final RunningTask rTask = runningTasks.remove(task.getId()); + if (rTask != null) { + LOG.log(Level.INFO, "Closing active context: {0}", task.getActiveContext().getId()); + task.getActiveContext().close(); + } else { + LOG.log(Level.INFO, "Master must have closed active context already for task {0}", task.getId()); + } + + if (MasterTask.TASK_ID.equals(task.getId())) { + jobComplete.set(true); + LOG.log(Level.INFO, "Master(=>Job) complete. Closing other running tasks: {0}", runningTasks.values()); + for (final RunningTask runTask : runningTasks.values()) { + runTask.getActiveContext().close(); + } + LOG.finest("Clearing runningTasks"); + runningTasks.clear(); + } + } + } + } + + /** + * @return Configuration for the MasterTask + */ + public Configuration getMasterTaskConfiguration() { + return Configurations.merge( + TaskConfiguration.CONF + .set(TaskConfiguration.IDENTIFIER, MasterTask.TASK_ID) + .set(TaskConfiguration.TASK, MasterTask.class) + .build(), + bgdControlParameters.getConfiguration()); + } + + /** + * @return Configuration for the SlaveTask + */ + private Configuration getSlaveTaskConfiguration(final String taskId) { + final double pSuccess = bgdControlParameters.getProbOfSuccessfulIteration(); + final int numberOfPartitions = dataLoadingService.getNumberOfPartitions(); + final double pFailure = 1 - Math.pow(pSuccess, 1.0 / numberOfPartitions); + return Tang.Factory.getTang() + .newConfigurationBuilder( + TaskConfiguration.CONF + .set(TaskConfiguration.IDENTIFIER, taskId) + .set(TaskConfiguration.TASK, SlaveTask.class) + .build()) + .bindNamedParameter(ModelDimensions.class, "" + bgdControlParameters.getDimensions()) + .bindImplementation(LossFunction.class, bgdControlParameters.getLossFunction()) + .bindNamedParameter(ProbabilityOfFailure.class, Double.toString(pFailure)) + .build(); + } + + private Configuration getTaskPoisonConfiguration() { + return PoisonedConfiguration.TASK_CONF + .set(PoisonedConfiguration.CRASH_PROBABILITY, STARTUP_FAILURE_PROB) + .set(PoisonedConfiguration.CRASH_TIMEOUT, 1) + .build(); + } + + private String getContextId(final Configuration contextConf) { + try { + return TANG.newInjector(contextConf).getNamedInstance(ContextIdentifier.class); + } catch (final InjectionException e) { + throw new RuntimeException("Unable to inject context identifier from context conf", e); + } + } + + private String getSlaveId(final ActiveContext activeContext) { + return "SlaveTask-" + slaveIds.getAndIncrement(); + } + + private boolean masterTaskSubmitted() { + return !masterSubmitted.compareAndSet(false, true); + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDLocal.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDLocal.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDLocal.java new file mode 100644 index 0000000..3a82314 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDLocal.java @@ -0,0 +1,53 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd; + +import org.apache.reef.client.LauncherStatus; +import org.apache.reef.examples.group.utils.timer.Timer; +import org.apache.reef.runtime.local.client.LocalRuntimeConfiguration; +import org.apache.reef.tang.Configuration; + +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Runs BGD on the local runtime. + */ +public class BGDLocal { + + private static final Logger LOG = Logger.getLogger(BGDLocal.class.getName()); + + private static final int NUM_LOCAL_THREADS = 20; + private static final int TIMEOUT = 10 * Timer.MINUTES; + + public static void main(final String[] args) throws Exception { + + final BGDClient bgdClient = BGDClient.fromCommandLine(args); + + final Configuration runtimeConfiguration = LocalRuntimeConfiguration.CONF + .set(LocalRuntimeConfiguration.NUMBER_OF_THREADS, "" + NUM_LOCAL_THREADS) + .build(); + + final String jobName = System.getProperty("user.name") + "-" + "ResourceAwareBGDLocal"; + + final LauncherStatus status = bgdClient.run(runtimeConfiguration, jobName, TIMEOUT); + + LOG.log(Level.INFO, "OUT: Status = {0}", status); + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDYarn.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDYarn.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDYarn.java new file mode 100644 index 0000000..19d3b10 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDYarn.java @@ -0,0 +1,52 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd; + +import org.apache.reef.client.LauncherStatus; +import org.apache.reef.examples.group.utils.timer.Timer; +import org.apache.reef.runtime.yarn.client.YarnClientConfiguration; +import org.apache.reef.tang.Configuration; + +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Runs BGD on the YARN runtime. + */ +public class BGDYarn { + + private static final Logger LOG = Logger.getLogger(BGDYarn.class.getName()); + + private static final int TIMEOUT = 4 * Timer.HOURS; + + public static void main(final String[] args) throws Exception { + + final BGDClient bgdClient = BGDClient.fromCommandLine(args); + + final Configuration runtimeConfiguration = YarnClientConfiguration.CONF + .set(YarnClientConfiguration.JVM_HEAP_SLACK, "0.1") + .build(); + + final String jobName = System.getProperty("user.name") + "-" + "BR-ResourceAwareBGD-YARN"; + + final LauncherStatus status = bgdClient.run(runtimeConfiguration, jobName, TIMEOUT); + + LOG.log(Level.INFO, "OUT: Status = {0}", status); + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/ControlMessages.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/ControlMessages.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/ControlMessages.java new file mode 100644 index 0000000..aeea56b --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/ControlMessages.java @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd; + +import java.io.Serializable; + +public enum ControlMessages implements Serializable { + ComputeGradientWithModel, + ComputeGradientWithMinEta, + DoLineSearch, + DoLineSearchWithModel, + Synchronize, + Stop +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/ExampleList.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/ExampleList.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/ExampleList.java new file mode 100644 index 0000000..97477a9 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/ExampleList.java @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd; + +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.reef.examples.group.bgd.data.Example; +import org.apache.reef.examples.group.bgd.data.parser.Parser; +import org.apache.reef.io.data.loading.api.DataSet; +import org.apache.reef.io.network.util.Pair; + +import javax.inject.Inject; +import java.util.ArrayList; +import java.util.List; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * + */ +public class ExampleList { + + private static final Logger LOG = Logger.getLogger(ExampleList.class.getName()); + + private final List<Example> examples = new ArrayList<>(); + private final DataSet<LongWritable, Text> dataSet; + private final Parser<String> parser; + + @Inject + public ExampleList(final DataSet<LongWritable, Text> dataSet, final Parser<String> parser) { + this.dataSet = dataSet; + this.parser = parser; + } + + /** + * @return the examples + */ + public List<Example> getExamples() { + if (examples.isEmpty()) { + loadData(); + } + return examples; + } + + private void loadData() { + LOG.info("Loading data"); + int i = 0; + for (final Pair<LongWritable, Text> examplePair : dataSet) { + final Example example = parser.parse(examplePair.second.toString()); + examples.add(example); + if (++i % 2000 == 0) { + LOG.log(Level.FINE, "Done parsing {0} lines", i); + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/LineSearchReduceFunction.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/LineSearchReduceFunction.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/LineSearchReduceFunction.java new file mode 100644 index 0000000..9132583 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/LineSearchReduceFunction.java @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd; + +import org.apache.reef.examples.group.utils.math.DenseVector; +import org.apache.reef.examples.group.utils.math.Vector; +import org.apache.reef.io.network.group.api.operators.Reduce; +import org.apache.reef.io.network.util.Pair; + +import javax.inject.Inject; + +public class LineSearchReduceFunction implements Reduce.ReduceFunction<Pair<Vector, Integer>> { + + @Inject + public LineSearchReduceFunction() { + } + + @Override + public Pair<Vector, Integer> apply(final Iterable<Pair<Vector, Integer>> evals) { + + Vector combinedEvaluations = null; + int numEx = 0; + + for (final Pair<Vector, Integer> eval : evals) { + if (combinedEvaluations == null) { + combinedEvaluations = new DenseVector(eval.first); + } else { + combinedEvaluations.add(eval.first); + } + numEx += eval.second; + } + + return new Pair<>(combinedEvaluations, numEx); + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/LossAndGradientReduceFunction.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/LossAndGradientReduceFunction.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/LossAndGradientReduceFunction.java new file mode 100644 index 0000000..cf4d0be --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/LossAndGradientReduceFunction.java @@ -0,0 +1,55 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd; + +import org.apache.reef.examples.group.utils.math.DenseVector; +import org.apache.reef.examples.group.utils.math.Vector; +import org.apache.reef.io.network.group.api.operators.Reduce.ReduceFunction; +import org.apache.reef.io.network.util.Pair; + +import javax.inject.Inject; + +public class LossAndGradientReduceFunction + implements ReduceFunction<Pair<Pair<Double, Integer>, Vector>> { + + @Inject + public LossAndGradientReduceFunction() { + } + + @Override + public Pair<Pair<Double, Integer>, Vector> apply( + final Iterable<Pair<Pair<Double, Integer>, Vector>> lags) { + + double lossSum = 0.0; + int numEx = 0; + Vector combinedGradient = null; + + for (final Pair<Pair<Double, Integer>, Vector> lag : lags) { + if (combinedGradient == null) { + combinedGradient = new DenseVector(lag.second); + } else { + combinedGradient.add(lag.second); + } + lossSum += lag.first.first; + numEx += lag.first.second; + } + + return new Pair<>(new Pair<>(lossSum, numEx), combinedGradient); + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/MasterTask.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/MasterTask.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/MasterTask.java new file mode 100644 index 0000000..06ed5fd --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/MasterTask.java @@ -0,0 +1,246 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd; + +import org.apache.reef.examples.group.bgd.operatornames.*; +import org.apache.reef.examples.group.bgd.parameters.*; +import org.apache.reef.examples.group.bgd.utils.StepSizes; +import org.apache.reef.examples.group.utils.math.DenseVector; +import org.apache.reef.examples.group.utils.math.Vector; +import org.apache.reef.examples.group.utils.timer.Timer; +import org.apache.reef.exception.evaluator.NetworkException; +import org.apache.reef.io.Tuple; +import org.apache.reef.io.network.group.api.operators.Broadcast; +import org.apache.reef.io.network.group.api.operators.Reduce; +import org.apache.reef.io.network.group.api.GroupChanges; +import org.apache.reef.io.network.group.api.task.CommunicationGroupClient; +import org.apache.reef.io.network.group.api.task.GroupCommClient; +import org.apache.reef.io.network.util.Pair; +import org.apache.reef.io.serialization.Codec; +import org.apache.reef.io.serialization.SerializableCodec; +import org.apache.reef.tang.annotations.Parameter; +import org.apache.reef.task.Task; + +import javax.inject.Inject; +import java.util.ArrayList; +import java.util.logging.Level; +import java.util.logging.Logger; + +public class MasterTask implements Task { + + public static final String TASK_ID = "MasterTask"; + + private static final Logger LOG = Logger.getLogger(MasterTask.class.getName()); + + private final CommunicationGroupClient communicationGroupClient; + private final Broadcast.Sender<ControlMessages> controlMessageBroadcaster; + private final Broadcast.Sender<Vector> modelBroadcaster; + private final Reduce.Receiver<Pair<Pair<Double, Integer>, Vector>> lossAndGradientReducer; + private final Broadcast.Sender<Pair<Vector, Vector>> modelAndDescentDirectionBroadcaster; + private final Broadcast.Sender<Vector> descentDriectionBroadcaster; + private final Reduce.Receiver<Pair<Vector, Integer>> lineSearchEvaluationsReducer; + private final Broadcast.Sender<Double> minEtaBroadcaster; + private final boolean ignoreAndContinue; + private final StepSizes ts; + private final double lambda; + private final int maxIters; + final ArrayList<Double> losses = new ArrayList<>(); + final Codec<ArrayList<Double>> lossCodec = new SerializableCodec<ArrayList<Double>>(); + private final Vector model; + + boolean sendModel = true; + double minEta = 0; + + @Inject + public MasterTask( + final GroupCommClient groupCommClient, + @Parameter(ModelDimensions.class) final int dimensions, + @Parameter(Lambda.class) final double lambda, + @Parameter(Iterations.class) final int maxIters, + @Parameter(EnableRampup.class) final boolean rampup, + final StepSizes ts) { + + this.lambda = lambda; + this.maxIters = maxIters; + this.ts = ts; + this.ignoreAndContinue = rampup; + this.model = new DenseVector(dimensions); + this.communicationGroupClient = groupCommClient.getCommunicationGroup(AllCommunicationGroup.class); + this.controlMessageBroadcaster = communicationGroupClient.getBroadcastSender(ControlMessageBroadcaster.class); + this.modelBroadcaster = communicationGroupClient.getBroadcastSender(ModelBroadcaster.class); + this.lossAndGradientReducer = communicationGroupClient.getReduceReceiver(LossAndGradientReducer.class); + this.modelAndDescentDirectionBroadcaster = communicationGroupClient.getBroadcastSender(ModelAndDescentDirectionBroadcaster.class); + this.descentDriectionBroadcaster = communicationGroupClient.getBroadcastSender(DescentDirectionBroadcaster.class); + this.lineSearchEvaluationsReducer = communicationGroupClient.getReduceReceiver(LineSearchEvaluationsReducer.class); + this.minEtaBroadcaster = communicationGroupClient.getBroadcastSender(MinEtaBroadcaster.class); + } + + @Override + public byte[] call(final byte[] memento) throws Exception { + + double gradientNorm = Double.MAX_VALUE; + for (int iteration = 1; !converged(iteration, gradientNorm); ++iteration) { + try (final Timer t = new Timer("Current Iteration(" + (iteration) + ")")) { + final Pair<Double, Vector> lossAndGradient = computeLossAndGradient(); + losses.add(lossAndGradient.first); + final Vector descentDirection = getDescentDirection(lossAndGradient.second); + + updateModel(descentDirection); + + gradientNorm = descentDirection.norm2(); + } + } + LOG.log(Level.INFO, "OUT: Stop"); + controlMessageBroadcaster.send(ControlMessages.Stop); + + for (final Double loss : losses) { + LOG.log(Level.INFO, "OUT: LOSS = {0}", loss); + } + return lossCodec.encode(losses); + } + + private void updateModel(final Vector descentDirection) throws NetworkException, InterruptedException { + try (final Timer t = new Timer("GetDescentDirection + FindMinEta + UpdateModel")) { + final Vector lineSearchEvals = lineSearch(descentDirection); + minEta = findMinEta(model, descentDirection, lineSearchEvals); + model.multAdd(minEta, descentDirection); + } + + LOG.log(Level.INFO, "OUT: New Model = {0}", model); + } + + private Vector lineSearch(final Vector descentDirection) throws NetworkException, InterruptedException { + Vector lineSearchResults = null; + boolean allDead = false; + do { + try (final Timer t = new Timer("LineSearch - Broadcast(" + + (sendModel ? "ModelAndDescentDirection" : "DescentDirection") + ") + Reduce(LossEvalsInLineSearch)")) { + if (sendModel) { + LOG.log(Level.INFO, "OUT: DoLineSearchWithModel"); + controlMessageBroadcaster.send(ControlMessages.DoLineSearchWithModel); + modelAndDescentDirectionBroadcaster.send(new Pair<>(model, descentDirection)); + } else { + LOG.log(Level.INFO, "OUT: DoLineSearch"); + controlMessageBroadcaster.send(ControlMessages.DoLineSearch); + descentDriectionBroadcaster.send(descentDirection); + } + final Pair<Vector, Integer> lineSearchEvals = lineSearchEvaluationsReducer.reduce(); + if (lineSearchEvals != null) { + final int numExamples = lineSearchEvals.second; + lineSearchResults = lineSearchEvals.first; + lineSearchResults.scale(1.0 / numExamples); + LOG.log(Level.INFO, "OUT: #Examples: {0}", numExamples); + LOG.log(Level.INFO, "OUT: LineSearchEvals: {0}", lineSearchResults); + allDead = false; + } else { + allDead = true; + } + } + + sendModel = chkAndUpdate(); + } while (allDead || (!ignoreAndContinue && sendModel)); + return lineSearchResults; + } + + private Pair<Double, Vector> computeLossAndGradient() throws NetworkException, InterruptedException { + Pair<Double, Vector> returnValue = null; + boolean allDead = false; + do { + try (final Timer t = new Timer("Broadcast(" + (sendModel ? "Model" : "MinEta") + ") + Reduce(LossAndGradient)")) { + if (sendModel) { + LOG.log(Level.INFO, "OUT: ComputeGradientWithModel"); + controlMessageBroadcaster.send(ControlMessages.ComputeGradientWithModel); + modelBroadcaster.send(model); + } else { + LOG.log(Level.INFO, "OUT: ComputeGradientWithMinEta"); + controlMessageBroadcaster.send(ControlMessages.ComputeGradientWithMinEta); + minEtaBroadcaster.send(minEta); + } + final Pair<Pair<Double, Integer>, Vector> lossAndGradient = lossAndGradientReducer.reduce(); + + if (lossAndGradient != null) { + final int numExamples = lossAndGradient.first.second; + LOG.log(Level.INFO, "OUT: #Examples: {0}", numExamples); + final double lossPerExample = lossAndGradient.first.first / numExamples; + LOG.log(Level.INFO, "OUT: Loss: {0}", lossPerExample); + final double objFunc = ((lambda / 2) * model.norm2Sqr()) + lossPerExample; + LOG.log(Level.INFO, "OUT: Objective Func Value: {0}", objFunc); + final Vector gradient = lossAndGradient.second; + gradient.scale(1.0 / numExamples); + LOG.log(Level.INFO, "OUT: Gradient: {0}", gradient); + returnValue = new Pair<>(objFunc, gradient); + allDead = false; + } else { + allDead = true; + } + } + sendModel = chkAndUpdate(); + } while (allDead || (!ignoreAndContinue && sendModel)); + return returnValue; + } + + private boolean chkAndUpdate() { + long t1 = System.currentTimeMillis(); + final GroupChanges changes = communicationGroupClient.getTopologyChanges(); + long t2 = System.currentTimeMillis(); + LOG.log(Level.INFO, "OUT: Time to get TopologyChanges = " + (t2 - t1) / 1000.0 + " sec"); + if (changes.exist()) { + LOG.log(Level.INFO, "OUT: There exist topology changes. Asking to update Topology"); + t1 = System.currentTimeMillis(); + communicationGroupClient.updateTopology(); + t2 = System.currentTimeMillis(); + LOG.log(Level.INFO, "OUT: Time to get TopologyChanges = " + (t2 - t1) / 1000.0 + " sec"); + return true; + } else { + LOG.log(Level.INFO, "OUT: No changes in topology exist. So not updating topology"); + return false; + } + } + + private boolean converged(final int iters, final double gradNorm) { + return iters >= maxIters || Math.abs(gradNorm) <= 1e-3; + } + + private double findMinEta(final Vector model, final Vector descentDir, final Vector lineSearchEvals) { + final double wNormSqr = model.norm2Sqr(); + final double dNormSqr = descentDir.norm2Sqr(); + final double wDotd = model.dot(descentDir); + final double[] t = ts.getT(); + int i = 0; + for (final double eta : t) { + final double modelNormSqr = wNormSqr + (eta * eta) * dNormSqr + 2 * eta * wDotd; + final double loss = lineSearchEvals.get(i) + ((lambda / 2) * modelNormSqr); + lineSearchEvals.set(i, loss); + ++i; + } + LOG.log(Level.INFO, "OUT: Regularized LineSearchEvals: {0}", lineSearchEvals); + final Tuple<Integer, Double> minTup = lineSearchEvals.min(); + LOG.log(Level.INFO, "OUT: MinTup: {0}", minTup); + final double minT = t[minTup.getKey()]; + LOG.log(Level.INFO, "OUT: MinT: {0}", minT); + return minT; + } + + private Vector getDescentDirection(final Vector gradient) { + gradient.multAdd(lambda, model); + gradient.scale(-1); + LOG.log(Level.INFO, "OUT: DescentDirection: {0}", gradient); + return gradient; + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/SlaveTask.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/SlaveTask.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/SlaveTask.java new file mode 100644 index 0000000..fadc16e --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/SlaveTask.java @@ -0,0 +1,204 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd; + +import org.apache.reef.examples.group.bgd.data.Example; +import org.apache.reef.examples.group.bgd.loss.LossFunction; +import org.apache.reef.examples.group.bgd.operatornames.*; +import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup; +import org.apache.reef.examples.group.bgd.parameters.ProbabilityOfFailure; +import org.apache.reef.examples.group.bgd.utils.StepSizes; +import org.apache.reef.examples.group.utils.math.DenseVector; +import org.apache.reef.examples.group.utils.math.Vector; +import org.apache.reef.io.network.group.api.operators.Broadcast; +import org.apache.reef.io.network.group.api.operators.Reduce; +import org.apache.reef.io.network.group.api.task.CommunicationGroupClient; +import org.apache.reef.io.network.group.api.task.GroupCommClient; +import org.apache.reef.io.network.util.Pair; +import org.apache.reef.tang.annotations.Parameter; +import org.apache.reef.task.Task; + +import javax.inject.Inject; +import java.util.List; +import java.util.logging.Logger; + +public class SlaveTask implements Task { + + private static final Logger LOG = Logger.getLogger(SlaveTask.class.getName()); + + private final double FAILURE_PROB; + + private final CommunicationGroupClient communicationGroup; + private final Broadcast.Receiver<ControlMessages> controlMessageBroadcaster; + private final Broadcast.Receiver<Vector> modelBroadcaster; + private final Reduce.Sender<Pair<Pair<Double, Integer>, Vector>> lossAndGradientReducer; + private final Broadcast.Receiver<Pair<Vector, Vector>> modelAndDescentDirectionBroadcaster; + private final Broadcast.Receiver<Vector> descentDirectionBroadcaster; + private final Reduce.Sender<Pair<Vector, Integer>> lineSearchEvaluationsReducer; + private final Broadcast.Receiver<Double> minEtaBroadcaster; + private List<Example> examples = null; + private final ExampleList dataSet; + private final LossFunction lossFunction; + private final StepSizes ts; + + private Vector model = null; + private Vector descentDirection = null; + + @Inject + public SlaveTask( + final GroupCommClient groupCommClient, + final ExampleList dataSet, + final LossFunction lossFunction, + @Parameter(ProbabilityOfFailure.class) final double pFailure, + final StepSizes ts) { + + this.dataSet = dataSet; + this.lossFunction = lossFunction; + this.FAILURE_PROB = pFailure; + LOG.info("Using pFailure=" + this.FAILURE_PROB); + this.ts = ts; + + this.communicationGroup = groupCommClient.getCommunicationGroup(AllCommunicationGroup.class); + this.controlMessageBroadcaster = communicationGroup.getBroadcastReceiver(ControlMessageBroadcaster.class); + this.modelBroadcaster = communicationGroup.getBroadcastReceiver(ModelBroadcaster.class); + this.lossAndGradientReducer = communicationGroup.getReduceSender(LossAndGradientReducer.class); + this.modelAndDescentDirectionBroadcaster = communicationGroup.getBroadcastReceiver(ModelAndDescentDirectionBroadcaster.class); + this.descentDirectionBroadcaster = communicationGroup.getBroadcastReceiver(DescentDirectionBroadcaster.class); + this.lineSearchEvaluationsReducer = communicationGroup.getReduceSender(LineSearchEvaluationsReducer.class); + this.minEtaBroadcaster = communicationGroup.getBroadcastReceiver(MinEtaBroadcaster.class); + } + + @Override + public byte[] call(final byte[] memento) throws Exception { + /* + * In the case where there will be evaluator failure and data is not in + * memory we want to load the data while waiting to join the communication + * group + */ + loadData(); + + for (boolean repeat = true; repeat; ) { + + final ControlMessages controlMessage = controlMessageBroadcaster.receive(); + switch (controlMessage) { + + case Stop: + repeat = false; + break; + + case ComputeGradientWithModel: + failPerhaps(); + this.model = modelBroadcaster.receive(); + lossAndGradientReducer.send(computeLossAndGradient()); + break; + + case ComputeGradientWithMinEta: + failPerhaps(); + final double minEta = minEtaBroadcaster.receive(); + assert (descentDirection != null); + this.descentDirection.scale(minEta); + assert (model != null); + this.model.add(descentDirection); + lossAndGradientReducer.send(computeLossAndGradient()); + break; + + case DoLineSearch: + failPerhaps(); + this.descentDirection = descentDirectionBroadcaster.receive(); + lineSearchEvaluationsReducer.send(lineSearchEvals()); + break; + + case DoLineSearchWithModel: + failPerhaps(); + final Pair<Vector, Vector> modelAndDescentDir = modelAndDescentDirectionBroadcaster.receive(); + this.model = modelAndDescentDir.first; + this.descentDirection = modelAndDescentDir.second; + lineSearchEvaluationsReducer.send(lineSearchEvals()); + break; + + default: + break; + } + } + + return null; + } + + private void failPerhaps() { + if (Math.random() < FAILURE_PROB) { + throw new RuntimeException("Simulated Failure"); + } + } + + private Pair<Vector, Integer> lineSearchEvals() { + + if (examples == null) { + loadData(); + } + + final Vector zed = new DenseVector(examples.size()); + final Vector ee = new DenseVector(examples.size()); + + for (int i = 0; i < examples.size(); i++) { + final Example example = examples.get(i); + double f = example.predict(model); + zed.set(i, f); + f = example.predict(descentDirection); + ee.set(i, f); + } + + final double[] t = ts.getT(); + final Vector evaluations = new DenseVector(t.length); + int i = 0; + for (final double d : t) { + double loss = 0; + for (int j = 0; j < examples.size(); j++) { + final Example example = examples.get(j); + final double val = zed.get(j) + d * ee.get(j); + loss += this.lossFunction.computeLoss(example.getLabel(), val); + } + evaluations.set(i++, loss); + } + + return new Pair<>(evaluations, examples.size()); + } + + private Pair<Pair<Double, Integer>, Vector> computeLossAndGradient() { + + if (examples == null) { + loadData(); + } + + final Vector gradient = new DenseVector(model.size()); + double loss = 0.0; + for (final Example example : examples) { + final double f = example.predict(model); + final double g = this.lossFunction.computeGradient(example.getLabel(), f); + example.addGradient(gradient, g); + loss += this.lossFunction.computeLoss(example.getLabel(), f); + } + + return new Pair<>(new Pair<>(loss, examples.size()), gradient); + } + + private void loadData() { + LOG.info("Loading data"); + examples = dataSet.getExamples(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/Example.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/Example.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/Example.java new file mode 100644 index 0000000..2ec7146 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/Example.java @@ -0,0 +1,52 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.data; + +import org.apache.reef.examples.group.utils.math.Vector; + +import java.io.Serializable; + +/** + * Base interface for Examples for linear models. + */ +public interface Example extends Serializable { + + /** + * Access to the label. + * + * @return the label + */ + double getLabel(); + + /** + * Computes the prediction for this Example, given the model w. + * <p/> + * w.dot(this.getFeatures()) + * + * @param w the model + * @return the prediction for this Example, given the model w. + */ + double predict(Vector w); + + /** + * Adds the current example's gradient to the gradientVector, assuming that + * the gradient with respect to the prediction is gradient. + */ + void addGradient(Vector gradientVector, double gradient); +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/SparseExample.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/SparseExample.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/SparseExample.java new file mode 100644 index 0000000..094f1d8 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/SparseExample.java @@ -0,0 +1,68 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.data; + +import org.apache.reef.examples.group.utils.math.Vector; + +/** + * Example implementation on a index and value array. + */ +public final class SparseExample implements Example { + + private static final long serialVersionUID = -2127500625316875426L; + + private final float[] values; + private final int[] indices; + private final double label; + + public SparseExample(final double label, final float[] values, final int[] indices) { + this.label = label; + this.values = values; + this.indices = indices; + } + + public int getFeatureLength() { + return this.values.length; + } + + @Override + public double getLabel() { + return this.label; + } + + @Override + public double predict(final Vector w) { + double result = 0.0; + for (int i = 0; i < this.indices.length; ++i) { + result += w.get(this.indices[i]) * this.values[i]; + } + return result; + } + + @Override + public void addGradient(final Vector gradientVector, final double gradient) { + for (int i = 0; i < this.indices.length; ++i) { + final int index = this.indices[i]; + final double contribution = gradient * this.values[i]; + final double oldValue = gradientVector.get(index); + final double newValue = oldValue + contribution; + gradientVector.set(index, newValue); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/parser/Parser.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/parser/Parser.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/parser/Parser.java new file mode 100644 index 0000000..f4d8d09 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/parser/Parser.java @@ -0,0 +1,32 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.data.parser; + +import org.apache.reef.examples.group.bgd.data.Example; + +/** + * Parses inputs into Examples. + * + * @param <T> + */ +public interface Parser<T> { + + public Example parse(final T input); + +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/parser/SVMLightParser.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/parser/SVMLightParser.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/parser/SVMLightParser.java new file mode 100644 index 0000000..5f64606 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/parser/SVMLightParser.java @@ -0,0 +1,98 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.data.parser; + +import org.apache.commons.lang.StringUtils; +import org.apache.reef.examples.group.bgd.data.Example; +import org.apache.reef.examples.group.bgd.data.SparseExample; + +import javax.inject.Inject; +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * A Parser for SVMLight records + */ +public class SVMLightParser implements Parser<String> { + + private static final Logger LOG = Logger.getLogger(SVMLightParser.class.getName()); + + @Inject + public SVMLightParser() { + } + + @Override + public Example parse(final String line) { + + final int entriesCount = StringUtils.countMatches(line, ":"); + final int[] indices = new int[entriesCount]; + final float[] values = new float[entriesCount]; + + final String[] entries = StringUtils.split(line, ' '); + String labelStr = entries[0]; + + final boolean pipeExists = labelStr.indexOf('|') != -1; + if (pipeExists) { + labelStr = labelStr.substring(0, labelStr.indexOf('|')); + } + double label = Double.parseDouble(labelStr); + + if (label != 1) { + label = -1; + } + + for (int j = 1; j < entries.length; ++j) { + final String x = entries[j]; + final String[] entity = StringUtils.split(x, ':'); + final int offset = pipeExists ? 0 : 1; + indices[j - 1] = Integer.parseInt(entity[0]) - offset; + values[j - 1] = Float.parseFloat(entity[1]); + } + return new SparseExample(label, values, indices); + } + + public static void main(final String[] args) { + final Parser<String> parser = new SVMLightParser(); + for (int i = 0; i < 10; i++) { + final List<SparseExample> examples = new ArrayList<>(); + float avgFtLen = 0; + try (final BufferedReader br = new BufferedReader(new FileReader( + "C:\\Users\\shravan\\data\\splice\\hdi\\hdi_uncomp\\part-r-0000" + i))) { + String line = null; + while ((line = br.readLine()) != null) { + final SparseExample spEx = (SparseExample) parser.parse(line); + avgFtLen += spEx.getFeatureLength(); + examples.add(spEx); + } + } catch (final IOException e) { + throw new RuntimeException("Exception", e); + } + + LOG.log(Level.INFO, "OUT: {0} {1} {2}", + new Object[] { examples.size(), avgFtLen, avgFtLen / examples.size() }); + + examples.clear(); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/LogisticLossFunction.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/LogisticLossFunction.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/LogisticLossFunction.java new file mode 100644 index 0000000..78eb16f --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/LogisticLossFunction.java @@ -0,0 +1,50 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.loss; + +import javax.inject.Inject; + +public final class LogisticLossFunction implements LossFunction { + + /** + * Trivial constructor. + */ + @Inject + public LogisticLossFunction() { + } + + @Override + public double computeLoss(final double y, final double f) { + final double predictedTimesLabel = y * f; + return Math.log(1 + Math.exp(-predictedTimesLabel)); + } + + @Override + public double computeGradient(final double y, final double f) { + final double predictedTimesLabel = y * f; + return -y / (1 + Math.exp(predictedTimesLabel)); + } + + @Override + public String toString() { + return "LogisticLossFunction{}"; + } +} + + http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/LossFunction.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/LossFunction.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/LossFunction.java new file mode 100644 index 0000000..e762add --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/LossFunction.java @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.loss; + +import org.apache.reef.tang.annotations.DefaultImplementation; + +/** + * Interface for Loss Functions. + */ +@DefaultImplementation(SquaredErrorLossFunction.class) +public interface LossFunction { + + /** + * Computes the loss incurred by predicting f, if y is the true label. + * + * @param y the label + * @param f the prediction + * @return the loss incurred by predicting f, if y is the true label. + */ + double computeLoss(final double y, final double f); + + /** + * Computes the gradient with respect to f, if y is the true label. + * + * @param y the label + * @param f the prediction + * @return the gradient with respect to f + */ + double computeGradient(final double y, final double f); +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/SquaredErrorLossFunction.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/SquaredErrorLossFunction.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/SquaredErrorLossFunction.java new file mode 100644 index 0000000..327f566 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/SquaredErrorLossFunction.java @@ -0,0 +1,49 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.loss; + +import javax.inject.Inject; + +/** + * The Squared Error {@link LossFunction}. + */ +public class SquaredErrorLossFunction implements LossFunction { + + /** + * Trivial constructor. + */ + @Inject + public SquaredErrorLossFunction() { + } + + @Override + public double computeLoss(double y, double f) { + return Math.pow(y - f, 2.0); + } + + @Override + public double computeGradient(double y, double f) { + return (f - y) * 0.5; + } + + @Override + public String toString() { + return "SquaredErrorLossFunction{}"; + } +}
