http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/CommunicationGroupDriverImpl.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/CommunicationGroupDriverImpl.java b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/CommunicationGroupDriverImpl.java new file mode 100644 index 0000000..d9a4cd9 --- /dev/null +++ b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/CommunicationGroupDriverImpl.java @@ -0,0 +1,451 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.io.network.group.impl.driver; + +import org.apache.reef.annotations.audience.DriverSide; +import org.apache.reef.annotations.audience.Private; +import org.apache.reef.driver.evaluator.FailedEvaluator; +import org.apache.reef.driver.parameters.DriverIdentifier; +import org.apache.reef.driver.task.FailedTask; +import org.apache.reef.driver.task.RunningTask; +import org.apache.reef.driver.task.TaskConfigurationOptions; +import org.apache.reef.io.network.group.api.config.OperatorSpec; +import org.apache.reef.io.network.group.api.driver.CommunicationGroupDriver; +import org.apache.reef.io.network.group.api.driver.Topology; +import org.apache.reef.io.network.group.impl.GroupCommunicationMessage; +import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec; +import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec; +import org.apache.reef.io.network.group.impl.config.parameters.CommunicationGroupName; +import org.apache.reef.io.network.group.impl.config.parameters.OperatorName; +import org.apache.reef.io.network.group.impl.config.parameters.SerializedOperConfigs; +import org.apache.reef.io.network.group.impl.utils.BroadcastingEventHandler; +import org.apache.reef.io.network.group.impl.utils.CountingSemaphore; +import org.apache.reef.io.network.group.impl.utils.SetMap; +import org.apache.reef.io.network.group.impl.utils.Utils; +import org.apache.reef.io.network.proto.ReefNetworkGroupCommProtos; +import org.apache.reef.tang.Configuration; +import org.apache.reef.tang.Injector; +import org.apache.reef.tang.JavaConfigurationBuilder; +import org.apache.reef.tang.Tang; +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.exceptions.InjectionException; +import org.apache.reef.tang.formats.ConfigurationSerializer; +import org.apache.reef.wake.EStage; + +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.logging.Logger; + +@DriverSide +@Private +public class CommunicationGroupDriverImpl implements CommunicationGroupDriver { + + private static final Logger LOG = Logger.getLogger(CommunicationGroupDriverImpl.class.getName()); + + private final Class<? extends Name<String>> groupName; + private final ConcurrentMap<Class<? extends Name<String>>, OperatorSpec> operatorSpecs = new ConcurrentHashMap<>(); + private final ConcurrentMap<Class<? extends Name<String>>, Topology> topologies = new ConcurrentHashMap<>(); + private final Map<String, TaskState> perTaskState = new HashMap<>(); + private boolean finalised = false; + private final ConfigurationSerializer confSerializer; + private final EStage<GroupCommunicationMessage> senderStage; + private final String driverId; + private final int numberOfTasks; + + private final CountingSemaphore allTasksAdded; + + private final Object topologiesLock = new Object(); + private final Object configLock = new Object(); + private final AtomicBoolean initializing = new AtomicBoolean(true); + + private final Object yetToRunLock = new Object(); + private final Object toBeRemovedLock = new Object(); + + private final SetMap<MsgKey, IndexedMsg> msgQue = new SetMap<>(); + + private final int fanOut; + + public CommunicationGroupDriverImpl(final Class<? extends Name<String>> groupName, + final ConfigurationSerializer confSerializer, + final EStage<GroupCommunicationMessage> senderStage, + final BroadcastingEventHandler<RunningTask> commGroupRunningTaskHandler, + final BroadcastingEventHandler<FailedTask> commGroupFailedTaskHandler, + final BroadcastingEventHandler<FailedEvaluator> commGroupFailedEvaluatorHandler, + final BroadcastingEventHandler<GroupCommunicationMessage> commGroupMessageHandler, + final String driverId, final int numberOfTasks, final int fanOut) { + super(); + this.groupName = groupName; + this.numberOfTasks = numberOfTasks; + this.driverId = driverId; + this.confSerializer = confSerializer; + this.senderStage = senderStage; + this.fanOut = fanOut; + this.allTasksAdded = new CountingSemaphore(numberOfTasks, getQualifiedName(), topologiesLock); + + final TopologyRunningTaskHandler topologyRunningTaskHandler = new TopologyRunningTaskHandler(this); + commGroupRunningTaskHandler.addHandler(topologyRunningTaskHandler); + final TopologyFailedTaskHandler topologyFailedTaskHandler = new TopologyFailedTaskHandler(this); + commGroupFailedTaskHandler.addHandler(topologyFailedTaskHandler); + final TopologyFailedEvaluatorHandler topologyFailedEvaluatorHandler = new TopologyFailedEvaluatorHandler(this); + commGroupFailedEvaluatorHandler.addHandler(topologyFailedEvaluatorHandler); + final TopologyMessageHandler topologyMessageHandler = new TopologyMessageHandler(this); + commGroupMessageHandler.addHandler(topologyMessageHandler); + } + + @Override + public CommunicationGroupDriver addBroadcast(final Class<? extends Name<String>> operatorName, + final BroadcastOperatorSpec spec) { + LOG.entering("CommunicationGroupDriverImpl", "addBroadcast", new Object[]{getQualifiedName(), Utils.simpleName(operatorName), spec}); + if (finalised) { + throw new IllegalStateException("Can't add more operators to a finalised spec"); + } + operatorSpecs.put(operatorName, spec); + final Topology topology = new TreeTopology(senderStage, groupName, operatorName, driverId, numberOfTasks, fanOut); + topology.setRootTask(spec.getSenderId()); + topology.setOperatorSpecification(spec); + topologies.put(operatorName, topology); + LOG.exiting("CommunicationGroupDriverImpl", "addBroadcast", Arrays.toString(new Object[]{getQualifiedName(), Utils.simpleName(operatorName), " added"})); + return this; + } + + @Override + public CommunicationGroupDriver addReduce(final Class<? extends Name<String>> operatorName, + final ReduceOperatorSpec spec) { + LOG.entering("CommunicationGroupDriverImpl", "addReduce", new Object[]{getQualifiedName(), Utils.simpleName(operatorName), spec}); + if (finalised) { + throw new IllegalStateException("Can't add more operators to a finalised spec"); + } + LOG.finer(getQualifiedName() + "Adding reduce operator to tree topology: " + spec); + operatorSpecs.put(operatorName, spec); + final Topology topology = new TreeTopology(senderStage, groupName, operatorName, driverId, numberOfTasks, fanOut); + topology.setRootTask(spec.getReceiverId()); + topology.setOperatorSpecification(spec); + topologies.put(operatorName, topology); + LOG.exiting("CommunicationGroupDriverImpl", "addReduce", Arrays.toString(new Object[]{getQualifiedName(), Utils.simpleName(operatorName), " added"})); + return this; + } + + @Override + public Configuration getTaskConfiguration(final Configuration taskConf) { + LOG.entering("CommunicationGroupDriverImpl", "getTaskConfiguration", new Object[]{getQualifiedName(), confSerializer.toString(taskConf)}); + final JavaConfigurationBuilder jcb = Tang.Factory.getTang().newConfigurationBuilder(); + final String taskId = taskId(taskConf); + if (perTaskState.containsKey(taskId)) { + jcb.bindNamedParameter(DriverIdentifier.class, driverId); + jcb.bindNamedParameter(CommunicationGroupName.class, groupName.getName()); + LOG.finest(getQualifiedName() + "Task has been added. Waiting to acquire configLock"); + synchronized (configLock) { + LOG.finest(getQualifiedName() + "Acquired configLock"); + while (cantGetConfig(taskId)) { + LOG.finest(getQualifiedName() + "Need to wait for failure"); + try { + configLock.wait(); + } catch (final InterruptedException e) { + throw new RuntimeException(getQualifiedName() + "InterruptedException while waiting on configLock", e); + } + } + LOG.finest(getQualifiedName() + taskId + " - Will fetch configuration now."); + LOG.finest(getQualifiedName() + "Released configLock. Waiting to acquire topologiesLock"); + } + synchronized (topologiesLock) { + LOG.finest(getQualifiedName() + "Acquired topologiesLock"); + for (final Map.Entry<Class<? extends Name<String>>, OperatorSpec> operSpecEntry : operatorSpecs.entrySet()) { + final Class<? extends Name<String>> operName = operSpecEntry.getKey(); + final Topology topology = topologies.get(operName); + final JavaConfigurationBuilder jcbInner = Tang.Factory.getTang() + .newConfigurationBuilder(topology.getTaskConfiguration(taskId)); + jcbInner.bindNamedParameter(DriverIdentifier.class, driverId); + jcbInner.bindNamedParameter(OperatorName.class, operName.getName()); + jcb.bindSetEntry(SerializedOperConfigs.class, confSerializer.toString(jcbInner.build())); + } + LOG.finest(getQualifiedName() + "Released topologiesLock"); + } + } else { + return null; + } + final Configuration configuration = jcb.build(); + LOG.exiting("CommunicationGroupDriverImpl", "getTaskConfiguration", Arrays.toString(new Object[]{getQualifiedName(), confSerializer.toString(configuration)})); + return configuration; + } + + private boolean cantGetConfig(final String taskId) { + LOG.entering("CommunicationGroupDriverImpl", "cantGetConfig", new Object[]{getQualifiedName(), taskId}); + final TaskState taskState = perTaskState.get(taskId); + if (!taskState.equals(TaskState.NOT_STARTED)) { + LOG.finest(getQualifiedName() + taskId + " has started."); + if (taskState.equals(TaskState.RUNNING)) { + LOG.exiting("CommunicationGroupDriverImpl", "cantGetConfig", Arrays.toString(new Object[]{true, getQualifiedName(), taskId, " is running. We can't get config"})); + return true; + } else { + LOG.exiting("CommunicationGroupDriverImpl", "cantGetConfig", Arrays.toString(new Object[]{false, getQualifiedName(), taskId, " has failed. We can get config"})); + return false; + } + } else { + LOG.exiting("CommunicationGroupDriverImpl", "cantGetConfig", Arrays.toString(new Object[]{false, getQualifiedName(), taskId, " has not started. We can get config"})); + return false; + } + } + + @Override + public void finalise() { + finalised = true; + } + + @Override + public void addTask(final Configuration partialTaskConf) { + LOG.entering("CommunicationGroupDriverImpl", "addTask", new Object[]{getQualifiedName(), confSerializer.toString(partialTaskConf)}); + final String taskId = taskId(partialTaskConf); + LOG.finest(getQualifiedName() + "AddTask(" + taskId + "). Waiting to acquire toBeRemovedLock"); + synchronized (toBeRemovedLock) { + LOG.finest(getQualifiedName() + "Acquired toBeRemovedLock"); + while (perTaskState.containsKey(taskId)) { + LOG.finest(getQualifiedName() + "Trying to add an existing task. Will wait for removeTask"); + try { + toBeRemovedLock.wait(); + } catch (final InterruptedException e) { + throw new RuntimeException(getQualifiedName() + "InterruptedException while waiting on toBeRemovedLock", e); + } + } + LOG.finest(getQualifiedName() + "Released toBeRemovedLock. Waiting to acquire topologiesLock"); + } + synchronized (topologiesLock) { + LOG.finest(getQualifiedName() + "Acquired topologiesLock"); + for (final Class<? extends Name<String>> operName : operatorSpecs.keySet()) { + final Topology topology = topologies.get(operName); + topology.addTask(taskId); + } + perTaskState.put(taskId, TaskState.NOT_STARTED); + LOG.finest(getQualifiedName() + "Released topologiesLock"); + } + LOG.fine(getQualifiedName() + "Added " + taskId + " to topology"); + LOG.exiting("CommunicationGroupDriverImpl", "addTask", Arrays.toString(new Object[]{getQualifiedName(), "Added task: ", taskId})); + } + + public void removeTask(final String taskId) { + LOG.entering("CommunicationGroupDriverImpl", "removeTask", new Object[]{getQualifiedName(), taskId}); + LOG.info(getQualifiedName() + "Removing Task " + taskId + + " as the evaluator has failed."); + LOG.finest(getQualifiedName() + "Remove Task(" + taskId + + "): Waiting to acquire topologiesLock"); + synchronized (topologiesLock) { + LOG.finest(getQualifiedName() + "Acquired topologiesLock"); + for (final Class<? extends Name<String>> operName : operatorSpecs.keySet()) { + final Topology topology = topologies.get(operName); + topology.removeTask(taskId); + } + perTaskState.remove(taskId); + LOG.finest(getQualifiedName() + "Released topologiesLock. Waiting to acquire toBeRemovedLock"); + } + synchronized (toBeRemovedLock) { + LOG.finest(getQualifiedName() + "Acquired toBeRemovedLock"); + LOG.finest(getQualifiedName() + "Removed Task " + taskId + " Notifying waiting threads"); + toBeRemovedLock.notifyAll(); + LOG.finest(getQualifiedName() + "Released toBeRemovedLock"); + } + LOG.fine(getQualifiedName() + "Removed " + taskId + " to topology"); + LOG.exiting("CommunicationGroupDriverImpl", "removeTask", Arrays.toString(new Object[]{getQualifiedName(), "Removed task: ", taskId})); + } + + public void runTask(final String id) { + LOG.entering("CommunicationGroupDriverImpl", "runTask", new Object[]{getQualifiedName(), id}); + LOG.finest(getQualifiedName() + "Task-" + id + " running. Waiting to acquire topologiesLock"); + LOG.fine(getQualifiedName() + "Got running Task: " + id); + + boolean nonMember = false; + synchronized (topologiesLock) { + if (perTaskState.containsKey(id)) { + LOG.finest(getQualifiedName() + "Acquired topologiesLock"); + for (final Class<? extends Name<String>> operName : operatorSpecs.keySet()) { + final Topology topology = topologies.get(operName); + topology.onRunningTask(id); + } + allTasksAdded.decrement(); + perTaskState.put(id, TaskState.RUNNING); + LOG.finest(getQualifiedName() + "Released topologiesLock. Waiting to acquire yetToRunLock"); + } else { + nonMember = true; + } + } + synchronized (yetToRunLock) { + LOG.finest(getQualifiedName() + "Acquired yetToRunLock"); + yetToRunLock.notifyAll(); + LOG.finest(getQualifiedName() + "Released yetToRunLock"); + } + if (nonMember) { + LOG.exiting("CommunicationGroupDriverImpl", "runTask", getQualifiedName() + id + " does not belong to this communication group. Ignoring"); + } else { + LOG.fine(getQualifiedName() + "Status of task " + id + " changed to RUNNING"); + LOG.exiting("CommunicationGroupDriverImpl", "runTask", Arrays.toString(new Object[]{getQualifiedName(), "Set running complete on task ", id})); + } + } + + public void failTask(final String id) { + LOG.entering("CommunicationGroupDriverImpl", "failTask", new Object[]{getQualifiedName(), id}); + LOG.finest(getQualifiedName() + "Task-" + id + " failed. Waiting to acquire yetToRunLock"); + LOG.fine(getQualifiedName() + "Got failed Task: " + id); + synchronized (yetToRunLock) { + LOG.finest(getQualifiedName() + "Acquired yetToRunLock"); + while (cantFailTask(id)) { + LOG.finest(getQualifiedName() + "Need to wait for it run"); + try { + yetToRunLock.wait(); + } catch (final InterruptedException e) { + throw new RuntimeException(getQualifiedName() + "InterruptedException while waiting on yetToRunLock", e); + } + } + LOG.finest(getQualifiedName() + id + " - Can safely set failure."); + LOG.finest(getQualifiedName() + "Released yetToRunLock. Waiting to acquire topologiesLock"); + } + synchronized (topologiesLock) { + LOG.finest(getQualifiedName() + "Acquired topologiesLock"); + for (final Class<? extends Name<String>> operName : operatorSpecs.keySet()) { + final Topology topology = topologies.get(operName); + topology.onFailedTask(id); + } + allTasksAdded.increment(); + perTaskState.put(id, TaskState.FAILED); + LOG.finest(getQualifiedName() + "Removing msgs associated with dead task " + id + " from msgQue."); + final Set<MsgKey> keys = msgQue.keySet(); + final List<MsgKey> keysToBeRemoved = new ArrayList<>(); + for (final MsgKey msgKey : keys) { + if (msgKey.getSrc().equals(id)) { + keysToBeRemoved.add(msgKey); + } + } + LOG.finest(getQualifiedName() + keysToBeRemoved + " keys that will be removed"); + for (final MsgKey key : keysToBeRemoved) { + msgQue.remove(key); + } + LOG.finest(getQualifiedName() + "Released topologiesLock. Waiting to acquire configLock"); + } + synchronized (configLock) { + LOG.finest(getQualifiedName() + "Acquired configLock"); + configLock.notifyAll(); + LOG.finest(getQualifiedName() + "Released configLock"); + } + LOG.fine(getQualifiedName() + "Status of task " + id + " changed to FAILED"); + LOG.exiting("CommunicationGroupDriverImpl", "failTask", Arrays.toString(new Object[]{getQualifiedName(), "Set failed complete on task ", id})); + } + + private boolean cantFailTask(final String taskId) { + LOG.entering("CommunicationGroupDriverImpl", "cantFailTask", new Object[]{getQualifiedName(), taskId}); + final TaskState taskState = perTaskState.get(taskId); + if (!taskState.equals(TaskState.NOT_STARTED)) { + LOG.finest(getQualifiedName() + taskId + " has started."); + if (!taskState.equals(TaskState.RUNNING)) { + LOG.exiting("CommunicationGroupDriverImpl", "cantFailTask", Arrays.toString(new Object[]{true, getQualifiedName(), taskId, " is not running yet. Can't set failure"})); + return true; + } else { + LOG.exiting("CommunicationGroupDriverImpl", "cantFailTask", Arrays.toString(new Object[]{false, getQualifiedName(), taskId, " is running. Can set failure"})); + return false; + } + } else { + LOG.exiting("CommunicationGroupDriverImpl", "cantFailTask", Arrays.toString(new Object[]{true, getQualifiedName(), taskId, " has not started. We can't fail a task that hasn't started"})); + return true; + } + } + + public void queNProcessMsg(final GroupCommunicationMessage msg) { + LOG.entering("CommunicationGroupDriverImpl", "queNProcessMsg", new Object[]{getQualifiedName(), msg}); + final IndexedMsg indMsg = new IndexedMsg(msg); + final Class<? extends Name<String>> operName = indMsg.getOperName(); + final MsgKey key = new MsgKey(msg); + if (msgQue.contains(key, indMsg)) { + throw new RuntimeException(getQualifiedName() + "MsgQue already contains " + msg.getType() + " msg for " + key + " in " + + Utils.simpleName(operName)); + } + LOG.finest(getQualifiedName() + "Adding msg to que"); + msgQue.add(key, indMsg); + if (msgQue.count(key) == topologies.size()) { + LOG.finest(getQualifiedName() + "MsgQue for " + key + " contains " + msg.getType() + " msgs from: " + + msgQue.get(key)); + for (final IndexedMsg innerIndMsg : msgQue.remove(key)) { + topologies.get(innerIndMsg.getOperName()).onReceiptOfMessage(innerIndMsg.getMsg()); + } + LOG.finest(getQualifiedName() + "All msgs processed and removed"); + } + LOG.exiting("CommunicationGroupDriverImpl", "queNProcessMsg", Arrays.toString(new Object[]{getQualifiedName(), "Que & Process done for: ", msg})); + } + + private boolean isMsgVersionOk(final GroupCommunicationMessage msg) { + LOG.entering("CommunicationGroupDriverImpl", "isMsgVersionOk", new Object[]{getQualifiedName(), msg}); + if (msg.hasVersion()) { + final String srcId = msg.getSrcid(); + final int rcvSrcVersion = msg.getSrcVersion(); + final int expSrcVersion = topologies.get(Utils.getClass(msg.getOperatorname())).getNodeVersion(srcId); + + final boolean srcVersionChk = chkVersion(rcvSrcVersion, expSrcVersion, "Src Version Check: "); + LOG.exiting("CommunicationGroupDriverImpl", "isMsgVersionOk", Arrays.toString(new Object[]{srcVersionChk, getQualifiedName(), msg})); + return srcVersionChk; + } else { + throw new RuntimeException(getQualifiedName() + "can only deal with versioned msgs"); + } + } + + private boolean chkVersion(final int rcvVersion, final int version, final String msg) { + if (rcvVersion < version) { + LOG.warning(getQualifiedName() + msg + "received a ver-" + rcvVersion + " msg while expecting ver-" + version); + return false; + } + if (rcvVersion > version) { + LOG.warning(getQualifiedName() + msg + "received a HIGHER ver-" + rcvVersion + " msg while expecting ver-" + + version + ". Something fishy!!!"); + return false; + } + return true; + } + + public void processMsg(final GroupCommunicationMessage msg) { + LOG.entering("CommunicationGroupDriverImpl", "processMsg", new Object[]{getQualifiedName(), msg}); + LOG.finest(getQualifiedName() + "ProcessMsg: " + msg + ". Waiting to acquire topologiesLock"); + synchronized (topologiesLock) { + LOG.finest(getQualifiedName() + "Acquired topologiesLock"); + if (!isMsgVersionOk(msg)) { + LOG.finer(getQualifiedName() + "Discarding msg. Released topologiesLock"); + return; + } + if (initializing.get() || msg.getType().equals(ReefNetworkGroupCommProtos.GroupCommMessage.Type.UpdateTopology)) { + LOG.fine(getQualifiedName() + msg.getSimpleOperName() + ": Waiting for all required(" + allTasksAdded.getInitialCount() + + ") nodes to run"); + allTasksAdded.await(); + LOG.fine(getQualifiedName() + msg.getSimpleOperName() + ": All required(" + allTasksAdded.getInitialCount() + + ") nodes are running"); + initializing.compareAndSet(true, false); + } + queNProcessMsg(msg); + LOG.finest(getQualifiedName() + "Released topologiesLock"); + } + LOG.exiting("CommunicationGroupDriverImpl", "processMsg", Arrays.toString(new Object[]{getQualifiedName(), "ProcessMsg done for: ", msg})); + } + + private String taskId(final Configuration partialTaskConf) { + try { + final Injector injector = Tang.Factory.getTang().newInjector(partialTaskConf); + return injector.getNamedInstance(TaskConfigurationOptions.Identifier.class); + } catch (final InjectionException e) { + throw new RuntimeException(getQualifiedName() + "Injection exception while extracting taskId from partialTaskConf", e); + } + } + + private String getQualifiedName() { + return Utils.simpleName(groupName) + " - "; + } +}
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/CtrlMsgSender.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/CtrlMsgSender.java b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/CtrlMsgSender.java new file mode 100644 index 0000000..5535511 --- /dev/null +++ b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/CtrlMsgSender.java @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.io.network.group.impl.driver; + + +import org.apache.reef.exception.evaluator.NetworkException; +import org.apache.reef.io.network.Connection; +import org.apache.reef.io.network.impl.NetworkService; +import org.apache.reef.io.network.group.impl.GroupCommunicationMessage; +import org.apache.reef.wake.EventHandler; +import org.apache.reef.wake.Identifier; +import org.apache.reef.wake.IdentifierFactory; + +import java.util.logging.Logger; + +/** + * Event handler that receives ctrl msgs and + * dispatched them using network service + */ +public class CtrlMsgSender implements EventHandler<GroupCommunicationMessage> { + + private static final Logger LOG = Logger.getLogger(CtrlMsgSender.class.getName()); + private final IdentifierFactory idFac; + private final NetworkService<GroupCommunicationMessage> netService; + + public CtrlMsgSender(final IdentifierFactory idFac, final NetworkService<GroupCommunicationMessage> netService) { + this.idFac = idFac; + this.netService = netService; + } + + @Override + public void onNext(final GroupCommunicationMessage srcCtrlMsg) { + LOG.entering("CtrlMsgSender", "onNext", srcCtrlMsg); + final Identifier id = idFac.getNewInstance(srcCtrlMsg.getDestid()); + final Connection<GroupCommunicationMessage> link = netService.newConnection(id); + try { + link.open(); + link.write(srcCtrlMsg); + } catch (final NetworkException e) { + throw new RuntimeException("Unable to send ctrl task msg to parent " + id, e); + } + LOG.exiting("CtrlMsgSender", "onNext", srcCtrlMsg); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/ExceptionHandler.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/ExceptionHandler.java b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/ExceptionHandler.java new file mode 100644 index 0000000..38a1df6 --- /dev/null +++ b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/ExceptionHandler.java @@ -0,0 +1,56 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.io.network.group.impl.driver; + +import org.apache.reef.wake.EventHandler; + +import javax.inject.Inject; +import java.util.ArrayList; +import java.util.List; +import java.util.logging.Logger; + +/** + * + */ +public class ExceptionHandler implements EventHandler<Exception> { + private static final Logger LOG = Logger.getLogger(ExceptionHandler.class.getName()); + List<Exception> exceptions = new ArrayList<>(); + + @Inject + public ExceptionHandler() { + } + + @Override + public synchronized void onNext(final Exception ex) { + LOG.entering("ExceptionHandler", "onNext", new Object[]{ex}); + exceptions.add(ex); + LOG.finest("Got an exception. Added it to list(" + exceptions.size() + ")"); + LOG.exiting("ExceptionHandler", "onNext"); + } + + public synchronized boolean hasExceptions() { + LOG.entering("ExceptionHandler", "hasExceptions"); + final boolean ret = !exceptions.isEmpty(); + LOG.finest("There are " + exceptions.size() + " exceptions. Clearing now"); + exceptions.clear(); + LOG.exiting("ExceptionHandler", "hasExceptions", ret); + return ret; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/FlatTopology.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/FlatTopology.java b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/FlatTopology.java new file mode 100644 index 0000000..70670a2 --- /dev/null +++ b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/FlatTopology.java @@ -0,0 +1,307 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.io.network.group.impl.driver; + +import org.apache.reef.io.network.group.api.operators.GroupCommOperator; +import org.apache.reef.io.network.group.api.GroupChanges; +import org.apache.reef.io.network.group.api.config.OperatorSpec; +import org.apache.reef.io.network.group.api.driver.TaskNode; +import org.apache.reef.io.network.group.api.driver.Topology; +import org.apache.reef.io.network.group.impl.GroupChangesCodec; +import org.apache.reef.io.network.group.impl.GroupChangesImpl; +import org.apache.reef.io.network.group.impl.GroupCommunicationMessage; +import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec; +import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec; +import org.apache.reef.io.network.group.impl.config.parameters.DataCodec; +import org.apache.reef.io.network.group.impl.config.parameters.ReduceFunctionParam; +import org.apache.reef.io.network.group.impl.config.parameters.TaskVersion; +import org.apache.reef.io.network.group.impl.operators.BroadcastReceiver; +import org.apache.reef.io.network.group.impl.operators.BroadcastSender; +import org.apache.reef.io.network.group.impl.operators.ReduceReceiver; +import org.apache.reef.io.network.group.impl.operators.ReduceSender; +import org.apache.reef.io.network.group.impl.utils.Utils; +import org.apache.reef.io.network.proto.ReefNetworkGroupCommProtos; +import org.apache.reef.io.serialization.Codec; +import org.apache.reef.tang.Configuration; +import org.apache.reef.tang.JavaConfigurationBuilder; +import org.apache.reef.tang.Tang; +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.wake.EStage; +import org.apache.reef.wake.EventHandler; +import org.apache.reef.wake.impl.SingleThreadStage; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.logging.Logger; + +/** + * Implements a one level Tree Topology + */ +public class FlatTopology implements Topology { + + private static final Logger LOG = Logger.getLogger(FlatTopology.class.getName()); + + private final EStage<GroupCommunicationMessage> senderStage; + private final Class<? extends Name<String>> groupName; + private final Class<? extends Name<String>> operName; + private final String driverId; + private String rootId; + private OperatorSpec operatorSpec; + + private TaskNode root; + private final ConcurrentMap<String, TaskNode> nodes = new ConcurrentSkipListMap<>(); + + public FlatTopology(final EStage<GroupCommunicationMessage> senderStage, + final Class<? extends Name<String>> groupName, final Class<? extends Name<String>> operatorName, + final String driverId, final int numberOfTasks) { + this.senderStage = senderStage; + this.groupName = groupName; + this.operName = operatorName; + this.driverId = driverId; + } + + @Override + public void setRootTask(final String rootId) { + this.rootId = rootId; + } + + /** + * @return the rootId + */ + @Override + public String getRootId() { + return rootId; + } + + @Override + public void setOperatorSpecification(final OperatorSpec spec) { + this.operatorSpec = spec; + } + + @Override + public Configuration getTaskConfiguration(final String taskId) { + LOG.finest(getQualifiedName() + "Getting config for task " + taskId); + final TaskNode taskNode = nodes.get(taskId); + if (taskNode == null) { + throw new RuntimeException(getQualifiedName() + taskId + " does not exist"); + } + + final int version; + version = getNodeVersion(taskId); + final JavaConfigurationBuilder jcb = Tang.Factory.getTang().newConfigurationBuilder(); + jcb.bindNamedParameter(DataCodec.class, operatorSpec.getDataCodecClass()); + jcb.bindNamedParameter(TaskVersion.class, Integer.toString(version)); + if (operatorSpec instanceof BroadcastOperatorSpec) { + final BroadcastOperatorSpec broadcastOperatorSpec = (BroadcastOperatorSpec) operatorSpec; + if (taskId.equals(broadcastOperatorSpec.getSenderId())) { + jcb.bindImplementation(GroupCommOperator.class, BroadcastSender.class); + } else { + jcb.bindImplementation(GroupCommOperator.class, BroadcastReceiver.class); + } + } + if (operatorSpec instanceof ReduceOperatorSpec) { + final ReduceOperatorSpec reduceOperatorSpec = (ReduceOperatorSpec) operatorSpec; + jcb.bindNamedParameter(ReduceFunctionParam.class, reduceOperatorSpec.getRedFuncClass()); + if (taskId.equals(reduceOperatorSpec.getReceiverId())) { + jcb.bindImplementation(GroupCommOperator.class, ReduceReceiver.class); + } else { + jcb.bindImplementation(GroupCommOperator.class, ReduceSender.class); + } + } + return jcb.build(); + } + + @Override + public int getNodeVersion(final String taskId) { + final TaskNode node = nodes.get(taskId); + if (node == null) { + throw new RuntimeException(getQualifiedName() + taskId + " is not available on the nodes map"); + } + final int version = node.getVersion(); + return version; + } + + @Override + public void removeTask(final String taskId) { + if (!nodes.containsKey(taskId)) { + LOG.warning("Trying to remove a non-existent node in the task graph"); + return; + } + if (taskId.equals(rootId)) { + unsetRootNode(taskId); + } else { + removeChild(taskId); + } + } + + @Override + public void addTask(final String taskId) { + if (nodes.containsKey(taskId)) { + LOG.warning("Got a request to add a task that is already in the graph"); + LOG.warning("We need to block this request till the delete finishes"); + } + if (taskId.equals(rootId)) { + setRootNode(taskId); + } else { + addChild(taskId); + } + } + + /** + * @param taskId + */ + private void addChild(final String taskId) { + LOG.finest(getQualifiedName() + "Adding leaf " + taskId); + final TaskNode node = new TaskNodeImpl(senderStage, groupName, operName, taskId, driverId, false); + final TaskNode leaf = node; + if (root != null) { + leaf.setParent(root); + root.addChild(leaf); + } + nodes.put(taskId, leaf); + } + + /** + * @param taskId + */ + private void removeChild(final String taskId) { + LOG.finest(getQualifiedName() + "Removing leaf " + taskId); + if (root != null) { + root.removeChild(nodes.get(taskId)); + } + nodes.remove(taskId); + } + + private void setRootNode(final String rootId) { + LOG.finest(getQualifiedName() + "Setting " + rootId + " as root"); + final TaskNode node = new TaskNodeImpl(senderStage, groupName, operName, rootId, driverId, true); + this.root = node; + + for (final Map.Entry<String, TaskNode> nodeEntry : nodes.entrySet()) { + final String id = nodeEntry.getKey(); + + final TaskNode leaf = nodeEntry.getValue(); + root.addChild(leaf); + leaf.setParent(root); + } + nodes.put(rootId, root); + } + + /** + * @param taskId + */ + private void unsetRootNode(final String taskId) { + LOG.finest(getQualifiedName() + "Unsetting " + rootId + " as root"); + nodes.remove(rootId); + + for (final Map.Entry<String, TaskNode> nodeEntry : nodes.entrySet()) { + final String id = nodeEntry.getKey(); + final TaskNode leaf = nodeEntry.getValue(); + leaf.setParent(null); + } + } + + @Override + public void onFailedTask(final String id) { + LOG.finest(getQualifiedName() + "Task-" + id + " failed"); + final TaskNode taskNode = nodes.get(id); + if (taskNode == null) { + throw new RuntimeException(getQualifiedName() + id + " does not exist"); + } + + taskNode.onFailedTask(); + } + + @Override + public void onRunningTask(final String id) { + LOG.finest(getQualifiedName() + "Task-" + id + " is running"); + final TaskNode taskNode = nodes.get(id); + if (taskNode == null) { + throw new RuntimeException(getQualifiedName() + id + " does not exist"); + } + + taskNode.onRunningTask(); + } + + @Override + public void onReceiptOfMessage(final GroupCommunicationMessage msg) { + LOG.finest(getQualifiedName() + "processing " + msg.getType() + " from " + msg.getSrcid()); + if (msg.getType().equals(ReefNetworkGroupCommProtos.GroupCommMessage.Type.TopologyChanges)) { + processTopologyChanges(msg); + return; + } + if (msg.getType().equals(ReefNetworkGroupCommProtos.GroupCommMessage.Type.UpdateTopology)) { + processUpdateTopology(msg); + return; + } + final String id = msg.getSrcid(); + nodes.get(id).onReceiptOfAcknowledgement(msg); + } + + private void processUpdateTopology(final GroupCommunicationMessage msg) { + final String dstId = msg.getSrcid(); + final int version = getNodeVersion(dstId); + + LOG.finest(getQualifiedName() + "Creating NodeTopologyUpdateWaitStage to wait on nodes to be updated"); + final EventHandler<List<TaskNode>> topoUpdateWaitHandler = new TopologyUpdateWaitHandler(senderStage, groupName, + operName, driverId, 0, + dstId, version, + getQualifiedName()); + final EStage<List<TaskNode>> nodeTopologyUpdateWaitStage = new SingleThreadStage<>("NodeTopologyUpdateWaitStage", + topoUpdateWaitHandler, + nodes.size()); + + final List<TaskNode> toBeUpdatedNodes = new ArrayList<>(nodes.size()); + LOG.finest(getQualifiedName() + "Checking which nodes need to be updated"); + for (final TaskNode node : nodes.values()) { + if (node.isRunning() && node.hasChanges() && node.resetTopologySetupSent()) { + toBeUpdatedNodes.add(node); + } + } + for (final TaskNode node : toBeUpdatedNodes) { + node.updatingTopology(); + senderStage.onNext(Utils.bldVersionedGCM(groupName, operName, ReefNetworkGroupCommProtos.GroupCommMessage.Type.UpdateTopology, driverId, 0, node.getTaskId(), + node.getVersion(), Utils.EmptyByteArr)); + } + nodeTopologyUpdateWaitStage.onNext(toBeUpdatedNodes); + } + + private void processTopologyChanges(final GroupCommunicationMessage msg) { + final String dstId = msg.getSrcid(); + boolean hasTopologyChanged = false; + LOG.finest(getQualifiedName() + "Checking which nodes need to be updated"); + for (final TaskNode node : nodes.values()) { + if (!node.isRunning() || node.hasChanges()) { + hasTopologyChanged = true; + break; + } + } + final GroupChanges changes = new GroupChangesImpl(hasTopologyChanged); + final Codec<GroupChanges> changesCodec = new GroupChangesCodec(); + senderStage.onNext(Utils.bldVersionedGCM(groupName, operName, ReefNetworkGroupCommProtos.GroupCommMessage.Type.TopologyChanges, driverId, 0, dstId, getNodeVersion(dstId), + changesCodec.encode(changes))); + } + + private String getQualifiedName() { + return Utils.simpleName(groupName) + ":" + Utils.simpleName(operName) + " - "; + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/GroupCommDriverImpl.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/GroupCommDriverImpl.java b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/GroupCommDriverImpl.java new file mode 100644 index 0000000..8c01b31 --- /dev/null +++ b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/GroupCommDriverImpl.java @@ -0,0 +1,250 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.io.network.group.impl.driver; + +import org.apache.reef.driver.context.ActiveContext; +import org.apache.reef.driver.context.ContextConfiguration; +import org.apache.reef.driver.context.ServiceConfiguration; +import org.apache.reef.driver.evaluator.FailedEvaluator; +import org.apache.reef.driver.parameters.DriverIdentifier; +import org.apache.reef.driver.task.FailedTask; +import org.apache.reef.driver.task.RunningTask; +import org.apache.reef.io.network.Message; +import org.apache.reef.io.network.impl.*; +import org.apache.reef.io.network.naming.NameServer; +import org.apache.reef.io.network.naming.NameServerImpl; +import org.apache.reef.io.network.naming.NameServerParameters; +import org.apache.reef.io.network.group.api.driver.CommunicationGroupDriver; +import org.apache.reef.io.network.group.api.driver.GroupCommServiceDriver; +import org.apache.reef.io.network.group.impl.GroupCommunicationMessage; +import org.apache.reef.io.network.group.impl.GroupCommunicationMessageCodec; +import org.apache.reef.io.network.group.impl.config.parameters.SerializedGroupConfigs; +import org.apache.reef.io.network.group.impl.config.parameters.TreeTopologyFanOut; +import org.apache.reef.io.network.group.impl.task.GroupCommNetworkHandlerImpl; +import org.apache.reef.io.network.group.impl.utils.BroadcastingEventHandler; +import org.apache.reef.io.network.group.impl.utils.Utils; +import org.apache.reef.io.network.util.StringIdentifierFactory; +import org.apache.reef.tang.Configuration; +import org.apache.reef.tang.JavaConfigurationBuilder; +import org.apache.reef.tang.Tang; +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.Parameter; +import org.apache.reef.tang.formats.ConfigurationSerializer; +import org.apache.reef.util.SingletonAsserter; +import org.apache.reef.wake.EStage; +import org.apache.reef.wake.EventHandler; +import org.apache.reef.wake.IdentifierFactory; +import org.apache.reef.wake.impl.LoggingEventHandler; +import org.apache.reef.wake.impl.SingleThreadStage; +import org.apache.reef.wake.impl.SyncStage; +import org.apache.reef.wake.impl.ThreadPoolStage; +import org.apache.reef.wake.remote.NetUtils; + +import javax.inject.Inject; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.logging.Logger; + +/** + * Sets up various stages to handle REEF events and adds the per communication + * group stages to them whenever a new communication group is created. + * <p/> + * Also starts the NameService & the NetworkService on the driver + */ +public class GroupCommDriverImpl implements GroupCommServiceDriver { + private static final Logger LOG = Logger.getLogger(GroupCommDriverImpl.class.getName()); + /** + * TANG instance + */ + private static final Tang tang = Tang.Factory.getTang(); + + private final AtomicInteger contextIds = new AtomicInteger(0); + + private final IdentifierFactory idFac = new StringIdentifierFactory(); + + private final NameServer nameService = new NameServerImpl(0, idFac); + + private final String nameServiceAddr; + private final int nameServicePort; + + private final Map<Class<? extends Name<String>>, CommunicationGroupDriver> commGroupDrivers = new HashMap<>(); + + private final ConfigurationSerializer confSerializer; + + private final NetworkService<GroupCommunicationMessage> netService; + + private final EStage<GroupCommunicationMessage> senderStage; + + private final String driverId; + private final BroadcastingEventHandler<RunningTask> groupCommRunningTaskHandler; + private final EStage<RunningTask> groupCommRunningTaskStage; + private final BroadcastingEventHandler<FailedTask> groupCommFailedTaskHandler; + private final EStage<FailedTask> groupCommFailedTaskStage; + private final BroadcastingEventHandler<FailedEvaluator> groupCommFailedEvaluatorHandler; + private final EStage<FailedEvaluator> groupCommFailedEvaluatorStage; + private final GroupCommMessageHandler groupCommMessageHandler; + private final EStage<GroupCommunicationMessage> groupCommMessageStage; + private final int fanOut; + + @Inject + public GroupCommDriverImpl(final ConfigurationSerializer confSerializer, + @Parameter(DriverIdentifier.class) final String driverId, + @Parameter(TreeTopologyFanOut.class) final int fanOut) { + assert (SingletonAsserter.assertSingleton(getClass())); + this.driverId = driverId; + this.fanOut = fanOut; + this.nameServiceAddr = NetUtils.getLocalAddress(); + this.nameServicePort = nameService.getPort(); + this.confSerializer = confSerializer; + this.groupCommRunningTaskHandler = new BroadcastingEventHandler<>(); + this.groupCommRunningTaskStage = new SyncStage<>("GroupCommRunningTaskStage", groupCommRunningTaskHandler); + this.groupCommFailedTaskHandler = new BroadcastingEventHandler<>(); + this.groupCommFailedTaskStage = new SyncStage<>("GroupCommFailedTaskStage", groupCommFailedTaskHandler); + this.groupCommFailedEvaluatorHandler = new BroadcastingEventHandler<>(); + this.groupCommFailedEvaluatorStage = new SyncStage<>("GroupCommFailedEvaluatorStage", + groupCommFailedEvaluatorHandler); + this.groupCommMessageHandler = new GroupCommMessageHandler(); + this.groupCommMessageStage = new SingleThreadStage<>("GroupCommMessageStage", groupCommMessageHandler, 100 * 1000); + this.netService = new NetworkService<>(idFac, 0, nameServiceAddr, nameServicePort, + new GroupCommunicationMessageCodec(), new MessagingTransportFactory(), + new EventHandler<Message<GroupCommunicationMessage>>() { + + @Override + public void onNext(final Message<GroupCommunicationMessage> msg) { + groupCommMessageStage.onNext(Utils.getGCM(msg)); + } + }, new LoggingEventHandler<Exception>()); + this.netService.registerId(idFac.getNewInstance(driverId)); + this.senderStage = new ThreadPoolStage<>("SrcCtrlMsgSender", new CtrlMsgSender(idFac, netService), 5); + } + + @Override + public CommunicationGroupDriver newCommunicationGroup(final Class<? extends Name<String>> groupName, + final int numberOfTasks) { + LOG.entering("GroupCommDriverImpl", "newCommunicationGroup", new Object[]{Utils.simpleName(groupName), numberOfTasks}); + final BroadcastingEventHandler<RunningTask> commGroupRunningTaskHandler = new BroadcastingEventHandler<>(); + final BroadcastingEventHandler<FailedTask> commGroupFailedTaskHandler = new BroadcastingEventHandler<>(); + final BroadcastingEventHandler<FailedEvaluator> commGroupFailedEvaluatorHandler = new BroadcastingEventHandler<>(); + final BroadcastingEventHandler<GroupCommunicationMessage> commGroupMessageHandler = new BroadcastingEventHandler<>(); + final CommunicationGroupDriver commGroupDriver = new CommunicationGroupDriverImpl(groupName, confSerializer, + senderStage, + commGroupRunningTaskHandler, + commGroupFailedTaskHandler, + commGroupFailedEvaluatorHandler, + commGroupMessageHandler, + driverId, numberOfTasks, fanOut); + commGroupDrivers.put(groupName, commGroupDriver); + groupCommRunningTaskHandler.addHandler(commGroupRunningTaskHandler); + groupCommFailedTaskHandler.addHandler(commGroupFailedTaskHandler); + groupCommMessageHandler.addHandler(groupName, commGroupMessageHandler); + LOG.exiting("GroupCommDriverImpl", "newCommunicationGroup", "Created communication group: " + Utils.simpleName(groupName)); + return commGroupDriver; + } + + @Override + public boolean isConfigured(final ActiveContext activeContext) { + LOG.entering("GroupCommDriverImpl", "isConfigured", activeContext.getId()); + final boolean retVal = activeContext.getId().startsWith("GroupCommunicationContext-"); + LOG.exiting("GroupCommDriverImpl", "isConfigured", retVal); + return retVal; + } + + @Override + public Configuration getContextConfiguration() { + LOG.entering("GroupCommDriverImpl", "getContextConf"); + final Configuration retVal = ContextConfiguration.CONF.set(ContextConfiguration.IDENTIFIER, + "GroupCommunicationContext-" + contextIds.getAndIncrement()).build(); + LOG.exiting("GroupCommDriverImpl", "getContextConf", confSerializer.toString(retVal)); + return retVal; + } + + @Override + public Configuration getServiceConfiguration() { + LOG.entering("GroupCommDriverImpl", "getServiceConf"); + final Configuration serviceConfiguration = ServiceConfiguration.CONF.set(ServiceConfiguration.SERVICES, + NetworkService.class) + .set(ServiceConfiguration.SERVICES, + GroupCommNetworkHandlerImpl.class) + .set(ServiceConfiguration.ON_CONTEXT_STOP, + NetworkServiceClosingHandler.class) + .set(ServiceConfiguration.ON_TASK_STARTED, + BindNSToTask.class) + .set(ServiceConfiguration.ON_TASK_STOP, + UnbindNSFromTask.class).build(); + final Configuration retVal = tang.newConfigurationBuilder(serviceConfiguration) + .bindNamedParameter(NetworkServiceParameters.NetworkServiceCodec.class, + GroupCommunicationMessageCodec.class) + .bindNamedParameter(NetworkServiceParameters.NetworkServiceHandler.class, + GroupCommNetworkHandlerImpl.class) + .bindNamedParameter(NetworkServiceParameters.NetworkServiceExceptionHandler.class, + ExceptionHandler.class) + .bindNamedParameter(NameServerParameters.NameServerAddr.class, nameServiceAddr) + .bindNamedParameter(NameServerParameters.NameServerPort.class, Integer.toString(nameServicePort)) + .bindNamedParameter(NetworkServiceParameters.NetworkServicePort.class, "0").build(); + LOG.exiting("GroupCommDriverImpl", "getServiceConf", confSerializer.toString(retVal)); + return retVal; + } + + @Override + public Configuration getTaskConfiguration(final Configuration partialTaskConf) { + LOG.entering("GroupCommDriverImpl", "getTaskConfiguration", new Object[]{confSerializer.toString(partialTaskConf)}); + final JavaConfigurationBuilder jcb = Tang.Factory.getTang().newConfigurationBuilder(partialTaskConf); + for (final CommunicationGroupDriver commGroupDriver : commGroupDrivers.values()) { + final Configuration commGroupConf = commGroupDriver.getTaskConfiguration(partialTaskConf); + if (commGroupConf != null) { + jcb.bindSetEntry(SerializedGroupConfigs.class, confSerializer.toString(commGroupConf)); + } + } + final Configuration retVal = jcb.build(); + LOG.exiting("GroupCommDriverImpl", "getTaskConfiguration", confSerializer.toString(retVal)); + return retVal; + } + + /** + * @return the groupCommRunningTaskStage + */ + @Override + public EStage<RunningTask> getGroupCommRunningTaskStage() { + LOG.entering("GroupCommDriverImpl", "getGroupCommRunningTaskStage"); + LOG.exiting("GroupCommDriverImpl", "getGroupCommRunningTaskStage", "Returning GroupCommRunningTaskStage"); + return groupCommRunningTaskStage; + } + + /** + * @return the groupCommFailedTaskStage + */ + @Override + public EStage<FailedTask> getGroupCommFailedTaskStage() { + LOG.entering("GroupCommDriverImpl", "getGroupCommFailedTaskStage"); + LOG.exiting("GroupCommDriverImpl", "getGroupCommFailedTaskStage", "Returning GroupCommFailedTaskStage"); + return groupCommFailedTaskStage; + } + + /** + * @return the groupCommFailedEvaluatorStage + */ + @Override + public EStage<FailedEvaluator> getGroupCommFailedEvaluatorStage() { + LOG.entering("GroupCommDriverImpl", "getGroupCommFailedEvaluatorStage"); + LOG.exiting("GroupCommDriverImpl", "getGroupCommFailedEvaluatorStage", "Returning GroupCommFaileEvaluatorStage"); + return groupCommFailedEvaluatorStage; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/GroupCommMessageHandler.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/GroupCommMessageHandler.java b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/GroupCommMessageHandler.java new file mode 100644 index 0000000..b466205 --- /dev/null +++ b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/GroupCommMessageHandler.java @@ -0,0 +1,55 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.io.network.group.impl.driver; + +import org.apache.reef.io.network.group.impl.GroupCommunicationMessage; +import org.apache.reef.io.network.group.impl.utils.BroadcastingEventHandler; +import org.apache.reef.io.network.group.impl.utils.Utils; +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.wake.EventHandler; + +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Logger; + +/** + * The network handler for the group communcation service on the driver side + */ +public class GroupCommMessageHandler implements EventHandler<GroupCommunicationMessage> { + + private static final Logger LOG = Logger.getLogger(GroupCommMessageHandler.class.getName()); + + private final Map<Class<? extends Name<String>>, BroadcastingEventHandler<GroupCommunicationMessage>> + commGroupMessageHandlers = new HashMap<>(); + + public void addHandler(final Class<? extends Name<String>> groupName, + final BroadcastingEventHandler<GroupCommunicationMessage> handler) { + LOG.entering("GroupCommMessageHandler", "addHandler", new Object[]{Utils.simpleName(groupName), handler}); + commGroupMessageHandlers.put(groupName, handler); + LOG.exiting("GroupCommMessageHandler", "addHandler", Utils.simpleName(groupName)); + } + + @Override + public void onNext(final GroupCommunicationMessage msg) { + LOG.entering("GroupCommMessageHandler", "onNext", msg); + final Class<? extends Name<String>> groupName = Utils.getClass(msg.getGroupname()); + commGroupMessageHandlers.get(groupName).onNext(msg); + LOG.exiting("GroupCommMessageHandler", "onNext", msg); + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/GroupCommService.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/GroupCommService.java b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/GroupCommService.java new file mode 100644 index 0000000..0915fc5 --- /dev/null +++ b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/GroupCommService.java @@ -0,0 +1,111 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.io.network.group.impl.driver; + +import org.apache.reef.driver.evaluator.FailedEvaluator; +import org.apache.reef.driver.parameters.EvaluatorDispatcherThreads; +import org.apache.reef.driver.parameters.ServiceEvaluatorFailedHandlers; +import org.apache.reef.driver.parameters.ServiceTaskFailedHandlers; +import org.apache.reef.driver.parameters.TaskRunningHandlers; +import org.apache.reef.driver.task.FailedTask; +import org.apache.reef.driver.task.RunningTask; +import org.apache.reef.io.network.group.api.driver.GroupCommServiceDriver; +import org.apache.reef.io.network.group.impl.config.parameters.TreeTopologyFanOut; +import org.apache.reef.tang.Configuration; +import org.apache.reef.tang.JavaConfigurationBuilder; +import org.apache.reef.tang.Tang; +import org.apache.reef.tang.annotations.Unit; +import org.apache.reef.tang.formats.AvroConfigurationSerializer; +import org.apache.reef.tang.formats.ConfigurationSerializer; +import org.apache.reef.wake.EventHandler; + +import javax.inject.Inject; +import java.util.logging.Logger; + +/** + * The Group Communication Service + */ +@Unit +public class GroupCommService { + + private static final Logger LOG = Logger.getLogger(GroupCommService.class.getName()); + private static final ConfigurationSerializer confSer = new AvroConfigurationSerializer(); + + private final GroupCommServiceDriver groupCommDriver; + + @Inject + public GroupCommService(final GroupCommServiceDriver groupCommDriver) { + this.groupCommDriver = groupCommDriver; + } + + public static Configuration getConfiguration() { + LOG.entering("GroupCommService", "getConfiguration"); + final JavaConfigurationBuilder jcb = Tang.Factory.getTang().newConfigurationBuilder(); + jcb.bindSetEntry(TaskRunningHandlers.class, RunningTaskHandler.class); + jcb.bindSetEntry(ServiceTaskFailedHandlers.class, FailedTaskHandler.class); + jcb.bindSetEntry(ServiceEvaluatorFailedHandlers.class, FailedEvaluatorHandler.class); + jcb.bindNamedParameter(EvaluatorDispatcherThreads.class, "1"); + final Configuration retVal = jcb.build(); + LOG.exiting("GroupCommService", "getConfiguration", confSer.toString(retVal)); + return retVal; + } + + public static Configuration getConfiguration(final int fanOut) { + LOG.entering("GroupCommService", "getConfiguration", fanOut); + final Configuration baseConf = getConfiguration(); + final Configuration retConf = Tang.Factory.getTang().newConfigurationBuilder(baseConf) + .bindNamedParameter(TreeTopologyFanOut.class, Integer.toString(fanOut)).build(); + LOG.exiting("GroupCommService", "getConfiguration", confSer.toString(retConf)); + return retConf; + } + + public class FailedEvaluatorHandler implements EventHandler<FailedEvaluator> { + + @Override + public void onNext(final FailedEvaluator failedEvaluator) { + LOG.entering("GroupCommService.FailedEvaluatorHandler", "onNext", failedEvaluator.getId()); + groupCommDriver.getGroupCommFailedEvaluatorStage().onNext(failedEvaluator); + LOG.exiting("GroupCommService.FailedEvaluatorHandler", "onNext", failedEvaluator.getId()); + } + + } + + public class RunningTaskHandler implements EventHandler<RunningTask> { + + @Override + public void onNext(final RunningTask runningTask) { + LOG.entering("GroupCommService.RunningTaskHandler", "onNext", runningTask.getId()); + groupCommDriver.getGroupCommRunningTaskStage().onNext(runningTask); + LOG.exiting("GroupCommService.RunningTaskHandler", "onNext", runningTask.getId()); + } + + } + + public class FailedTaskHandler implements EventHandler<FailedTask> { + + @Override + public void onNext(final FailedTask failedTask) { + LOG.entering("GroupCommService.FailedTaskHandler", "onNext", failedTask.getId()); + groupCommDriver.getGroupCommFailedTaskStage().onNext(failedTask); + LOG.exiting("GroupCommService.FailedTaskHandler", "onNext", failedTask.getId()); + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/IndexedMsg.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/IndexedMsg.java b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/IndexedMsg.java new file mode 100644 index 0000000..b72979c --- /dev/null +++ b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/IndexedMsg.java @@ -0,0 +1,71 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.io.network.group.impl.driver; + +import org.apache.reef.io.network.group.impl.GroupCommunicationMessage; +import org.apache.reef.io.network.group.impl.utils.Utils; +import org.apache.reef.tang.annotations.Name; + +/** + * Helper class to wrap msg and the operator name in the msg + */ +public class IndexedMsg { + private final Class<? extends Name<String>> operName; + private final GroupCommunicationMessage msg; + + public IndexedMsg(final GroupCommunicationMessage msg) { + super(); + this.operName = Utils.getClass(msg.getOperatorname()); + this.msg = msg; + } + + public Class<? extends Name<String>> getOperName() { + return operName; + } + + public GroupCommunicationMessage getMsg() { + return msg; + } + + @Override + public int hashCode() { + return operName.getName().hashCode(); + } + + @Override + public boolean equals(final Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof IndexedMsg)) { + return false; + } + final IndexedMsg that = (IndexedMsg) obj; + if (this.operName == that.operName) { + return true; + } + return false; + } + + @Override + public String toString() { + return operName.getSimpleName(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/MsgKey.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/MsgKey.java b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/MsgKey.java new file mode 100644 index 0000000..9ac021e --- /dev/null +++ b/lang/java/reef-io/src/main/java/org/apache/reef/io/network/group/impl/driver/MsgKey.java @@ -0,0 +1,90 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.io.network.group.impl.driver; + +import org.apache.reef.io.network.group.impl.GroupCommunicationMessage; +import org.apache.reef.io.network.proto.ReefNetworkGroupCommProtos; + +/** + * The key object used in map to aggregate msgs from + * all the operators before updating state on driver + */ +public class MsgKey { + private final String src; + private final String dst; + private final ReefNetworkGroupCommProtos.GroupCommMessage.Type msgType; + + public MsgKey (final String src, final String dst, final ReefNetworkGroupCommProtos.GroupCommMessage.Type msgType) { + this.src = src; + this.dst = dst; + this.msgType = msgType; + } + + public MsgKey (final GroupCommunicationMessage msg) { + this.src = msg.getSrcid() + ":" + msg.getSrcVersion(); + this.dst = msg.getDestid() + ":" + msg.getVersion(); + this.msgType = msg.getType(); + } + + public String getSrc () { + return src.split(":",2)[0]; + } + + public String getDst () { + return dst.split(":",2)[0]; + } + + public ReefNetworkGroupCommProtos.GroupCommMessage.Type getMsgType () { + return msgType; + } + + @Override + public String toString () { + return "(" + src + "," + dst + "," + msgType + ")"; + } + + @Override + public boolean equals (final Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof MsgKey)) { + return false; + } + final MsgKey that = (MsgKey) obj; + if (!this.src.equals(that.src)) { + return false; + } + if (!this.dst.equals(that.dst)) { + return false; + } + if (!this.msgType.equals(that.msgType)) { + return false; + } + return true; + } + + @Override + public int hashCode () { + int result = src.hashCode(); + result = 31 * result + dst.hashCode(); + result = 31 * result + msgType.hashCode(); + return result; + } +}
