Repository: incubator-reef Updated Branches: refs/heads/master b9686cd79 -> 75f25a267
[REEF-329] Improve Injection Communication Group This PR is to improve injection in Communication Group: * Use Fork Injector in creating CommunicationGroupClient * Use Fork Injector in Creating GroupCommOperators * Use configuration data bound at driver side to inject GroupCommOperators at task side instead of letting clients to decide what operator instance to inject * Resolved race condition issue caused by delayed message handler registration JIRA: [REEF-329](https://issues.apache.org/jira/browse/REEF-329) Pull Request: This closes #187 Project: http://git-wip-us.apache.org/repos/asf/incubator-reef/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-reef/commit/75f25a26 Tree: http://git-wip-us.apache.org/repos/asf/incubator-reef/tree/75f25a26 Diff: http://git-wip-us.apache.org/repos/asf/incubator-reef/diff/75f25a26 Branch: refs/heads/master Commit: 75f25a267ae9e597944e79c79e179e5b544d980b Parents: b9686cd Author: Julia Wang <[email protected]> Authored: Fri May 8 18:37:27 2015 -0700 Committer: Markus Weimer <[email protected]> Committed: Thu May 21 17:57:40 2015 -0700 ---------------------------------------------------------------------- .../GroupCommunicationTests.cs | 204 ++++++++++++++++++- .../Org.Apache.REEF.Network.Tests.csproj | 14 +- .../Config/GroupCommConfigurationOptions.cs | 10 + .../Driver/Impl/CommunicationGroupDriver.cs | 3 - .../Group/Operators/Impl/BroadcastReceiver.cs | 26 +-- .../Group/Operators/Impl/BroadcastSender.cs | 17 +- .../Group/Operators/Impl/ReduceReceiver.cs | 18 +- .../Group/Operators/Impl/ReduceSender.cs | 13 +- .../Group/Operators/Impl/ScatterReceiver.cs | 17 +- .../Group/Operators/Impl/ScatterSender.cs | 18 +- .../Group/Task/Impl/CommunicationGroupClient.cs | 78 ++----- .../Impl/CommunicationGroupNetworkObserver.cs | 35 +--- .../Group/Task/Impl/GroupCommClient.cs | 27 +-- .../Group/Topology/FlatTopology.cs | 17 ++ .../Group/Topology/TreeTopology.cs | 18 +- .../Injection/TestInjection.cs | 82 ++++++++ 16 files changed, 437 insertions(+), 160 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/75f25a26/lang/cs/Org.Apache.REEF.Network.Tests/GroupCommunication/GroupCommunicationTests.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network.Tests/GroupCommunication/GroupCommunicationTests.cs b/lang/cs/Org.Apache.REEF.Network.Tests/GroupCommunication/GroupCommunicationTests.cs index 7a6b5c1..53b8cb5 100644 --- a/lang/cs/Org.Apache.REEF.Network.Tests/GroupCommunication/GroupCommunicationTests.cs +++ b/lang/cs/Org.Apache.REEF.Network.Tests/GroupCommunication/GroupCommunicationTests.cs @@ -20,6 +20,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Globalization; using System.Linq; using System.Net; using System.Reactive; @@ -29,12 +30,14 @@ using Org.Apache.REEF.Common.Io; using Org.Apache.REEF.Common.Tasks; using Org.Apache.REEF.Examples.MachineLearning.KMeans; using Org.Apache.REEF.Examples.MachineLearning.KMeans.codecs; +using Org.Apache.REEF.Network.Examples.GroupCommunication; using Org.Apache.REEF.Network.Group.Codec; using Org.Apache.REEF.Network.Group.Config; using Org.Apache.REEF.Network.Group.Driver; using Org.Apache.REEF.Network.Group.Driver.Impl; using Org.Apache.REEF.Network.Group.Operators; using Org.Apache.REEF.Network.Group.Operators.Impl; +using Org.Apache.REEF.Network.Group.Pipelining; using Org.Apache.REEF.Network.Group.Pipelining.Impl; using Org.Apache.REEF.Network.Group.Task; using Org.Apache.REEF.Network.Group.Topology; @@ -107,7 +110,6 @@ namespace Org.Apache.REEF.Network.Tests.GroupCommunication int fanOut = 2; var groupCommunicationDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); - ICommunicationGroupDriver commGroup = groupCommunicationDriver.DefaultGroup .AddBroadcast<int>( broadcastOperatorName, @@ -155,7 +157,76 @@ namespace Org.Apache.REEF.Network.Tests.GroupCommunication Assert.AreEqual(sum, expected); } } - + + /// <summary> + /// This is to test operator injection in CommunicationGroupClient with int[] as message type + /// </summary> + [TestMethod] + public void TestGetBroadcastReduceOperatorsForIntArrayMessageType() + { + const string groupName = "group1"; + const string broadcastOperatorName = "broadcast"; + const string reduceOperatorName = "reduce"; + const string masterTaskId = "task0"; + const string driverId = "Driver Id"; + const int numTasks = 3; + const int fanOut = 2; + + IConfiguration codecConfig = CodecConfiguration<int[]>.Conf + .Set(CodecConfiguration<int[]>.Codec, GenericType<IntArrayCodec>.Class) + .Build(); + + IConfiguration reduceFunctionConfig = ReduceFunctionConfiguration<int[]>.Conf + .Set(ReduceFunctionConfiguration<int[]>.ReduceFunction, GenericType<ArraySumFunction>.Class) + .Build(); + + IConfiguration dataConverterConfig = TangFactory.GetTang().NewConfigurationBuilder( + PipelineDataConverterConfiguration<int[]>.Conf + .Set(PipelineDataConverterConfiguration<int[]>.DataConverter, + GenericType<PipelineIntDataConverter>.Class) + .Build()) + .BindNamedParameter<GroupTestConfig.ChunkSize, int>( + GenericType<GroupTestConfig.ChunkSize>.Class, + GroupTestConstants.ChunkSize.ToString(CultureInfo.InvariantCulture)) + .Build(); + + var groupCommunicationDriver = GetInstanceOfGroupCommDriver(driverId, masterTaskId, groupName, fanOut, numTasks); + ICommunicationGroupDriver commGroup = groupCommunicationDriver.DefaultGroup + .AddBroadcast<int[]>( + broadcastOperatorName, + masterTaskId, + TopologyTypes.Flat, + codecConfig, + dataConverterConfig) + .AddReduce<int[]>( + reduceOperatorName, + masterTaskId, + TopologyTypes.Flat, + codecConfig, + dataConverterConfig, + reduceFunctionConfig) + .Build(); + + var commGroups = CommGroupClients(groupName, numTasks, groupCommunicationDriver, commGroup); + + //for master task + IBroadcastSender<int[]> broadcastSender = commGroups[0].GetBroadcastSender<int[]>(broadcastOperatorName); + IReduceReceiver<int[]> sumReducer = commGroups[0].GetReduceReceiver<int[]>(reduceOperatorName); + + IBroadcastReceiver<int[]> broadcastReceiver1 = commGroups[1].GetBroadcastReceiver<int[]>(broadcastOperatorName); + IReduceSender<int[]> triangleNumberSender1 = commGroups[1].GetReduceSender<int[]>(reduceOperatorName); + + IBroadcastReceiver<int[]> broadcastReceiver2 = commGroups[2].GetBroadcastReceiver<int[]>(broadcastOperatorName); + IReduceSender<int[]> triangleNumberSender2 = commGroups[2].GetReduceSender<int[]>(reduceOperatorName); + + Assert.IsNotNull(broadcastSender); + Assert.IsNotNull(sumReducer); + Assert.IsNotNull(broadcastReceiver1); + Assert.IsNotNull(triangleNumberSender1); + Assert.IsNotNull(broadcastReceiver2); + Assert.IsNotNull(triangleNumberSender2); + } + [TestMethod] public void TestScatterReduceOperators() { @@ -710,7 +781,7 @@ namespace Org.Apache.REEF.Network.Tests.GroupCommunication return groupCommDriver; } - public static List<ICommunicationGroupClient> CommGroupClients(string groupName, int numTasks, IGroupCommDriver groupCommDriver, ICommunicationGroupDriver commGroup) + public static List<ICommunicationGroupClient> CommGroupClients(string groupName, int numTasks, IGroupCommDriver groupCommDriver, ICommunicationGroupDriver commGroupDriver) { List<ICommunicationGroupClient> commGroups = new List<ICommunicationGroupClient>(); IConfiguration serviceConfig = groupCommDriver.GetServiceConfiguration(); @@ -725,17 +796,24 @@ namespace Org.Apache.REEF.Network.Tests.GroupCommunication .Set(TaskConfiguration.Task, GenericType<MyTask>.Class) .Build()) .Build(); - commGroup.AddTask(taskId); + commGroupDriver.AddTask(taskId); partialConfigs.Add(partialTaskConfig); } for (int i = 0; i < numTasks; i++) { + //get task configuration at driver side string taskId = "task" + i; IConfiguration groupCommTaskConfig = groupCommDriver.GetGroupCommTaskConfiguration(taskId); IConfiguration mergedConf = Configurations.Merge(groupCommTaskConfig, partialConfigs[i], serviceConfig); - IInjector injector = TangFactory.GetTang().NewInjector(mergedConf); + var conf = TangFactory.GetTang() + .NewConfigurationBuilder(mergedConf) + .BindNamedParameter(typeof(GroupCommConfigurationOptions.Initialize), "false") + .Build(); + IInjector injector = TangFactory.GetTang().NewInjector(conf); + + //simulate injection at evaluator side IGroupCommClient groupCommClient = injector.GetInstance<IGroupCommClient>(); commGroups.Add(groupCommClient.GetCommunicationGroup(groupName)); } @@ -821,4 +899,120 @@ namespace Org.Apache.REEF.Network.Tests.GroupCommunication throw new NotImplementedException(); } } + + class ArraySumFunction : IReduceFunction<int[]> + { + [Inject] + private ArraySumFunction() + { + } + + public int[] Reduce(IEnumerable<int[]> elements) + { + int[] result = null; + int count = 0; + + foreach (var element in elements) + { + if (count == 0) + { + result = element.Clone() as int[]; + } + else + { + if (element.Length != result.Length) + { + throw new Exception("integer arrays are of different sizes"); + } + + for (int i = 0; i < result.Length; i++) + { + result[i] += element[i]; + } + } + count++; + } + + return result; + } + } + + class IntArrayCodec : ICodec<int[]> + { + [Inject] + private IntArrayCodec() + { + } + + public byte[] Encode(int[] obj) + { + byte[] result = new byte[sizeof(Int32) * obj.Length]; + Buffer.BlockCopy(obj, 0, result, 0, result.Length); + return result; + } + + public int[] Decode(byte[] data) + { + if (data.Length % sizeof(Int32) != 0) + { + throw new Exception("error inside integer array decoder, byte array length not a multiple of interger size"); + } + + int[] result = new int[data.Length / sizeof(Int32)]; + Buffer.BlockCopy(data, 0, result, 0, data.Length); + return result; + } + } + + class PipelineIntDataConverter : IPipelineDataConverter<int[]> + { + readonly int _chunkSize; + + [Inject] + private PipelineIntDataConverter([Parameter(typeof(GroupTestConfig.ChunkSize))] int chunkSize) + { + _chunkSize = chunkSize; + } + + public List<PipelineMessage<int[]>> PipelineMessage(int[] message) + { + List<PipelineMessage<int[]>> messageList = new List<PipelineMessage<int[]>>(); + int totalChunks = message.Length / _chunkSize; + + if (message.Length % _chunkSize != 0) + { + totalChunks++; + } + + int counter = 0; + for (int i = 0; i < message.Length; i += _chunkSize) + { + int[] data = new int[Math.Min(_chunkSize, message.Length - i)]; + Buffer.BlockCopy(message, i * sizeof(int), data, 0, data.Length * sizeof(int)); + + messageList.Add(counter == totalChunks - 1 + ? new PipelineMessage<int[]>(data, true) + : new PipelineMessage<int[]>(data, false)); + + counter++; + } + + return messageList; + } + + public int[] FullMessage(List<PipelineMessage<int[]>> pipelineMessage) + { + int size = pipelineMessage.Select(x => x.Data.Length).Sum(); + int[] data = new int[size]; + int offset = 0; + + foreach (var message in pipelineMessage) + { + Buffer.BlockCopy(message.Data, 0, data, offset, message.Data.Length * sizeof(int)); + offset += message.Data.Length * sizeof(int); + } + + return data; + } + } } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/75f25a26/lang/cs/Org.Apache.REEF.Network.Tests/Org.Apache.REEF.Network.Tests.csproj ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network.Tests/Org.Apache.REEF.Network.Tests.csproj b/lang/cs/Org.Apache.REEF.Network.Tests/Org.Apache.REEF.Network.Tests.csproj index cb2019b..eebfdaf 100644 --- a/lang/cs/Org.Apache.REEF.Network.Tests/Org.Apache.REEF.Network.Tests.csproj +++ b/lang/cs/Org.Apache.REEF.Network.Tests/Org.Apache.REEF.Network.Tests.csproj @@ -59,23 +59,27 @@ under the License. <Compile Include="Properties\AssemblyInfo.cs" /> </ItemGroup> <ItemGroup> - <ProjectReference Include="..\Org.Apache.REEF.Common\Org.Apache.REEF.Common.csproj"> + <ProjectReference Include="$(SolutionDir)\Org.Apache.REEF.Common\Org.Apache.REEF.Common.csproj"> <Project>{545a0582-4105-44ce-b99c-b1379514a630}</Project> <Name>Org.Apache.REEF.Common</Name> </ProjectReference> - <ProjectReference Include="..\Org.Apache.REEF.Examples\Org.Apache.REEF.Examples.csproj"> + <ProjectReference Include="$(SolutionDir)\Org.Apache.REEF.Examples\Org.Apache.REEF.Examples.csproj"> <Project>{75503f90-7b82-4762-9997-94b5c68f15db}</Project> <Name>Org.Apache.REEF.Examples</Name> </ProjectReference> - <ProjectReference Include="..\Org.Apache.REEF.Network\Org.Apache.REEF.Network.csproj"> + <ProjectReference Include="$(SolutionDir)\Org.Apache.REEF.Network.Examples\Org.Apache.REEF.Network.Examples.csproj"> + <Project>{b1b43b60-ddd0-4805-a9b4-ba84a0ccb7c7}</Project> + <Name>Org.Apache.REEF.Network.Examples</Name> + </ProjectReference> + <ProjectReference Include="$(SolutionDir)\Org.Apache.REEF.Network\Org.Apache.REEF.Network.csproj"> <Project>{883ce800-6a6a-4e0a-b7fe-c054f4f2c1dc}</Project> <Name>Org.Apache.REEF.Network</Name> </ProjectReference> - <ProjectReference Include="..\Org.Apache.REEF.Tang\Org.Apache.REEF.Tang.csproj"> + <ProjectReference Include="$(SolutionDir)\Org.Apache.REEF.Tang\Org.Apache.REEF.Tang.csproj"> <Project>{97dbb573-3994-417a-9f69-ffa25f00d2a6}</Project> <Name>Org.Apache.REEF.Tang</Name> </ProjectReference> - <ProjectReference Include="..\Org.Apache.REEF.Wake\Org.Apache.REEF.Wake.csproj"> + <ProjectReference Include="$(SolutionDir)\Org.Apache.REEF.Wake\Org.Apache.REEF.Wake.csproj"> <Project>{cdfb3464-4041-42b1-9271-83af24cd5008}</Project> <Name>Org.Apache.REEF.Wake</Name> </ProjectReference> http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/75f25a26/lang/cs/Org.Apache.REEF.Network/Group/Config/GroupCommConfigurationOptions.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Config/GroupCommConfigurationOptions.cs b/lang/cs/Org.Apache.REEF.Network/Group/Config/GroupCommConfigurationOptions.cs index f86c546..50c0dd6 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Config/GroupCommConfigurationOptions.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Config/GroupCommConfigurationOptions.cs @@ -98,5 +98,15 @@ namespace Org.Apache.REEF.Network.Group.Config public class TopologyChildTaskIds : Name<ISet<string>> { } + + [NamedParameter("Type of the message")] + public class MessageType : Name<string> + { + } + + [NamedParameter("Wether or not to call topology initialize", defaultValue: "true")] + public class Initialize : Name<bool> + { + } } } http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/75f25a26/lang/cs/Org.Apache.REEF.Network/Group/Driver/Impl/CommunicationGroupDriver.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Driver/Impl/CommunicationGroupDriver.cs b/lang/cs/Org.Apache.REEF.Network/Group/Driver/Impl/CommunicationGroupDriver.cs index 89238f8..07581ba 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Driver/Impl/CommunicationGroupDriver.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Driver/Impl/CommunicationGroupDriver.cs @@ -316,9 +316,6 @@ namespace Org.Apache.REEF.Network.Group.Driver.Impl { var innerConf = TangFactory.GetTang().NewConfigurationBuilder(GetOperatorConfiguration(operatorName, taskId)) - .BindNamedParameter<GroupCommConfigurationOptions.DriverId, string>( - GenericType<GroupCommConfigurationOptions.DriverId>.Class, - _driverId) .BindNamedParameter<GroupCommConfigurationOptions.OperatorName, string>( GenericType<GroupCommConfigurationOptions.OperatorName>.Class, operatorName) http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/75f25a26/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/BroadcastReceiver.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/BroadcastReceiver.cs b/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/BroadcastReceiver.cs index 75ab88e..65ed6b9 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/BroadcastReceiver.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/BroadcastReceiver.cs @@ -36,23 +36,24 @@ namespace Org.Apache.REEF.Network.Group.Operators.Impl public class BroadcastReceiver<T> : IBroadcastReceiver<T> { private const int PipelineVersion = 2; - private readonly ICommunicationGroupNetworkObserver _networkHandler; private readonly OperatorTopology<PipelineMessage<T>> _topology; private static readonly Logger Logger = Logger.GetLogger(typeof(BroadcastReceiver<T>)); + /// <summary> /// Creates a new BroadcastReceiver. /// </summary> /// <param name="operatorName">The operator identifier</param> - /// <param name="groupName">The name of the CommunicationGroup that the - /// operator belongs to</param> + /// <param name="groupName">The name of the CommunicationGroup that the operator belongs to</param> + /// <param name="initialize">Require Topology Initialize to be called to wait for all task being registered. + /// Default is true. For unit testing, it can be set to false.</param> /// <param name="topology">The node's topology graph</param> /// <param name="networkHandler">The incoming message handler</param> - /// <param name="dataConverter">The converter used to convert original - /// message to pipelined ones and vice versa.</param> + /// <param name="dataConverter">The converter used to convert original message to pipelined ones and vice versa.</param> [Inject] - public BroadcastReceiver( + private BroadcastReceiver( [Parameter(typeof(GroupCommConfigurationOptions.OperatorName))] string operatorName, [Parameter(typeof(GroupCommConfigurationOptions.CommunicationGroupName))] string groupName, + [Parameter(typeof(GroupCommConfigurationOptions.Initialize))] bool initialize, OperatorTopology<PipelineMessage<T>> topology, ICommunicationGroupNetworkObserver networkHandler, IPipelineDataConverter<T> dataConverter) @@ -60,15 +61,16 @@ namespace Org.Apache.REEF.Network.Group.Operators.Impl OperatorName = operatorName; GroupName = groupName; Version = PipelineVersion; - - _networkHandler = networkHandler; + PipelineDataConverter = dataConverter; _topology = topology; - _topology.Initialize(); var msgHandler = Observer.Create<GroupCommunicationMessage>(message => _topology.OnNext(message)); - _networkHandler.Register(operatorName, msgHandler); + networkHandler.Register(operatorName, msgHandler); - PipelineDataConverter = dataConverter; + if (initialize) + { + topology.Initialize(); + } } /// <summary> @@ -91,7 +93,6 @@ namespace Org.Apache.REEF.Network.Group.Operators.Impl /// </summary> public IPipelineDataConverter<T> PipelineDataConverter { get; private set; } - /// <summary> /// Receive a message from parent BroadcastSender. /// </summary> @@ -115,6 +116,5 @@ namespace Org.Apache.REEF.Network.Group.Operators.Impl return PipelineDataConverter.FullMessage(messageList); } - } } http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/75f25a26/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/BroadcastSender.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/BroadcastSender.cs b/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/BroadcastSender.cs index 21701ea..34f9cd2 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/BroadcastSender.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/BroadcastSender.cs @@ -45,31 +45,36 @@ namespace Org.Apache.REEF.Network.Group.Operators.Impl /// <param name="operatorName">The identifier for the operator</param> /// <param name="groupName">The name of the CommunicationGroup that the operator /// belongs to</param> + /// <param name="initialize">Require Topology Initialize to be called to wait for all task being registered. + /// Default is true. For unit testing, it can be set to false.</param> /// <param name="topology">The node's topology graph</param> /// <param name="networkHandler">The incoming message handler</param> /// <param name="dataConverter">The converter used to convert original /// message to pipelined ones and vice versa.</param> [Inject] - public BroadcastSender( + private BroadcastSender( [Parameter(typeof(GroupCommConfigurationOptions.OperatorName))] string operatorName, [Parameter(typeof(GroupCommConfigurationOptions.CommunicationGroupName))] string groupName, + [Parameter(typeof(GroupCommConfigurationOptions.Initialize))] bool initialize, OperatorTopology<PipelineMessage<T>> topology, ICommunicationGroupNetworkObserver networkHandler, IPipelineDataConverter<T> dataConverter) { + _topology = topology; OperatorName = operatorName; GroupName = groupName; Version = PipelineVersion; - - _topology = topology; - _topology.Initialize(); + PipelineDataConverter = dataConverter; var msgHandler = Observer.Create<GroupCommunicationMessage>(message => _topology.OnNext(message)); networkHandler.Register(operatorName, msgHandler); - PipelineDataConverter = dataConverter; + if (initialize) + { + topology.Initialize(); + } } - + /// <summary> /// Returns the identifier for the Group Communication operator. /// </summary> http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/75f25a26/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ReduceReceiver.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ReduceReceiver.cs b/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ReduceReceiver.cs index 2242368..0c2fd94 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ReduceReceiver.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ReduceReceiver.cs @@ -35,7 +35,7 @@ namespace Org.Apache.REEF.Network.Group.Operators.Impl /// <typeparam name="T">The message type</typeparam> public class ReduceReceiver<T> : IReduceReceiver<T> { - private static readonly Logger Logger = Logger.GetLogger(typeof (ReduceReceiver<T>)); + private static readonly Logger Logger = Logger.GetLogger(typeof(ReduceReceiver<T>)); private const int PipelineVersion = 2; private readonly OperatorTopology<PipelineMessage<T>> _topology; private readonly PipelinedReduceFunction<T> _pipelinedReduceFunc; @@ -45,15 +45,18 @@ namespace Org.Apache.REEF.Network.Group.Operators.Impl /// </summary> /// <param name="operatorName">The name of the reduce operator</param> /// <param name="groupName">The name of the operator's CommunicationGroup</param> + /// <param name="initialize">Require Topology Initialize to be called to wait for all task being registered. + /// Default is true. For unit testing, it can be set to false.</param> /// <param name="topology">The task's operator topology graph</param> /// <param name="networkHandler">Handles incoming messages from other tasks</param> /// <param name="reduceFunction">The class used to aggregate all incoming messages</param> /// <param name="dataConverter">The converter used to convert original /// message to pipelined ones and vice versa.</param> [Inject] - public ReduceReceiver( - [Parameter(typeof (GroupCommConfigurationOptions.OperatorName))] string operatorName, - [Parameter(typeof (GroupCommConfigurationOptions.CommunicationGroupName))] string groupName, + private ReduceReceiver( + [Parameter(typeof(GroupCommConfigurationOptions.OperatorName))] string operatorName, + [Parameter(typeof(GroupCommConfigurationOptions.CommunicationGroupName))] string groupName, + [Parameter(typeof(GroupCommConfigurationOptions.Initialize))] bool initialize, OperatorTopology<PipelineMessage<T>> topology, ICommunicationGroupNetworkObserver networkHandler, IReduceFunction<T> reduceFunction, @@ -63,15 +66,18 @@ namespace Org.Apache.REEF.Network.Group.Operators.Impl GroupName = groupName; Version = PipelineVersion; ReduceFunction = reduceFunction; + PipelineDataConverter = dataConverter; _pipelinedReduceFunc = new PipelinedReduceFunction<T>(ReduceFunction); _topology = topology; - _topology.Initialize(); var msgHandler = Observer.Create<GroupCommunicationMessage>(message => _topology.OnNext(message)); networkHandler.Register(operatorName, msgHandler); - PipelineDataConverter = dataConverter; + if (initialize) + { + topology.Initialize(); + } } /// <summary> http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/75f25a26/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ReduceSender.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ReduceSender.cs b/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ReduceSender.cs index d61657f..cd2049b 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ReduceSender.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ReduceSender.cs @@ -36,7 +36,7 @@ namespace Org.Apache.REEF.Network.Group.Operators.Impl /// <typeparam name="T">The message type</typeparam> public class ReduceSender<T> : IReduceSender<T> { - private static readonly Logger Logger = Logger.GetLogger(typeof (ReduceSender<T>)); + private static readonly Logger Logger = Logger.GetLogger(typeof(ReduceSender<T>)); private const int PipelineVersion = 2; private readonly OperatorTopology<PipelineMessage<T>> _topology; private readonly PipelinedReduceFunction<T> _pipelinedReduceFunc; @@ -46,15 +46,18 @@ namespace Org.Apache.REEF.Network.Group.Operators.Impl /// </summary> /// <param name="operatorName">The name of the reduce operator</param> /// <param name="groupName">The name of the reduce operator's CommunicationGroup</param> + /// <param name="initialize">Require Topology Initialize to be called to wait for all task being registered. + /// Default is true. For unit testing, it can be set to false.</param> /// <param name="topology">The Task's operator topology graph</param> /// <param name="networkHandler">The handler used to handle incoming messages</param> /// <param name="reduceFunction">The function used to reduce the incoming messages</param> /// <param name="dataConverter">The converter used to convert original /// message to pipelined ones and vice versa.</param> [Inject] - public ReduceSender( + private ReduceSender( [Parameter(typeof(GroupCommConfigurationOptions.OperatorName))] string operatorName, [Parameter(typeof(GroupCommConfigurationOptions.CommunicationGroupName))] string groupName, + [Parameter(typeof(GroupCommConfigurationOptions.Initialize))] bool initialize, OperatorTopology<PipelineMessage<T>> topology, ICommunicationGroupNetworkObserver networkHandler, IReduceFunction<T> reduceFunction, @@ -68,12 +71,16 @@ namespace Org.Apache.REEF.Network.Group.Operators.Impl _pipelinedReduceFunc = new PipelinedReduceFunction<T>(ReduceFunction); _topology = topology; - _topology.Initialize(); var msgHandler = Observer.Create<GroupCommunicationMessage>(message => _topology.OnNext(message)); networkHandler.Register(operatorName, msgHandler); PipelineDataConverter = dataConverter; + + if (initialize) + { + topology.Initialize(); + } } /// <summary> http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/75f25a26/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ScatterReceiver.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ScatterReceiver.cs b/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ScatterReceiver.cs index b40ff68..13635bb 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ScatterReceiver.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ScatterReceiver.cs @@ -35,8 +35,6 @@ namespace Org.Apache.REEF.Network.Group.Operators.Impl public class ScatterReceiver<T> : IScatterReceiver<T> { private const int DefaultVersion = 1; - - private readonly ICommunicationGroupNetworkObserver _networkHandler; private readonly OperatorTopology<T> _topology; /// <summary> @@ -44,25 +42,30 @@ namespace Org.Apache.REEF.Network.Group.Operators.Impl /// </summary> /// <param name="operatorName">The name of the scatter operator</param> /// <param name="groupName">The name of the operator's CommunicationGroup</param> + /// <param name="initialize">Require Topology Initialize to be called to wait for all task being registered. + /// Default is true. For unit testing, it can be set to false.</param> /// <param name="topology">The task's operator topology graph</param> /// <param name="networkHandler">Handles incoming messages from other tasks</param> [Inject] - public ScatterReceiver( + private ScatterReceiver( [Parameter(typeof(GroupCommConfigurationOptions.OperatorName))] string operatorName, [Parameter(typeof(GroupCommConfigurationOptions.CommunicationGroupName))] string groupName, + [Parameter(typeof(GroupCommConfigurationOptions.Initialize))] bool initialize, OperatorTopology<T> topology, ICommunicationGroupNetworkObserver networkHandler) { OperatorName = operatorName; GroupName = groupName; Version = DefaultVersion; - - _networkHandler = networkHandler; _topology = topology; - _topology.Initialize(); var msgHandler = Observer.Create<GroupCommunicationMessage>(message => _topology.OnNext(message)); - _networkHandler.Register(operatorName, msgHandler); + networkHandler.Register(operatorName, msgHandler); + + if (initialize) + { + topology.Initialize(); + } } /// <summary> http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/75f25a26/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ScatterSender.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ScatterSender.cs b/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ScatterSender.cs index 2c664b8..47b6f6f 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ScatterSender.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Operators/Impl/ScatterSender.cs @@ -17,6 +17,7 @@ * under the License. */ +using System; using System.Collections.Generic; using System.Reactive; using Org.Apache.REEF.Network.Group.Config; @@ -35,8 +36,6 @@ namespace Org.Apache.REEF.Network.Group.Operators.Impl public class ScatterSender<T> : IScatterSender<T> { private const int DefaultVersion = 1; - - private readonly ICommunicationGroupNetworkObserver _networkHandler; private readonly OperatorTopology<T> _topology; /// <summary> @@ -44,25 +43,30 @@ namespace Org.Apache.REEF.Network.Group.Operators.Impl /// </summary> /// <param name="operatorName">The name of the scatter operator</param> /// <param name="groupName">The name of the operator's Communication Group</param> + /// <param name="initialize">Require Topology Initialize to be called to wait for all task being registered. + /// Default is true. For unit testing, it can be set to false.</param> /// <param name="topology">The operator topology</param> /// <param name="networkHandler">The network handler</param> [Inject] - public ScatterSender( + private ScatterSender( [Parameter(typeof(GroupCommConfigurationOptions.OperatorName))] string operatorName, [Parameter(typeof(GroupCommConfigurationOptions.CommunicationGroupName))] string groupName, + [Parameter(typeof(GroupCommConfigurationOptions.Initialize))] bool initialize, OperatorTopology<T> topology, ICommunicationGroupNetworkObserver networkHandler) { OperatorName = operatorName; GroupName = groupName; Version = DefaultVersion; - - _networkHandler = networkHandler; _topology = topology; - _topology.Initialize(); var msgHandler = Observer.Create<GroupCommunicationMessage>(message => _topology.OnNext(message)); - _networkHandler.Register(operatorName, msgHandler); + networkHandler.Register(operatorName, msgHandler); + + if (initialize) + { + topology.Initialize(); + } } public string OperatorName { get; private set; } http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/75f25a26/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/CommunicationGroupClient.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/CommunicationGroupClient.cs b/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/CommunicationGroupClient.cs index 3fcdc2f..1155048 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/CommunicationGroupClient.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/CommunicationGroupClient.cs @@ -19,18 +19,14 @@ using System; using System.Collections.Generic; -using Org.Apache.REEF.Common.Tasks; using Org.Apache.REEF.Network.Group.Config; -using Org.Apache.REEF.Network.Group.Driver.Impl; using Org.Apache.REEF.Network.Group.Operators; using Org.Apache.REEF.Network.Group.Operators.Impl; -using Org.Apache.REEF.Network.NetworkService; using Org.Apache.REEF.Tang.Annotations; -using Org.Apache.REEF.Tang.Exceptions; using Org.Apache.REEF.Tang.Formats; -using Org.Apache.REEF.Tang.Implementations.Tang; using Org.Apache.REEF.Tang.Interface; using Org.Apache.REEF.Tang.Util; +using Org.Apache.REEF.Utilities.Diagnostics; using Org.Apache.REEF.Utilities.Logging; namespace Org.Apache.REEF.Network.Group.Task.Impl @@ -41,61 +37,46 @@ namespace Org.Apache.REEF.Network.Group.Task.Impl public class CommunicationGroupClient : ICommunicationGroupClient { private readonly Logger LOGGER = Logger.GetLogger(typeof(CommunicationGroupClient)); - - private readonly string _taskId; - private string _driverId; - - private readonly Dictionary<string, IInjector> _operatorInjectors; private readonly Dictionary<string, object> _operators; - private readonly NetworkService<GroupCommunicationMessage> _networkService; - private readonly IGroupCommNetworkObserver _groupCommNetworkHandler; - private readonly ICommunicationGroupNetworkObserver _commGroupNetworkHandler; /// <summary> /// Creates a new CommunicationGroupClient. /// </summary> - /// <param name="taskId">The identifier for this Task.</param> /// <param name="groupName">The name of the CommunicationGroup</param> - /// <param name="driverId">The identifier for the driver</param> /// <param name="operatorConfigs">The serialized operator configurations</param> /// <param name="groupCommNetworkObserver">The handler for all incoming messages /// across all Communication Groups</param> - /// <param name="networkService">The network service used to send messages.</param> /// <param name="configSerializer">Used to deserialize operator configuration.</param> + /// <param name="commGroupNetworkHandler">Communication group network observer that holds all the handlers for each operator.</param> + /// <param name="injector">injector forked from the injector that creates this instance</param> [Inject] - public CommunicationGroupClient( - [Parameter(typeof(TaskConfigurationOptions.Identifier))] string taskId, + private CommunicationGroupClient( [Parameter(typeof(GroupCommConfigurationOptions.CommunicationGroupName))] string groupName, - [Parameter(typeof(GroupCommConfigurationOptions.DriverId))] string driverId, [Parameter(typeof(GroupCommConfigurationOptions.SerializedOperatorConfigs))] ISet<string> operatorConfigs, IGroupCommNetworkObserver groupCommNetworkObserver, - NetworkService<GroupCommunicationMessage> networkService, AvroConfigurationSerializer configSerializer, - CommunicationGroupNetworkObserver commGroupNetworkHandler) + ICommunicationGroupNetworkObserver commGroupNetworkHandler, + IInjector injector) { - _taskId = taskId; - _driverId = driverId; - GroupName = groupName; - _operators = new Dictionary<string, object>(); - _operatorInjectors = new Dictionary<string, IInjector>(); - _networkService = networkService; - _groupCommNetworkHandler = groupCommNetworkObserver; - _commGroupNetworkHandler = commGroupNetworkHandler; - _groupCommNetworkHandler.Register(groupName, _commGroupNetworkHandler); + GroupName = groupName; + groupCommNetworkObserver.Register(groupName, commGroupNetworkHandler); - // Deserialize operator configuration and store each injector. - // When user requests the Group Communication Operator, use type information to - // create the instance. foreach (string operatorConfigStr in operatorConfigs) - { + { IConfiguration operatorConfig = configSerializer.FromString(operatorConfigStr); - IInjector injector = TangFactory.GetTang().NewInjector(operatorConfig); - string operatorName = injector.GetNamedInstance<GroupCommConfigurationOptions.OperatorName, string>( + IInjector operatorInjector = injector.ForkInjector(operatorConfig); + string operatorName = operatorInjector.GetNamedInstance<GroupCommConfigurationOptions.OperatorName, string>( GenericType<GroupCommConfigurationOptions.OperatorName>.Class); - _operatorInjectors[operatorName] = injector; + string msgType = operatorInjector.GetNamedInstance<GroupCommConfigurationOptions.MessageType, string>( + GenericType<GroupCommConfigurationOptions.MessageType>.Class); + + Type groupCommOperatorGenericInterface = typeof(IGroupCommOperator<>); + Type groupCommOperatorInterface = groupCommOperatorGenericInterface.MakeGenericType(Type.GetType(msgType)); + var operatorObj = operatorInjector.GetInstance(groupCommOperatorInterface); + _operators.Add(operatorName, operatorObj); } } @@ -185,32 +166,11 @@ namespace Org.Apache.REEF.Network.Group.Task.Impl { throw new ArgumentNullException("operatorName"); } - if (!_operatorInjectors.ContainsKey(operatorName)) - { - throw new ArgumentException("Invalid operator name, cannot create CommunicationGroupClient"); - } object op; if (!_operators.TryGetValue(operatorName, out op)) { - IInjector injector = _operatorInjectors[operatorName]; - - injector.BindVolatileParameter(GenericType<TaskConfigurationOptions.Identifier>.Class, _taskId); - injector.BindVolatileParameter(GenericType<GroupCommConfigurationOptions.CommunicationGroupName>.Class, GroupName); - injector.BindVolatileInstance(GenericType<ICommunicationGroupNetworkObserver>.Class, _commGroupNetworkHandler); - injector.BindVolatileInstance(GenericType<NetworkService<GroupCommunicationMessage>>.Class, _networkService); - injector.BindVolatileInstance(GenericType<ICommunicationGroupClient>.Class, this); - - try - { - op = injector.GetInstance<T>(); - _operators[operatorName] = op; - } - catch (InjectionException) - { - LOGGER.Log(Level.Error, "Cannot inject Group Communication operator: No known operator of type: {0}", typeof(T)); - throw; - } + Exceptions.Throw(new ArgumentException("Operator is not added at Driver side:" + operatorName), LOGGER); } return (T) op; http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/75f25a26/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/CommunicationGroupNetworkObserver.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/CommunicationGroupNetworkObserver.cs b/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/CommunicationGroupNetworkObserver.cs index 444c4a1..30dddcf 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/CommunicationGroupNetworkObserver.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/CommunicationGroupNetworkObserver.cs @@ -35,20 +35,14 @@ namespace Org.Apache.REEF.Network.Group.Task.Impl { private static readonly Logger LOGGER = Logger.GetLogger(typeof(CommunicationGroupNetworkObserver)); private readonly Dictionary<string, IObserver<GroupCommunicationMessage>> _handlers; - private readonly int _retryCount; - private readonly int _sleepTime; /// <summary> /// Creates a new CommunicationGroupNetworkObserver. /// </summary> [Inject] - public CommunicationGroupNetworkObserver( - [Parameter(typeof(GroupCommConfigurationOptions.RetryCountWaitingForHanler))] int retryCount, - [Parameter(typeof(GroupCommConfigurationOptions.SleepTimeWaitingForHandler))] int sleepTime) + public CommunicationGroupNetworkObserver() { _handlers = new Dictionary<string, IObserver<GroupCommunicationMessage>>(); - _retryCount = retryCount; - _sleepTime = sleepTime; } /// <summary> @@ -83,7 +77,7 @@ namespace Org.Apache.REEF.Network.Group.Task.Impl { string operatorName = message.OperatorName; - IObserver<GroupCommunicationMessage> handler = GetOperatorHandler(operatorName, _retryCount, _sleepTime); + IObserver<GroupCommunicationMessage> handler = GetOperatorHandler(operatorName); if (handler == null) { @@ -99,30 +93,15 @@ namespace Org.Apache.REEF.Network.Group.Task.Impl /// GetOperatorHandler for operatorName /// </summary> /// <param name="operatorName"></param> - /// <param name="retry"></param> - /// <param name="sleepTime"></param> /// <returns></returns> - private IObserver<GroupCommunicationMessage> GetOperatorHandler(string operatorName, int retry, int sleepTime) + private IObserver<GroupCommunicationMessage> GetOperatorHandler(string operatorName) { - //registration of handler might be delayed while the Network Service has received message from other servers - for (int i = 0; i < retry; i++) + IObserver<GroupCommunicationMessage> handler; + if (!_handlers.TryGetValue(operatorName, out handler)) { - if (!_handlers.ContainsKey(operatorName)) - { - LOGGER.Log(Level.Info, "handler for operator {0} has not been registered." + operatorName); - Thread.Sleep(sleepTime); - } - else - { - IObserver<GroupCommunicationMessage> handler; - if (!_handlers.TryGetValue(operatorName, out handler)) - { - Exceptions.Throw(new ArgumentException("No handler registered yet with the operator name: " + operatorName), LOGGER); - } - return handler; - } + Exceptions.Throw(new ApplicationException("No handler registered yet with the operator name: " + operatorName), LOGGER); } - return null; + return handler; } public void OnError(Exception error) http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/75f25a26/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/GroupCommClient.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/GroupCommClient.cs b/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/GroupCommClient.cs index 4f0f283..4cf0e06 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/GroupCommClient.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Task/Impl/GroupCommClient.cs @@ -25,15 +25,13 @@ using Org.Apache.REEF.Network.Group.Driver.Impl; using Org.Apache.REEF.Network.NetworkService; using Org.Apache.REEF.Tang.Annotations; using Org.Apache.REEF.Tang.Formats; -using Org.Apache.REEF.Tang.Implementations.Tang; using Org.Apache.REEF.Tang.Interface; -using Org.Apache.REEF.Tang.Util; using Org.Apache.REEF.Wake.Remote.Impl; namespace Org.Apache.REEF.Network.Group.Task.Impl { /// <summary> - /// Used by Tasks to fetch CommunicationGroupClients. + /// Container of ommunicationGroupClients /// </summary> public class GroupCommClient : IGroupCommClient { @@ -43,20 +41,20 @@ namespace Org.Apache.REEF.Network.Group.Task.Impl /// <summary> /// Creates a new GroupCommClient and registers the task ID with the Name Server. + /// Currently the GroupCommClient is injected in task constructor. When work with REEF-289, we should put the injection at a proepr palce. /// </summary> /// <param name="groupConfigs">The set of serialized Group Communication configurations</param> - /// <param name="taskId">The identifier for this task</param> - /// <param name="groupCommNetworkObserver">The network handler to receive incoming messages - /// for this task</param> + /// <param name="taskId">The identifier for this taskfor this task</param> /// <param name="networkService">The network service used to send messages</param> /// <param name="configSerializer">Used to deserialize Group Communication configuration</param> + /// <param name="injector">injector forked from the injector that creates this instance</param> [Inject] - public GroupCommClient( + private GroupCommClient( [Parameter(typeof(GroupCommConfigurationOptions.SerializedGroupConfigs))] ISet<string> groupConfigs, [Parameter(typeof(TaskConfigurationOptions.Identifier))] string taskId, - IGroupCommNetworkObserver groupCommNetworkObserver, NetworkService<GroupCommunicationMessage> networkService, - AvroConfigurationSerializer configSerializer) + AvroConfigurationSerializer configSerializer, + IInjector injector) { _commGroups = new Dictionary<string, ICommunicationGroupClient>(); _networkService = networkService; @@ -65,14 +63,9 @@ namespace Org.Apache.REEF.Network.Group.Task.Impl foreach (string serializedGroupConfig in groupConfigs) { IConfiguration groupConfig = configSerializer.FromString(serializedGroupConfig); - - IInjector injector = TangFactory.GetTang().NewInjector(groupConfig); - injector.BindVolatileParameter(GenericType<TaskConfigurationOptions.Identifier>.Class, taskId); - injector.BindVolatileInstance(GenericType<IGroupCommNetworkObserver>.Class, groupCommNetworkObserver); - injector.BindVolatileInstance(GenericType<NetworkService<GroupCommunicationMessage>>.Class, networkService); - - ICommunicationGroupClient commGroup = injector.GetInstance<ICommunicationGroupClient>(); - _commGroups[commGroup.GroupName] = commGroup; + IInjector groupInjector = injector.ForkInjector(groupConfig); + ICommunicationGroupClient commGroupClient = groupInjector.GetInstance<ICommunicationGroupClient>(); + _commGroups[commGroupClient.GroupName] = commGroupClient; } } http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/75f25a26/lang/cs/Org.Apache.REEF.Network/Group/Topology/FlatTopology.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Topology/FlatTopology.cs b/lang/cs/Org.Apache.REEF.Network/Group/Topology/FlatTopology.cs index e80dea6..c36f1ca 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Topology/FlatTopology.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Topology/FlatTopology.cs @@ -110,10 +110,12 @@ namespace Org.Apache.REEF.Network.Group.Topology if (taskId.Equals(broadcastSpec.SenderId)) { confBuilder.BindImplementation(GenericType<IGroupCommOperator<T>>.Class, GenericType<BroadcastSender<T>>.Class); + SetMessageType(typeof(BroadcastSender<T>), confBuilder); } else { confBuilder.BindImplementation(GenericType<IGroupCommOperator<T>>.Class, GenericType<BroadcastReceiver<T>>.Class); + SetMessageType(typeof(BroadcastReceiver<T>), confBuilder); } } else if (OperatorSpec is ReduceOperatorSpec) @@ -122,10 +124,12 @@ namespace Org.Apache.REEF.Network.Group.Topology if (taskId.Equals(reduceSpec.ReceiverId)) { confBuilder.BindImplementation(GenericType<IGroupCommOperator<T>>.Class, GenericType<ReduceReceiver<T>>.Class); + SetMessageType(typeof(ReduceReceiver<T>), confBuilder); } else { confBuilder.BindImplementation(GenericType<IGroupCommOperator<T>>.Class, GenericType<ReduceSender<T>>.Class); + SetMessageType(typeof(ReduceSender<T>), confBuilder); } } else if (OperatorSpec is ScatterOperatorSpec) @@ -134,10 +138,12 @@ namespace Org.Apache.REEF.Network.Group.Topology if (taskId.Equals(scatterSpec.SenderId)) { confBuilder.BindImplementation(GenericType<IGroupCommOperator<T>>.Class, GenericType<ScatterSender<T>>.Class); + SetMessageType(typeof(ScatterSender<T>), confBuilder); } else { confBuilder.BindImplementation(GenericType<IGroupCommOperator<T>>.Class, GenericType<ScatterReceiver<T>>.Class); + SetMessageType(typeof(ScatterReceiver<T>), confBuilder); } } else @@ -148,6 +154,17 @@ namespace Org.Apache.REEF.Network.Group.Topology return Configurations.Merge(confBuilder.Build(), OperatorSpec.Configiration); } + private static void SetMessageType(Type operatorType, ICsConfigurationBuilder confBuilder) + { + if (operatorType.IsGenericType) + { + var genericTypes = operatorType.GenericTypeArguments; + var msgType = genericTypes[0]; + confBuilder.BindNamedParameter<GroupCommConfigurationOptions.MessageType, string>( + GenericType<GroupCommConfigurationOptions.MessageType>.Class, msgType.AssemblyQualifiedName); + } + } + /// <summary> /// Adds a task to the topology graph. /// </summary> http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/75f25a26/lang/cs/Org.Apache.REEF.Network/Group/Topology/TreeTopology.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Topology/TreeTopology.cs b/lang/cs/Org.Apache.REEF.Network/Group/Topology/TreeTopology.cs index e8ba6c1..d6c6bc6 100644 --- a/lang/cs/Org.Apache.REEF.Network/Group/Topology/TreeTopology.cs +++ b/lang/cs/Org.Apache.REEF.Network/Group/Topology/TreeTopology.cs @@ -111,7 +111,6 @@ namespace Org.Apache.REEF.Network.Group.Topology //add parentid, if no parent, add itself ICsConfigurationBuilder confBuilder = TangFactory.GetTang().NewConfigurationBuilder() - //.BindImplementation(typeof(ICodec<T1>), OperatorSpec.Codec) .BindNamedParameter<GroupCommConfigurationOptions.TopologyRootTaskId, string>( GenericType<GroupCommConfigurationOptions.TopologyRootTaskId>.Class, parentId); @@ -130,10 +129,12 @@ namespace Org.Apache.REEF.Network.Group.Topology if (taskId.Equals(broadcastSpec.SenderId)) { confBuilder.BindImplementation(GenericType<IGroupCommOperator<T>>.Class, GenericType<BroadcastSender<T>>.Class); + SetMessageType(typeof(BroadcastSender<T>), confBuilder); } else { confBuilder.BindImplementation(GenericType<IGroupCommOperator<T>>.Class, GenericType<BroadcastReceiver<T>>.Class); + SetMessageType(typeof(BroadcastReceiver<T>), confBuilder); } } else if (OperatorSpec is ReduceOperatorSpec) @@ -142,10 +143,12 @@ namespace Org.Apache.REEF.Network.Group.Topology if (taskId.Equals(reduceSpec.ReceiverId)) { confBuilder.BindImplementation(GenericType<IGroupCommOperator<T>>.Class, GenericType<ReduceReceiver<T>>.Class); + SetMessageType(typeof(ReduceReceiver<T>), confBuilder); } else { confBuilder.BindImplementation(GenericType<IGroupCommOperator<T>>.Class, GenericType<ReduceSender<T>>.Class); + SetMessageType(typeof(ReduceSender<T>), confBuilder); } } else if (OperatorSpec is ScatterOperatorSpec) @@ -154,10 +157,12 @@ namespace Org.Apache.REEF.Network.Group.Topology if (taskId.Equals(scatterSpec.SenderId)) { confBuilder.BindImplementation(GenericType<IGroupCommOperator<T>>.Class, GenericType<ScatterSender<T>>.Class); + SetMessageType(typeof(ScatterSender<T>), confBuilder); } else { confBuilder.BindImplementation(GenericType<IGroupCommOperator<T>>.Class, GenericType<ScatterReceiver<T>>.Class); + SetMessageType(typeof(ScatterReceiver<T>), confBuilder); } } else @@ -233,5 +238,16 @@ namespace Org.Apache.REEF.Network.Group.Topology _prev.Successor = node; _prev = node; } + + private static void SetMessageType(Type operatorType, ICsConfigurationBuilder confBuilder) + { + if (operatorType.IsGenericType) + { + var genericTypes = operatorType.GenericTypeArguments; + var msgType = genericTypes[0]; + confBuilder.BindNamedParameter<GroupCommConfigurationOptions.MessageType, string>( + GenericType<GroupCommConfigurationOptions.MessageType>.Class, msgType.AssemblyQualifiedName); + } + } } } http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/75f25a26/lang/cs/Org.Apache.REEF.Tang.Tests/Injection/TestInjection.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Tang.Tests/Injection/TestInjection.cs b/lang/cs/Org.Apache.REEF.Tang.Tests/Injection/TestInjection.cs index 3fbb173..99bb297 100644 --- a/lang/cs/Org.Apache.REEF.Tang.Tests/Injection/TestInjection.cs +++ b/lang/cs/Org.Apache.REEF.Tang.Tests/Injection/TestInjection.cs @@ -304,6 +304,57 @@ namespace Org.Apache.REEF.Tang.Tests.Injection Assert.IsNotNull(o.ExternalObject is ExternalClass); } + + /// <summary> + /// In this test, interface is a generic of T. Implementations have different generic arguments such as int and string. + /// When doing injection, we must specify the interface with a specified argument type + /// </summary> + [TestMethod] + public void TestInjectionWithGenericArguments() + { + var c = TangFactory.GetTang().NewConfigurationBuilder() + .BindImplementation(GenericType<IMyOperator<int>>.Class, GenericType<MyOperatorImpl<int>>.Class) + .BindImplementation(GenericType<IMyOperator<string>>.Class, GenericType<MyOperatorImpl<string>>.Class) + .Build(); + + var injector = TangFactory.GetTang().NewInjector(c); + + //argument type must be specified in injection + var o1 = injector.GetInstance(typeof(IMyOperator<int>)); + var o2 = injector.GetInstance(typeof(IMyOperator<string>)); + var o3 = injector.GetInstance(typeof(MyOperatorTopology<int>)); + + Assert.IsTrue(o1 is MyOperatorImpl<int>); + Assert.IsTrue(o2 is MyOperatorImpl<string>); + Assert.IsTrue(o3 is MyOperatorTopology<int>); + } + + /// <summary> + /// In this test, interface argument type is set through Configuration. We can get the argument type and then + /// make the interface with the argument type on the fly so that to do the injection + /// </summary> + [TestMethod] + public void TestInjectionWithGenericArgumentType() + { + var c = TangFactory.GetTang().NewConfigurationBuilder() + .BindImplementation(GenericType<IMyOperator<int[]>>.Class, GenericType<MyOperatorImpl<int[]>>.Class) + .BindNamedParameter(typeof(MessageType), typeof(int[]).AssemblyQualifiedName) + .Build(); + + var injector = TangFactory.GetTang().NewInjector(c); + + //get argument type from configuration + var messageTypeAsString = injector.GetNamedInstance<MessageType, string>(GenericType<MessageType>.Class); + Type messageType = Type.GetType(messageTypeAsString); + + //creat interface with generic type on the fly + Type genericInterfaceType = typeof(IMyOperator<>); + Type interfaceOfMessageType = genericInterfaceType.MakeGenericType(messageType); + + var o = injector.GetInstance(interfaceOfMessageType); + + Assert.IsTrue(o is MyOperatorImpl<int[]>); + } } class AReferenceClass : IAInterface @@ -386,4 +437,35 @@ namespace Org.Apache.REEF.Tang.Tests.Injection { } } + + interface IMyOperator<T> + { + string OperatorName { get; } + } + + class MyOperatorImpl<T> : IMyOperator<T> + { + [Inject] + public MyOperatorImpl() + { + } + + string IMyOperator<T>.OperatorName + { + get { throw new NotImplementedException(); } + } + } + + [NamedParameter] + class MessageType : Name<string> + { + } + + class MyOperatorTopology<T> + { + [Inject] + public MyOperatorTopology(IMyOperator<T> op) + { + } + } } \ No newline at end of file
