Skip to content

Commit

Permalink
Merge pull request #746 from serokell/dmozhevitin/#738-improve-snapsh…
Browse files Browse the repository at this point in the history
…ot-search-algorithm

[#738] Further improve the algorithm of picking the latest compatible snapshot
  • Loading branch information
DMozhevitin authored Nov 17, 2023
2 parents 73a2e2d + d04995b commit aaa81d0
Showing 1 changed file with 73 additions and 12 deletions.
85 changes: 73 additions & 12 deletions baking/src/tezos_baking/tezos_setup_wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,21 @@ def get_node_version():
)


def is_non_protocol_testnet(network):
return network == "mainnet" or network == "ghostnet"


# Starting from Nairobi protocol, the corresponding testnet
# is no longer a named network, so we need to provide the URL
# of the network configuration instead of the network name
# in 'octez-node config init' command.
def network_name_or_teztnets_url(network):
if is_non_protocol_testnet(network):
return network
else:
return f"https://teztnets.xyz/{network}"


compatible_snapshot_version = 6


Expand Down Expand Up @@ -262,8 +277,7 @@ def get_snapshot_mode_query(config):
dynamic_import_modes = {}

for name in default_providers.keys():
if config["snapshots"].get(name, None):
dynamic_import_modes[mk_option(name)] = mk_desc(name)
dynamic_import_modes[mk_option(name)] = mk_desc(name)

import_modes = {**dynamic_import_modes, **static_import_modes}

Expand Down Expand Up @@ -397,12 +411,13 @@ def check_blockchain_data(self):
if not node_dir_config.issubset(node_dir_contents):
print_and_log("The Tezos node data directory has not been configured yet.")
print_and_log(" Configuring directory: " + node_dir)
network = self.config["network"]
proc_call(
"sudo -u tezos octez-node-"
+ self.config["network"]
+ " config init"
+ " --network "
+ self.config["network"]
+ network_name_or_teztnets_url(self.config["network"])
+ " --rpc-addr "
+ self.config["node_rpc_addr"]
)
Expand Down Expand Up @@ -544,7 +559,7 @@ def sum_pred(*preds):
# it could happen that `snapshot_version` field is not supplied by provider
# e.g. marigold snapshots don't supply it
lambda major, minor, rc, snapshot_version: snapshot_version
and compatible_snapshot_version == snapshot_version
and compatible_snapshot_version - snapshot_version <= 1
)

non_rc_on_stable_pred = lambda major, minor, rc, snapshot_version: (
Expand Down Expand Up @@ -687,6 +702,48 @@ def get_snapshot_from_provider(self, name, url):
snapshot_block_hash = self.config["snapshots"][name]["block_hash"]
return (snapshot_file, snapshot_block_hash)

# check if a given provider has the compatible snapshot
# available in its metadata and return the metadata of this
# snapshot if it's available
def try_fallback_provider(self, name, url):
print(f"Getting snapshots' metadata from {name} instead...")
self.get_snapshot_metadata(name, url)
return self.config["snapshots"].get(name, None)

# check if some of the providers has the compatible snapshot
# available in its metadadata and return the provider name
def find_fallback_provider(self, providers):
for name, url in providers.items():
snapshot = self.try_fallback_provider(name, url)
if snapshot is not None:
return name
return None

# tries to get the latest compatible snapshot from the given
# provider's metadata
#
# if the snapshot not found, tries to find it in other known
# providers
def get_snapshot_from_provider_with_fallback(self, name, url):
print_and_log(f"Getting snapshots' metadata from {name}...")

self.get_snapshot_metadata(name, url)
snapshot = self.config["snapshots"].get(name, None)

if snapshot is None:
fallback_providers = default_providers.copy()
fallback_providers.pop(name)
fallback_provider = self.find_fallback_provider(fallback_providers)

if fallback_provider is None:
return None
else:
name = fallback_provider

snapshot_file = self.fetch_snapshot_from_provider(name)
snapshot_block_hash = self.config["snapshots"][name]["block_hash"]
return (snapshot_file, snapshot_block_hash)

def get_snapshot_from_direct_url(self, url):
try:
self.query_step(snapshot_sha256_query)
Expand Down Expand Up @@ -743,10 +800,6 @@ def import_snapshot(self):

self.config["snapshots"] = {}

print_and_log("Getting snapshots' metadata from providers...")
for name, url in default_providers.items():
self.get_snapshot_metadata(name, url)

os.makedirs(TMP_SNAPSHOT_LOCATION, exist_ok=True)

else:
Expand Down Expand Up @@ -786,10 +839,18 @@ def import_snapshot(self):
else:
for name, url in default_providers.items():
if name in self.config["snapshot_mode"]:
(
snapshot_file,
snapshot_block_hash,
) = self.get_snapshot_from_provider(name, url)
selected_provider = (name, url)
snapshot_info = self.get_snapshot_from_provider_with_fallback(
*selected_provider
)
if snapshot_info is None:
print_and_log(
"Couldn't find available snapshot in any of the known providers.",
log=logging.warning,
colorcode=color_yellow,
)
raise InterruptStep
(snapshot_file, snapshot_block_hash) = snapshot_info
except InterruptStep:
print_and_log("Getting back to the snapshot import mode step.")
continue
Expand Down

0 comments on commit aaa81d0

Please sign in to comment.