[REEF-150] Adding group communication to REEF .Net This is to port Group Communication to REEF .Net in Apache Git Add source code, examples and tests Updated namespace to follow conventions
IIRA: Reef-150. (https://issues.apache.org/jira/browse/REEF-150) Author: Julia Wang Email: [email protected] Project: http://git-wip-us.apache.org/repos/asf/incubator-reef/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-reef/commit/0292caf1 Tree: http://git-wip-us.apache.org/repos/asf/incubator-reef/tree/0292caf1 Diff: http://git-wip-us.apache.org/repos/asf/incubator-reef/diff/0292caf1 Branch: refs/heads/master Commit: 0292caf1437d61586d4e7ba1370710be833f5292 Parents: 7edb857 Author: Julia Wang <[email protected]> Authored: Tue Feb 10 19:16:14 2015 -0800 Committer: tmajest <[email protected]> Committed: Wed Feb 11 14:55:10 2015 -0800 ---------------------------------------------------------------------- .../MachineLearning/KMeans/Centroids.cs | 44 ++ .../MachineLearning/KMeans/Constants.cs | 35 ++ .../KMeans/Contracts/CentroidsContract.cs | 49 ++ .../KMeans/Contracts/DataVectorContract.cs | 52 ++ .../KMeans/Contracts/PartialMeanContract.cs | 48 ++ .../Contracts/ProcessedResultsContract.cs | 57 +++ .../KMeans/DataPartitionCache.cs | 104 ++++ .../MachineLearning/KMeans/DataVector.cs | 260 ++++++++++ .../KMeans/KMeansConfiguratioinOptions.cs | 46 ++ .../KMeans/KMeansDriverHandlers.cs | 191 ++++++++ .../MachineLearning/KMeans/KMeansMasterTask.cs | 155 ++++++ .../MachineLearning/KMeans/KMeansSlaveTask.cs | 118 +++++ .../MachineLearning/KMeans/LegacyKMeansTask.cs | 113 +++++ .../MachineLearning/KMeans/PartialMean.cs | 124 +++++ .../MachineLearning/KMeans/ProcessedResults.cs | 54 +++ .../KMeans/codecs/CentroidsCodec.cs | 49 ++ .../KMeans/codecs/DataVectorCodec.cs | 46 ++ .../KMeans/codecs/ProcessedResultsCodec.cs | 57 +++ .../Org.Apache.REEF.Examples.csproj | 19 + .../Group/Codec/GcmMessageProto.cs | 76 +++ .../Codec/GroupCommunicationMessageCodec.cs | 77 +++ .../Group/Config/MpiConfigurationOptions.cs | 71 +++ .../Group/Driver/ICommunicationGroupDriver.cs | 89 ++++ .../Group/Driver/IMpiDriver.cs | 93 ++++ .../Driver/Impl/CommunicationGroupDriver.cs | 260 ++++++++++ .../Driver/Impl/GroupCommunicationMessage.cs | 107 ++++ .../Group/Driver/Impl/MessageType.cs | 30 ++ .../Group/Driver/Impl/MpiDriver.cs | 239 +++++++++ .../Group/Driver/Impl/TaskStarter.cs | 135 ++++++ .../Group/Operators/IBroadcastReceiver.cs | 40 ++ .../Group/Operators/IBroadcastSender.cs | 40 ++ .../Group/Operators/IMpiOperator.cs | 49 ++ .../Group/Operators/IOperatorSpec.cs | 36 ++ .../Group/Operators/IReduceFunction.cs | 41 ++ .../Group/Operators/IReduceReceiver.cs | 46 ++ .../Group/Operators/IReduceSender.cs | 40 ++ .../Group/Operators/IScatterReceiver.cs | 41 ++ .../Group/Operators/IScatterSender.cs | 60 +++ .../Operators/Impl/BroadcastOperatorSpec.cs | 50 ++ .../Group/Operators/Impl/BroadcastReceiver.cs | 92 ++++ .../Group/Operators/Impl/BroadcastSender.cs | 98 ++++ .../Group/Operators/Impl/ReduceFunction.cs | 62 +++ .../Group/Operators/Impl/ReduceOperatorSpec.cs | 62 +++ .../Group/Operators/Impl/ReduceReceiver.cs | 100 ++++ .../Group/Operators/Impl/ReduceSender.cs | 97 ++++ .../Group/Operators/Impl/ScatterOperatorSpec.cs | 58 +++ .../Group/Operators/Impl/ScatterReceiver.cs | 101 ++++ .../Group/Operators/Impl/ScatterSender.cs | 112 +++++ .../Group/Operators/Impl/Sender.cs | 74 +++ .../Group/Task/ICommunicationGroupClient.cs | 90 ++++ .../Task/ICommunicationGroupNetworkObserver.cs | 49 ++ .../Group/Task/IMpiClient.cs | 44 ++ .../Group/Task/IMpiNetworkObserver.cs | 50 ++ .../Group/Task/Impl/CommunicationGroupClient.cs | 219 +++++++++ .../Impl/CommunicationGroupNetworkObserver.cs | 93 ++++ .../Group/Task/Impl/MpiClient.cs | 108 +++++ .../Group/Task/Impl/MpiNetworkObserver.cs | 109 +++++ .../Group/Task/Impl/NodeStruct.cs | 67 +++ .../Group/Task/Impl/OperatorTopology.cs | 484 +++++++++++++++++++ .../Group/Topology/FlatTopology.cs | 201 ++++++++ .../Group/Topology/ITopology.cs | 41 ++ .../Group/Topology/TaskNode.cs | 69 +++ .../Org.Apache.REEF.Network.csproj | 43 ++ .../Functional/ML/KMeans/TestKMeans.cs | 163 +++++++ .../BroadcastReduceDriver.cs | 187 +++++++ .../BroadcastReduceTest/BroadcastReduceTest.cs | 83 ++++ .../MPI/BroadcastReduceTest/MasterTask.cs | 89 ++++ .../MPI/BroadcastReduceTest/SlaveTask.cs | 80 +++ .../Functional/MPI/MpiTestConfig.cs | 36 ++ .../Functional/MPI/MpiTestConstants.cs | 33 ++ .../MPI/ScatterReduceTest/MasterTask.cs | 72 +++ .../ScatterReduceTest/ScatterReduceDriver.cs | 168 +++++++ .../MPI/ScatterReduceTest/ScatterReduceTest.cs | 80 +++ .../MPI/ScatterReduceTest/SlaveTask.cs | 67 +++ .../Org.Apache.REEF.Tests.csproj | 11 + 75 files changed, 6833 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Centroids.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Centroids.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Centroids.cs new file mode 100644 index 0000000..0ceb4ed --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Centroids.cs @@ -0,0 +1,44 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System.Collections.Generic; +using Org.Apache.REEF.Examples.MachineLearning.KMeans.codecs; +using Org.Apache.REEF.Utilities; + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans +{ + public class Centroids + { + public Centroids(List<DataVector> points) + { + Points = points; + } + + public List<DataVector> Points { get; set; } + + /// <summary> + /// helper function mostly used for logging + /// </summary> + /// <returns>the serialized string</returns> + public override string ToString() + { + return ByteUtilities.ByteArrarysToString(new CentroidsCodec().Encode(this)); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Constants.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Constants.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Constants.cs new file mode 100644 index 0000000..526031e --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Constants.cs @@ -0,0 +1,35 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans +{ + public class Constants + { + public const string KMeansExecutionBaseDirectory = @"KMeans"; + public const string DataDirectory = "data"; + public const string PartialMeanFilePrefix = "partialMeans_"; + public const string CentroidsFile = "centroids"; + public const string MasterTaskId = "KMeansMasterTaskId"; + public const string SlaveTaskIdPrefix = "KMeansSlaveTask_"; + public const string KMeansCommunicationGroupName = "KMeansBroadcastReduceGroup"; + public const string CentroidsBroadcastOperatorName = "CentroidsBroadcast"; + public const string ControlMessageBroadcastOperatorName = "ControlMessageBroadcast"; + public const string MeansReduceOperatorName = "MeansReduce"; + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Contracts/CentroidsContract.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Contracts/CentroidsContract.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Contracts/CentroidsContract.cs new file mode 100644 index 0000000..8ea56c9 --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Contracts/CentroidsContract.cs @@ -0,0 +1,49 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System.Collections.Generic; +using System.Linq; +using System.Runtime.Serialization; + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans.Contracts +{ + [DataContract] + public class CentroidsContract + { + [DataMember] + public List<DataVectorContract> DataVectorContracts { get; set; } + + public static CentroidsContract Create(Centroids centroids) + { + List<DataVectorContract> dataVectorContracts = + centroids.Points.Select(DataVectorContract.Create).ToList(); + + return new CentroidsContract { DataVectorContracts = dataVectorContracts }; + } + + public Centroids ToCentroids() + { + List<DataVector> dataVectors = DataVectorContracts + .Select(dv => dv.ToDataVector()) + .ToList(); + + return new Centroids(dataVectors); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Contracts/DataVectorContract.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Contracts/DataVectorContract.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Contracts/DataVectorContract.cs new file mode 100644 index 0000000..1e41944 --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Contracts/DataVectorContract.cs @@ -0,0 +1,52 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System.Collections.Generic; +using System.Runtime.Serialization; + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans.Contracts +{ + [DataContract] + public class DataVectorContract + { + [DataMember] + public List<float> Data { get; private set; } + + [DataMember] + public int Label { get; private set; } + + [DataMember] + public int Dimension { get; private set; } + + public static DataVectorContract Create(DataVector dataVector) + { + return new DataVectorContract() + { + Data = dataVector.Data, + Label = dataVector.Label, + Dimension = dataVector.Dimension + }; + } + + public DataVector ToDataVector() + { + return new DataVector(Data, Label); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Contracts/PartialMeanContract.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Contracts/PartialMeanContract.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Contracts/PartialMeanContract.cs new file mode 100644 index 0000000..b62f8ff --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Contracts/PartialMeanContract.cs @@ -0,0 +1,48 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System.Runtime.Serialization; + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans.Contracts +{ + [DataContract] + public class PartialMeanContract + { + [DataMember] + public DataVectorContract DataVectContract { get; set; } + + [DataMember] + public int Size { get; set; } + + public static PartialMeanContract Create(PartialMean partialMean) + { + return new PartialMeanContract + { + DataVectContract = DataVectorContract.Create(partialMean.Mean), + Size = partialMean.Size + }; + } + + public PartialMean ToPartialMean() + { + DataVector dataVector = DataVectContract.ToDataVector(); + return new PartialMean(dataVector, Size); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Contracts/ProcessedResultsContract.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Contracts/ProcessedResultsContract.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Contracts/ProcessedResultsContract.cs new file mode 100644 index 0000000..c78ee3c --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/Contracts/ProcessedResultsContract.cs @@ -0,0 +1,57 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System.Collections.Generic; +using System.Linq; +using System.Runtime.Serialization; + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans.Contracts +{ + [DataContract] + public class ProcessedResultsContract + { + [DataMember] + public List<PartialMeanContract> PartialMeanContracts { get; set; } + + [DataMember] + public float Loss { get; set; } + + public static ProcessedResultsContract Create(ProcessedResults obj) + { + List<PartialMeanContract> partialMeansContracts = obj.Means + .Select(PartialMeanContract.Create) + .ToList(); + + return new ProcessedResultsContract + { + PartialMeanContracts = partialMeansContracts, + Loss = obj.Loss + }; + } + + public ProcessedResults ToProcessedResults() + { + List<PartialMean> partialMeans = PartialMeanContracts + .Select(contract => contract.ToPartialMean()) + .ToList(); + + return new ProcessedResults(partialMeans, Loss); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/DataPartitionCache.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/DataPartitionCache.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/DataPartitionCache.cs new file mode 100644 index 0000000..91c3ed0 --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/DataPartitionCache.cs @@ -0,0 +1,104 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using Org.Apache.REEF.Common.Services; +using Org.Apache.REEF.Tang.Annotations; +using Org.Apache.REEF.Utilities.Logging; + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans +{ + // TODO: we should outsource some of the functionalites to a data loader implemenation + public class DataPartitionCache : IService + { + private static readonly Logger _Logger = Logger.GetLogger(typeof(DataPartitionCache)); + + [Inject] + public DataPartitionCache( + [Parameter(Value = typeof(PartitionIndex))] int partition, + [Parameter(Value = typeof(KMeansConfiguratioinOptions.ExecutionDirectory))] string executionDirectory) + { + Partition = partition; + if (Partition < 0) + { + _Logger.Log(Level.Info, "no data to load since partition = " + Partition); + } + else + { + _Logger.Log(Level.Info, "loading data for partition " + Partition); + DataVectors = loadData(partition, executionDirectory); + } + } + + public List<DataVector> DataVectors { get; set; } + + public int Partition { get; set; } + + // read initial data from file and marked it as unlabeled (not associated with any centroid) + public static List<DataVector> ReadDataFile(string path, char seperator = ',') + { + List<DataVector> data = new List<DataVector>(); + FileStream file = new FileStream(path, FileMode.Open, FileAccess.Read, FileShare.Read); + using (StreamReader reader = new StreamReader(file)) + { + while (!reader.EndOfStream) + { + string line = reader.ReadLine(); + if (!string.IsNullOrWhiteSpace(line)) + { + data.Add(DataVector.FromString(line)); + } + } + reader.Close(); + } + + return data; + } + + public void LabelData(Centroids centroids) + { + foreach (DataVector vector in DataVectors) + { + float minimumDistance = float.MaxValue; + foreach (DataVector centroid in centroids.Points) + { + float d = vector.DistanceTo(centroid); + if (d < minimumDistance) + { + vector.Label = centroid.Label; + minimumDistance = d; + } + } + } + } + + private List<DataVector> loadData(int partition, string executionDirectory) + { + string file = Path.Combine(executionDirectory, Constants.DataDirectory, partition.ToString(CultureInfo.InvariantCulture)); + return ReadDataFile(file); + } + + [NamedParameter("Data partition index", "partition", "")] + public class PartitionIndex : Name<int> + { + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/DataVector.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/DataVector.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/DataVector.cs new file mode 100644 index 0000000..fb6a16e --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/DataVector.cs @@ -0,0 +1,260 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Linq; + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans +{ + public class DataVector + { + public DataVector(int dimension, int label = -1) + { + Dimension = dimension; + Data = Enumerable.Repeat((float)0, Dimension).ToList(); + Label = label; + } + + // unlablered data + public DataVector(List<float> data) : this(data, -1) + { + } + + public DataVector(List<float> data, int label) + { + if (data == null || data.Count == 0) + { + throw new ArgumentNullException("data"); + } + Dimension = data.Count; + Data = data; + Label = label; + } + + public List<float> Data { get; set; } + + public int Label { get; set; } + + public int Dimension { get; set; } + + public static float TotalDistance(List<DataVector> list1, List<DataVector> list2) + { + if (list1 == null || list2 == null || list1.Count == 0 || list2.Count == 0) + { + throw new ArgumentException("one of the input list is null or empty"); + } + if (list1.Count != list2.Count) + { + throw new ArgumentException("list 1's dimensionality does not mach list 2."); + } + float distance = 0; + for (int i = 0; i < list1.Count; i++) + { + distance += list1[i].DistanceTo(list2[i]); + } + return distance; + } + + public static DataVector Mean(List<DataVector> vectors) + { + if (vectors == null || vectors.Count == 0) + { + throw new ArgumentNullException("vectors"); + } + DataVector mean = new DataVector(vectors[0].Dimension); + for (int i = 0; i < vectors.Count; i++) + { + mean = mean.Add(vectors[i], ignoreLabel: true); + } + return mean.Normalize(vectors.Count); + } + + // shuffle data and write them to different partions (different files on disk for now) + public static List<DataVector> ShuffleDataAndGetInitialCentriods(string originalDataFile, int partitionsNum, int clustersNum, string executionDirectory) + { + List<DataVector> data = DataPartitionCache.ReadDataFile(originalDataFile); + // shuffle, not truely random, but sufficient for our purpose + data = data.OrderBy(a => Guid.NewGuid()).ToList(); + string dataDirectory = Path.Combine(executionDirectory, Constants.DataDirectory); + // clean things up first + if (Directory.Exists(dataDirectory)) + { + Directory.Delete(dataDirectory, true); + } + Directory.CreateDirectory(dataDirectory); + + int residualCount = data.Count; + int batchSize = data.Count / partitionsNum; + for (int i = 0; i < partitionsNum; i++) + { + int linesCount = residualCount > batchSize ? batchSize : residualCount; + using (StreamWriter writer = new StreamWriter( + File.OpenWrite(Path.Combine(executionDirectory, Constants.DataDirectory, i.ToString(CultureInfo.InvariantCulture))))) + { + for (int j = i * batchSize; j < (i * batchSize) + linesCount; j++) + { + writer.WriteLine(data[j].ToString()); + } + writer.Close(); + } + } + return InitializeCentroids(clustersNum, data, executionDirectory); + } + + public static void WriteToCentroidFile(List<DataVector> centroids, string executionDirectory) + { + string centroidFile = Path.Combine(executionDirectory, Constants.CentroidsFile); + File.Delete(centroidFile); + using (StreamWriter writer = new StreamWriter(File.OpenWrite(centroidFile))) + { + foreach (DataVector dataVector in centroids) + { + writer.WriteLine(dataVector.ToString()); + } + writer.Close(); + } + } + + // TODO: replace with proper deserialization + public static DataVector FromString(string str) + { + if (string.IsNullOrWhiteSpace(str)) + { + throw new ArgumentException("str"); + } + string[] dataAndLable = str.Split(';'); + if (dataAndLable == null || dataAndLable.Length > 2) + { + throw new ArgumentException("Cannot deserialize DataVector from string " + str); + } + int label = -1; + if (dataAndLable.Length == 2) + { + label = int.Parse(dataAndLable[1], CultureInfo.InvariantCulture); + } + List<float> data = dataAndLable[0].Split(',').Select(float.Parse).ToList(); + return new DataVector(data, label); + } + + // by default use squared euclidean disatance + // a naive implemenation without considering things like data normalization or overflow + // and it is not particular about efficiency + public float DistanceTo(DataVector other) + { + VectorsArithmeticPrecondition(other); + float d = 0; + for (int i = 0; i < Data.Count; i++) + { + float diff = Data[i] - other.Data[i]; + d += diff * diff; + } + return d; + } + + public float DistanceTo(List<DataVector> list) + { + float distance = 0; + for (int i = 0; i < list.Count; i++) + { + distance += this.DistanceTo(list[i]); + } + return distance; + } + + public DataVector Add(DataVector other, bool ignoreLabel = false) + { + VectorsArithmeticPrecondition(other); + if (!ignoreLabel) + { + if (Label != other.Label) + { + throw new InvalidOperationException("by default cannot apply addition operation on data of different labels."); + } + } + List<float> sumData = new List<float>(Data); + for (int i = 0; i < Data.Count; i++) + { + sumData[i] += other.Data[i]; + } + return new DataVector(sumData, ignoreLabel ? -1 : Label); + } + + public DataVector Normalize(float normalizationFactor) + { + if (normalizationFactor == 0) + { + throw new InvalidOperationException("normalizationFactor is zero"); + } + DataVector result = new DataVector(Data, Label); + for (int i = 0; i < Data.Count; i++) + { + result.Data[i] /= normalizationFactor; + } + return result; + } + + public DataVector MultiplyScalar(float scalar) + { + DataVector result = new DataVector(Data, Label); + for (int i = 0; i < Data.Count; i++) + { + result.Data[i] *= scalar; + } + return result; + } + + // TODO: replace with proper serialization + public override string ToString() + { + return string.Join(",", Data.Select(i => i.ToString(CultureInfo.InvariantCulture)).ToArray()) + ";" + Label; + } + + // normally centroids are picked as random points from the vector space + // here we just pick K random data samples + private static List<DataVector> InitializeCentroids(int clustersNum, List<DataVector> data, string executionDirectory) + { + // again we used the not-some-random guid trick, + // not truly random and not quite efficient, but easy to implement as v1 + List<DataVector> centroids = data.OrderBy(a => Guid.NewGuid()).Take(clustersNum).ToList(); + + // add label to centroids + for (int i = 0; i < centroids.Count; i++) + { + centroids[i].Label = i; + } + WriteToCentroidFile(centroids, executionDirectory); + return centroids; + } + + private void VectorsArithmeticPrecondition(DataVector other) + { + if (other == null || other.Data == null) + { + throw new ArgumentNullException("other"); + } + if (Data.Count != other.Data.Count) + { + throw new InvalidOperationException("vector dimentionality mismatch"); + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansConfiguratioinOptions.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansConfiguratioinOptions.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansConfiguratioinOptions.cs new file mode 100644 index 0000000..aeac77d --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansConfiguratioinOptions.cs @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using Org.Apache.REEF.Tang.Annotations; + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans +{ + public class KMeansConfiguratioinOptions + { + [NamedParameter("Number of clusters", "K", "0")] + public class K : Name<int> + { + } + + /// <summary> + /// This is for loading the initial data samples from file + /// currently it is assumed to load from local disk, we can easily extend this to + /// be an url that point to cloud storage, and have things downloaded from blob storage instead + /// </summary> + [NamedParameter("Directory for storing all execution data", "DD")] + public class ExecutionDirectory : Name<string> + { + } + + [NamedParameter("Number of Evaluators")] + public class TotalNumEvaluators : Name<int> + { + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansDriverHandlers.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansDriverHandlers.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansDriverHandlers.cs new file mode 100644 index 0000000..235a268 --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansDriverHandlers.cs @@ -0,0 +1,191 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Linq; +using Org.Apache.REEF.Common.Io; +using Org.Apache.REEF.Common.Services; +using Org.Apache.REEF.Common.Tasks; +using Org.Apache.REEF.Driver; +using Org.Apache.REEF.Driver.Bridge; +using Org.Apache.REEF.Driver.Context; +using Org.Apache.REEF.Driver.Evaluator; +using Org.Apache.REEF.Examples.MachineLearning.KMeans.codecs; +using Org.Apache.REEF.Network.Group.Driver; +using Org.Apache.REEF.Network.Group.Driver.Impl; +using Org.Apache.REEF.Network.Group.Operators.Impl; +using Org.Apache.REEF.Network.NetworkService; +using Org.Apache.REEF.Network.NetworkService.Codec; +using Org.Apache.REEF.Tang.Annotations; +using Org.Apache.REEF.Tang.Formats; +using Org.Apache.REEF.Tang.Implementations.Configuration; +using Org.Apache.REEF.Tang.Implementations.Tang; +using Org.Apache.REEF.Tang.Interface; +using Org.Apache.REEF.Tang.Util; +using Org.Apache.REEF.Utilities.Logging; + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans +{ + public class KMeansDriverHandlers : + IStartHandler, + IObserver<IEvaluatorRequestor>, + IObserver<IAllocatedEvaluator>, + IObserver<IActiveContext> + { + private static readonly Logger _Logger = Logger.GetLogger(typeof(KMeansDriverHandlers)); + private readonly object _lockObj = new object(); + private string _executionDirectory; + + // TODO: we may want to make this injectable + private int _partitionsNumber = 2; + private int _clustersNumber = 3; + private int _totalEvaluators; + private int _partitionInex = 0; + private IMpiDriver _mpiDriver; + private ICommunicationGroupDriver _commGroup; + private TaskStarter _mpiTaskStarter; + + [Inject] + public KMeansDriverHandlers() + { + Identifier = "KMeansDriverId"; + _executionDirectory = Path.Combine(Directory.GetCurrentDirectory(), Constants.KMeansExecutionBaseDirectory, Guid.NewGuid().ToString("N").Substring(0, 4)); + ISet<string> arguments = ClrHandlerHelper.GetCommandLineArguments(); + string dataFile = arguments.Single(a => a.StartsWith("DataFile", StringComparison.Ordinal)).Split(':')[1]; + DataVector.ShuffleDataAndGetInitialCentriods( + Path.Combine(Directory.GetCurrentDirectory(), "reef", "global", dataFile), + _partitionsNumber, + _clustersNumber, + _executionDirectory); + + _totalEvaluators = _partitionsNumber + 1; + _mpiDriver = new MpiDriver(Identifier, Constants.MasterTaskId, new AvroConfigurationSerializer()); + + _commGroup = _mpiDriver.NewCommunicationGroup( + Constants.KMeansCommunicationGroupName, + _totalEvaluators) + .AddBroadcast(Constants.CentroidsBroadcastOperatorName, new BroadcastOperatorSpec<Centroids>(Constants.MasterTaskId, new CentroidsCodec())) + .AddBroadcast(Constants.ControlMessageBroadcastOperatorName, new BroadcastOperatorSpec<ControlMessage>(Constants.MasterTaskId, new ControlMessageCodec())) + .AddReduce(Constants.MeansReduceOperatorName, new ReduceOperatorSpec<ProcessedResults>(Constants.MasterTaskId, new ProcessedResultsCodec(), new KMeansMasterTask.AggregateMeans())) + .Build(); + _mpiTaskStarter = new TaskStarter(_mpiDriver, _totalEvaluators); + + CreateClassHierarchy(); + } + + public string Identifier { get; set; } + + public void OnNext(IEvaluatorRequestor evalutorRequestor) + { + int evaluatorsNumber = _totalEvaluators; + int memory = 2048; + int core = 1; + EvaluatorRequest request = new EvaluatorRequest(evaluatorsNumber, memory, core); + + evalutorRequestor.Submit(request); + } + + public void OnNext(IAllocatedEvaluator allocatedEvaluator) + { + IConfiguration contextConfiguration = _mpiDriver.GetContextConfiguration(); + + int partitionNum; + if (_mpiDriver.IsMasterContextConfiguration(contextConfiguration)) + { + partitionNum = -1; + } + else + { + lock (_lockObj) + { + partitionNum = _partitionInex; + _partitionInex++; + } + } + + IConfiguration gcServiceConfiguration = _mpiDriver.GetServiceConfiguration(); + + IConfiguration commonServiceConfiguration = TangFactory.GetTang().NewConfigurationBuilder(gcServiceConfiguration) + .BindNamedParameter<DataPartitionCache.PartitionIndex, int>(GenericType<DataPartitionCache.PartitionIndex>.Class, partitionNum.ToString(CultureInfo.InvariantCulture)) + .BindNamedParameter<KMeansConfiguratioinOptions.ExecutionDirectory, string>(GenericType<KMeansConfiguratioinOptions.ExecutionDirectory>.Class, _executionDirectory) + .BindNamedParameter<KMeansConfiguratioinOptions.TotalNumEvaluators, int>(GenericType<KMeansConfiguratioinOptions.TotalNumEvaluators>.Class, _totalEvaluators.ToString(CultureInfo.InvariantCulture)) + .BindNamedParameter<KMeansConfiguratioinOptions.K, int>(GenericType<KMeansConfiguratioinOptions.K>.Class, _clustersNumber.ToString(CultureInfo.InvariantCulture)) + .Build(); + + IConfiguration dataCacheServiceConfiguration = ServiceConfiguration.ConfigurationModule + .Set(ServiceConfiguration.Services, GenericType<DataPartitionCache>.Class) + .Build(); + + allocatedEvaluator.SubmitContextAndService(contextConfiguration, Configurations.Merge(commonServiceConfiguration, dataCacheServiceConfiguration)); + } + + public void OnNext(IActiveContext activeContext) + { + IConfiguration taskConfiguration; + + if (_mpiDriver.IsMasterTaskContext(activeContext)) + { + // Configure Master Task + taskConfiguration = TaskConfiguration.ConfigurationModule + .Set(TaskConfiguration.Identifier, Constants.MasterTaskId) + .Set(TaskConfiguration.Task, GenericType<KMeansMasterTask>.Class) + .Build(); + + _commGroup.AddTask(Constants.MasterTaskId); + } + else + { + string slaveTaskId = Constants.SlaveTaskIdPrefix + activeContext.Id; + // Configure Slave Task + taskConfiguration = TaskConfiguration.ConfigurationModule + .Set(TaskConfiguration.Identifier, Constants.SlaveTaskIdPrefix + activeContext.Id) + .Set(TaskConfiguration.Task, GenericType<KMeansSlaveTask>.Class) + .Build(); + + _commGroup.AddTask(slaveTaskId); + } + _mpiTaskStarter.QueueTask(taskConfiguration, activeContext); + } + + public void OnError(Exception error) + { + throw new NotImplementedException(); + } + + public void OnCompleted() + { + throw new NotImplementedException(); + } + + private void CreateClassHierarchy() + { + HashSet<string> clrDlls = new HashSet<string>(); + clrDlls.Add(typeof(IDriver).Assembly.GetName().Name); + clrDlls.Add(typeof(ITask).Assembly.GetName().Name); + clrDlls.Add(typeof(LegacyKMeansTask).Assembly.GetName().Name); + clrDlls.Add(typeof(INameClient).Assembly.GetName().Name); + clrDlls.Add(typeof(INetworkService<>).Assembly.GetName().Name); + + ClrHandlerHelper.GenerateClassHierarchy(clrDlls); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansMasterTask.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansMasterTask.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansMasterTask.cs new file mode 100644 index 0000000..3dd7adb --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansMasterTask.cs @@ -0,0 +1,155 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Linq; +using Org.Apache.REEF.Common.Tasks; +using Org.Apache.REEF.Network.Group.Operators; +using Org.Apache.REEF.Network.Group.Task; +using Org.Apache.REEF.Network.NetworkService; +using Org.Apache.REEF.Tang.Annotations; +using Org.Apache.REEF.Utilities.Logging; + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans +{ + public class KMeansMasterTask : ITask + { + private static Logger _logger = Logger.GetLogger(typeof(KMeansMasterTask)); + + private int _iteration = 0; + + private ICommunicationGroupClient _commGroup; + private IBroadcastSender<Centroids> _dataBroadcastSender; + private IBroadcastSender<ControlMessage> _controlBroadcastSender; + private IReduceReceiver<ProcessedResults> _meansReducerReceiver; + private string _kMeansExecutionDirectory; + private Centroids _centroids; + private bool _isInitialIteration; + + [Inject] + public KMeansMasterTask( + [Parameter(typeof(KMeansConfiguratioinOptions.TotalNumEvaluators))] int totalNumEvaluators, + [Parameter(Value = typeof(KMeansConfiguratioinOptions.ExecutionDirectory))] string executionDirectory, + IMpiClient mpiClient) + { + using (_logger.LogFunction("KMeansMasterTask")) + { + if (totalNumEvaluators <= 1) + { + throw new ArgumentException("There must be more than 1 Evaluators in total, but the total evaluators number provided is " + totalNumEvaluators); + } + _commGroup = mpiClient.GetCommunicationGroup(Constants.KMeansCommunicationGroupName); + _dataBroadcastSender = _commGroup.GetBroadcastSender<Centroids>(Constants.CentroidsBroadcastOperatorName); + _meansReducerReceiver = _commGroup.GetReduceReceiver<ProcessedResults>(Constants.MeansReduceOperatorName); + _controlBroadcastSender = _commGroup.GetBroadcastSender<ControlMessage>(Constants.ControlMessageBroadcastOperatorName); + _kMeansExecutionDirectory = executionDirectory; + _isInitialIteration = true; + } + } + + public byte[] Call(byte[] memento) + { + // TODO: this belongs to dedicated dataloader layer, will refactor once we have that + string centroidFile = Path.Combine(_kMeansExecutionDirectory, Constants.CentroidsFile); + _centroids = new Centroids(DataPartitionCache.ReadDataFile(centroidFile)); + + float loss = float.MaxValue; + float newLoss; + + while (true) + { + if (_isInitialIteration) + { + // broadcast initial centroids to all slave nodes + _logger.Log(Level.Info, "Broadcasting INITIAL centroids to all slave nodes: " + _centroids); + _isInitialIteration = false; + } + else + { + ProcessedResults results = _meansReducerReceiver.Reduce(); + _centroids = new Centroids(results.Means.Select(m => m.Mean).ToList()); + _logger.Log(Level.Info, "Broadcasting new centroids to all slave nodes: " + _centroids); + newLoss = results.Loss; + _logger.Log(Level.Info, string.Format(CultureInfo.InvariantCulture, "The new loss value {0} at iteration {1} ", newLoss, _iteration)); + if (newLoss > loss) + { + _controlBroadcastSender.Send(ControlMessage.STOP); + throw new InvalidOperationException( + string.Format(CultureInfo.InvariantCulture, "The new loss {0} is larger than previous loss {1}, while loss function must be monotonically decreasing across iterations", newLoss, loss)); + } + else if (newLoss.Equals(loss)) + { + _logger.Log(Level.Info, string.Format(CultureInfo.InvariantCulture, "KMeans clustering has converged with a loss value of {0} at iteration {1} ", newLoss, _iteration)); + break; + } + else + { + loss = newLoss; + } + } + _controlBroadcastSender.Send(ControlMessage.RECEIVE); + _dataBroadcastSender.Send(_centroids); + _iteration++; + } + _controlBroadcastSender.Send(ControlMessage.STOP); + return null; + } + + public void Dispose() + { + } + + public class AggregateMeans : IReduceFunction<ProcessedResults> + { + [Inject] + public AggregateMeans() + { + } + + public ProcessedResults Reduce(IEnumerable<ProcessedResults> elements) + { + List<PartialMean> aggregatedMeans = new List<PartialMean>(); + List<PartialMean> totalList = new List<PartialMean>(); + float aggregatedLoss = 0; + + foreach (var element in elements) + { + totalList.AddRange(element.Means); + aggregatedLoss += element.Loss; + } + + // we infer the value of K from the labeled data + int clustersNum = totalList.Max(p => p.Mean.Label) + 1; + for (int i = 0; i < clustersNum; i++) + { + List<PartialMean> means = totalList.Where(m => m.Mean.Label == i).ToList(); + aggregatedMeans.Add(new PartialMean(PartialMean.AggreatedMean(means), means.Count)); + } + + ProcessedResults returnMeans = new ProcessedResults(aggregatedMeans, aggregatedLoss); + + _logger.Log(Level.Info, "The true means aggregated by the reduce function: " + returnMeans); + return returnMeans; + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansSlaveTask.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansSlaveTask.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansSlaveTask.cs new file mode 100644 index 0000000..a36fbcb --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/KMeansSlaveTask.cs @@ -0,0 +1,118 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System.Collections.Generic; +using System.Linq; +using Org.Apache.REEF.Common.Tasks; +using Org.Apache.REEF.Network.Group.Operators; +using Org.Apache.REEF.Network.Group.Task; +using Org.Apache.REEF.Network.NetworkService; +using Org.Apache.REEF.Tang.Annotations; +using Org.Apache.REEF.Utilities.Logging; + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans +{ + public class KMeansSlaveTask : ITask + { + private static Logger _logger = Logger.GetLogger(typeof(KMeansSlaveTask)); + private int _clustersNum; + private IMpiClient _mpiClient; + private ICommunicationGroupClient _commGroup; + private IBroadcastReceiver<Centroids> _dataBroadcastReceiver; + private IBroadcastReceiver<ControlMessage> _controlBroadcastReceiver; + private IReduceSender<ProcessedResults> _partialMeansSender; + private DataPartitionCache _dataPartition; + + [Inject] + public KMeansSlaveTask( + DataPartitionCache dataPartition, + [Parameter(typeof(KMeansConfiguratioinOptions.TotalNumEvaluators))] int clustersNumber, + IMpiClient mpiClient) + { + using (_logger.LogFunction("KMeansSlaveTask::KMeansSlaveTask")) + { + _dataPartition = dataPartition; + _mpiClient = mpiClient; + _clustersNum = clustersNumber; + _commGroup = _mpiClient.GetCommunicationGroup(Constants.KMeansCommunicationGroupName); + _dataBroadcastReceiver = _commGroup.GetBroadcastReceiver<Centroids>(Constants.CentroidsBroadcastOperatorName); + _partialMeansSender = _commGroup.GetReduceSender<ProcessedResults>(Constants.MeansReduceOperatorName); + _controlBroadcastReceiver = _commGroup.GetBroadcastReceiver<ControlMessage>(Constants.ControlMessageBroadcastOperatorName); + } + } + + public byte[] Call(byte[] memento) + { + while (true) + { + if (_controlBroadcastReceiver.Receive() == ControlMessage.STOP) + { + break; + } + Centroids centroids = _dataBroadcastReceiver.Receive(); + // we compute the loss here before data is relabled, this does not reflect the latest clustering result at the end of current iteration, + // but it will save another round of group communications in each iteration + _logger.Log(Level.Info, "Received centroids from master: " + centroids); + _dataPartition.LabelData(centroids); + ProcessedResults partialMeans = new ProcessedResults(ComputePartialMeans(), ComputeLossFunction(centroids, _dataPartition.DataVectors)); + _logger.Log(Level.Info, "Sending partial means: " + partialMeans); + _partialMeansSender.Send(partialMeans); + } + + return null; + } + + public void Dispose() + { + _mpiClient.Dispose(); + } + + private List<PartialMean> ComputePartialMeans() + { + _logger.Log(Level.Verbose, "cluster number " + _clustersNum); + List<PartialMean> partialMeans = new PartialMean[_clustersNum].ToList(); + for (int i = 0; i < _clustersNum; i++) + { + List<DataVector> slices = _dataPartition.DataVectors.Where(d => d.Label == i).ToList(); + DataVector average = new DataVector(_dataPartition.DataVectors[0].Dimension); + + if (slices.Count > 1) + { + average = DataVector.Mean(slices); + } + average.Label = i; + partialMeans[i] = new PartialMean(average, slices.Count); + _logger.Log(Level.Info, "Adding to partial means list: " + partialMeans[i]); + } + return partialMeans; + } + + private float ComputeLossFunction(Centroids centroids, List<DataVector> labeledData) + { + float d = 0; + for (int i = 0; i < centroids.Points.Count; i++) + { + DataVector centroid = centroids.Points[i]; + List<DataVector> slice = labeledData.Where(v => v.Label == centroid.Label).ToList(); + d += centroid.DistanceTo(slice); + } + return d; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/LegacyKMeansTask.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/LegacyKMeansTask.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/LegacyKMeansTask.cs new file mode 100644 index 0000000..b674d84 --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/LegacyKMeansTask.cs @@ -0,0 +1,113 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Org.Apache.REEF.Tang.Annotations; + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans +{ + /// <summary> + /// This is the legacy KmeansTask implmented when group communications are not available + /// It is still being used for plain KMeans without REEF, we probably want to refactor it later + /// to reflect that + /// </summary> + public class LegacyKMeansTask + { + private int _clustersNum; + private DataPartitionCache _dataPartition; + private string _kMeansExecutionDirectory; + + private Centroids _centroids; + private List<PartialMean> _partialMeans; + + [Inject] + public LegacyKMeansTask( + DataPartitionCache dataPartition, + [Parameter(Value = typeof(KMeansConfiguratioinOptions.K))] int clustersNumber, + [Parameter(Value = typeof(KMeansConfiguratioinOptions.ExecutionDirectory))] string executionDirectory) + { + _dataPartition = dataPartition; + _clustersNum = clustersNumber; + _kMeansExecutionDirectory = executionDirectory; + if (_centroids == null) + { + string centroidFile = Path.Combine(_kMeansExecutionDirectory, Constants.CentroidsFile); + _centroids = new Centroids(DataPartitionCache.ReadDataFile(centroidFile)); + } + } + + public static float ComputeLossFunction(List<DataVector> centroids, List<DataVector> labeledData) + { + float d = 0; + for (int i = 0; i < centroids.Count; i++) + { + DataVector centroid = centroids[i]; + List<DataVector> slice = labeledData.Where(v => v.Label == centroid.Label).ToList(); + d += centroid.DistanceTo(slice); + } + return d; + } + + public byte[] CallWithWritingToFileSystem(byte[] memento) + { + string centroidFile = Path.Combine(_kMeansExecutionDirectory, Constants.CentroidsFile); + _centroids = new Centroids(DataPartitionCache.ReadDataFile(centroidFile)); + + _dataPartition.LabelData(_centroids); + _partialMeans = ComputePartialMeans(); + + // should be replaced with MPI + using (StreamWriter writer = new StreamWriter( + File.OpenWrite(Path.Combine(_kMeansExecutionDirectory, Constants.DataDirectory, Constants.PartialMeanFilePrefix + _dataPartition.Partition)))) + { + for (int i = 0; i < _partialMeans.Count; i++) + { + writer.WriteLine(_partialMeans[i].ToString()); + } + writer.Close(); + } + + return null; + } + + public List<PartialMean> ComputePartialMeans() + { + List<PartialMean> partialMeans = new PartialMean[_clustersNum].ToList(); + for (int i = 0; i < _clustersNum; i++) + { + List<DataVector> slices = _dataPartition.DataVectors.Where(d => d.Label == i).ToList(); + DataVector average = new DataVector(_dataPartition.DataVectors[0].Dimension); + + if (slices.Count > 1) + { + average = DataVector.Mean(slices); + } + average.Label = i; + partialMeans[i] = new PartialMean(average, slices.Count); + } + return partialMeans; + } + + public void Dispose() + { + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/PartialMean.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/PartialMean.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/PartialMean.cs new file mode 100644 index 0000000..6f44167 --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/PartialMean.cs @@ -0,0 +1,124 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Linq; + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans +{ + public class PartialMean + { + public PartialMean(DataVector vector, int size) + { + Mean = vector; + Size = size; + } + + public PartialMean() + { + } + + public DataVector Mean { get; set; } + + public int Size { get; set; } + + public static PartialMean FromString(string str) + { + if (string.IsNullOrWhiteSpace(str)) + { + throw new ArgumentException("str"); + } + string[] parts = str.Split('#'); + if (parts == null || parts.Length != 2) + { + throw new ArgumentException("Cannot deserialize PartialMean from string " + str); + } + return new PartialMean(DataVector.FromString(parts[0]), int.Parse(parts[1], CultureInfo.InvariantCulture)); + } + + public static DataVector AggreatedMean(List<PartialMean> means) + { + if (means == null || means.Count == 0) + { + throw new ArgumentException("means"); + } + PartialMean mean = means[0]; + for (int i = 1; i < means.Count; i++) + { + mean = mean.CombinePartialMean(means[i]); + } + return mean.Mean; + } + + public static List<DataVector> AggregateTrueMeansToFileSystem(int partitionsNum, int clustersNum, string executionDirectory) + { + List<PartialMean> partialMeans = new List<PartialMean>(); + for (int i = 0; i < partitionsNum; i++) + { + // should be replaced with MPI + string path = Path.Combine(executionDirectory, Constants.DataDirectory, Constants.PartialMeanFilePrefix + i.ToString(CultureInfo.InvariantCulture)); + FileStream file = new FileStream(path, FileMode.Open, FileAccess.Read, FileShare.Read); + using (StreamReader reader = new StreamReader(file)) + { + int index = 0; + while (!reader.EndOfStream) + { + string line = reader.ReadLine(); + if (index++ < clustersNum) + { + partialMeans.Add(PartialMean.FromString(line)); + } + } + reader.Close(); + } + } + List<DataVector> newCentroids = new List<DataVector>(); + for (int i = 0; i < clustersNum; i++) + { + List<PartialMean> means = partialMeans.Where(m => m.Mean.Label == i).ToList(); + newCentroids.Add(PartialMean.AggreatedMean(means)); + } + return newCentroids; + } + + public override string ToString() + { + return Mean.ToString() + "#" + Size; + } + + private PartialMean CombinePartialMean(PartialMean other) + { + PartialMean aggreatedMean = new PartialMean(); + if (other == null) + { + throw new ArgumentNullException("other"); + } + if (Mean.Label != other.Mean.Label) + { + throw new ArgumentException("cannot combine partial means with different labels"); + } + aggreatedMean.Size = Size + other.Size; + aggreatedMean.Mean = (Mean.MultiplyScalar(Size).Add(other.Mean.MultiplyScalar(other.Size))).Normalize(aggreatedMean.Size); + return aggreatedMean; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/ProcessedResults.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/ProcessedResults.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/ProcessedResults.cs new file mode 100644 index 0000000..3e3394f --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/ProcessedResults.cs @@ -0,0 +1,54 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System.Collections.Generic; +using Org.Apache.REEF.Examples.MachineLearning.KMeans.codecs; +using Org.Apache.REEF.Utilities; + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans +{ + /// <summary> + /// ProcessedResults includes a list of "PartialMeans" and "Partial loss", but it can be used to denote + /// the "whole" means as well, aggreated from all PartialMeans + /// </summary> + public class ProcessedResults + { + public ProcessedResults(List<PartialMean> means, float loss) + { + Means = means; + Loss = loss; + } + + public List<PartialMean> Means { get; set; } + + /// <summary> + /// the loss for current slice computed from + /// </summary> + public float Loss { get; set; } + + /// <summary> + /// helper function mostly used for logging + /// </summary> + /// <returns>seralized string</returns> + public override string ToString() + { + return ByteUtilities.ByteArrarysToString(new ProcessedResultsCodec().Encode(this)); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/codecs/CentroidsCodec.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/codecs/CentroidsCodec.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/codecs/CentroidsCodec.cs new file mode 100644 index 0000000..19480b6 --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/codecs/CentroidsCodec.cs @@ -0,0 +1,49 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using Org.Apache.REEF.Examples.MachineLearning.KMeans.Contracts; +using Org.Apache.REEF.Tang.Annotations; +using Org.Apache.REEF.Utilities; +using Org.Apache.REEF.Wake.Remote; + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans.codecs +{ + /// <summary> + /// Used to serialize and deserialize Centroids objects. + /// </summary> + public class CentroidsCodec : ICodec<Centroids> + { + [Inject] + public CentroidsCodec() + { + } + + public byte[] Encode(Centroids centroids) + { + CentroidsContract contract = CentroidsContract.Create(centroids); + return AvroUtils.AvroSerialize(contract); + } + + public Centroids Decode(byte[] data) + { + CentroidsContract contract = AvroUtils.AvroDeserialize<CentroidsContract>(data); + return contract.ToCentroids(); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/codecs/DataVectorCodec.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/codecs/DataVectorCodec.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/codecs/DataVectorCodec.cs new file mode 100644 index 0000000..820162e --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/codecs/DataVectorCodec.cs @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using Org.Apache.REEF.Examples.MachineLearning.KMeans.Contracts; +using Org.Apache.REEF.Tang.Annotations; +using Org.Apache.REEF.Utilities; +using Org.Apache.REEF.Wake.Remote; + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans.codecs +{ + public class DataVectorCodec : ICodec<DataVector> + { + [Inject] + public DataVectorCodec() + { + } + + public byte[] Encode(DataVector obj) + { + DataVectorContract contract = DataVectorContract.Create(obj); + return AvroUtils.AvroSerialize(contract); + } + + public DataVector Decode(byte[] data) + { + DataVectorContract contract = AvroUtils.AvroDeserialize<DataVectorContract>(data); + return contract.ToDataVector(); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/codecs/ProcessedResultsCodec.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/codecs/ProcessedResultsCodec.cs b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/codecs/ProcessedResultsCodec.cs new file mode 100644 index 0000000..8166030 --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Examples/MachineLearning/KMeans/codecs/ProcessedResultsCodec.cs @@ -0,0 +1,57 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using Org.Apache.REEF.Tang.Annotations; +using Org.Apache.REEF.Utilities; +using Org.Apache.REEF.Wake.Remote; + +namespace Org.Apache.REEF.Examples.MachineLearning.KMeans.codecs +{ + /// <summary> + /// TODO: use proper avro scheme to do encode/decode + /// </summary> + public class ProcessedResultsCodec : ICodec<ProcessedResults> + { + [Inject] + public ProcessedResultsCodec() + { + } + + public byte[] Encode(ProcessedResults results) + { + return ByteUtilities.StringToByteArrays(results.Loss + "+" + string.Join("@", results.Means.Select(m => m.ToString()))); + } + + public ProcessedResults Decode(byte[] data) + { + string[] parts = ByteUtilities.ByteArrarysToString(data).Split('+'); + if (parts.Count() != 2) + { + throw new ArgumentException("cannot deserialize from" + ByteUtilities.ByteArrarysToString(data)); + } + float loss = float.Parse(parts[0], CultureInfo.InvariantCulture); + List<PartialMean> means = parts[1].Split('@').Select(PartialMean.FromString).ToList(); + return new ProcessedResults(means, loss); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Examples/Org.Apache.REEF.Examples.csproj ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Examples/Org.Apache.REEF.Examples.csproj b/lang/cs/Org.Apache.REEF.Examples/Org.Apache.REEF.Examples.csproj index e2691c5..e3ca589 100644 --- a/lang/cs/Org.Apache.REEF.Examples/Org.Apache.REEF.Examples.csproj +++ b/lang/cs/Org.Apache.REEF.Examples/Org.Apache.REEF.Examples.csproj @@ -72,6 +72,7 @@ under the License. <ItemGroup> <Reference Include="System" /> <Reference Include="System.Core" /> + <Reference Include="System.Runtime.Serialization" /> <Reference Include="System.Xml.Linq" /> <Reference Include="System.Data.DataSetExtensions" /> <Reference Include="Microsoft.CSharp" /> @@ -95,6 +96,24 @@ under the License. <Compile Include="HelloCLRBridge\Handlers\HelloStartHandler.cs" /> <Compile Include="HelloCLRBridge\Handlers\HelloTaskMessageHandler.cs" /> <Compile Include="HelloCLRBridge\HelloTraceListener.cs" /> + <Compile Include="MachineLearning\KMeans\Centroids.cs" /> + <Compile Include="MachineLearning\KMeans\codecs\CentroidsCodec.cs" /> + <Compile Include="MachineLearning\KMeans\codecs\DataVectorCodec.cs" /> + <Compile Include="MachineLearning\KMeans\codecs\ProcessedResultsCodec.cs" /> + <Compile Include="MachineLearning\KMeans\Constants.cs" /> + <Compile Include="MachineLearning\KMeans\Contracts\CentroidsContract.cs" /> + <Compile Include="MachineLearning\KMeans\Contracts\DataVectorContract.cs" /> + <Compile Include="MachineLearning\KMeans\Contracts\PartialMeanContract.cs" /> + <Compile Include="MachineLearning\KMeans\Contracts\ProcessedResultsContract.cs" /> + <Compile Include="MachineLearning\KMeans\DataPartitionCache.cs" /> + <Compile Include="MachineLearning\KMeans\DataVector.cs" /> + <Compile Include="MachineLearning\KMeans\KMeansConfiguratioinOptions.cs" /> + <Compile Include="MachineLearning\KMeans\KMeansDriverHandlers.cs" /> + <Compile Include="MachineLearning\KMeans\KMeansMasterTask.cs" /> + <Compile Include="MachineLearning\KMeans\KMeansSlaveTask.cs" /> + <Compile Include="MachineLearning\KMeans\LegacyKMeansTask.cs" /> + <Compile Include="MachineLearning\KMeans\PartialMean.cs" /> + <Compile Include="MachineLearning\KMeans\ProcessedResults.cs" /> <Compile Include="RetainedEvalCLRBridge\Handlers\RetainedEvalActiveContextHandler.cs" /> <Compile Include="RetainedEvalCLRBridge\Handlers\RetainedEvalAllocatedEvaluatorHandler.cs" /> <Compile Include="RetainedEvalCLRBridge\Handlers\RetainedEvalEvaluatorRequestorHandler.cs" /> http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Network/Group/Codec/GcmMessageProto.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Codec/GcmMessageProto.cs b/lang/cs/Org.Apache.REEF.Network/Group/Codec/GcmMessageProto.cs new file mode 100644 index 0000000..8a5a726 --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Network/Group/Codec/GcmMessageProto.cs @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Org.Apache.REEF.Network.Group.Driver; +using Org.Apache.REEF.Network.Group.Driver.Impl; +using ProtoBuf; + +namespace Org.Apache.REEF.Network.Group.Codec +{ + [ProtoContract] + public class GcmMessageProto + { + [ProtoMember(1)] + public byte[][] Data { get; set; } + + [ProtoMember(2)] + public string OperatorName { get; set; } + + [ProtoMember(3)] + public string GroupName { get; set; } + + [ProtoMember(4)] + public string Source { get; set; } + + [ProtoMember(5)] + public string Destination { get; set; } + + [ProtoMember(6)] + public MessageType MsgType { get; set; } + + public static GcmMessageProto Create(GroupCommunicationMessage gcm) + { + return new GcmMessageProto() + { + Data = gcm.Data, + OperatorName = gcm.OperatorName, + GroupName = gcm.GroupName, + Source = gcm.Source, + Destination = gcm.Destination, + MsgType = gcm.MsgType, + }; + } + + public GroupCommunicationMessage ToGcm() + { + return new GroupCommunicationMessage( + GroupName, + OperatorName, + Source, + Destination, + Data, + MsgType); + } + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/0292caf1/lang/cs/Org.Apache.REEF.Network/Group/Codec/GroupCommunicationMessageCodec.cs ---------------------------------------------------------------------- diff --git a/lang/cs/Org.Apache.REEF.Network/Group/Codec/GroupCommunicationMessageCodec.cs b/lang/cs/Org.Apache.REEF.Network/Group/Codec/GroupCommunicationMessageCodec.cs new file mode 100644 index 0000000..a8f884f --- /dev/null +++ b/lang/cs/Org.Apache.REEF.Network/Group/Codec/GroupCommunicationMessageCodec.cs @@ -0,0 +1,77 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Org.Apache.REEF.Network.Group.Codec; +using Org.Apache.REEF.Network.Group.Driver; +using Org.Apache.REEF.Network.Group.Driver.Impl; +using Org.Apache.REEF.Tang.Annotations; +using Org.Apache.REEF.Wake.Remote; +using ProtoBuf; + +namespace Org.Apache.REEF.Network.Group.Codec +{ + /// <summary> + /// Used to serialize GroupCommunicationMessages. + /// </summary> + public class GroupCommunicationMessageCodec : ICodec<GroupCommunicationMessage> + { + /// <summary> + /// Create a new GroupCommunicationMessageCodec. + /// </summary> + [Inject] + public GroupCommunicationMessageCodec() + { + } + + /// <summary> + /// Serialize the GroupCommunicationObject into a byte array using Protobuf. + /// </summary> + /// <param name="obj">The object to serialize.</param> + /// <returns>The serialized GroupCommunicationMessage in byte array form</returns> + public byte[] Encode(GroupCommunicationMessage obj) + { + GcmMessageProto proto = GcmMessageProto.Create(obj); + using (var stream = new MemoryStream()) + { + Serializer.Serialize(stream, proto); + return stream.ToArray(); + } + } + + /// <summary> + /// Deserialize the byte array into a GroupCommunicationMessage using Protobuf. + /// </summary> + /// <param name="data">The byte array to deserialize</param> + /// <returns>The deserialized GroupCommunicationMessage object.</returns> + public GroupCommunicationMessage Decode(byte[] data) + { + using (var stream = new MemoryStream(data)) + { + GcmMessageProto proto = Serializer.Deserialize<GcmMessageProto>(stream); + return proto.ToGcm(); + } + } + } +}
