diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index a492e1cba8..45febd4a34 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -21,10 +21,12 @@ from flytekit.clis.sdk_in_container.serve import serve from flytekit.clis.sdk_in_container.utils import ErrorHandlingCommand, validate_package from flytekit.clis.version import info -from flytekit.configuration.file import FLYTECTL_CONFIG_ENV_VAR, FLYTECTL_CONFIG_ENV_VAR_OVERRIDE +from flytekit.configuration.file import FLYTECTL_CONFIG_ENV_VAR_OVERRIDE from flytekit.configuration.internal import LocalSDK from flytekit.configuration.plugin import get_plugin -from flytekit.loggers import logger + +# from flytekit.configuration.file import FLYTECTL_CONFIG_ENV_VAR, FLYTECTL_CONFIG_ENV_VAR_OVERRIDE +# from flytekit.loggers import logger @click.group("pyflyte", invoke_without_command=True, cls=ErrorHandlingCommand) @@ -66,11 +68,12 @@ def main(ctx, pkgs: typing.List[str], config: str, verbose: int): if config: ctx.obj[CTX_CONFIG_FILE] = config cfg = configuration.ConfigFile(config) + # Temporarily commented out to ensure proper output format when using --quiet flag in pyflyte register # Set here so that if someone has Config.auto() in their user code, the config here will get used. - if FLYTECTL_CONFIG_ENV_VAR in os.environ: - logger.info( - f"Config file arg {config} will override env var {FLYTECTL_CONFIG_ENV_VAR}: {os.environ[FLYTECTL_CONFIG_ENV_VAR]}" - ) + # if FLYTECTL_CONFIG_ENV_VAR in os.environ: + # logger.info( + # f"Config file arg {config} will override env var {FLYTECTL_CONFIG_ENV_VAR}: {os.environ[FLYTECTL_CONFIG_ENV_VAR]}" + # ) os.environ[FLYTECTL_CONFIG_ENV_VAR_OVERRIDE] = config if not pkgs: pkgs = LocalSDK.WORKFLOW_PACKAGES.read(cfg) diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 57a9b58448..06801dfd4e 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -33,6 +33,9 @@ the root of your project, it finds the first folder that does not have a ``__init__.py`` file. """ +_original_secho = click.secho +_original_log_level = logger.level + @click.command("register", help=_register_help) @project_option_dec @@ -142,6 +145,20 @@ help="Skip errors during registration. This is useful when registering multiple packages and you want to skip " "errors for some packages.", ) +@click.option( + "--summary-format", + "-f", + required=False, + type=click.Choice(["json", "yaml"], case_sensitive=False), + default=None, + help="Output format for registration summary. Lists registered workflows, tasks, and launch plans. 'json' and 'yaml' supported.", +) +@click.option( + "--quiet", + is_flag=True, + default=False, + help="Suppress output messages, only displaying errors.", +) @click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1) @click.pass_context def register( @@ -162,12 +179,25 @@ def register( activate_launchplans: bool, env: typing.Optional[typing.Dict[str, str]], skip_errors: bool, + summary_format: typing.Optional[str], + quiet: bool, ): """ see help """ + + if summary_format is not None: + quiet = True + + if quiet: + # Mute all secho output through monkey patching + click.secho = lambda *args, **kw: None + # Output only log at ERROR or CRITICAL level + logger.setLevel("ERROR") + # Set the relevant copy option if non_fast is set, this enables the individual file listing behavior # that the copy flag uses. + if non_fast: click.secho("The --non-fast flag is deprecated, please use --copy none instead", fg="yellow") if "--copy" in sys.argv: @@ -195,39 +225,46 @@ def register( "Missing argument 'PACKAGE_OR_MODULE...', at least one PACKAGE_OR_MODULE is required but multiple can be passed", ) - # Use extra images in the config file if that file exists - config_file = ctx.obj.get(constants.CTX_CONFIG_FILE) - if config_file: - image_config = patch_image_config(config_file, image_config) - - click.secho( - f"Running pyflyte register from {os.getcwd()} " - f"with images {image_config} " - f"and image destination folder {destination_dir} " - f"on {len(package_or_module)} package(s) {package_or_module}", - dim=True, - ) - - # Create and save FlyteRemote, - remote = get_and_save_remote_with_click_context(ctx, project, domain, data_upload_location="flyte://data") - click.secho(f"Registering against {remote.config.platform.endpoint}") - repo.register( - project, - domain, - image_config, - output, - destination_dir, - service_account, - raw_data_prefix, - version, - deref_symlinks, - copy_style=copy, - package_or_module=package_or_module, - remote=remote, - env=env, - dry_run=dry_run, - activate_launchplans=activate_launchplans, - skip_errors=skip_errors, - show_files=show_files, - verbosity=ctx.obj[constants.CTX_VERBOSE], - ) + try: + # Use extra images in the config file if that file exists + config_file = ctx.obj.get(constants.CTX_CONFIG_FILE) + if config_file: + image_config = patch_image_config(config_file, image_config) + + click.secho( + f"Running pyflyte register from {os.getcwd()} " + f"with images {image_config} " + f"and image destination folder {destination_dir} " + f"on {len(package_or_module)} package(s) {package_or_module}", + dim=True, + ) + + # Create and save FlyteRemote, + remote = get_and_save_remote_with_click_context(ctx, project, domain, data_upload_location="flyte://data") + click.secho(f"Registering against {remote.config.platform.endpoint}") + repo.register( + project, + domain, + image_config, + output, + destination_dir, + service_account, + raw_data_prefix, + version, + deref_symlinks, + copy_style=copy, + package_or_module=package_or_module, + remote=remote, + env=env, + summary_format=summary_format, + quiet=quiet, + dry_run=dry_run, + activate_launchplans=activate_launchplans, + skip_errors=skip_errors, + show_files=show_files, + verbosity=ctx.obj[constants.CTX_VERBOSE], + ) + finally: + # Restore original secho + click.secho = _original_secho + logger.setLevel(_original_log_level) diff --git a/flytekit/clis/sdk_in_container/utils.py b/flytekit/clis/sdk_in_container/utils.py index 9a5fdaa890..d5365eaa86 100644 --- a/flytekit/clis/sdk_in_container/utils.py +++ b/flytekit/clis/sdk_in_container/utils.py @@ -71,7 +71,8 @@ def validate_package(ctx, param, values): pkgs.extend(val.split(",")) else: pkgs.append(val) - logger.debug(f"Using packages: {pkgs}") + # Temporarily commented out to ensure proper output format when using --quiet flag in pyflyte register + # logger.debug(f"Using packages: {pkgs}") return pkgs diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index e2e46f49d3..807d05b5c7 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -1,12 +1,15 @@ import asyncio import functools +import json import os import tarfile import tempfile import typing +from contextlib import contextmanager from pathlib import Path import click +import yaml from rich import print as rprint from flytekit.configuration import FastSerializationSettings, ImageConfig, SerializationSettings @@ -22,6 +25,9 @@ from flytekit.tools.serialize_helpers import get_registrable_entities, persist_registrable_entities from flytekit.tools.translator import FlyteControlPlaneEntity, Options +_original_secho = click.secho +_original_log_level = logger.level + class NoSerializableEntitiesError(Exception): pass @@ -237,6 +243,20 @@ def print_registration_status( rprint(f"[{color}]{state_ind} {name}: {i.name} (Failed)") +@contextmanager +def temporary_secho(): + """ + Temporarily restores the original click.secho function. + Useful when you need to temporarily disable quiet mode. + """ + current_secho = click.secho + try: + click.secho = _original_secho + yield + finally: + click.secho = current_secho + + def register( project: str, domain: str, @@ -251,6 +271,8 @@ def register( remote: FlyteRemote, copy_style: CopyFileDetection, env: typing.Optional[typing.Dict[str, str]], + summary_format: typing.Optional[str], + quiet: bool = False, dry_run: bool = False, activate_launchplans: bool = False, skip_errors: bool = False, @@ -261,119 +283,162 @@ def register( Temporarily, for fast register, specify both the fast arg as well as copy_style. fast == True with copy_style == None means use the old fast register tar'ring method. """ - detected_root = find_common_root(package_or_module) - click.secho(f"Detected Root {detected_root}, using this to create deployable package...", fg="yellow") - - # Create serialization settings - # Todo: Rely on default Python interpreter for now, this will break custom Spark containers - serialization_settings = SerializationSettings( - project=project, - domain=domain, - version=version, - image_config=image_config, - fast_serialization_settings=None, # should probably add incomplete fast settings - env=env, - ) - if not version and copy_style == CopyFileDetection.NO_COPY: - click.secho("Version is required.", fg="red") - return + # Mute all secho output through monkey patching + if quiet: + click.secho = lambda *args, **kw: None + logger.setLevel("ERROR") + + try: + detected_root = find_common_root(package_or_module) + click.secho(f"Detected Root {detected_root}, using this to create deployable package...", fg="yellow") + + # Create serialization settings + # Todo: Rely on default Python interpreter for now, this will break custom Spark containers + serialization_settings = SerializationSettings( + project=project, + domain=domain, + version=version, + image_config=image_config, + fast_serialization_settings=None, # should probably add incomplete fast settings + env=env, + ) + + if not version and copy_style == CopyFileDetection.NO_COPY: + click.secho("Version is required.", fg="red") + return - b = serialization_settings.new_builder() - serialization_settings = b.build() + b = serialization_settings.new_builder() + serialization_settings = b.build() - options = Options.default_from(k8s_service_account=service_account, raw_data_prefix=raw_data_prefix) + options = Options.default_from(k8s_service_account=service_account, raw_data_prefix=raw_data_prefix) - # Load all the entities - FlyteContextManager.push_context(remote.context) - serialization_settings.git_repo = _get_git_repo_url(str(detected_root)) - pkgs_and_modules = list_packages_and_modules(detected_root, list(package_or_module)) + # Load all the entities + FlyteContextManager.push_context(remote.context) + serialization_settings.git_repo = _get_git_repo_url(str(detected_root)) + pkgs_and_modules = list_packages_and_modules(detected_root, list(package_or_module)) - # NB: The change here is that the loading of user code _cannot_ depend on fast register information (the computed - # version, upload native url, hash digest, etc.). - serialize_load_only(pkgs_and_modules, serialization_settings, str(detected_root)) + # NB: The change here is that the loading of user code _cannot_ depend on fast register information (the computed + # version, upload native url, hash digest, etc.). + serialize_load_only(pkgs_and_modules, serialization_settings, str(detected_root)) - # Fast registration is handled after module loading - if copy_style != CopyFileDetection.NO_COPY: - md5_bytes, native_url = remote.fast_package( - detected_root, - deref_symlinks, - output, - options=fast_registration.FastPackageOptions([], copy_style=copy_style, show_files=show_files), + # Fast registration is handled after module loading + if copy_style != CopyFileDetection.NO_COPY: + md5_bytes, native_url = remote.fast_package( + detected_root, + deref_symlinks, + output, + options=fast_registration.FastPackageOptions([], copy_style=copy_style, show_files=show_files), + ) + # update serialization settings from fast register output + fast_serialization_settings = FastSerializationSettings( + enabled=True, + destination_dir=destination_dir, + distribution_location=native_url, + ) + serialization_settings.fast_serialization_settings = fast_serialization_settings + if not version: + version = remote._version_from_hash(md5_bytes, serialization_settings, service_account, raw_data_prefix) # noqa + serialization_settings.version = version + click.secho(f"Computed version is {version}", fg="yellow") + + registrable_entities = serialize_get_control_plane_entities( + serialization_settings, str(detected_root), options, is_registration=True ) - # update serialization settings from fast register output - fast_serialization_settings = FastSerializationSettings( - enabled=True, - destination_dir=destination_dir, - distribution_location=native_url, + + click.secho( + f"Serializing and registering {len(registrable_entities)} flyte entities", + fg="green", ) - serialization_settings.fast_serialization_settings = fast_serialization_settings - if not version: - version = remote._version_from_hash(md5_bytes, serialization_settings, service_account, raw_data_prefix) # noqa - serialization_settings.version = version - click.secho(f"Computed version is {version}", fg="yellow") - - registrable_entities = serialize_get_control_plane_entities( - serialization_settings, str(detected_root), options, is_registration=True - ) - click.secho( - f"Serializing and registering {len(registrable_entities)} flyte entities", - fg="green", - ) - FlyteContextManager.pop_context() - if len(registrable_entities) == 0: - click.secho("No Flyte entities were detected. Aborting!", fg="red") - return + FlyteContextManager.pop_context() + if len(registrable_entities) == 0: + click.secho("No Flyte entities were detected. Aborting!", fg="red") + return - def _raw_register(cp_entity: FlyteControlPlaneEntity): - is_lp = False - if isinstance(cp_entity, launch_plan.LaunchPlan): - og_id = cp_entity.id - is_lp = True - else: - og_id = cp_entity.template.id - try: - if not dry_run: - try: - i = remote.raw_register( - cp_entity, serialization_settings, version=version, create_default_launchplan=False - ) - console_url = remote.generate_console_url(i) - print_activation_message = False - if is_lp: - if activate_launchplans: - remote.activate_launchplan(i) - print_activation_message = True - if cp_entity.should_auto_activate: - print_activation_message = True - print_registration_status( - i, console_url=console_url, verbosity=verbosity, activation=print_activation_message - ) - - except Exception as e: - if not skip_errors: - raise e - print_registration_status(og_id, success=False) + def _raw_register(cp_entity: FlyteControlPlaneEntity): + is_lp = False + if isinstance(cp_entity, launch_plan.LaunchPlan): + og_id = cp_entity.id + is_lp = True else: - print_registration_status(og_id, dry_run=True) - except RegistrationSkipped: - print_registration_status(og_id, success=False) - - async def _register(entities: typing.List[task.TaskSpec]): - loop = asyncio.get_running_loop() - tasks = [] - for entity in entities: - tasks.append(loop.run_in_executor(None, functools.partial(_raw_register, entity))) - await asyncio.gather(*tasks) - return - - # concurrent register - cp_task_entities = list(filter(lambda x: isinstance(x, task.TaskSpec), registrable_entities)) - asyncio.run(_register(cp_task_entities)) - # serial register - cp_other_entities = list(filter(lambda x: not isinstance(x, task.TaskSpec), registrable_entities)) - for entity in cp_other_entities: - _raw_register(entity) - - click.secho(f"Successfully registered {len(registrable_entities)} entities", fg="green") + og_id = cp_entity.template.id + + result = { + "id": og_id.name, + "type": og_id.resource_type_name(), + "version": og_id.version, + "status": "skipped", # default status + } + + try: + if not dry_run: + try: + i = remote.raw_register( + cp_entity, serialization_settings, version=version, create_default_launchplan=False + ) + console_url = remote.generate_console_url(i) + print_activation_message = False + if is_lp: + if activate_launchplans: + remote.activate_launchplan(i) + print_activation_message = True + if cp_entity.should_auto_activate: + print_activation_message = True + if not quiet: + print_registration_status( + i, console_url=console_url, verbosity=verbosity, activation=print_activation_message + ) + result["status"] = "success" + + except Exception as e: + if not skip_errors: + raise e + if not quiet: + print_registration_status(og_id, success=False) + result["status"] = "failed" + else: + if not quiet: + print_registration_status(og_id, dry_run=True) + except RegistrationSkipped: + if not quiet: + print_registration_status(og_id, success=False) + result["status"] = "skipped" + + return result + + async def _register(entities: typing.List[task.TaskSpec]): + loop = asyncio.get_running_loop() + tasks = [] + for entity in entities: + tasks.append(loop.run_in_executor(None, functools.partial(_raw_register, entity))) + results = await asyncio.gather(*tasks) + return results + + # concurrent register + cp_task_entities = list(filter(lambda x: isinstance(x, task.TaskSpec), registrable_entities)) + task_results = asyncio.run(_register(cp_task_entities)) + # serial register + cp_other_entities = list(filter(lambda x: not isinstance(x, task.TaskSpec), registrable_entities)) + other_results = [] + for entity in cp_other_entities: + other_results.append(_raw_register(entity)) + + all_results = task_results + other_results + + click.secho(f"Successfully registered {len(registrable_entities)} entities", fg="green") + + if summary_format is not None: + supported_format = {"json", "yaml"} + if summary_format not in supported_format: + raise ValueError(f"Unsupported file format: {summary_format}") + + with temporary_secho(): + if summary_format == "json": + click.secho(json.dumps(all_results, indent=2)) + elif summary_format == "yaml": + click.secho(yaml.dump(all_results)) + finally: + # Restore original secho + click.secho = _original_secho + logger.setLevel(_original_log_level) diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index ec14aa8227..35e8604d15 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -1,6 +1,8 @@ import os import shutil import subprocess +import json +import yaml import mock import pytest @@ -15,6 +17,7 @@ from flytekit.core import context_manager from flytekit.core.context_manager import FlyteContextManager from flytekit.remote.remote import FlyteRemote +from flytekit.loggers import logging sample_file_contents = """ from flytekit import task, workflow @@ -163,3 +166,108 @@ def test_non_fast_register_require_version(mock_client, mock_remote): result = runner.invoke(pyflyte.main, ["register", "--non-fast", "core3"]) assert result.exit_code == 1 shutil.rmtree("core3") + + +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) +def test_register_registrated_summary_json(mock_client, mock_remote): + ctx = FlyteContextManager.current_context() + mock_remote._client = mock_client + mock_remote.return_value.context = ctx + mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" + mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url" + runner = CliRunner() + context_manager.FlyteEntities.entities.clear() + + with runner.isolated_filesystem(): + out = subprocess.run(["git", "init"], capture_output=True) + assert out.returncode == 0 + os.makedirs("core5", exist_ok=True) + with open(os.path.join("core5", "sample.py"), "w") as f: + f.write(sample_file_contents) + f.close() + + result = runner.invoke( + pyflyte.main, + ["register", "--summary-format", "json", "core5"] + ) + assert result.exit_code == 0 + try: + summary_data = json.loads(result.output) + except json.JSONDecodeError as e: + pytest.fail(f"Failed to parse registration summary JSON: {e}") + except Exception as e: + pytest.fail(f"Unexpected error while parsing registration summary: {e}") + assert isinstance(summary_data, list) + assert len(summary_data) > 0 + for entry in summary_data: + assert "id" in entry + assert "type" in entry + assert "version" in entry + assert "status" in entry + shutil.rmtree("core5") + +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) +def test_register_registrated_summary_yaml(mock_client, mock_remote): + ctx = FlyteContextManager.current_context() + mock_remote._client = mock_client + mock_remote.return_value.context = ctx + mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" + mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url" + runner = CliRunner() + context_manager.FlyteEntities.entities.clear() + + with runner.isolated_filesystem(): + out = subprocess.run(["git", "init"], capture_output=True) + assert out.returncode == 0 + os.makedirs("core6", exist_ok=True) + with open(os.path.join("core6", "sample.py"), "w") as f: + f.write(sample_file_contents) + f.close() + + result = runner.invoke( + pyflyte.main, + ["register", "--summary-format", "yaml", "core6"] + ) + assert result.exit_code == 0 + try: + summary_data = yaml.safe_load(result.output) + except yaml.YAMLError as e: + pytest.fail(f"Failed to parse YAML output: {e}") + assert isinstance(summary_data, list) + assert len(summary_data) > 0 + for entry in summary_data: + assert "id" in entry + assert "type" in entry + assert "version" in entry + assert "status" in entry + + shutil.rmtree("core6") + + +@mock.patch("flytekit.configuration.plugin.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) +def test_register_quiet(mock_client, mock_remote): + ctx = FlyteContextManager.current_context() + mock_remote._client = mock_client + mock_remote.return_value.context = ctx + mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" + mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url" + runner = CliRunner() + context_manager.FlyteEntities.entities.clear() + with runner.isolated_filesystem(): + out = subprocess.run(["git", "init"], capture_output=True) + assert out.returncode == 0 + os.makedirs("core7", exist_ok=True) + with open(os.path.join("core7", "sample.py"), "w") as f: + f.write(sample_file_contents) + f.close() + result = runner.invoke( + pyflyte.main, + ["register", "--quiet", "core7"] + ) + assert result.exit_code == 0 + assert result.output == "" + + shutil.rmtree("core7")