Skip to content

Commit

Permalink
Fix snap_update_test.py after PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
pieqq committed Oct 18, 2023
1 parent 47f4cc2 commit 60351ba
Showing 1 changed file with 119 additions and 162 deletions.
281 changes: 119 additions & 162 deletions providers/base/bin/snap_update_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,255 +19,212 @@
# along with Checkbox. If not, see <http://www.gnu.org/licenses/>.

import argparse
from glob import glob
from pathlib import Path
import json
import logging
import os.path
import sys
import time

from checkbox_support.snap_utils.snapd import Snapd


def guess_snaps() -> dict:
def guess_snaps() -> list:
"""
Guess the names of the kernel, snapd and gadget snaps from installed snaps
on the system.
:return: a dict with the snap names for each snap found
:rtype: dict
:return: a list of snap names that are either kernel, snapd or gadget snaps
:rtype: list
"""
snapd = Snapd()
installed_snaps = snapd.list()
snaps = {}
for snap in installed_snaps:
if snap["type"] == "kernel":
snaps["kernel"] = snap["name"]
elif snap["type"] == "gadget":
snaps["gadget"] = snap["name"]
elif snap["type"] == "snapd":
snaps["snapd"] = snap["name"]
snaps = [
snap["name"]
for snap in Snapd().list()
if snap["type"] in ("kernel", "gadget", "snapd")
]
return snaps


def get_snap_base_rev() -> dict:
def get_snaps_base_rev() -> dict:
"""
Retrieve the name and the base revision of each snap originally installed
on the system.
:return: a dict containing the snap names and their base revisions
:rtype: dict
"""
base_snaps = glob("/var/lib/snapd/seed/snaps/*.snap")
seed_snaps_dir = Path("/var/lib/snapd/seed/snaps/")
base_snaps = seed_snaps_dir.glob("*.snap")
base_rev_info = {}
for snap_path in base_snaps:
snap_basename = os.path.basename(snap_path)
snap_name = os.path.splitext(snap_basename)[0]
snap, rev = snap_name.rsplit("_", maxsplit=1)
snap, rev = snap_path.stem.rsplit("_", maxsplit=1)
base_rev_info[snap] = rev
return base_rev_info


def get_snap_info(name) -> dict:
"""
Retrieve information such as name, type, available revisions, etc. about
a given snap.
class SnapInfo:
def __init__(self, name):
snap = Snapd().list(name)
self.name = snap["name"]
self.type = snap["type"]
self.tracking_channel = snap["tracking-channel"]
self.installed_revision = snap["revision"]
self.tracking_prefix = (self.tracking_channel.split("/")[0] + "/") if "/" in self.tracking_channel else ""
self.base_revision = get_snaps_base_rev().get(name, "")

:return: a dict with the available information
:rtype: dict
"""
snapd = Snapd()
snap_info = {}
snap = snapd.list(name)
base_revs = get_snap_base_rev()
snap_info["name"] = snap["name"]
snap_info["type"] = snap["type"]
snap_info["tracking_channel"] = snap["tracking-channel"]
snap_info["installed_revision"] = snap["revision"]
snap_info["base_revision"] = base_revs.get(name, "")
tracking = snap_info["tracking_channel"]
prefix = (tracking.split("/")[0] + "/") if "/" in tracking else ""
snap_info["tracking_prefix"] = prefix
revisions = {}
for item in Snapd().find(name, exact=True):
for channel, info in item["channels"].items():
revisions[channel] = info["revision"]

snap_additional_info = snapd.find(name, exact=True)
snap_info["revisions"] = {}
for item in snap_additional_info:
for channel, info in item["channels"].items():
snap_info["revisions"][channel] = info["revision"]
return snap_info
self.stable_revision = revisions.get(
"{}stable".format(self.tracking_prefix), ""
)
self.candidate_revision = revisions.get(
"{}candidate".format(self.tracking_prefix), ""
)
self.beta_revision = revisions.get(
"{}beta".format(self.tracking_prefix), ""
)
self.edge_revision = revisions.get(
"{}edge".format(self.tracking_prefix), ""
)

def print_as_resource(self):
print("name: {}".format(self.name))
print("type: {}".format(self.type))
print("tracking: {}".format(self.tracking_channel))
print("base_rev: {}".format(self.base_revision))
print("stable_rev: {}".format(self.stable_revision))
print("candidate_rev: {}".format(self.candidate_revision))
print("beta_rev: {}".format(self.beta_revision))
print("edge_rev: {}".format(self.edge_revision))
print("original_installed_rev: {}".format(self.installed_revision))
print()


def print_resource_info():
snaps = guess_snaps().values()
for snap in snaps:
info = get_snap_info(snap)
tracking = info["tracking_channel"]
prefix = info["tracking_prefix"]
base_rev = info.get("base_revision", "")
stable_rev = info["revisions"].get("{}stable".format(prefix), "")
cand_rev = info["revisions"].get("{}candidate".format(prefix), "")
beta_rev = info["revisions"].get("{}beta".format(prefix), "")
edge_rev = info["revisions"].get("{}edge".format(prefix), "")
installed_rev = info.get("installed_revision", "")
for snap in guess_snaps():
SnapInfo(snap).print_as_resource()

print("name: {}".format(info["name"]))
print("type: {}".format(info["type"]))
print("tracking: {}".format(tracking))
print("base_rev: {}".format(base_rev))
print("stable_rev: {}".format(stable_rev))
print("candidate_rev: {}".format(cand_rev))
print("beta_rev: {}".format(beta_rev))
print("edge_rev: {}".format(edge_rev))
print("original_installed_rev: {}".format(installed_rev))
print()
def save_change_info(path, data):
with open(path, "w") as file:
json.dump(data, file)

def load_change_info(path):
try:
with open(path, "r") as file:
data = json.load(file)
except FileNotFoundError:
logging.error("File not found: %s", path)
logging.error("Did the previous job run as expected?")
raise SystemExit(1)
return data

class SnapRefreshRevert:
def __init__(self, name, rev, info_path):
def __init__(self, name, revision, info_path):
self.snapd = Snapd()
self.snap_info = get_snap_info(name)
self.snap_info = SnapInfo(name)
self.path = info_path
self.rev = rev
self.revision = revision
self.name = name

def snap_refresh(self):
data = {}
original_revision = self.snap_info["installed_revision"]
if original_revision == self.rev:
original_revision = self.snap_info.installed_revision
if original_revision == self.revision:
logging.error(
"Trying to refresh to the same revision (%s)!", self.rev
"Trying to refresh to the same revision (%s)!", self.revision
)
return 1
raise SystemExit(1)
data["name"] = self.name
data["original_revision"] = original_revision
data["destination_revision"] = self.rev
data["destination_revision"] = self.revision
logging.info(
"Refreshing %s snap from rev %s to rev %s",
"Refreshing %s snap from revision %s to revision %s",
self.name,
original_revision,
self.rev,
self.revision,
)
r = self.snapd.refresh(
response = self.snapd.refresh(
self.name,
channel=self.snap_info["tracking_channel"],
revision=self.rev,
channel=self.snap_info.tracking_channel,
revision=self.revision,
reboot=True,
)
logging.info(
"Refreshing requested (channel %s, rev %s)",
self.snap_info["tracking_channel"],
self.rev,
"Refreshing requested (channel %s, revision %s)",
self.snap_info.tracking_channel,
self.revision,
)
with open(self.path, "w") as file:
data["refresh_id"] = r["change"]
json.dump(data, file)
data["change_id"] = response["change"]
save_change_info(self.path, data)
logging.info("Waiting for reboot...")
return 0

def verify_refresh(self):
try:
with open(self.path, "r") as file:
data = json.load(file)
except FileNotFoundError:
logging.error("File not found: %s", self.path)
logging.error("Did the previous job run as expected?")
return 1
id = data["refresh_id"]
name = data["name"]

logging.info("Checking refresh status for snap %s...", name)
start_time = time.time()
timeout = 300 # 5 minutes timeout
while True:
result = self.snapd.change(str(id))
if result == "Done":
logging.info("%s snap refresh complete", name)
break

if time.time() - start_time >= timeout:
logging.error(
"%s snap refresh did not complete within 5 minutes", name
)
return False
logging.info("Waiting for %s snap refreshing to be done...", name)
logging.info("Trying again in 10 seconds...")
time.sleep(10)

current_rev = self.snapd.list(self.snap_info["name"])["revision"]
destination_rev = data["destination_revision"]
if current_rev != destination_rev:
logging.error(
"Current revision %s is NOT equal to expected revision %s",
current_rev,
destination_rev,
)
return 1
else:
logging.info(
"PASS: current revision (%s) matches the expected revision",
current_rev,
)
return 0

def snap_revert(self):
with open(self.path, "r") as file:
data = json.load(file)
data = load_change_info(self.path)
original_rev = data["original_revision"]
destination_rev = data["destination_revision"]
logging.info(
"Reverting %s snap (from rev %s to rev %s)",
"Reverting %s snap (from revision %s to revision %s)",
self.name,
destination_rev,
original_rev,
)
r = self.snapd.revert(self.snap_info["name"], reboot=True)
response = self.snapd.revert(self.name, reboot=True)
logging.info("Reverting requested")
with open(self.path, "w") as file:
data["revert_id"] = r["change"]
json.dump(data, file)
data["change_id"] = response["change"]
save_change_info(self.path, data)
logging.info("Waiting for reboot...")

def verify_revert(self):
with open(self.path, "r") as file:
data = json.load(file)
id = data["revert_id"]
original_rev = data["original_revision"]
def verify(self, type):
if type not in ("refresh", "revert"):
raise SystemExit(
"'{}' verification unknown. Can be either 'refresh' or 'revert'.".format(
type
)
)
data = load_change_info(self.path)
id = data["change_id"]

logging.info("Checking %s snap revert status", self.name)
logging.info("Checking %s status for snap %s...", type, self.name)
start_time = time.time()
timeout = 300 # 5 minutes timeout
while True:
result = self.snapd.change(str(id))
if result == "Done":
logging.info("%s snap revert complete", self.name)
logging.info("%s snap %s complete", self.name, type)
break

if time.time() - start_time >= timeout:
logging.error(
"%s snap revert did not complete within 5 minutes",
"%s snap %s did not complete within 5 minutes",
self.name,
type,
)
return False
raise SystemExit(1)
logging.info(
"Waiting for %s snap reverting to be done...", self.name
"Waiting for %s snap %s to be done...", self.name, type
)
logging.info("Trying again in 10 seconds.")
logging.info("Trying again in 10 seconds...")
time.sleep(10)

current_rev = self.snapd.list(self.snap_info["name"])["revision"]
if current_rev != original_rev:
current_rev = self.snapd.list(self.name)["revision"]
if type == "refresh":
tested_rev = data["destination_revision"]
else:
tested_rev = data["original_revision"]
if current_rev != tested_rev:
logging.error(
"Current revision (%s) is NOT equal to original revision (%s)",
"Current revision (%s) is different from expected revision (%s)",
current_rev,
original_rev,
tested_rev,
)
return 1
raise SystemExit(1)
else:
logging.info(
"PASS: current revision (%s) matches the original revision",
"PASS: current revision (%s) matches the expected revision",
current_rev,
)
return 0


def main():
Expand Down Expand Up @@ -310,7 +267,7 @@ def main():
help="Path to the information file",
)
parser.add_argument(
"--rev",
"--revision",
help="Revision to refresh to",
)

Expand All @@ -320,16 +277,16 @@ def main():
print_resource_info()
else:
test = SnapRefreshRevert(
name=args.name, info_path=args.info_path, rev=args.rev
name=args.name, info_path=args.info_path, revision=args.revision
)
if args.refresh:
return test.snap_refresh()
test.snap_refresh()
if args.verify_refresh:
return test.verify_refresh()
test.verify("refresh")
if args.revert:
return test.snap_revert()
test.snap_revert()
if args.verify_revert:
return test.verify_revert()
test.verify("revert")


if __name__ == "__main__":
Expand Down

0 comments on commit 60351ba

Please sign in to comment.