http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/WeightedLogisticLossFunction.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/WeightedLogisticLossFunction.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/WeightedLogisticLossFunction.java new file mode 100644 index 0000000..2ad11fc --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/WeightedLogisticLossFunction.java @@ -0,0 +1,74 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.loss; + +import javax.inject.Inject; + +public final class WeightedLogisticLossFunction implements LossFunction { + + private static final double POS = 0.0025; + private static final double NEG = 0.9975; + + private final double posWeight; + private final double negWeight; + + /** + * Trivial constructor. + */ + @Inject + public WeightedLogisticLossFunction() { + this.posWeight = (this.POS + this.NEG) / (2 * this.POS); + this.negWeight = (this.POS + this.NEG) / (2 * this.NEG); + } + + @Override + public double computeLoss(double y, double f) { + + final double predictedTimesLabel = y * f; + final double weight = y == -1 ? this.negWeight : this.posWeight; + + if (predictedTimesLabel >= 0) { + return weight * Math.log(1 + Math.exp(-predictedTimesLabel)); + } else { + return weight * (-predictedTimesLabel + Math.log(1 + Math.exp(predictedTimesLabel))); + } + } + + @Override + public double computeGradient(double y, double f) { + + final double predictedTimesLabel = y * f; + final double weight = y == -1 ? this.negWeight : this.posWeight; + + final double probability; + if (predictedTimesLabel >= 0) { + probability = 1 / (1 + Math.exp(-predictedTimesLabel)); + } else { + final double ExpVal = Math.exp(predictedTimesLabel); + probability = ExpVal / (1 + ExpVal); + } + + return (probability - 1) * y * weight; + } + + @Override + public String toString() { + return "WeightedLogisticLossFunction{}"; + } +}
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/ControlMessageBroadcaster.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/ControlMessageBroadcaster.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/ControlMessageBroadcaster.java new file mode 100644 index 0000000..d7013fb --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/ControlMessageBroadcaster.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.operatornames; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * Used to identify the broadcast operator for control flow messages. + */ +@NamedParameter() +public final class ControlMessageBroadcaster implements Name<String> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/DescentDirectionBroadcaster.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/DescentDirectionBroadcaster.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/DescentDirectionBroadcaster.java new file mode 100644 index 0000000..1d8a148 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/DescentDirectionBroadcaster.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.operatornames; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * Name of the broadcast operator used to send descent directions during linesearch. + */ +@NamedParameter() +public final class DescentDirectionBroadcaster implements Name<String> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/LineSearchEvaluationsReducer.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/LineSearchEvaluationsReducer.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/LineSearchEvaluationsReducer.java new file mode 100644 index 0000000..411fd17 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/LineSearchEvaluationsReducer.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.operatornames; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * Name of the reducer used to aggregate line search results. + */ +@NamedParameter() +public final class LineSearchEvaluationsReducer implements Name<String> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/LossAndGradientReducer.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/LossAndGradientReducer.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/LossAndGradientReducer.java new file mode 100644 index 0000000..91261ce --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/LossAndGradientReducer.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.operatornames; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * Name used for the Reduce operator for loss and gradient aggregation. + */ +@NamedParameter() +public final class LossAndGradientReducer implements Name<String> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/MinEtaBroadcaster.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/MinEtaBroadcaster.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/MinEtaBroadcaster.java new file mode 100644 index 0000000..eb230d2 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/MinEtaBroadcaster.java @@ -0,0 +1,26 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.operatornames; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +@NamedParameter() +public final class MinEtaBroadcaster implements Name<String> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/ModelAndDescentDirectionBroadcaster.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/ModelAndDescentDirectionBroadcaster.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/ModelAndDescentDirectionBroadcaster.java new file mode 100644 index 0000000..755cc00 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/ModelAndDescentDirectionBroadcaster.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.operatornames; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * Name of the broadcast operator used to send a model and descent direction during line search. + */ +@NamedParameter() +public final class ModelAndDescentDirectionBroadcaster implements Name<String> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/ModelBroadcaster.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/ModelBroadcaster.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/ModelBroadcaster.java new file mode 100644 index 0000000..8a1aae0 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/ModelBroadcaster.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.operatornames; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * The name of the broadcast operator used for model broadcasts. + */ +@NamedParameter() +public final class ModelBroadcaster implements Name<String> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/package-info.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/package-info.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/package-info.java new file mode 100644 index 0000000..5c364cc --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/operatornames/package-info.java @@ -0,0 +1,23 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Parameter names used to identify the various operators used in BGD. + */ +package org.apache.reef.examples.group.bgd.operatornames; http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/AllCommunicationGroup.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/AllCommunicationGroup.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/AllCommunicationGroup.java new file mode 100644 index 0000000..dd173a3 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/AllCommunicationGroup.java @@ -0,0 +1,26 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +@NamedParameter() +public final class AllCommunicationGroup implements Name<String> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/BGDControlParameters.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/BGDControlParameters.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/BGDControlParameters.java new file mode 100644 index 0000000..ab60d16 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/BGDControlParameters.java @@ -0,0 +1,126 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.examples.group.bgd.loss.LossFunction; +import org.apache.reef.tang.Configuration; +import org.apache.reef.tang.Tang; +import org.apache.reef.tang.annotations.Parameter; +import org.apache.reef.tang.formats.CommandLine; + +import javax.inject.Inject; + +public final class BGDControlParameters { + + private final int dimensions; + private final double lambda; + private final double eps; + private final int iters; + private final int minParts; + private final boolean rampup; + + private final double eta; + private final double probOfSuccessfulIteration; + private final BGDLossType lossType; + + @Inject + public BGDControlParameters( + final @Parameter(ModelDimensions.class) int dimensions, + final @Parameter(Lambda.class) double lambda, + final @Parameter(Eps.class) double eps, + final @Parameter(Eta.class) double eta, + final @Parameter(ProbabilityOfSuccesfulIteration.class) double probOfSuccessfulIteration, + final @Parameter(Iterations.class) int iters, + final @Parameter(EnableRampup.class) boolean rampup, + final @Parameter(MinParts.class) int minParts, + final BGDLossType lossType) { + this.dimensions = dimensions; + this.lambda = lambda; + this.eps = eps; + this.eta = eta; + this.probOfSuccessfulIteration = probOfSuccessfulIteration; + this.iters = iters; + this.rampup = rampup; + this.minParts = minParts; + this.lossType = lossType; + } + + public Configuration getConfiguration() { + return Tang.Factory.getTang().newConfigurationBuilder() + .bindNamedParameter(ModelDimensions.class, Integer.toString(this.dimensions)) + .bindNamedParameter(Lambda.class, Double.toString(this.lambda)) + .bindNamedParameter(Eps.class, Double.toString(this.eps)) + .bindNamedParameter(Eta.class, Double.toString(this.eta)) + .bindNamedParameter(ProbabilityOfSuccesfulIteration.class, Double.toString(probOfSuccessfulIteration)) + .bindNamedParameter(Iterations.class, Integer.toString(this.iters)) + .bindNamedParameter(EnableRampup.class, Boolean.toString(this.rampup)) + .bindNamedParameter(MinParts.class, Integer.toString(this.minParts)) + .bindNamedParameter(LossFunctionType.class, lossType.lossFunctionString()) + .build(); + } + + public static CommandLine registerShortNames(final CommandLine commandLine) { + return commandLine + .registerShortNameOfClass(ModelDimensions.class) + .registerShortNameOfClass(Lambda.class) + .registerShortNameOfClass(Eps.class) + .registerShortNameOfClass(Eta.class) + .registerShortNameOfClass(ProbabilityOfSuccesfulIteration.class) + .registerShortNameOfClass(Iterations.class) + .registerShortNameOfClass(EnableRampup.class) + .registerShortNameOfClass(MinParts.class) + .registerShortNameOfClass(LossFunctionType.class); + } + + public int getDimensions() { + return this.dimensions; + } + + public double getLambda() { + return this.lambda; + } + + public double getEps() { + return this.eps; + } + + public double getEta() { + return this.eta; + } + + public double getProbOfSuccessfulIteration() { + return probOfSuccessfulIteration; + } + + public int getIters() { + return this.iters; + } + + public int getMinParts() { + return this.minParts; + } + + public boolean isRampup() { + return this.rampup; + } + + public Class<? extends LossFunction> getLossFunction() { + return this.lossType.getLossFunction(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/BGDLossType.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/BGDLossType.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/BGDLossType.java new file mode 100644 index 0000000..e20dcaf --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/BGDLossType.java @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.examples.group.bgd.loss.LogisticLossFunction; +import org.apache.reef.examples.group.bgd.loss.LossFunction; +import org.apache.reef.examples.group.bgd.loss.SquaredErrorLossFunction; +import org.apache.reef.examples.group.bgd.loss.WeightedLogisticLossFunction; +import org.apache.reef.tang.annotations.Parameter; + +import javax.inject.Inject; +import java.util.HashMap; +import java.util.Map; + +public class BGDLossType { + + private static final Map<String, Class<? extends LossFunction>> LOSS_FUNCTIONS = + new HashMap<String, Class<? extends LossFunction>>() {{ + put("logLoss", LogisticLossFunction.class); + put("weightedLogLoss", WeightedLogisticLossFunction.class); + put("squaredError", SquaredErrorLossFunction.class); + }}; + + private final Class<? extends LossFunction> lossFunction; + + private final String lossFunctionStr; + + @Inject + public BGDLossType(@Parameter(LossFunctionType.class) final String lossFunctionStr) { + this.lossFunctionStr = lossFunctionStr; + this.lossFunction = LOSS_FUNCTIONS.get(lossFunctionStr); + if (this.lossFunction == null) { + throw new RuntimeException("Specified loss function type: " + lossFunctionStr + + " is not implemented. Supported types are logLoss|weightedLogLoss|squaredError"); + } + } + + public Class<? extends LossFunction> getLossFunction() { + return this.lossFunction; + } + + public String lossFunctionString() { + return lossFunctionStr; + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/EnableRampup.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/EnableRampup.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/EnableRampup.java new file mode 100644 index 0000000..cc388a9 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/EnableRampup.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * Maximum Number of Iterations. + */ +@NamedParameter(doc = "Should we ram-up?", short_name = "rampup", default_value = "false") +public final class EnableRampup implements Name<Boolean> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Eps.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Eps.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Eps.java new file mode 100644 index 0000000..05d7b84 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Eps.java @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * Break criterion for the optimizer. If the progress in mean loss between + * two iterations is less than this, the optimization stops. + */ +@NamedParameter(short_name = "eps", default_value = "1e-6") +public final class Eps implements Name<Double> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Eta.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Eta.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Eta.java new file mode 100644 index 0000000..59ef312 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Eta.java @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * Break criterion for the optimizer. If the progress in mean loss between + * two iterations is less than this, the optimization stops. + */ +@NamedParameter(short_name = "eta", default_value = "0.01") +public final class Eta implements Name<Double> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/EvaluatorMemory.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/EvaluatorMemory.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/EvaluatorMemory.java new file mode 100644 index 0000000..7b1015e --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/EvaluatorMemory.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * The memory used for each Evaluator. In MB. + */ +@NamedParameter(short_name = "memory", default_value = "1024", doc = "The memory used for each Evaluator. In MB.") +public final class EvaluatorMemory implements Name<Integer> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/InputDir.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/InputDir.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/InputDir.java new file mode 100644 index 0000000..9ad5656 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/InputDir.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * The input folder of the learner. + */ +@NamedParameter(short_name = "input") +public final class InputDir implements Name<String> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Iterations.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Iterations.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Iterations.java new file mode 100644 index 0000000..79530be --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Iterations.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * Maximum Number of Iterations. + */ +@NamedParameter(doc = "Number of iterations", short_name = "iterations", default_value = "100") +public final class Iterations implements Name<Integer> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Lambda.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Lambda.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Lambda.java new file mode 100644 index 0000000..c278ca9 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Lambda.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * The regularization constant + */ +@NamedParameter(doc = "The regularization constant", short_name = "lambda", default_value = "1e-4") +public final class Lambda implements Name<Double> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/LossFunctionType.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/LossFunctionType.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/LossFunctionType.java new file mode 100644 index 0000000..e740c95 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/LossFunctionType.java @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * + */ +@NamedParameter(doc = "Loss Function to be used: logLoss|weightedLogLoss|squaredError", short_name = "loss", default_value = "logLoss") +public class LossFunctionType implements Name<String> { + +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/MinParts.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/MinParts.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/MinParts.java new file mode 100644 index 0000000..6488c56 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/MinParts.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * Maximum Number of Iterations. + */ +@NamedParameter(doc = "Min Number of partitions", short_name = "minparts", default_value = "2") +public final class MinParts implements Name<Integer> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/ModelDimensions.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/ModelDimensions.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/ModelDimensions.java new file mode 100644 index 0000000..cd19085 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/ModelDimensions.java @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * The dimensionality of the model learned. + */ +@NamedParameter(doc = "Model dimensions", short_name = "dim") +public class ModelDimensions implements Name<Integer> { + +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/NumSplits.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/NumSplits.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/NumSplits.java new file mode 100644 index 0000000..dbbbead --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/NumSplits.java @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * + */ +// TODO: Document +@NamedParameter(short_name = "splits", default_value = "5") +public final class NumSplits implements Name<Integer> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/NumberOfReceivers.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/NumberOfReceivers.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/NumberOfReceivers.java new file mode 100644 index 0000000..bedfb5a --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/NumberOfReceivers.java @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * + */ +@NamedParameter(doc = "The number of receivers for the operators") +public class NumberOfReceivers implements Name<Integer> { + +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/ProbabilityOfFailure.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/ProbabilityOfFailure.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/ProbabilityOfFailure.java new file mode 100644 index 0000000..53fc5db --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/ProbabilityOfFailure.java @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * Break criterion for the optimizer. If the progress in mean loss between + * two iterations is less than this, the optimization stops. + */ +@NamedParameter(default_value = "0.01") +public final class ProbabilityOfFailure implements Name<Double> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/ProbabilityOfSuccesfulIteration.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/ProbabilityOfSuccesfulIteration.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/ProbabilityOfSuccesfulIteration.java new file mode 100644 index 0000000..243b82c --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/ProbabilityOfSuccesfulIteration.java @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + +/** + * Break criterion for the optimizer. If the progress in mean loss between + * two iterations is less than this, the optimization stops. + */ +@NamedParameter(short_name = "psuccess", default_value = "0.5") +public final class ProbabilityOfSuccesfulIteration implements Name<Double> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Timeout.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Timeout.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Timeout.java new file mode 100644 index 0000000..8332514 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/parameters/Timeout.java @@ -0,0 +1,28 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.parameters; + +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; + + +// TODO: Document +@NamedParameter(short_name = "timeout", default_value = "2") +public final class Timeout implements Name<Integer> { +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/utils/StepSizes.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/utils/StepSizes.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/utils/StepSizes.java new file mode 100644 index 0000000..adcb2ce --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/utils/StepSizes.java @@ -0,0 +1,59 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.utils; + +import javax.inject.Inject; +import java.util.Arrays; +import java.util.logging.Level; +import java.util.logging.Logger; + +public class StepSizes { + + private static final Logger LOG = Logger.getLogger(StepSizes.class.getName()); + + private final double[] t; + private final int gridSize = 21; + + @Inject + public StepSizes() { + this.t = new double[gridSize]; + final int mid = (gridSize / 2); + t[mid] = 1; + for (int i = mid - 1; i >= 0; i--) { + t[i] = t[i + 1] / 2.0; + } + for (int i = mid + 1; i < gridSize; i++) { + t[i] = t[i - 1] * 2.0; + } + } + + public double[] getT() { + return t; + } + + public int getGridSize() { + return gridSize; + } + + public static void main(final String[] args) { + // TODO Auto-generated method stub + final StepSizes t = new StepSizes(); + LOG.log(Level.INFO, "OUT: {0}", Arrays.toString(t.getT())); + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/utils/SubConfiguration.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/utils/SubConfiguration.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/utils/SubConfiguration.java new file mode 100644 index 0000000..4d74356 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/utils/SubConfiguration.java @@ -0,0 +1,73 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.bgd.utils; + +import org.apache.reef.driver.task.TaskConfiguration; +import org.apache.reef.driver.task.TaskConfigurationOptions; +import org.apache.reef.examples.group.bgd.MasterTask; +import org.apache.reef.tang.Configuration; +import org.apache.reef.tang.Injector; +import org.apache.reef.tang.JavaConfigurationBuilder; +import org.apache.reef.tang.Tang; +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.exceptions.InjectionException; +import org.apache.reef.tang.formats.AvroConfigurationSerializer; +import org.apache.reef.tang.formats.ConfigurationSerializer; + +import java.util.logging.Level; +import java.util.logging.Logger; + +public class SubConfiguration { + + private static final Logger LOG = Logger.getLogger(SubConfiguration.class.getName()); + + @SafeVarargs + public static Configuration from( + final Configuration baseConf, final Class<? extends Name<?>>... classes) { + + final Injector injector = Tang.Factory.getTang().newInjector(baseConf); + final JavaConfigurationBuilder confBuilder = Tang.Factory.getTang().newConfigurationBuilder(); + + for (final Class<? extends Name<?>> clazz : classes) { + try { + confBuilder.bindNamedParameter(clazz, + injector.getNamedInstance((Class<? extends Name<Object>>) clazz).toString()); + } catch (final InjectionException ex) { + final String msg = "Exception while creating subconfiguration"; + LOG.log(Level.WARNING, msg, ex); + throw new RuntimeException(msg, ex); + } + } + + return confBuilder.build(); + } + + public static void main(final String[] args) throws InjectionException { + + final Configuration conf = TaskConfiguration.CONF + .set(TaskConfiguration.IDENTIFIER, "TASK") + .set(TaskConfiguration.TASK, MasterTask.class) + .build(); + + final ConfigurationSerializer confSerizalizer = new AvroConfigurationSerializer(); + final Configuration subConf = SubConfiguration.from(conf, TaskConfigurationOptions.Identifier.class); + LOG.log(Level.INFO, "OUT: Base conf:\n{0}", confSerizalizer.toString(conf)); + LOG.log(Level.INFO, "OUT: Sub conf:\n{0}", confSerizalizer.toString(subConf)); + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/BroadcastDriver.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/BroadcastDriver.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/BroadcastDriver.java new file mode 100644 index 0000000..5505859 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/BroadcastDriver.java @@ -0,0 +1,285 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.broadcast; + +import org.apache.reef.annotations.audience.DriverSide; +import org.apache.reef.driver.context.ActiveContext; +import org.apache.reef.driver.context.ClosedContext; +import org.apache.reef.driver.context.ContextConfiguration; +import org.apache.reef.driver.evaluator.AllocatedEvaluator; +import org.apache.reef.driver.evaluator.EvaluatorRequest; +import org.apache.reef.driver.evaluator.EvaluatorRequestor; +import org.apache.reef.driver.task.FailedTask; +import org.apache.reef.driver.task.TaskConfiguration; +import org.apache.reef.evaluator.context.parameters.ContextIdentifier; +import org.apache.reef.examples.group.bgd.operatornames.ControlMessageBroadcaster; +import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup; +import org.apache.reef.examples.group.bgd.parameters.ModelDimensions; +import org.apache.reef.examples.group.broadcast.parameters.ModelBroadcaster; +import org.apache.reef.examples.group.broadcast.parameters.ModelReceiveAckReducer; +import org.apache.reef.examples.group.broadcast.parameters.NumberOfReceivers; +import org.apache.reef.io.network.group.api.driver.CommunicationGroupDriver; +import org.apache.reef.io.network.group.api.driver.GroupCommDriver; +import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec; +import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec; +import org.apache.reef.io.serialization.SerializableCodec; +import org.apache.reef.poison.PoisonedConfiguration; +import org.apache.reef.tang.Configuration; +import org.apache.reef.tang.Injector; +import org.apache.reef.tang.Tang; +import org.apache.reef.tang.annotations.Parameter; +import org.apache.reef.tang.annotations.Unit; +import org.apache.reef.tang.exceptions.InjectionException; +import org.apache.reef.tang.formats.ConfigurationSerializer; +import org.apache.reef.wake.EventHandler; +import org.apache.reef.wake.time.event.StartTime; + +import javax.inject.Inject; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.logging.Level; +import java.util.logging.Logger; + +@DriverSide +@Unit +public class BroadcastDriver { + + private static final Logger LOG = Logger.getLogger(BroadcastDriver.class.getName()); + + private final AtomicBoolean masterSubmitted = new AtomicBoolean(false); + private final AtomicInteger slaveIds = new AtomicInteger(0); + private final AtomicInteger failureSet = new AtomicInteger(0); + + private final GroupCommDriver groupCommDriver; + private final CommunicationGroupDriver allCommGroup; + private final ConfigurationSerializer confSerializer; + private final int dimensions; + private final EvaluatorRequestor requestor; + private final int numberOfReceivers; + private final AtomicInteger numberOfAllocatedEvaluators; + + private String groupCommConfiguredMasterId; + + @Inject + public BroadcastDriver( + final EvaluatorRequestor requestor, + final GroupCommDriver groupCommDriver, + final ConfigurationSerializer confSerializer, + final @Parameter(ModelDimensions.class) int dimensions, + final @Parameter(NumberOfReceivers.class) int numberOfReceivers) { + + this.requestor = requestor; + this.groupCommDriver = groupCommDriver; + this.confSerializer = confSerializer; + this.dimensions = dimensions; + this.numberOfReceivers = numberOfReceivers; + this.numberOfAllocatedEvaluators = new AtomicInteger(numberOfReceivers + 1); + + this.allCommGroup = this.groupCommDriver.newCommunicationGroup( + AllCommunicationGroup.class, numberOfReceivers + 1); + + LOG.info("Obtained all communication group"); + + this.allCommGroup + .addBroadcast(ControlMessageBroadcaster.class, + BroadcastOperatorSpec.newBuilder() + .setSenderId(MasterTask.TASK_ID) + .setDataCodecClass(SerializableCodec.class) + .build()) + .addBroadcast(ModelBroadcaster.class, + BroadcastOperatorSpec.newBuilder() + .setSenderId(MasterTask.TASK_ID) + .setDataCodecClass(SerializableCodec.class) + .build()) + .addReduce(ModelReceiveAckReducer.class, + ReduceOperatorSpec.newBuilder() + .setReceiverId(MasterTask.TASK_ID) + .setDataCodecClass(SerializableCodec.class) + .setReduceFunctionClass(ModelReceiveAckReduceFunction.class) + .build()) + .finalise(); + + LOG.info("Added operators to allCommGroup"); + } + + /** + * Handles the StartTime event: Request numOfReceivers Evaluators. + */ + final class StartHandler implements EventHandler<StartTime> { + @Override + public void onNext(final StartTime startTime) { + final int numEvals = BroadcastDriver.this.numberOfReceivers + 1; + LOG.log(Level.FINE, "Requesting {0} evaluators", numEvals); + BroadcastDriver.this.requestor.submit(EvaluatorRequest.newBuilder() + .setNumber(numEvals) + .setMemory(2048) + .build()); + } + } + + /** + * Handles AllocatedEvaluator: Submits a context with an id + */ + final class EvaluatorAllocatedHandler implements EventHandler<AllocatedEvaluator> { + @Override + public void onNext(final AllocatedEvaluator allocatedEvaluator) { + LOG.log(Level.INFO, "Submitting an id context to AllocatedEvaluator: {0}", allocatedEvaluator); + final Configuration contextConfiguration = ContextConfiguration.CONF + .set(ContextConfiguration.IDENTIFIER, "BroadcastContext-" + + BroadcastDriver.this.numberOfAllocatedEvaluators.getAndDecrement()) + .build(); + allocatedEvaluator.submitContext(contextConfiguration); + } + } + + public class FailedTaskHandler implements EventHandler<FailedTask> { + + @Override + public void onNext(final FailedTask failedTask) { + + LOG.log(Level.FINE, "Got failed Task: {0}", failedTask.getId()); + + final ActiveContext activeContext = failedTask.getActiveContext().get(); + final Configuration partialTaskConf = Tang.Factory.getTang() + .newConfigurationBuilder( + TaskConfiguration.CONF + .set(TaskConfiguration.IDENTIFIER, failedTask.getId()) + .set(TaskConfiguration.TASK, SlaveTask.class) + .build(), + PoisonedConfiguration.TASK_CONF + .set(PoisonedConfiguration.CRASH_PROBABILITY, "0") + .set(PoisonedConfiguration.CRASH_TIMEOUT, "1") + .build()) + .bindNamedParameter(ModelDimensions.class, "" + dimensions) + .build(); + + // Do not add the task back: + // allCommGroup.addTask(partialTaskConf); + + final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf); + LOG.log(Level.FINER, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf)); + + activeContext.submitTask(taskConf); + } + } + + public class ContextActiveHandler implements EventHandler<ActiveContext> { + + private final AtomicBoolean storeMasterId = new AtomicBoolean(false); + + @Override + public void onNext(final ActiveContext activeContext) { + + LOG.log(Level.FINE, "Got active context: {0}", activeContext.getId()); + + /** + * The active context can be either from data loading service or after network + * service has loaded contexts. So check if the GroupCommDriver knows if it was + * configured by one of the communication groups. + */ + if (groupCommDriver.isConfigured(activeContext)) { + + if (activeContext.getId().equals(groupCommConfiguredMasterId) && !masterTaskSubmitted()) { + + final Configuration partialTaskConf = Tang.Factory.getTang() + .newConfigurationBuilder( + TaskConfiguration.CONF + .set(TaskConfiguration.IDENTIFIER, MasterTask.TASK_ID) + .set(TaskConfiguration.TASK, MasterTask.class) + .build()) + .bindNamedParameter(ModelDimensions.class, Integer.toString(dimensions)) + .build(); + + allCommGroup.addTask(partialTaskConf); + + final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf); + LOG.log(Level.FINER, "Submit MasterTask conf: {0}", confSerializer.toString(taskConf)); + + activeContext.submitTask(taskConf); + + } else { + + final Configuration partialTaskConf = Tang.Factory.getTang() + .newConfigurationBuilder( + TaskConfiguration.CONF + .set(TaskConfiguration.IDENTIFIER, getSlaveId(activeContext)) + .set(TaskConfiguration.TASK, SlaveTask.class) + .build(), + PoisonedConfiguration.TASK_CONF + .set(PoisonedConfiguration.CRASH_PROBABILITY, "0.4") + .set(PoisonedConfiguration.CRASH_TIMEOUT, "1") + .build()) + .bindNamedParameter(ModelDimensions.class, Integer.toString(dimensions)) + .build(); + + allCommGroup.addTask(partialTaskConf); + + final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf); + LOG.log(Level.FINER, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf)); + + activeContext.submitTask(taskConf); + } + } else { + + final Configuration contextConf = groupCommDriver.getContextConfiguration(); + final String contextId = contextId(contextConf); + + if (storeMasterId.compareAndSet(false, true)) { + groupCommConfiguredMasterId = contextId; + } + + final Configuration serviceConf = groupCommDriver.getServiceConfiguration(); + LOG.log(Level.FINER, "Submit GCContext conf: {0}", confSerializer.toString(contextConf)); + LOG.log(Level.FINER, "Submit Service conf: {0}", confSerializer.toString(serviceConf)); + + activeContext.submitContextAndService(contextConf, serviceConf); + } + } + + private String contextId(final Configuration contextConf) { + try { + final Injector injector = Tang.Factory.getTang().newInjector(contextConf); + return injector.getNamedInstance(ContextIdentifier.class); + } catch (final InjectionException e) { + throw new RuntimeException("Unable to inject context identifier from context conf", e); + } + } + + private String getSlaveId(final ActiveContext activeContext) { + return "SlaveTask-" + slaveIds.getAndIncrement(); + } + + private boolean masterTaskSubmitted() { + return !masterSubmitted.compareAndSet(false, true); + } + } + + public class ContextCloseHandler implements EventHandler<ClosedContext> { + + @Override + public void onNext(final ClosedContext closedContext) { + LOG.log(Level.FINE, "Got closed context: {0}", closedContext.getId()); + final ActiveContext parentContext = closedContext.getParentContext(); + if (parentContext != null) { + LOG.log(Level.FINE, "Closing parent context: {0}", parentContext.getId()); + parentContext.close(); + } + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/BroadcastREEF.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/BroadcastREEF.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/BroadcastREEF.java new file mode 100644 index 0000000..cac6ccd --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/BroadcastREEF.java @@ -0,0 +1,148 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.broadcast; + +import org.apache.reef.annotations.audience.ClientSide; +import org.apache.reef.client.DriverConfiguration; +import org.apache.reef.client.DriverLauncher; +import org.apache.reef.client.LauncherStatus; +import org.apache.reef.examples.group.bgd.parameters.ModelDimensions; +import org.apache.reef.examples.group.broadcast.parameters.NumberOfReceivers; +import org.apache.reef.io.network.group.impl.driver.GroupCommService; +import org.apache.reef.runtime.local.client.LocalRuntimeConfiguration; +import org.apache.reef.runtime.yarn.client.YarnClientConfiguration; +import org.apache.reef.tang.Configuration; +import org.apache.reef.tang.Injector; +import org.apache.reef.tang.JavaConfigurationBuilder; +import org.apache.reef.tang.Tang; +import org.apache.reef.tang.annotations.Name; +import org.apache.reef.tang.annotations.NamedParameter; +import org.apache.reef.tang.exceptions.InjectionException; +import org.apache.reef.tang.formats.AvroConfigurationSerializer; +import org.apache.reef.tang.formats.CommandLine; +import org.apache.reef.util.EnvironmentUtils; + +import java.io.IOException; +import java.util.logging.Level; +import java.util.logging.Logger; + +@ClientSide +public class BroadcastREEF { + private static final Logger LOG = Logger.getLogger(BroadcastREEF.class.getName()); + + private static final String NUM_LOCAL_THREADS = "20"; + + /** + * Number of milliseconds to wait for the job to complete. + */ + private static final int JOB_TIMEOUT = 2 * 60 * 1000; + + /** + * Command line parameter = true to run locally, or false to run on YARN. + */ + @NamedParameter(doc = "Whether or not to run on the local runtime", short_name = "local", default_value = "true") + public static final class Local implements Name<Boolean> { + } + + @NamedParameter(short_name = "input") + public static final class InputDir implements Name<String> { + } + + private static boolean local; + private static int dimensions; + private static int numberOfReceivers; + + private static Configuration parseCommandLine(final String[] aArgs) { + final JavaConfigurationBuilder cb = Tang.Factory.getTang().newConfigurationBuilder(); + try { + final CommandLine cl = new CommandLine(cb); + cl.registerShortNameOfClass(Local.class); + cl.registerShortNameOfClass(ModelDimensions.class); + cl.registerShortNameOfClass(NumberOfReceivers.class); + cl.processCommandLine(aArgs); + } catch (final IOException ex) { + final String msg = "Unable to parse command line"; + LOG.log(Level.SEVERE, msg, ex); + throw new RuntimeException(msg, ex); + } + return cb.build(); + } + + /** + * copy the parameters from the command line required for the Client configuration + */ + private static void storeCommandLineArgs( + final Configuration commandLineConf) throws InjectionException { + final Injector injector = Tang.Factory.getTang().newInjector(commandLineConf); + local = injector.getNamedInstance(Local.class); + dimensions = injector.getNamedInstance(ModelDimensions.class); + numberOfReceivers = injector.getNamedInstance(NumberOfReceivers.class); + } + + /** + * @return (immutable) TANG Configuration object. + */ + private static Configuration getRunTimeConfiguration() { + final Configuration runtimeConfiguration; + if (local) { + LOG.log(Level.INFO, "Running Broadcast example using group API on the local runtime"); + runtimeConfiguration = LocalRuntimeConfiguration.CONF + .set(LocalRuntimeConfiguration.NUMBER_OF_THREADS, NUM_LOCAL_THREADS) + .build(); + } else { + LOG.log(Level.INFO, "Running Broadcast example using group API on YARN"); + runtimeConfiguration = YarnClientConfiguration.CONF.build(); + } + return runtimeConfiguration; + } + + public static LauncherStatus runBGDReef( + final Configuration runtimeConfiguration) throws InjectionException { + + final Configuration driverConfiguration = EnvironmentUtils + .addClasspath(DriverConfiguration.CONF, DriverConfiguration.GLOBAL_LIBRARIES) + .set(DriverConfiguration.ON_DRIVER_STARTED, BroadcastDriver.StartHandler.class) + .set(DriverConfiguration.ON_EVALUATOR_ALLOCATED, BroadcastDriver.EvaluatorAllocatedHandler.class) + .set(DriverConfiguration.ON_CONTEXT_ACTIVE, BroadcastDriver.ContextActiveHandler.class) + .set(DriverConfiguration.ON_CONTEXT_CLOSED, BroadcastDriver.ContextCloseHandler.class) + .set(DriverConfiguration.ON_TASK_FAILED, BroadcastDriver.FailedTaskHandler.class) + .set(DriverConfiguration.DRIVER_IDENTIFIER, "BroadcastDriver") + .build(); + + final Configuration groupCommServConfiguration = GroupCommService.getConfiguration(); + + final Configuration mergedDriverConfiguration = Tang.Factory.getTang() + .newConfigurationBuilder(groupCommServConfiguration, driverConfiguration) + .bindNamedParameter(ModelDimensions.class, Integer.toString(dimensions)) + .bindNamedParameter(NumberOfReceivers.class, Integer.toString(numberOfReceivers)) + .build(); + + LOG.info(new AvroConfigurationSerializer().toString(mergedDriverConfiguration)); + + return DriverLauncher.getLauncher(runtimeConfiguration).run(mergedDriverConfiguration, JOB_TIMEOUT); + } + + public static void main(final String[] args) throws InjectionException { + final Configuration commandLineConf = parseCommandLine(args); + storeCommandLineArgs(commandLineConf); + final Configuration runtimeConfiguration = getRunTimeConfiguration(); + final LauncherStatus state = runBGDReef(runtimeConfiguration); + LOG.log(Level.INFO, "REEF job completed: {0}", state); + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/ControlMessages.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/ControlMessages.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/ControlMessages.java new file mode 100644 index 0000000..dc51076 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/ControlMessages.java @@ -0,0 +1,26 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.broadcast; + +import java.io.Serializable; + +public enum ControlMessages implements Serializable { + ReceiveModel, + Stop +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/MasterTask.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/MasterTask.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/MasterTask.java new file mode 100644 index 0000000..b5627d2 --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/MasterTask.java @@ -0,0 +1,97 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.broadcast; + +import org.apache.reef.examples.group.bgd.operatornames.ControlMessageBroadcaster; +import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup; +import org.apache.reef.examples.group.bgd.parameters.ModelDimensions; +import org.apache.reef.examples.group.broadcast.parameters.ModelBroadcaster; +import org.apache.reef.examples.group.broadcast.parameters.ModelReceiveAckReducer; +import org.apache.reef.examples.group.utils.math.DenseVector; +import org.apache.reef.examples.group.utils.math.Vector; +import org.apache.reef.io.network.group.api.operators.Broadcast; +import org.apache.reef.io.network.group.api.operators.Reduce; +import org.apache.reef.io.network.group.api.GroupChanges; +import org.apache.reef.io.network.group.api.task.CommunicationGroupClient; +import org.apache.reef.io.network.group.api.task.GroupCommClient; +import org.apache.reef.tang.annotations.Parameter; +import org.apache.reef.task.Task; +import org.mortbay.log.Log; + +import javax.inject.Inject; +import java.util.logging.Level; +import java.util.logging.Logger; + +public class MasterTask implements Task { + + public static final String TASK_ID = "MasterTask"; + + private static final Logger LOG = Logger.getLogger(MasterTask.class.getName()); + + private final CommunicationGroupClient communicationGroupClient; + private final Broadcast.Sender<ControlMessages> controlMessageBroadcaster; + private final Broadcast.Sender<Vector> modelBroadcaster; + private final Reduce.Receiver<Boolean> modelReceiveAckReducer; + + private final int dimensions; + + @Inject + public MasterTask( + final GroupCommClient groupCommClient, + final @Parameter(ModelDimensions.class) int dimensions) { + + this.dimensions = dimensions; + + this.communicationGroupClient = groupCommClient.getCommunicationGroup(AllCommunicationGroup.class); + this.controlMessageBroadcaster = communicationGroupClient.getBroadcastSender(ControlMessageBroadcaster.class); + this.modelBroadcaster = communicationGroupClient.getBroadcastSender(ModelBroadcaster.class); + this.modelReceiveAckReducer = communicationGroupClient.getReduceReceiver(ModelReceiveAckReducer.class); + } + + @Override + public byte[] call(final byte[] memento) throws Exception { + + final Vector model = new DenseVector(dimensions); + final long time1 = System.currentTimeMillis(); + final int numIters = 10; + + for (int i = 0; i < numIters; i++) { + + controlMessageBroadcaster.send(ControlMessages.ReceiveModel); + modelBroadcaster.send(model); + modelReceiveAckReducer.reduce(); + + final GroupChanges changes = communicationGroupClient.getTopologyChanges(); + if (changes.exist()) { + Log.info("There exist topology changes. Asking to update Topology"); + communicationGroupClient.updateTopology(); + } else { + Log.info("No changes in topology exist. So not updating topology"); + } + } + + final long time2 = System.currentTimeMillis(); + LOG.log(Level.FINE, "Broadcasting vector of dimensions {0} took {1} secs", + new Object[]{dimensions, (time2 - time1) / (numIters * 1000.0)}); + + controlMessageBroadcaster.send(ControlMessages.Stop); + + return null; + } +} http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/ModelReceiveAckReduceFunction.java ---------------------------------------------------------------------- diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/ModelReceiveAckReduceFunction.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/ModelReceiveAckReduceFunction.java new file mode 100644 index 0000000..e549cbc --- /dev/null +++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/broadcast/ModelReceiveAckReduceFunction.java @@ -0,0 +1,39 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.reef.examples.group.broadcast; + +import org.apache.reef.io.network.group.api.operators.Reduce.ReduceFunction; + +import javax.inject.Inject; + +/** + * + */ +public class ModelReceiveAckReduceFunction implements ReduceFunction<Boolean> { + + @Inject + public ModelReceiveAckReduceFunction() { + } + + @Override + public Boolean apply(final Iterable<Boolean> elements) { + return true; + } + +}
