Skip to content

Commit

Permalink
Added error handling in license deployment. (#417)
Browse files Browse the repository at this point in the history
* Added error handling in license deployment.

* lint fixes

* nits

* lint fixes
  • Loading branch information
aniketsinghrawat authored Nov 6, 2023
1 parent 59408ae commit 00de1d5
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 19 deletions.
15 changes: 15 additions & 0 deletions weather_dl_v2/license_deployment/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def _remove_config_from_license_queue(
) -> None:
pass

@abc.abstractmethod
def _empty_license_queue(self, license_id: str) -> None:
pass

@abc.abstractmethod
def _get_partition_from_manifest(self, config_name: str) -> str:
pass
Expand Down Expand Up @@ -128,6 +132,17 @@ def _remove_config_from_license_queue(
f"Updated {license_id} queue in 'queues' collection. Update_time: {result.update_time}."
)

def _empty_license_queue(self, license_id: str) -> None:
result: WriteResult = (
self._get_db()
.collection(get_config().queues_collection)
.document(license_id)
.update({"queue": []})
)
logger.info(
f"Updated {license_id} queue in 'queues' collection. Update_time: {result.update_time}."
)


# TODO: Firestore transcational fails after reading a document 20 times with roll over.
# This happens when too many licenses try to access the same partition document.
Expand Down
77 changes: 61 additions & 16 deletions weather_dl_v2/license_deployment/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
from job_creator import create_download_job
from clients import CLIENTS
from manifest import FirestoreManifest
from util import exceptionit
from util import exceptionit, ThreadSafeDict

db_client = FirestoreClient()
secretmanager_client = secretmanager.SecretManagerServiceClient()

CONFIG_MAX_ERROR_COUNT = 10

def create_job(request, result):
res = {
Expand All @@ -48,27 +48,64 @@ def create_job(request, result):


@exceptionit
def make_fetch_request(request):
def make_fetch_request(request, error_map: ThreadSafeDict):
client = CLIENTS[client_name](request["dataset"])
manifest = FirestoreManifest(license_id=license_id)
logger.info(
f"By using {client_name} datasets, "
f"users agree to the terms and conditions specified in {client.license_url!r}"
f"users agree to the terms and conditions specified in {client.license_url!r}."
)

target = request["location"]
selection = json.loads(request["selection"])

logger.info(f"Fetching data for {target!r}.")
with manifest.transact(
request["config_name"],
request["dataset"],
selection,
target,
request["username"],
):
result = client.retrieve(request["dataset"], selection, manifest)

config_name = request["config_name"]

if not error_map.has_key(config_name):
error_map[config_name] = 0

if error_map[config_name] >= CONFIG_MAX_ERROR_COUNT:
logger.info(f"Error count for config {config_name} exceeded CONFIG_MAX_ERROR_COUNT ({CONFIG_MAX_ERROR_COUNT}).")
error_map.remove(config_name)
logger.info(f"Removing config {config_name} from license queue.")
# Remove config from this license queue.
db_client._remove_config_from_license_queue(license_id=license_id, config_name=config_name)
return

# Wait for exponential time based on error count.
if error_map[config_name] > 0:
logger.info(f"Error count for config {config_name}: {error_map[config_name]}.")
time = error_map.exponential_time(config_name)
logger.info(f"Sleeping for {time} mins.")
time.sleep(time)

try:
with manifest.transact(
request["config_name"],
request["dataset"],
selection,
target,
request["username"],
):
result = client.retrieve(request["dataset"], selection, manifest)
except Exception as e:
# We are handling this as generic case as CDS client throws generic exceptions.

# License expired.
if "Access token expired" in str(e):
logger.error(f"{license_id} expired. Emptying queue! error: {e}.")
db_client._empty_license_queue(license_id=license_id)
return

# Increment error count for a config.
logger.error(f"Partition fetching failed. Error {e}.")
error_map.increment(config_name)
return

# If any partition in successful reset the error count.
error_map[config_name] = 0
create_job(request, result)


Expand All @@ -90,20 +127,28 @@ def fetch_request_from_db():

def main():
logger.info("Started looking at the request.")
error_map = ThreadSafeDict()
with ThreadPoolExecutor(concurrency_limit) as executor:
# Disclaimer: A license will pick always pick concurrency_limit + 1
# parition. One extra parition will be kept in threadpool task queue.

while True:
# Fetch a request from the database
request = fetch_request_from_db()

if request is not None:
executor.submit(make_fetch_request, request)
executor.submit(make_fetch_request, request, error_map)
else:
logger.info("No request available. Waiting...")
time.sleep(5)

# Check if the maximum concurrency level has been reached
# If so, wait for a slot to become available
while executor._work_queue.qsize() >= concurrency_limit:
# Each license should not pick more partitions than it's
# concurrency_limit. We limit the threadpool queue size to just 1
# to prevent the license from picking more partitions than
# it's concurrency_limit. When an executor is freed up, the task
# in queue is picked and license fetches another task.
while executor._work_queue.qsize() >= 1:
logger.info("Worker busy. Waiting...")
time.sleep(1)


Expand Down
8 changes: 5 additions & 3 deletions weather_dl_v2/license_deployment/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,14 +500,16 @@ def _update(self, download_status: DownloadStatus) -> None:
status = DownloadStatus.to_dict(download_status)
doc_id = generate_md5_hash(status["location"])

# Update document with download status
# Update document with download status.
download_doc_ref = self.root_document_for_store(doc_id)

result: WriteResult = download_doc_ref.set(status)

logger.info(
f"Firestore manifest updated. "
f"update_time={result.update_time}, "
"Firestore manifest updated. " +
f"update_time={result.update_time}, " +
f"status={status['status']} " +
f"stage={status['stage']} " +
f"filename={download_status.location}."
)

Expand Down
63 changes: 63 additions & 0 deletions weather_dl_v2/license_deployment/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from xarray.core.utils import ensure_us_time_resolution
from urllib.parse import urlparse
from google.api_core.exceptions import BadRequest
from threading import Lock

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -237,3 +238,65 @@ def download_with_aria2(url: str, path: str) -> None:
f'Failed download from server {url!r} to {path!r} due to {e.stderr.decode("utf-8")}.'
)
raise

class ThreadSafeDict:
"""A thread safe dict with crud operations."""


def __init__(self) -> None:
self._dict = {}
self._lock = Lock()
self.initial_delay = 1
self.factor = 0.5


def __getitem__(self, key):
val = None
with self._lock:
val = self._dict[key]
return val


def __setitem__(self, key, value):
with self._lock:
self._dict[key] = value


def remove(self, key):
with self._lock:
self._dict.__delitem__(key)


def has_key(self, key):
present = False
with self._lock:
present = key in self._dict
return present


def increment(self, key, delta=1):
with self._lock:
if key in self._dict:
self._dict[key] += delta


def decrement(self, key, delta=1):
with self._lock:
if key in self._dict:
self._dict[key] -= delta


def find_exponential_delay(self, n: int) -> int:
delay = self.initial_delay
for _ in range(n):
delay += delay*self.factor
return delay


def exponential_time(self, key):
"""Returns exponential time based on dict value. Time in seconds."""
delay = 0
with self._lock:
if key in self._dict:
delay = self.find_exponential_delay(self._dict[key])
return delay * 60

0 comments on commit 00de1d5

Please sign in to comment.