Skip to content

Commit

Permalink
I changed the codes so that
Browse files Browse the repository at this point in the history
1) Tasks log more detailed information
2) reduce the dependency between the group communication service and the data loading service
3) reduce the number of fields using Optional
  • Loading branch information
Kijung-Shin committed Apr 28, 2015
1 parent 3237e68 commit 0478987
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 122 deletions.
11 changes: 8 additions & 3 deletions src/main/java/edu/snu/reef/flexion/core/ComputeTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import edu.snu.reef.flexion.groupcomm.interfaces.DataReduceSender;
import edu.snu.reef.flexion.groupcomm.interfaces.DataScatterReceiver;
import edu.snu.reef.flexion.groupcomm.names.*;
import org.apache.reef.driver.task.TaskConfigurationOptions;
import org.apache.reef.tang.annotations.Name;
import org.apache.reef.tang.annotations.Parameter;
import org.apache.reef.task.HeartBeatTriggerManager;
Expand All @@ -37,8 +38,10 @@
import java.util.logging.Logger;

public final class ComputeTask implements Task, TaskMessageSource {
public final static String TASK_ID_PREFIX = "CmpTask";
private final static Logger LOG = Logger.getLogger(ComputeTask.class.getName());
public final static String TASK_ID = "CmpTask";

private final String taskId;
private final UserComputeTask userComputeTask;
private final CommunicationGroupClient commGroup;
private final HeartBeatTriggerManager heartBeatTriggerManager;
Expand All @@ -49,17 +52,19 @@ public final class ComputeTask implements Task, TaskMessageSource {
@Inject
public ComputeTask(final GroupCommClient groupCommClient,
final UserComputeTask userComputeTask,
@Parameter(TaskConfigurationOptions.Identifier.class) String taskId,
@Parameter(CommunicationGroup.class) final String commGroupName,
final HeartBeatTriggerManager heartBeatTriggerManager) throws ClassNotFoundException {
this.userComputeTask = userComputeTask;
this.taskId = taskId;
this.commGroup = groupCommClient.getCommunicationGroup((Class<? extends Name<String>>) Class.forName(commGroupName));
this.ctrlMessageBroadcast = commGroup.getBroadcastReceiver(CtrlMsgBroadcast.class);
this.heartBeatTriggerManager = heartBeatTriggerManager;
}

@Override
public final byte[] call(final byte[] memento) throws Exception {
LOG.log(Level.INFO, "CmpTask commencing...");
LOG.log(Level.INFO, String.format("%s starting...", taskId));

userComputeTask.initialize();
int iteration=0;
Expand All @@ -86,7 +91,7 @@ private void receiveData(int iteration) throws Exception {
((DataScatterReceiver)userComputeTask).receiveScatterData(iteration,
commGroup.getScatterReceiver(DataScatter.class).receive());
}
};
}

private void sendData(int iteration) throws Exception {
if (userComputeTask.isGatherUsed()) {
Expand Down
19 changes: 10 additions & 9 deletions src/main/java/edu/snu/reef/flexion/core/ControllerTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import edu.snu.reef.flexion.groupcomm.interfaces.DataReduceReceiver;
import edu.snu.reef.flexion.groupcomm.interfaces.DataScatterSender;
import edu.snu.reef.flexion.groupcomm.names.*;
import org.apache.reef.driver.task.TaskConfigurationOptions;
import org.apache.reef.tang.annotations.Name;
import org.apache.reef.tang.annotations.Parameter;
import org.apache.reef.task.Task;
Expand All @@ -32,24 +33,28 @@
import java.util.logging.Logger;

public final class ControllerTask implements Task {
public final static String TASK_ID_PREFIX = "CtrlTask";
private final static Logger LOG = Logger.getLogger(ControllerTask.class.getName());
public final static String TASK_ID = "CtrlTask";

private final String taskId;
private final UserControllerTask userControllerTask;
private final CommunicationGroupClient commGroup;
private final Broadcast.Sender<CtrlMessage> ctrlMessageBroadcast;

@Inject
public ControllerTask(final GroupCommClient groupCommClient,
final UserControllerTask userControllerTask,
@Parameter(TaskConfigurationOptions.Identifier.class) String taskId,
@Parameter(CommunicationGroup.class) final String commGroupName) throws ClassNotFoundException {
this.commGroup = groupCommClient.getCommunicationGroup((Class<? extends Name<String>>) Class.forName(commGroupName));
this.userControllerTask = userControllerTask;
this.taskId = taskId;
this.ctrlMessageBroadcast = commGroup.getBroadcastSender(CtrlMsgBroadcast.class);
}

@Override
public final byte[] call(final byte[] memento) throws Exception {
LOG.log(Level.INFO, "CtrlTask commencing...");
LOG.log(Level.INFO, String.format("%s starting...", taskId));

int iteration = 0;
userControllerTask.initialize();
Expand All @@ -58,7 +63,7 @@ public final byte[] call(final byte[] memento) throws Exception {
sendData(iteration);
receiveData(iteration);
userControllerTask.run(iteration);
topologyChanged();
updateTopology();
iteration++;
}
ctrlMessageBroadcast.send(CtrlMessage.TERMINATE);
Expand All @@ -68,15 +73,11 @@ public final byte[] call(final byte[] memento) throws Exception {
}

/**
* Check if group communication topology has changed, and updates it if it has.
* @return true if topology has changed, false if not
* Update the group communication topology, if it has changed
*/
private final boolean topologyChanged() {
private final void updateTopology() {
if (commGroup.getTopologyChanges().exist()) {
commGroup.updateTopology();
return true;
} else {
return false;
}
}

Expand Down
7 changes: 6 additions & 1 deletion src/main/java/edu/snu/reef/flexion/core/CtrlMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
import java.io.Serializable;

public enum CtrlMessage implements Serializable {
TERMINATE, RUN

// run the next iteration of the main loop
RUN,

// break the main loop of @ComputeTask and run cleanup before terminating the current task
TERMINATE
}

35 changes: 25 additions & 10 deletions src/main/java/edu/snu/reef/flexion/core/FlexionDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,17 @@
import com.microsoft.reef.io.network.nggroup.impl.config.ReduceOperatorSpec;
import com.microsoft.reef.io.network.nggroup.impl.config.ScatterOperatorSpec;
import edu.snu.reef.flexion.groupcomm.names.*;
import edu.snu.reef.flexion.parameters.EvaluatorNum;
import org.apache.reef.driver.context.ActiveContext;
import org.apache.reef.driver.task.CompletedTask;
import org.apache.reef.driver.task.FailedTask;
import org.apache.reef.driver.task.TaskConfiguration;
import org.apache.reef.driver.task.TaskMessage;
import org.apache.reef.driver.task.*;
import org.apache.reef.evaluator.context.parameters.ContextIdentifier;
import org.apache.reef.io.data.loading.api.DataLoadingService;
import org.apache.reef.io.serialization.SerializableCodec;
import org.apache.reef.tang.Configuration;
import org.apache.reef.tang.Configurations;
import org.apache.reef.tang.Injector;
import org.apache.reef.tang.Tang;
import org.apache.reef.tang.annotations.Parameter;
import org.apache.reef.tang.annotations.Unit;
import org.apache.reef.tang.exceptions.InjectionException;
import org.apache.reef.wake.EventHandler;
Expand Down Expand Up @@ -103,9 +102,16 @@ public final class FlexionDriver {
* Map to record which stage is being executed by each evaluator which is identified by context id
*/
private final Map<String, Integer> contextToStageSequence;

/**
* The number of evaluators assigned for Compute Tasks
*/
private final Integer evalNum;

private final ObjectSerializableCodec<Long> codecLong = new ObjectSerializableCodec<>();
private final UserParameters userParameters;


/**
* This class is instantiated by TANG
*
Expand All @@ -120,7 +126,8 @@ public final class FlexionDriver {
private FlexionDriver(final GroupCommDriver groupCommDriver,
final DataLoadingService dataLoadingService,
final UserJobInfo userJobInfo,
final UserParameters userParameters)
final UserParameters userParameters,
@Parameter(EvaluatorNum.class) final Integer evalNum)
throws IllegalAccessException, InstantiationException,
NoSuchMethodException, InvocationTargetException {
this.groupCommDriver = groupCommDriver;
Expand All @@ -130,6 +137,7 @@ private FlexionDriver(final GroupCommDriver groupCommDriver,
this.commGroupDriverList = new LinkedList<>();
this.contextToStageSequence = new HashMap<>();
this.userParameters = userParameters;
this.evalNum = evalNum;
initializeCommDriver();
}

Expand All @@ -140,8 +148,7 @@ private void initializeCommDriver(){
int sequence = 0;
for (StageInfo stageInfo : stageInfoList) {
CommunicationGroupDriver commGroup = groupCommDriver.newCommunicationGroup(
stageInfo.getCommGroupName(),
dataLoadingService.getNumberOfPartitions() + 1);
stageInfo.getCommGroupName(), evalNum + 1);
commGroup.addBroadcast(CtrlMsgBroadcast.class,
BroadcastOperatorSpec.newBuilder()
.setSenderId(getCtrlTaskId(sequence))
Expand Down Expand Up @@ -249,6 +256,14 @@ public void onNext(final TaskMessage message) {
}
}

final class TaskRunningHandler implements EventHandler<RunningTask> {

@Override
public void onNext(RunningTask runningTask) {
LOG.info(runningTask.getId() + " has started.");
}
}

/**
* When a certain task completes, the following task is submitted
*/
Expand All @@ -270,7 +285,7 @@ public void onNext(CompletedTask completedTask) {
}
}

final class FailedTaskHandler implements EventHandler<FailedTask> {
final class TaskFailedHandler implements EventHandler<FailedTask> {

@Override
public void onNext(FailedTask failedTask) {
Expand Down Expand Up @@ -338,11 +353,11 @@ final private boolean isCtrlTaskId(String id){
}

final private String getCtrlTaskId(int sequence) {
return ControllerTask.TASK_ID + "-" + sequence;
return ControllerTask.TASK_ID_PREFIX + "-" + sequence;
}

final private String getCmpTaskId(int sequence) {
return ComputeTask.TASK_ID + "-" + sequence;
return ComputeTask.TASK_ID_PREFIX + "-" + sequence;
}


Expand Down
5 changes: 3 additions & 2 deletions src/main/java/edu/snu/reef/flexion/core/FlexionLauncher.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public final class FlexionLauncher {
private final FlexionParameters flexionParameters;

@Inject
private FlexionLauncher(FlexionParameters flexionParameters) {
private FlexionLauncher(final FlexionParameters flexionParameters) {
this.flexionParameters = flexionParameters;
}

Expand Down Expand Up @@ -89,7 +89,8 @@ private final Configuration getDriverConfWithDataLoad() {
.set(DriverConfiguration.ON_CONTEXT_ACTIVE, FlexionDriver.ActiveContextHandler.class)
.set(DriverConfiguration.ON_TASK_MESSAGE, FlexionDriver.TaskMessageHandler.class)
.set(DriverConfiguration.ON_TASK_COMPLETED, FlexionDriver.TaskCompletedHandler.class)
.set(DriverConfiguration.ON_TASK_FAILED, FlexionDriver.FailedTaskHandler.class);
.set(DriverConfiguration.ON_TASK_RUNNING, FlexionDriver.TaskRunningHandler.class)
.set(DriverConfiguration.ON_TASK_FAILED, FlexionDriver.TaskFailedHandler.class);

final EvaluatorRequest evalRequest = EvaluatorRequest.newBuilder()
.setNumber(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ public final String getIdentifier() {

public final Configuration getDriverConf() {
Configuration driverConf = Tang.Factory.getTang().newConfigurationBuilder()
.bindNamedParameter(EvaluatorNum.class, String.valueOf(evalNum))
.bindImplementation(UserJobInfo.class, userJobInfo.getClass())
.bindImplementation(UserParameters.class, userParameters.getClass())
.build();
Expand Down
Loading

0 comments on commit 0478987

Please sign in to comment.