Skip to content

Commit

Permalink
fix profile use switching (#17300)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Feb 27, 2025
1 parent 29d8afd commit ec16d74
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 70 deletions.
61 changes: 33 additions & 28 deletions src/prefect/cli/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from prefect.client.base import determine_server_type
from prefect.client.orchestration import ServerType, get_client
from prefect.context import use_profile
from prefect.exceptions import ObjectNotFound
from prefect.settings import ProfilesCollection
from prefect.utilities.collections import AutoEnum

Expand Down Expand Up @@ -293,13 +292,16 @@ def show_profile_changes(
@profile_app.command()
def populate_defaults():
"""Populate the profiles configuration with default base profiles, preserving existing user profiles."""
user_path = prefect.settings.PREFECT_PROFILES_PATH.value()
default_profiles = prefect.settings.profiles._read_profiles_from(
prefect.settings.DEFAULT_PROFILES_PATH
from prefect.settings.profiles import (
_read_profiles_from, # type: ignore[reportPrivateUsage]
_write_profiles_to, # type: ignore[reportPrivateUsage]
)

user_path = prefect.settings.PREFECT_PROFILES_PATH.value()
default_profiles = _read_profiles_from(prefect.settings.DEFAULT_PROFILES_PATH)

if user_path.exists():
user_profiles = prefect.settings.profiles._read_profiles_from(user_path)
user_profiles = _read_profiles_from(user_path)

if not show_profile_changes(user_profiles, default_profiles):
return
Expand All @@ -324,7 +326,7 @@ def populate_defaults():
if name not in user_profiles:
user_profiles.add_profile(profile)

prefect.settings.profiles._write_profiles_to(user_path, user_profiles)
_write_profiles_to(user_path, user_profiles)
app.console.print(f"\nProfiles updated in [green]{user_path}[/green]")
app.console.print(
"\nUse with [green]prefect profile use[/green] [blue][PROFILE-NAME][/blue]"
Expand All @@ -348,27 +350,32 @@ class ConnectionStatus(AutoEnum):
async def check_server_connection() -> ConnectionStatus:
httpx_settings = dict(timeout=3)
try:
# attempt to infer Cloud 2.0 API from the connection URL
cloud_client = get_cloud_client(
httpx_settings=httpx_settings, infer_cloud_url=True
)
async with cloud_client:
await cloud_client.api_healthcheck()
return ConnectionStatus.CLOUD_CONNECTED
except CloudUnauthorizedError:
# if the Cloud 2.0 API exists and fails to authenticate, notify the user
return ConnectionStatus.CLOUD_UNAUTHORIZED
except ObjectNotFound:
# if the route does not exist, attempt to connect as a hosted Prefect
# instance
# First determine the server type based on the URL
server_type = determine_server_type()

# Only try to connect to Cloud if the URL looks like a Cloud URL
if server_type == ServerType.CLOUD:
try:
cloud_client = get_cloud_client(
httpx_settings=httpx_settings, infer_cloud_url=True
)
async with cloud_client:
await cloud_client.api_healthcheck()
return ConnectionStatus.CLOUD_CONNECTED
except CloudUnauthorizedError:
# if the Cloud API exists and fails to authenticate, notify the user
return ConnectionStatus.CLOUD_UNAUTHORIZED
except (httpx.HTTPStatusError, Exception):
return ConnectionStatus.CLOUD_ERROR

# For non-Cloud URLs, try to connect as a hosted Prefect instance
if server_type == ServerType.EPHEMERAL:
return ConnectionStatus.EPHEMERAL
elif server_type == ServerType.UNCONFIGURED:
return ConnectionStatus.UNCONFIGURED

# Try to connect to the server
try:
# inform the user if Prefect API endpoints exist, but there are
# connection issues
server_type = determine_server_type()
if server_type == ServerType.EPHEMERAL:
return ConnectionStatus.EPHEMERAL
elif server_type == ServerType.UNCONFIGURED:
return ConnectionStatus.UNCONFIGURED
client = get_client(httpx_settings=httpx_settings)
async with client:
connect_error = await client.api_healthcheck()
Expand All @@ -378,8 +385,6 @@ async def check_server_connection() -> ConnectionStatus:
return ConnectionStatus.SERVER_CONNECTED
except Exception:
return ConnectionStatus.SERVER_ERROR
except httpx.HTTPStatusError:
return ConnectionStatus.CLOUD_ERROR
except TypeError:
# if no Prefect API URL has been set, httpx will throw a TypeError
try:
Expand Down
11 changes: 6 additions & 5 deletions src/prefect/settings/profiles.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import inspect
import warnings
from pathlib import Path
Expand All @@ -9,7 +11,6 @@
Iterable,
Iterator,
Optional,
Set,
Union,
)

Expand All @@ -32,8 +33,8 @@


def _cast_settings(
settings: Union[Dict[Union[str, Setting], Any], Any],
) -> Dict[Setting, Any]:
settings: dict[str | Setting, Any] | Any,
) -> dict[Setting, Any]:
"""For backwards compatibility, allow either Settings objects as keys or string references to settings."""
if not isinstance(settings, dict):
raise ValueError("Settings must be a dictionary.")
Expand Down Expand Up @@ -63,7 +64,7 @@ class Profile(BaseModel):
)

name: str
settings: Annotated[Dict[Setting, Any], BeforeValidator(_cast_settings)] = Field(
settings: Annotated[dict[Setting, Any], BeforeValidator(_cast_settings)] = Field(
default_factory=dict
)
source: Optional[Path] = None
Expand Down Expand Up @@ -114,7 +115,7 @@ def __init__(
self.active_name = active

@property
def names(self) -> Set[str]:
def names(self) -> set[str]:
"""
Return a set of profile names in this collection.
"""
Expand Down
104 changes: 67 additions & 37 deletions tests/cli/test_profile.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import shutil
from pathlib import Path
from uuid import uuid4

import pytest
import respx
from httpx import Response

from prefect.cli.profile import show_profile_changes
from prefect.client.cloud import CloudUnauthorizedError
from prefect.context import use_profile
from prefect.settings import (
DEFAULT_PROFILES_PATH,
Expand All @@ -19,12 +21,14 @@
save_profiles,
temporary_settings,
)
from prefect.settings.profiles import _read_profiles_from
from prefect.settings.profiles import (
_read_profiles_from, # type: ignore[reportPrivateUsage]
)
from prefect.testing.cli import invoke_and_assert


@pytest.fixture(autouse=True)
def temporary_profiles_path(tmp_path):
def temporary_profiles_path(tmp_path: Path):
path = tmp_path / "profiles.toml"
with temporary_settings({PREFECT_PROFILES_PATH: path}):
yield path
Expand All @@ -41,7 +45,7 @@ def test_use_profile_unknown_key():
class TestChangingProfileAndCheckingServerConnection:
@pytest.fixture
def profiles(self):
prefect_cloud_api_url = "https://mock-cloud.prefect.io/api"
prefect_cloud_api_url = "https://api.prefect.cloud/api"
prefect_cloud_server_api_url = (
f"{prefect_cloud_api_url}/accounts/{uuid4()}/workspaces/{uuid4()}"
)
Expand Down Expand Up @@ -79,39 +83,58 @@ def profiles(self):

@pytest.fixture
def authorized_cloud(self):
# attempts to reach the Cloud 2 workspaces endpoint implies a good connection
# attempts to reach the Cloud API implies a good connection
# to Prefect Cloud as opposed to a hosted Prefect server instance
with respx.mock(using="httpx") as respx_mock:
authorized = respx_mock.get(
"https://mock-cloud.prefect.io/api/me/workspaces",
with respx.mock(using="httpx", assert_all_called=False) as respx_mock:
# Mock the health endpoint for cloud
health = respx_mock.get(
"https://api.prefect.cloud/api/health",
).mock(return_value=Response(200, json={}))

# Keep the workspaces endpoint mock for backward compatibility
respx_mock.get(
"https://api.prefect.cloud/api/me/workspaces",
).mock(return_value=Response(200, json=[]))

yield authorized
yield health

@pytest.fixture
def unauthorized_cloud(self):
# requests to cloud with an invalid key will result in a 401 response
with respx.mock(using="httpx") as respx_mock:
unauthorized = respx_mock.get(
"https://mock-cloud.prefect.io/api/me/workspaces",
).mock(return_value=Response(401, json={}))
with respx.mock(using="httpx", assert_all_called=False) as respx_mock:
# Mock the health endpoint for cloud
health = respx_mock.get(
"https://api.prefect.cloud/api/health",
).mock(side_effect=CloudUnauthorizedError("Invalid API key"))

yield unauthorized
# Keep the workspaces endpoint mock for backward compatibility
respx_mock.get(
"https://api.prefect.cloud/api/me/workspaces",
).mock(side_effect=CloudUnauthorizedError("Invalid API key"))

yield health

@pytest.fixture
def unhealthy_cloud(self):
# Cloud may respond with a 500 error when having connection issues
with respx.mock(using="httpx") as respx_mock:
unhealthy_cloud = respx_mock.get(
"https://mock-cloud.prefect.io/api/me/workspaces",
).mock(return_value=Response(500, json={}))
# requests to cloud with an invalid key will result in a 401 response
with respx.mock(using="httpx", assert_all_called=False) as respx_mock:
# Mock the health endpoint for cloud with an error
unhealthy = respx_mock.get(
"https://api.prefect.cloud/api/health",
).mock(side_effect=self.connection_error)

yield unhealthy_cloud
# Keep the workspaces endpoint mock for backward compatibility
respx_mock.get(
"https://api.prefect.cloud/api/me/workspaces",
).mock(side_effect=self.connection_error)

yield unhealthy

@pytest.fixture
def hosted_server_has_no_cloud_api(self):
# if the API URL points to a hosted Prefect server instance, no Cloud API will be found
with respx.mock(using="httpx") as respx_mock:
with respx.mock(using="httpx", assert_all_called=False) as respx_mock:
# We don't need to mock the cloud API endpoint anymore since we check server type first
hosted = respx_mock.get(
"https://hosted-server.prefect.io/api/me/workspaces",
).mock(return_value=Response(404, json={}))
Expand All @@ -120,7 +143,7 @@ def hosted_server_has_no_cloud_api(self):

@pytest.fixture
def healthy_hosted_server(self):
with respx.mock(using="httpx") as respx_mock:
with respx.mock(using="httpx", assert_all_called=False) as respx_mock:
hosted = respx_mock.get(
"https://hosted-server.prefect.io/api/health",
).mock(return_value=Response(200, json={}))
Expand All @@ -132,14 +155,15 @@ def connection_error(self, *args):

@pytest.fixture
def unhealthy_hosted_server(self):
with respx.mock(using="httpx") as respx_mock:
with respx.mock(using="httpx", assert_all_called=False) as respx_mock:
badly_hosted = respx_mock.get(
"https://hosted-server.prefect.io/api/health",
).mock(side_effect=self.connection_error)

yield badly_hosted

def test_authorized_cloud_connection(self, authorized_cloud, profiles):
@pytest.mark.usefixtures("authorized_cloud")
def test_authorized_cloud_connection(self, profiles: ProfilesCollection):
save_profiles(profiles)
invoke_and_assert(
["profile", "use", "prefect-cloud"],
Expand All @@ -152,7 +176,8 @@ def test_authorized_cloud_connection(self, authorized_cloud, profiles):
profiles = load_profiles()
assert profiles.active_name == "prefect-cloud"

def test_unauthorized_cloud_connection(self, unauthorized_cloud, profiles):
@pytest.mark.usefixtures("unauthorized_cloud")
def test_unauthorized_cloud_connection(self, profiles: ProfilesCollection):
save_profiles(profiles)
invoke_and_assert(
["profile", "use", "prefect-cloud-with-invalid-key"],
Expand All @@ -166,7 +191,8 @@ def test_unauthorized_cloud_connection(self, unauthorized_cloud, profiles):
profiles = load_profiles()
assert profiles.active_name == "prefect-cloud-with-invalid-key"

def test_unhealthy_cloud_connection(self, unhealthy_cloud, profiles):
@pytest.mark.usefixtures("unhealthy_cloud")
def test_unhealthy_cloud_connection(self, profiles: ProfilesCollection):
save_profiles(profiles)
invoke_and_assert(
["profile", "use", "prefect-cloud"],
Expand All @@ -177,9 +203,8 @@ def test_unhealthy_cloud_connection(self, unhealthy_cloud, profiles):
profiles = load_profiles()
assert profiles.active_name == "prefect-cloud"

def test_using_hosted_server(
self, hosted_server_has_no_cloud_api, healthy_hosted_server, profiles
):
@pytest.mark.usefixtures("hosted_server_has_no_cloud_api", "healthy_hosted_server")
def test_using_hosted_server(self, profiles: ProfilesCollection):
save_profiles(profiles)
invoke_and_assert(
["profile", "use", "hosted-server"],
Expand All @@ -192,9 +217,10 @@ def test_using_hosted_server(
profiles = load_profiles()
assert profiles.active_name == "hosted-server"

def test_unhealthy_hosted_server(
self, hosted_server_has_no_cloud_api, unhealthy_hosted_server, profiles
):
@pytest.mark.usefixtures(
"hosted_server_has_no_cloud_api", "unhealthy_hosted_server"
)
def test_unhealthy_hosted_server(self, profiles: ProfilesCollection):
save_profiles(profiles)
invoke_and_assert(
["profile", "use", "hosted-server"],
Expand All @@ -205,7 +231,7 @@ def test_unhealthy_hosted_server(
profiles = load_profiles()
assert profiles.active_name == "hosted-server"

def test_using_ephemeral_server(self, profiles):
def test_using_ephemeral_server(self, profiles: ProfilesCollection):
save_profiles(profiles)
invoke_and_assert(
["profile", "use", "ephemeral"],
Expand Down Expand Up @@ -506,7 +532,9 @@ def test_rename_profile_changes_active_profile():
assert profiles.active_name == "bar"


def test_rename_profile_warns_on_environment_variable_active_profile(monkeypatch):
def test_rename_profile_warns_on_environment_variable_active_profile(
monkeypatch: pytest.MonkeyPatch,
):
save_profiles(
ProfilesCollection(
profiles=[
Expand Down Expand Up @@ -581,7 +609,7 @@ def test_inspect_profile_without_settings():


class TestProfilesPopulateDefaults:
def test_populate_defaults(self, temporary_profiles_path):
def test_populate_defaults(self, temporary_profiles_path: Path):
default_profiles = _read_profiles_from(DEFAULT_PROFILES_PATH)

assert not temporary_profiles_path.exists()
Expand Down Expand Up @@ -612,7 +640,9 @@ def test_populate_defaults(self, temporary_profiles_path):
for name in default_profiles.names:
assert populated_profiles[name].settings == default_profiles[name].settings

def test_populate_defaults_with_existing_profiles(self, temporary_profiles_path):
def test_populate_defaults_with_existing_profiles(
self, temporary_profiles_path: Path
):
existing_profiles = ProfilesCollection(
profiles=[Profile(name="existing", settings={PREFECT_API_KEY: "test_key"})],
active="existing",
Expand Down Expand Up @@ -644,7 +674,7 @@ def test_populate_defaults_with_existing_profiles(self, temporary_profiles_path)
assert "existing" in backup_profiles.names
assert backup_profiles["existing"].settings == {PREFECT_API_KEY: "test_key"}

def test_populate_defaults_no_changes_needed(self, temporary_profiles_path):
def test_populate_defaults_no_changes_needed(self, temporary_profiles_path: Path):
shutil.copy(DEFAULT_PROFILES_PATH, temporary_profiles_path)

invoke_and_assert(
Expand All @@ -657,7 +687,7 @@ def test_populate_defaults_no_changes_needed(self, temporary_profiles_path):

assert temporary_profiles_path.read_text() == DEFAULT_PROFILES_PATH.read_text()

def test_show_profile_changes(self, capsys):
def test_show_profile_changes(self, capsys: pytest.CaptureFixture[str]):
default_profiles = ProfilesCollection(
profiles=[
Profile(
Expand Down

0 comments on commit ec16d74

Please sign in to comment.