diff --git a/src/lando/main/admin.py b/src/lando/main/admin.py index 846f6b40..f366d2db 100644 --- a/src/lando/main/admin.py +++ b/src/lando/main/admin.py @@ -1 +1,8 @@ -# Register your models here. +from django.contrib import admin + +from lando.main.models import LandingJob, Repo, Revision, Worker + +admin.site.register(LandingJob, admin.ModelAdmin) +admin.site.register(Revision, admin.ModelAdmin) +admin.site.register(Repo, admin.ModelAdmin) +admin.site.register(Worker, admin.ModelAdmin) diff --git a/src/lando/main/management/commands/__init__.py b/src/lando/main/management/commands/__init__.py index e69de29b..37a5bf65 100644 --- a/src/lando/main/management/commands/__init__.py +++ b/src/lando/main/management/commands/__init__.py @@ -0,0 +1,85 @@ +import os +import re +import subprocess +from time import sleep + +from lando.main.models import Worker + + +class WorkerMixin: + @staticmethod + def _setup_ssh(ssh_private_key: str): + """Add a given private ssh key to ssh agent. + + SSH keys are needed in order to push to repositories that have an ssh + push path. + + The private key should be passed as it is in the key file, including all + new line characters and the new line character at the end. + + Args: + ssh_private_key (str): A string representing the private SSH key file. + """ + # Set all the correct environment variables + agent_process = subprocess.run( + ["ssh-agent", "-s"], capture_output=True, universal_newlines=True + ) + + # This pattern will match keys and values, and ignore everything after the + # semicolon. For example, the output of `agent_process` is of the form: + # SSH_AUTH_SOCK=/tmp/ssh-c850kLXXOS5e/agent.120801; export SSH_AUTH_SOCK; + # SSH_AGENT_PID=120802; export SSH_AGENT_PID; + # echo Agent pid 120802; + pattern = re.compile("(.+)=([^;]*)") + for key, value in pattern.findall(agent_process.stdout): + os.environ[key] = value + + # Add private SSH key to agent + # NOTE: ssh-add seems to output everything to stderr, including upon exit 0. + add_process = subprocess.run( + ["ssh-add", "-"], + input=ssh_private_key, + capture_output=True, + universal_newlines=True, + ) + if add_process.returncode != 0: + raise Exception(add_process.stderr) + + @property + def _instance(self): + return Worker.objects.get(name=self.name) + + def _setup(self): + """Perform various setup actions.""" + if self._instance.ssh_private_key: + self._setup_ssh(self._instance.ssh_private_key) + + def _start(self, max_loops: int | None = None, *args, **kwargs): + """Run the main event loop.""" + # NOTE: The worker will exit when max_loops is reached, or when the stop + # variable is changed to True. + loops = 0 + while not self._instance.is_stopped: + if max_loops is not None and loops >= max_loops: + break + while self._instance.is_paused: + self.throttle(self._instance.sleep_seconds) + self.loop(*args, **kwargs) + loops += 1 + + self.stdout.write(f"{self} exited after {loops} loops.") + + def throttle(self, seconds: int | None = None): + """Sleep for a given number of seconds.""" + sleep(seconds if seconds is not None else self._instance.throttle_seconds) + + def start(self, max_loops: int | None = None): + """Run setup sequence and start the event loop.""" + if self._instance.is_stopped: + return + self._setup() + self._start(max_loops=max_loops) + + def loop(self, *args, **kwargs): + """The main event loop.""" + raise NotImplementedError() diff --git a/src/lando/main/management/commands/landing_worker.py b/src/lando/main/management/commands/landing_worker.py new file mode 100644 index 00000000..2784bb88 --- /dev/null +++ b/src/lando/main/management/commands/landing_worker.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import logging +from contextlib import contextmanager +from datetime import datetime +from io import StringIO + +from django.core.management.base import BaseCommand +from django.db import transaction +from lando.main.management.commands import WorkerMixin +from lando.main.models import LandingJob, LandingJobStatus + +logger = logging.getLogger(__name__) + + +@contextmanager +def job_processing(job: LandingJob): + """Mutex-like context manager that manages job processing miscellany. + + This context manager facilitates graceful worker shutdown, tracks the duration of + the current job, and commits changes to the DB at the very end. + + Args: + job: the job currently being processed + db: active database session + """ + start_time = datetime.now() + try: + yield + finally: + job.duration_seconds = (datetime.now() - start_time).seconds + + +class Command(BaseCommand, WorkerMixin): + help = "Start the landing worker." + name = "landing-worker" + + def add_arguments(self, parser): + pass + + def handle(self, *args, **options): + self.last_job_finished = None + self.start() + + def loop(self): + if self.last_job_finished is False: + logger.info("Last job did not complete, sleeping.") + self.throttle(self._instance.sleep_seconds) + + for repo in self._instance.enabled_repos: + if not repo.is_initialized: + repo.initialize() + + with transaction.atomic(): + job = LandingJob.next_job(repositories=self._instance.enabled_repos).first() + + if job is None: + self.throttle(self._instance.sleep_seconds) + return + + with job_processing(job): + job.status = LandingJobStatus.IN_PROGRESS + job.attempts += 1 + job.save() + + self.stdout.write(f"Starting landing job {job}") + self.last_job_finished = self.run_job(job) + self.stdout.write("Finished processing landing job") + + def run_job(self, job: LandingJob) -> bool: + repo = job.target_repo + repo.reset() + repo.pull() + + for revision in job.revisions.all(): + patch_buffer = StringIO(revision.patch) + repo.apply_patch(patch_buffer) + + # TODO: need to account for reverts/backouts somehow in the futue. + revision.commit_id = repo._run("rev-parse", "HEAD").stdout.strip() + revision.save() + + repo.push() + + job.status = LandingJobStatus.LANDED + job.save() diff --git a/src/lando/main/management/commands/test_command.py b/src/lando/main/management/commands/test_command.py deleted file mode 100644 index 06a504a7..00000000 --- a/src/lando/main/management/commands/test_command.py +++ /dev/null @@ -1,12 +0,0 @@ -from django.core.management.base import BaseCommand - - -class Command(BaseCommand): - help = "Test command" - - def add_arguments(self, parser): - parser.add_argument("names", nargs="+") - - def handle(self, *args, **options): - for name in options["names"]: - self.stdout.write(self.style.SUCCESS(f"Hello {name}!")) diff --git a/src/lando/main/migrations/0003_remove_landingjob_repository_name_and_more.py b/src/lando/main/migrations/0003_remove_landingjob_repository_name_and_more.py new file mode 100644 index 00000000..a28e829f --- /dev/null +++ b/src/lando/main/migrations/0003_remove_landingjob_repository_name_and_more.py @@ -0,0 +1,88 @@ +# Generated by Django 5.0rc1 on 2023-12-01 16:16 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("main", "0002_repo_worker"), + ] + + operations = [ + migrations.RemoveField( + model_name="landingjob", + name="repository_name", + ), + migrations.RemoveField( + model_name="landingjob", + name="repository_url", + ), + migrations.RemoveField( + model_name="revision", + name="patch_bytes", + ), + migrations.AddField( + model_name="landingjob", + name="target_repo", + field=models.ForeignKey( + null=True, on_delete=django.db.models.deletion.SET_NULL, to="main.repo" + ), + ), + migrations.AddField( + model_name="revision", + name="commit_id", + field=models.CharField(blank=True, max_length=40, null=True), + ), + migrations.AddField( + model_name="revision", + name="patch", + field=models.TextField(blank=True, default=""), + ), + migrations.AlterField( + model_name="landingjob", + name="landed_commit_id", + field=models.TextField(blank=True, default=""), + ), + migrations.AlterField( + model_name="landingjob", + name="requester_email", + field=models.CharField(blank=True, default="", max_length=255), + ), + migrations.AlterField( + model_name="landingjob", + name="target_commit_hash", + field=models.TextField(blank=True, default=""), + ), + migrations.AlterField( + model_name="repo", + name="system_path", + field=models.FilePathField( + allow_folders=True, + blank=True, + default="", + max_length=255, + path="/mediafiles/repos", + ), + ), + migrations.AlterField( + model_name="revision", + name="data", + field=models.JSONField(blank=True, default=dict), + ), + migrations.AlterField( + model_name="revision", + name="diff_id", + field=models.IntegerField(blank=True, null=True), + ), + migrations.AlterField( + model_name="revision", + name="patch_data", + field=models.JSONField(blank=True, default=dict), + ), + migrations.AlterField( + model_name="revision", + name="revision_id", + field=models.IntegerField(blank=True, null=True, unique=True), + ), + ] diff --git a/src/lando/main/models.py b/src/lando/main/models.py index 8c96bc33..9b05f819 100644 --- a/src/lando/main/models.py +++ b/src/lando/main/models.py @@ -4,6 +4,7 @@ import logging import os import subprocess +import tempfile from pathlib import Path from typing import ( Any, @@ -16,7 +17,7 @@ from django.utils.translation import gettext_lazy from lando import settings -from lando.utils import build_patch_for_revision +from lando.utils import GitPatchHelper, build_patch_for_revision logger = logging.getLogger(__name__) @@ -45,20 +46,23 @@ class Revision(BaseModel): """ # revision_id and diff_id map to Phabricator IDs (integers). - revision_id = models.IntegerField(null=True, unique=True) + revision_id = models.IntegerField(blank=True, null=True, unique=True) # diff_id is that of the latest diff on the revision at landing request time. It # does not track all diffs. - diff_id = models.IntegerField(null=True) + diff_id = models.IntegerField(blank=True, null=True) # The actual patch. - patch_bytes = models.BinaryField(default=b"") + patch = models.TextField(blank=True, default="") # Patch metadata, such as author, timestamp, etc... - patch_data = models.JSONField(default=dict) + patch_data = models.JSONField(blank=True, default=dict) # A general purpose data field to store arbitrary information about this revision. - data = models.JSONField(default=dict) + data = models.JSONField(blank=True, default=dict) + + # The commit ID generated by the landing worker, before pushing to remote repo. + commit_id = models.CharField(max_length=40, null=True, blank=True) def __repr__(self) -> str: """Return a human-readable representation of the instance.""" @@ -100,16 +104,10 @@ class LandingJob(BaseModel): error_breakdown = models.JSONField(null=True, blank=True, default=dict) # LDAP email of the user who requested transplant. - requester_email = models.CharField(max_length=255) - - # Lando's name for the repository. - repository_name = models.CharField(max_length=255) - - # URL of the repository revisions are to land to. - repository_url = models.TextField(default="") + requester_email = models.CharField(blank=True, default="", max_length=255) # Identifier for the most descendent commit created by this landing. - landed_commit_id = models.TextField(default="") + landed_commit_id = models.TextField(blank=True, default="") # Number of attempts made to complete the job. attempts = models.IntegerField(default=0) @@ -121,10 +119,12 @@ class LandingJob(BaseModel): duration_seconds = models.IntegerField(default=0) # Identifier of the published commit which this job should land on top of. - target_commit_hash = models.TextField(default="") + target_commit_hash = models.TextField(blank=True, default="") revisions = models.ManyToManyField(Revision) # TODO: order by index + target_repo = models.ForeignKey("Repo", on_delete=models.SET_NULL, null=True) + @classmethod def job_queue_query( cls, @@ -140,14 +140,14 @@ def job_queue_query( many seconds ago. """ applicable_statuses = ( - cls.LandingJobStatus.SUBMITTED, - cls.LandingJobStatus.IN_PROGRESS, - cls.LandingJobStatus.DEFERRED, + LandingJobStatus.SUBMITTED, + LandingJobStatus.IN_PROGRESS, + LandingJobStatus.DEFERRED, ) q = cls.objects.filter(status__in=applicable_statuses) if repositories: - q = q.filter(repository_name__in=(repositories)) + q = q.filter(target_repo__in=repositories) if grace_seconds: now = datetime.datetime.now(datetime.timezone.utc) @@ -163,17 +163,13 @@ def job_queue_query( return q @classmethod - def next_job_for_update_query( - cls, repositories: Optional[Iterable[str]] = None - ) -> QuerySet: + def next_job(cls, repositories: Optional[Iterable[str]] = None) -> QuerySet: """Return a query which selects the next job and locks the row.""" query = cls.job_queue_query(repositories=repositories) # Returned rows should be locked for updating, this ensures the next # job can be claimed. - query = query.select_for_update() - - return query + return query.select_for_update() def add_job_with_revisions(revisions: list[Revision], **params: Any) -> LandingJob: @@ -194,35 +190,70 @@ class Repo(BaseModel): is_initialized = models.BooleanField(default=False) system_path = models.FilePathField( - path=settings.REPO_ROOT, max_length=255, allow_folders=True + path=settings.REPO_ROOT, + max_length=255, + allow_folders=True, + blank=True, + default="", ) def _run(self, *args, cwd=None): cwd = cwd or self.system_path command = ["git"] + list(args) - result = subprocess.run(command, cwd=cwd) + result = subprocess.run(command, cwd=cwd, capture_output=True, text=True) return result def initialize(self): + self.refresh_from_db() + if self.is_initialized: raise self.system_path = str(Path(settings.REPO_ROOT) / self.name) - self.save() result = self._run("clone", self.pull_path, self.name, cwd=settings.REPO_ROOT) if result.returncode == 0: self.is_initialized = True + self.save() else: - raise Exception(result.returncode) - self.save() + raise Exception(f"{result.returncode}: {result.stderr}") - def update(self): + def pull(self): self._run("pull", "--all", "--prune") def reset(self, branch=None): self._run("reset", "--hard", branch or self.default_branch) self._run("clean", "--force") + def apply_patch(self, patch_buffer): + patch_helper = GitPatchHelper(patch_buffer) + self.patch_header = patch_helper.get_header + + # Import the diff to apply the changes then commit separately to + # ensure correct parsing of the commit message. + f_msg = tempfile.NamedTemporaryFile(encoding="utf-8", mode="w+") + f_diff = tempfile.NamedTemporaryFile(encoding="utf-8", mode="w+") + with f_msg, f_diff: + patch_helper.write_commit_description(f_msg) + f_msg.flush() + patch_helper.write_diff(f_diff) + f_diff.flush() + + self._run("apply", f_diff.name) + + # Commit using the extracted date, user, and commit desc. + # --landing_system is provided by the set_landing_system hgext. + date = patch_helper.get_header("Date") + user = patch_helper.get_header("From") + + self._run("add", "-A") + self._run("commit", "--date", date, "--author", user, "--file", f_msg.name) + + def last_commit_id(self): + return self._run("rev-parse", "HEAD").stdout.strip() + + def push(self): + self._run("push") + class Worker(BaseModel): name = models.CharField(max_length=255, unique=True) @@ -233,3 +264,7 @@ class Worker(BaseModel): throttle_seconds = models.IntegerField(default=10) sleep_seconds = models.IntegerField(default=10) + + @property + def enabled_repos(self): + return self.applicable_repos.all() diff --git a/src/lando/settings.py b/src/lando/settings.py index 89153287..84cad219 100644 --- a/src/lando/settings.py +++ b/src/lando/settings.py @@ -150,3 +150,5 @@ "django.contrib.auth.backends.ModelBackend", "mozilla_django_oidc.auth.OIDCAuthenticationBackend", ] + +GITHUB_ACCESS_TOKEN = os.getenv("LANDO_GITHUB_ACCESS_TOKEN")