Skip to content

Commit

Permalink
Merge pull request #382 from kingosticks/fix-startup-time
Browse files Browse the repository at this point in the history
Fix startup time (Fixes #343)
  • Loading branch information
kingosticks authored Mar 13, 2024
2 parents 01f191a + 836f089 commit c8f028b
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 115 deletions.
54 changes: 36 additions & 18 deletions src/mopidy_spotify/playlists.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import threading

from mopidy import backend
from mopidy.core import listener

from mopidy_spotify import translator, utils

Expand All @@ -11,20 +13,17 @@ class SpotifyPlaylistsProvider(backend.PlaylistsProvider):
def __init__(self, backend):
self._backend = backend
self._timeout = self._backend._config["spotify"]["timeout"]
self._loaded = False
self._refresh_mutex = threading.Lock()

def as_list(self):
with utils.time_logger("playlists.as_list()", logging.DEBUG):
if not self._loaded:
return []

return list(self._get_flattened_playlist_refs())

def _get_flattened_playlist_refs(self):
def _get_flattened_playlist_refs(self, *, refresh=False):
if not self._backend._web_client.logged_in:
return []

user_playlists = self._backend._web_client.get_user_playlists()
user_playlists = self._backend._web_client.get_user_playlists(refresh=refresh)
return translator.to_playlist_refs(
user_playlists, self._backend._web_client.user_id
)
Expand All @@ -48,18 +47,37 @@ def _get_playlist(self, uri, *, as_items=False):
def refresh(self):
if not self._backend._web_client.logged_in:
return

logger.info("Refreshing Spotify playlists")

with utils.time_logger("playlists.refresh()", logging.DEBUG):
self._backend._web_client.clear_cache()
count = 0
for playlist_ref in self._get_flattened_playlist_refs():
self._get_playlist(playlist_ref.uri)
count = count + 1
logger.info(f"Refreshed {count} Spotify playlists")

self._loaded = True
if not self._refresh_mutex.acquire(blocking=False):
logger.info("Refreshing Spotify playlists already in progress")
return
try:
uris = [ref.uri for ref in self._get_flattened_playlist_refs(refresh=True)]
logger.info(f"Refreshing {len(uris)} Spotify playlists in background")
threading.Thread(
target=self._refresh_tracks,
args=(uris,),
daemon=True,
).start()
except Exception:
logger.exception("Error occurred while refreshing Spotify playlists")
self._refresh_mutex.release()

def _refresh_tracks(self, playlist_uris):
if not self._refresh_mutex.locked():
logger.error("Lock must be held before calling this method")
return []
try:
with utils.time_logger("playlists._refresh_tracks()", logging.DEBUG):
refreshed = [uri for uri in playlist_uris if self.lookup(uri)]
logger.info(f"Refreshed {len(refreshed)} Spotify playlists")

listener.CoreListener.send("playlists_loaded")
except Exception:
logger.exception("Error occurred while refreshing Spotify playlists tracks")
else:
return refreshed # For test
finally:
self._refresh_mutex.release()

def create(self, name):
pass # TODO
Expand Down
85 changes: 52 additions & 33 deletions src/mopidy_spotify/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import re
import threading
import time
import urllib.parse
from dataclasses import dataclass
Expand Down Expand Up @@ -64,6 +65,9 @@ def __init__( # noqa: PLR0913

self._headers = {"Content-Type": "application/json"}
self._session = utils.get_requests_session(proxy_config or {})
# TODO: Move _cache_mutex to the object it actually protects.
self._cache_mutex = threading.Lock() # Protects get() cache param.
self._refresh_mutex = threading.Lock() # Protects _headers and _expires.

def get(self, path, cache=None, *args, **kwargs):
if self._authorization_failed:
Expand All @@ -75,21 +79,22 @@ def get(self, path, cache=None, *args, **kwargs):

_trace(f"Get '{path}'")

ignore_expiry = kwargs.pop("ignore_expiry", False)
expiry_strategy = kwargs.pop("expiry_strategy", None)
if cache is not None and path in cache:
cached_result = cache.get(path)
if cached_result.still_valid(ignore_expiry=ignore_expiry):
if cached_result.still_valid(expiry_strategy=expiry_strategy):
return cached_result
kwargs.setdefault("headers", {}).update(cached_result.etag_headers)

# TODO: Factor this out once we add more methods.
# TODO: Don't silently error out.
try:
if self._should_refresh_token():
self._refresh_token()
except OAuthTokenRefreshError as e:
logger.error(e) # noqa: TRY400
return WebResponse(None, None)
with self._refresh_mutex:
try:
if self._should_refresh_token():
self._refresh_token()
except OAuthTokenRefreshError as e:
logger.error(e) # noqa: TRY400
return WebResponse(None, None)

# Make sure our headers always override user supplied ones.
kwargs.setdefault("headers", {}).update(self._headers)
Expand All @@ -102,11 +107,12 @@ def get(self, path, cache=None, *args, **kwargs):
)
return WebResponse(None, None)

if self._should_cache_response(cache, result):
previous_result = cache.get(path)
if previous_result and previous_result.updated(result):
result = previous_result
cache[path] = result
with self._cache_mutex:
if self._should_cache_response(cache, result):
previous_result = cache.get(path)
if previous_result and previous_result.updated(result):
result = previous_result
cache[path] = result

return result

Expand All @@ -115,11 +121,16 @@ def _should_cache_response(self, cache, response):

def _should_refresh_token(self):
# TODO: Add jitter to margin?
if not self._refresh_mutex.locked():
raise OAuthTokenRefreshError("Lock must be held before calling.")
return not self._auth or time.time() > self._expires - self._margin

def _refresh_token(self):
logger.debug(f"Fetching OAuth token from {self._refresh_url}")

if not self._refresh_mutex.locked():
raise OAuthTokenRefreshError("Lock must be held before calling.")

data = {"grant_type": "client_credentials"}
result = self._request_with_retries(
"POST", self._refresh_url, auth=self._auth, data=data
Expand Down Expand Up @@ -259,6 +270,12 @@ def _parse_retry_after(self, response):
return max(0, seconds)


@unique
class ExpiryStrategy(Enum):
FORCE_FRESH = "force-fresh"
FORCE_EXPIRED = "force-expired"


class WebResponse(dict):
def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -330,19 +347,20 @@ def _parse_etag(response):

return None

def still_valid(self, *, ignore_expiry=False):
if ignore_expiry:
result = True
status = "forced"
elif self._expires >= time.time():
result = True
status = "fresh"
def still_valid(self, *, expiry_strategy=None):
if expiry_strategy is None:
if self._expires >= time.time():
valid = True
status = "fresh"
else:
valid = False
status = "expired"
else:
result = False
status = "expired"
self._from_cache = result
valid = expiry_strategy is ExpiryStrategy.FORCE_FRESH
status = expiry_strategy.value
self._from_cache = valid
_trace("Cached data %s for %s", status, self)
return result
return valid

@property
def status_unchanged(self):
Expand Down Expand Up @@ -434,8 +452,13 @@ def login(self):
def logged_in(self):
return self.user_id is not None

def get_user_playlists(self):
pages = self.get_all(f"users/{self.user_id}/playlists", params={"limit": 50})
def get_user_playlists(self, *, refresh=False):
expiry_strategy = ExpiryStrategy.FORCE_EXPIRED if refresh else None
pages = self.get_all(
f"users/{self.user_id}/playlists",
params={"limit": 50},
expiry_strategy=expiry_strategy,
)
for page in pages:
yield from page.get("items", [])

Expand All @@ -446,7 +469,9 @@ def _with_all_tracks(self, obj, params=None):
track_pages = self.get_all(
tracks_path,
params=params,
ignore_expiry=obj.status_unchanged,
expiry_strategy=(
ExpiryStrategy.FORCE_FRESH if obj.status_unchanged else None
),
)

more_tracks = []
Expand Down Expand Up @@ -527,12 +552,6 @@ def get_track(self, web_link):

return self.get_one(f"tracks/{web_link.id}", params={"market": "from_token"})

def clear_cache(
self,
extra_expiry=None, # noqa: ARG002
):
self._cache.clear()


@unique
class LinkType(Enum):
Expand Down
16 changes: 16 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import threading


class ThreadJoiner:
def __init__(self, timeout: int = 1):
self.timeout = timeout

def __enter__(self):
self.before = set(threading.enumerate())

def __exit__(self, exc_type, exc_val, exc_tb):
new_threads = set(threading.enumerate()) - self.before
for thread in new_threads:
thread.join(timeout=self.timeout)
if thread.is_alive():
raise RuntimeError(f"Timeout joining thread {thread}")
15 changes: 10 additions & 5 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from mopidy_spotify import backend, library, playlists
from mopidy_spotify.backend import SpotifyPlaybackProvider

from tests import ThreadJoiner


def get_backend(config):
obj = backend.SpotifyBackend(config=config, audio=None)
Expand Down Expand Up @@ -58,7 +60,8 @@ def test_on_start_configures_proxy(web_mock, config):
"password": "s3cret",
}
backend = get_backend(config)
backend.on_start()
with ThreadJoiner():
backend.on_start()

assert True

Expand All @@ -74,7 +77,8 @@ def test_on_start_configures_web_client(web_mock, config):
config["spotify"]["client_secret"] = "AbCdEfG"

backend = get_backend(config)
backend.on_start()
with ThreadJoiner():
backend.on_start()

web_mock.SpotifyOAuthClient.assert_called_once_with(
client_id="1234567",
Expand All @@ -92,12 +96,13 @@ def test_on_start_logs_in(web_mock, config):

def test_on_start_refreshes_playlists(web_mock, config, caplog):
backend = get_backend(config)
backend.on_start()
with ThreadJoiner():
backend.on_start()

client_mock = web_mock.SpotifyOAuthClient.return_value
client_mock.get_user_playlists.assert_called_once()
client_mock.get_user_playlists.assert_called_once_with(refresh=True)
assert "Refreshing 0 Spotify playlists in background" in caplog.text
assert "Refreshed 0 Spotify playlists" in caplog.text
assert backend.playlists._loaded


def test_on_start_doesnt_refresh_playlists_if_not_allowed(web_mock, config, caplog):
Expand Down
Loading

0 comments on commit c8f028b

Please sign in to comment.