[REEF-173] simplify group communication API * No spec exposed on API * create default group when the MpiDriver is created * using default codec if not given * Update test cases * Add new test cases for new API * Mark obsoleted API
JIRA: [REEF-173](https://issues.apache.org/jira/browse/REEF-173) Pull Request: This closes #100 Project: http://git-wip-us.apache.org/repos/asf/incubator-reef/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-reef/commit/87107120 Tree: http://git-wip-us.apache.org/repos/asf/incubator-reef/tree/87107120 Diff: http://git-wip-us.apache.org/repos/asf/incubator-reef/diff/87107120 Branch: refs/heads/master Commit: 871071203d734e3ac21192739a5aeefd7ac028ed Parents: 860c2fc Author: Julia Wang <[email protected]> Authored: Tue Mar 3 17:04:25 2015 -0800 Committer: Markus Weimer <[email protected]> Committed: Wed Mar 11 13:13:08 2015 -0700 ---------------------------------------------------------------------- .../KMeans/KMeansDriverHandlers.cs | 33 +- .../Group/Config/MpiConfigurationOptions.cs | 10 + .../Group/Driver/ICommunicationGroupDriver.cs | 60 +- .../Group/Driver/IMpiDriver.cs | 3 + .../Driver/Impl/CommunicationGroupDriver.cs | 151 ++++- .../Group/Driver/Impl/MpiDriver.cs | 47 +- .../Group/Operators/Impl/ReduceSender.cs | 2 +- .../Group/Task/Impl/OperatorTopology.cs | 3 + .../Injection/TestSetInjection.cs | 35 ++ .../Functional/ML/KMeans/TestKMeans.cs | 37 +- .../BroadcastReduceDriver.cs | 24 +- .../BroadcastReduceTest/BroadcastReduceTest.cs | 18 +- .../ScatterReduceTest/ScatterReduceDriver.cs | 24 +- .../MPI/ScatterReduceTest/ScatterReduceTest.cs | 19 +- .../Network/GroupCommunicationTests.cs | 551 +++++++------------ .../GroupCommunicationTreeTopologyTests.cs | 411 ++------------ 16 files changed, 655 insertions(+), 773 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/87107120/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansDriverHandlers.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansDriverHandlers.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansDriverHandlers.cs index 4920782..0c67777 100644 --- a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansDriverHandlers.cs +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansDriverHandlers.cs @@ -30,6 +30,7 @@ using Org.Apache.REEF.Driver.Bridge; using Org.Apache.REEF.Driver.Context; using Org.Apache.REEF.Driver.Evaluator; using Org.Apache.REEF.Examples.MachineLearning.KMeans.codecs; +using Org.Apache.REEF.Network.Group.Config; using Org.Apache.REEF.Network.Group.Driver; using Org.Apache.REEF.Network.Group.Driver.Impl; using Org.Apache.REEF.Network.Group.Operators.Impl; @@ -56,9 +57,7 @@ namespace Org.Apache.REEF.Examples.MachineLearning.KMeans private readonly string _executionDirectory; // TODO: we may want to make this injectable - private readonly int _partitionsNumber = 2; private readonly int _clustersNumber = 3; - private readonly int _fanOut = 2; private readonly int _totalEvaluators; private int _partitionInex = 0; private readonly IMpiDriver _mpiDriver; @@ -66,7 +65,7 @@ namespace Org.Apache.REEF.Examples.MachineLearning.KMeans private readonly TaskStarter _mpiTaskStarter; [Inject] - public KMeansDriverHandlers() + public KMeansDriverHandlers([Parameter(typeof(NumPartitions))] int numPartitions, MpiDriver mpiDriver) { Identifier = "KMeansDriverId"; _executionDirectory = Path.Combine(Directory.GetCurrentDirectory(), Constants.KMeansExecutionBaseDirectory, Guid.NewGuid().ToString("N").Substring(0, 4)); @@ -74,20 +73,20 @@ namespace Org.Apache.REEF.Examples.MachineLearning.KMeans string dataFile = arguments.Single(a => a.StartsWith("DataFile", StringComparison.Ordinal)).Split(':')[1]; DataVector.ShuffleDataAndGetInitialCentriods( Path.Combine(Directory.GetCurrentDirectory(), "reef", "global", dataFile), - _partitionsNumber, + numPartitions, _clustersNumber, - _executionDirectory); + _executionDirectory); - _totalEvaluators = _partitionsNumber + 1; - _mpiDriver = new MpiDriver(Identifier, Constants.MasterTaskId, _fanOut, new AvroConfigurationSerializer()); + _totalEvaluators = numPartitions + 1; - _commGroup = _mpiDriver.NewCommunicationGroup( - Constants.KMeansCommunicationGroupName, - _totalEvaluators) - .AddBroadcast(Constants.CentroidsBroadcastOperatorName, new BroadcastOperatorSpec<Centroids>(Constants.MasterTaskId, new CentroidsCodec())) - .AddBroadcast(Constants.ControlMessageBroadcastOperatorName, new BroadcastOperatorSpec<ControlMessage>(Constants.MasterTaskId, new ControlMessageCodec())) - .AddReduce(Constants.MeansReduceOperatorName, new ReduceOperatorSpec<ProcessedResults>(Constants.MasterTaskId, new ProcessedResultsCodec(), new KMeansMasterTask.AggregateMeans())) + _mpiDriver = mpiDriver; + + _commGroup = _mpiDriver.DefaultGroup + .AddBroadcast(Constants.CentroidsBroadcastOperatorName,Constants.MasterTaskId, new CentroidsCodec()) + .AddBroadcast(Constants.ControlMessageBroadcastOperatorName, Constants.MasterTaskId, new ControlMessageCodec()) + .AddReduce(Constants.MeansReduceOperatorName, Constants.MasterTaskId, new ProcessedResultsCodec(), new KMeansMasterTask.AggregateMeans()) .Build(); + _mpiTaskStarter = new TaskStarter(_mpiDriver, _totalEvaluators); CreateClassHierarchy(); @@ -97,10 +96,9 @@ namespace Org.Apache.REEF.Examples.MachineLearning.KMeans public void OnNext(IEvaluatorRequestor evalutorRequestor) { - int evaluatorsNumber = _totalEvaluators; int memory = 2048; int core = 1; - EvaluatorRequest request = new EvaluatorRequest(evaluatorsNumber, memory, core); + EvaluatorRequest request = new EvaluatorRequest(_totalEvaluators, memory, core); evalutorRequestor.Submit(request); } @@ -189,4 +187,9 @@ namespace Org.Apache.REEF.Examples.MachineLearning.KMeans ClrHandlerHelper.GenerateClassHierarchy(clrDlls); } } + + [NamedParameter("Number of partitions")] + public class NumPartitions : Name<int> + { + } } http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/87107120/lang/cs/Org.Apache.REEF.Network/Group/Config/MpiConfigurationOptions.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Config/MpiConfigurationOptions.cs b/lang/cs/Org.Apache.REEF.Network/Group/Config/MpiConfigurationOptions.cs index bdef88b..b7bd357 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Config/MpiConfigurationOptions.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Config/MpiConfigurationOptions.cs @@ -54,6 +54,16 @@ namespace Org.Apache.REEF.Network.Group.Config { } + [NamedParameter("Group name", defaultValue: "Group1")] + public class GroupName : Name<string> + { + } + + [NamedParameter("Number of tasks", defaultValue: "5")] + public class NumberOfTasks : Name<int> + { + } + [NamedParameter("with of the tree in topology", defaultValue:"2")] public class FanOut : Name<int> { http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/87107120/lang/cs/Org.Apache.REEF.Network/Group/Driver/ICommunicationGroupDriver.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Driver/ICommunicationGroupDriver.cs b/lang/cs/Org.Apache.REEF.Network/Group/Driver/ICommunicationGroupDriver.cs index 22caebd..5a857e0 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Driver/ICommunicationGroupDriver.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Driver/ICommunicationGroupDriver.cs @@ -18,9 +18,11 @@ */ using System.Collections.Generic; +using Org.Apache.REEF.Network.Group.Operators; using Org.Apache.REEF.Network.Group.Operators.Impl; using Org.Apache.REEF.Network.Group.Topology; using Org.Apache.REEF.Tang.Interface; +using Org.Apache.REEF.Wake.Remote; namespace Org.Apache.REEF.Network.Group.Driver { @@ -34,7 +36,7 @@ namespace Org.Apache.REEF.Network.Group.Driver /// <summary> /// Returns the list of task ids that belong to this Communication Group /// </summary> - List<string> TaskIds { get; } + List<string> TaskIds { get; } /// <summary> /// Adds the Broadcast MPI operator to the communication group. @@ -43,18 +45,53 @@ namespace Org.Apache.REEF.Network.Group.Driver /// <param name="operatorName">The name of the broadcast operator</param> /// <param name="spec">The specification that defines the Broadcast operator</param> /// <returns>The same CommunicationGroupDriver with the added Broadcast operator info</returns> + [System.Obsolete("use AddBroadcast<T>(string operatorName, string masterTaskId, ICodec<T> codecType, TopologyTypes topologyType = TopologyTypes.Flat)")] ICommunicationGroupDriver AddBroadcast<T>(string operatorName, BroadcastOperatorSpec<T> spec, TopologyTypes topologyType = TopologyTypes.Flat); /// <summary> + /// Adds the Broadcast MPI operator to the communication group. + /// </summary> + /// <typeparam name="T">The type of messages that operators will send</typeparam> + /// <param name="operatorName">The name of the broadcast operator</param> + /// <param name="masterTaskId">The master task id in broadcast operator</param> + /// <param name="codecType">The Codec used for serialization</param> + /// <param name="topologyType">The topology type for the operator</param> + /// <returns>The same CommunicationGroupDriver with the added Broadcast operator info</returns> + ICommunicationGroupDriver AddBroadcast<T>(string operatorName, string masterTaskId, ICodec<T> codecType, TopologyTypes topologyType = TopologyTypes.Flat); + + /// <summary> + /// Adds the Broadcast MPI operator to the communication group. Default to IntCodec + /// </summary> + /// <typeparam name="T">The type of messages that operators will send</typeparam> + /// <param name="operatorName">The name of the broadcast operator</param> + /// <param name="masterTaskId">The master task id in broadcast operator</param> + /// <param name="topologyType">The topology type for the operator</param> + /// <returns>The same CommunicationGroupDriver with the added Broadcast operator info</returns> + ICommunicationGroupDriver AddBroadcast(string operatorName, string masterTaskId, TopologyTypes topologyType = TopologyTypes.Flat); + + /// <summary> /// Adds the Reduce MPI operator to the communication group. /// </summary> /// <typeparam name="T">The type of messages that operators will send</typeparam> /// <param name="operatorName">The name of the reduce operator</param> /// <param name="spec">The specification that defines the Reduce operator</param> /// <returns>The same CommunicationGroupDriver with the added Reduce operator info</returns> + [System.Obsolete("use AddReduce<T>(string operatorName, string masterTaskId, ICodec<T> codecType, IReduceFunction<T> reduceFunction, TopologyTypes topologyType = TopologyTypes.Flat)")] ICommunicationGroupDriver AddReduce<T>(string operatorName, ReduceOperatorSpec<T> spec, TopologyTypes topologyType = TopologyTypes.Flat); /// <summary> + /// Adds the Reduce MPI operator to the communication group. + /// </summary> + /// <typeparam name="T">The type of messages that operators will send</typeparam> + /// <param name="operatorName">The name of the reduce operator</param> + /// <param name="masterTaskId">The master task id for the typology</param> + /// <param name="codecType">The codec used for serializing messages.</param> + /// <param name="reduceFunction">The class used to aggregate all messages.</param> + /// <param name="topologyType">The topology for the operator</param> + /// <returns>The same CommunicationGroupDriver with the added Reduce operator info</returns> + ICommunicationGroupDriver AddReduce<T>(string operatorName, string masterTaskId, ICodec<T> codecType, IReduceFunction<T> reduceFunction, TopologyTypes topologyType = TopologyTypes.Flat); + + /// <summary> /// Adds the Scatter MPI operator to the communication group. /// </summary> /// <typeparam name="T">The type of messages that operators will send</typeparam> @@ -62,9 +99,30 @@ namespace Org.Apache.REEF.Network.Group.Driver /// <param name="spec">The specification that defines the Scatter operator</param> /// <param name="topologyType">type of topology used in the operaor</param> /// <returns>The same CommunicationGroupDriver with the added Scatter operator info</returns> + [System.Obsolete("use AddScatter<T>(string operatorName, string senderId, ICodec<T> codecType, TopologyTypes topologyType = TopologyTypes.Flat)")] ICommunicationGroupDriver AddScatter<T>(string operatorName, ScatterOperatorSpec<T> spec, TopologyTypes topologyType = TopologyTypes.Flat); /// <summary> + /// Adds the Scatter MPI operator to the communication group. + /// </summary> + /// <typeparam name="T">The type of messages that operators will send</typeparam> + /// <param name="operatorName">The name of the scatter operator</param> + /// <param name="senderId">The sender id</param> + /// <param name="codecType">The codec used for serializing messages.</param> + /// <param name="topologyType">type of topology used in the operaor</param> + /// <returns>The same CommunicationGroupDriver with the added Scatter operator info</returns> + ICommunicationGroupDriver AddScatter<T>(string operatorName, string senderId, ICodec<T> codecType, TopologyTypes topologyType = TopologyTypes.Flat); + + /// <summary> + /// Adds the Scatter MPI operator to the communication group with default Codec + /// </summary> + /// <param name="operatorName">The name of the scatter operator</param> + /// <param name="senderId">The sender id</param> + /// <param name="topologyType">type of topology used in the operaor</param> + /// <returns>The same CommunicationGroupDriver with the added Scatter operator info</returns> + ICommunicationGroupDriver AddScatter(string operatorName, string senderId, TopologyTypes topologyType = TopologyTypes.Flat); + + /// <summary> /// Finalizes the CommunicationGroupDriver. /// After the CommunicationGroupDriver has been finalized, no more operators may /// be added to the group. http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/87107120/lang/cs/Org.Apache.REEF.Network/Group/Driver/IMpiDriver.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Driver/IMpiDriver.cs b/lang/cs/Org.Apache.REEF.Network/Group/Driver/IMpiDriver.cs index 422d63e..9c6eef2 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Driver/IMpiDriver.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Driver/IMpiDriver.cs @@ -33,6 +33,9 @@ namespace Org.Apache.REEF.Network.Group.Driver /// </summary> string MasterTaskId { get; } + ICommunicationGroupDriver DefaultGroup { get; } + + /// <summary> /// Create a new CommunicationGroup with the given name and number of tasks/operators. /// </summary> http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/87107120/lang/cs/Org.Apache.REEF.Network/Group/Driver/Impl/CommunicationGroupDriver.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Driver/Impl/CommunicationGroupDriver.cs b/lang/cs/Org.Apache.REEF.Network/Group/Driver/Impl/CommunicationGroupDriver.cs index 0b426a6..154a0f5 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Driver/Impl/CommunicationGroupDriver.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Driver/Impl/CommunicationGroupDriver.cs @@ -20,6 +20,7 @@ using System.Collections.Generic; using System.Reflection; using Org.Apache.REEF.Network.Group.Config; +using Org.Apache.REEF.Network.Group.Operators; using Org.Apache.REEF.Network.Group.Operators.Impl; using Org.Apache.REEF.Network.Group.Topology; using Org.Apache.REEF.Tang.Exceptions; @@ -28,6 +29,8 @@ using Org.Apache.REEF.Tang.Implementations.Tang; using Org.Apache.REEF.Tang.Interface; using Org.Apache.REEF.Tang.Util; using Org.Apache.REEF.Utilities.Logging; +using Org.Apache.REEF.Wake.Remote; +using Org.Apache.REEF.Wake.Remote.Impl; namespace Org.Apache.REEF.Network.Group.Driver.Impl { @@ -94,6 +97,7 @@ namespace Org.Apache.REEF.Network.Group.Driver.Impl /// <param name="operatorName">The name of the broadcast operator</param> /// <param name="spec">The specification that defines the Broadcast operator</param> /// <returns>The same CommunicationGroupDriver with the added Broadcast operator info</returns> + [System.Obsolete("use AddBroadcast<T>(string operatorName, string masterTaskId, ICodec<T> codecType, TopologyTypes topologyType = TopologyTypes.Flat)")] public ICommunicationGroupDriver AddBroadcast<T>( string operatorName, BroadcastOperatorSpec<T> spec, @@ -122,12 +126,109 @@ namespace Org.Apache.REEF.Network.Group.Driver.Impl } /// <summary> + /// Adds the Broadcast MPI operator to the communication group. + /// </summary> + /// <typeparam name="T">The type of messages that operators will send</typeparam> + /// <param name="operatorName">The name of the broadcast operator</param> + /// <param name="masterTaskId">The master task id in broadcast operator</param> + /// <param name="codecType">The Codec used for serialization</param> + /// <param name="topologyType">The topology type for the operator</param> + /// <returns>The same CommunicationGroupDriver with the added Broadcast operator info</returns> + public ICommunicationGroupDriver AddBroadcast<T>(string operatorName, string masterTaskId, ICodec<T> codecType, TopologyTypes topologyType = TopologyTypes.Flat) + { + if (_finalized) + { + throw new IllegalStateException("Can't add operators once the spec has been built."); + } + + var spec = new BroadcastOperatorSpec<T>( + masterTaskId, + codecType); + + ITopology<T> topology; + if (topologyType == TopologyTypes.Flat) + { + topology = new FlatTopology<T>(operatorName, _groupName, spec.SenderId, _driverId, spec); + } + else + { + topology = new TreeTopology<T>(operatorName, _groupName, spec.SenderId, _driverId, spec, + _fanOut); + } + + _topologies[operatorName] = topology; + _operatorSpecs[operatorName] = spec; + + return this; + } + + /// <summary> + /// Adds the Broadcast MPI operator to the communication group. Default to IntCodec + /// </summary> + /// <typeparam name="T">The type of messages that operators will send</typeparam> + /// <param name="operatorName">The name of the broadcast operator</param> + /// <param name="masterTaskId">The master task id in broadcast operator</param> + /// <param name="topologyType">The topology type for the operator</param> + /// <returns>The same CommunicationGroupDriver with the added Broadcast operator info</returns> + public ICommunicationGroupDriver AddBroadcast(string operatorName, string masterTaskId, + TopologyTypes topologyType = TopologyTypes.Flat) + { + return AddBroadcast(operatorName, masterTaskId, new IntCodec(), topologyType); + } + + /// <summary> + /// Adds the Reduce MPI operator to the communication group. + /// </summary> + /// <typeparam name="T">The type of messages that operators will send</typeparam> + /// <param name="operatorName">The name of the reduce operator</param> + /// <param name="masterTaskId">The master task id for the typology</param> + /// <param name="codecType">The codec used for serializing messages.</param> + /// <param name="reduceFunction">The class used to aggregate all messages.</param> + /// <param name="topologyType">The topology for the operator</param> + /// <returns>The same CommunicationGroupDriver with the added Reduce operator info</returns> + public ICommunicationGroupDriver AddReduce<T>( + string operatorName, + string masterTaskId, + ICodec<T> codecType, + IReduceFunction<T> reduceFunction, + TopologyTypes topologyType = TopologyTypes.Flat) + { + if (_finalized) + { + throw new IllegalStateException("Can't add operators once the spec has been built."); + } + + var spec = new ReduceOperatorSpec<T>( + masterTaskId, + codecType, + reduceFunction); + + ITopology<T> topology; + + if (topologyType == TopologyTypes.Flat) + { + topology = new FlatTopology<T>(operatorName, _groupName, spec.ReceiverId, _driverId, spec); + } + else + { + topology = new TreeTopology<T>(operatorName, _groupName, spec.ReceiverId, _driverId, spec, + _fanOut); + } + + _topologies[operatorName] = topology; + _operatorSpecs[operatorName] = spec; + + return this; + } + + /// <summary> /// Adds the Reduce MPI operator to the communication group. /// </summary> /// <typeparam name="T">The type of messages that operators will send</typeparam> /// <param name="operatorName">The name of the reduce operator</param> /// <param name="spec">The specification that defines the Reduce operator</param> /// <returns>The same CommunicationGroupDriver with the added Reduce operator info</returns> + [System.Obsolete("use AddReduce<T>(string operatorName, string masterTaskId, ICodec<T> codecType, IReduceFunction<T> reduceFunction, TopologyTypes topologyType = TopologyTypes.Flat)")] public ICommunicationGroupDriver AddReduce<T>( string operatorName, ReduceOperatorSpec<T> spec, @@ -162,6 +263,7 @@ namespace Org.Apache.REEF.Network.Group.Driver.Impl /// <param name="operatorName">The name of the scatter operator</param> /// <param name="spec">The specification that defines the Scatter operator</param> /// <returns>The same CommunicationGroupDriver with the added Scatter operator info</returns> + [System.Obsolete("use AddScatter<T>(string operatorName, string senderId, ICodec<T> codecType, TopologyTypes topologyType = TopologyTypes.Flat)")] public ICommunicationGroupDriver AddScatter<T>(string operatorName, ScatterOperatorSpec<T> spec, TopologyTypes topologyType = TopologyTypes.Flat) { if (_finalized) @@ -169,7 +271,7 @@ namespace Org.Apache.REEF.Network.Group.Driver.Impl throw new IllegalStateException("Can't add operators once the spec has been built."); } - ITopology<T> topology; + ITopology<T> topology; if (topologyType == TopologyTypes.Flat) { @@ -187,6 +289,53 @@ namespace Org.Apache.REEF.Network.Group.Driver.Impl } /// <summary> + /// Adds the Scatter MPI operator to the communication group. + /// </summary> + /// <typeparam name="T">The type of messages that operators will send</typeparam> + /// <param name="operatorName">The name of the scatter operator</param> + /// <param name="senderId">The sender id</param> + /// <param name="codecType">The codec used for serializing messages.</param> + /// <param name="topologyType">type of topology used in the operaor</param> + /// <returns>The same CommunicationGroupDriver with the added Scatter operator info</returns> + public ICommunicationGroupDriver AddScatter<T>(string operatorName, string senderId, ICodec<T> codecType, TopologyTypes topologyType = TopologyTypes.Flat) + { + if (_finalized) + { + throw new IllegalStateException("Can't add operators once the spec has been built."); + } + + var spec = new ScatterOperatorSpec<T>(senderId, codecType); + + ITopology<T> topology; + + if (topologyType == TopologyTypes.Flat) + { + topology = new FlatTopology<T>(operatorName, _groupName, spec.SenderId, _driverId, spec); + } + else + { + topology = new TreeTopology<T>(operatorName, _groupName, spec.SenderId, _driverId, spec, + _fanOut); + } + _topologies[operatorName] = topology; + _operatorSpecs[operatorName] = spec; + + return this; + } + + /// <summary> + /// Adds the Scatter MPI operator to the communication group with default Codec + /// </summary> + /// <param name="operatorName">The name of the scatter operator</param> + /// <param name="senderId">The sender id</param> + /// <param name="topologyType">type of topology used in the operaor</param> + /// <returns>The same CommunicationGroupDriver with the added Scatter operator info</returns> + public ICommunicationGroupDriver AddScatter(string operatorName, string senderId, TopologyTypes topologyType = TopologyTypes.Flat) + { + return AddScatter(operatorName, senderId, new IntCodec(), topologyType); + } + + /// <summary> /// Finalizes the CommunicationGroupDriver. /// After the CommunicationGroupDriver has been finalized, no more operators may /// be added to the group. http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/87107120/lang/cs/Org.Apache.REEF.Network/Group/Driver/Impl/MpiDriver.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Driver/Impl/MpiDriver.cs b/lang/cs/Org.Apache.REEF.Network/Group/Driver/Impl/MpiDriver.cs index 40ca02c..291b7d6 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Driver/Impl/MpiDriver.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Driver/Impl/MpiDriver.cs @@ -52,10 +52,11 @@ namespace Org.Apache.REEF.Network.Group.Driver.Impl private static Logger LOGGER = Logger.GetLogger(typeof(MpiDriver)); private readonly string _driverId; - private readonly string _nameServerAddr; + private readonly string _nameServerAddr; private readonly int _nameServerPort; private int _contextIds; private int _fanOut; + private string _groupName; private readonly Dictionary<string, ICommunicationGroupDriver> _commGroups; private readonly AvroConfigurationSerializer _configSerializer; @@ -66,7 +67,9 @@ namespace Org.Apache.REEF.Network.Group.Driver.Impl /// </summary> /// <param name="driverId">Identifer for the REEF driver</param> /// <param name="masterTaskId">Identifer for MPI master task</param> + /// <param name="fanOut">fanOut for tree topology</param> /// <param name="configSerializer">Used to serialize task configuration</param> + [System.Obsolete("user the other constructor")] [Inject] public MpiDriver( [Parameter(typeof(MpiConfigurationOptions.DriverId))] string driverId, @@ -89,10 +92,50 @@ namespace Org.Apache.REEF.Network.Group.Driver.Impl } /// <summary> + /// Create a new MpiDriver object. + /// </summary> + /// <param name="driverId">Identifer for the REEF driver</param> + /// <param name="masterTaskId">Identifer for MPI master task</param> + /// <param name="fanOut">fanOut for tree topology</param> + /// <param name="groupName">default communication group name</param> + /// <param name="numberOfTasks">Number of tasks in the default group</param> + /// <param name="configSerializer">Used to serialize task configuration</param> + [Inject] + public MpiDriver( + [Parameter(typeof(MpiConfigurationOptions.DriverId))] string driverId, + [Parameter(typeof(MpiConfigurationOptions.MasterTaskId))] string masterTaskId, + [Parameter(typeof(MpiConfigurationOptions.FanOut))] int fanOut, + [Parameter(typeof(MpiConfigurationOptions.GroupName))] string groupName, + [Parameter(typeof(MpiConfigurationOptions.NumberOfTasks))] int numberOfTasks, + AvroConfigurationSerializer configSerializer) + { + _driverId = driverId; + _contextIds = -1; + _fanOut = fanOut; + MasterTaskId = masterTaskId; + _groupName = groupName; + + _configSerializer = configSerializer; + _commGroups = new Dictionary<string, ICommunicationGroupDriver>(); + _nameServer = new NameServer(0); + + IPEndPoint localEndpoint = _nameServer.LocalEndpoint; + _nameServerAddr = localEndpoint.Address.ToString(); + _nameServerPort = localEndpoint.Port; + + NewCommunicationGroup(groupName, numberOfTasks); + } + + /// <summary> /// Returns the identifier for the master task /// </summary> public string MasterTaskId { get; private set; } + public ICommunicationGroupDriver DefaultGroup + { + get { return _commGroups[_groupName]; } + } + /// <summary> /// Create a new CommunicationGroup with the given name and number of tasks/operators. /// </summary> @@ -163,10 +206,8 @@ namespace Org.Apache.REEF.Network.Group.Driver.Impl /// <summary> /// Get the configuration for a particular task. - /// /// The task may belong to many Communication Groups, so each one is serialized /// in the configuration as a SerializedGroupConfig. - /// /// The user must merge their part of task configuration (task id, task class) /// with this returned MPI task configuration. /// </summary> http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/87107120/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ReduceSender.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ReduceSender.cs b/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ReduceSender.cs index a2c7d93..d21983a 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ReduceSender.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ReduceSender.cs @@ -83,7 +83,7 @@ namespace Org.Apache.REEF.Network.Group.Operators.Impl public int Version { get; private set; } /// <summary> - /// Sends data to the operator's ReduceReceiver to be aggregated. + /// Get reduced data from children, reduce with the data given, then sends reduced data to parent /// </summary> /// <param name="data">The data to send</param> public void Send(T data) http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/87107120/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/OperatorTopology.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/OperatorTopology.cs b/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/OperatorTopology.cs index fcec282..088e2e7 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/OperatorTopology.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/OperatorTopology.cs @@ -45,6 +45,9 @@ namespace Org.Apache.REEF.Network.Group.Task.Impl /// <typeparam name="T">The message type</typeparam> public class OperatorTopology<T> : IObserver<GroupCommunicationMessage> { + private const int DefaultTimeout = 50000; + private const int RetryCount = 10; + private static readonly Logger LOGGER = Logger.GetLogger(typeof(OperatorTopology<>)); private readonly string _groupName; http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/87107120/lang/cs/Org.Apache.REEF.Tang.Tests/Injection/TestSetInjection.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Tang.Tests/Injection/TestSetInjection.cs b/lang/cs/Org.Apache.REEF.Tang.Tests/Injection/TestSetInjection.cs index 0350249..7101bf9 100644 --- a/lang/cs/Org.Apache.REEF.Tang.Tests/Injection/TestSetInjection.cs +++ b/lang/cs/Org.Apache.REEF.Tang.Tests/Injection/TestSetInjection.cs @@ -61,6 +61,26 @@ namespace Org.Apache.REEF.Tang.Tests.Injection } [TestMethod] + public void TestStringInjectNoDefault() + { + BoxNoDefault b = (BoxNoDefault)TangFactory.GetTang().NewInjector().GetInstance(typeof(BoxNoDefault)); + ISet<string> actual = b.Numbers; + Assert.AreEqual(actual.Count, 0); + } + + [TestMethod] + public void TestStringInjectNoDefaultWithValue() + { + var cb = TangFactory.GetTang().NewConfigurationBuilder(); + cb.BindSetEntry<SetOfNumbersNoDefault, string>(GenericType<SetOfNumbersNoDefault>.Class, "123"); + BoxNoDefault b = (BoxNoDefault)TangFactory.GetTang().NewInjector(cb.Build()).GetInstance(typeof(BoxNoDefault)); + + ISet<string> actual = b.Numbers; + + Assert.AreEqual(actual.Count, 1); + } + + [TestMethod] public void TestObjectInjectDefault() { IInjector i = TangFactory.GetTang().NewInjector(); @@ -327,6 +347,21 @@ namespace Org.Apache.REEF.Tang.Tests.Injection { } + [NamedParameter()] + public class SetOfNumbersNoDefault : Name<ISet<string>> + { + } + public class BoxNoDefault + { + [Inject] + public BoxNoDefault([Parameter(typeof(SetOfNumbersNoDefault))] ISet<string> numbers) + { + this.Numbers = numbers; + } + + public ISet<string> Numbers { get; set; } + } + public class Box { [Inject] http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/87107120/lang/cs/Org.Apache.REEF.Tests/Functional/ML/KMeans/TestKMeans.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Tests/Functional/ML/KMeans/TestKMeans.cs b/lang/cs/Org.Apache.REEF.Tests/Functional/ML/KMeans/TestKMeans.cs index 982d799..6b3b26c 100644 --- a/lang/cs/Org.Apache.REEF.Tests/Functional/ML/KMeans/TestKMeans.cs +++ b/lang/cs/Org.Apache.REEF.Tests/Functional/ML/KMeans/TestKMeans.cs @@ -23,9 +23,13 @@ using System.Globalization; using System.IO; using Microsoft.VisualStudio.TestTools.UnitTesting; using Org.Apache.REEF.Common.Io; +using Org.Apache.REEF.Common.Tasks; using Org.Apache.REEF.Driver.Bridge; using Org.Apache.REEF.Examples.MachineLearning.KMeans; +using Org.Apache.REEF.Network.Group.Config; using Org.Apache.REEF.Network.NetworkService; +using Org.Apache.REEF.Tang.Implementations.Configuration; +using Org.Apache.REEF.Tang.Implementations.Tang; using Org.Apache.REEF.Tang.Interface; using Org.Apache.REEF.Tang.Util; using Org.Apache.REEF.Utilities.Logging; @@ -140,14 +144,31 @@ namespace Org.Apache.REEF.Tests.Functional.ML.KMeans private IConfiguration DriverConfiguration() { - return DriverBridgeConfiguration.ConfigurationModule - .Set(DriverBridgeConfiguration.OnDriverStarted, GenericType<KMeansDriverHandlers>.Class) - .Set(DriverBridgeConfiguration.OnEvaluatorAllocated, GenericType<KMeansDriverHandlers>.Class) - .Set(DriverBridgeConfiguration.OnEvaluatorRequested, GenericType<KMeansDriverHandlers>.Class) - .Set(DriverBridgeConfiguration.OnContextActive, GenericType<KMeansDriverHandlers>.Class) - .Set(DriverBridgeConfiguration.CommandLineArguments, "DataFile:" + _dataFile) - .Set(DriverBridgeConfiguration.CustomTraceLevel, Level.Info.ToString()) - .Build(); + int fanOut = 2; + int totalEvaluators = Partitions + 1; + string Identifier = "KMeansDriverId"; + + IConfiguration driverConfig = TangFactory.GetTang().NewConfigurationBuilder( + DriverBridgeConfiguration.ConfigurationModule + .Set(DriverBridgeConfiguration.OnDriverStarted, GenericType<KMeansDriverHandlers>.Class) + .Set(DriverBridgeConfiguration.OnEvaluatorAllocated, GenericType<KMeansDriverHandlers>.Class) + .Set(DriverBridgeConfiguration.OnEvaluatorRequested, GenericType<KMeansDriverHandlers>.Class) + .Set(DriverBridgeConfiguration.OnContextActive, GenericType<KMeansDriverHandlers>.Class) + .Set(DriverBridgeConfiguration.CommandLineArguments, "DataFile:" + _dataFile) + .Set(DriverBridgeConfiguration.CustomTraceLevel, Level.Info.ToString()) + .Build()) + .BindIntNamedParam<NumPartitions>(Partitions.ToString()) + .Build(); + + IConfiguration mpiDriverConfig = TangFactory.GetTang().NewConfigurationBuilder() + .BindStringNamedParam<MpiConfigurationOptions.DriverId>(Identifier) + .BindStringNamedParam<MpiConfigurationOptions.MasterTaskId>(Constants.MasterTaskId) + .BindStringNamedParam<MpiConfigurationOptions.GroupName>(Constants.KMeansCommunicationGroupName) + .BindIntNamedParam<MpiConfigurationOptions.FanOut>(fanOut.ToString(CultureInfo.InvariantCulture).ToString(CultureInfo.InvariantCulture)) + .BindIntNamedParam<MpiConfigurationOptions.NumberOfTasks>(totalEvaluators.ToString()) + .Build(); + + return Configurations.Merge(driverConfig, mpiDriverConfig); } private HashSet<string> AssembliesToCopy() http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/87107120/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/BroadcastReduceTest/BroadcastReduceDriver.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/BroadcastReduceTest/BroadcastReduceDriver.cs b/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/BroadcastReduceTest/BroadcastReduceDriver.cs index 947dfdc..ed3e7b5 100644 --- a/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/BroadcastReduceTest/BroadcastReduceDriver.cs +++ b/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/BroadcastReduceTest/BroadcastReduceDriver.cs @@ -27,6 +27,7 @@ using Org.Apache.REEF.Driver; using Org.Apache.REEF.Driver.Bridge; using Org.Apache.REEF.Driver.Context; using Org.Apache.REEF.Driver.Evaluator; +using Org.Apache.REEF.Network.Group.Config; using Org.Apache.REEF.Network.Group.Driver; using Org.Apache.REEF.Network.Group.Driver.Impl; using Org.Apache.REEF.Network.Group.Operators; @@ -57,33 +58,22 @@ namespace Org.Apache.REEF.Tests.Functional.MPI.BroadcastReduceTest public BroadcastReduceDriver( [Parameter(typeof(MpiTestConfig.NumEvaluators))] int numEvaluators, [Parameter(typeof(MpiTestConfig.NumIterations))] int numIterations, - [Parameter(typeof(MpiTestConfig.FanOut))] int fanOut, - AvroConfigurationSerializer confSerializer) + MpiDriver mpiDriver) { Identifier = "BroadcastStartHandler"; _numEvaluators = numEvaluators; _numIterations = numIterations; - - _mpiDriver = new MpiDriver( - MpiTestConstants.DriverId, - MpiTestConstants.MasterTaskId, - fanOut, - confSerializer); - - _commGroup = _mpiDriver.NewCommunicationGroup( - MpiTestConstants.GroupName, - numEvaluators) + _mpiDriver = mpiDriver; + _commGroup = _mpiDriver.DefaultGroup .AddBroadcast( MpiTestConstants.BroadcastOperatorName, - new BroadcastOperatorSpec<int>( - MpiTestConstants.MasterTaskId, - new IntCodec())) + MpiTestConstants.MasterTaskId, + new IntCodec()) .AddReduce( MpiTestConstants.ReduceOperatorName, - new ReduceOperatorSpec<int>( MpiTestConstants.MasterTaskId, new IntCodec(), - new SumFunction())) + new SumFunction()) .Build(); _mpiTaskStarter = new TaskStarter(_mpiDriver, numEvaluators); http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/87107120/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/BroadcastReduceTest/BroadcastReduceTest.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/BroadcastReduceTest/BroadcastReduceTest.cs b/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/BroadcastReduceTest/BroadcastReduceTest.cs index 17b28f4..2224c04 100644 --- a/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/BroadcastReduceTest/BroadcastReduceTest.cs +++ b/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/BroadcastReduceTest/BroadcastReduceTest.cs @@ -24,7 +24,10 @@ using Org.Apache.REEF.Common.Io; using Org.Apache.REEF.Common.Tasks; using Org.Apache.REEF.Driver; using Org.Apache.REEF.Driver.Bridge; +using Org.Apache.REEF.Network.Group.Config; using Org.Apache.REEF.Network.NetworkService; +using Org.Apache.REEF.Tang.Formats; +using Org.Apache.REEF.Tang.Implementations.Configuration; using Org.Apache.REEF.Tang.Implementations.Tang; using Org.Apache.REEF.Tang.Interface; using Org.Apache.REEF.Tang.Util; @@ -64,13 +67,20 @@ namespace Org.Apache.REEF.Tests.Functional.MPI.BroadcastReduceTest .BindNamedParameter<MpiTestConfig.NumIterations, int>( GenericType<MpiTestConfig.NumIterations>.Class, MpiTestConstants.NumIterations.ToString(CultureInfo.InvariantCulture)) - .BindNamedParameter<MpiTestConfig.FanOut, int>( - GenericType<MpiTestConfig.FanOut>.Class, - MpiTestConstants.FanOut.ToString(CultureInfo.InvariantCulture)) .BindNamedParameter<MpiTestConfig.NumEvaluators, int>( GenericType<MpiTestConfig.NumEvaluators>.Class, numTasks.ToString(CultureInfo.InvariantCulture)) .Build(); + + IConfiguration mpiDriverConfig = TangFactory.GetTang().NewConfigurationBuilder() + .BindStringNamedParam<MpiConfigurationOptions.DriverId>(MpiTestConstants.DriverId) + .BindStringNamedParam<MpiConfigurationOptions.MasterTaskId>(MpiTestConstants.MasterTaskId) + .BindStringNamedParam<MpiConfigurationOptions.GroupName>(MpiTestConstants.GroupName) + .BindIntNamedParam<MpiConfigurationOptions.FanOut>(MpiTestConstants.FanOut.ToString(CultureInfo.InvariantCulture).ToString(CultureInfo.InvariantCulture)) + .BindIntNamedParam<MpiConfigurationOptions.NumberOfTasks>(numTasks.ToString(CultureInfo.InvariantCulture)) + .Build(); + + IConfiguration merged = Configurations.Merge(driverConfig, mpiDriverConfig); HashSet<string> appDlls = new HashSet<string>(); appDlls.Add(typeof(IDriver).Assembly.GetName().Name); @@ -79,7 +89,7 @@ namespace Org.Apache.REEF.Tests.Functional.MPI.BroadcastReduceTest appDlls.Add(typeof(INameClient).Assembly.GetName().Name); appDlls.Add(typeof(INetworkService<>).Assembly.GetName().Name); - TestRun(appDlls, driverConfig, false, JavaLoggingSetting.VERBOSE); + TestRun(appDlls, merged, false, JavaLoggingSetting.VERBOSE); ValidateSuccessForLocalRuntime(numTasks); } } http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/87107120/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/ScatterReduceTest/ScatterReduceDriver.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/ScatterReduceTest/ScatterReduceDriver.cs b/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/ScatterReduceTest/ScatterReduceDriver.cs index e16a917..6029bfe 100644 --- a/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/ScatterReduceTest/ScatterReduceDriver.cs +++ b/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/ScatterReduceTest/ScatterReduceDriver.cs @@ -19,6 +19,7 @@ using System; using System.Collections.Generic; +using System.Globalization; using System.Linq; using Org.Apache.REEF.Common.Io; using Org.Apache.REEF.Common.Tasks; @@ -26,6 +27,7 @@ using Org.Apache.REEF.Driver; using Org.Apache.REEF.Driver.Bridge; using Org.Apache.REEF.Driver.Context; using Org.Apache.REEF.Driver.Evaluator; +using Org.Apache.REEF.Network.Group.Config; using Org.Apache.REEF.Network.Group.Driver; using Org.Apache.REEF.Network.Group.Driver.Impl; using Org.Apache.REEF.Network.Group.Operators; @@ -34,6 +36,7 @@ using Org.Apache.REEF.Network.Group.Topology; using Org.Apache.REEF.Network.NetworkService; using Org.Apache.REEF.Tang.Annotations; using Org.Apache.REEF.Tang.Formats; +using Org.Apache.REEF.Tang.Implementations.Tang; using Org.Apache.REEF.Tang.Interface; using Org.Apache.REEF.Tang.Util; using Org.Apache.REEF.Utilities.Logging; @@ -54,33 +57,22 @@ namespace Org.Apache.REEF.Tests.Functional.MPI.ScatterReduceTest [Inject] public ScatterReduceDriver( [Parameter(typeof(MpiTestConfig.NumEvaluators))] int numEvaluators, - [Parameter(typeof(MpiTestConfig.FanOut))] int fanOut, - AvroConfigurationSerializer confSerializer) + MpiDriver mpiDriver) { Identifier = "BroadcastStartHandler"; _numEvaluators = numEvaluators; - - _mpiDriver = new MpiDriver( - MpiTestConstants.DriverId, - MpiTestConstants.MasterTaskId, - fanOut, - confSerializer); - - _commGroup = _mpiDriver.NewCommunicationGroup( - MpiTestConstants.GroupName, - numEvaluators) + _mpiDriver = mpiDriver; + _commGroup = _mpiDriver.DefaultGroup .AddScatter( MpiTestConstants.ScatterOperatorName, - new ScatterOperatorSpec<int>( MpiTestConstants.MasterTaskId, - new IntCodec()), + new IntCodec(), TopologyTypes.Tree) .AddReduce( MpiTestConstants.ReduceOperatorName, - new ReduceOperatorSpec<int>( MpiTestConstants.MasterTaskId, new IntCodec(), - new SumFunction())) + new SumFunction()) .Build(); _mpiTaskStarter = new TaskStarter(_mpiDriver, numEvaluators); http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/87107120/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/ScatterReduceTest/ScatterReduceTest.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/ScatterReduceTest/ScatterReduceTest.cs b/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/ScatterReduceTest/ScatterReduceTest.cs index d03036c..59e27b4 100644 --- a/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/ScatterReduceTest/ScatterReduceTest.cs +++ b/lang/cs/Org.Apache.REEF.Tests/Functional/MPI/ScatterReduceTest/ScatterReduceTest.cs @@ -24,7 +24,9 @@ using Org.Apache.REEF.Common.Io; using Org.Apache.REEF.Common.Tasks; using Org.Apache.REEF.Driver; using Org.Apache.REEF.Driver.Bridge; +using Org.Apache.REEF.Network.Group.Config; using Org.Apache.REEF.Network.NetworkService; +using Org.Apache.REEF.Tang.Implementations.Configuration; using Org.Apache.REEF.Tang.Implementations.Tang; using Org.Apache.REEF.Tang.Interface; using Org.Apache.REEF.Tang.Util; @@ -64,11 +66,18 @@ namespace Org.Apache.REEF.Tests.Functional.MPI.ScatterReduceTest .BindNamedParameter<MpiTestConfig.NumEvaluators, int>( GenericType<MpiTestConfig.NumEvaluators>.Class, numTasks.ToString(CultureInfo.InvariantCulture)) - .BindNamedParameter<MpiTestConfig.FanOut, int>( - GenericType<MpiTestConfig.FanOut>.Class, - MpiTestConstants.FanOut.ToString(CultureInfo.InvariantCulture)) .Build(); - + + IConfiguration mpiDriverConfig = TangFactory.GetTang().NewConfigurationBuilder() + .BindStringNamedParam<MpiConfigurationOptions.DriverId>(MpiTestConstants.DriverId) + .BindStringNamedParam<MpiConfigurationOptions.MasterTaskId>(MpiTestConstants.MasterTaskId) + .BindStringNamedParam<MpiConfigurationOptions.GroupName>(MpiTestConstants.GroupName) + .BindIntNamedParam<MpiConfigurationOptions.FanOut>(MpiTestConstants.FanOut.ToString(CultureInfo.InvariantCulture).ToString(CultureInfo.InvariantCulture)) + .BindIntNamedParam<MpiConfigurationOptions.NumberOfTasks>(numTasks.ToString()) + .Build(); + + IConfiguration merged = Configurations.Merge(driverConfig, mpiDriverConfig); + HashSet<string> appDlls = new HashSet<string>(); appDlls.Add(typeof(IDriver).Assembly.GetName().Name); appDlls.Add(typeof(ITask).Assembly.GetName().Name); @@ -76,7 +85,7 @@ namespace Org.Apache.REEF.Tests.Functional.MPI.ScatterReduceTest appDlls.Add(typeof(INameClient).Assembly.GetName().Name); appDlls.Add(typeof(INetworkService<>).Assembly.GetName().Name); - TestRun(appDlls, driverConfig, false, JavaLoggingSetting.VERBOSE); + TestRun(appDlls, merged, false, JavaLoggingSetting.VERBOSE); ValidateSuccessForLocalRuntime(numTasks); } } http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/87107120/lang/cs/Org.Apache.REEF.Tests/Network/GroupCommunicationTests.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Tests/Network/GroupCommunicationTests.cs b/lang/cs/Org.Apache.REEF.Tests/Network/GroupCommunicationTests.cs index a5297f6..8768297 100644 --- a/lang/cs/Org.Apache.REEF.Tests/Network/GroupCommunicationTests.cs +++ b/lang/cs/Org.Apache.REEF.Tests/Network/GroupCommunicationTests.cs @@ -27,6 +27,7 @@ using System.Reactive; using System.Text; using Org.Apache.REEF.Common.Tasks; using Org.Apache.REEF.Network.Group.Codec; +using Org.Apache.REEF.Network.Group.Config; using Org.Apache.REEF.Network.Group.Driver; using Org.Apache.REEF.Network.Group.Driver.Impl; using Org.Apache.REEF.Network.Group.Operators; @@ -41,8 +42,6 @@ using Org.Apache.REEF.Tang.Implementations.Configuration; using Org.Apache.REEF.Tang.Implementations.Tang; using Org.Apache.REEF.Tang.Interface; using Org.Apache.REEF.Tang.Util; -using Org.Apache.REEF.Tests.Functional.MPI; -using Org.Apache.REEF.Utilities.Logging; using Org.Apache.REEF.Wake.Remote; using Org.Apache.REEF.Wake.Remote.Impl; @@ -100,51 +99,21 @@ namespace Org.Apache.REEF.Tests.Network int numTasks = 3; int fanOut = 2; - IMpiDriver mpiDriver = new MpiDriver(driverId, masterTaskId, fanOut, new AvroConfigurationSerializer()); + var mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); - ICommunicationGroupDriver commGroup = mpiDriver.NewCommunicationGroup( - groupName, - numTasks) - .AddBroadcast( + ICommunicationGroupDriver commGroup = mpiDriver.DefaultGroup + .AddBroadcast<int>( broadcastOperatorName, - new BroadcastOperatorSpec<int>( masterTaskId, - new IntCodec())) - .AddReduce( + new IntCodec()) + .AddReduce<int>( reduceOperatorName, - new ReduceOperatorSpec<int>( masterTaskId, new IntCodec(), - new SumFunction())) + new SumFunction()) .Build(); - List<ICommunicationGroupClient> commGroups = new List<ICommunicationGroupClient>(); - IConfiguration serviceConfig = mpiDriver.GetServiceConfiguration(); - - List<IConfiguration> partialConfigs = new List<IConfiguration>(); - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration partialTaskConfig = TangFactory.GetTang().NewConfigurationBuilder( - TaskConfiguration.ConfigurationModule - .Set(TaskConfiguration.Identifier, taskId) - .Set(TaskConfiguration.Task, GenericType<MyTask>.Class) - .Build()) - .Build(); - commGroup.AddTask(taskId); - partialConfigs.Add(partialTaskConfig); - } - - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration mpiTaskConfig = mpiDriver.GetMpiTaskConfiguration(taskId); - IConfiguration mergedConf = Configurations.Merge(mpiTaskConfig, partialConfigs[i], serviceConfig); - IInjector injector = TangFactory.GetTang().NewInjector(mergedConf); - - IMpiClient mpiClient = injector.GetInstance<IMpiClient>(); - commGroups.Add(mpiClient.GetCommunicationGroup(groupName)); - } + var commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); //for master task IBroadcastSender<int> broadcastSender = commGroups[0].GetBroadcastSender<int>(broadcastOperatorName); @@ -175,7 +144,7 @@ namespace Org.Apache.REEF.Tests.Network Assert.AreEqual(sum, expected); } } - + [TestMethod] public void TestScatterReduceOperators() { @@ -187,52 +156,21 @@ namespace Org.Apache.REEF.Tests.Network int numTasks = 5; int fanOut = 2; - IMpiDriver mpiDriver = new MpiDriver(driverId, masterTaskId, fanOut, new AvroConfigurationSerializer()); + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); - ICommunicationGroupDriver commGroup = mpiDriver.NewCommunicationGroup( - groupName, - numTasks) - .AddScatter( + ICommunicationGroupDriver commGroup = mpiDriver.DefaultGroup + .AddScatter<int>( scatterOperatorName, - new ScatterOperatorSpec<int>( - masterTaskId, - new IntCodec())) + masterTaskId, + new IntCodec()) .AddReduce( reduceOperatorName, - new ReduceOperatorSpec<int>( masterTaskId, new IntCodec(), - new SumFunction())) + new SumFunction()) .Build(); - List<IConfiguration> partialConfigs = new List<IConfiguration>(); - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration partialTaskConfig = TangFactory.GetTang().NewConfigurationBuilder( - TaskConfiguration.ConfigurationModule - .Set(TaskConfiguration.Identifier, taskId) - .Set(TaskConfiguration.Task, GenericType<MyTask>.Class) - .Build()) - .Build(); - commGroup.AddTask(taskId); - partialConfigs.Add(partialTaskConfig); - } - - IConfiguration serviceConfig = mpiDriver.GetServiceConfiguration(); - - List<ICommunicationGroupClient> commGroups = new List<ICommunicationGroupClient>(); - - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration mpiTaskConfig = mpiDriver.GetMpiTaskConfiguration(taskId); - IConfiguration mergedConf = Configurations.Merge(mpiTaskConfig, partialConfigs[i], serviceConfig); - IInjector injector = TangFactory.GetTang().NewInjector(mergedConf); - - IMpiClient mpiClient = injector.GetInstance<IMpiClient>(); - commGroups.Add(mpiClient.GetCommunicationGroup(groupName)); - } + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(scatterOperatorName); IReduceReceiver<int> sumReducer = commGroups[0].GetReduceReceiver<int>(reduceOperatorName); @@ -270,18 +208,6 @@ namespace Org.Apache.REEF.Tests.Network Assert.AreEqual(sum, data.Sum()); } - private static void ScatterReceiveReduce(IScatterReceiver<int> receiver, IReduceSender<int> sumSender) - { - List<int> data1 = receiver.Receive(); - int sum1 = data1.Sum(); - sumSender.Send(sum1); - } - - private int TriangleNumber(int n) - { - return Enumerable.Range(1, n).Sum(); - } - [TestMethod] public void TestBroadcastOperator() { @@ -295,40 +221,47 @@ namespace Org.Apache.REEF.Tests.Network int value = 1337; int fanOut = 3; - IMpiDriver mpiDriver = new MpiDriver(driverId, masterTaskId, fanOut, new AvroConfigurationSerializer()); + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); - var commGroup = mpiDriver.NewCommunicationGroup(groupName, numTasks) - .AddBroadcast(operatorName, new BroadcastOperatorSpec<int>(masterTaskId, new IntCodec())) + var commGroup = mpiDriver.DefaultGroup + .AddBroadcast(operatorName, masterTaskId, new IntCodec()) .Build(); - List<IConfiguration> partialConfigs = new List<IConfiguration>(); - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration partialTaskConfig = TangFactory.GetTang().NewConfigurationBuilder( - TaskConfiguration.ConfigurationModule - .Set(TaskConfiguration.Identifier, taskId) - .Set(TaskConfiguration.Task, GenericType<MyTask>.Class) - .Build()) - .Build(); - commGroup.AddTask(taskId); - partialConfigs.Add(partialTaskConfig); - } + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); - IConfiguration serviceConfig = mpiDriver.GetServiceConfiguration(); + IBroadcastSender<int> sender = commGroups[0].GetBroadcastSender<int>(operatorName); + IBroadcastReceiver<int> receiver1 = commGroups[1].GetBroadcastReceiver<int>(operatorName); + IBroadcastReceiver<int> receiver2 = commGroups[2].GetBroadcastReceiver<int>(operatorName); - List<ICommunicationGroupClient> commGroups = new List<ICommunicationGroupClient>(); + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration mpiTaskConfig = mpiDriver.GetMpiTaskConfiguration(taskId); - IConfiguration mergedConf = Configurations.Merge(mpiTaskConfig, partialConfigs[i], serviceConfig); - IInjector injector = TangFactory.GetTang().NewInjector(mergedConf); + sender.Send(value); + Assert.AreEqual(value, receiver1.Receive()); + Assert.AreEqual(value, receiver2.Receive()); + } - IMpiClient mpiClient = injector.GetInstance<IMpiClient>(); - commGroups.Add(mpiClient.GetCommunicationGroup(groupName)); - } + [TestMethod] + public void TestBroadcastOperatorWithDefaultCodec() + { + NameServer nameServer = new NameServer(0); + + string groupName = "group1"; + string operatorName = "broadcast"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 10; + int value = 1337; + int fanOut = 3; + + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + var commGroup = mpiDriver.DefaultGroup + .AddBroadcast(operatorName, masterTaskId) + .Build(); + + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); IBroadcastSender<int> sender = commGroups[0].GetBroadcastSender<int>(operatorName); IBroadcastReceiver<int> receiver1 = commGroups[1].GetBroadcastReceiver<int>(operatorName); @@ -356,40 +289,13 @@ namespace Org.Apache.REEF.Tests.Network int value3 = 99; int fanOut = 2; - IMpiDriver mpiDriver = new MpiDriver(driverId, masterTaskId, fanOut, new AvroConfigurationSerializer()); - - var commGroup = mpiDriver.NewCommunicationGroup(groupName, numTasks) - .AddBroadcast(operatorName, new BroadcastOperatorSpec<int>(masterTaskId, new IntCodec())) - .Build(); - - List<IConfiguration> partialConfigs = new List<IConfiguration>(); - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration partialTaskConfig = TangFactory.GetTang().NewConfigurationBuilder( - TaskConfiguration.ConfigurationModule - .Set(TaskConfiguration.Identifier, taskId) - .Set(TaskConfiguration.Task, GenericType<MyTask>.Class) - .Build()) - .Build(); - - commGroup.AddTask(taskId); - partialConfigs.Add(partialTaskConfig); - } + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); - IConfiguration serviceConfig = mpiDriver.GetServiceConfiguration(); + var commGroup = mpiDriver.DefaultGroup + .AddBroadcast(operatorName, masterTaskId, new IntCodec()) + .Build(); - IList<ICommunicationGroupClient> commGroups = new List<ICommunicationGroupClient>(); - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration mpiTaskConfig = mpiDriver.GetMpiTaskConfiguration(taskId); - IConfiguration mergedConf = Configurations.Merge(mpiTaskConfig, partialConfigs[i], serviceConfig); - IInjector injector = TangFactory.GetTang().NewInjector(mergedConf); - - IMpiClient mpiClient = injector.GetInstance<IMpiClient>(); - commGroups.Add(mpiClient.GetCommunicationGroup(groupName)); - } + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); IBroadcastSender<int> sender = commGroups[0].GetBroadcastSender<int>(operatorName); IBroadcastReceiver<int> receiver1 = commGroups[1].GetBroadcastReceiver<int>(operatorName); @@ -418,39 +324,17 @@ namespace Org.Apache.REEF.Tests.Network string groupName = "group1"; string operatorName = "reduce"; int numTasks = 4; - IMpiDriver mpiDriver = new MpiDriver("driverid", "task0", 2, new AvroConfigurationSerializer()); - - var commGroup = mpiDriver.NewCommunicationGroup(groupName, numTasks) - .AddReduce(operatorName, new ReduceOperatorSpec<int>("task0", new IntCodec(), new SumFunction())) - .Build(); - - List<IConfiguration> partialConfigs = new List<IConfiguration>(); - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration partialTaskConfig = TangFactory.GetTang().NewConfigurationBuilder( - TaskConfiguration.ConfigurationModule - .Set(TaskConfiguration.Identifier, taskId) - .Set(TaskConfiguration.Task, GenericType<MyTask>.Class) - .Build()) - .Build(); - commGroup.AddTask(taskId); - partialConfigs.Add(partialTaskConfig); - } + string driverId = "driverid"; + string masterTaskId = "task0"; + int fanOut = 2; - IConfiguration serviceConfig = mpiDriver.GetServiceConfiguration(); - IList<ICommunicationGroupClient> commGroups = new List<ICommunicationGroupClient>(); + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration taskConfig = mpiDriver.GetMpiTaskConfiguration(taskId); - IConfiguration mergedConf = Configurations.Merge(taskConfig, partialConfigs[i], serviceConfig); - IInjector injector = TangFactory.GetTang().NewInjector(mergedConf); + var commGroup = mpiDriver.DefaultGroup + .AddReduce(operatorName, "task0", new IntCodec(), new SumFunction()) + .Build(); - IMpiClient mpiClient = injector.GetInstance<IMpiClient>(); - commGroups.Add(mpiClient.GetCommunicationGroup(groupName)); - } + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); IReduceReceiver<int> receiver = commGroups[0].GetReduceReceiver<int>(operatorName); IReduceSender<int> sender1 = commGroups[1].GetReduceSender<int>(operatorName); @@ -475,39 +359,17 @@ namespace Org.Apache.REEF.Tests.Network string groupName = "group1"; string operatorName = "reduce"; int numTasks = 4; - IMpiDriver mpiDriver = new MpiDriver("driverid", "task0", 2, new AvroConfigurationSerializer()); - - var commGroup = mpiDriver.NewCommunicationGroup(groupName, numTasks) - .AddReduce(operatorName, new ReduceOperatorSpec<int>("task0", new IntCodec(), new SumFunction())) - .Build(); - - List<IConfiguration> partialConfigs = new List<IConfiguration>(); - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration partialTaskConfig = TangFactory.GetTang().NewConfigurationBuilder( - TaskConfiguration.ConfigurationModule - .Set(TaskConfiguration.Identifier, taskId) - .Set(TaskConfiguration.Task, GenericType<MyTask>.Class) - .Build()) - .Build(); - commGroup.AddTask(taskId); - partialConfigs.Add(partialTaskConfig); - } + string driverId = "driverid"; + string masterTaskId = "task0"; + int fanOut = 2; - IConfiguration serviceConfig = mpiDriver.GetServiceConfiguration(); - IList<ICommunicationGroupClient> commGroups = new List<ICommunicationGroupClient>(); + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration taskConfig = mpiDriver.GetMpiTaskConfiguration(taskId); - IConfiguration mergedConf = Configurations.Merge(taskConfig, partialConfigs[i], serviceConfig); - IInjector injector = TangFactory.GetTang().NewInjector(mergedConf); + var commGroup = mpiDriver.DefaultGroup + .AddReduce(operatorName, "task0", new IntCodec(), new SumFunction()) + .Build(); - IMpiClient mpiClient = injector.GetInstance<IMpiClient>(); - commGroups.Add(mpiClient.GetCommunicationGroup(groupName)); - } + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); IReduceReceiver<int> receiver = commGroups[0].GetReduceReceiver<int>(operatorName); IReduceSender<int> sender1 = commGroups[1].GetReduceSender<int>(operatorName); @@ -545,40 +407,52 @@ namespace Org.Apache.REEF.Tests.Network int numTasks = 5; int fanOut = 2; - IMpiDriver mpiDriver = new MpiDriver(driverId, masterTaskId, fanOut, new AvroConfigurationSerializer()); + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); - var commGroup = mpiDriver.NewCommunicationGroup(groupName, numTasks) - .AddScatter(operatorName, new ScatterOperatorSpec<int>(masterTaskId, new IntCodec())) + var commGroup = mpiDriver.DefaultGroup + .AddScatter(operatorName, masterTaskId, new IntCodec()) .Build(); - List<IConfiguration> partialConfigs = new List<IConfiguration>(); - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration partialTaskConfig = TangFactory.GetTang().NewConfigurationBuilder( - TaskConfiguration.ConfigurationModule - .Set(TaskConfiguration.Identifier, taskId) - .Set(TaskConfiguration.Task, GenericType<MyTask>.Class) - .Build()) - .Build(); - commGroup.AddTask(taskId); - partialConfigs.Add(partialTaskConfig); - } + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); - IConfiguration serviceConfig = mpiDriver.GetServiceConfiguration(); + IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); + IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver3 = commGroups[3].GetScatterReceiver<int>(operatorName); + IScatterReceiver<int> receiver4 = commGroups[4].GetScatterReceiver<int>(operatorName); - List<ICommunicationGroupClient> commGroups = new List<ICommunicationGroupClient>(); + Assert.IsNotNull(sender); + Assert.IsNotNull(receiver1); + Assert.IsNotNull(receiver2); + Assert.IsNotNull(receiver3); + Assert.IsNotNull(receiver4); - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration mpiTaskConfig = mpiDriver.GetMpiTaskConfiguration(taskId); - IConfiguration mergedConf = Configurations.Merge(mpiTaskConfig, partialConfigs[i], serviceConfig); - IInjector injector = TangFactory.GetTang().NewInjector(mergedConf); + List<int> data = new List<int> { 1, 2, 3, 4 }; - IMpiClient mpiClient = injector.GetInstance<IMpiClient>(); - commGroups.Add(mpiClient.GetCommunicationGroup(groupName)); - } + sender.Send(data); + Assert.AreEqual(1, receiver1.Receive().Single()); + Assert.AreEqual(2, receiver2.Receive().Single()); + Assert.AreEqual(3, receiver3.Receive().Single()); + Assert.AreEqual(4, receiver4.Receive().Single()); + } + + [TestMethod] + public void TestScatterOperatorWithDefaultCodec() + { + string groupName = "group1"; + string operatorName = "scatter"; + string masterTaskId = "task0"; + string driverId = "Driver Id"; + int numTasks = 5; + int fanOut = 2; + + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + + var commGroup = mpiDriver.DefaultGroup + .AddScatter(operatorName, masterTaskId) + .Build(); + + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); @@ -611,40 +485,13 @@ namespace Org.Apache.REEF.Tests.Network int numTasks = 5; int fanOut = 2; - IMpiDriver mpiDriver = new MpiDriver(driverId, masterTaskId, fanOut, new AvroConfigurationSerializer()); + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); - var commGroup = mpiDriver.NewCommunicationGroup(groupName, numTasks) - .AddScatter(operatorName, new ScatterOperatorSpec<int>(masterTaskId, new IntCodec())) + var commGroup = mpiDriver.DefaultGroup + .AddScatter(operatorName, masterTaskId, new IntCodec()) .Build(); - List<IConfiguration> partialConfigs = new List<IConfiguration>(); - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration partialTaskConfig = TangFactory.GetTang().NewConfigurationBuilder( - TaskConfiguration.ConfigurationModule - .Set(TaskConfiguration.Identifier, taskId) - .Set(TaskConfiguration.Task, GenericType<MyTask>.Class) - .Build()) - .Build(); - commGroup.AddTask(taskId); - partialConfigs.Add(partialTaskConfig); - } - - IConfiguration serviceConfig = mpiDriver.GetServiceConfiguration(); - - List<ICommunicationGroupClient> commGroups = new List<ICommunicationGroupClient>(); - - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration mpiTaskConfig = mpiDriver.GetMpiTaskConfiguration(taskId); - IConfiguration mergedConf = Configurations.Merge(mpiTaskConfig, partialConfigs[i], serviceConfig); - IInjector injector = TangFactory.GetTang().NewInjector(mergedConf); - - IMpiClient mpiClient = injector.GetInstance<IMpiClient>(); - commGroups.Add(mpiClient.GetCommunicationGroup(groupName)); - } + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); @@ -688,40 +535,13 @@ namespace Org.Apache.REEF.Tests.Network int numTasks = 4; int fanOut = 2; - IMpiDriver mpiDriver = new MpiDriver(driverId, masterTaskId, fanOut, new AvroConfigurationSerializer()); + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); - var commGroup = mpiDriver.NewCommunicationGroup(groupName, numTasks) - .AddScatter(operatorName, new ScatterOperatorSpec<int>(masterTaskId, new IntCodec())) + var commGroup = mpiDriver.DefaultGroup + .AddScatter(operatorName, masterTaskId, new IntCodec()) .Build(); - List<IConfiguration> partialConfigs = new List<IConfiguration>(); - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration partialTaskConfig = TangFactory.GetTang().NewConfigurationBuilder( - TaskConfiguration.ConfigurationModule - .Set(TaskConfiguration.Identifier, taskId) - .Set(TaskConfiguration.Task, GenericType<MyTask>.Class) - .Build()) - .Build(); - commGroup.AddTask(taskId); - partialConfigs.Add(partialTaskConfig); - } - - IConfiguration serviceConfig = mpiDriver.GetServiceConfiguration(); - - List<ICommunicationGroupClient> commGroups = new List<ICommunicationGroupClient>(); - - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration mpiTaskConfig = mpiDriver.GetMpiTaskConfiguration(taskId); - IConfiguration mergedConf = Configurations.Merge(mpiTaskConfig, partialConfigs[i], serviceConfig); - IInjector injector = TangFactory.GetTang().NewInjector(mergedConf); - - IMpiClient mpiClient = injector.GetInstance<IMpiClient>(); - commGroups.Add(mpiClient.GetCommunicationGroup(groupName)); - } + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); @@ -762,41 +582,13 @@ namespace Org.Apache.REEF.Tests.Network int numTasks = 4; int fanOut = 2; - IMpiDriver mpiDriver = new MpiDriver(driverId, masterTaskId, fanOut, new AvroConfigurationSerializer()); + IMpiDriver mpiDriver = GetInstanceOfMpiDriver(driverId, masterTaskId, groupName, fanOut, numTasks); - var commGroup = mpiDriver.NewCommunicationGroup(groupName, numTasks) - .AddScatter(operatorName, new ScatterOperatorSpec<int>(masterTaskId, new IntCodec())) + var commGroup = mpiDriver.DefaultGroup + .AddScatter(operatorName, masterTaskId, new IntCodec()) .Build(); - List<IConfiguration> partialConfigs = new List<IConfiguration>(); - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration partialTaskConfig = TangFactory.GetTang().NewConfigurationBuilder( - TaskConfiguration.ConfigurationModule - .Set(TaskConfiguration.Identifier, taskId) - .Set(TaskConfiguration.Task, GenericType<MyTask>.Class) - .Build()) - .Build(); - commGroup.AddTask(taskId); - partialConfigs.Add(partialTaskConfig); - } - - IConfiguration serviceConfig = mpiDriver.GetServiceConfiguration(); - - List<ICommunicationGroupClient> commGroups = new List<ICommunicationGroupClient>(); - - for (int i = 0; i < numTasks; i++) - { - string taskId = "task" + i; - IConfiguration mpiTaskConfig = mpiDriver.GetMpiTaskConfiguration(taskId); - IConfiguration mergedConf = Configurations.Merge(mpiTaskConfig, partialConfigs[i], serviceConfig); - IInjector injector = TangFactory.GetTang().NewInjector(mergedConf); - - IMpiClient mpiClient = injector.GetInstance<IMpiClient>(); - commGroups.Add(mpiClient.GetCommunicationGroup(groupName)); - } - + List<ICommunicationGroupClient> commGroups = CommGroupClients(groupName, numTasks, mpiDriver, commGroup); IScatterSender<int> sender = commGroups[0].GetScatterSender<int>(operatorName); IScatterReceiver<int> receiver1 = commGroups[1].GetScatterReceiver<int>(operatorName); IScatterReceiver<int> receiver2 = commGroups[2].GetScatterReceiver<int>(operatorName); @@ -853,7 +645,54 @@ namespace Org.Apache.REEF.Tests.Network Assert.AreEqual(10, reduceFunction.Reduce(new int[] { 1, 2, 3, 4 })); } - private NetworkService<GroupCommunicationMessage> BuildNetworkService( + public static IMpiDriver GetInstanceOfMpiDriver(string driverId, string masterTaskId, string groupName, int fanOut, int numTasks) + { + var c = TangFactory.GetTang().NewConfigurationBuilder() + .BindStringNamedParam<MpiConfigurationOptions.DriverId>(driverId) + .BindStringNamedParam<MpiConfigurationOptions.MasterTaskId>(masterTaskId) + .BindStringNamedParam<MpiConfigurationOptions.GroupName>(groupName) + .BindIntNamedParam<MpiConfigurationOptions.FanOut>(fanOut.ToString()) + .BindIntNamedParam<MpiConfigurationOptions.NumberOfTasks>(numTasks.ToString()) + .BindImplementation(GenericType<IConfigurationSerializer>.Class, GenericType<AvroConfigurationSerializer>.Class) + .Build(); + + IMpiDriver mpiDriver = TangFactory.GetTang().NewInjector(c).GetInstance<MpiDriver>(); + return mpiDriver; + } + + public static List<ICommunicationGroupClient> CommGroupClients(string groupName, int numTasks, IMpiDriver mpiDriver, ICommunicationGroupDriver commGroup) + { + List<ICommunicationGroupClient> commGroups = new List<ICommunicationGroupClient>(); + IConfiguration serviceConfig = mpiDriver.GetServiceConfiguration(); + + List<IConfiguration> partialConfigs = new List<IConfiguration>(); + for (int i = 0; i < numTasks; i++) + { + string taskId = "task" + i; + IConfiguration partialTaskConfig = TangFactory.GetTang().NewConfigurationBuilder( + TaskConfiguration.ConfigurationModule + .Set(TaskConfiguration.Identifier, taskId) + .Set(TaskConfiguration.Task, GenericType<MyTask>.Class) + .Build()) + .Build(); + commGroup.AddTask(taskId); + partialConfigs.Add(partialTaskConfig); + } + + for (int i = 0; i < numTasks; i++) + { + string taskId = "task" + i; + IConfiguration mpiTaskConfig = mpiDriver.GetMpiTaskConfiguration(taskId); + IConfiguration mergedConf = Configurations.Merge(mpiTaskConfig, partialConfigs[i], serviceConfig); + IInjector injector = TangFactory.GetTang().NewInjector(mergedConf); + + IMpiClient mpiClient = injector.GetInstance<IMpiClient>(); + commGroups.Add(mpiClient.GetCommunicationGroup(groupName)); + } + return commGroups; + } + + public static NetworkService<GroupCommunicationMessage> BuildNetworkService( IPEndPoint nameServerEndpoint, IObserver<NsMessage<GroupCommunicationMessage>> handler) { return new NetworkService<GroupCommunicationMessage>( @@ -861,36 +700,48 @@ namespace Org.Apache.REEF.Tests.Network handler, new StringIdentifierFactory(), new GroupCommunicationMessageCodec()); } - private GroupCommunicationMessage CreateGcm(string message, string from, string to) + private GroupCommunicationMessage CreateGcm(string message, string from, string to) { byte[] data = Encoding.UTF8.GetBytes(message); return new GroupCommunicationMessage("g1", "op1", from, to, data, MessageType.Data); } - private class SumFunction : IReduceFunction<int> + private static void ScatterReceiveReduce(IScatterReceiver<int> receiver, IReduceSender<int> sumSender) { - [Inject] - public SumFunction() - { - } + List<int> data1 = receiver.Receive(); + int sum1 = data1.Sum(); + sumSender.Send(sum1); + } - public int Reduce(IEnumerable<int> elements) - { - return elements.Sum(); - } + public static int TriangleNumber(int n) + { + return Enumerable.Range(1, n).Sum(); } + } - private class MyTask : ITask + public class SumFunction : IReduceFunction<int> + { + [Inject] + public SumFunction() { - public void Dispose() - { - throw new NotImplementedException(); - } + } - public byte[] Call(byte[] memento) - { - throw new NotImplementedException(); - } + public int Reduce(IEnumerable<int> elements) + { + return elements.Sum(); + } + } + + public class MyTask : ITask + { + public void Dispose() + { + throw new NotImplementedException(); + } + + public byte[] Call(byte[] memento) + { + throw new NotImplementedException(); } } } \ No newline at end of file
