You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@reef.apache.org by we...@apache.org on 2015/03/06 02:55:38 UTC
[8/8] incubator-reef git commit: [REEF-118] Add Shimoga library for
elastic group communication.
[REEF-118] Add Shimoga library for elastic group communication.
Shimoga is REEF library for elastic group communication. It provides
MPI-style operators like Broadcast and Reduce for inter-task messaging.
JIRA:
[REEF-118](https://issues.apache.org/jira/browse/REEF-118)
Pull Request:
This closes #63
Project: http://git-wip-us.apache.org/repos/asf/incubator-reef/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-reef/commit/6c6ad336
Tree: http://git-wip-us.apache.org/repos/asf/incubator-reef/tree/6c6ad336
Diff: http://git-wip-us.apache.org/repos/asf/incubator-reef/diff/6c6ad336
Branch: refs/heads/master
Commit: 6c6ad33674c6e61e44015e0632023e776b07536e
Parents: 0911c08
Author: Sergiy Matusevych <mo...@apache.org>
Authored: Thu Feb 12 14:22:58 2015 -0800
Committer: Markus Weimer <we...@apache.org>
Committed: Thu Mar 5 17:52:55 2015 -0800
----------------------------------------------------------------------
lang/java/reef-examples/pom.xml | 24 -
.../reef/examples/group/bgd/BGDClient.java | 134 +++++
.../reef/examples/group/bgd/BGDDriver.java | 376 ++++++++++++
.../reef/examples/group/bgd/BGDLocal.java | 53 ++
.../apache/reef/examples/group/bgd/BGDYarn.java | 52 ++
.../examples/group/bgd/ControlMessages.java | 30 +
.../reef/examples/group/bgd/ExampleList.java | 72 +++
.../group/bgd/LineSearchReduceFunction.java | 51 ++
.../bgd/LossAndGradientReduceFunction.java | 55 ++
.../reef/examples/group/bgd/MasterTask.java | 246 ++++++++
.../reef/examples/group/bgd/SlaveTask.java | 204 +++++++
.../reef/examples/group/bgd/data/Example.java | 52 ++
.../examples/group/bgd/data/SparseExample.java | 68 +++
.../examples/group/bgd/data/parser/Parser.java | 32 +
.../group/bgd/data/parser/SVMLightParser.java | 98 ++++
.../group/bgd/loss/LogisticLossFunction.java | 50 ++
.../examples/group/bgd/loss/LossFunction.java | 46 ++
.../bgd/loss/SquaredErrorLossFunction.java | 49 ++
.../bgd/loss/WeightedLogisticLossFunction.java | 74 +++
.../ControlMessageBroadcaster.java | 29 +
.../DescentDirectionBroadcaster.java | 29 +
.../LineSearchEvaluationsReducer.java | 29 +
.../operatornames/LossAndGradientReducer.java | 29 +
.../bgd/operatornames/MinEtaBroadcaster.java | 26 +
.../ModelAndDescentDirectionBroadcaster.java | 29 +
.../bgd/operatornames/ModelBroadcaster.java | 29 +
.../group/bgd/operatornames/package-info.java | 23 +
.../bgd/parameters/AllCommunicationGroup.java | 26 +
.../bgd/parameters/BGDControlParameters.java | 126 ++++
.../group/bgd/parameters/BGDLossType.java | 61 ++
.../group/bgd/parameters/EnableRampup.java | 29 +
.../reef/examples/group/bgd/parameters/Eps.java | 30 +
.../reef/examples/group/bgd/parameters/Eta.java | 30 +
.../group/bgd/parameters/EvaluatorMemory.java | 29 +
.../examples/group/bgd/parameters/InputDir.java | 29 +
.../group/bgd/parameters/Iterations.java | 29 +
.../examples/group/bgd/parameters/Lambda.java | 29 +
.../group/bgd/parameters/LossFunctionType.java | 30 +
.../examples/group/bgd/parameters/MinParts.java | 29 +
.../group/bgd/parameters/ModelDimensions.java | 30 +
.../group/bgd/parameters/NumSplits.java | 30 +
.../group/bgd/parameters/NumberOfReceivers.java | 30 +
.../bgd/parameters/ProbabilityOfFailure.java | 30 +
.../ProbabilityOfSuccesfulIteration.java | 30 +
.../examples/group/bgd/parameters/Timeout.java | 28 +
.../examples/group/bgd/utils/StepSizes.java | 59 ++
.../group/bgd/utils/SubConfiguration.java | 73 +++
.../group/broadcast/BroadcastDriver.java | 285 +++++++++
.../examples/group/broadcast/BroadcastREEF.java | 148 +++++
.../group/broadcast/ControlMessages.java | 26 +
.../examples/group/broadcast/MasterTask.java | 97 ++++
.../ModelReceiveAckReduceFunction.java | 39 ++
.../examples/group/broadcast/SlaveTask.java | 76 +++
.../parameters/AllCommunicationGroup.java | 30 +
.../parameters/ControlMessageBroadcaster.java | 26 +
.../group/broadcast/parameters/Dimensions.java | 30 +
.../parameters/FailureProbability.java | 30 +
.../broadcast/parameters/ModelBroadcaster.java | 26 +
.../parameters/ModelReceiveAckReducer.java | 26 +
.../broadcast/parameters/NumberOfReceivers.java | 30 +
.../utils/math/AbstractImmutableVector.java | 103 ++++
.../group/utils/math/AbstractVector.java | 61 ++
.../examples/group/utils/math/DenseVector.java | 112 ++++
.../group/utils/math/ImmutableVector.java | 78 +++
.../examples/group/utils/math/SparseVector.java | 57 ++
.../reef/examples/group/utils/math/Vector.java | 72 +++
.../examples/group/utils/math/VectorCodec.java | 70 +++
.../reef/examples/group/utils/math/Window.java | 76 +++
.../reef/examples/group/utils/timer/Timer.java | 58 ++
.../reef/examples/scheduler/Scheduler.java | 5 +-
.../utils/wake/BlockingEventHandler.java | 2 +-
.../utils/wake/LoggingEventHandler.java | 17 +-
lang/java/reef-io/pom.xml | 5 +
.../reef/io/network/group/api/GroupChanges.java | 31 +
.../network/group/api/config/OperatorSpec.java | 38 ++
.../api/driver/CommunicationGroupDriver.java | 87 +++
.../group/api/driver/GroupCommDriver.java | 76 +++
.../api/driver/GroupCommServiceDriver.java | 59 ++
.../io/network/group/api/driver/TaskNode.java | 94 +++
.../group/api/driver/TaskNodeStatus.java | 81 +++
.../io/network/group/api/driver/Topology.java | 115 ++++
.../operators/AbstractGroupCommOperator.java | 44 ++
.../network/group/api/operators/AllGather.java | 50 ++
.../network/group/api/operators/AllReduce.java | 55 ++
.../network/group/api/operators/Broadcast.java | 60 ++
.../io/network/group/api/operators/Gather.java | 64 ++
.../group/api/operators/GroupCommOperator.java | 33 ++
.../io/network/group/api/operators/Reduce.java | 99 ++++
.../group/api/operators/ReduceScatter.java | 67 +++
.../io/network/group/api/operators/Scatter.java | 74 +++
.../group/api/operators/package-info.java | 48 ++
.../group/api/task/CommGroupNetworkHandler.java | 41 ++
.../api/task/CommunicationGroupClient.java | 97 ++++
.../task/CommunicationGroupServiceClient.java | 34 ++
.../network/group/api/task/GroupCommClient.java | 42 ++
.../group/api/task/GroupCommNetworkHandler.java | 38 ++
.../io/network/group/api/task/NodeStruct.java | 42 ++
.../group/api/task/OperatorTopology.java | 58 ++
.../group/api/task/OperatorTopologyStruct.java | 73 +++
.../network/group/impl/GroupChangesCodec.java | 71 +++
.../io/network/group/impl/GroupChangesImpl.java | 45 ++
.../group/impl/GroupCommunicationMessage.java | 167 ++++++
.../impl/GroupCommunicationMessageCodec.java | 111 ++++
.../impl/config/BroadcastOperatorSpec.java | 86 +++
.../group/impl/config/ReduceOperatorSpec.java | 107 ++++
.../parameters/CommunicationGroupName.java | 28 +
.../group/impl/config/parameters/DataCodec.java | 29 +
.../impl/config/parameters/OperatorName.java | 28 +
.../config/parameters/ReduceFunctionParam.java | 29 +
.../parameters/SerializedGroupConfigs.java | 30 +
.../parameters/SerializedOperConfigs.java | 30 +
.../impl/config/parameters/TaskVersion.java | 28 +
.../config/parameters/TreeTopologyFanOut.java | 28 +
.../driver/CommunicationGroupDriverImpl.java | 451 +++++++++++++++
.../group/impl/driver/CtrlMsgSender.java | 61 ++
.../group/impl/driver/ExceptionHandler.java | 56 ++
.../network/group/impl/driver/FlatTopology.java | 307 ++++++++++
.../group/impl/driver/GroupCommDriverImpl.java | 250 ++++++++
.../impl/driver/GroupCommMessageHandler.java | 55 ++
.../group/impl/driver/GroupCommService.java | 111 ++++
.../network/group/impl/driver/IndexedMsg.java | 71 +++
.../io/network/group/impl/driver/MsgKey.java | 90 +++
.../network/group/impl/driver/TaskNodeImpl.java | 476 +++++++++++++++
.../group/impl/driver/TaskNodeStatusImpl.java | 267 +++++++++
.../io/network/group/impl/driver/TaskState.java | 23 +
.../driver/TopologyFailedEvaluatorHandler.java | 50 ++
.../impl/driver/TopologyFailedTaskHandler.java | 45 ++
.../impl/driver/TopologyMessageHandler.java | 44 ++
.../impl/driver/TopologyRunningTaskHandler.java | 44 ++
.../impl/driver/TopologyUpdateWaitHandler.java | 94 +++
.../network/group/impl/driver/TreeTopology.java | 345 +++++++++++
.../network/group/impl/driver/package-info.java | 116 ++++
.../group/impl/operators/BroadcastReceiver.java | 159 +++++
.../group/impl/operators/BroadcastSender.java | 141 +++++
.../group/impl/operators/ReduceReceiver.java | 155 +++++
.../group/impl/operators/ReduceSender.java | 161 ++++++
.../io/network/group/impl/operators/Sender.java | 59 ++
.../group/impl/task/ChildNodeStruct.java | 42 ++
.../impl/task/CommGroupNetworkHandlerImpl.java | 102 ++++
.../impl/task/CommunicationGroupClientImpl.java | 296 ++++++++++
.../group/impl/task/GroupCommClientImpl.java | 85 +++
.../impl/task/GroupCommNetworkHandlerImpl.java | 68 +++
.../io/network/group/impl/task/InitHandler.java | 54 ++
.../network/group/impl/task/NodeStructImpl.java | 98 ++++
.../group/impl/task/OperatorTopologyImpl.java | 466 +++++++++++++++
.../impl/task/OperatorTopologyStructImpl.java | 579 +++++++++++++++++++
.../group/impl/task/ParentNodeStruct.java | 45 ++
.../impl/utils/BroadcastingEventHandler.java | 44 ++
.../group/impl/utils/ConcurrentCountingMap.java | 134 +++++
.../network/group/impl/utils/CountingMap.java | 98 ++++
.../group/impl/utils/CountingSemaphore.java | 103 ++++
.../impl/utils/ResettingCountDownLatch.java | 57 ++
.../io/network/group/impl/utils/SetMap.java | 95 +++
.../reef/io/network/group/impl/utils/Utils.java | 80 +++
.../reef/io/network/group/package-info.java | 33 ++
.../reef/io/network/naming/NameServer.java | 20 +-
.../reef/io/network/naming/NameServerImpl.java | 1 -
.../org/apache/reef/io/network/util/Utils.java | 119 ++++
.../org/apache/reef/io/storage/ram/RamMap.java | 6 +-
.../src/main/proto/group_comm_protocol.proto | 64 ++
.../GroupCommunicationMessageCodecTest.java | 72 +++
.../apache/reef/io/network/util/TestUtils.java | 60 ++
.../services/network/NetworkServiceTest.java | 26 +-
163 files changed, 13180 insertions(+), 76 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/pom.xml
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/pom.xml b/lang/java/reef-examples/pom.xml
index f910a77..68c4693 100644
--- a/lang/java/reef-examples/pom.xml
+++ b/lang/java/reef-examples/pom.xml
@@ -214,30 +214,6 @@ under the License.
</plugins>
</build>
</profile>
- <profile>
- <id>MatMult</id>
- <build>
- <defaultGoal>exec:exec</defaultGoal>
- <plugins>
- <plugin>
- <groupId>org.codehaus.mojo</groupId>
- <artifactId>exec-maven-plugin</artifactId>
- <configuration>
- <executable>java</executable>
- <arguments>
- <argument>-classpath</argument>
- <classpath/>
- <argument>-Djava.util.logging.config.class=org.apache.reef.util.logging.Config
- </argument>
- <argument>-Dcom.microsoft.reef.runtime.local.folder=${project.build.directory}
- </argument>
- <argument>org.apache.reef.examples.groupcomm.matmul.MatMultREEF</argument>
- </arguments>
- </configuration>
- </plugin>
- </plugins>
- </build>
- </profile>
<profile>
<id>RetainedEval</id>
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDClient.java
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDClient.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDClient.java
new file mode 100644
index 0000000..84865e8
--- /dev/null
+++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDClient.java
@@ -0,0 +1,134 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.examples.group.bgd;
+
+import org.apache.hadoop.mapred.TextInputFormat;
+import org.apache.reef.client.DriverConfiguration;
+import org.apache.reef.client.DriverLauncher;
+import org.apache.reef.client.LauncherStatus;
+import org.apache.reef.client.REEF;
+import org.apache.reef.driver.evaluator.EvaluatorRequest;
+import org.apache.reef.examples.group.bgd.parameters.*;
+import org.apache.reef.io.data.loading.api.DataLoadingRequestBuilder;
+import org.apache.reef.io.network.group.impl.config.parameters.TreeTopologyFanOut;
+import org.apache.reef.io.network.group.impl.driver.GroupCommService;
+import org.apache.reef.tang.Configuration;
+import org.apache.reef.tang.Configurations;
+import org.apache.reef.tang.JavaConfigurationBuilder;
+import org.apache.reef.tang.Tang;
+import org.apache.reef.tang.annotations.Parameter;
+import org.apache.reef.tang.formats.CommandLine;
+import org.apache.reef.util.EnvironmentUtils;
+
+import javax.inject.Inject;
+
+/**
+ * A client to submit BGD Jobs
+ */
+public class BGDClient {
+ private final String input;
+ private final int numSplits;
+ private final int memory;
+
+ private final BGDControlParameters bgdControlParameters;
+ private final int fanOut;
+
+ @Inject
+ public BGDClient(final @Parameter(InputDir.class) String input,
+ final @Parameter(NumSplits.class) int numSplits,
+ final @Parameter(EvaluatorMemory.class) int memory,
+ final @Parameter(TreeTopologyFanOut.class) int fanOut,
+ final BGDControlParameters bgdControlParameters) {
+ this.input = input;
+ this.fanOut = fanOut;
+ this.bgdControlParameters = bgdControlParameters;
+ this.numSplits = numSplits;
+ this.memory = memory;
+ }
+
+ /**
+ * Runs BGD on the given runtime.
+ *
+ * @param runtimeConfiguration the runtime to run on.
+ * @param jobName the name of the job on the runtime.
+ * @return
+ */
+ public void submit(final Configuration runtimeConfiguration, final String jobName) throws Exception {
+ final Configuration driverConfiguration = getDriverConfiguration(jobName);
+ Tang.Factory.getTang().newInjector(runtimeConfiguration).getInstance(REEF.class).submit(driverConfiguration);
+ }
+
+ /**
+ * Runs BGD on the given runtime - with timeout.
+ *
+ * @param runtimeConfiguration the runtime to run on.
+ * @param jobName the name of the job on the runtime.
+ * @param timeout the time after which the job will be killed if not completed, in ms
+ * @return job completion status
+ */
+ public LauncherStatus run(final Configuration runtimeConfiguration,
+ final String jobName, final int timeout) throws Exception {
+ final Configuration driverConfiguration = getDriverConfiguration(jobName);
+ return DriverLauncher.getLauncher(runtimeConfiguration).run(driverConfiguration, timeout);
+ }
+
+ private final Configuration getDriverConfiguration(final String jobName) {
+ return Configurations.merge(
+ getDataLoadConfiguration(jobName),
+ GroupCommService.getConfiguration(fanOut),
+ this.bgdControlParameters.getConfiguration());
+ }
+
+ private Configuration getDataLoadConfiguration(final String jobName) {
+ final EvaluatorRequest computeRequest = EvaluatorRequest.newBuilder()
+ .setNumber(1)
+ .setMemory(memory)
+ .build();
+ final Configuration dataLoadConfiguration = new DataLoadingRequestBuilder()
+ .setMemoryMB(memory)
+ .setInputFormatClass(TextInputFormat.class)
+ .setInputPath(input)
+ .setNumberOfDesiredSplits(numSplits)
+ .setComputeRequest(computeRequest)
+ .renewFailedEvaluators(false)
+ .setDriverConfigurationModule(EnvironmentUtils
+ .addClasspath(DriverConfiguration.CONF, DriverConfiguration.GLOBAL_LIBRARIES)
+ .set(DriverConfiguration.DRIVER_MEMORY, Integer.toString(memory))
+ .set(DriverConfiguration.ON_CONTEXT_ACTIVE, BGDDriver.ContextActiveHandler.class)
+ .set(DriverConfiguration.ON_TASK_RUNNING, BGDDriver.TaskRunningHandler.class)
+ .set(DriverConfiguration.ON_TASK_FAILED, BGDDriver.TaskFailedHandler.class)
+ .set(DriverConfiguration.ON_TASK_COMPLETED, BGDDriver.TaskCompletedHandler.class)
+ .set(DriverConfiguration.DRIVER_IDENTIFIER, jobName))
+ .build();
+ return dataLoadConfiguration;
+ }
+
+ public static final BGDClient fromCommandLine(final String[] args) throws Exception {
+ final JavaConfigurationBuilder configurationBuilder = Tang.Factory.getTang().newConfigurationBuilder();
+ final CommandLine commandLine = new CommandLine(configurationBuilder)
+ .registerShortNameOfClass(InputDir.class)
+ .registerShortNameOfClass(Timeout.class)
+ .registerShortNameOfClass(EvaluatorMemory.class)
+ .registerShortNameOfClass(NumSplits.class)
+ .registerShortNameOfClass(TreeTopologyFanOut.class);
+ BGDControlParameters.registerShortNames(commandLine);
+ commandLine.processCommandLine(args);
+ return Tang.Factory.getTang().newInjector(configurationBuilder.build()).getInstance(BGDClient.class);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDDriver.java
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDDriver.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDDriver.java
new file mode 100644
index 0000000..2a80581
--- /dev/null
+++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDDriver.java
@@ -0,0 +1,376 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.examples.group.bgd;
+
+import org.apache.reef.annotations.audience.DriverSide;
+import org.apache.reef.driver.context.ActiveContext;
+import org.apache.reef.driver.context.ServiceConfiguration;
+import org.apache.reef.driver.task.CompletedTask;
+import org.apache.reef.driver.task.FailedTask;
+import org.apache.reef.driver.task.RunningTask;
+import org.apache.reef.driver.task.TaskConfiguration;
+import org.apache.reef.evaluator.context.parameters.ContextIdentifier;
+import org.apache.reef.examples.group.bgd.data.parser.Parser;
+import org.apache.reef.examples.group.bgd.data.parser.SVMLightParser;
+import org.apache.reef.examples.group.bgd.loss.LossFunction;
+import org.apache.reef.examples.group.bgd.operatornames.*;
+import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup;
+import org.apache.reef.examples.group.bgd.parameters.BGDControlParameters;
+import org.apache.reef.examples.group.bgd.parameters.ModelDimensions;
+import org.apache.reef.examples.group.bgd.parameters.ProbabilityOfFailure;
+import org.apache.reef.io.data.loading.api.DataLoadingService;
+import org.apache.reef.io.network.group.api.driver.CommunicationGroupDriver;
+import org.apache.reef.io.network.group.api.driver.GroupCommDriver;
+import org.apache.reef.io.network.group.impl.config.BroadcastOperatorSpec;
+import org.apache.reef.io.network.group.impl.config.ReduceOperatorSpec;
+import org.apache.reef.io.serialization.Codec;
+import org.apache.reef.io.serialization.SerializableCodec;
+import org.apache.reef.poison.PoisonedConfiguration;
+import org.apache.reef.tang.Configuration;
+import org.apache.reef.tang.Configurations;
+import org.apache.reef.tang.Tang;
+import org.apache.reef.tang.annotations.Unit;
+import org.apache.reef.tang.exceptions.InjectionException;
+import org.apache.reef.tang.formats.ConfigurationSerializer;
+import org.apache.reef.wake.EventHandler;
+
+import javax.inject.Inject;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+@DriverSide
+@Unit
+public class BGDDriver {
+
+ private static final Logger LOG = Logger.getLogger(BGDDriver.class.getName());
+
+ private static final Tang TANG = Tang.Factory.getTang();
+
+ private static final double STARTUP_FAILURE_PROB = 0.01;
+
+ private final DataLoadingService dataLoadingService;
+ private final GroupCommDriver groupCommDriver;
+ private final ConfigurationSerializer confSerializer;
+ private final CommunicationGroupDriver communicationsGroup;
+ private final AtomicBoolean masterSubmitted = new AtomicBoolean(false);
+ private final AtomicInteger slaveIds = new AtomicInteger(0);
+ private final Map<String, RunningTask> runningTasks = new HashMap<>();
+ private final AtomicBoolean jobComplete = new AtomicBoolean(false);
+ private final Codec<ArrayList<Double>> lossCodec = new SerializableCodec<>();
+ private final BGDControlParameters bgdControlParameters;
+
+ private String communicationsGroupMasterContextId;
+
+ @Inject
+ public BGDDriver(final DataLoadingService dataLoadingService,
+ final GroupCommDriver groupCommDriver,
+ final ConfigurationSerializer confSerializer,
+ final BGDControlParameters bgdControlParameters) {
+ this.dataLoadingService = dataLoadingService;
+ this.groupCommDriver = groupCommDriver;
+ this.confSerializer = confSerializer;
+ this.bgdControlParameters = bgdControlParameters;
+
+ final int minNumOfPartitions =
+ bgdControlParameters.isRampup()
+ ? bgdControlParameters.getMinParts()
+ : dataLoadingService.getNumberOfPartitions();
+
+ final int numParticipants = minNumOfPartitions + 1;
+
+ this.communicationsGroup = this.groupCommDriver.newCommunicationGroup(
+ AllCommunicationGroup.class, // NAME
+ numParticipants); // Number of participants
+
+ LOG.log(Level.INFO,
+ "Obtained entire communication group: start with {0} partitions", numParticipants);
+
+ this.communicationsGroup
+ .addBroadcast(ControlMessageBroadcaster.class,
+ BroadcastOperatorSpec.newBuilder()
+ .setSenderId(MasterTask.TASK_ID)
+ .setDataCodecClass(SerializableCodec.class)
+ .build())
+ .addBroadcast(ModelBroadcaster.class,
+ BroadcastOperatorSpec.newBuilder()
+ .setSenderId(MasterTask.TASK_ID)
+ .setDataCodecClass(SerializableCodec.class)
+ .build())
+ .addReduce(LossAndGradientReducer.class,
+ ReduceOperatorSpec.newBuilder()
+ .setReceiverId(MasterTask.TASK_ID)
+ .setDataCodecClass(SerializableCodec.class)
+ .setReduceFunctionClass(LossAndGradientReduceFunction.class)
+ .build())
+ .addBroadcast(ModelAndDescentDirectionBroadcaster.class,
+ BroadcastOperatorSpec.newBuilder()
+ .setSenderId(MasterTask.TASK_ID)
+ .setDataCodecClass(SerializableCodec.class)
+ .build())
+ .addBroadcast(DescentDirectionBroadcaster.class,
+ BroadcastOperatorSpec.newBuilder()
+ .setSenderId(MasterTask.TASK_ID)
+ .setDataCodecClass(SerializableCodec.class)
+ .build())
+ .addReduce(LineSearchEvaluationsReducer.class,
+ ReduceOperatorSpec.newBuilder()
+ .setReceiverId(MasterTask.TASK_ID)
+ .setDataCodecClass(SerializableCodec.class)
+ .setReduceFunctionClass(LineSearchReduceFunction.class)
+ .build())
+ .addBroadcast(MinEtaBroadcaster.class,
+ BroadcastOperatorSpec.newBuilder()
+ .setSenderId(MasterTask.TASK_ID)
+ .setDataCodecClass(SerializableCodec.class)
+ .build())
+ .finalise();
+
+ LOG.log(Level.INFO, "Added operators to communicationsGroup");
+ }
+
+ final class ContextActiveHandler implements EventHandler<ActiveContext> {
+
+ @Override
+ public void onNext(final ActiveContext activeContext) {
+ LOG.log(Level.INFO, "Got active context: {0}", activeContext.getId());
+ if (jobRunning(activeContext)) {
+ if (!groupCommDriver.isConfigured(activeContext)) {
+ // The Context is not configured with the group communications service let's do that.
+ submitGroupCommunicationsService(activeContext);
+ } else {
+ // The group communications service is already active on this context. We can submit the task.
+ submitTask(activeContext);
+ }
+ }
+ }
+
+ /**
+ * @param activeContext a context to be configured with group communications.
+ */
+ private void submitGroupCommunicationsService(final ActiveContext activeContext) {
+ final Configuration contextConf = groupCommDriver.getContextConfiguration();
+ final String contextId = getContextId(contextConf);
+ final Configuration serviceConf;
+ if (!dataLoadingService.isDataLoadedContext(activeContext)) {
+ communicationsGroupMasterContextId = contextId;
+ serviceConf = groupCommDriver.getServiceConfiguration();
+ } else {
+ final Configuration parsedDataServiceConf = ServiceConfiguration.CONF
+ .set(ServiceConfiguration.SERVICES, ExampleList.class)
+ .build();
+ serviceConf = Tang.Factory.getTang()
+ .newConfigurationBuilder(groupCommDriver.getServiceConfiguration(), parsedDataServiceConf)
+ .bindImplementation(Parser.class, SVMLightParser.class)
+ .build();
+ }
+
+ LOG.log(Level.FINEST, "Submit GCContext conf: {0} and Service conf: {1}", new Object[]{
+ confSerializer.toString(contextConf), confSerializer.toString(serviceConf)});
+
+ activeContext.submitContextAndService(contextConf, serviceConf);
+ }
+
+ private void submitTask(final ActiveContext activeContext) {
+
+ assert (groupCommDriver.isConfigured(activeContext));
+
+ final Configuration partialTaskConfiguration;
+ if (activeContext.getId().equals(communicationsGroupMasterContextId) && !masterTaskSubmitted()) {
+ partialTaskConfiguration = getMasterTaskConfiguration();
+ LOG.info("Submitting MasterTask conf");
+ } else {
+ partialTaskConfiguration = getSlaveTaskConfiguration(getSlaveId(activeContext));
+ // partialTaskConfiguration = Configurations.merge(
+ // getSlaveTaskConfiguration(getSlaveId(activeContext)),
+ // getTaskPoisonConfiguration());
+ LOG.info("Submitting SlaveTask conf");
+ }
+ communicationsGroup.addTask(partialTaskConfiguration);
+ final Configuration taskConfiguration = groupCommDriver.getTaskConfiguration(partialTaskConfiguration);
+ LOG.log(Level.FINEST, "{0}", confSerializer.toString(taskConfiguration));
+ activeContext.submitTask(taskConfiguration);
+ }
+
+ private boolean jobRunning(final ActiveContext activeContext) {
+ synchronized (runningTasks) {
+ if (!jobComplete.get()) {
+ return true;
+ } else {
+ LOG.log(Level.INFO, "Job complete. Not submitting any task. Closing context {0}", activeContext);
+ activeContext.close();
+ return false;
+ }
+ }
+ }
+ }
+
+ final class TaskRunningHandler implements EventHandler<RunningTask> {
+
+ @Override
+ public void onNext(final RunningTask runningTask) {
+ synchronized (runningTasks) {
+ if (!jobComplete.get()) {
+ LOG.log(Level.INFO, "Job has not completed yet. Adding to runningTasks: {0}", runningTask);
+ runningTasks.put(runningTask.getId(), runningTask);
+ } else {
+ LOG.log(Level.INFO, "Job complete. Closing context: {0}", runningTask.getActiveContext().getId());
+ runningTask.getActiveContext().close();
+ }
+ }
+ }
+ }
+
+ final class TaskFailedHandler implements EventHandler<FailedTask> {
+
+ @Override
+ public void onNext(final FailedTask failedTask) {
+
+ final String failedTaskId = failedTask.getId();
+
+ LOG.log(Level.WARNING, "Got failed Task: " + failedTaskId);
+
+ if (jobRunning(failedTaskId)) {
+
+ final ActiveContext activeContext = failedTask.getActiveContext().get();
+ final Configuration partialTaskConf = getSlaveTaskConfiguration(failedTaskId);
+
+ // Do not add the task back:
+ // allCommGroup.addTask(partialTaskConf);
+
+ final Configuration taskConf = groupCommDriver.getTaskConfiguration(partialTaskConf);
+ LOG.log(Level.FINEST, "Submit SlaveTask conf: {0}", confSerializer.toString(taskConf));
+
+ activeContext.submitTask(taskConf);
+ }
+ }
+
+ private boolean jobRunning(final String failedTaskId) {
+ synchronized (runningTasks) {
+ if (!jobComplete.get()) {
+ return true;
+ } else {
+ final RunningTask rTask = runningTasks.remove(failedTaskId);
+ LOG.log(Level.INFO, "Job has completed. Not resubmitting");
+ if (rTask != null) {
+ LOG.log(Level.INFO, "Closing activecontext");
+ rTask.getActiveContext().close();
+ } else {
+ LOG.log(Level.INFO, "Master must have closed my context");
+ }
+ return false;
+ }
+ }
+ }
+ }
+
+ final class TaskCompletedHandler implements EventHandler<CompletedTask> {
+
+ @Override
+ public void onNext(final CompletedTask task) {
+ LOG.log(Level.INFO, "Got CompletedTask: {0}", task.getId());
+ final byte[] retVal = task.get();
+ if (retVal != null) {
+ final List<Double> losses = BGDDriver.this.lossCodec.decode(retVal);
+ for (final Double loss : losses) {
+ LOG.log(Level.INFO, "OUT: LOSS = {0}", loss);
+ }
+ }
+ synchronized (runningTasks) {
+ LOG.log(Level.INFO, "Acquired lock on runningTasks. Removing {0}", task.getId());
+ final RunningTask rTask = runningTasks.remove(task.getId());
+ if (rTask != null) {
+ LOG.log(Level.INFO, "Closing active context: {0}", task.getActiveContext().getId());
+ task.getActiveContext().close();
+ } else {
+ LOG.log(Level.INFO, "Master must have closed active context already for task {0}", task.getId());
+ }
+
+ if (MasterTask.TASK_ID.equals(task.getId())) {
+ jobComplete.set(true);
+ LOG.log(Level.INFO, "Master(=>Job) complete. Closing other running tasks: {0}", runningTasks.values());
+ for (final RunningTask runTask : runningTasks.values()) {
+ runTask.getActiveContext().close();
+ }
+ LOG.finest("Clearing runningTasks");
+ runningTasks.clear();
+ }
+ }
+ }
+ }
+
+ /**
+ * @return Configuration for the MasterTask
+ */
+ public Configuration getMasterTaskConfiguration() {
+ return Configurations.merge(
+ TaskConfiguration.CONF
+ .set(TaskConfiguration.IDENTIFIER, MasterTask.TASK_ID)
+ .set(TaskConfiguration.TASK, MasterTask.class)
+ .build(),
+ bgdControlParameters.getConfiguration());
+ }
+
+ /**
+ * @return Configuration for the SlaveTask
+ */
+ private Configuration getSlaveTaskConfiguration(final String taskId) {
+ final double pSuccess = bgdControlParameters.getProbOfSuccessfulIteration();
+ final int numberOfPartitions = dataLoadingService.getNumberOfPartitions();
+ final double pFailure = 1 - Math.pow(pSuccess, 1.0 / numberOfPartitions);
+ return Tang.Factory.getTang()
+ .newConfigurationBuilder(
+ TaskConfiguration.CONF
+ .set(TaskConfiguration.IDENTIFIER, taskId)
+ .set(TaskConfiguration.TASK, SlaveTask.class)
+ .build())
+ .bindNamedParameter(ModelDimensions.class, "" + bgdControlParameters.getDimensions())
+ .bindImplementation(LossFunction.class, bgdControlParameters.getLossFunction())
+ .bindNamedParameter(ProbabilityOfFailure.class, Double.toString(pFailure))
+ .build();
+ }
+
+ private Configuration getTaskPoisonConfiguration() {
+ return PoisonedConfiguration.TASK_CONF
+ .set(PoisonedConfiguration.CRASH_PROBABILITY, STARTUP_FAILURE_PROB)
+ .set(PoisonedConfiguration.CRASH_TIMEOUT, 1)
+ .build();
+ }
+
+ private String getContextId(final Configuration contextConf) {
+ try {
+ return TANG.newInjector(contextConf).getNamedInstance(ContextIdentifier.class);
+ } catch (final InjectionException e) {
+ throw new RuntimeException("Unable to inject context identifier from context conf", e);
+ }
+ }
+
+ private String getSlaveId(final ActiveContext activeContext) {
+ return "SlaveTask-" + slaveIds.getAndIncrement();
+ }
+
+ private boolean masterTaskSubmitted() {
+ return !masterSubmitted.compareAndSet(false, true);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDLocal.java
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDLocal.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDLocal.java
new file mode 100644
index 0000000..3a82314
--- /dev/null
+++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDLocal.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.examples.group.bgd;
+
+import org.apache.reef.client.LauncherStatus;
+import org.apache.reef.examples.group.utils.timer.Timer;
+import org.apache.reef.runtime.local.client.LocalRuntimeConfiguration;
+import org.apache.reef.tang.Configuration;
+
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * Runs BGD on the local runtime.
+ */
+public class BGDLocal {
+
+ private static final Logger LOG = Logger.getLogger(BGDLocal.class.getName());
+
+ private static final int NUM_LOCAL_THREADS = 20;
+ private static final int TIMEOUT = 10 * Timer.MINUTES;
+
+ public static void main(final String[] args) throws Exception {
+
+ final BGDClient bgdClient = BGDClient.fromCommandLine(args);
+
+ final Configuration runtimeConfiguration = LocalRuntimeConfiguration.CONF
+ .set(LocalRuntimeConfiguration.NUMBER_OF_THREADS, "" + NUM_LOCAL_THREADS)
+ .build();
+
+ final String jobName = System.getProperty("user.name") + "-" + "ResourceAwareBGDLocal";
+
+ final LauncherStatus status = bgdClient.run(runtimeConfiguration, jobName, TIMEOUT);
+
+ LOG.log(Level.INFO, "OUT: Status = {0}", status);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDYarn.java
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDYarn.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDYarn.java
new file mode 100644
index 0000000..19d3b10
--- /dev/null
+++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/BGDYarn.java
@@ -0,0 +1,52 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.examples.group.bgd;
+
+import org.apache.reef.client.LauncherStatus;
+import org.apache.reef.examples.group.utils.timer.Timer;
+import org.apache.reef.runtime.yarn.client.YarnClientConfiguration;
+import org.apache.reef.tang.Configuration;
+
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * Runs BGD on the YARN runtime.
+ */
+public class BGDYarn {
+
+ private static final Logger LOG = Logger.getLogger(BGDYarn.class.getName());
+
+ private static final int TIMEOUT = 4 * Timer.HOURS;
+
+ public static void main(final String[] args) throws Exception {
+
+ final BGDClient bgdClient = BGDClient.fromCommandLine(args);
+
+ final Configuration runtimeConfiguration = YarnClientConfiguration.CONF
+ .set(YarnClientConfiguration.JVM_HEAP_SLACK, "0.1")
+ .build();
+
+ final String jobName = System.getProperty("user.name") + "-" + "BR-ResourceAwareBGD-YARN";
+
+ final LauncherStatus status = bgdClient.run(runtimeConfiguration, jobName, TIMEOUT);
+
+ LOG.log(Level.INFO, "OUT: Status = {0}", status);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/ControlMessages.java
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/ControlMessages.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/ControlMessages.java
new file mode 100644
index 0000000..aeea56b
--- /dev/null
+++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/ControlMessages.java
@@ -0,0 +1,30 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.examples.group.bgd;
+
+import java.io.Serializable;
+
+public enum ControlMessages implements Serializable {
+ ComputeGradientWithModel,
+ ComputeGradientWithMinEta,
+ DoLineSearch,
+ DoLineSearchWithModel,
+ Synchronize,
+ Stop
+}
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/ExampleList.java
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/ExampleList.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/ExampleList.java
new file mode 100644
index 0000000..97477a9
--- /dev/null
+++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/ExampleList.java
@@ -0,0 +1,72 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.examples.group.bgd;
+
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.reef.examples.group.bgd.data.Example;
+import org.apache.reef.examples.group.bgd.data.parser.Parser;
+import org.apache.reef.io.data.loading.api.DataSet;
+import org.apache.reef.io.network.util.Pair;
+
+import javax.inject.Inject;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ *
+ */
+public class ExampleList {
+
+ private static final Logger LOG = Logger.getLogger(ExampleList.class.getName());
+
+ private final List<Example> examples = new ArrayList<>();
+ private final DataSet<LongWritable, Text> dataSet;
+ private final Parser<String> parser;
+
+ @Inject
+ public ExampleList(final DataSet<LongWritable, Text> dataSet, final Parser<String> parser) {
+ this.dataSet = dataSet;
+ this.parser = parser;
+ }
+
+ /**
+ * @return the examples
+ */
+ public List<Example> getExamples() {
+ if (examples.isEmpty()) {
+ loadData();
+ }
+ return examples;
+ }
+
+ private void loadData() {
+ LOG.info("Loading data");
+ int i = 0;
+ for (final Pair<LongWritable, Text> examplePair : dataSet) {
+ final Example example = parser.parse(examplePair.second.toString());
+ examples.add(example);
+ if (++i % 2000 == 0) {
+ LOG.log(Level.FINE, "Done parsing {0} lines", i);
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/LineSearchReduceFunction.java
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/LineSearchReduceFunction.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/LineSearchReduceFunction.java
new file mode 100644
index 0000000..9132583
--- /dev/null
+++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/LineSearchReduceFunction.java
@@ -0,0 +1,51 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.examples.group.bgd;
+
+import org.apache.reef.examples.group.utils.math.DenseVector;
+import org.apache.reef.examples.group.utils.math.Vector;
+import org.apache.reef.io.network.group.api.operators.Reduce;
+import org.apache.reef.io.network.util.Pair;
+
+import javax.inject.Inject;
+
+public class LineSearchReduceFunction implements Reduce.ReduceFunction<Pair<Vector, Integer>> {
+
+ @Inject
+ public LineSearchReduceFunction() {
+ }
+
+ @Override
+ public Pair<Vector, Integer> apply(final Iterable<Pair<Vector, Integer>> evals) {
+
+ Vector combinedEvaluations = null;
+ int numEx = 0;
+
+ for (final Pair<Vector, Integer> eval : evals) {
+ if (combinedEvaluations == null) {
+ combinedEvaluations = new DenseVector(eval.first);
+ } else {
+ combinedEvaluations.add(eval.first);
+ }
+ numEx += eval.second;
+ }
+
+ return new Pair<>(combinedEvaluations, numEx);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/LossAndGradientReduceFunction.java
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/LossAndGradientReduceFunction.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/LossAndGradientReduceFunction.java
new file mode 100644
index 0000000..cf4d0be
--- /dev/null
+++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/LossAndGradientReduceFunction.java
@@ -0,0 +1,55 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.examples.group.bgd;
+
+import org.apache.reef.examples.group.utils.math.DenseVector;
+import org.apache.reef.examples.group.utils.math.Vector;
+import org.apache.reef.io.network.group.api.operators.Reduce.ReduceFunction;
+import org.apache.reef.io.network.util.Pair;
+
+import javax.inject.Inject;
+
+public class LossAndGradientReduceFunction
+ implements ReduceFunction<Pair<Pair<Double, Integer>, Vector>> {
+
+ @Inject
+ public LossAndGradientReduceFunction() {
+ }
+
+ @Override
+ public Pair<Pair<Double, Integer>, Vector> apply(
+ final Iterable<Pair<Pair<Double, Integer>, Vector>> lags) {
+
+ double lossSum = 0.0;
+ int numEx = 0;
+ Vector combinedGradient = null;
+
+ for (final Pair<Pair<Double, Integer>, Vector> lag : lags) {
+ if (combinedGradient == null) {
+ combinedGradient = new DenseVector(lag.second);
+ } else {
+ combinedGradient.add(lag.second);
+ }
+ lossSum += lag.first.first;
+ numEx += lag.first.second;
+ }
+
+ return new Pair<>(new Pair<>(lossSum, numEx), combinedGradient);
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/MasterTask.java
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/MasterTask.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/MasterTask.java
new file mode 100644
index 0000000..06ed5fd
--- /dev/null
+++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/MasterTask.java
@@ -0,0 +1,246 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.examples.group.bgd;
+
+import org.apache.reef.examples.group.bgd.operatornames.*;
+import org.apache.reef.examples.group.bgd.parameters.*;
+import org.apache.reef.examples.group.bgd.utils.StepSizes;
+import org.apache.reef.examples.group.utils.math.DenseVector;
+import org.apache.reef.examples.group.utils.math.Vector;
+import org.apache.reef.examples.group.utils.timer.Timer;
+import org.apache.reef.exception.evaluator.NetworkException;
+import org.apache.reef.io.Tuple;
+import org.apache.reef.io.network.group.api.operators.Broadcast;
+import org.apache.reef.io.network.group.api.operators.Reduce;
+import org.apache.reef.io.network.group.api.GroupChanges;
+import org.apache.reef.io.network.group.api.task.CommunicationGroupClient;
+import org.apache.reef.io.network.group.api.task.GroupCommClient;
+import org.apache.reef.io.network.util.Pair;
+import org.apache.reef.io.serialization.Codec;
+import org.apache.reef.io.serialization.SerializableCodec;
+import org.apache.reef.tang.annotations.Parameter;
+import org.apache.reef.task.Task;
+
+import javax.inject.Inject;
+import java.util.ArrayList;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+public class MasterTask implements Task {
+
+ public static final String TASK_ID = "MasterTask";
+
+ private static final Logger LOG = Logger.getLogger(MasterTask.class.getName());
+
+ private final CommunicationGroupClient communicationGroupClient;
+ private final Broadcast.Sender<ControlMessages> controlMessageBroadcaster;
+ private final Broadcast.Sender<Vector> modelBroadcaster;
+ private final Reduce.Receiver<Pair<Pair<Double, Integer>, Vector>> lossAndGradientReducer;
+ private final Broadcast.Sender<Pair<Vector, Vector>> modelAndDescentDirectionBroadcaster;
+ private final Broadcast.Sender<Vector> descentDriectionBroadcaster;
+ private final Reduce.Receiver<Pair<Vector, Integer>> lineSearchEvaluationsReducer;
+ private final Broadcast.Sender<Double> minEtaBroadcaster;
+ private final boolean ignoreAndContinue;
+ private final StepSizes ts;
+ private final double lambda;
+ private final int maxIters;
+ final ArrayList<Double> losses = new ArrayList<>();
+ final Codec<ArrayList<Double>> lossCodec = new SerializableCodec<ArrayList<Double>>();
+ private final Vector model;
+
+ boolean sendModel = true;
+ double minEta = 0;
+
+ @Inject
+ public MasterTask(
+ final GroupCommClient groupCommClient,
+ @Parameter(ModelDimensions.class) final int dimensions,
+ @Parameter(Lambda.class) final double lambda,
+ @Parameter(Iterations.class) final int maxIters,
+ @Parameter(EnableRampup.class) final boolean rampup,
+ final StepSizes ts) {
+
+ this.lambda = lambda;
+ this.maxIters = maxIters;
+ this.ts = ts;
+ this.ignoreAndContinue = rampup;
+ this.model = new DenseVector(dimensions);
+ this.communicationGroupClient = groupCommClient.getCommunicationGroup(AllCommunicationGroup.class);
+ this.controlMessageBroadcaster = communicationGroupClient.getBroadcastSender(ControlMessageBroadcaster.class);
+ this.modelBroadcaster = communicationGroupClient.getBroadcastSender(ModelBroadcaster.class);
+ this.lossAndGradientReducer = communicationGroupClient.getReduceReceiver(LossAndGradientReducer.class);
+ this.modelAndDescentDirectionBroadcaster = communicationGroupClient.getBroadcastSender(ModelAndDescentDirectionBroadcaster.class);
+ this.descentDriectionBroadcaster = communicationGroupClient.getBroadcastSender(DescentDirectionBroadcaster.class);
+ this.lineSearchEvaluationsReducer = communicationGroupClient.getReduceReceiver(LineSearchEvaluationsReducer.class);
+ this.minEtaBroadcaster = communicationGroupClient.getBroadcastSender(MinEtaBroadcaster.class);
+ }
+
+ @Override
+ public byte[] call(final byte[] memento) throws Exception {
+
+ double gradientNorm = Double.MAX_VALUE;
+ for (int iteration = 1; !converged(iteration, gradientNorm); ++iteration) {
+ try (final Timer t = new Timer("Current Iteration(" + (iteration) + ")")) {
+ final Pair<Double, Vector> lossAndGradient = computeLossAndGradient();
+ losses.add(lossAndGradient.first);
+ final Vector descentDirection = getDescentDirection(lossAndGradient.second);
+
+ updateModel(descentDirection);
+
+ gradientNorm = descentDirection.norm2();
+ }
+ }
+ LOG.log(Level.INFO, "OUT: Stop");
+ controlMessageBroadcaster.send(ControlMessages.Stop);
+
+ for (final Double loss : losses) {
+ LOG.log(Level.INFO, "OUT: LOSS = {0}", loss);
+ }
+ return lossCodec.encode(losses);
+ }
+
+ private void updateModel(final Vector descentDirection) throws NetworkException, InterruptedException {
+ try (final Timer t = new Timer("GetDescentDirection + FindMinEta + UpdateModel")) {
+ final Vector lineSearchEvals = lineSearch(descentDirection);
+ minEta = findMinEta(model, descentDirection, lineSearchEvals);
+ model.multAdd(minEta, descentDirection);
+ }
+
+ LOG.log(Level.INFO, "OUT: New Model = {0}", model);
+ }
+
+ private Vector lineSearch(final Vector descentDirection) throws NetworkException, InterruptedException {
+ Vector lineSearchResults = null;
+ boolean allDead = false;
+ do {
+ try (final Timer t = new Timer("LineSearch - Broadcast("
+ + (sendModel ? "ModelAndDescentDirection" : "DescentDirection") + ") + Reduce(LossEvalsInLineSearch)")) {
+ if (sendModel) {
+ LOG.log(Level.INFO, "OUT: DoLineSearchWithModel");
+ controlMessageBroadcaster.send(ControlMessages.DoLineSearchWithModel);
+ modelAndDescentDirectionBroadcaster.send(new Pair<>(model, descentDirection));
+ } else {
+ LOG.log(Level.INFO, "OUT: DoLineSearch");
+ controlMessageBroadcaster.send(ControlMessages.DoLineSearch);
+ descentDriectionBroadcaster.send(descentDirection);
+ }
+ final Pair<Vector, Integer> lineSearchEvals = lineSearchEvaluationsReducer.reduce();
+ if (lineSearchEvals != null) {
+ final int numExamples = lineSearchEvals.second;
+ lineSearchResults = lineSearchEvals.first;
+ lineSearchResults.scale(1.0 / numExamples);
+ LOG.log(Level.INFO, "OUT: #Examples: {0}", numExamples);
+ LOG.log(Level.INFO, "OUT: LineSearchEvals: {0}", lineSearchResults);
+ allDead = false;
+ } else {
+ allDead = true;
+ }
+ }
+
+ sendModel = chkAndUpdate();
+ } while (allDead || (!ignoreAndContinue && sendModel));
+ return lineSearchResults;
+ }
+
+ private Pair<Double, Vector> computeLossAndGradient() throws NetworkException, InterruptedException {
+ Pair<Double, Vector> returnValue = null;
+ boolean allDead = false;
+ do {
+ try (final Timer t = new Timer("Broadcast(" + (sendModel ? "Model" : "MinEta") + ") + Reduce(LossAndGradient)")) {
+ if (sendModel) {
+ LOG.log(Level.INFO, "OUT: ComputeGradientWithModel");
+ controlMessageBroadcaster.send(ControlMessages.ComputeGradientWithModel);
+ modelBroadcaster.send(model);
+ } else {
+ LOG.log(Level.INFO, "OUT: ComputeGradientWithMinEta");
+ controlMessageBroadcaster.send(ControlMessages.ComputeGradientWithMinEta);
+ minEtaBroadcaster.send(minEta);
+ }
+ final Pair<Pair<Double, Integer>, Vector> lossAndGradient = lossAndGradientReducer.reduce();
+
+ if (lossAndGradient != null) {
+ final int numExamples = lossAndGradient.first.second;
+ LOG.log(Level.INFO, "OUT: #Examples: {0}", numExamples);
+ final double lossPerExample = lossAndGradient.first.first / numExamples;
+ LOG.log(Level.INFO, "OUT: Loss: {0}", lossPerExample);
+ final double objFunc = ((lambda / 2) * model.norm2Sqr()) + lossPerExample;
+ LOG.log(Level.INFO, "OUT: Objective Func Value: {0}", objFunc);
+ final Vector gradient = lossAndGradient.second;
+ gradient.scale(1.0 / numExamples);
+ LOG.log(Level.INFO, "OUT: Gradient: {0}", gradient);
+ returnValue = new Pair<>(objFunc, gradient);
+ allDead = false;
+ } else {
+ allDead = true;
+ }
+ }
+ sendModel = chkAndUpdate();
+ } while (allDead || (!ignoreAndContinue && sendModel));
+ return returnValue;
+ }
+
+ private boolean chkAndUpdate() {
+ long t1 = System.currentTimeMillis();
+ final GroupChanges changes = communicationGroupClient.getTopologyChanges();
+ long t2 = System.currentTimeMillis();
+ LOG.log(Level.INFO, "OUT: Time to get TopologyChanges = " + (t2 - t1) / 1000.0 + " sec");
+ if (changes.exist()) {
+ LOG.log(Level.INFO, "OUT: There exist topology changes. Asking to update Topology");
+ t1 = System.currentTimeMillis();
+ communicationGroupClient.updateTopology();
+ t2 = System.currentTimeMillis();
+ LOG.log(Level.INFO, "OUT: Time to get TopologyChanges = " + (t2 - t1) / 1000.0 + " sec");
+ return true;
+ } else {
+ LOG.log(Level.INFO, "OUT: No changes in topology exist. So not updating topology");
+ return false;
+ }
+ }
+
+ private boolean converged(final int iters, final double gradNorm) {
+ return iters >= maxIters || Math.abs(gradNorm) <= 1e-3;
+ }
+
+ private double findMinEta(final Vector model, final Vector descentDir, final Vector lineSearchEvals) {
+ final double wNormSqr = model.norm2Sqr();
+ final double dNormSqr = descentDir.norm2Sqr();
+ final double wDotd = model.dot(descentDir);
+ final double[] t = ts.getT();
+ int i = 0;
+ for (final double eta : t) {
+ final double modelNormSqr = wNormSqr + (eta * eta) * dNormSqr + 2 * eta * wDotd;
+ final double loss = lineSearchEvals.get(i) + ((lambda / 2) * modelNormSqr);
+ lineSearchEvals.set(i, loss);
+ ++i;
+ }
+ LOG.log(Level.INFO, "OUT: Regularized LineSearchEvals: {0}", lineSearchEvals);
+ final Tuple<Integer, Double> minTup = lineSearchEvals.min();
+ LOG.log(Level.INFO, "OUT: MinTup: {0}", minTup);
+ final double minT = t[minTup.getKey()];
+ LOG.log(Level.INFO, "OUT: MinT: {0}", minT);
+ return minT;
+ }
+
+ private Vector getDescentDirection(final Vector gradient) {
+ gradient.multAdd(lambda, model);
+ gradient.scale(-1);
+ LOG.log(Level.INFO, "OUT: DescentDirection: {0}", gradient);
+ return gradient;
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/SlaveTask.java
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/SlaveTask.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/SlaveTask.java
new file mode 100644
index 0000000..fadc16e
--- /dev/null
+++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/SlaveTask.java
@@ -0,0 +1,204 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.examples.group.bgd;
+
+import org.apache.reef.examples.group.bgd.data.Example;
+import org.apache.reef.examples.group.bgd.loss.LossFunction;
+import org.apache.reef.examples.group.bgd.operatornames.*;
+import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup;
+import org.apache.reef.examples.group.bgd.parameters.ProbabilityOfFailure;
+import org.apache.reef.examples.group.bgd.utils.StepSizes;
+import org.apache.reef.examples.group.utils.math.DenseVector;
+import org.apache.reef.examples.group.utils.math.Vector;
+import org.apache.reef.io.network.group.api.operators.Broadcast;
+import org.apache.reef.io.network.group.api.operators.Reduce;
+import org.apache.reef.io.network.group.api.task.CommunicationGroupClient;
+import org.apache.reef.io.network.group.api.task.GroupCommClient;
+import org.apache.reef.io.network.util.Pair;
+import org.apache.reef.tang.annotations.Parameter;
+import org.apache.reef.task.Task;
+
+import javax.inject.Inject;
+import java.util.List;
+import java.util.logging.Logger;
+
+public class SlaveTask implements Task {
+
+ private static final Logger LOG = Logger.getLogger(SlaveTask.class.getName());
+
+ private final double FAILURE_PROB;
+
+ private final CommunicationGroupClient communicationGroup;
+ private final Broadcast.Receiver<ControlMessages> controlMessageBroadcaster;
+ private final Broadcast.Receiver<Vector> modelBroadcaster;
+ private final Reduce.Sender<Pair<Pair<Double, Integer>, Vector>> lossAndGradientReducer;
+ private final Broadcast.Receiver<Pair<Vector, Vector>> modelAndDescentDirectionBroadcaster;
+ private final Broadcast.Receiver<Vector> descentDirectionBroadcaster;
+ private final Reduce.Sender<Pair<Vector, Integer>> lineSearchEvaluationsReducer;
+ private final Broadcast.Receiver<Double> minEtaBroadcaster;
+ private List<Example> examples = null;
+ private final ExampleList dataSet;
+ private final LossFunction lossFunction;
+ private final StepSizes ts;
+
+ private Vector model = null;
+ private Vector descentDirection = null;
+
+ @Inject
+ public SlaveTask(
+ final GroupCommClient groupCommClient,
+ final ExampleList dataSet,
+ final LossFunction lossFunction,
+ @Parameter(ProbabilityOfFailure.class) final double pFailure,
+ final StepSizes ts) {
+
+ this.dataSet = dataSet;
+ this.lossFunction = lossFunction;
+ this.FAILURE_PROB = pFailure;
+ LOG.info("Using pFailure=" + this.FAILURE_PROB);
+ this.ts = ts;
+
+ this.communicationGroup = groupCommClient.getCommunicationGroup(AllCommunicationGroup.class);
+ this.controlMessageBroadcaster = communicationGroup.getBroadcastReceiver(ControlMessageBroadcaster.class);
+ this.modelBroadcaster = communicationGroup.getBroadcastReceiver(ModelBroadcaster.class);
+ this.lossAndGradientReducer = communicationGroup.getReduceSender(LossAndGradientReducer.class);
+ this.modelAndDescentDirectionBroadcaster = communicationGroup.getBroadcastReceiver(ModelAndDescentDirectionBroadcaster.class);
+ this.descentDirectionBroadcaster = communicationGroup.getBroadcastReceiver(DescentDirectionBroadcaster.class);
+ this.lineSearchEvaluationsReducer = communicationGroup.getReduceSender(LineSearchEvaluationsReducer.class);
+ this.minEtaBroadcaster = communicationGroup.getBroadcastReceiver(MinEtaBroadcaster.class);
+ }
+
+ @Override
+ public byte[] call(final byte[] memento) throws Exception {
+ /*
+ * In the case where there will be evaluator failure and data is not in
+ * memory we want to load the data while waiting to join the communication
+ * group
+ */
+ loadData();
+
+ for (boolean repeat = true; repeat; ) {
+
+ final ControlMessages controlMessage = controlMessageBroadcaster.receive();
+ switch (controlMessage) {
+
+ case Stop:
+ repeat = false;
+ break;
+
+ case ComputeGradientWithModel:
+ failPerhaps();
+ this.model = modelBroadcaster.receive();
+ lossAndGradientReducer.send(computeLossAndGradient());
+ break;
+
+ case ComputeGradientWithMinEta:
+ failPerhaps();
+ final double minEta = minEtaBroadcaster.receive();
+ assert (descentDirection != null);
+ this.descentDirection.scale(minEta);
+ assert (model != null);
+ this.model.add(descentDirection);
+ lossAndGradientReducer.send(computeLossAndGradient());
+ break;
+
+ case DoLineSearch:
+ failPerhaps();
+ this.descentDirection = descentDirectionBroadcaster.receive();
+ lineSearchEvaluationsReducer.send(lineSearchEvals());
+ break;
+
+ case DoLineSearchWithModel:
+ failPerhaps();
+ final Pair<Vector, Vector> modelAndDescentDir = modelAndDescentDirectionBroadcaster.receive();
+ this.model = modelAndDescentDir.first;
+ this.descentDirection = modelAndDescentDir.second;
+ lineSearchEvaluationsReducer.send(lineSearchEvals());
+ break;
+
+ default:
+ break;
+ }
+ }
+
+ return null;
+ }
+
+ private void failPerhaps() {
+ if (Math.random() < FAILURE_PROB) {
+ throw new RuntimeException("Simulated Failure");
+ }
+ }
+
+ private Pair<Vector, Integer> lineSearchEvals() {
+
+ if (examples == null) {
+ loadData();
+ }
+
+ final Vector zed = new DenseVector(examples.size());
+ final Vector ee = new DenseVector(examples.size());
+
+ for (int i = 0; i < examples.size(); i++) {
+ final Example example = examples.get(i);
+ double f = example.predict(model);
+ zed.set(i, f);
+ f = example.predict(descentDirection);
+ ee.set(i, f);
+ }
+
+ final double[] t = ts.getT();
+ final Vector evaluations = new DenseVector(t.length);
+ int i = 0;
+ for (final double d : t) {
+ double loss = 0;
+ for (int j = 0; j < examples.size(); j++) {
+ final Example example = examples.get(j);
+ final double val = zed.get(j) + d * ee.get(j);
+ loss += this.lossFunction.computeLoss(example.getLabel(), val);
+ }
+ evaluations.set(i++, loss);
+ }
+
+ return new Pair<>(evaluations, examples.size());
+ }
+
+ private Pair<Pair<Double, Integer>, Vector> computeLossAndGradient() {
+
+ if (examples == null) {
+ loadData();
+ }
+
+ final Vector gradient = new DenseVector(model.size());
+ double loss = 0.0;
+ for (final Example example : examples) {
+ final double f = example.predict(model);
+ final double g = this.lossFunction.computeGradient(example.getLabel(), f);
+ example.addGradient(gradient, g);
+ loss += this.lossFunction.computeLoss(example.getLabel(), f);
+ }
+
+ return new Pair<>(new Pair<>(loss, examples.size()), gradient);
+ }
+
+ private void loadData() {
+ LOG.info("Loading data");
+ examples = dataSet.getExamples();
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/Example.java
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/Example.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/Example.java
new file mode 100644
index 0000000..2ec7146
--- /dev/null
+++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/Example.java
@@ -0,0 +1,52 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.examples.group.bgd.data;
+
+import org.apache.reef.examples.group.utils.math.Vector;
+
+import java.io.Serializable;
+
+/**
+ * Base interface for Examples for linear models.
+ */
+public interface Example extends Serializable {
+
+ /**
+ * Access to the label.
+ *
+ * @return the label
+ */
+ double getLabel();
+
+ /**
+ * Computes the prediction for this Example, given the model w.
+ * <p/>
+ * w.dot(this.getFeatures())
+ *
+ * @param w the model
+ * @return the prediction for this Example, given the model w.
+ */
+ double predict(Vector w);
+
+ /**
+ * Adds the current example's gradient to the gradientVector, assuming that
+ * the gradient with respect to the prediction is gradient.
+ */
+ void addGradient(Vector gradientVector, double gradient);
+}
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/SparseExample.java
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/SparseExample.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/SparseExample.java
new file mode 100644
index 0000000..094f1d8
--- /dev/null
+++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/SparseExample.java
@@ -0,0 +1,68 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.examples.group.bgd.data;
+
+import org.apache.reef.examples.group.utils.math.Vector;
+
+/**
+ * Example implementation on a index and value array.
+ */
+public final class SparseExample implements Example {
+
+ private static final long serialVersionUID = -2127500625316875426L;
+
+ private final float[] values;
+ private final int[] indices;
+ private final double label;
+
+ public SparseExample(final double label, final float[] values, final int[] indices) {
+ this.label = label;
+ this.values = values;
+ this.indices = indices;
+ }
+
+ public int getFeatureLength() {
+ return this.values.length;
+ }
+
+ @Override
+ public double getLabel() {
+ return this.label;
+ }
+
+ @Override
+ public double predict(final Vector w) {
+ double result = 0.0;
+ for (int i = 0; i < this.indices.length; ++i) {
+ result += w.get(this.indices[i]) * this.values[i];
+ }
+ return result;
+ }
+
+ @Override
+ public void addGradient(final Vector gradientVector, final double gradient) {
+ for (int i = 0; i < this.indices.length; ++i) {
+ final int index = this.indices[i];
+ final double contribution = gradient * this.values[i];
+ final double oldValue = gradientVector.get(index);
+ final double newValue = oldValue + contribution;
+ gradientVector.set(index, newValue);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/parser/Parser.java
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/parser/Parser.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/parser/Parser.java
new file mode 100644
index 0000000..f4d8d09
--- /dev/null
+++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/parser/Parser.java
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.examples.group.bgd.data.parser;
+
+import org.apache.reef.examples.group.bgd.data.Example;
+
+/**
+ * Parses inputs into Examples.
+ *
+ * @param <T>
+ */
+public interface Parser<T> {
+
+ public Example parse(final T input);
+
+}
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/parser/SVMLightParser.java
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/parser/SVMLightParser.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/parser/SVMLightParser.java
new file mode 100644
index 0000000..5f64606
--- /dev/null
+++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/data/parser/SVMLightParser.java
@@ -0,0 +1,98 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.examples.group.bgd.data.parser;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.reef.examples.group.bgd.data.Example;
+import org.apache.reef.examples.group.bgd.data.SparseExample;
+
+import javax.inject.Inject;
+import java.io.BufferedReader;
+import java.io.FileReader;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * A Parser for SVMLight records
+ */
+public class SVMLightParser implements Parser<String> {
+
+ private static final Logger LOG = Logger.getLogger(SVMLightParser.class.getName());
+
+ @Inject
+ public SVMLightParser() {
+ }
+
+ @Override
+ public Example parse(final String line) {
+
+ final int entriesCount = StringUtils.countMatches(line, ":");
+ final int[] indices = new int[entriesCount];
+ final float[] values = new float[entriesCount];
+
+ final String[] entries = StringUtils.split(line, ' ');
+ String labelStr = entries[0];
+
+ final boolean pipeExists = labelStr.indexOf('|') != -1;
+ if (pipeExists) {
+ labelStr = labelStr.substring(0, labelStr.indexOf('|'));
+ }
+ double label = Double.parseDouble(labelStr);
+
+ if (label != 1) {
+ label = -1;
+ }
+
+ for (int j = 1; j < entries.length; ++j) {
+ final String x = entries[j];
+ final String[] entity = StringUtils.split(x, ':');
+ final int offset = pipeExists ? 0 : 1;
+ indices[j - 1] = Integer.parseInt(entity[0]) - offset;
+ values[j - 1] = Float.parseFloat(entity[1]);
+ }
+ return new SparseExample(label, values, indices);
+ }
+
+ public static void main(final String[] args) {
+ final Parser<String> parser = new SVMLightParser();
+ for (int i = 0; i < 10; i++) {
+ final List<SparseExample> examples = new ArrayList<>();
+ float avgFtLen = 0;
+ try (final BufferedReader br = new BufferedReader(new FileReader(
+ "C:\\Users\\shravan\\data\\splice\\hdi\\hdi_uncomp\\part-r-0000" + i))) {
+ String line = null;
+ while ((line = br.readLine()) != null) {
+ final SparseExample spEx = (SparseExample) parser.parse(line);
+ avgFtLen += spEx.getFeatureLength();
+ examples.add(spEx);
+ }
+ } catch (final IOException e) {
+ throw new RuntimeException("Exception", e);
+ }
+
+ LOG.log(Level.INFO, "OUT: {0} {1} {2}",
+ new Object[] { examples.size(), avgFtLen, avgFtLen / examples.size() });
+
+ examples.clear();
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/LogisticLossFunction.java
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/LogisticLossFunction.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/LogisticLossFunction.java
new file mode 100644
index 0000000..78eb16f
--- /dev/null
+++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/LogisticLossFunction.java
@@ -0,0 +1,50 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.examples.group.bgd.loss;
+
+import javax.inject.Inject;
+
+public final class LogisticLossFunction implements LossFunction {
+
+ /**
+ * Trivial constructor.
+ */
+ @Inject
+ public LogisticLossFunction() {
+ }
+
+ @Override
+ public double computeLoss(final double y, final double f) {
+ final double predictedTimesLabel = y * f;
+ return Math.log(1 + Math.exp(-predictedTimesLabel));
+ }
+
+ @Override
+ public double computeGradient(final double y, final double f) {
+ final double predictedTimesLabel = y * f;
+ return -y / (1 + Math.exp(predictedTimesLabel));
+ }
+
+ @Override
+ public String toString() {
+ return "LogisticLossFunction{}";
+ }
+}
+
+
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/LossFunction.java
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/LossFunction.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/LossFunction.java
new file mode 100644
index 0000000..e762add
--- /dev/null
+++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/LossFunction.java
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.examples.group.bgd.loss;
+
+import org.apache.reef.tang.annotations.DefaultImplementation;
+
+/**
+ * Interface for Loss Functions.
+ */
+@DefaultImplementation(SquaredErrorLossFunction.class)
+public interface LossFunction {
+
+ /**
+ * Computes the loss incurred by predicting f, if y is the true label.
+ *
+ * @param y the label
+ * @param f the prediction
+ * @return the loss incurred by predicting f, if y is the true label.
+ */
+ double computeLoss(final double y, final double f);
+
+ /**
+ * Computes the gradient with respect to f, if y is the true label.
+ *
+ * @param y the label
+ * @param f the prediction
+ * @return the gradient with respect to f
+ */
+ double computeGradient(final double y, final double f);
+}
http://git-wip-us.apache.org/repos/asf/incubator-reef/blob/6c6ad336/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/SquaredErrorLossFunction.java
----------------------------------------------------------------------
diff --git a/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/SquaredErrorLossFunction.java b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/SquaredErrorLossFunction.java
new file mode 100644
index 0000000..327f566
--- /dev/null
+++ b/lang/java/reef-examples/src/main/java/org/apache/reef/examples/group/bgd/loss/SquaredErrorLossFunction.java
@@ -0,0 +1,49 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.reef.examples.group.bgd.loss;
+
+import javax.inject.Inject;
+
+/**
+ * The Squared Error {@link LossFunction}.
+ */
+public class SquaredErrorLossFunction implements LossFunction {
+
+ /**
+ * Trivial constructor.
+ */
+ @Inject
+ public SquaredErrorLossFunction() {
+ }
+
+ @Override
+ public double computeLoss(double y, double f) {
+ return Math.pow(y - f, 2.0);
+ }
+
+ @Override
+ public double computeGradient(double y, double f) {
+ return (f - y) * 0.5;
+ }
+
+ @Override
+ public String toString() {
+ return "SquaredErrorLossFunction{}";
+ }
+}