Skip to content

Commit

Permalink
Allow configuring single_branch (#2256)
Browse files Browse the repository at this point in the history
* Allow configuring single_branch

* Fix tests
  • Loading branch information
r4victor authored Feb 3, 2025
1 parent 59a935e commit d1be0f3
Show file tree
Hide file tree
Showing 12 changed files with 49 additions and 4 deletions.
9 changes: 7 additions & 2 deletions runner/internal/executor/repo.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ func (ex *RunExecutor) setupRepo(ctx context.Context) error {
}

func (ex *RunExecutor) prepareGit(ctx context.Context) error {
repoManager := repo.NewManager(ctx, ex.repoCredentials.CloneURL, ex.run.RepoData.RepoBranch, ex.run.RepoData.RepoHash).WithLocalPath(ex.workingDir)
repoManager := repo.NewManager(
ctx,
ex.repoCredentials.CloneURL,
ex.run.RepoData.RepoBranch,
ex.run.RepoData.RepoHash,
ex.jobSpec.SingleBranch,
).WithLocalPath(ex.workingDir)
if ex.repoCredentials != nil {
log.Trace(ctx, "Credentials is not empty")
switch ex.repoCredentials.GetProtocol() {
Expand All @@ -51,7 +57,6 @@ func (ex *RunExecutor) prepareGit(ctx context.Context) error {
if ex.repoCredentials.PrivateKey == nil {
return gerrors.Newf("private key is empty")
}
repoManager = repo.NewManager(ctx, ex.repoCredentials.CloneURL, ex.run.RepoData.RepoBranch, ex.run.RepoData.RepoHash).WithLocalPath(ex.workingDir)
repoManager.WithSSHAuth(*ex.repoCredentials.PrivateKey, "") // we don't support passphrase
default:
return gerrors.Newf("unsupported remote repo protocol: %s", ex.repoCredentials.GetProtocol())
Expand Down
4 changes: 2 additions & 2 deletions runner/internal/repo/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ type Manager struct {
hash string
}

func NewManager(ctx context.Context, url, branch, hash string) *Manager {
func NewManager(ctx context.Context, url, branch, hash string, singleBranch bool) *Manager {
ctx = log.AppendArgsCtx(ctx, "url", url, "branch", branch, "hash", hash)
m := &Manager{
ctx: ctx,
clo: git.CloneOptions{
URL: url,
RecurseSubmodules: git.DefaultSubmoduleRecursionDepth,
ReferenceName: plumbing.NewBranchReferenceName(branch),
SingleBranch: true,
SingleBranch: singleBranch,
},
hash: hash,
}
Expand Down
1 change: 1 addition & 0 deletions runner/internal/schemas/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type JobSpec struct {
Commands []string `json:"commands"`
Entrypoint []string `json:"entrypoint"`
Env map[string]string `json:"env"`
SingleBranch bool `json:"single_branch"`
MaxDuration int `json:"max_duration"`
WorkingDir *string `json:"working_dir"`
}
Expand Down
10 changes: 10 additions & 0 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ class BaseRunConfiguration(CoreModel):
description="Use image with NVIDIA CUDA Compiler (NVCC) included. Mutually exclusive with `image`"
),
]
single_branch: Annotated[
Optional[bool],
Field(
description=(
"Whether to clone and track only the current branch or all remote branches."
" Relevant only when using remote Git repos."
" Defaults to `false` for dev environments and to `true` for tasks and services"
)
),
] = None
env: Annotated[
Env,
Field(description="The mapping or the list of environment variables"),
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ class JobSpec(CoreModel):
home_dir: Optional[str]
image_name: str
privileged: bool = False
single_branch: Optional[bool] = None
max_duration: Optional[int]
stop_duration: Optional[int] = None
registry_auth: Optional[RegistryAuth]
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/server/schemas/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class SubmitBody(CoreModel):
"entrypoint",
"env",
"gateway",
"single_branch",
"max_duration",
"working_dir",
}
Expand Down
10 changes: 10 additions & 0 deletions src/dstack/_internal/server/services/jobs/configurators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ async def get_job_specs(self, replica_num: int) -> List[JobSpec]:
def _shell_commands(self) -> List[str]:
pass

@abstractmethod
def _default_single_branch(self) -> bool:
pass

@abstractmethod
def _default_max_duration(self) -> Optional[int]:
pass
Expand Down Expand Up @@ -104,6 +108,7 @@ async def _get_job_spec(
image_name=self._image_name(),
user=await self._user(),
privileged=self._privileged(),
single_branch=self._single_branch(),
max_duration=self._max_duration(),
stop_duration=self._stop_duration(),
registry_auth=self._registry_auth(),
Expand Down Expand Up @@ -172,6 +177,11 @@ async def _user(self) -> Optional[UnixUser]:
def _privileged(self) -> bool:
return self.run_spec.configuration.privileged

def _single_branch(self) -> bool:
if self.run_spec.configuration.single_branch is None:
return self._default_single_branch()
return self.run_spec.configuration.single_branch

def _max_duration(self) -> Optional[int]:
if self.run_spec.merged_profile.max_duration in [None, True]:
return self._default_max_duration()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def _shell_commands(self) -> List[str]:
commands += ["tail -f /dev/null"] # idle
return commands

def _default_single_branch(self) -> bool:
return False

def _default_max_duration(self) -> Optional[int]:
return DEFAULT_MAX_DURATION_SECONDS

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ class ServiceJobConfigurator(JobConfigurator):
def _shell_commands(self) -> List[str]:
return self.run_spec.configuration.commands

def _default_single_branch(self) -> bool:
return True

def _default_max_duration(self) -> Optional[int]:
return None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ async def get_job_specs(self, replica_num: int) -> List[JobSpec]:
def _shell_commands(self) -> List[str]:
return self.run_spec.configuration.commands

def _default_single_branch(self) -> bool:
return True

def _default_max_duration(self) -> Optional[int]:
return DEFAULT_MAX_DURATION_SECONDS

Expand Down
4 changes: 4 additions & 0 deletions src/dstack/api/server/_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def _get_run_spec_excludes(run_spec: RunSpec) -> Optional[dict]:
configuration_excludes.add("stop_duration")
if profile is not None and profile.stop_duration is None:
profile_excludes.add("stop_duration")
# client >= 0.18.40 / server <= 0.18.39 compatibility tweak
if configuration.single_branch is None:
configuration_excludes.add("single_branch")

if configuration_excludes:
spec_excludes["configuration"] = configuration_excludes
if profile_excludes:
Expand Down
4 changes: 4 additions & 0 deletions src/tests/_internal/server/routers/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def get_dev_env_run_plan_dict(
"instance_types": None,
"creation_policy": None,
"instance_name": None,
"single_branch": None,
"max_duration": "off",
"stop_duration": None,
"max_price": None,
Expand Down Expand Up @@ -180,6 +181,7 @@ def get_dev_env_run_plan_dict(
"replica_num": 0,
"job_num": 0,
"jobs_per_replica": 1,
"single_branch": False,
"max_duration": None,
"stop_duration": 300,
"registry_auth": None,
Expand Down Expand Up @@ -261,6 +263,7 @@ def get_dev_env_run_dict(
"instance_types": None,
"creation_policy": None,
"instance_name": None,
"single_branch": None,
"max_duration": "off",
"stop_duration": None,
"max_price": None,
Expand Down Expand Up @@ -331,6 +334,7 @@ def get_dev_env_run_dict(
"replica_num": 0,
"job_num": 0,
"jobs_per_replica": 1,
"single_branch": False,
"max_duration": None,
"stop_duration": 300,
"registry_auth": None,
Expand Down

0 comments on commit d1be0f3

Please sign in to comment.