From 408d9e6c61a3f2d965562aa8c5ef856d7e2ed919 Mon Sep 17 00:00:00 2001 From: Rich Scott Date: Fri, 7 Feb 2025 13:07:28 -0700 Subject: [PATCH 1/2] Scala client protobuf generation and build container (Take 2) (#4192) * Update magefiles for Scala build tasks Signed-off-by: Rich Scott * Add support for Scala client to Armada. These changes were authored by Clif Houck and Rich Scott , with additional suggestions by Enrico Minack . These changes add a new Docker container for compiling the .proto definitions to Scala language bindings, using the ScalaPB protobuf compiler (https://scalapb.github.io/), to prepare for running Spark on top of Armada (as opposed to Kubernetes). A Docker container is built on an Ubuntu base, with the Scala 2 compiler and the `sbt` Scala build tool. Generated `xxxx.scala` protobuf bindings are written in `client/scala/scala-armada-client`, and will eventually be compiled and packaged in a .jar file. Just as the existing Python client does, we do not commit the generated artifacts to the repo (see `client/scala/scala-armada-client/.gitignore`). See client/scala/examples/ for examples on using this. Fixes https://github.com/G-Research/spark/issues/21 Signed-off-by: Rich Scott * Simplify construction of ArmadaClient Signed-off-by: Rich Scott * Build examples projects with manual copying client jar * Add mavenLocal resolver * Remove superfluous directives in build.sbt; remove unneeded examples Signed-off-by: Rich Scott * Start client method names with lowercase chars Signed-off-by: Rich Scott * Change a var to val in Scala client test; add note about codegen Signed-off-by: Rich Scott --------- Signed-off-by: Rich Scott Co-authored-by: Enrico Minack --- build/scala-client/Dockerfile | 25 +++ client/scala/examples/.gitignore | 36 ++++ client/scala/examples/README.md | 11 ++ client/scala/examples/build.sbt | 26 +++ .../scala/examples/project/build.properties | 1 + client/scala/examples/project/plugins.sbt | 7 + .../io/armadaproject/armada/Health.scala | 12 ++ .../scala/io/armadaproject/armada/Job.scala | 54 +++++ client/scala/scala-armada-client/.gitignore | 36 ++++ .../scala/scala-armada-client/.scalafmt.conf | 3 + client/scala/scala-armada-client/README.md | 5 + client/scala/scala-armada-client/build.sbt | 25 +++ .../project/build.properties | 1 + .../scala-armada-client/project/plugins.sbt | 7 + .../armadaproject/armada/ArmadaClient.scala | 64 ++++++ .../src/test/scala/ArmadaClientSuite.scala | 187 ++++++++++++++++++ magefiles/config.go | 1 + magefiles/main.go | 2 +- magefiles/scala.go | 36 ++++ scripts/build-scala-client.sh | 59 ++++++ 20 files changed, 597 insertions(+), 1 deletion(-) create mode 100644 build/scala-client/Dockerfile create mode 100644 client/scala/examples/.gitignore create mode 100644 client/scala/examples/README.md create mode 100644 client/scala/examples/build.sbt create mode 100644 client/scala/examples/project/build.properties create mode 100644 client/scala/examples/project/plugins.sbt create mode 100644 client/scala/examples/src/main/scala/io/armadaproject/armada/Health.scala create mode 100644 client/scala/examples/src/main/scala/io/armadaproject/armada/Job.scala create mode 100644 client/scala/scala-armada-client/.gitignore create mode 100644 client/scala/scala-armada-client/.scalafmt.conf create mode 100644 client/scala/scala-armada-client/README.md create mode 100644 client/scala/scala-armada-client/build.sbt create mode 100644 client/scala/scala-armada-client/project/build.properties create mode 100644 client/scala/scala-armada-client/project/plugins.sbt create mode 100644 client/scala/scala-armada-client/src/main/scala/io/armadaproject/armada/ArmadaClient.scala create mode 100644 client/scala/scala-armada-client/src/test/scala/ArmadaClientSuite.scala create mode 100644 magefiles/scala.go create mode 100755 scripts/build-scala-client.sh diff --git a/build/scala-client/Dockerfile b/build/scala-client/Dockerfile new file mode 100644 index 00000000000..e8ceb749ac3 --- /dev/null +++ b/build/scala-client/Dockerfile @@ -0,0 +1,25 @@ +FROM ubuntu:24.04 + +ARG SCALA_VERSION=2.13.15 +ARG SBT_VERSION=1.10.7 + +LABEL org.opencontainers.image.authors="G-Research Open-Source Software" +LABEL org.opencontainers.image.licenses="Apache-2.0" +LABEL org.opencontainers.image.ref.name="Ubuntu Scala Image" +LABEL org.opencontainers.image.version="" + +RUN set -ex && \ + apt update && \ + apt install -y apt-utils && \ + apt install -y bash curl && \ + curl -s -O https://downloads.lightbend.com/scala/${SCALA_VERSION}/scala-${SCALA_VERSION}.deb && \ + apt install -y ./scala-${SCALA_VERSION}.deb && \ + curl -s -L -O https://github.com/sbt/sbt/releases/download/v${SBT_VERSION}/sbt-${SBT_VERSION}.tgz && \ + tar -C / -xzvf ./sbt-${SBT_VERSION}.tgz && \ + apt-get clean && \ + rm -rf scala-${SCALA_VERSION}.deb sbt-${SBT_VERSION}.tgz /var/lib/apt/lists/* + +COPY scripts/build-scala-client.sh / +RUN chmod +x /build-scala-client.sh + +ENTRYPOINT [ "/build-scala-client.sh" ] diff --git a/client/scala/examples/.gitignore b/client/scala/examples/.gitignore new file mode 100644 index 00000000000..ea350d48edc --- /dev/null +++ b/client/scala/examples/.gitignore @@ -0,0 +1,36 @@ +# macOS +.DS_Store + +# sbt specific +dist/* +# During development of this client, you may want to +# comment out the next three lines, so IDEs/LSPs can +# find and use the protoc-generated scala code. +target/ +lib_managed/ +src_managed/ + +project/boot/ +project/plugins/project/ +project/local-plugins.sbt +.history +.ensime +.ensime_cache/ +.sbt-scripted/ +local.sbt + +# Bloop +.bsp + +# VS Code +.vscode/ + +# Metals +.bloop/ +.metals/ +metals.sbt + +# IDEA +.idea +.idea_modules +/.worksheet/ diff --git a/client/scala/examples/README.md b/client/scala/examples/README.md new file mode 100644 index 00000000000..05e02f39037 --- /dev/null +++ b/client/scala/examples/README.md @@ -0,0 +1,11 @@ +To build the example Armada client programs, build the Scala Armada client first: +``` +cd client/scala/scala-armada-client/ +sbt publishLocal +``` + +Then compile the examples project: +``` +cd client/scala/examples/ +sbt compile +``` diff --git a/client/scala/examples/build.sbt b/client/scala/examples/build.sbt new file mode 100644 index 00000000000..2b02515d466 --- /dev/null +++ b/client/scala/examples/build.sbt @@ -0,0 +1,26 @@ +val scala2Version = "2.13.15" + +lazy val root = project + .in(file(".")) + .settings( + name := "Scala Armada Client", + version := "0.1.0-SNAPSHOT", + + scalaVersion := scala2Version, + + libraryDependencies += "io.armadaproject.armada" %% "scala-armada-client" % "0.1.0-SNAPSHOT", + libraryDependencies += "org.scalameta" %% "munit" % "1.0.0" % Test + ) + +Compile / PB.targets := Seq( + scalapb.gen() -> (Compile / sourceManaged).value +) + +// Additional directories to search for imports: +Compile / PB.protoSources ++= Seq(file("./proto")) + +libraryDependencies ++= Seq( + "com.thesamet.scalapb" %% "scalapb-runtime-grpc" % scalapb.compiler.Version.scalapbVersion +) + +resolvers += Resolver.mavenLocal diff --git a/client/scala/examples/project/build.properties b/client/scala/examples/project/build.properties new file mode 100644 index 00000000000..73df629ac1a --- /dev/null +++ b/client/scala/examples/project/build.properties @@ -0,0 +1 @@ +sbt.version=1.10.7 diff --git a/client/scala/examples/project/plugins.sbt b/client/scala/examples/project/plugins.sbt new file mode 100644 index 00000000000..ab620c5f468 --- /dev/null +++ b/client/scala/examples/project/plugins.sbt @@ -0,0 +1,7 @@ +addSbtPlugin("com.thesamet" % "sbt-protoc" % "1.0.7") + +libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.11.13" + +libraryDependencies ++= Seq( + "com.google.protobuf" % "protobuf-java" % "3.13.0" % "protobuf" +) diff --git a/client/scala/examples/src/main/scala/io/armadaproject/armada/Health.scala b/client/scala/examples/src/main/scala/io/armadaproject/armada/Health.scala new file mode 100644 index 00000000000..edbf036b955 --- /dev/null +++ b/client/scala/examples/src/main/scala/io/armadaproject/armada/Health.scala @@ -0,0 +1,12 @@ +import io.armadaproject.armada.ArmadaClient + +object Main { + def main(args: Array[String]): Unit = { + val host = "localhost" + val port = 30002 + + val ac = ArmadaClient(host, port) + val status = ac.eventHealth() + println(status) + } +} diff --git a/client/scala/examples/src/main/scala/io/armadaproject/armada/Job.scala b/client/scala/examples/src/main/scala/io/armadaproject/armada/Job.scala new file mode 100644 index 00000000000..3b071129334 --- /dev/null +++ b/client/scala/examples/src/main/scala/io/armadaproject/armada/Job.scala @@ -0,0 +1,54 @@ +package io.armadaproject.armada + +import k8s.io.api.core.v1.generated.{Container, PodSpec, ResourceRequirements} +import k8s.io.apimachinery.pkg.api.resource.generated.Quantity + +object Main { + val host = "localhost" + val port = 30002 + + def main(args: Array[String]): Unit = { + val sleepContainer = Container() + .withName("ls") + .withImagePullPolicy("IfNotPresent") + .withImage("alpine:3.10") + .withCommand(Seq("ls")) + .withArgs( + Seq( + "-c", + "ls -l; sleep 30; date; echo '========'; ls -l; sleep 10; date" + ) + ) + .withResources( + ResourceRequirements( + limits = Map( + "memory" -> Quantity(Option("10Mi")), + "cpu" -> Quantity(Option("100m")) + ), + requests = Map( + "memory" -> Quantity(Option("10Mi")), + "cpu" -> Quantity(Option("100m")) + ) + ) + ) + + val podSpec = PodSpec() + .withTerminationGracePeriodSeconds(0) + .withRestartPolicy("Never") + .withContainers(Seq(sleepContainer)) + + val testJob = api.submit + .JobSubmitRequestItem() + .withPriority(0) + .withNamespace("personal-anonymous") + .withPodSpec(podSpec) + + val ac = ArmadaClient("localhost", 30002) + val response = ac.submitJobs("testQueue", "testJobSetId", List(testJob)) + + println(s"Job Submit Response") + for (respItem <- response.jobResponseItems) { + println(s"JobID: ${respItem.jobId} Error: ${respItem.error} ") + } + } +} diff --git a/client/scala/scala-armada-client/.gitignore b/client/scala/scala-armada-client/.gitignore new file mode 100644 index 00000000000..ea350d48edc --- /dev/null +++ b/client/scala/scala-armada-client/.gitignore @@ -0,0 +1,36 @@ +# macOS +.DS_Store + +# sbt specific +dist/* +# During development of this client, you may want to +# comment out the next three lines, so IDEs/LSPs can +# find and use the protoc-generated scala code. +target/ +lib_managed/ +src_managed/ + +project/boot/ +project/plugins/project/ +project/local-plugins.sbt +.history +.ensime +.ensime_cache/ +.sbt-scripted/ +local.sbt + +# Bloop +.bsp + +# VS Code +.vscode/ + +# Metals +.bloop/ +.metals/ +metals.sbt + +# IDEA +.idea +.idea_modules +/.worksheet/ diff --git a/client/scala/scala-armada-client/.scalafmt.conf b/client/scala/scala-armada-client/.scalafmt.conf new file mode 100644 index 00000000000..19189510e0c --- /dev/null +++ b/client/scala/scala-armada-client/.scalafmt.conf @@ -0,0 +1,3 @@ +version = 3.8.3 + +runner.dialect = scala213 diff --git a/client/scala/scala-armada-client/README.md b/client/scala/scala-armada-client/README.md new file mode 100644 index 00000000000..424f74872b7 --- /dev/null +++ b/client/scala/scala-armada-client/README.md @@ -0,0 +1,5 @@ +## sbt project compiled with Scala 2 + +### Usage + +You can build the Scala Armada client with `sbt package`. The jar can be found in `target/scala-2.13/`. diff --git a/client/scala/scala-armada-client/build.sbt b/client/scala/scala-armada-client/build.sbt new file mode 100644 index 00000000000..93e27000aa5 --- /dev/null +++ b/client/scala/scala-armada-client/build.sbt @@ -0,0 +1,25 @@ +val scala2Version = "2.13.15" + +lazy val root = project + .in(file(".")) + .settings( + organization := "io.armadaproject.armada", + name := "Scala-Armada-Client", + version := "0.1.0-SNAPSHOT", + + scalaVersion := scala2Version, + + libraryDependencies += "org.scalameta" %% "munit" % "1.0.0" % Test + ) + +Compile / PB.targets := Seq( + scalapb.gen() -> (Compile / sourceManaged).value +) + +// Additional directories to search for imports: +Compile / PB.protoSources ++= Seq(file("./proto")) + +libraryDependencies ++= Seq( + "io.grpc" % "grpc-netty" % scalapb.compiler.Version.grpcJavaVersion, + "com.thesamet.scalapb" %% "scalapb-runtime-grpc" % scalapb.compiler.Version.scalapbVersion +) diff --git a/client/scala/scala-armada-client/project/build.properties b/client/scala/scala-armada-client/project/build.properties new file mode 100644 index 00000000000..73df629ac1a --- /dev/null +++ b/client/scala/scala-armada-client/project/build.properties @@ -0,0 +1 @@ +sbt.version=1.10.7 diff --git a/client/scala/scala-armada-client/project/plugins.sbt b/client/scala/scala-armada-client/project/plugins.sbt new file mode 100644 index 00000000000..ab620c5f468 --- /dev/null +++ b/client/scala/scala-armada-client/project/plugins.sbt @@ -0,0 +1,7 @@ +addSbtPlugin("com.thesamet" % "sbt-protoc" % "1.0.7") + +libraryDependencies += "com.thesamet.scalapb" %% "compilerplugin" % "0.11.13" + +libraryDependencies ++= Seq( + "com.google.protobuf" % "protobuf-java" % "3.13.0" % "protobuf" +) diff --git a/client/scala/scala-armada-client/src/main/scala/io/armadaproject/armada/ArmadaClient.scala b/client/scala/scala-armada-client/src/main/scala/io/armadaproject/armada/ArmadaClient.scala new file mode 100644 index 00000000000..5f125be1186 --- /dev/null +++ b/client/scala/scala-armada-client/src/main/scala/io/armadaproject/armada/ArmadaClient.scala @@ -0,0 +1,64 @@ +package io.armadaproject.armada + +import api.job.{JobStatusRequest, JobStatusResponse, JobsGrpc} +import api.event.EventGrpc +import api.submit.{SubmitGrpc, JobSubmitRequest, JobSubmitResponse, JobSubmitRequestItem, + Queue, QueueDeleteRequest, QueueGetRequest} +import api.health.HealthCheckResponse +import api.submit.Job +import k8s.io.api.core.v1.generated.{Container, PodSpec, ResourceRequirements} +import k8s.io.apimachinery.pkg.api.resource.generated.Quantity +import com.google.protobuf.empty.Empty +import io.grpc.{ManagedChannelBuilder, ManagedChannel} + +class ArmadaClient(channel: ManagedChannel) { + def submitJobs(queue: String, jobSetId: String, jobRequestItems: Seq[JobSubmitRequestItem]): JobSubmitResponse = { + val blockingStub = SubmitGrpc.blockingStub(channel) + blockingStub.submitJobs(JobSubmitRequest(queue, jobSetId, jobRequestItems)) + } + + def getJobStatus(jobId: String): JobStatusResponse = { + val blockingStub = JobsGrpc.blockingStub(channel) + blockingStub.getJobStatus(JobStatusRequest(jobIds = Seq(jobId))) + } + + def eventHealth(): HealthCheckResponse.ServingStatus = { + val blockingStub = EventGrpc.blockingStub(channel) + blockingStub.health(Empty()).status + } + + def submitHealth(): HealthCheckResponse.ServingStatus = { + val blockingStub = SubmitGrpc.blockingStub(channel) + blockingStub.health(Empty()).status + } + + def createQueue(name: String): Unit = { + val blockingStub = SubmitGrpc.blockingStub(channel) + val q = api.submit.Queue().withName(name).withPriorityFactor(1) + blockingStub.createQueue(q) + } + + def deleteQueue(name: String): Unit = { + val qReq = QueueDeleteRequest(name) + val blockingStub = SubmitGrpc.blockingStub(channel) + blockingStub.deleteQueue(qReq) + } + + def getQueue(name: String): Queue = { + val qReq = QueueGetRequest(name) + val blockingStub = SubmitGrpc.blockingStub(channel) + blockingStub.getQueue(qReq) + } +} + +object ArmadaClient { + // TODO: SSL + def apply(channel: ManagedChannel): ArmadaClient = { + new ArmadaClient(channel) + } + + def apply(host: String, port: Int): ArmadaClient = { + val channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build + ArmadaClient(channel) + } +} diff --git a/client/scala/scala-armada-client/src/test/scala/ArmadaClientSuite.scala b/client/scala/scala-armada-client/src/test/scala/ArmadaClientSuite.scala new file mode 100644 index 00000000000..c533a8d2900 --- /dev/null +++ b/client/scala/scala-armada-client/src/test/scala/ArmadaClientSuite.scala @@ -0,0 +1,187 @@ +package io.armadaproject.armada + +import io.armadaproject.armada.ArmadaClient +import api.submit.{SubmitGrpc, CancellationResult, Queue, BatchQueueCreateResponse, + StreamingQueueMessage, JobReprioritizeResponse, JobSubmitResponse, + BatchQueueUpdateResponse, JobSubmitResponseItem, JobSubmitRequestItem, + JobState, JobSetCancelRequest, JobCancelRequest, QueueDeleteRequest, + QueueGetRequest, StreamingQueueGetRequest, JobPreemptRequest, + JobReprioritizeRequest, JobSubmitRequest, QueueList} +import api.job.{JobRunState, JobsGrpc, JobErrorsResponse, JobDetailsRequest, + JobRunDetailsResponse, JobDetailsResponse, + JobStatusUsingExternalJobUriRequest, JobStatusResponse, + JobErrorsRequest, JobRunDetailsRequest, JobStatusRequest} +import com.google.protobuf.empty.Empty +import api.health.HealthCheckResponse +import api.event.{EventGrpc, EventStreamMessage, JobSetRequest, WatchRequest} +import io.grpc.stub.StreamObserver +import io.grpc.{Server, ServerBuilder} + +import scala.concurrent.Future +import scala.util.Random + +private class EventMockServer extends EventGrpc.Event { + override def health(empty: Empty): scala.concurrent.Future[HealthCheckResponse] = { + Future.successful(HealthCheckResponse(HealthCheckResponse.ServingStatus.SERVING)) + } + + override def getJobSetEvents(request: JobSetRequest, + responseObserver: io.grpc.stub.StreamObserver[EventStreamMessage]): Unit = { + // TODO: fill-in + } + + override def watch(request: WatchRequest, responseObserver: io.grpc.stub.StreamObserver[EventStreamMessage]): Unit = { + // TODO: fill-in + } +} + +private class SubmitMockServer extends SubmitGrpc.Submit { + def cancelJobSet(request: JobSetCancelRequest): scala.concurrent.Future[com.google.protobuf.empty.Empty] = { + Future.successful(new Empty) + } + + def cancelJobs(request: JobCancelRequest): scala.concurrent.Future[CancellationResult] = { + Future.successful(new CancellationResult) + } + + def createQueue(request: Queue): scala.concurrent.Future[com.google.protobuf.empty.Empty] = { + Future.successful(new Empty) + } + + def createQueues(request: QueueList): scala.concurrent.Future[BatchQueueCreateResponse] = { + Future.successful(new BatchQueueCreateResponse) + } + + def deleteQueue(request: QueueDeleteRequest): scala.concurrent.Future[com.google.protobuf.empty.Empty] = { + Future.successful(new Empty) + } + def getQueue(request: QueueGetRequest): scala.concurrent.Future[Queue] = { + Future.successful(new Queue) + } + + def getQueues(request: StreamingQueueGetRequest, responseObserver: io.grpc.stub.StreamObserver[StreamingQueueMessage]): Unit = { + Future.successful(new StreamingQueueMessage) + } + + def health(request: com.google.protobuf.empty.Empty): scala.concurrent.Future[HealthCheckResponse] = { + Future.successful(HealthCheckResponse(HealthCheckResponse.ServingStatus.SERVING)) + } + + def preemptJobs(request: JobPreemptRequest): scala.concurrent.Future[com.google.protobuf.empty.Empty] = { + Future.successful(new Empty) + } + + def reprioritizeJobs(request: JobReprioritizeRequest): scala.concurrent.Future[JobReprioritizeResponse] = { + Future.successful(new JobReprioritizeResponse) + } + + def submitJobs(request: JobSubmitRequest): scala.concurrent.Future[JobSubmitResponse] = { + Future.successful((new JobSubmitResponse(List(JobSubmitResponseItem("fakeJobId"))))) + } + + def updateQueue(request: Queue): scala.concurrent.Future[com.google.protobuf.empty.Empty] = { + Future.successful(new Empty) + } + + def updateQueues(request: QueueList): scala.concurrent.Future[BatchQueueUpdateResponse] = { + Future.successful(new BatchQueueUpdateResponse) + } + +} + +private class JobsMockServer extends JobsGrpc.Jobs { + def getJobDetails(request: JobDetailsRequest): scala.concurrent.Future[JobDetailsResponse] = { + Future.successful(new JobDetailsResponse) + } + + def getJobErrors(request: JobErrorsRequest): scala.concurrent.Future[JobErrorsResponse] = { + Future.successful(new JobErrorsResponse) + } + + def getJobRunDetails(request: JobRunDetailsRequest): scala.concurrent.Future[JobRunDetailsResponse] = { + Future.successful(new JobRunDetailsResponse) + } + + def getJobStatus(request: JobStatusRequest): scala.concurrent.Future[JobStatusResponse] = { + val response = new JobStatusResponse(Map("fakeJobId" -> JobState.RUNNING)) + Future.successful(response) + } + + def getJobStatusUsingExternalJobUri(request: JobStatusUsingExternalJobUriRequest): scala.concurrent.Future[JobStatusResponse] = { + Future.successful(new JobStatusResponse) + } +} + +// For more information on writing tests, see +// https://scalameta.org/munit/docs/getting-started.html +class ArmadaClientSuite extends munit.FunSuite { + val testPort = 12345 + val mockEventServer = new Fixture[Server]("Event GRPC Mock Server") { + private var server: Server = null + def apply() = server + override def beforeAll(): Unit = { + import scala.concurrent.ExecutionContext + server = ServerBuilder + .forPort(testPort) + .addService(EventGrpc.bindService(new EventMockServer, ExecutionContext.global)) + .addService(SubmitGrpc.bindService(new SubmitMockServer, ExecutionContext.global)) + .addService(JobsGrpc.bindService(new JobsMockServer, ExecutionContext.global)) + .build() + .start() + } + override def afterAll(): Unit = { + server.shutdown() + } + } + + override def munitFixtures = List(mockEventServer) + + test("ArmadaClient.EventHealth()") { + val ac = ArmadaClient("localhost", testPort) + val status = ac.EventHealth() + assertEquals(status, HealthCheckResponse.ServingStatus.SERVING) + } + + test("ArmadaClient.SubmitHealth()") { + val ac = ArmadaClient("localhost", testPort) + val status = ac.SubmitHealth() + assertEquals(status, HealthCheckResponse.ServingStatus.SERVING) + } + + test("ArmadaClient.SubmitJobs()") { + val ac = ArmadaClient("localhost", testPort) + val response = ac.SubmitJobs("testQueue", "testJobSetId", List(new JobSubmitRequestItem())) + assertEquals(response.jobResponseItems(0), JobSubmitResponseItem("fakeJobId")) + } + + test("ArmadaClient.GetJobStatus()") { + val ac = ArmadaClient("localhost", testPort) + val response = ac.GetJobStatus("fakeJobId") + assert(response.jobStates("fakeJobId").isRunning) + } + + // Queue tests currently disabled - Armada mock server does not implement full queue + // state so these fail when running with mock; they pass with a real Armada instance + // test("test queue existence, creation, deletion") { + // val ac = new ArmadaClient(ArmadaClient.GetChannel("localhost", testPort)) + // val qName = "test-queue-" + Random.alphanumeric.take(8).mkString + // var q: Queue = new Queue() + + // // queue should not exist yet + // intercept[io.grpc.StatusRuntimeException] { + // q = ac.getQueue(qName) + // } + // assertNotEquals(q.name, qName) + + // ac.createQueue(qName) + // q = ac.getQueue(qName) + // assertEquals(q.name, qName) + + // ac.deleteQueue(qName) + // q = new Queue() + // intercept[io.grpc.StatusRuntimeException] { + // q = ac.getQueue(qName) + // } + // assertNotEquals(q.name, qName) + // } +} diff --git a/magefiles/config.go b/magefiles/config.go index 7f6ca05b0f2..f7048617f75 100644 --- a/magefiles/config.go +++ b/magefiles/config.go @@ -9,6 +9,7 @@ import ( type BuildConfig struct { DockerRegistries map[string]string `json:"dockerRegistries"` PythonBuilderBaseImage string `json:"pythonBuilderBaseImage"` + ScalaBuilderBaseImage string `json:"scalaBuilderBaseImage"` } func getBuildConfig() (BuildConfig, error) { diff --git a/magefiles/main.go b/magefiles/main.go index 7c77a25e564..d39b893acf0 100644 --- a/magefiles/main.go +++ b/magefiles/main.go @@ -205,7 +205,7 @@ func LocalDev(arg string) error { case "minimal": mg.Deps(mg.F(goreleaserMinimalRelease, "bundle"), Kind, downloadDependencyImages) case "full": - mg.Deps(BuildPython, mg.F(BuildDockers, "bundle, lookout-bundle"), Kind, downloadDependencyImages) + mg.Deps(BuildPython, BuildScala, mg.F(BuildDockers, "bundle, lookout-bundle"), Kind, downloadDependencyImages) case "no-build", "debug": mg.Deps(Kind, downloadDependencyImages) default: diff --git a/magefiles/scala.go b/magefiles/scala.go new file mode 100644 index 00000000000..ca6e2c95914 --- /dev/null +++ b/magefiles/scala.go @@ -0,0 +1,36 @@ +package main + +import ( + "fmt" + "os" + + "github.com/magefile/mage/mg" +) + +// Build armada Scala client. +func BuildScala() error { + mg.Deps(BootstrapProto) + + buildConfig, err := getBuildConfig() + if err != nil { + return err + } + + err = dockerBuildImage(NewDockerBuildConfig(buildConfig.ScalaBuilderBaseImage), + "armada-scala-client-builder", "./build/scala-client/Dockerfile") + if err != nil { + return err + } + + wd, err := os.Getwd() + if err != nil { + return err + } + + return dockerRun("run", + "-u", fmt.Sprintf("%d:%d", os.Getuid(), os.Getgid()), + "--rm", + "-v", fmt.Sprintf("%s:/build", wd), + "-w", "/build", + "armada-scala-client-builder") +} diff --git a/scripts/build-scala-client.sh b/scripts/build-scala-client.sh new file mode 100755 index 00000000000..b0f1322bc46 --- /dev/null +++ b/scripts/build-scala-client.sh @@ -0,0 +1,59 @@ +#!/bin/bash +# This script is intended to be run under the docker container +# in the root dir of the Armada repo + +export PATH=/sbt/bin:$PATH + +ROOT=$(pwd) +SDIR=client/scala/scala-armada-client + +rm -rf $ROOT/$SDIR/proto +mkdir -p $ROOT/$SDIR/proto + +cd proto +for pfile in \ + google/api/annotations.proto \ + google/api/http.proto \ + google/protobuf/*.proto \ + github.com/gogo/protobuf/gogoproto/gogo.proto \ + k8s.io/api/core/v1/generated.proto \ + k8s.io/apimachinery/pkg/api/resource/generated.proto \ + k8s.io/apimachinery/pkg/apis/meta/v1/generated.proto \ + k8s.io/apimachinery/pkg/runtime/generated.proto \ + k8s.io/apimachinery/pkg/runtime/schema/generated.proto \ + k8s.io/apimachinery/pkg/util/intstr/generated.proto \ + k8s.io/api/networking/v1/generated.proto +do + dir=$(dirname $pfile) + mkdir -p $ROOT/$SDIR/proto/$dir + cp $pfile $ROOT/$SDIR/proto/$dir/ +done + +cd .. +for pfile in \ + pkg/api/event.proto pkg/api/submit.proto pkg/api/health.proto pkg/api/job.proto pkg/api/binoculars/binoculars.proto +do + dir=$(dirname $pfile) + mkdir -p $ROOT/$SDIR/proto/$dir + cp $pfile $ROOT/$SDIR/proto/$dir/ +done + +# the sbt config in the $SDIR directory causes the following commands to generate +# the scala protobuf files into $SDIR/target/scala-2.13/src_managed/main/ + +cd $ROOT/$SDIR +sbt clean && \ +sbt -Dsbt.io.implicit.relative.glob.conversion=allow compile && \ +sbt -Dsbt.io.implicit.relative.glob.conversion=allow package + +if [ $? -eq 0 ]; then + jarfile=$(find . -type f -name 'scala-armada-client*.jar') + if [[ $jarfile ]]; then + echo "" > /dev/stderr + jarfile=$(echo $jarfile | sed -e s%^\./%%) + echo "Armada Scala client jar file written to $SDIR/$jarfile" > /dev/stderr + fi +else + echo "sbt build exited with exit code $?" > /dev/stderr +fi + From eb1324ccef6348542b798f72302655ba69c13039 Mon Sep 17 00:00:00 2001 From: Chris Martin Date: Sun, 9 Feb 2025 17:45:43 +0000 Subject: [PATCH 2/2] Scheduler To Use internaltypes.JobSchedulingInfo (#4182) * wip * converted to scheduling info * fix tests --- .../internaltypes/job_scheduling_info.go | 114 +++++++++++++ internal/scheduler/jobdb/job.go | 35 ++-- internal/scheduler/jobdb/job_test.go | 93 ++++------ internal/scheduler/jobdb/jobdb.go | 26 ++- internal/scheduler/jobdb/jobdb_test.go | 161 ++++++------------ internal/scheduler/jobdb/reconciliation.go | 25 ++- internal/scheduler/jobdb/test_utils.go | 31 ++-- internal/scheduler/metrics.go | 2 +- internal/scheduler/nodedb/nodematching.go | 8 +- .../scheduler/nodedb/nodematching_test.go | 66 +++---- internal/scheduler/scheduler.go | 10 +- internal/scheduler/scheduler_test.go | 31 ++-- internal/scheduler/scheduling/context/job.go | 2 +- .../scheduler/scheduling/jobiteration_test.go | 17 +- .../scheduling/preemption_description_test.go | 15 +- internal/scheduler/simulator/simulator.go | 6 +- .../scheduler/testfixtures/testfixtures.go | 75 +++----- 17 files changed, 362 insertions(+), 355 deletions(-) create mode 100644 internal/scheduler/internaltypes/job_scheduling_info.go diff --git a/internal/scheduler/internaltypes/job_scheduling_info.go b/internal/scheduler/internaltypes/job_scheduling_info.go new file mode 100644 index 00000000000..39ec45a6794 --- /dev/null +++ b/internal/scheduler/internaltypes/job_scheduling_info.go @@ -0,0 +1,114 @@ +package internaltypes + +import ( + "time" + + "github.com/gogo/protobuf/proto" + "github.com/pkg/errors" + "golang.org/x/exp/maps" + v1 "k8s.io/api/core/v1" + + armadaslices "github.com/armadaproject/armada/internal/common/slices" + "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" +) + +// JobSchedulingInfo is a minimal representation of job requirements that the scheduler uses for scheduling +type JobSchedulingInfo struct { + Lifetime uint32 + PriorityClassName string + SubmitTime time.Time + Priority uint32 + PodRequirements *PodRequirements + Version uint32 +} + +func (j *JobSchedulingInfo) DeepCopy() *JobSchedulingInfo { + return &JobSchedulingInfo{ + Lifetime: j.Lifetime, + PriorityClassName: j.PriorityClassName, + SubmitTime: j.SubmitTime, + Priority: j.Priority, + PodRequirements: j.PodRequirements.DeepCopy(), + Version: j.Version, + } +} + +// PodRequirements captures the scheduling requirements specific to a pod. +type PodRequirements struct { + NodeSelector map[string]string + Affinity *v1.Affinity + Tolerations []v1.Toleration + Annotations map[string]string + ResourceRequirements v1.ResourceRequirements +} + +func (p *PodRequirements) GetAffinityNodeSelector() *v1.NodeSelector { + affinity := p.Affinity + if affinity == nil { + return nil + } + nodeAffinity := affinity.NodeAffinity + if nodeAffinity == nil { + return nil + } + return nodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution +} + +func (p *PodRequirements) DeepCopy() *PodRequirements { + clonedResourceRequirements := proto.Clone(&p.ResourceRequirements).(*v1.ResourceRequirements) + return &PodRequirements{ + NodeSelector: maps.Clone(p.NodeSelector), + Affinity: proto.Clone(p.Affinity).(*v1.Affinity), + Annotations: maps.Clone(p.Annotations), + Tolerations: armadaslices.Map(p.Tolerations, func(t v1.Toleration) v1.Toleration { + cloned := proto.Clone(&t).(*v1.Toleration) + return *cloned + }), + ResourceRequirements: *clonedResourceRequirements, + } +} + +func FromSchedulerObjectsJobSchedulingInfo(j *schedulerobjects.JobSchedulingInfo) (*JobSchedulingInfo, error) { + podRequirements := j.GetPodRequirements() + if podRequirements == nil { + return nil, errors.Errorf("job must have pod requirements") + } + return &JobSchedulingInfo{ + Lifetime: j.Lifetime, + PriorityClassName: j.PriorityClassName, + SubmitTime: j.SubmitTime, + Priority: j.Priority, + PodRequirements: &PodRequirements{ + NodeSelector: podRequirements.NodeSelector, + Affinity: podRequirements.Affinity, + Tolerations: podRequirements.Tolerations, + Annotations: podRequirements.Annotations, + ResourceRequirements: podRequirements.ResourceRequirements, + }, + Version: j.Version, + }, nil +} + +func ToSchedulerObjectsJobSchedulingInfo(j *JobSchedulingInfo) *schedulerobjects.JobSchedulingInfo { + podRequirements := j.PodRequirements + return &schedulerobjects.JobSchedulingInfo{ + Lifetime: j.Lifetime, + PriorityClassName: j.PriorityClassName, + SubmitTime: j.SubmitTime, + Priority: j.Priority, + ObjectRequirements: []*schedulerobjects.ObjectRequirements{ + { + Requirements: &schedulerobjects.ObjectRequirements_PodRequirements{ + PodRequirements: &schedulerobjects.PodRequirements{ + NodeSelector: podRequirements.NodeSelector, + Affinity: podRequirements.Affinity, + Tolerations: podRequirements.Tolerations, + Annotations: podRequirements.Annotations, + ResourceRequirements: podRequirements.ResourceRequirements, + }, + }, + }, + }, + Version: j.Version, + } +} diff --git a/internal/scheduler/jobdb/job.go b/internal/scheduler/jobdb/job.go index a1b264bb62a..86f583be597 100644 --- a/internal/scheduler/jobdb/job.go +++ b/internal/scheduler/jobdb/job.go @@ -4,7 +4,6 @@ import ( "fmt" "time" - "github.com/gogo/protobuf/proto" "github.com/hashicorp/go-multierror" "github.com/pkg/errors" "golang.org/x/exp/maps" @@ -51,7 +50,7 @@ type Job struct { // The current version of the queued state. queuedVersion int32 // Scheduling requirements of this job. - jobSchedulingInfo *schedulerobjects.JobSchedulingInfo + jobSchedulingInfo *internaltypes.JobSchedulingInfo // All resource requirements, including floating resources, for this job allResourceRequirements internaltypes.ResourceList // Kubernetes (i.e. non-floating) resource requirements of this job @@ -264,17 +263,12 @@ func (job *Job) Assert() error { func (job *Job) ensureJobSchedulingInfoFieldsInitialised() { // Initialise the annotation and nodeSelector maps if nil. // Since those need to be mutated in-place. - if job.jobSchedulingInfo != nil { - for _, req := range job.jobSchedulingInfo.ObjectRequirements { - if podReq := req.GetPodRequirements(); podReq != nil { - if podReq.Annotations == nil { - podReq.Annotations = make(map[string]string) - } - if podReq.NodeSelector == nil { - podReq.NodeSelector = make(map[string]string) - } - } - } + podReq := job.jobSchedulingInfo.PodRequirements + if podReq.Annotations == nil { + podReq.Annotations = make(map[string]string) + } + if podReq.NodeSelector == nil { + podReq.NodeSelector = make(map[string]string) } } @@ -459,7 +453,7 @@ func (job *Job) WithRequestedPriority(priority uint32) *Job { } // JobSchedulingInfo returns the scheduling requirements associated with the job -func (job *Job) JobSchedulingInfo() *schedulerobjects.JobSchedulingInfo { +func (job *Job) JobSchedulingInfo() *internaltypes.JobSchedulingInfo { return job.jobSchedulingInfo } @@ -538,8 +532,8 @@ func (job *Job) KubernetesResourceRequirements() internaltypes.ResourceList { } // PodRequirements returns the pod requirements of the Job -func (job *Job) PodRequirements() *schedulerobjects.PodRequirements { - return job.jobSchedulingInfo.GetPodRequirements() +func (job *Job) PodRequirements() *internaltypes.PodRequirements { + return job.jobSchedulingInfo.PodRequirements } // Queued returns true if the job should be considered by the scheduler for assignment or false otherwise. @@ -658,11 +652,10 @@ func (job *Job) HasRuns() bool { } func (job *Job) ValidateResourceRequests() error { - pr := job.jobSchedulingInfo.GetPodRequirements() + pr := job.jobSchedulingInfo.PodRequirements if pr == nil { return nil } - req := pr.ResourceRequirements.Requests if req == nil { return nil @@ -801,13 +794,13 @@ func (job *Job) Validated() bool { return job.validated } -// Does this job request any floating resources? +// RequestsFloatingResources returns true if this job requests any floating resources func (job *Job) RequestsFloatingResources() bool { return !job.AllResourceRequirements().OfType(internaltypes.Floating).AllZero() } // WithJobSchedulingInfo returns a copy of the job with the job scheduling info updated. -func (job *Job) WithJobSchedulingInfo(jobSchedulingInfo *schedulerobjects.JobSchedulingInfo) (*Job, error) { +func (job *Job) WithJobSchedulingInfo(jobSchedulingInfo *internaltypes.JobSchedulingInfo) (*Job, error) { j := copyJob(*job) j.jobSchedulingInfo = jobSchedulingInfo j.ensureJobSchedulingInfoFieldsInitialised() @@ -823,7 +816,7 @@ func (job *Job) WithJobSchedulingInfo(jobSchedulingInfo *schedulerobjects.JobSch func (job *Job) DeepCopy() *Job { j := copyJob(*job) - j.jobSchedulingInfo = proto.Clone(job.JobSchedulingInfo()).(*schedulerobjects.JobSchedulingInfo) + j.jobSchedulingInfo = job.jobSchedulingInfo.DeepCopy() j.ensureJobSchedulingInfoFieldsInitialised() j.schedulingKey = SchedulingKeyFromJob(j.jobDb.schedulingKeyGenerator, j) diff --git a/internal/scheduler/jobdb/job_test.go b/internal/scheduler/jobdb/job_test.go index 49f195112b5..90db15bb4a8 100644 --- a/internal/scheduler/jobdb/job_test.go +++ b/internal/scheduler/jobdb/job_test.go @@ -3,34 +3,26 @@ package jobdb import ( "testing" - v1 "k8s.io/api/core/v1" - k8sResource "k8s.io/apimachinery/pkg/api/resource" - - "github.com/gogo/protobuf/proto" "github.com/google/uuid" "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" + k8sResource "k8s.io/apimachinery/pkg/api/resource" "github.com/armadaproject/armada/internal/common/types" - "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" + "github.com/armadaproject/armada/internal/scheduler/internaltypes" ) -var jobSchedulingInfo = &schedulerobjects.JobSchedulingInfo{ - ObjectRequirements: []*schedulerobjects.ObjectRequirements{ - { - Requirements: &schedulerobjects.ObjectRequirements_PodRequirements{ - PodRequirements: &schedulerobjects.PodRequirements{ - ResourceRequirements: v1.ResourceRequirements{ - Requests: v1.ResourceList{ - "cpu": k8sResource.MustParse("1"), - "storage-connections": k8sResource.MustParse("1"), - }, - }, - Annotations: map[string]string{ - "foo": "bar", - }, - }, +var jobSchedulingInfo = &internaltypes.JobSchedulingInfo{ + PodRequirements: &internaltypes.PodRequirements{ + ResourceRequirements: v1.ResourceRequirements{ + Requests: v1.ResourceList{ + "cpu": k8sResource.MustParse("1"), + "storage-connections": k8sResource.MustParse("1"), }, }, + Annotations: map[string]string{ + "foo": "bar", + }, }, } @@ -344,23 +336,17 @@ func TestJob_DeepCopy(t *testing.T) { } func TestJob_TestWithJobSchedulingInfo(t *testing.T) { - newSchedInfo := &schedulerobjects.JobSchedulingInfo{ - ObjectRequirements: []*schedulerobjects.ObjectRequirements{ - { - Requirements: &schedulerobjects.ObjectRequirements_PodRequirements{ - PodRequirements: &schedulerobjects.PodRequirements{ - ResourceRequirements: v1.ResourceRequirements{ - Requests: v1.ResourceList{ - "cpu": k8sResource.MustParse("2"), - "storage-connections": k8sResource.MustParse("2"), - }, - }, - Annotations: map[string]string{ - "fish": "chips", - }, - }, + newSchedInfo := &internaltypes.JobSchedulingInfo{ + PodRequirements: &internaltypes.PodRequirements{ + ResourceRequirements: v1.ResourceRequirements{ + Requests: v1.ResourceList{ + "cpu": k8sResource.MustParse("2"), + "storage-connections": k8sResource.MustParse("2"), }, }, + Annotations: map[string]string{ + "fish": "chips", + }, }, } newJob := JobWithJobSchedulingInfo(baseJob, newSchedInfo) @@ -379,18 +365,12 @@ func TestJob_TestWithJobSchedulingInfo(t *testing.T) { } func TestRequestsFloatingResources(t *testing.T) { - noFloatingResourcesJob := JobWithJobSchedulingInfo(baseJob, &schedulerobjects.JobSchedulingInfo{ - ObjectRequirements: []*schedulerobjects.ObjectRequirements{ - { - Requirements: &schedulerobjects.ObjectRequirements_PodRequirements{ - PodRequirements: &schedulerobjects.PodRequirements{ - ResourceRequirements: v1.ResourceRequirements{ - Requests: v1.ResourceList{ - "cpu": k8sResource.MustParse("1"), - "storage-connections": k8sResource.MustParse("0"), - }, - }, - }, + noFloatingResourcesJob := JobWithJobSchedulingInfo(baseJob, &internaltypes.JobSchedulingInfo{ + PodRequirements: &internaltypes.PodRequirements{ + ResourceRequirements: v1.ResourceRequirements{ + Requests: v1.ResourceList{ + "cpu": k8sResource.MustParse("1"), + "storage-connections": k8sResource.MustParse("0"), }, }, }, @@ -400,20 +380,13 @@ func TestRequestsFloatingResources(t *testing.T) { } func TestJobSchedulingInfoFieldsInitialised(t *testing.T) { - infoWithNilFields := &schedulerobjects.JobSchedulingInfo{ - ObjectRequirements: []*schedulerobjects.ObjectRequirements{ - { - Requirements: &schedulerobjects.ObjectRequirements_PodRequirements{ - PodRequirements: &schedulerobjects.PodRequirements{}, - }, - }, - }, + infoWithNilFields := &internaltypes.JobSchedulingInfo{ + PodRequirements: &internaltypes.PodRequirements{}, } - infoWithNilFieldsCopy := proto.Clone(infoWithNilFields).(*schedulerobjects.JobSchedulingInfo) - assert.NotNil(t, infoWithNilFields.GetPodRequirements()) - assert.Nil(t, infoWithNilFields.GetPodRequirements().NodeSelector) - assert.Nil(t, infoWithNilFields.GetPodRequirements().Annotations) + infoWithNilFieldsCopy := infoWithNilFields.DeepCopy() + assert.Nil(t, infoWithNilFields.PodRequirements.NodeSelector) + assert.Nil(t, infoWithNilFields.PodRequirements.Annotations) job, err := jobDb.NewJob("test-job", "test-jobSet", "test-queue", 2, 0.0, infoWithNilFieldsCopy, true, 0, false, false, false, 3, false, []string{}) assert.Nil(t, err) @@ -421,7 +394,7 @@ func TestJobSchedulingInfoFieldsInitialised(t *testing.T) { assert.NotNil(t, job.Annotations()) // Copy again here, as the fields get mutated so we want a clean copy - infoWithNilFieldsCopy2 := proto.Clone(infoWithNilFields).(*schedulerobjects.JobSchedulingInfo) + infoWithNilFieldsCopy2 := infoWithNilFields updatedJob := JobWithJobSchedulingInfo(baseJob, infoWithNilFieldsCopy2) assert.NotNil(t, updatedJob.NodeSelector()) assert.NotNil(t, updatedJob.Annotations()) diff --git a/internal/scheduler/jobdb/jobdb.go b/internal/scheduler/jobdb/jobdb.go index 9252aef11ba..790c0a4cb0d 100644 --- a/internal/scheduler/jobdb/jobdb.go +++ b/internal/scheduler/jobdb/jobdb.go @@ -176,7 +176,7 @@ func (jobDb *JobDb) NewJob( queue string, priority uint32, bidPrice float64, - schedulingInfo *schedulerobjects.JobSchedulingInfo, + schedulingInfo *internaltypes.JobSchedulingInfo, queued bool, queuedVersion int32, cancelRequested bool, @@ -220,12 +220,12 @@ func (jobDb *JobDb) NewJob( return job, nil } -func (jobDb *JobDb) getResourceRequirements(schedulingInfo *schedulerobjects.JobSchedulingInfo) internaltypes.ResourceList { +func (jobDb *JobDb) getResourceRequirements(schedulingInfo *internaltypes.JobSchedulingInfo) internaltypes.ResourceList { return jobDb.resourceListFactory.FromJobResourceListIgnoreUnknown(safeGetRequirements(schedulingInfo)) } -func safeGetRequirements(schedulingInfo *schedulerobjects.JobSchedulingInfo) map[string]k8sResource.Quantity { - pr := schedulingInfo.GetPodRequirements() +func safeGetRequirements(schedulingInfo *internaltypes.JobSchedulingInfo) map[string]k8sResource.Quantity { + pr := schedulingInfo.PodRequirements if pr == nil { return map[string]k8sResource.Quantity{} } @@ -238,18 +238,14 @@ func safeGetRequirements(schedulingInfo *schedulerobjects.JobSchedulingInfo) map return adapters.K8sResourceListToMap(req) } -func (jobDb *JobDb) internJobSchedulingInfoStrings(info *schedulerobjects.JobSchedulingInfo) *schedulerobjects.JobSchedulingInfo { - for _, requirement := range info.ObjectRequirements { - if podRequirement := requirement.GetPodRequirements(); podRequirement != nil { - for k, v := range podRequirement.Annotations { - podRequirement.Annotations[jobDb.stringInterner.Intern(k)] = jobDb.stringInterner.Intern(v) - } +func (jobDb *JobDb) internJobSchedulingInfoStrings(info *internaltypes.JobSchedulingInfo) *internaltypes.JobSchedulingInfo { + pr := info.PodRequirements + for k, v := range pr.Annotations { + pr.Annotations[jobDb.stringInterner.Intern(k)] = jobDb.stringInterner.Intern(v) + } - for k, v := range podRequirement.NodeSelector { - podRequirement.NodeSelector[jobDb.stringInterner.Intern(k)] = jobDb.stringInterner.Intern(v) - } - podRequirement.PreemptionPolicy = jobDb.stringInterner.Intern(podRequirement.PreemptionPolicy) - } + for k, v := range pr.NodeSelector { + pr.NodeSelector[jobDb.stringInterner.Intern(k)] = jobDb.stringInterner.Intern(v) } return info } diff --git a/internal/scheduler/jobdb/jobdb_test.go b/internal/scheduler/jobdb/jobdb_test.go index 125d8efdf8f..71f4a103735 100644 --- a/internal/scheduler/jobdb/jobdb_test.go +++ b/internal/scheduler/jobdb/jobdb_test.go @@ -5,7 +5,6 @@ import ( "sort" "testing" - "github.com/gogo/protobuf/proto" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -18,6 +17,7 @@ import ( "github.com/armadaproject/armada/internal/common/stringinterner" "github.com/armadaproject/armada/internal/common/types" "github.com/armadaproject/armada/internal/common/util" + "github.com/armadaproject/armada/internal/scheduler/internaltypes" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" ) @@ -283,18 +283,12 @@ func TestJobDb_TestBatchDelete(t *testing.T) { } func TestJobDb_SchedulingKeyIsPopulated(t *testing.T) { - podRequirements := &schedulerobjects.PodRequirements{ + podRequirements := &internaltypes.PodRequirements{ NodeSelector: map[string]string{"foo": "bar"}, } - jobSchedulingInfo := &schedulerobjects.JobSchedulingInfo{ + jobSchedulingInfo := &internaltypes.JobSchedulingInfo{ PriorityClassName: "foo", - ObjectRequirements: []*schedulerobjects.ObjectRequirements{ - { - Requirements: &schedulerobjects.ObjectRequirements_PodRequirements{ - PodRequirements: podRequirements, - }, - }, - }, + PodRequirements: podRequirements, } jobDb := NewTestJobDb() job, err := jobDb.NewJob("jobId", "jobSet", "queue", 1, 0.0, jobSchedulingInfo, false, 0, false, false, false, 2, false, []string{}) @@ -304,14 +298,14 @@ func TestJobDb_SchedulingKeyIsPopulated(t *testing.T) { func TestJobDb_SchedulingKey(t *testing.T) { tests := map[string]struct { - podRequirementsA *schedulerobjects.PodRequirements + podRequirementsA *internaltypes.PodRequirements priorityClassNameA string - podRequirementsB *schedulerobjects.PodRequirements + podRequirementsB *internaltypes.PodRequirements priorityClassNameB string equal bool }{ "annotations does not affect key": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property1": "value1", "property3": "value3", @@ -328,7 +322,6 @@ func TestJobDb_SchedulingKey(t *testing.T) { "fish": "chips", "salt": "pepper", }, - PreemptionPolicy: "abc", ResourceRequirements: v1.ResourceRequirements{ Limits: map[v1.ResourceName]resource.Quantity{ "cpu": resource.MustParse("1"), @@ -342,7 +335,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, }, - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property1": "value1", "property3": "value3", @@ -358,62 +351,6 @@ func TestJobDb_SchedulingKey(t *testing.T) { "foo": "bar", "fish": "chips", }, - PreemptionPolicy: "abc", - ResourceRequirements: v1.ResourceRequirements{ - Limits: map[v1.ResourceName]resource.Quantity{ - "cpu": resource.MustParse("1"), - "memory": resource.MustParse("2"), - "nvidia.com/gpu": resource.MustParse("3"), - }, - Requests: map[v1.ResourceName]resource.Quantity{ - "cpu": resource.MustParse("2"), - "memory": resource.MustParse("2"), - "nvidia.com/gpu": resource.MustParse("2"), - }, - }, - }, - equal: true, - }, - "preemptionPolicy does not affect key": { - podRequirementsA: &schedulerobjects.PodRequirements{ - NodeSelector: map[string]string{ - "property1": "value1", - "property3": "value3", - }, - Tolerations: []v1.Toleration{{ - Key: "a", - Operator: "b", - Value: "b", - Effect: "d", - TolerationSeconds: pointer.Int64(1), - }}, - PreemptionPolicy: "abc", - ResourceRequirements: v1.ResourceRequirements{ - Limits: map[v1.ResourceName]resource.Quantity{ - "cpu": resource.MustParse("1"), - "memory": resource.MustParse("2"), - "nvidia.com/gpu": resource.MustParse("3"), - }, - Requests: map[v1.ResourceName]resource.Quantity{ - "cpu": resource.MustParse("2"), - "memory": resource.MustParse("2"), - "nvidia.com/gpu": resource.MustParse("2"), - }, - }, - }, - podRequirementsB: &schedulerobjects.PodRequirements{ - NodeSelector: map[string]string{ - "property1": "value1", - "property3": "value3", - }, - Tolerations: []v1.Toleration{{ - Key: "a", - Operator: "b", - Value: "b", - Effect: "d", - TolerationSeconds: pointer.Int64(1), - }}, - PreemptionPolicy: "abcdef", ResourceRequirements: v1.ResourceRequirements{ Limits: map[v1.ResourceName]resource.Quantity{ "cpu": resource.MustParse("1"), @@ -430,7 +367,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { equal: true, }, "limits does not affect key": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property1": "value1", "property3": "value3", @@ -442,7 +379,6 @@ func TestJobDb_SchedulingKey(t *testing.T) { Effect: "d", TolerationSeconds: pointer.Int64(1), }}, - PreemptionPolicy: "abc", ResourceRequirements: v1.ResourceRequirements{ Limits: map[v1.ResourceName]resource.Quantity{ "cpu": resource.MustParse("1"), @@ -456,7 +392,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, }, - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property1": "value1", "property3": "value3", @@ -468,7 +404,6 @@ func TestJobDb_SchedulingKey(t *testing.T) { Effect: "d", TolerationSeconds: pointer.Int64(1), }}, - PreemptionPolicy: "abcdef", ResourceRequirements: v1.ResourceRequirements{ Limits: map[v1.ResourceName]resource.Quantity{ "cpu": resource.MustParse("1"), @@ -485,7 +420,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { equal: true, }, "priority": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property1": "value1", "property3": "value3", @@ -505,7 +440,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, }, - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property1": "value1", "property3": "value3", @@ -528,7 +463,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { equal: true, }, "zero request does not affect key": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property1": "value1", "property3": "value3", @@ -548,7 +483,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, }, - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property3": "value3", "property1": "value1", @@ -572,7 +507,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { equal: true, }, "nodeSelector key": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property1": "value1", "property3": "value3", @@ -592,7 +527,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, }, - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property3": "value3", "property1": "value1", @@ -615,7 +550,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, "nodeSelector value": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property1": "value1", "property3": "value3", @@ -635,7 +570,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, }, - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property3": "value3", "property1": "value1-2", @@ -657,12 +592,12 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, "nodeSelector different keys, same values": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "my-cool-label": "value", }, }, - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "my-other-cool-label": "value", }, @@ -670,7 +605,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { equal: false, }, "toleration key": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property1": "value1", "property3": "value3", @@ -690,7 +625,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, }, - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property3": "value3", "property1": "value1", @@ -712,7 +647,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, "toleration operator": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property1": "value1", "property3": "value3", @@ -732,7 +667,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, }, - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property3": "value3", "property1": "value1", @@ -754,7 +689,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, "toleration value": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property1": "value1", "property3": "value3", @@ -774,7 +709,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, }, - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property3": "value3", "property1": "value1", @@ -796,7 +731,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, "toleration effect": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property1": "value1", "property3": "value3", @@ -816,7 +751,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, }, - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property3": "value3", "property1": "value1", @@ -838,7 +773,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, "toleration tolerationSeconds": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property1": "value1", "property3": "value3", @@ -858,7 +793,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, }, - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property3": "value3", "property1": "value1", @@ -881,7 +816,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { equal: true, }, "key ordering": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property1": "value1", "property3": "value3", @@ -901,7 +836,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, }, - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ NodeSelector: map[string]string{ "property3": "value3", "property1": "value1", @@ -924,7 +859,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { equal: true, }, "affinity PodAffinity ignored": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ Affinity: &v1.Affinity{ NodeAffinity: &v1.NodeAffinity{ RequiredDuringSchedulingIgnoredDuringExecution: &v1.NodeSelector{ @@ -987,7 +922,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { PodAntiAffinity: nil, }, }, - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ Affinity: &v1.Affinity{ NodeAffinity: &v1.NodeAffinity{ RequiredDuringSchedulingIgnoredDuringExecution: &v1.NodeSelector{ @@ -1053,7 +988,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { equal: true, }, "affinity NodeAffinity MatchExpressions": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ Affinity: &v1.Affinity{ NodeAffinity: &v1.NodeAffinity{ RequiredDuringSchedulingIgnoredDuringExecution: &v1.NodeSelector{ @@ -1079,7 +1014,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, }, - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ Affinity: &v1.Affinity{ NodeAffinity: &v1.NodeAffinity{ RequiredDuringSchedulingIgnoredDuringExecution: &v1.NodeSelector{ @@ -1108,7 +1043,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { equal: false, }, "affinity NodeAffinity MatchFields": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ Affinity: &v1.Affinity{ NodeAffinity: &v1.NodeAffinity{ RequiredDuringSchedulingIgnoredDuringExecution: &v1.NodeSelector{ @@ -1134,7 +1069,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, }, - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ Affinity: &v1.Affinity{ NodeAffinity: &v1.NodeAffinity{ RequiredDuringSchedulingIgnoredDuringExecution: &v1.NodeSelector{ @@ -1163,7 +1098,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { equal: false, }, "affinity NodeAffinity multiple MatchFields": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ Affinity: &v1.Affinity{ NodeAffinity: &v1.NodeAffinity{ RequiredDuringSchedulingIgnoredDuringExecution: &v1.NodeSelector{ @@ -1189,7 +1124,7 @@ func TestJobDb_SchedulingKey(t *testing.T) { }, }, }, - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ Affinity: &v1.Affinity{ NodeAffinity: &v1.NodeAffinity{ RequiredDuringSchedulingIgnoredDuringExecution: &v1.NodeSelector{ @@ -1223,13 +1158,13 @@ func TestJobDb_SchedulingKey(t *testing.T) { equal: false, }, "priority class names equal": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ ResourceRequirements: v1.ResourceRequirements{ Requests: map[v1.ResourceName]resource.Quantity{"cpu": resource.MustParse("2")}, }, }, priorityClassNameA: "my-cool-priority-class", - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ ResourceRequirements: v1.ResourceRequirements{ Requests: map[v1.ResourceName]resource.Quantity{"cpu": resource.MustParse("2")}, }, @@ -1238,13 +1173,13 @@ func TestJobDb_SchedulingKey(t *testing.T) { equal: true, }, "priority class names different": { - podRequirementsA: &schedulerobjects.PodRequirements{ + podRequirementsA: &internaltypes.PodRequirements{ ResourceRequirements: v1.ResourceRequirements{ Requests: map[v1.ResourceName]resource.Quantity{"cpu": resource.MustParse("2")}, }, }, priorityClassNameA: "my-cool-priority-class", - podRequirementsB: &schedulerobjects.PodRequirements{ + podRequirementsB: &internaltypes.PodRequirements{ ResourceRequirements: v1.ResourceRequirements{ Requests: map[v1.ResourceName]resource.Quantity{"cpu": resource.MustParse("2")}, }, @@ -1257,14 +1192,14 @@ func TestJobDb_SchedulingKey(t *testing.T) { t.Run(name, func(t *testing.T) { skg := schedulerobjects.NewSchedulingKeyGenerator() - jobSchedulingInfoA := proto.Clone(jobSchedulingInfo).(*schedulerobjects.JobSchedulingInfo) + jobSchedulingInfoA := jobSchedulingInfo.DeepCopy() jobSchedulingInfoA.PriorityClassName = tc.priorityClassNameA - jobSchedulingInfoA.ObjectRequirements[0].Requirements = &schedulerobjects.ObjectRequirements_PodRequirements{PodRequirements: tc.podRequirementsA} + jobSchedulingInfoA.PodRequirements = tc.podRequirementsA jobA := JobWithJobSchedulingInfo(baseJob, jobSchedulingInfoA) - jobSchedulingInfoB := proto.Clone(jobSchedulingInfo).(*schedulerobjects.JobSchedulingInfo) + jobSchedulingInfoB := jobSchedulingInfo.DeepCopy() jobSchedulingInfoB.PriorityClassName = tc.priorityClassNameB - jobSchedulingInfoB.ObjectRequirements[0].Requirements = &schedulerobjects.ObjectRequirements_PodRequirements{PodRequirements: tc.podRequirementsB} + jobSchedulingInfoB.PodRequirements = tc.podRequirementsB jobB := JobWithJobSchedulingInfo(baseJob, jobSchedulingInfoB) schedulingKeyA := SchedulingKeyFromJob(skg, jobA) diff --git a/internal/scheduler/jobdb/reconciliation.go b/internal/scheduler/jobdb/reconciliation.go index 0d275b0072a..38c1ebb6b4d 100644 --- a/internal/scheduler/jobdb/reconciliation.go +++ b/internal/scheduler/jobdb/reconciliation.go @@ -7,6 +7,7 @@ import ( armadamath "github.com/armadaproject/armada/internal/common/math" armadaslices "github.com/armadaproject/armada/internal/common/slices" "github.com/armadaproject/armada/internal/scheduler/database" + "github.com/armadaproject/armada/internal/scheduler/internaltypes" "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" "github.com/armadaproject/armada/pkg/api" ) @@ -139,15 +140,20 @@ func (jobDb *JobDb) reconcileJobDifferences(job *Job, jobRepoJob *database.Job, job = job.WithRequestedPriority(uint32(jobRepoJob.Priority)) } if uint32(jobRepoJob.SchedulingInfoVersion) > job.JobSchedulingInfo().Version { - schedulingInfo := &schedulerobjects.JobSchedulingInfo{} - if err = proto.Unmarshal(jobRepoJob.SchedulingInfo, schedulingInfo); err != nil { + schedulingInfoProto := &schedulerobjects.JobSchedulingInfo{} + if err = proto.Unmarshal(jobRepoJob.SchedulingInfo, schedulingInfoProto); err != nil { err = errors.Wrapf(err, "error unmarshalling scheduling info for job %s", jobRepoJob.JobID) - return + return jst, err + } + schedulingInfo, err := internaltypes.FromSchedulerObjectsJobSchedulingInfo(schedulingInfoProto) + if err != nil { + err = errors.Wrapf(err, "error converting scheduler info for job %s", jobRepoJob.JobID) + return jst, err } job, err = job.WithJobSchedulingInfo(schedulingInfo) if err != nil { err = errors.Wrapf(err, "error unmarshalling scheduling info for job %s", jobRepoJob.JobID) - return + return jst, err } } if jobRepoJob.QueuedVersion > job.QueuedVersion() { @@ -250,9 +256,14 @@ func (jobDb *JobDb) enforceTerminalStateExclusivity(jobRun *JobRun, rst *RunStat // schedulerJobFromDatabaseJob creates a new scheduler job from a database job. func (jobDb *JobDb) schedulerJobFromDatabaseJob(dbJob *database.Job) (*Job, error) { - schedulingInfo := &schedulerobjects.JobSchedulingInfo{} - if err := proto.Unmarshal(dbJob.SchedulingInfo, schedulingInfo); err != nil { - return nil, errors.Wrapf(err, "error unmarshalling scheduling info for job %s", dbJob.JobID) + schedulingInfoProto := &schedulerobjects.JobSchedulingInfo{} + if err := proto.Unmarshal(dbJob.SchedulingInfo, schedulingInfoProto); err != nil { + return nil, errors.WithMessagef(err, "error unmarshalling scheduling info for job %s", dbJob.JobID) + } + + schedulingInfo, err := internaltypes.FromSchedulerObjectsJobSchedulingInfo(schedulingInfoProto) + if err != nil { + return nil, errors.WithMessagef(err, "error converting scheduling info for job %s", dbJob.JobID) } job, err := jobDb.NewJob( diff --git a/internal/scheduler/jobdb/test_utils.go b/internal/scheduler/jobdb/test_utils.go index de39431b3b9..c8083c6b5f7 100644 --- a/internal/scheduler/jobdb/test_utils.go +++ b/internal/scheduler/jobdb/test_utils.go @@ -5,7 +5,6 @@ import ( schedulerconfiguration "github.com/armadaproject/armada/internal/scheduler/configuration" "github.com/armadaproject/armada/internal/scheduler/internaltypes" - "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" ) var testResourceListFactory = makeTestResourceListFactory() @@ -33,21 +32,21 @@ func getTestFloatingResourceTypes() []schedulerconfiguration.FloatingResourceCon } } -func WithJobDbJobPodRequirements(job *Job, reqs *schedulerobjects.PodRequirements) *Job { - return JobWithJobSchedulingInfo(job, &schedulerobjects.JobSchedulingInfo{ - PriorityClassName: job.JobSchedulingInfo().PriorityClassName, - SubmitTime: job.JobSchedulingInfo().SubmitTime, - ObjectRequirements: []*schedulerobjects.ObjectRequirements{ - { - Requirements: &schedulerobjects.ObjectRequirements_PodRequirements{ - PodRequirements: reqs, - }, - }, - }, - }) -} - -func JobWithJobSchedulingInfo(job *Job, jobSchedulingInfo *schedulerobjects.JobSchedulingInfo) *Job { +//func WithJobDbJobPodRequirements(job *Job, reqs *schedulerobjects.PodRequirements) *Job { +// return JobWithJobSchedulingInfo(job, &schedulerobjects.JobSchedulingInfo{ +// PriorityClassName: job.JobSchedulingInfo().PriorityClassName, +// SubmitTime: job.JobSchedulingInfo().SubmitTime, +// ObjectRequirements: []*schedulerobjects.ObjectRequirements{ +// { +// Requirements: &schedulerobjects.ObjectRequirements_PodRequirements{ +// PodRequirements: reqs, +// }, +// }, +// }, +// }) +//} + +func JobWithJobSchedulingInfo(job *Job, jobSchedulingInfo *internaltypes.JobSchedulingInfo) *Job { j, err := job.WithJobSchedulingInfo(jobSchedulingInfo) if err != nil { panic(err) diff --git a/internal/scheduler/metrics.go b/internal/scheduler/metrics.go index a64f040efb5..f51f82b7573 100644 --- a/internal/scheduler/metrics.go +++ b/internal/scheduler/metrics.go @@ -177,7 +177,7 @@ func (c *MetricsCollector) updateQueueMetrics(ctx *armadacontext.Context) ([]pro pools := job.ResolvedPools() priorityClass := job.JobSchedulingInfo().PriorityClassName - resourceRequirements := job.JobSchedulingInfo().GetObjectRequirements()[0].GetPodRequirements().GetResourceRequirements().Requests + resourceRequirements := job.JobSchedulingInfo().PodRequirements.ResourceRequirements.Requests jobResources := make(map[string]float64) for key, value := range resourceRequirements { jobResources[string(key)] = resource.QuantityAsFloat64(value) diff --git a/internal/scheduler/nodedb/nodematching.go b/internal/scheduler/nodedb/nodematching.go index 0a1dffc7dfe..37462e0389a 100644 --- a/internal/scheduler/nodedb/nodematching.go +++ b/internal/scheduler/nodedb/nodematching.go @@ -125,7 +125,7 @@ func (err *InsufficientResources) String() string { // If the requirements are not met, it returns the reason for why. // If the requirements can't be parsed, an error is returned. func NodeTypeJobRequirementsMet(nodeType *internaltypes.NodeType, jctx *schedulercontext.JobSchedulingContext) (bool, PodRequirementsNotMetReason) { - matches, reason := TolerationRequirementsMet(nodeType, jctx.AdditionalTolerations, jctx.PodRequirements.GetTolerations()) + matches, reason := TolerationRequirementsMet(nodeType, jctx.AdditionalTolerations, jctx.PodRequirements.Tolerations) if !matches { return matches, reason } @@ -135,7 +135,7 @@ func NodeTypeJobRequirementsMet(nodeType *internaltypes.NodeType, jctx *schedule return matches, reason } - return NodeSelectorRequirementsMet(nodeType.GetLabelValue, nodeType.GetUnsetIndexedLabelValue, jctx.PodRequirements.GetNodeSelector()) + return NodeSelectorRequirementsMet(nodeType.GetLabelValue, nodeType.GetUnsetIndexedLabelValue, jctx.PodRequirements.NodeSelector) } // JobRequirementsMet determines whether a job can be scheduled onto this node. @@ -159,7 +159,7 @@ func JobRequirementsMet(node *internaltypes.Node, priority int32, jctx *schedule // StaticJobRequirementsMet checks if a job can be scheduled onto this node, // accounting for taints, node selectors, node affinity, and total resources available on the node. func StaticJobRequirementsMet(node *internaltypes.Node, jctx *schedulercontext.JobSchedulingContext) (bool, PodRequirementsNotMetReason, error) { - matches, reason := NodeTolerationRequirementsMet(node, jctx.AdditionalTolerations, jctx.PodRequirements.GetTolerations()) + matches, reason := NodeTolerationRequirementsMet(node, jctx.AdditionalTolerations, jctx.PodRequirements.Tolerations) if !matches { return matches, reason, nil } @@ -169,7 +169,7 @@ func StaticJobRequirementsMet(node *internaltypes.Node, jctx *schedulercontext.J return matches, reason, nil } - matches, reason = NodeSelectorRequirementsMet(node.GetLabelValue, nil, jctx.PodRequirements.GetNodeSelector()) + matches, reason = NodeSelectorRequirementsMet(node.GetLabelValue, nil, jctx.PodRequirements.NodeSelector) if !matches { return matches, reason, nil } diff --git a/internal/scheduler/nodedb/nodematching_test.go b/internal/scheduler/nodedb/nodematching_test.go index 0decf8d9fe4..a3bf9c25404 100644 --- a/internal/scheduler/nodedb/nodematching_test.go +++ b/internal/scheduler/nodedb/nodematching_test.go @@ -18,13 +18,13 @@ func TestNodeSchedulingRequirementsMet(t *testing.T) { tests := map[string]struct { node *internaltypes.Node - req *schedulerobjects.PodRequirements + req *internaltypes.PodRequirements priority int32 expectSuccess bool }{ "nil taints and labels": { node: makeTestNodeTaintsLabels(nil, nil), - req: &schedulerobjects.PodRequirements{ + req: &internaltypes.PodRequirements{ Tolerations: []v1.Toleration{{Key: "foo", Value: "foo"}}, Affinity: &v1.Affinity{ NodeAffinity: &v1.NodeAffinity{ @@ -48,7 +48,7 @@ func TestNodeSchedulingRequirementsMet(t *testing.T) { }, "no taints or labels": { node: makeTestNodeTaintsLabels(make([]v1.Taint, 0), make(map[string]string)), - req: &schedulerobjects.PodRequirements{ + req: &internaltypes.PodRequirements{ Tolerations: []v1.Toleration{{Key: "foo", Value: "foo"}}, Affinity: &v1.Affinity{ NodeAffinity: &v1.NodeAffinity{ @@ -75,7 +75,7 @@ func TestNodeSchedulingRequirementsMet(t *testing.T) { []v1.Taint{{Key: "foo", Value: "foo", Effect: v1.TaintEffectNoSchedule}}, nil, ), - req: &schedulerobjects.PodRequirements{ + req: &internaltypes.PodRequirements{ Tolerations: []v1.Toleration{{Key: "foo", Value: "foo"}}, }, expectSuccess: true, @@ -85,7 +85,7 @@ func TestNodeSchedulingRequirementsMet(t *testing.T) { []v1.Taint{{Key: "foo", Value: "foo", Effect: v1.TaintEffectNoSchedule}}, nil, ), - req: &schedulerobjects.PodRequirements{}, + req: &internaltypes.PodRequirements{}, expectSuccess: false, }, "matched node affinity": { @@ -93,7 +93,7 @@ func TestNodeSchedulingRequirementsMet(t *testing.T) { nil, map[string]string{"bar": "bar"}, ), - req: &schedulerobjects.PodRequirements{ + req: &internaltypes.PodRequirements{ Affinity: &v1.Affinity{ NodeAffinity: &v1.NodeAffinity{ RequiredDuringSchedulingIgnoredDuringExecution: &v1.NodeSelector{ @@ -116,7 +116,7 @@ func TestNodeSchedulingRequirementsMet(t *testing.T) { }, "unmatched node affinity": { node: makeTestNodeTaintsLabels(nil, nil), - req: &schedulerobjects.PodRequirements{ + req: &internaltypes.PodRequirements{ Affinity: &v1.Affinity{ NodeAffinity: &v1.NodeAffinity{ RequiredDuringSchedulingIgnoredDuringExecution: &v1.NodeSelector{ @@ -142,7 +142,7 @@ func TestNodeSchedulingRequirementsMet(t *testing.T) { []v1.Taint{{Key: "foo", Value: "foo", Effect: v1.TaintEffectNoSchedule}}, map[string]string{"bar": "bar"}, ), - req: &schedulerobjects.PodRequirements{ + req: &internaltypes.PodRequirements{ Tolerations: []v1.Toleration{{Key: "foo", Value: "foo"}}, Affinity: &v1.Affinity{ NodeAffinity: &v1.NodeAffinity{ @@ -169,7 +169,7 @@ func TestNodeSchedulingRequirementsMet(t *testing.T) { []v1.Taint{{Key: "foo", Value: "foo", Effect: v1.TaintEffectNoSchedule}}, map[string]string{"bar": "bar"}, ), - req: &schedulerobjects.PodRequirements{ + req: &internaltypes.PodRequirements{ Affinity: &v1.Affinity{ NodeAffinity: &v1.NodeAffinity{ RequiredDuringSchedulingIgnoredDuringExecution: &v1.NodeSelector{ @@ -195,7 +195,7 @@ func TestNodeSchedulingRequirementsMet(t *testing.T) { []v1.Taint{{Key: "foo", Value: "foo", Effect: v1.TaintEffectNoSchedule}}, nil, ), - req: &schedulerobjects.PodRequirements{ + req: &internaltypes.PodRequirements{ Tolerations: []v1.Toleration{{Key: "foo", Value: "foo"}}, Affinity: &v1.Affinity{ NodeAffinity: &v1.NodeAffinity{ @@ -222,14 +222,14 @@ func TestNodeSchedulingRequirementsMet(t *testing.T) { nil, map[string]string{"bar": "bar"}, ), - req: &schedulerobjects.PodRequirements{ + req: &internaltypes.PodRequirements{ NodeSelector: map[string]string{"bar": "bar"}, }, expectSuccess: true, }, "unmatched node selector": { node: makeTestNodeTaintsLabels(nil, nil), - req: &schedulerobjects.PodRequirements{ + req: &internaltypes.PodRequirements{ NodeSelector: map[string]string{"bar": "bar"}, }, expectSuccess: false, @@ -239,7 +239,7 @@ func TestNodeSchedulingRequirementsMet(t *testing.T) { []v1.Taint{{Key: "foo", Value: "foo", Effect: v1.TaintEffectNoSchedule}}, map[string]string{"bar": "bar"}, ), - req: &schedulerobjects.PodRequirements{ + req: &internaltypes.PodRequirements{ Tolerations: []v1.Toleration{{Key: "foo", Value: "foo"}}, NodeSelector: map[string]string{"bar": "bar"}, }, @@ -250,7 +250,7 @@ func TestNodeSchedulingRequirementsMet(t *testing.T) { []v1.Taint{{Key: "foo", Value: "foo", Effect: v1.TaintEffectNoSchedule}}, map[string]string{"bar": "bar"}, ), - req: &schedulerobjects.PodRequirements{ + req: &internaltypes.PodRequirements{ NodeSelector: map[string]string{"bar": "bar"}, }, expectSuccess: false, @@ -260,7 +260,7 @@ func TestNodeSchedulingRequirementsMet(t *testing.T) { []v1.Taint{{Key: "foo", Value: "foo", Effect: v1.TaintEffectNoSchedule}}, nil, ), - req: &schedulerobjects.PodRequirements{ + req: &internaltypes.PodRequirements{ Tolerations: []v1.Toleration{{Key: "foo", Value: "foo"}}, NodeSelector: map[string]string{"bar": "bar"}, }, @@ -282,7 +282,7 @@ func TestNodeSchedulingRequirementsMet(t *testing.T) { }, ), ), - req: &schedulerobjects.PodRequirements{ + req: &internaltypes.PodRequirements{ ResourceRequirements: v1.ResourceRequirements{ Requests: v1.ResourceList{ "cpu": resource.MustParse("1"), @@ -307,7 +307,7 @@ func TestNodeSchedulingRequirementsMet(t *testing.T) { }, ), ), - req: &schedulerobjects.PodRequirements{ + req: &internaltypes.PodRequirements{ ResourceRequirements: v1.ResourceRequirements{ Requests: v1.ResourceList{ "cpu": resource.MustParse("1"), @@ -338,7 +338,7 @@ func TestNodeSchedulingRequirementsMet(t *testing.T) { }, ), ), - req: &schedulerobjects.PodRequirements{ + req: &internaltypes.PodRequirements{ ResourceRequirements: v1.ResourceRequirements{ Requests: v1.ResourceList{ "cpu": resource.MustParse("1"), @@ -369,7 +369,7 @@ func TestNodeSchedulingRequirementsMet(t *testing.T) { }, ), ), - req: &schedulerobjects.PodRequirements{ + req: &internaltypes.PodRequirements{ ResourceRequirements: v1.ResourceRequirements{ Requests: v1.ResourceList{ "cpu": resource.MustParse("1"), @@ -408,13 +408,13 @@ func TestNodeTypeSchedulingRequirementsMet(t *testing.T) { Labels map[string]string IndexedTaints map[string]bool IndexedLabels map[string]bool - Req *schedulerobjects.PodRequirements + Req *internaltypes.PodRequirements ExpectSuccess bool }{ "nil taints and labels": { Taints: nil, Labels: nil, - Req: &schedulerobjects.PodRequirements{ + Req: &internaltypes.PodRequirements{ Tolerations: []v1.Toleration{{Key: "foo", Value: "foo"}}, NodeSelector: map[string]string{"bar": "bar"}, }, @@ -423,7 +423,7 @@ func TestNodeTypeSchedulingRequirementsMet(t *testing.T) { "no taints or labels": { Taints: make([]v1.Taint, 0), Labels: make(map[string]string), - Req: &schedulerobjects.PodRequirements{ + Req: &internaltypes.PodRequirements{ Tolerations: []v1.Toleration{{Key: "foo", Value: "foo"}}, NodeSelector: map[string]string{"bar": "bar"}, }, @@ -432,7 +432,7 @@ func TestNodeTypeSchedulingRequirementsMet(t *testing.T) { "tolerated taints": { Taints: []v1.Taint{{Key: "foo", Value: "foo", Effect: v1.TaintEffectNoSchedule}}, Labels: nil, - Req: &schedulerobjects.PodRequirements{ + Req: &internaltypes.PodRequirements{ Tolerations: []v1.Toleration{{Key: "foo", Value: "foo"}}, }, ExpectSuccess: true, @@ -440,21 +440,21 @@ func TestNodeTypeSchedulingRequirementsMet(t *testing.T) { "untolerated taints": { Taints: []v1.Taint{{Key: "foo", Value: "foo", Effect: v1.TaintEffectNoSchedule}}, Labels: nil, - Req: &schedulerobjects.PodRequirements{}, + Req: &internaltypes.PodRequirements{}, ExpectSuccess: false, }, "untolerated non-indexed taint": { Taints: []v1.Taint{{Key: "foo", Value: "foo", Effect: v1.TaintEffectNoSchedule}}, Labels: nil, IndexedTaints: make(map[string]bool), - Req: &schedulerobjects.PodRequirements{}, + Req: &internaltypes.PodRequirements{}, ExpectSuccess: true, }, "matched node selector": { Taints: nil, Labels: map[string]string{"bar": "bar"}, IndexedLabels: map[string]bool{"bar": true}, - Req: &schedulerobjects.PodRequirements{ + Req: &internaltypes.PodRequirements{ NodeSelector: map[string]string{"bar": "bar"}, }, ExpectSuccess: true, @@ -463,7 +463,7 @@ func TestNodeTypeSchedulingRequirementsMet(t *testing.T) { Taints: nil, Labels: nil, IndexedLabels: map[string]bool{"bar": true}, - Req: &schedulerobjects.PodRequirements{ + Req: &internaltypes.PodRequirements{ Tolerations: []v1.Toleration{{Key: "foo", Value: "foo"}}, NodeSelector: map[string]string{"bar": "bar"}, }, @@ -473,7 +473,7 @@ func TestNodeTypeSchedulingRequirementsMet(t *testing.T) { Taints: nil, Labels: map[string]string{"bar": "baz"}, IndexedLabels: map[string]bool{"bar": true}, - Req: &schedulerobjects.PodRequirements{ + Req: &internaltypes.PodRequirements{ NodeSelector: map[string]string{"bar": "bar"}, }, ExpectSuccess: false, @@ -481,7 +481,7 @@ func TestNodeTypeSchedulingRequirementsMet(t *testing.T) { "missing label": { Taints: nil, Labels: nil, - Req: &schedulerobjects.PodRequirements{ + Req: &internaltypes.PodRequirements{ NodeSelector: map[string]string{"bar": "bar"}, }, ExpectSuccess: true, @@ -490,7 +490,7 @@ func TestNodeTypeSchedulingRequirementsMet(t *testing.T) { Taints: []v1.Taint{{Key: "foo", Value: "foo", Effect: v1.TaintEffectNoSchedule}}, Labels: map[string]string{"bar": "bar"}, IndexedLabels: map[string]bool{"bar": true}, - Req: &schedulerobjects.PodRequirements{ + Req: &internaltypes.PodRequirements{ Tolerations: []v1.Toleration{{Key: "foo", Value: "foo"}}, NodeSelector: map[string]string{"bar": "bar"}, }, @@ -500,7 +500,7 @@ func TestNodeTypeSchedulingRequirementsMet(t *testing.T) { Taints: []v1.Taint{{Key: "foo", Value: "foo", Effect: v1.TaintEffectNoSchedule}}, Labels: map[string]string{"bar": "bar"}, IndexedLabels: map[string]bool{"bar": true}, - Req: &schedulerobjects.PodRequirements{ + Req: &internaltypes.PodRequirements{ NodeSelector: map[string]string{"bar": "bar"}, }, ExpectSuccess: false, @@ -509,7 +509,7 @@ func TestNodeTypeSchedulingRequirementsMet(t *testing.T) { Taints: []v1.Taint{{Key: "foo", Value: "foo", Effect: v1.TaintEffectNoSchedule}}, Labels: map[string]string{"bar": "baz"}, IndexedLabels: map[string]bool{"bar": true}, - Req: &schedulerobjects.PodRequirements{ + Req: &internaltypes.PodRequirements{ Tolerations: []v1.Toleration{{Key: "foo", Value: "foo"}}, NodeSelector: map[string]string{"bar": "bar"}, }, @@ -519,7 +519,7 @@ func TestNodeTypeSchedulingRequirementsMet(t *testing.T) { Taints: []v1.Taint{{Key: "foo", Value: "foo", Effect: v1.TaintEffectNoSchedule}}, Labels: nil, IndexedLabels: map[string]bool{"bar": true}, - Req: &schedulerobjects.PodRequirements{ + Req: &internaltypes.PodRequirements{ Tolerations: []v1.Toleration{{Key: "foo", Value: "foo"}}, NodeSelector: map[string]string{"bar": "bar"}, }, diff --git a/internal/scheduler/scheduler.go b/internal/scheduler/scheduler.go index 356817ddaef..fac52e687af 100644 --- a/internal/scheduler/scheduler.go +++ b/internal/scheduler/scheduler.go @@ -4,7 +4,6 @@ import ( "fmt" "time" - "github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/types" "github.com/google/uuid" "github.com/pkg/errors" @@ -15,6 +14,7 @@ import ( "github.com/armadaproject/armada/internal/common/armadacontext" protoutil "github.com/armadaproject/armada/internal/common/proto" "github.com/armadaproject/armada/internal/scheduler/database" + "github.com/armadaproject/armada/internal/scheduler/internaltypes" "github.com/armadaproject/armada/internal/scheduler/jobdb" "github.com/armadaproject/armada/internal/scheduler/kubernetesobjects/affinity" "github.com/armadaproject/armada/internal/scheduler/leader" @@ -414,10 +414,10 @@ func (s *Scheduler) syncState(ctx *armadacontext.Context, initial bool) ([]*jobd return jobDbJobs, jsts, nil } -func (s *Scheduler) createSchedulingInfoWithNodeAntiAffinityForAttemptedRuns(job *jobdb.Job) (*schedulerobjects.JobSchedulingInfo, error) { - newSchedulingInfo := proto.Clone(job.JobSchedulingInfo()).(*schedulerobjects.JobSchedulingInfo) +func (s *Scheduler) createSchedulingInfoWithNodeAntiAffinityForAttemptedRuns(job *jobdb.Job) (*internaltypes.JobSchedulingInfo, error) { + newSchedulingInfo := job.JobSchedulingInfo().DeepCopy() newSchedulingInfo.Version = job.JobSchedulingInfo().Version + 1 - podRequirements := newSchedulingInfo.GetPodRequirements() + podRequirements := newSchedulingInfo.PodRequirements if podRequirements == nil { return nil, errors.Errorf("no pod scheduling requirement found for job %s", job.Id()) } @@ -726,7 +726,7 @@ func (s *Scheduler) generateUpdateMessagesFromJob(ctx *armadacontext.Context, jo Event: &armadaevents.EventSequence_Event_JobRequeued{ JobRequeued: &armadaevents.JobRequeued{ JobId: job.Id(), - SchedulingInfo: job.JobSchedulingInfo(), + SchedulingInfo: internaltypes.ToSchedulerObjectsJobSchedulingInfo(job.JobSchedulingInfo()), UpdateSequenceNumber: job.QueuedVersion(), }, }, diff --git a/internal/scheduler/scheduler_test.go b/internal/scheduler/scheduler_test.go index f369680d24c..04c9c71df21 100644 --- a/internal/scheduler/scheduler_test.go +++ b/internal/scheduler/scheduler_test.go @@ -123,7 +123,7 @@ var queuedJob = testfixtures.NewJob( "testQueue", uint32(10), 0.0, - schedulingInfo, + toInternalSchedulingInfo(schedulingInfo), true, 0, false, @@ -139,7 +139,7 @@ var leasedJob = testfixtures.NewJob( "testQueue", 0, 0.0, - schedulingInfo, + toInternalSchedulingInfo(schedulingInfo), false, 1, false, @@ -155,7 +155,7 @@ var preemptibleLeasedJob = testfixtures.NewJob( "testQueue", 0, 0.0, - preemptibleSchedulingInfo, + toInternalSchedulingInfo(preemptibleSchedulingInfo), false, 1, false, @@ -171,7 +171,7 @@ var cancelledJob = testfixtures.NewJob( "testQueue", 0, 0.0, - schedulingInfo, + toInternalSchedulingInfo(schedulingInfo), false, 1, true, @@ -187,7 +187,7 @@ var returnedOnceLeasedJob = testfixtures.NewJob( "testQueue", uint32(10), 0.0, - schedulingInfo, + toInternalSchedulingInfo(schedulingInfo), false, 3, false, @@ -246,7 +246,7 @@ var leasedFailFastJob = testfixtures.NewJob( "testQueue", uint32(10), 0.0, - failFastSchedulingInfo, + toInternalSchedulingInfo(failFastSchedulingInfo), false, 1, false, @@ -269,7 +269,7 @@ var ( "testQueue", uint32(10), 0.0, - schedulingInfo, + toInternalSchedulingInfo(schedulingInfo), true, 2, false, @@ -925,8 +925,7 @@ func TestScheduler_TestCycle(t *testing.T) { assert.Equal(t, expectedPriority, job.Priority()) } if len(tc.expectedNodeAntiAffinities) > 0 { - assert.Len(t, job.JobSchedulingInfo().ObjectRequirements, 1) - affinity := job.JobSchedulingInfo().ObjectRequirements[0].GetPodRequirements().Affinity + affinity := job.JobSchedulingInfo().PodRequirements.Affinity assert.NotNil(t, affinity) expectedAffinity := createAntiAffinity(t, nodeIdLabel, tc.expectedNodeAntiAffinities) assert.Equal(t, expectedAffinity, affinity) @@ -1020,7 +1019,7 @@ func TestRun(t *testing.T) { wg.Add(1) sched.onCycleCompleted = func() { wg.Done() } jobId := util.NewULID() - jobRepo.updatedJobs = []database.Job{{JobID: jobId, Queue: "testQueue", Queued: true, Validated: true}} + jobRepo.updatedJobs = []database.Job{{JobID: jobId, Queue: "testQueue", Queued: true, Validated: true, SchedulingInfo: schedulingInfoBytes}} schedulingAlgo.jobsToSchedule = []string{jobId} testClock.Step(10 * time.Second) wg.Wait() @@ -1327,7 +1326,7 @@ func TestScheduler_TestSyncState(t *testing.T) { }, }, expectedUpdatedJobs: []*jobdb.Job{ - jobdb.JobWithJobSchedulingInfo(leasedJob, updatedSchedulingInfo). + jobdb.JobWithJobSchedulingInfo(leasedJob, toInternalSchedulingInfo(updatedSchedulingInfo)). WithQueued(true). WithQueuedVersion(3), }, @@ -1791,7 +1790,7 @@ func jobDbJobFromDbJob(resourceListFactory *internaltypes.ResourceListFactory, j job.Queue, uint32(job.Priority), job.BidPrice, - &schedulingInfo, + toInternalSchedulingInfo(&schedulingInfo), job.Queued, job.QueuedVersion, job.CancelRequested, @@ -2874,3 +2873,11 @@ func fixInsertJobsDbOp(dbOp scheduleringester.InsertJobs) scheduleringester.Inse } return dbOp } + +func toInternalSchedulingInfo(j *schedulerobjects.JobSchedulingInfo) *internaltypes.JobSchedulingInfo { + internalJsi, err := internaltypes.FromSchedulerObjectsJobSchedulingInfo(j) + if err != nil { + panic(err) + } + return internalJsi +} diff --git a/internal/scheduler/scheduling/context/job.go b/internal/scheduler/scheduling/context/job.go index 81ea403c5ce..8be8afa404d 100644 --- a/internal/scheduler/scheduling/context/job.go +++ b/internal/scheduler/scheduling/context/job.go @@ -36,7 +36,7 @@ type JobSchedulingContext struct { Job *jobdb.Job // Scheduling requirements of this job. // We currently require that each job contains exactly one pod spec. - PodRequirements *schedulerobjects.PodRequirements + PodRequirements *internaltypes.PodRequirements // Resource requirements in an efficient internaltypes.ResourceList KubernetesResourceRequirements internaltypes.ResourceList // Node selectors to consider in addition to those included with the PodRequirements. diff --git a/internal/scheduler/scheduling/jobiteration_test.go b/internal/scheduler/scheduling/jobiteration_test.go index f74444ab1b2..c4bd4403630 100644 --- a/internal/scheduler/scheduling/jobiteration_test.go +++ b/internal/scheduler/scheduling/jobiteration_test.go @@ -10,20 +10,21 @@ import ( "github.com/armadaproject/armada/internal/common/armadacontext" "github.com/armadaproject/armada/internal/common/util" + "github.com/armadaproject/armada/internal/scheduler/internaltypes" "github.com/armadaproject/armada/internal/scheduler/jobdb" - "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" schedulercontext "github.com/armadaproject/armada/internal/scheduler/scheduling/context" "github.com/armadaproject/armada/internal/scheduler/testfixtures" ) func TestInMemoryJobRepository(t *testing.T) { + emptyRequirements := &internaltypes.PodRequirements{} jobs := []*jobdb.Job{ - testfixtures.TestJob("A", util.ULID(), "armada-default", nil).WithCreated(3).WithPriority(1), - testfixtures.TestJob("A", util.ULID(), "armada-default", nil).WithCreated(1).WithPriority(1), - testfixtures.TestJob("A", util.ULID(), "armada-default", nil).WithCreated(2).WithPriority(1), - testfixtures.TestJob("A", util.ULID(), "armada-default", nil).WithCreated(0).WithPriority(3), - testfixtures.TestJob("A", util.ULID(), "armada-default", nil).WithCreated(0).WithPriority(0), - testfixtures.TestJob("A", util.ULID(), "armada-default", nil).WithCreated(0).WithPriority(2), + testfixtures.TestJob("A", util.ULID(), "armada-default", emptyRequirements).WithCreated(3).WithPriority(1), + testfixtures.TestJob("A", util.ULID(), "armada-default", emptyRequirements).WithCreated(1).WithPriority(1), + testfixtures.TestJob("A", util.ULID(), "armada-default", emptyRequirements).WithCreated(2).WithPriority(1), + testfixtures.TestJob("A", util.ULID(), "armada-default", emptyRequirements).WithCreated(0).WithPriority(3), + testfixtures.TestJob("A", util.ULID(), "armada-default", emptyRequirements).WithCreated(0).WithPriority(0), + testfixtures.TestJob("A", util.ULID(), "armada-default", emptyRequirements).WithCreated(0).WithPriority(2), } jctxs := make([]*schedulercontext.JobSchedulingContext, len(jobs)) for i, job := range jobs { @@ -269,6 +270,6 @@ func (repo *mockJobRepository) GetJobIterator(ctx *armadacontext.Context, queue return NewQueuedJobsIterator(ctx, queue, testfixtures.TestPool, repo, jobdb.FairShareOrder) } -func jobFromPodSpec(queue string, req *schedulerobjects.PodRequirements) *jobdb.Job { +func jobFromPodSpec(queue string, req *internaltypes.PodRequirements) *jobdb.Job { return testfixtures.TestJob(queue, util.ULID(), "armada-default", req) } diff --git a/internal/scheduler/scheduling/preemption_description_test.go b/internal/scheduler/scheduling/preemption_description_test.go index c0ba13a325a..4ee416499dd 100644 --- a/internal/scheduler/scheduling/preemption_description_test.go +++ b/internal/scheduler/scheduling/preemption_description_test.go @@ -4,11 +4,12 @@ import ( "fmt" "testing" + "github.com/armadaproject/armada/internal/scheduler/internaltypes" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/armadaproject/armada/internal/scheduler/jobdb" - "github.com/armadaproject/armada/internal/scheduler/schedulerobjects" "github.com/armadaproject/armada/internal/scheduler/scheduling/context" "github.com/armadaproject/armada/internal/scheduler/testfixtures" "github.com/armadaproject/armada/internal/server/configuration" @@ -123,15 +124,9 @@ func makeJob(t *testing.T, jobId string, isGang bool) *jobdb.Job { if isGang { annotations[configuration.GangIdAnnotation] = "gang" } - schedulingInfo := &schedulerobjects.JobSchedulingInfo{ - ObjectRequirements: []*schedulerobjects.ObjectRequirements{ - { - Requirements: &schedulerobjects.ObjectRequirements_PodRequirements{ - PodRequirements: &schedulerobjects.PodRequirements{ - Annotations: annotations, - }, - }, - }, + schedulingInfo := &internaltypes.JobSchedulingInfo{ + PodRequirements: &internaltypes.PodRequirements{ + Annotations: annotations, }, } diff --git a/internal/scheduler/simulator/simulator.go b/internal/scheduler/simulator/simulator.go index 98ccdd4f143..123ff100da7 100644 --- a/internal/scheduler/simulator/simulator.go +++ b/internal/scheduler/simulator/simulator.go @@ -777,7 +777,11 @@ func (s *Simulator) handleEventSequence(_ *armadacontext.Context, es *armadaeven } func (s *Simulator) handleSubmitJob(txn *jobdb.Txn, e *armadaevents.SubmitJob, time time.Time, eventSequence *armadaevents.EventSequence) (*jobdb.Job, bool, error) { - schedulingInfo, err := scheduleringester.SchedulingInfoFromSubmitJob(e, time) + schedulingInfoProto, err := scheduleringester.SchedulingInfoFromSubmitJob(e, time) + if err != nil { + return nil, false, err + } + schedulingInfo, err := internaltypes.FromSchedulerObjectsJobSchedulingInfo(schedulingInfoProto) if err != nil { return nil, false, err } diff --git a/internal/scheduler/testfixtures/testfixtures.go b/internal/scheduler/testfixtures/testfixtures.go index d27755b2650..7f2c324e4c7 100644 --- a/internal/scheduler/testfixtures/testfixtures.go +++ b/internal/scheduler/testfixtures/testfixtures.go @@ -9,7 +9,6 @@ import ( "sync/atomic" "time" - "github.com/gogo/protobuf/proto" "github.com/google/uuid" "github.com/oklog/ulid" "golang.org/x/exp/maps" @@ -145,7 +144,7 @@ func NewJob( queue string, priority uint32, price float64, - schedulingInfo *schedulerobjects.JobSchedulingInfo, + schedulingInfo *internaltypes.JobSchedulingInfo, queued bool, queuedVersion int32, cancelRequested bool, @@ -402,13 +401,11 @@ func WithNodeAffinityJobs(nodeSelectorTerms []v1.NodeSelectorTerm, jobs []*jobdb func WithRequestsJobs(rl schedulerobjects.ResourceList, jobs []*jobdb.Job) []*jobdb.Job { newJobs := make([]*jobdb.Job, len(jobs)) for i, job := range jobs { - newSchedInfo := proto.Clone(job.JobSchedulingInfo()).(*schedulerobjects.JobSchedulingInfo) - for _, newReq := range newSchedInfo.GetObjectRequirements() { - maps.Copy( - newReq.GetPodRequirements().ResourceRequirements.Requests, - schedulerobjects.V1ResourceListFromResourceList(rl), - ) - } + newSchedInfo := job.JobSchedulingInfo().DeepCopy() + maps.Copy( + newSchedInfo.PodRequirements.ResourceRequirements.Requests, + schedulerobjects.V1ResourceListFromResourceList(rl), + ) newJob, err := job.WithJobSchedulingInfo(newSchedInfo) if err != nil { panic(err) @@ -420,18 +417,14 @@ func WithRequestsJobs(rl schedulerobjects.ResourceList, jobs []*jobdb.Job) []*jo func WithNodeSelectorJobs(selector map[string]string, jobs []*jobdb.Job) []*jobdb.Job { for _, job := range jobs { - for _, req := range job.JobSchedulingInfo().GetObjectRequirements() { - req.GetPodRequirements().NodeSelector = maps.Clone(selector) - } + job.JobSchedulingInfo().PodRequirements.NodeSelector = maps.Clone(selector) } return jobs } func WithNodeSelectorJob(selector map[string]string, job *jobdb.Job) *jobdb.Job { job = job.DeepCopy() - for _, req := range job.JobSchedulingInfo().GetObjectRequirements() { - req.GetPodRequirements().NodeSelector = maps.Clone(selector) - } + job.JobSchedulingInfo().PodRequirements.NodeSelector = maps.Clone(selector) return job } @@ -463,12 +456,10 @@ func WithNodeUniformityGangAnnotationsJobs(jobs []*jobdb.Job, nodeUniformityLabe func WithAnnotationsJobs(annotations map[string]string, jobs []*jobdb.Job) []*jobdb.Job { for _, job := range jobs { - for _, req := range job.JobSchedulingInfo().GetObjectRequirements() { - if req.GetPodRequirements().Annotations == nil { - req.GetPodRequirements().Annotations = make(map[string]string) - } - maps.Copy(req.GetPodRequirements().Annotations, annotations) + if job.PodRequirements().Annotations == nil { + job.PodRequirements().Annotations = make(map[string]string) } + maps.Copy(job.PodRequirements().Annotations, annotations) } return jobs } @@ -531,7 +522,7 @@ func extractPriority(priorityClassName string) int32 { return priorityClass.Priority } -func TestJob(queue string, jobId ulid.ULID, priorityClassName string, req *schedulerobjects.PodRequirements) *jobdb.Job { +func TestJob(queue string, jobId ulid.ULID, priorityClassName string, req *internaltypes.PodRequirements) *jobdb.Job { created := jobTimestamp.Add(1) submitTime := time.Time{}.Add(time.Millisecond * time.Duration(created)) job, _ := JobDb.NewJob( @@ -541,16 +532,10 @@ func TestJob(queue string, jobId ulid.ULID, priorityClassName string, req *sched // This is the per-queue priority of this job, which is unrelated to `priorityClassName`. 1000, 0.0, - &schedulerobjects.JobSchedulingInfo{ + &internaltypes.JobSchedulingInfo{ PriorityClassName: priorityClassName, SubmitTime: submitTime, - ObjectRequirements: []*schedulerobjects.ObjectRequirements{ - { - Requirements: &schedulerobjects.ObjectRequirements_PodRequirements{ - PodRequirements: req, - }, - }, - }, + PodRequirements: req, }, false, 0, @@ -589,23 +574,23 @@ func Test1GpuJob(queue string, priorityClassName string) *jobdb.Job { return TestJob(queue, jobId, priorityClassName, Test1GpuPodReqs(queue, jobId, extractPriority(priorityClassName))) } -func N1CpuPodReqs(queue string, priority int32, n int) []*schedulerobjects.PodRequirements { - rv := make([]*schedulerobjects.PodRequirements, n) +func N1CpuPodReqs(queue string, priority int32, n int) []*internaltypes.PodRequirements { + rv := make([]*internaltypes.PodRequirements, n) for i := 0; i < n; i++ { rv[i] = Test1Cpu4GiPodReqs(queue, util.ULID(), priority) } return rv } -func TestPodReqs(queue string, jobId ulid.ULID, priority int32, requests v1.ResourceList) *schedulerobjects.PodRequirements { - return &schedulerobjects.PodRequirements{ +func TestPodReqs(queue string, jobId ulid.ULID, priority int32, requests v1.ResourceList) *internaltypes.PodRequirements { + return &internaltypes.PodRequirements{ ResourceRequirements: v1.ResourceRequirements{Requests: requests}, Annotations: make(map[string]string), NodeSelector: make(map[string]string), } } -func Test1Cpu4GiPodReqs(queue string, jobId ulid.ULID, priority int32) *schedulerobjects.PodRequirements { +func Test1Cpu4GiPodReqs(queue string, jobId ulid.ULID, priority int32) *internaltypes.PodRequirements { return TestPodReqs( queue, jobId, @@ -617,7 +602,7 @@ func Test1Cpu4GiPodReqs(queue string, jobId ulid.ULID, priority int32) *schedule ) } -func Test1Cpu16GiPodReqs(queue string, jobId ulid.ULID, priority int32) *schedulerobjects.PodRequirements { +func Test1Cpu16GiPodReqs(queue string, jobId ulid.ULID, priority int32) *internaltypes.PodRequirements { return TestPodReqs( queue, jobId, @@ -629,7 +614,7 @@ func Test1Cpu16GiPodReqs(queue string, jobId ulid.ULID, priority int32) *schedul ) } -func Test16Cpu128GiPodReqs(queue string, jobId ulid.ULID, priority int32) *schedulerobjects.PodRequirements { +func Test16Cpu128GiPodReqs(queue string, jobId ulid.ULID, priority int32) *internaltypes.PodRequirements { req := TestPodReqs( queue, jobId, @@ -648,7 +633,7 @@ func Test16Cpu128GiPodReqs(queue string, jobId ulid.ULID, priority int32) *sched return req } -func Test32Cpu256GiPodReqs(queue string, jobId ulid.ULID, priority int32) *schedulerobjects.PodRequirements { +func Test32Cpu256GiPodReqs(queue string, jobId ulid.ULID, priority int32) *internaltypes.PodRequirements { req := TestPodReqs( queue, jobId, @@ -667,7 +652,7 @@ func Test32Cpu256GiPodReqs(queue string, jobId ulid.ULID, priority int32) *sched return req } -func Test1GpuPodReqs(queue string, jobId ulid.ULID, priority int32) *schedulerobjects.PodRequirements { +func Test1GpuPodReqs(queue string, jobId ulid.ULID, priority int32) *internaltypes.PodRequirements { req := TestPodReqs( queue, jobId, @@ -687,8 +672,8 @@ func Test1GpuPodReqs(queue string, jobId ulid.ULID, priority int32) *schedulerob return req } -func TestUnitReqs() *schedulerobjects.PodRequirements { - return &schedulerobjects.PodRequirements{ +func TestUnitReqs() *internaltypes.PodRequirements { + return &internaltypes.PodRequirements{ ResourceRequirements: v1.ResourceRequirements{ Requests: v1.ResourceList{ "cpu": resource.MustParse("1"), @@ -859,16 +844,10 @@ func TestQueuedJobDbJob() *jobdb.Job { TestQueue, 0, 0.0, - &schedulerobjects.JobSchedulingInfo{ + &internaltypes.JobSchedulingInfo{ PriorityClassName: TestDefaultPriorityClass, SubmitTime: BaseTime, - ObjectRequirements: []*schedulerobjects.ObjectRequirements{ - { - Requirements: &schedulerobjects.ObjectRequirements_PodRequirements{ - PodRequirements: TestUnitReqs(), - }, - }, - }, + PodRequirements: TestUnitReqs(), }, true, 0,