Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add options to output registered entity summary #3028

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
26 changes: 26 additions & 0 deletions flytekit/clis/sdk_in_container/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,20 @@
help="Skip errors during registration. This is useful when registering multiple packages and you want to skip "
"errors for some packages.",
)
@click.option(
"--summary-format",
"-f",
required=False,
type=click.Choice(["json", "yaml"], case_sensitive=False),
default=None,
help="Output format for registration summary. Lists registered workflows, tasks, and launch plans. 'json' and 'yaml' supported.",
)
@click.option(
"--quiet",
is_flag=True,
default=False,
help="Suppress output messages, only displaying errors.",
)
@click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1)
@click.pass_context
def register(
Expand All @@ -162,12 +176,15 @@ def register(
activate_launchplans: bool,
env: typing.Optional[typing.Dict[str, str]],
skip_errors: bool,
summary_format: typing.Optional[str],
quiet: bool,
):
"""
see help
"""
# Set the relevant copy option if non_fast is set, this enables the individual file listing behavior
# that the copy flag uses.

if non_fast:
click.secho("The --non-fast flag is deprecated, please use --copy none instead", fg="yellow")
if "--copy" in sys.argv:
Expand Down Expand Up @@ -195,6 +212,13 @@ def register(
"Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed",
)

if summary_format is not None:
quiet = True

# mutes all secho
if quiet:
click.secho = lambda *args, **kw: None

# Use extra images in the config file if that file exists
config_file = ctx.obj.get(constants.CTX_CONFIG_FILE)
if config_file:
Expand Down Expand Up @@ -225,6 +249,8 @@ def register(
package_or_module=package_or_module,
remote=remote,
env=env,
summary_format=summary_format,
quiet=quiet,
dry_run=dry_run,
activate_launchplans=activate_launchplans,
skip_errors=skip_errors,
Expand Down
77 changes: 67 additions & 10 deletions flytekit/tools/repo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import asyncio
import functools
import json
import os
import tarfile
import tempfile
import typing
from contextlib import contextmanager
from pathlib import Path

import click
import yaml
from rich import print as rprint

from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings
Expand All @@ -22,6 +25,8 @@
from flytekit.tools.serialize_helpers import get_registrable_entities, persist_registrable_entities
from flytekit.tools.translator import FlyteControlPlaneEntity, Options

original_secho = click.secho


class NoSerializableEntitiesError(Exception):
pass
Expand Down Expand Up @@ -237,6 +242,20 @@ def print_registration_status(
rprint(f"[{color}]{state_ind} {name}: {i.name} (Failed)")


@contextmanager
def temporary_secho():
"""
Temporarily restores the original click.secho function.
Useful when you need to temporarily disable quiet mode.
"""
current_secho = click.secho
try:
click.secho = original_secho
yield
finally:
click.secho = current_secho


def register(
project: str,
domain: str,
Expand All @@ -251,6 +270,8 @@ def register(
remote: FlyteRemote,
copy_style: CopyFileDetection,
env: typing.Optional[typing.Dict[str, str]],
summary_format: typing.Optional[str],
quiet: bool = False,
dry_run: bool = False,
activate_launchplans: bool = False,
skip_errors: bool = False,
Expand All @@ -261,6 +282,10 @@ def register(
Temporarily, for fast register, specify both the fast arg as well as copy_style.
fast == True with copy_style == None means use the old fast register tar'ring method.
"""

if quiet:
click.secho = lambda *args, **kw: None

detected_root = find_common_root(package_or_module)
click.secho(f"Detected Root {detected_root}, using this to create deployable package...", fg="yellow")

Expand Down Expand Up @@ -316,6 +341,7 @@ def register(
registrable_entities = serialize_get_control_plane_entities(
serialization_settings, str(detected_root), options, is_registration=True
)

click.secho(
f"Serializing and registering {len(registrable_entities)} flyte entities",
fg="green",
Expand All @@ -333,6 +359,14 @@ def _raw_register(cp_entity: FlyteControlPlaneEntity):
is_lp = True
else:
og_id = cp_entity.template.id

result = {
"id": og_id.name,
"type": og_id.resource_type_name(),
"version": og_id.version,
"status": "skipped", # default status
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider enhancing registration result information

Consider adding more detailed status information in the result dictionary. The current status field only captures high-level states ('skipped', 'success', 'failed'). Additional fields like error_message and timestamp could provide more context for debugging and monitoring.

Code suggestion
Check the AI-generated fix before applying
Suggested change
result = {
"id": og_id.name,
"type": og_id.resource_type_name(),
"version": og_id.version,
"status": "skipped", # default status
}
result = {
"id": og_id.name,
"type": og_id.resource_type_name(),
"version": og_id.version,
"status": "skipped", # default status
"timestamp": datetime.datetime.now().isoformat(),
"error_message": "",
"details": {}
}

Code Review Run #9a3edb


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged


try:
if not dry_run:
try:
Expand All @@ -347,33 +381,56 @@ def _raw_register(cp_entity: FlyteControlPlaneEntity):
print_activation_message = True
if cp_entity.should_auto_activate:
print_activation_message = True
print_registration_status(
i, console_url=console_url, verbosity=verbosity, activation=print_activation_message
)
if not quiet:
print_registration_status(
i, console_url=console_url, verbosity=verbosity, activation=print_activation_message
)
result["status"] = "success"

except Exception as e:
if not skip_errors:
raise e
print_registration_status(og_id, success=False)
if not quiet:
print_registration_status(og_id, success=False)
result["status"] = "failed"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it fails, what will the values of other keys be?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The values of the other keys(id, type, version) are pre-computed before registration. Thus, the values will not be empty even if the registration fails.

The values are from:

result = {
"id": og_id.name,
"type": og_id.resource_type_name(),
"version": og_id.version,
"status": "skipped", # default status
}
where og_id is the id of the entity's template / entity itself.

else:
print_registration_status(og_id, dry_run=True)
if not quiet:
print_registration_status(og_id, dry_run=True)
except RegistrationSkipped:
print_registration_status(og_id, success=False)
if not quiet:
print_registration_status(og_id, success=False)
result["status"] = "skipped"

return result

async def _register(entities: typing.List[task.TaskSpec]):
loop = asyncio.get_running_loop()
tasks = []
for entity in entities:
tasks.append(loop.run_in_executor(None, functools.partial(_raw_register, entity)))
await asyncio.gather(*tasks)
return
results = await asyncio.gather(*tasks)
return results

# concurrent register
cp_task_entities = list(filter(lambda x: isinstance(x, task.TaskSpec), registrable_entities))
asyncio.run(_register(cp_task_entities))
task_results = asyncio.run(_register(cp_task_entities))
# serial register
cp_other_entities = list(filter(lambda x: not isinstance(x, task.TaskSpec), registrable_entities))
other_results = []
for entity in cp_other_entities:
_raw_register(entity)
other_results.append(_raw_register(entity))

all_results = task_results + other_results

click.secho(f"Successfully registered {len(registrable_entities)} entities", fg="green")

if summary_format is not None:
supported_format = {"json", "yaml"}
if summary_format not in supported_format:
raise ValueError(f"Unsupported file format: {summary_format}")

with temporary_secho():
if summary_format == "json":
click.secho(json.dumps(all_results, indent=2))
elif summary_format == "yaml":
click.secho(yaml.dump(all_results))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider safer click.secho state management

Consider using a context manager for temporarily restoring click.secho instead of directly manipulating it. The current approach with temporary_secho() could lead to inconsistent state if an exception occurs. A similar issue was also found in flytekit/clis/sdk_in_container/register.py (line 219-220).

Code suggestion
Check the AI-generated fix before applying
 @@ -246,11 +246,12 @@
 -def temporary_secho():
 +@contextmanager
 +def temporary_secho():
      current_secho = click.secho
      try:
          click.secho = original_secho
          yield
      finally:
          click.secho = current_secho

Code Review Run #5721dc


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

98 changes: 98 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_register.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import shutil
import subprocess
import json
import yaml

import mock
import pytest
Expand Down Expand Up @@ -163,3 +165,99 @@ def test_non_fast_register_require_version(mock_client, mock_remote):
result = runner.invoke(pyflyte.main, ["register", "--non-fast", "core3"])
assert result.exit_code == 1
shutil.rmtree("core3")


@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote)
@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient)
def test_register_registrated_summary_json(mock_client, mock_remote):
ctx = FlyteContextManager.current_context()
mock_remote._client = mock_client
mock_remote.return_value.context = ctx
mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash"
mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url"
runner = CliRunner()
context_manager.FlyteEntities.entities.clear()

with runner.isolated_filesystem():
out = subprocess.run(["git", "init"], capture_output=True)
assert out.returncode == 0
os.makedirs("core5", exist_ok=True)
with open(os.path.join("core5", "sample.py"), "w") as f:
f.write(sample_file_contents)
f.close()

result = runner.invoke(
pyflyte.main,
["register", "--summary-format", "json", "core5"]
)
assert result.exit_code == 0
summary_data = json.loads(result.output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding JSON validation check

Consider validating the result.output before parsing it as JSON to handle potential invalid JSON gracefully.

Code suggestion
Check the AI-generated fix before applying
Suggested change
summary_data = json.loads(result.output)
try:
summary_data = json.loads(result.output)
except json.JSONDecodeError as e:
pytest.fail(f"Failed to parse registration summary JSON: {e}")
except Exception as e:
pytest.fail(f"Unexpected error while parsing registration summary: {e}")

Code Review Run #5721dc


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

assert isinstance(summary_data, list)
assert len(summary_data) > 0
for entry in summary_data:
assert "id" in entry
assert "type" in entry
assert "version" in entry
assert "status" in entry
shutil.rmtree("core5")

@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote)
@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient)
def test_register_registrated_summary_yaml(mock_client, mock_remote):
ctx = FlyteContextManager.current_context()
mock_remote._client = mock_client
mock_remote.return_value.context = ctx
mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash"
mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url"
runner = CliRunner()
context_manager.FlyteEntities.entities.clear()

with runner.isolated_filesystem():
out = subprocess.run(["git", "init"], capture_output=True)
assert out.returncode == 0
os.makedirs("core6", exist_ok=True)
with open(os.path.join("core6", "sample.py"), "w") as f:
f.write(sample_file_contents)
f.close()

result = runner.invoke(
pyflyte.main,
["register", "--summary-format", "yaml", "core6"]
)
assert result.exit_code == 0
summary_data = yaml.safe_load(result.output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding YAML parse error handling

Consider adding error handling when parsing YAML from result.output. The yaml.safe_load() could raise yaml.YAMLError if the output is not valid YAML.

Code suggestion
Check the AI-generated fix before applying
Suggested change
summary_data = yaml.safe_load(result.output)
try:
summary_data = yaml.safe_load(result.output)
except yaml.YAMLError as e:
pytest.fail(f"Failed to parse YAML output: {e}")

Code Review Run #5721dc


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

assert isinstance(summary_data, list)
assert len(summary_data) > 0
for entry in summary_data:
assert "id" in entry
assert "type" in entry
assert "version" in entry
assert "status" in entry

shutil.rmtree("core6")

@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote)
@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient)
def test_register_quiet(mock_client, mock_remote):
ctx = FlyteContextManager.current_context()
mock_remote._client = mock_client
mock_remote.return_value.context = ctx
mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash"
mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url"
runner = CliRunner()
context_manager.FlyteEntities.entities.clear()
with runner.isolated_filesystem():
out = subprocess.run(["git", "init"], capture_output=True)
assert out.returncode == 0
os.makedirs("core7", exist_ok=True)
with open(os.path.join("core7", "sample.py"), "w") as f:
f.write(sample_file_contents)
f.close()
result = runner.invoke(
pyflyte.main,
["register", "--quiet", "core7"]
)
assert result.exit_code == 0
assert result.output == ""

shutil.rmtree("core7")
Loading