diff --git a/baking/src/tezos_baking/tezos_setup_wizard.py b/baking/src/tezos_baking/tezos_setup_wizard.py index ac1ecaacd..574186318 100644 --- a/baking/src/tezos_baking/tezos_setup_wizard.py +++ b/baking/src/tezos_baking/tezos_setup_wizard.py @@ -196,6 +196,21 @@ def get_node_version(): ) +def is_non_protocol_testnet(network): + return network == "mainnet" or network == "ghostnet" + + +# Starting from Nairobi protocol, the corresponding testnet +# is no longer a named network, so we need to provide the URL +# of the network configuration instead of the network name +# in 'octez-node config init' command. +def network_name_or_teztnets_url(network): + if is_non_protocol_testnet(network): + return network + else: + return f"https://teztnets.xyz/{network}" + + compatible_snapshot_version = 6 @@ -262,8 +277,7 @@ def get_snapshot_mode_query(config): dynamic_import_modes = {} for name in default_providers.keys(): - if config["snapshots"].get(name, None): - dynamic_import_modes[mk_option(name)] = mk_desc(name) + dynamic_import_modes[mk_option(name)] = mk_desc(name) import_modes = {**dynamic_import_modes, **static_import_modes} @@ -397,12 +411,13 @@ def check_blockchain_data(self): if not node_dir_config.issubset(node_dir_contents): print_and_log("The Tezos node data directory has not been configured yet.") print_and_log(" Configuring directory: " + node_dir) + network = self.config["network"] proc_call( "sudo -u tezos octez-node-" + self.config["network"] + " config init" + " --network " - + self.config["network"] + + network_name_or_teztnets_url(self.config["network"]) + " --rpc-addr " + self.config["node_rpc_addr"] ) @@ -544,7 +559,7 @@ def sum_pred(*preds): # it could happen that `snapshot_version` field is not supplied by provider # e.g. marigold snapshots don't supply it lambda major, minor, rc, snapshot_version: snapshot_version - and compatible_snapshot_version == snapshot_version + and compatible_snapshot_version - snapshot_version <= 1 ) non_rc_on_stable_pred = lambda major, minor, rc, snapshot_version: ( @@ -687,6 +702,48 @@ def get_snapshot_from_provider(self, name, url): snapshot_block_hash = self.config["snapshots"][name]["block_hash"] return (snapshot_file, snapshot_block_hash) + # check if a given provider has the compatible snapshot + # available in its metadata and return the metadata of this + # snapshot if it's available + def try_fallback_provider(self, name, url): + print(f"Getting snapshots' metadata from {name} instead...") + self.get_snapshot_metadata(name, url) + return self.config["snapshots"].get(name, None) + + # check if some of the providers has the compatible snapshot + # available in its metadadata and return the provider name + def find_fallback_provider(self, providers): + for name, url in providers.items(): + snapshot = self.try_fallback_provider(name, url) + if snapshot is not None: + return name + return None + + # tries to get the latest compatible snapshot from the given + # provider's metadata + # + # if the snapshot not found, tries to find it in other known + # providers + def get_snapshot_from_provider_with_fallback(self, name, url): + print_and_log(f"Getting snapshots' metadata from {name}...") + + self.get_snapshot_metadata(name, url) + snapshot = self.config["snapshots"].get(name, None) + + if snapshot is None: + fallback_providers = default_providers.copy() + fallback_providers.pop(name) + fallback_provider = self.find_fallback_provider(fallback_providers) + + if fallback_provider is None: + return None + else: + name = fallback_provider + + snapshot_file = self.fetch_snapshot_from_provider(name) + snapshot_block_hash = self.config["snapshots"][name]["block_hash"] + return (snapshot_file, snapshot_block_hash) + def get_snapshot_from_direct_url(self, url): try: self.query_step(snapshot_sha256_query) @@ -743,10 +800,6 @@ def import_snapshot(self): self.config["snapshots"] = {} - print_and_log("Getting snapshots' metadata from providers...") - for name, url in default_providers.items(): - self.get_snapshot_metadata(name, url) - os.makedirs(TMP_SNAPSHOT_LOCATION, exist_ok=True) else: @@ -786,10 +839,18 @@ def import_snapshot(self): else: for name, url in default_providers.items(): if name in self.config["snapshot_mode"]: - ( - snapshot_file, - snapshot_block_hash, - ) = self.get_snapshot_from_provider(name, url) + selected_provider = (name, url) + snapshot_info = self.get_snapshot_from_provider_with_fallback( + *selected_provider + ) + if snapshot_info is None: + print_and_log( + "Couldn't find available snapshot in any of the known providers.", + log=logging.warning, + colorcode=color_yellow, + ) + raise InterruptStep + (snapshot_file, snapshot_block_hash) = snapshot_info except InterruptStep: print_and_log("Getting back to the snapshot import mode step.") continue