Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve config handling #47

Merged
merged 2 commits into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions jupyterlab_pullrequests/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import traceback
from typing import Optional
from http import HTTPStatus

import tornado
Expand Down Expand Up @@ -198,26 +199,33 @@ def get_body_value(handler):
]


def setup_handlers(web_app: "NotebookWebApplication", config: PRConfig):
def setup_handlers(web_app: tornado.web.Application, config: PRConfig, log: Optional[logging.Logger]=None):
host_pattern = ".*$"
base_url = url_path_join(web_app.settings["base_url"], NAMESPACE)

logger = get_logger()
log = log or logging.getLogger(__name__)

manager_class = MANAGERS.get(config.provider)
if manager_class is None:
logger.error(f"No manager defined for provider '{config.provider}'.")
log.error(f"PR Manager: No manager defined for provider '{config.provider}'.")
raise NotImplementedError()
manager = manager_class(config.api_base_url, config.access_token)

web_app.add_handlers(
host_pattern,
[
(
url_path_join(base_url, pat),
handler,
{"logger": logger, "manager": manager},
)
for pat, handler in default_handlers
],
)
log.info(f"PR Manager Class {manager_class}")
try:
manager = manager_class(config)
except Exception as err:
import traceback
logging.error("PR Manager Exception", exc_info=1)
raise err

handlers = [
(
url_path_join(base_url, pat),
handler,
{"logger": log, "manager": manager},
)
for pat, handler in default_handlers
]

log.debug(f"PR Handlers: {handlers}")

web_app.add_handlers(host_pattern, handlers)
25 changes: 11 additions & 14 deletions jupyterlab_pullrequests/managers/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,20 @@
from tornado.httputil import url_concat
from tornado.web import HTTPError

from ..base import CommentReply, NewComment
from ..base import CommentReply, NewComment, PRConfig
from .manager import PullRequestsManager


class GitHubManager(PullRequestsManager):
"""Pull request manager for GitHub."""

def __init__(
self, base_api_url: str = "https://api.github.com", access_token: str = ""
) -> None:
"""
Args:
base_api_url: Base REST API url for the versioning service
access_token: Versioning service access token
"""
super().__init__(base_api_url=base_api_url, access_token=access_token)
self._pull_requests_cache = {} # Dict[str, Dict]
def __init__(self, config: PRConfig) -> None:
super().__init__(config)
self._pull_requests_cache = {}

@property
def base_api_url(self):
return self._config.api_base_url or "https://api.github.com"

@property
def per_page_argument(self) -> Optional[Tuple[str, int]]:
Expand All @@ -40,7 +37,7 @@ async def get_current_user(self) -> Dict[str, str]:
Returns:
JSON description of the user matching the access token
"""
git_url = url_path_join(self._base_api_url, "user")
git_url = url_path_join(self.base_api_url, "user")
data = await self._call_github(git_url, has_pagination=False)

return {"username": data["login"]}
Expand Down Expand Up @@ -186,7 +183,7 @@ async def list_prs(self, username: str, pr_filter: str) -> List[Dict[str, str]]:

# Use search API to find matching pull requests and return
git_url = url_path_join(
self._base_api_url, "/search/issues?q=+state:open+type:pr" + search_filter
self.base_api_url, "/search/issues?q=+state:open+type:pr" + search_filter
)

results = await self._call_github(git_url)
Expand Down Expand Up @@ -273,7 +270,7 @@ async def _call_github(
"""
headers = {
"Accept": media_type,
"Authorization": f"token {self._access_token}",
"Authorization": f"token {self._config.access_token}",
}

return await super()._call_provider(
Expand Down
30 changes: 14 additions & 16 deletions jupyterlab_pullrequests/managers/gitlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from tornado.httputil import url_concat
from tornado.web import HTTPError

from ..base import CommentReply, NewComment
from ..base import CommentReply, NewComment, PRConfig
from ..log import get_logger
from .manager import PullRequestsManager

Expand All @@ -25,22 +25,20 @@ class GitLabManager(PullRequestsManager):

MINIMAL_VERSION = "13.1" # Due to pagination https://docs.gitlab.com/ee/api/README.html#pagination

def __init__(
self, base_api_url: str = "https://gitlab.com/api/v4/", access_token: str = ""
) -> None:
"""
Args:
base_api_url: Base REST API url for the versioning service
access_token: Versioning service access token
"""
super().__init__(base_api_url=base_api_url, access_token=access_token)
def __init__(self, config: PRConfig) -> None:
super().__init__(config)

# Creating new file discussion required some commit sha's so we will cache them
self._merge_requests_cache = {} # Dict[str, Dict]
# Creating discussion on unmodified line requires to figure out the line number
# in the diff file for the original and the new file using Myers algorithm. So
# we cache the diff to speed up the process.
self._file_diff_cache = {} # Dict[Tuple[str, str], List[difflib.Match]]

@property
def base_api_url(self):
return self._config.api_base_url or "https://gitlab.com/api/v4/"

@property
def per_page_argument(self) -> Optional[Tuple[str, int]]:
"""Returns query argument to set number of items per page.
Expand All @@ -57,7 +55,7 @@ async def check_server_version(self) -> bool:
Returns:
Whether the server version is higher than the minimal supported version.
"""
url = url_path_join(self._base_api_url, "version")
url = url_path_join(self.base_api_url, "version")
data = await self._call_gitlab(url, has_pagination=False)
server_version = data.get("version", "")
is_valid = True
Expand All @@ -79,7 +77,7 @@ async def get_current_user(self) -> Dict[str, str]:
# Check server compatibility
await self.check_server_version()

git_url = url_path_join(self._base_api_url, "user")
git_url = url_path_join(self.base_api_url, "user")
data = await self._call_gitlab(git_url, has_pagination=False)

return {"username": data["username"]}
Expand Down Expand Up @@ -227,15 +225,15 @@ async def list_prs(self, username: str, pr_filter: str) -> List[Dict[str, str]]:

# Use search API to find matching pull requests and return
git_url = url_path_join(
self._base_api_url, "/merge_requests?state=opened&" + search_filter
self.base_api_url, "/merge_requests?state=opened&" + search_filter
)

results = await self._call_gitlab(git_url)

data = []
for result in results:
url = url_path_join(
self._base_api_url,
self.base_api_url,
"projects",
str(result["project_id"]),
"merge_requests",
Expand Down Expand Up @@ -374,7 +372,7 @@ async def _call_gitlab(
"""

headers = {
"Authorization": f"Bearer {self._access_token}",
"Authorization": f"Bearer {self._config.access_token}",
"Accept": "application/json",
}
return await super()._call_provider(
Expand Down Expand Up @@ -481,7 +479,7 @@ def _response_to_comment(result: Dict[str, str]) -> Dict[str, str]:
async def __get_content(self, project_id: int, filename: str, sha: str) -> str:
url = url_concat(
url_path_join(
self._base_api_url,
self.base_api_url,
"projects",
str(project_id),
"repository/files",
Expand Down
20 changes: 7 additions & 13 deletions jupyterlab_pullrequests/managers/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,19 @@

from .._version import __version__
from ..log import get_logger

from ..base import PRConfig

class PullRequestsManager(abc.ABC):
"""Abstract base class for pull requests manager."""

def __init__(self, base_api_url: str = "", access_token: str = "") -> None:
"""
Args:
base_api_url: Base REST API url for the versioning service
access_token: Versioning service access token
"""
def __init__(self, config: PRConfig) -> None:
self._config = config
self._client = tornado.httpclient.AsyncHTTPClient()
self._base_api_url = base_api_url
self._access_token = access_token

@property
def base_api_url(self) -> str:
"""The provider base REST API URL"""
return self._base_api_url
return self._config.api_base_url

@property
def log(self) -> logging.Logger:
Expand Down Expand Up @@ -142,7 +136,7 @@ async def _call_provider(
List or Dict: Create from JSON response body if load_json is True
str: Raw response body if load_json is False
"""
if not self._access_token:
if not self._config.access_token:
raise tornado.web.HTTPError(
status_code=http.HTTPStatus.BAD_REQUEST,
reason="No access token specified. Please set PRConfig.access_token in your user jupyter_server_config file.",
Expand All @@ -154,8 +148,8 @@ async def _call_provider(
headers["Content-Type"] = "application/json"
body = tornado.escape.json_encode(body)

if not url.startswith(self._base_api_url):
url = url_path_join(self._base_api_url, url)
if not url.startswith(self.base_api_url):
url = url_path_join(self.base_api_url, url)

with_pagination = False
if (
Expand Down
39 changes: 39 additions & 0 deletions jupyterlab_pullrequests/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1 +1,40 @@
import pytest

from ..base import PRConfig

# the preferred method for loading jupyter_server (because entry_points)
pytest_plugins = ["jupyter_server.pytest_plugin"]


@pytest.fixture
def pr_base_config():
return PRConfig()


@pytest.fixture
def pr_github_config(pr_base_config):
return pr_base_config()


@pytest.fixture
def pr_github_manager(pr_base_config):
from ..managers.github import GitHubManager
return GitHubManager(pr_base_config)


@pytest.fixture
def pr_valid_github_manager(pr_github_manager):
pr_github_manager._config.access_token = "valid"
return pr_github_manager


@pytest.fixture
def pr_gitlab_manger(pr_base_config):
from ..managers.gitlab import GitLabManager
return GitLabManager(pr_base_config)


@pytest.fixture
def pr_valid_gitlab_manager(pr_gitlab_manger):
pr_gitlab_manger._config.access_token = "valid"
return pr_gitlab_manger
Loading