diff --git a/tony-core/src/main/java/com/linkedin/tony/Constants.java b/tony-core/src/main/java/com/linkedin/tony/Constants.java index 27458c37..7665fa38 100644 --- a/tony-core/src/main/java/com/linkedin/tony/Constants.java +++ b/tony-core/src/main/java/com/linkedin/tony/Constants.java @@ -158,6 +158,9 @@ public class Constants { public static final String VCORES = "vcores"; public static final String GPUS = "gpus"; + public static final String ALLOCATION_TAGS = "allocation-tags"; + public static final String PLACEMENT_SPEC = "placement-spec"; + // pid environment variable set by YARN public static final String JVM_PID = "JVM_PID"; diff --git a/tony-core/src/main/java/com/linkedin/tony/HadoopCompatibleAdapter.java b/tony-core/src/main/java/com/linkedin/tony/HadoopCompatibleAdapter.java index 9fe3d40b..8af958dc 100644 --- a/tony-core/src/main/java/com/linkedin/tony/HadoopCompatibleAdapter.java +++ b/tony-core/src/main/java/com/linkedin/tony/HadoopCompatibleAdapter.java @@ -18,21 +18,37 @@ import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicLong; -import org.apache.commons.lang.StringUtils; +import com.linkedin.tony.models.JobContainerRequest; +import com.linkedin.tony.util.Utils; +import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.yarn.api.records.Container; +import org.apache.hadoop.yarn.api.records.ExecutionTypeRequest; +import org.apache.hadoop.yarn.api.records.Priority; import org.apache.hadoop.yarn.api.records.Resource; import org.apache.hadoop.yarn.api.records.ResourceInformation; +import org.apache.hadoop.yarn.client.api.AMRMClient; +import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync; import org.apache.hadoop.yarn.util.resource.ResourceUtils; public final class HadoopCompatibleAdapter { private static final Log LOG = LogFactory.getLog(HadoopCompatibleAdapter.class); + private static final AtomicLong ALLOCATE_ID_COUNTER = new AtomicLong(1); + private HadoopCompatibleAdapter() { } @@ -157,4 +173,61 @@ public static boolean existGPUResource() { return false; } } + + public static void constructAndAddSchedulingRequest(AMRMClientAsync amRMClient, + JobContainerRequest containerRequest) { + try { + List reqs = new ArrayList<>(); + Object schedReq = constructSchedulingRequest(containerRequest); + LOG.info("Request schedling containers ask: " + schedReq); + for (int i = 0; i < containerRequest.getNumInstances(); i++) { + reqs.add(schedReq); + } + Method addMethod = Arrays.stream(amRMClient.getClass().getMethods()) + .filter(x -> x.getName().equals("addSchedulingRequests") && x.getParameterCount() == 1) + .findFirst().get(); + addMethod.invoke(amRMClient, reqs); + } catch (Exception e) { + throw new RuntimeException("Errors on adding scheduing request.", e); + } + } + + private static Object constructSchedulingRequest(JobContainerRequest containerRequest) { + try { + Priority priority = Priority.newInstance(containerRequest.getPriority()); + Resource capability = Resource.newInstance((int) containerRequest.getMemory(), containerRequest.getVCores()); + if (containerRequest.getGPU() > 0) { + Utils.setCapabilityGPU(capability, containerRequest.getGPU()); + } + Set allocationTags = CollectionUtils.isEmpty(containerRequest.getAllocationTags()) + ? Collections.singleton("") : new HashSet<>(containerRequest.getAllocationTags()); + + Class placementConstraintCls = + Class.forName("org.apache.hadoop.yarn.util.constraint.PlacementConstraintParser"); + Method parseMethod = placementConstraintCls.getMethod("parseExpression", String.class); + + Object parsedObj = parseMethod.invoke(placementConstraintCls, containerRequest.getPlacementSpec()); + Class abstractConstraintCls = + Class.forName("org.apache.hadoop.yarn.api.resource.PlacementConstraint$AbstractConstraint"); + + Object placementConstraintObj = abstractConstraintCls.getMethod("build").invoke(parsedObj); + + Class resourceSizingCls = Class.forName("org.apache.hadoop.yarn.api.records.ResourceSizing"); + Method resourceSizingMethod = Arrays.stream(resourceSizingCls.getMethods()) + .filter(x -> x.getName().equals("newInstance") && x.getParameterCount() == 1).findFirst().get(); + Object resourceSizingObj = resourceSizingMethod.invoke(null, capability); + + Class schedulingReqCls = Class.forName("org.apache.hadoop.yarn.api.records.SchedulingRequest"); + Method newInstanceMethod = Arrays.stream(schedulingReqCls.getMethods()) + .filter(x -> x.getName().equals("newInstance") && x.getParameterCount() == 6).findFirst().get(); + + Object schedReq = newInstanceMethod.invoke(null, ALLOCATE_ID_COUNTER.incrementAndGet(), priority, + ExecutionTypeRequest.newInstance(), allocationTags, + resourceSizingObj, placementConstraintObj); + + return schedReq; + } catch (ClassNotFoundException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + throw new RuntimeException("Errors on constructing scheduling requests of Yarn.", e); + } + } } diff --git a/tony-core/src/main/java/com/linkedin/tony/TaskScheduler.java b/tony-core/src/main/java/com/linkedin/tony/TaskScheduler.java index fcda834b..3ba8513a 100644 --- a/tony-core/src/main/java/com/linkedin/tony/TaskScheduler.java +++ b/tony-core/src/main/java/com/linkedin/tony/TaskScheduler.java @@ -36,7 +36,6 @@ public class TaskScheduler { // job with dependency -> (dependent job name, number of instances for that job) private Map> taskDependencyMap = new HashMap<>(); private Map localResources; - private Map> jobTypeToContainerRequestsMap = new HashMap<>(); private Map> jobTypeToContainerResources; boolean dependencyCheckPassed = true; @@ -90,16 +89,20 @@ boolean checkDependencySatisfied(JobContainerRequest request) { } private void scheduleJob(JobContainerRequest request) { - AMRMClient.ContainerRequest containerAsk = Utils.setupContainerRequestForRM(request); - String jobName = request.getJobName(); - if (!jobTypeToContainerRequestsMap.containsKey(jobName)) { - jobTypeToContainerRequestsMap.put(jobName, new ArrayList<>()); - jobTypeToContainerResources.put(jobName, getContainerResources(jobName)); - } - jobTypeToContainerRequestsMap.get(request.getJobName()).add(containerAsk); - for (int i = 0; i < request.getNumInstances(); i++) { - amRMClient.addContainerRequest(containerAsk); + if (request.getPlacementSpec() != null) { + // this should use newer api of Yarn with this placement constraint feature, + // only be supported in hadoop 3.2.x + HadoopCompatibleAdapter.constructAndAddSchedulingRequest(amRMClient, request); + } else { + AMRMClient.ContainerRequest containerAsk = Utils.setupContainerRequestForRM(request); + for (int i = 0; i < request.getNumInstances(); i++) { + amRMClient.addContainerRequest(containerAsk); + } } + + String jobName = request.getJobName(); + jobTypeToContainerResources.putIfAbsent(jobName, getContainerResources(jobName)); + session.addNumExpectedTask(request.getNumInstances()); } diff --git a/tony-core/src/main/java/com/linkedin/tony/TonyClient.java b/tony-core/src/main/java/com/linkedin/tony/TonyClient.java index d51856af..764e13fa 100644 --- a/tony-core/src/main/java/com/linkedin/tony/TonyClient.java +++ b/tony-core/src/main/java/com/linkedin/tony/TonyClient.java @@ -53,7 +53,7 @@ import org.apache.commons.cli.Options; import org.apache.commons.cli.ParseException; import org.apache.commons.io.FileUtils; -import org.apache.commons.lang.StringUtils; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; diff --git a/tony-core/src/main/java/com/linkedin/tony/TonyConfigurationKeys.java b/tony-core/src/main/java/com/linkedin/tony/TonyConfigurationKeys.java index dd0b3d5b..a643f9f8 100644 --- a/tony-core/src/main/java/com/linkedin/tony/TonyConfigurationKeys.java +++ b/tony-core/src/main/java/com/linkedin/tony/TonyConfigurationKeys.java @@ -241,6 +241,22 @@ public static String getResourcesKey(String jobName) { return String.format(TONY_PREFIX + "%s.resources", jobName); } + public static String getPlacementSpecKey(String jobName) { + return String.format(TONY_PREFIX + "%s.placement-spec", jobName); + } + + public static String getAllocationSpecKey(String jobName) { + return String.format(TONY_PREFIX + "%s.allocation-tags", jobName); + } + + public static String getContainerPlacementSpecKey() { + return TONY_PREFIX + "containers.placement-spec"; + } + + public static String getContainerAllocationTagsKey() { + return TONY_PREFIX + "containers.allocation-tags"; + } + // Resources for all containers public static String getContainerResourcesKey() { return TONY_PREFIX + "containers.resources"; diff --git a/tony-core/src/main/java/com/linkedin/tony/TonySession.java b/tony-core/src/main/java/com/linkedin/tony/TonySession.java index bd0cd9d1..9194148e 100644 --- a/tony-core/src/main/java/com/linkedin/tony/TonySession.java +++ b/tony-core/src/main/java/com/linkedin/tony/TonySession.java @@ -21,7 +21,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; -import org.apache.commons.lang.StringUtils; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; diff --git a/tony-core/src/main/java/com/linkedin/tony/models/JobContainerRequest.java b/tony-core/src/main/java/com/linkedin/tony/models/JobContainerRequest.java index 7829170b..e6b00303 100644 --- a/tony-core/src/main/java/com/linkedin/tony/models/JobContainerRequest.java +++ b/tony-core/src/main/java/com/linkedin/tony/models/JobContainerRequest.java @@ -17,6 +17,25 @@ public class JobContainerRequest { private String nodeLabelsExpression; private List dependsOn; + private String placementSpec; + private List allocationTags; + + public JobContainerRequest(String jobName, int numInstances, long memory, int vCores, int gpu, int priority, + String nodeLabelsExpression, final List dependsOn, String placementSpec, + List allocationTags) { + this.numInstances = numInstances; + this.memory = memory; + this.vCores = vCores; + this.priority = priority; + this.gpu = gpu; + this.jobName = jobName; + this.nodeLabelsExpression = nodeLabelsExpression; + this.dependsOn = dependsOn; + this.placementSpec = placementSpec; + this.allocationTags = allocationTags; + } + + public JobContainerRequest(String jobName, int numInstances, long memory, int vCores, int gpu, int priority, String nodeLabelsExpression, final List dependsOn) { this.numInstances = numInstances; @@ -60,4 +79,12 @@ public String getNodeLabelsExpression() { public final List getDependsOn() { return dependsOn; } + + public String getPlacementSpec() { + return placementSpec; + } + + public List getAllocationTags() { + return allocationTags; + } } diff --git a/tony-core/src/main/java/com/linkedin/tony/runtime/HorovodRuntime.java b/tony-core/src/main/java/com/linkedin/tony/runtime/HorovodRuntime.java index 3907896f..6158ca22 100644 --- a/tony-core/src/main/java/com/linkedin/tony/runtime/HorovodRuntime.java +++ b/tony-core/src/main/java/com/linkedin/tony/runtime/HorovodRuntime.java @@ -25,7 +25,7 @@ import java.util.Map; import java.util.stream.Collectors; -import org.apache.commons.lang.StringUtils; +import org.apache.commons.lang3.StringUtils; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.yarn.api.records.FinalApplicationStatus; diff --git a/tony-core/src/main/java/com/linkedin/tony/runtime/MLGenericRuntime.java b/tony-core/src/main/java/com/linkedin/tony/runtime/MLGenericRuntime.java index 226ecdbc..7b613fbe 100644 --- a/tony-core/src/main/java/com/linkedin/tony/runtime/MLGenericRuntime.java +++ b/tony-core/src/main/java/com/linkedin/tony/runtime/MLGenericRuntime.java @@ -23,7 +23,7 @@ import java.util.stream.Collectors; import org.apache.commons.collections.CollectionUtils; -import org.apache.commons.lang.StringUtils; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.yarn.api.records.FinalApplicationStatus; diff --git a/tony-core/src/main/java/com/linkedin/tony/util/Utils.java b/tony-core/src/main/java/com/linkedin/tony/util/Utils.java index fce7ff88..deccca88 100644 --- a/tony-core/src/main/java/com/linkedin/tony/util/Utils.java +++ b/tony-core/src/main/java/com/linkedin/tony/util/Utils.java @@ -409,6 +409,17 @@ public static Map parseContainerRequests(Configurat TonyConfigurationKeys.DEFAULT_VCORES); int gpus = conf.getInt(TonyConfigurationKeys.getResourceKey(jobName, Constants.GPUS), TonyConfigurationKeys.DEFAULT_GPUS); + + String placementSpec = conf.get( + TonyConfigurationKeys.getPlacementSpecKey(jobName), + conf.get(TonyConfigurationKeys.getContainerPlacementSpecKey()) + ); + String[] allocationTagsArr = conf.getStrings( + TonyConfigurationKeys.getAllocationSpecKey(jobName), + conf.getStrings(TonyConfigurationKeys.getContainerAllocationTagsKey()) + ); + List allocationTags = allocationTagsArr == null ? null : new ArrayList<>(Arrays.asList(allocationTagsArr)); + if (gpus > 0 && !HadoopCompatibleAdapter.existGPUResource()) { throw new RuntimeException(String.format("User requested %d GPUs for job '%s' but GPU is not available on the cluster. ", gpus, jobName)); @@ -431,7 +442,7 @@ public static Map parseContainerRequests(Configurat // We rely on unique priority behavior to match allocation request to task in Hadoop 2.7 containerRequests.put(jobName, new JobContainerRequest(jobName, numInstances, memory, vCores, gpus, priority, - nodeLabel, dependsOn)); + nodeLabel, dependsOn, placementSpec, allocationTags)); priority++; } } diff --git a/tony-core/src/test/java/com/linkedin/tony/TestPortAllocation.java b/tony-core/src/test/java/com/linkedin/tony/TestPortAllocation.java index 7a19429c..c62a9d7a 100644 --- a/tony-core/src/test/java/com/linkedin/tony/TestPortAllocation.java +++ b/tony-core/src/test/java/com/linkedin/tony/TestPortAllocation.java @@ -12,7 +12,7 @@ import java.io.IOException; import java.net.BindException; import java.time.Duration; -import org.apache.commons.lang.SystemUtils; +import org.apache.commons.lang3.SystemUtils; import org.testng.annotations.Test; public class TestPortAllocation { diff --git a/tony-core/src/test/java/com/linkedin/tony/util/TestUtils.java b/tony-core/src/test/java/com/linkedin/tony/util/TestUtils.java index 3e1980c1..567f3213 100644 --- a/tony-core/src/test/java/com/linkedin/tony/util/TestUtils.java +++ b/tony-core/src/test/java/com/linkedin/tony/util/TestUtils.java @@ -120,6 +120,39 @@ public void testParseContainerRequestsShouldFail() { Utils.parseContainerRequests(conf); } + @Test + public void testParsePlacementSpecAndAllocationTags() { + Configuration conf = new Configuration(); + conf.addResource("tony-default.xml"); + conf.setInt("tony.worker.instances", 1); + + // case1: set nothing + Map containerRequests = Utils.parseContainerRequests(conf); + assertNull(containerRequests.get("worker").getPlacementSpec()); + assertNull(containerRequests.get("worker").getAllocationTags()); + + // case2: set all + conf = new Configuration(); + conf.addResource("tony-default.xml"); + conf.setInt("tony.worker.instances", 1); + conf.setStrings("tony.worker.placement-spec", "java=true"); + conf.setStrings("tony.worker.allocation-tags", "tony"); + containerRequests = Utils.parseContainerRequests(conf); + + assertEquals("java=true", containerRequests.get("worker").getPlacementSpec()); + assertEquals(Arrays.asList("tony"), containerRequests.get("worker").getAllocationTags()); + + // case3: set nothing for job, but it will fallback to container setting + conf = new Configuration(); + conf.addResource("tony-default.xml"); + conf.setInt("tony.worker.instances", 1); + conf.setStrings("tony.containers.placement-spec", "java=true"); + conf.setStrings("tony.worker.allocation-tags", "tony"); + containerRequests = Utils.parseContainerRequests(conf); + assertEquals("java=true", containerRequests.get("worker").getPlacementSpec()); + assertEquals(Arrays.asList("tony"), containerRequests.get("worker").getAllocationTags()); + } + @Test public void testIsArchive() { ClassLoader classLoader = getClass().getClassLoader();