From 198be4f680d9f786ce4d90af5eaf8a3419502549 Mon Sep 17 00:00:00 2001 From: Daniel Neilson Date: Fri, 1 Sep 2023 03:37:29 +0000 Subject: [PATCH] feat!: Import from internal repository This is an import of the source code that had been authored on internal repositories. --- .github/CODEOWNERS | 1 + .github/ISSUE_TEMPLATE/bug.yml | 33 + .github/ISSUE_TEMPLATE/config.yml | 1 + .github/ISSUE_TEMPLATE/doc.yml | 13 + .github/ISSUE_TEMPLATE/feature_request.yml | 17 + .github/ISSUE_TEMPLATE/maintenance.yml | 17 + .github/PULL_REQUEST_TEMPLATE.md | 15 + .github/dependabot.yml | 21 + .github/workflows/auto_approve.yml | 21 + .github/workflows/code_quality.yml | 17 + .github/workflows/release.yml | 51 ++ .github/workflows/reuse_python_build.yml | 62 ++ .gitignore | 24 + CONTRIBUTING.md | 9 +- DEVELOPMENT.md | 160 ++++ README.md | 11 +- THIRD-PARTY-LICENSES.txt | 47 ++ hatch.toml | 33 + hatch_version_hook.py | 156 ++++ pipeline/build.sh | 10 + pipeline/publish.sh | 6 + pyproject.toml | 128 +++ requirements-development.txt | 2 + requirements-testing.txt | 10 + scripts/add_copyright_headers.sh | 75 ++ src/openjd/adaptor_runtime/__init__.py | 7 + .../adaptor_runtime/_background/__init__.py | 13 + .../_background/backend_runner.py | 111 +++ .../_background/frontend_runner.py | 332 ++++++++ .../_background/http_server.py | 298 +++++++ .../_background/log_buffers.py | 174 ++++ .../adaptor_runtime/_background/model.py | 113 +++ src/openjd/adaptor_runtime/_entrypoint.py | 328 ++++++++ src/openjd/adaptor_runtime/_http/__init__.py | 6 + .../adaptor_runtime/_http/exceptions.py | 17 + .../adaptor_runtime/_http/request_handler.py | 227 +++++ src/openjd/adaptor_runtime/_http/sockets.py | 163 ++++ src/openjd/adaptor_runtime/_osname.py | 95 +++ src/openjd/adaptor_runtime/_utils/__init__.py | 7 + .../adaptor_runtime/_utils/_secure_open.py | 75 ++ .../adaptor_runtime/adaptors/__init__.py | 21 + .../adaptor_runtime/adaptors/_adaptor.py | 70 ++ .../adaptors/_adaptor_runner.py | 86 ++ .../adaptors/_adaptor_states.py | 61 ++ .../adaptor_runtime/adaptors/_base_adaptor.py | 231 ++++++ .../adaptors/_command_adaptor.py | 68 ++ .../adaptor_runtime/adaptors/_path_mapping.py | 139 ++++ .../adaptor_runtime/adaptors/_validator.py | 156 ++++ .../adaptors/configuration/__init__.py | 44 + .../_adaptor_configuration.schema.json | 10 + .../adaptors/configuration/_configuration.py | 206 +++++ .../configuration/_configuration_manager.py | 273 ++++++ .../adaptor_runtime/app_handlers/__init__.py | 5 + .../app_handlers/_regex_callback_handler.py | 111 +++ .../application_ipc/__init__.py | 6 + .../application_ipc/_actions_queue.py | 46 ++ .../application_ipc/_adaptor_server.py | 43 + .../application_ipc/_http_request_handler.py | 136 +++ src/openjd/adaptor_runtime/configuration.json | 4 + .../adaptor_runtime/configuration.schema.json | 14 + .../adaptor_runtime/process/__init__.py | 11 + .../adaptor_runtime/process/_logging.py | 12 + .../process/_logging_subprocess.py | 220 +++++ .../process/_managed_process.py | 71 ++ .../adaptor_runtime/process/_stream_logger.py | 62 ++ src/openjd/adaptor_runtime/py.typed | 1 + src/openjd/adaptor_runtime_client/__init__.py | 13 + src/openjd/adaptor_runtime_client/action.py | 50 ++ .../client_interface.py | 224 +++++ .../adaptor_runtime_client/connection.py | 73 ++ src/openjd/adaptor_runtime_client/py.typed | 1 + test/openjd/adaptor_runtime/__init__.py | 1 + test/openjd/adaptor_runtime/conftest.py | 27 + .../IntegCommandAdaptor.json | 1 + .../integ/IntegCommandAdaptor/__init__.py | 6 + .../integ/IntegCommandAdaptor/__main__.py | 31 + .../integ/IntegCommandAdaptor/adaptor.py | 40 + test/openjd/adaptor_runtime/integ/__init__.py | 1 + .../integ/adaptors/__init__.py | 1 + .../integ/adaptors/configuration/__init__.py | 1 + .../configuration/test_configuration.py | 68 ++ .../test_configuration_manager.py | 97 +++ .../adaptors/test_integration_adaptor.py | 104 +++ .../adaptors/test_integration_path_mapping.py | 350 ++++++++ .../integ/application_ipc/__init__.py | 1 + .../integ/application_ipc/fake_app_client.py | 24 + .../test_integration_adaptor_ipc.py | 185 +++++ .../integ/background/__init__.py | 1 + .../sample_adaptor/SampleAdaptor.json | 1 + .../background/sample_adaptor/__init__.py | 7 + .../background/sample_adaptor/__main__.py | 19 + .../background/sample_adaptor/adaptor.py | 28 + ...integration.background.sample_adaptor.json | 3 + .../integ/background/test_background_mode.py | 226 +++++ .../adaptor_runtime/integ/process/__init__.py | 1 + .../process/scripts/echo_sleep_n_times.sh | 10 + .../integ/process/scripts/no_sigterm.sh | 23 + .../integ/process/scripts/print_signals.sh | 17 + .../test_integration_logging_subprocess.py | 221 +++++ .../test_integration_managed_process.py | 88 ++ .../integ/test_integration_entrypoint.py | 178 ++++ .../openjd/adaptor_runtime/test_importable.py | 9 + test/openjd/adaptor_runtime/unit/__init__.py | 1 + .../adaptor_runtime/unit/adaptors/__init__.py | 1 + .../unit/adaptors/configuration/__init__.py | 1 + .../unit/adaptors/configuration/stubs.py | 72 ++ .../configuration/test_configuration.py | 284 +++++++ .../test_configuration_manager.py | 657 +++++++++++++++ .../unit/adaptors/fake_adaptor.py | 24 + .../unit/adaptors/test_adaptor.py | 76 ++ .../unit/adaptors/test_adaptor_runner.py | 181 ++++ .../unit/adaptors/test_base_adaptor.py | 383 +++++++++ .../unit/adaptors/test_basic_adaptor.py | 39 + .../unit/adaptors/test_path_mapping.py | 372 +++++++++ .../unit/application_ipc/__init__.py | 1 + .../application_ipc/test_actions_queue.py | 87 ++ .../test_adaptor_http_request_handler.py | 216 +++++ .../unit/background/__init__.py | 1 + .../unit/background/test_backend_runner.py | 174 ++++ .../unit/background/test_frontend_runner.py | 743 +++++++++++++++++ .../unit/background/test_http_server.py | 778 ++++++++++++++++++ .../unit/background/test_log_buffers.py | 181 ++++ .../unit/background/test_model.py | 51 ++ .../adaptor_runtime/unit/handlers/__init__.py | 1 + .../handlers/test_regex_callback_handler.py | 329 ++++++++ .../adaptor_runtime/unit/http/__init__.py | 1 + .../unit/http/test_request_handler.py | 242 ++++++ .../adaptor_runtime/unit/http/test_sockets.py | 265 ++++++ .../adaptor_runtime/unit/process/__init__.py | 1 + .../unit/process/test_logging_subprocess.py | 496 +++++++++++ .../unit/process/test_managed_process.py | 73 ++ .../unit/process/test_stream_logger.py | 132 +++ .../adaptor_runtime/unit/test_entrypoint.py | 620 ++++++++++++++ .../adaptor_runtime/unit/test_osname.py | 71 ++ .../adaptor_runtime/unit/utils/__init__.py | 1 + .../unit/utils/test_secure_open.py | 97 +++ .../openjd/adaptor_runtime_client/__init__.py | 1 + .../adaptor_runtime_client/integ/__init__.py | 1 + .../integ/fake_client.py | 36 + .../test_integration_client_interface.py | 35 + .../adaptor_runtime_client/test_importable.py | 9 + .../adaptor_runtime_client/unit/__init__.py | 1 + .../unit/test_action.py | 67 ++ .../unit/test_client_interface.py | 380 +++++++++ test/openjd/test_copyright_header.py | 79 ++ test/openjd/test_importable.py | 5 + 146 files changed, 14400 insertions(+), 10 deletions(-) create mode 100644 .github/CODEOWNERS create mode 100644 .github/ISSUE_TEMPLATE/bug.yml create mode 100644 .github/ISSUE_TEMPLATE/config.yml create mode 100644 .github/ISSUE_TEMPLATE/doc.yml create mode 100644 .github/ISSUE_TEMPLATE/feature_request.yml create mode 100644 .github/ISSUE_TEMPLATE/maintenance.yml create mode 100644 .github/PULL_REQUEST_TEMPLATE.md create mode 100644 .github/dependabot.yml create mode 100644 .github/workflows/auto_approve.yml create mode 100644 .github/workflows/code_quality.yml create mode 100644 .github/workflows/release.yml create mode 100644 .github/workflows/reuse_python_build.yml create mode 100644 .gitignore create mode 100644 DEVELOPMENT.md create mode 100644 THIRD-PARTY-LICENSES.txt create mode 100644 hatch.toml create mode 100644 hatch_version_hook.py create mode 100755 pipeline/build.sh create mode 100755 pipeline/publish.sh create mode 100644 pyproject.toml create mode 100644 requirements-development.txt create mode 100644 requirements-testing.txt create mode 100755 scripts/add_copyright_headers.sh create mode 100644 src/openjd/adaptor_runtime/__init__.py create mode 100644 src/openjd/adaptor_runtime/_background/__init__.py create mode 100644 src/openjd/adaptor_runtime/_background/backend_runner.py create mode 100644 src/openjd/adaptor_runtime/_background/frontend_runner.py create mode 100644 src/openjd/adaptor_runtime/_background/http_server.py create mode 100644 src/openjd/adaptor_runtime/_background/log_buffers.py create mode 100644 src/openjd/adaptor_runtime/_background/model.py create mode 100644 src/openjd/adaptor_runtime/_entrypoint.py create mode 100644 src/openjd/adaptor_runtime/_http/__init__.py create mode 100644 src/openjd/adaptor_runtime/_http/exceptions.py create mode 100644 src/openjd/adaptor_runtime/_http/request_handler.py create mode 100644 src/openjd/adaptor_runtime/_http/sockets.py create mode 100644 src/openjd/adaptor_runtime/_osname.py create mode 100644 src/openjd/adaptor_runtime/_utils/__init__.py create mode 100644 src/openjd/adaptor_runtime/_utils/_secure_open.py create mode 100644 src/openjd/adaptor_runtime/adaptors/__init__.py create mode 100644 src/openjd/adaptor_runtime/adaptors/_adaptor.py create mode 100644 src/openjd/adaptor_runtime/adaptors/_adaptor_runner.py create mode 100644 src/openjd/adaptor_runtime/adaptors/_adaptor_states.py create mode 100644 src/openjd/adaptor_runtime/adaptors/_base_adaptor.py create mode 100644 src/openjd/adaptor_runtime/adaptors/_command_adaptor.py create mode 100644 src/openjd/adaptor_runtime/adaptors/_path_mapping.py create mode 100644 src/openjd/adaptor_runtime/adaptors/_validator.py create mode 100644 src/openjd/adaptor_runtime/adaptors/configuration/__init__.py create mode 100644 src/openjd/adaptor_runtime/adaptors/configuration/_adaptor_configuration.schema.json create mode 100644 src/openjd/adaptor_runtime/adaptors/configuration/_configuration.py create mode 100644 src/openjd/adaptor_runtime/adaptors/configuration/_configuration_manager.py create mode 100644 src/openjd/adaptor_runtime/app_handlers/__init__.py create mode 100644 src/openjd/adaptor_runtime/app_handlers/_regex_callback_handler.py create mode 100644 src/openjd/adaptor_runtime/application_ipc/__init__.py create mode 100644 src/openjd/adaptor_runtime/application_ipc/_actions_queue.py create mode 100644 src/openjd/adaptor_runtime/application_ipc/_adaptor_server.py create mode 100644 src/openjd/adaptor_runtime/application_ipc/_http_request_handler.py create mode 100644 src/openjd/adaptor_runtime/configuration.json create mode 100644 src/openjd/adaptor_runtime/configuration.schema.json create mode 100644 src/openjd/adaptor_runtime/process/__init__.py create mode 100644 src/openjd/adaptor_runtime/process/_logging.py create mode 100644 src/openjd/adaptor_runtime/process/_logging_subprocess.py create mode 100644 src/openjd/adaptor_runtime/process/_managed_process.py create mode 100644 src/openjd/adaptor_runtime/process/_stream_logger.py create mode 100644 src/openjd/adaptor_runtime/py.typed create mode 100644 src/openjd/adaptor_runtime_client/__init__.py create mode 100644 src/openjd/adaptor_runtime_client/action.py create mode 100644 src/openjd/adaptor_runtime_client/client_interface.py create mode 100644 src/openjd/adaptor_runtime_client/connection.py create mode 100644 src/openjd/adaptor_runtime_client/py.typed create mode 100644 test/openjd/adaptor_runtime/__init__.py create mode 100644 test/openjd/adaptor_runtime/conftest.py create mode 100644 test/openjd/adaptor_runtime/integ/IntegCommandAdaptor/IntegCommandAdaptor.json create mode 100644 test/openjd/adaptor_runtime/integ/IntegCommandAdaptor/__init__.py create mode 100644 test/openjd/adaptor_runtime/integ/IntegCommandAdaptor/__main__.py create mode 100644 test/openjd/adaptor_runtime/integ/IntegCommandAdaptor/adaptor.py create mode 100644 test/openjd/adaptor_runtime/integ/__init__.py create mode 100644 test/openjd/adaptor_runtime/integ/adaptors/__init__.py create mode 100644 test/openjd/adaptor_runtime/integ/adaptors/configuration/__init__.py create mode 100644 test/openjd/adaptor_runtime/integ/adaptors/configuration/test_configuration.py create mode 100644 test/openjd/adaptor_runtime/integ/adaptors/configuration/test_configuration_manager.py create mode 100644 test/openjd/adaptor_runtime/integ/adaptors/test_integration_adaptor.py create mode 100644 test/openjd/adaptor_runtime/integ/adaptors/test_integration_path_mapping.py create mode 100644 test/openjd/adaptor_runtime/integ/application_ipc/__init__.py create mode 100644 test/openjd/adaptor_runtime/integ/application_ipc/fake_app_client.py create mode 100644 test/openjd/adaptor_runtime/integ/application_ipc/test_integration_adaptor_ipc.py create mode 100644 test/openjd/adaptor_runtime/integ/background/__init__.py create mode 100644 test/openjd/adaptor_runtime/integ/background/sample_adaptor/SampleAdaptor.json create mode 100644 test/openjd/adaptor_runtime/integ/background/sample_adaptor/__init__.py create mode 100644 test/openjd/adaptor_runtime/integ/background/sample_adaptor/__main__.py create mode 100644 test/openjd/adaptor_runtime/integ/background/sample_adaptor/adaptor.py create mode 100644 test/openjd/adaptor_runtime/integ/background/sample_adaptor/tests.integration.background.sample_adaptor.json create mode 100644 test/openjd/adaptor_runtime/integ/background/test_background_mode.py create mode 100644 test/openjd/adaptor_runtime/integ/process/__init__.py create mode 100755 test/openjd/adaptor_runtime/integ/process/scripts/echo_sleep_n_times.sh create mode 100755 test/openjd/adaptor_runtime/integ/process/scripts/no_sigterm.sh create mode 100755 test/openjd/adaptor_runtime/integ/process/scripts/print_signals.sh create mode 100644 test/openjd/adaptor_runtime/integ/process/test_integration_logging_subprocess.py create mode 100644 test/openjd/adaptor_runtime/integ/process/test_integration_managed_process.py create mode 100644 test/openjd/adaptor_runtime/integ/test_integration_entrypoint.py create mode 100644 test/openjd/adaptor_runtime/test_importable.py create mode 100644 test/openjd/adaptor_runtime/unit/__init__.py create mode 100644 test/openjd/adaptor_runtime/unit/adaptors/__init__.py create mode 100644 test/openjd/adaptor_runtime/unit/adaptors/configuration/__init__.py create mode 100644 test/openjd/adaptor_runtime/unit/adaptors/configuration/stubs.py create mode 100644 test/openjd/adaptor_runtime/unit/adaptors/configuration/test_configuration.py create mode 100644 test/openjd/adaptor_runtime/unit/adaptors/configuration/test_configuration_manager.py create mode 100644 test/openjd/adaptor_runtime/unit/adaptors/fake_adaptor.py create mode 100644 test/openjd/adaptor_runtime/unit/adaptors/test_adaptor.py create mode 100644 test/openjd/adaptor_runtime/unit/adaptors/test_adaptor_runner.py create mode 100644 test/openjd/adaptor_runtime/unit/adaptors/test_base_adaptor.py create mode 100644 test/openjd/adaptor_runtime/unit/adaptors/test_basic_adaptor.py create mode 100644 test/openjd/adaptor_runtime/unit/adaptors/test_path_mapping.py create mode 100644 test/openjd/adaptor_runtime/unit/application_ipc/__init__.py create mode 100644 test/openjd/adaptor_runtime/unit/application_ipc/test_actions_queue.py create mode 100644 test/openjd/adaptor_runtime/unit/application_ipc/test_adaptor_http_request_handler.py create mode 100644 test/openjd/adaptor_runtime/unit/background/__init__.py create mode 100644 test/openjd/adaptor_runtime/unit/background/test_backend_runner.py create mode 100644 test/openjd/adaptor_runtime/unit/background/test_frontend_runner.py create mode 100644 test/openjd/adaptor_runtime/unit/background/test_http_server.py create mode 100644 test/openjd/adaptor_runtime/unit/background/test_log_buffers.py create mode 100644 test/openjd/adaptor_runtime/unit/background/test_model.py create mode 100644 test/openjd/adaptor_runtime/unit/handlers/__init__.py create mode 100644 test/openjd/adaptor_runtime/unit/handlers/test_regex_callback_handler.py create mode 100644 test/openjd/adaptor_runtime/unit/http/__init__.py create mode 100644 test/openjd/adaptor_runtime/unit/http/test_request_handler.py create mode 100644 test/openjd/adaptor_runtime/unit/http/test_sockets.py create mode 100644 test/openjd/adaptor_runtime/unit/process/__init__.py create mode 100644 test/openjd/adaptor_runtime/unit/process/test_logging_subprocess.py create mode 100644 test/openjd/adaptor_runtime/unit/process/test_managed_process.py create mode 100644 test/openjd/adaptor_runtime/unit/process/test_stream_logger.py create mode 100644 test/openjd/adaptor_runtime/unit/test_entrypoint.py create mode 100644 test/openjd/adaptor_runtime/unit/test_osname.py create mode 100644 test/openjd/adaptor_runtime/unit/utils/__init__.py create mode 100644 test/openjd/adaptor_runtime/unit/utils/test_secure_open.py create mode 100644 test/openjd/adaptor_runtime_client/__init__.py create mode 100644 test/openjd/adaptor_runtime_client/integ/__init__.py create mode 100644 test/openjd/adaptor_runtime_client/integ/fake_client.py create mode 100644 test/openjd/adaptor_runtime_client/integ/test_integration_client_interface.py create mode 100644 test/openjd/adaptor_runtime_client/test_importable.py create mode 100644 test/openjd/adaptor_runtime_client/unit/__init__.py create mode 100644 test/openjd/adaptor_runtime_client/unit/test_action.py create mode 100644 test/openjd/adaptor_runtime_client/unit/test_client_interface.py create mode 100644 test/openjd/test_copyright_header.py create mode 100644 test/openjd/test_importable.py diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..7472320 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @xxyggoqtpcmcofkc/Developers \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/bug.yml b/.github/ISSUE_TEMPLATE/bug.yml new file mode 100644 index 0000000..dd351d8 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug.yml @@ -0,0 +1,33 @@ +name: "\U0001F41B Bug Report" +description: Report a bug +title: "Bug: TITLE" +labels: ["bug"] +body: + - type: textarea + id: expected_behaviour + attributes: + label: Expected Behaviour + validations: + required: true + + - type: textarea + id: current_behaviour + attributes: + label: Current Behaviour + validations: + required: true + + - type: textarea + id: reproduction_steps + attributes: + label: Reproduction Steps + validations: + required: true + + - type: textarea + id: code_snippet + attributes: + label: Code Snippet + validations: + required: true + diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..ec4bb38 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1 @@ +blank_issues_enabled: false \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/doc.yml b/.github/ISSUE_TEMPLATE/doc.yml new file mode 100644 index 0000000..883623b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/doc.yml @@ -0,0 +1,13 @@ + +name: "πŸ“• Documentation Issue" +description: Issue in the documentation +title: "Docs: TITLE" +labels: ["documenation"] +body: + - type: textarea + id: documentation_issue + attributes: + label: Documentation Issue + description: Describe the issue + validations: + required: true \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000..ed7f957 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,17 @@ +name: "\U0001F680 Feature Request" +description: Request a new feature +title: "Feature request: TITLE" +labels: ["feature"] +body: + - type: textarea + id: use_case + attributes: + label: Use Case + validations: + required: true + - type: textarea + id: proposed_solution + attributes: + label: Proposed Solution + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/maintenance.yml b/.github/ISSUE_TEMPLATE/maintenance.yml new file mode 100644 index 0000000..a0f98ff --- /dev/null +++ b/.github/ISSUE_TEMPLATE/maintenance.yml @@ -0,0 +1,17 @@ +name: "πŸ› οΈ Maintenance" +description: Some type of improvement +title: "Maintenance: TITLE" +labels: ["feature"] +body: + - type: textarea + id: description + attributes: + label: Description + validations: + required: true + - type: textarea + id: solution + attributes: + label: Solution + validations: + required: true diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..07e5be9 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,15 @@ +### What was the problem/requirement? (What/Why) + +### What was the solution? (How) + +### What is the impact of this change? + +### How was this change tested? + +### Was this change documented? + +### Is this a breaking change? + +---- + +*By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.* \ No newline at end of file diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..ea85102 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,21 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" # Location of package manifests + schedule: + interval: "weekly" + day: "monday" + commit-message: + prefix: "chore(deps):" + - package-ecosystem: "github-actions" + directory: "/" # Location of package manifests + schedule: + interval: "weekly" + day: "monday" + commit-message: + prefix: "chore(github):" \ No newline at end of file diff --git a/.github/workflows/auto_approve.yml b/.github/workflows/auto_approve.yml new file mode 100644 index 0000000..5fbbef7 --- /dev/null +++ b/.github/workflows/auto_approve.yml @@ -0,0 +1,21 @@ +name: Dependabot auto-approve +on: pull_request + +permissions: + pull-requests: write + +jobs: + dependabot: + runs-on: ubuntu-latest + if: ${{ github.actor == 'dependabot[bot]' }} + steps: + - name: Dependabot metadata + id: metadata + uses: dependabot/fetch-metadata@v1 + with: + github-token: "${{ secrets.GITHUB_TOKEN }}" + - name: Approve a PR + run: gh pr review --approve "$PR_URL" + env: + PR_URL: ${{ github.event.pull_request.html_url }} + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/.github/workflows/code_quality.yml b/.github/workflows/code_quality.yml new file mode 100644 index 0000000..c57fe58 --- /dev/null +++ b/.github/workflows/code_quality.yml @@ -0,0 +1,17 @@ +name: Code Quality + +on: + pull_request: + branches: [ mainline ] + workflow_call: + inputs: + branch: + required: false + type: string + +jobs: + TestPython: + name: Code Quality + uses: ./.github/workflows/reuse_python_build.yml + secrets: inherit + diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..d0a3365 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,51 @@ +name: Release + +on: + workflow_dispatch: + inputs: + version_to_publish: + description: "Version to be release" + required: false + +jobs: + TestMainline: + name: Test Mainline + uses: ./.github/workflows/code_quality.yml + with: + branch: mainline + secrets: inherit + + Merge: + needs: TestMainline + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - uses: actions/checkout@v3 + with: + ref: release + fetch-depth: 0 + token: ${{ secrets.CI_TOKEN }} + - name: Set Git config + run: | + git config --local user.email "client-software-ci@amazon.com" + git config --local user.name "client-software-ci" + - name: Update Release + run: git merge --ff-only origin/mainline -v + - name: Push new release + if: ${{ inputs.version_to_publish}} + run: | + git tag -a ${{ inputs.version_to_publish }} -m "Release ${{ inputs.version_to_publish }}" + git push origin release ${{ inputs.version_to_publish }} + - name: Push post release + if: ${{ !inputs.version_to_publish}} + run: git push origin release + + TestRelease: + needs: Merge + name: Test Release + uses: ./.github/workflows/code_quality.yml + with: + branch: release + secrets: inherit + diff --git a/.github/workflows/reuse_python_build.yml b/.github/workflows/reuse_python_build.yml new file mode 100644 index 0000000..30696d2 --- /dev/null +++ b/.github/workflows/reuse_python_build.yml @@ -0,0 +1,62 @@ +name: Python Build + +on: + workflow_call: + inputs: + branch: + required: false + type: string + +jobs: + Python: + # We've seen some deadlocks in tests run on the CI. Hard-cap the job runtime + # to prevent those from running too long before being terminated. + timeout-minutes: 15 + runs-on: ${{ matrix.os }} + permissions: + id-token: write + contents: read + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11'] + os: [ubuntu-latest, macOS-latest] + env: + PYTHON: ${{ matrix.python-version }} + CODEARTIFACT_REGION: "us-west-2" + CODEARTIFACT_DOMAIN: ${{ secrets.CODEARTIFACT_DOMAIN }} + CODEARTIFACT_ACCOUNT_ID: ${{ secrets.CODEARTIFACT_ACCOUNT_ID }} + CODEARTIFACT_REPOSITORY: ${{ secrets.CODEARTIFACT_REPOSITORY }} + steps: + - uses: actions/checkout@v3 + if: ${{ !inputs.branch }} + + - uses: actions/checkout@v3 + if: ${{ inputs.branch }} + with: + ref: ${{ inputs.branch }} + fetch-depth: 0 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Configure AWS credentials + uses: aws-actions/configure-aws-credentials@v3 + with: + role-to-assume: ${{ secrets.AWS_CODEARTIFACT_ROLE }} + aws-region: us-west-2 + mask-aws-account-id: true + + - name: Install Hatch + run: | + CODEARTIFACT_AUTH_TOKEN=$(aws codeartifact get-authorization-token --domain ${{ secrets.CODEARTIFACT_DOMAIN }} --domain-owner ${{ secrets.CODEARTIFACT_ACCOUNT_ID }} --query authorizationToken --output text --region us-west-2) + echo "::add-mask::$CODEARTIFACT_AUTH_TOKEN" + echo CODEARTIFACT_AUTH_TOKEN=$CODEARTIFACT_AUTH_TOKEN >> $GITHUB_ENV + pip install --upgrade -r requirements-development.txt + + - name: Run Linting + run: hatch run lint + + - name: Run Tests + run: hatch run test \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f30fb6f --- /dev/null +++ b/.gitignore @@ -0,0 +1,24 @@ +*~ +*# +*.swp + +*.DS_Store + +__pycache__/ +*.py[cod] +*$py.class +*.egg-info/ + +/.coverage +/.coverage.* +/.cache +/.pytest_cache +/.mypy_cache +/.ruff_cache +/.attach_pid* +/.venv + +/doc/_apidoc/ +/build +/dist +_version.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c4b6a1c..0b3cabf 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -6,7 +6,6 @@ documentation, we greatly value feedback and contributions from our community. Please read through this document before submitting any issues or pull requests to ensure we have all the necessary information to effectively respond to your bug report or contribution. - ## Reporting Bugs/Feature Requests We welcome you to use the GitHub issue tracker to report bugs or suggest features. @@ -19,8 +18,12 @@ reported the issue. Please try to include as much information as you can. Detail * Any modifications you've made relevant to the bug * Anything unusual about your environment or deployment +## Development + +Please see [DEVELOPMENT.md](./DEVELOPMENT.md) for more information. ## Contributing via Pull Requests + Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 1. You are working against the latest source on the *main* branch. @@ -39,18 +42,20 @@ To send us a pull request, please: GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). - ## Finding contributions to work on + Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. ## Code of Conduct + This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact opensource-codeofconduct@amazon.com with any additional questions or comments. ## Security issue notifications + If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md new file mode 100644 index 0000000..2bd0787 --- /dev/null +++ b/DEVELOPMENT.md @@ -0,0 +1,160 @@ + +# Development + +## Command Reference + +``` +# Build the package +hatch build + +# Run tests +hatch run test + +# Run linting +hatch run lint + +# Run formatting +hatch run fmt + +# Run a full test +hatch run all:test +``` + +## The Package's Public Interface + +This package is a library wherein we are explicit and intentional with what we expose as public. + +The standard convention in Python is to prefix things with an underscore character ('_') to +signify that the thing is private to the implementation, and is not intended to be used by +external consumers of the thing. + +We use this convention in this package in two ways: + +1. In filenames. + 1. Any file whose name is not prefixed with an underscore **is** a part of the public + interface of this package. The name may not change and public symbols (classes, modules, + functions, etc) defined in the file may not be moved to other files or renamed without a + major version number change. + 2. Any file whose name is prefixed with an underscore is an internal module of the package + and is not part of the public interface. These files can be renamed, refactored, have symbols + renamed, etc. Any symbol defined in one of these files that is intended to be part of this + package's public interface must be imported into an appropriate `__init__.py` file. +2. Every symbol that is defined or imported in a public module and is not intended to be part + of the module's public interface is prefixed with an underscore. + +For example, a public module in this package will be defined with the following style: + +```python +# The os module is not part of this file's external interface +import os as _os + +# PublicClass is part of this file's external interface. +class PublicClass: + def publicmethod(self): + pass + + def _privatemethod(self): + pass + +# _PrivateClass is not part of this file's external interface. +class _PrivateClass: + def publicmethod(self): + pass + + def _privatemethod(self): + pass +``` + +### On `import os as _os` + +Every module/symbol that is imported into a Python module becomes a part of that module's interface. +Thus, if we have a module called `foo.py` such as: + +```python +# foo.py + +import os +``` + +Then, the `os` module becomes part of the public interface for `foo.py` and a consumer of that module +is free to do: + +```python +from foo import os +``` + +We don't want all (generally, we don't want any) of our imports to become part of the public API for +the module, so we import modules/symbols into a public module with the following style: + +```python +import os as _os +from typing import Dict as _Dict +``` + +## Use of Keyword-Only Arguments + +Another convention that we are adopting in this package is that all functions/methods that are a +part of the package's external interface should refrain from using positional-or-keyword arguments. +All arguments should be keyword-only unless the argument name has no true external meaning (e.g. +arg1, arg2, etc for `min`). Benefits of this convention are: + +1. All uses of the public APIs of this package are forced to be self-documenting; and +2. The benefits set forth in PEP 570 ( https://www.python.org/dev/peps/pep-0570/#problems-without-positional-only-parameters ). + +## Exceptions + +All functions/methods that raise an exception should have a section in their docstring that states +the exception(s) they raise. e.g. + +```py +def my_function(key, value): +"""Does something... + + Raises: + KeyError: when the key is not valid + ValueError: when the value is not valid +""" +``` + +All function/method calls that can raise an exception should have a comment in the line above +that states which exception(s) can be raised. e.g. + +```py +try: + # Raises: KeyError, ValueError + my_function("key", "value") +except ValueError as e: + # Error handling... +``` + +## About the data model + +1. The data model is written using Pydantic. Pydantic provides the framework for parsing and validating + input job templates. +1. We intentionally use `Decimal` in place of `float` in our data models as `Decimal` will preserve the + precision present in the input whereas `float` will not. +1. Some classes in the model have "Definition" or "Template" version, as well as a "Target model" version. + These exist for parts of the model where instantiating a JobTemplate into a Job changes the model in + some way. These are not necessary for all parts of the model. + +## Super verbose test output + +If you find that you need much more information from a failing test (say you're debugging a +deadlocking test) then a way to get verbose output from the test is to enable Pytest +[Live Logging](https://docs.pytest.org/en/latest/how-to/logging.html#live-logs): + +1. Add a `pytest.ini` to the root directory of the repository that contains (Note: for some reason, +setting `log_cli` and `log_cli_level` in `pyproject.toml` does not work, nor does setting the options +on the command-line; if you figure out how to get it to work then please update this section): +``` +[pytest] +xfail_strict = False +log_cli = true +log_cli_level = 10 +``` +2. Modify `pyproject.toml` to set the following additional `addopts` in the `tool.pytest.ini_options` section: +``` + "-vvvvv", + "--numprocesses=1" +``` +3. Add logging statements to your tests as desired and run the test(s) that you are debugging. diff --git a/README.md b/README.md index 847260c..c13eb32 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,7 @@ -## My Project +# Open Job Description - Adaptor Runtime -TODO: Fill this README out! - -Be sure to: - -* Change the title in this README -* Edit your repository description on GitHub +This package provides a runtime library that can be used to implement a CLI adaptor interface around +a desired application. ## Security @@ -14,4 +10,3 @@ See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more inform ## License This project is licensed under the Apache-2.0 License. - diff --git a/THIRD-PARTY-LICENSES.txt b/THIRD-PARTY-LICENSES.txt new file mode 100644 index 0000000..2705add --- /dev/null +++ b/THIRD-PARTY-LICENSES.txt @@ -0,0 +1,47 @@ +** Python-jsonschema; version 4.19 -- https://github.com/python-jsonschema/jsonschema +Copyright (c) 2013 Julian Berman + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +------ + +** PyYAML; version 6.0 -- https://pyyaml.org/ +Copyright (c) 2017-2021 Ingy dΓΆt Net +Copyright (c) 2006-2016 Kirill Simonov + +Copyright (c) 2017-2021 Ingy dΓΆt Net +Copyright (c) 2006-2016 Kirill Simonov + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/hatch.toml b/hatch.toml new file mode 100644 index 0000000..e8d1365 --- /dev/null +++ b/hatch.toml @@ -0,0 +1,33 @@ +[envs.default] +pre-install-commands = [ + "pip install -r requirements-testing.txt" +] + +[envs.default.scripts] +sync = "pip install -r requirements-testing.txt" +test = "pytest --cov-config pyproject.toml {args:test}" +typing = "mypy {args:src test}" +style = [ + "ruff {args:.}", + "black --check --diff {args:.}", +] +fmt = [ + "black {args:.}", + "style", +] +lint = [ + "style", + "typing", +] + +[[envs.all.matrix]] +python = ["3.9", "3.10", "3.11"] + +[envs.default.env-vars] +PIP_INDEX_URL="https://aws:{env:CODEARTIFACT_AUTH_TOKEN}@{env:CODEARTIFACT_DOMAIN}-{env:CODEARTIFACT_ACCOUNT_ID}.d.codeartifact.{env:CODEARTIFACT_REGION}.amazonaws.com/pypi/{env:CODEARTIFACT_REPOSITORY}/simple/" + +[envs.codebuild.scripts] +build = "hatch build" + +[envs.codebuild.env-vars] +PIP_INDEX_URL="" diff --git a/hatch_version_hook.py b/hatch_version_hook.py new file mode 100644 index 0000000..d7024ea --- /dev/null +++ b/hatch_version_hook.py @@ -0,0 +1,156 @@ +import logging +import os +import shutil +import sys + +from dataclasses import dataclass +from hatchling.builders.hooks.plugin.interface import BuildHookInterface +from typing import Any, Optional + + +_logger = logging.Logger(__name__, logging.INFO) +_stdout_handler = logging.StreamHandler(sys.stdout) +_stdout_handler.addFilter(lambda record: record.levelno <= logging.INFO) +_stderr_handler = logging.StreamHandler(sys.stderr) +_stderr_handler.addFilter(lambda record: record.levelno > logging.INFO) +_logger.addHandler(_stdout_handler) +_logger.addHandler(_stderr_handler) + + +@dataclass +class CopyConfig: + sources: list[str] + destinations: list[str] + + +class CustomBuildHookException(Exception): + pass + + +class CustomBuildHook(BuildHookInterface): + """ + A Hatch build hook that is pulled in automatically by Hatch's "custom" hook support + See: https://hatch.pypa.io/1.6/plugins/build-hook/custom/ + This build hook copies files from one location (sources) to another (destinations). + Config options: + - `log_level (str)`: The logging level. Any value accepted by logging.Logger.setLevel is allowed. Default is INFO. + - `copy_map (list[dict])`: A list of mappings of files to copy and the destinations to copy them into. In TOML files, + this is expressed as an array of tables. See https://toml.io/en/v1.0.0#array-of-tables + Example TOML config: + ``` + [tool.hatch.build.hooks.custom] + path = "hatch_hook.py" + log_level = "DEBUG" + [[tool.hatch.build.hooks.custom.copy_map]] + sources = [ + "_version.py", + ] + destinations = [ + "src/openjd", + ] + [[tool.hatch.build.hooks.custom.copy_map]] + sources = [ + "something_the_tests_need.py", + "something_else_the_tests_need.ini", + ] + destinations = [ + "test/openjd", + ] + ``` + """ + + REQUIRED_OPTS = [ + "copy_map", + ] + + def initialize(self, version: str, build_data: dict[str, Any]) -> None: + if not self._prepare(): + return + + for copy_cfg in self.copy_map: + _logger.info(f"Copying {copy_cfg.sources} to {copy_cfg.destinations}") + for destination in copy_cfg.destinations: + for source in copy_cfg.sources: + copy_func = shutil.copy if os.path.isfile(source) else shutil.copytree + copy_func( + os.path.join(self.root, source), + os.path.join(self.root, destination), + ) + _logger.info("Copy complete") + + def clean(self, versions: list[str]) -> None: + if not self._prepare(): + return + + for copy_cfg in self.copy_map: + _logger.info(f"Cleaning {copy_cfg.sources} from {copy_cfg.destinations}") + cleaned_count = 0 + for destination in copy_cfg.destinations: + for source in copy_cfg.sources: + source_path = os.path.join(self.root, destination, source) + remove_func = os.remove if os.path.isfile(source_path) else os.rmdir + try: + remove_func(source_path) + except FileNotFoundError: + _logger.debug(f"Skipping {source_path} because it does not exist...") + else: + cleaned_count += 1 + _logger.info(f"Cleaned {cleaned_count} items") + + def _prepare(self) -> bool: + missing_required_opts = [ + opt for opt in self.REQUIRED_OPTS if opt not in self.config or not self.config[opt] + ] + if missing_required_opts: + _logger.warn( + f"Required options {missing_required_opts} are missing or empty. " + "Contining without copying sources to destinations...", + file=sys.stderr, + ) + return False + + log_level = self.config.get("log_level") + if log_level: + _logger.setLevel(log_level) + + return True + + @property + def copy_map(self) -> Optional[list[CopyConfig]]: + raw_copy_map: list[dict] = self.config.get("copy_map") + if not raw_copy_map: + return None + + if not ( + isinstance(raw_copy_map, list) + and all(isinstance(copy_cfg, dict) for copy_cfg in raw_copy_map) + ): + raise CustomBuildHookException( + f'"copy_map" config option is a nonvalid type. Expected list[dict], but got {raw_copy_map}' + ) + + def verify_list_of_file_paths(file_paths: Any, config_name: str): + if not (isinstance(file_paths, list) and all(isinstance(fp, str) for fp in file_paths)): + raise CustomBuildHookException( + f'"{config_name}" config option is a nonvalid type. Expected list[str], but got {file_paths}' + ) + + missing_paths = [ + fp for fp in file_paths if not os.path.exists(os.path.join(self.root, fp)) + ] + if len(missing_paths) > 0: + raise CustomBuildHookException( + f'"{config_name}" config option contains some file paths that do not exist: {missing_paths}' + ) + + copy_map: list[CopyConfig] = [] + for copy_cfg in raw_copy_map: + destinations: list[str] = copy_cfg.get("destinations") + verify_list_of_file_paths(destinations, "destinations") + + sources: list[str] = copy_cfg.get("sources") + verify_list_of_file_paths(sources, "source") + + copy_map.append(CopyConfig(sources, destinations)) + + return copy_map diff --git a/pipeline/build.sh b/pipeline/build.sh new file mode 100755 index 0000000..e8014f3 --- /dev/null +++ b/pipeline/build.sh @@ -0,0 +1,10 @@ +#!/bin/sh +# Set the -e option +set -e + +pip install --upgrade pip +pip install --upgrade hatch +pip install --upgrade twine +hatch run codebuild:lint +hatch run codebuild:test +hatch run codebuild:build \ No newline at end of file diff --git a/pipeline/publish.sh b/pipeline/publish.sh new file mode 100755 index 0000000..9e9c1b5 --- /dev/null +++ b/pipeline/publish.sh @@ -0,0 +1,6 @@ +#!/bin/sh +# Set the -e option +set -e + +./pipeline/build.sh +twine upload --repository codeartifact dist/* --verbose \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..278842c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,128 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "openjd-adaptor-runtime" +dynamic = ["version"] +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.9" + +dependencies = [ + "pyyaml ~= 6.0", + "jsonschema >= 4.19.0, == 4.*", +] + +[tool.hatch.build] +artifacts = [ + "*_version.py" +] +only-pacakges = true + +[tool.hatch.version] +source = "vcs" + +[tool.hatch.version.raw-options] +version_scheme = "post-release" + +[tool.hatch.build.hooks.vcs] +version-file = "_version.py" + +[tool.hatch.build.hooks.custom] +path = "hatch_version_hook.py" + +[[tool.hatch.build.hooks.custom.copy_map]] +sources = [ + "_version.py", +] +destinations = [ + "src/openjd/adaptor_runtime", + "src/openjd/adaptor_runtime_client", +] + +[tool.hatch.build.targets.sdist] +packages = [ + "src/openjd", +] +only-include = [ + "src/openjd", +] + +[tool.hatch.build.targets.wheel] +packages = [ + "src/openjd", +] +only-include = [ + "src/openjd", +] + +[tool.mypy] +check_untyped_defs = false +show_error_codes = false +pretty = true +ignore_missing_imports = true +disallow_incomplete_defs = false +disallow_untyped_calls = false +show_error_context = true +strict_equality = false +python_version = 3.9 +warn_redundant_casts = true +warn_unused_configs = true +warn_unused_ignores = false +# Tell mypy that there's a namespace package at src/openjd +namespace_packages = true +explicit_package_bases = true +mypy_path = "src" + +[tool.ruff] +ignore = [ + "E501", + # Double Check if this should be fixed + "E731", +] +line-length = 100 + + +[tool.ruff.pep8-naming] +classmethod-decorators = [ + "classmethod", +] + +[tool.ruff.isort] +known-first-party = [ + "openjd", +] + +[tool.black] +line-length = 100 + +[tool.pytest.ini_options] +xfail_strict = false +addopts = [ + "-rfEx", + "--durations=5", + "--cov=src/openjd/adaptor_runtime", + "--cov=src/openjd/adaptor_runtime_client", + "--color=yes", + "--cov-report=html:build/coverage", + "--cov-report=xml:build/coverage/coverage.xml", + "--cov-report=term-missing", + "--numprocesses=auto", + "--timeout=30" +] + + +[tool.coverage.run] +branch = true +parallel = true + + +[tool.coverage.paths] +source = [ + "src/" +] + +[tool.coverage.report] +show_missing = true +fail_under = 94 diff --git a/requirements-development.txt b/requirements-development.txt new file mode 100644 index 0000000..c9ba6c2 --- /dev/null +++ b/requirements-development.txt @@ -0,0 +1,2 @@ +hatch >= 1.7.0, == 1.* +hatch-vcs ~= 0.3.0 \ No newline at end of file diff --git a/requirements-testing.txt b/requirements-testing.txt new file mode 100644 index 0000000..916da56 --- /dev/null +++ b/requirements-testing.txt @@ -0,0 +1,10 @@ +coverage[toml] >= 7.3, == 7.* +pytest >= 7.4, == 7.* +pytest-cov >= 4.1, == 4.* +pytest-timeout >= 2.1, == 2.* +pytest-xdist >= 3.3.1, == 3.* +black >= 23.7, == 23.* +ruff >= 0.0.286, == 0.0.* +mypy >= 1.5.1, == 1.5.* +types-PyYAML ~= 6.0 +psutil ~= 5.9.5 diff --git a/scripts/add_copyright_headers.sh b/scripts/add_copyright_headers.sh new file mode 100755 index 0000000..7992a66 --- /dev/null +++ b/scripts/add_copyright_headers.sh @@ -0,0 +1,75 @@ +#!/bin/bash +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +if [ $# -eq 0 ]; then + echo "Usage: add-copyright-headers ..." >&2 + exit 1 +fi + +for file in "$@"; do + if ! head -1 | grep 'Copyright ' "$file" >/dev/null; then + case "$file" in + *.java) + CONTENT=$(cat "$file") + cat > "$file" </dev/null; then + CONTENT=$(tail -n +2 "$file") + cat > "$file" < +$CONTENT +EOF + else + CONTENT=$(cat "$file") + cat > "$file" < +$CONTENT +EOF + fi + ;; + *.py) + CONTENT=$(cat "$file") + cat > "$file" < "$file" < "$file" < "$file" <&2 + exit 1 + ;; + esac + fi +done \ No newline at end of file diff --git a/src/openjd/adaptor_runtime/__init__.py b/src/openjd/adaptor_runtime/__init__.py new file mode 100644 index 0000000..e945b9a --- /dev/null +++ b/src/openjd/adaptor_runtime/__init__.py @@ -0,0 +1,7 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from ._entrypoint import EntryPoint + +__all__ = [ + "EntryPoint", +] diff --git a/src/openjd/adaptor_runtime/_background/__init__.py b/src/openjd/adaptor_runtime/_background/__init__.py new file mode 100644 index 0000000..f0d8ebe --- /dev/null +++ b/src/openjd/adaptor_runtime/_background/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from .backend_runner import BackendRunner +from .frontend_runner import FrontendRunner +from .log_buffers import InMemoryLogBuffer, FileLogBuffer, LogBufferHandler + +__all__ = [ + "BackendRunner", + "FrontendRunner", + "InMemoryLogBuffer", + "FileLogBuffer", + "LogBufferHandler", +] diff --git a/src/openjd/adaptor_runtime/_background/backend_runner.py b/src/openjd/adaptor_runtime/_background/backend_runner.py new file mode 100644 index 0000000..9fd837a --- /dev/null +++ b/src/openjd/adaptor_runtime/_background/backend_runner.py @@ -0,0 +1,111 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import json +import logging +import os +import signal +from queue import Queue +from threading import Thread +from types import FrameType +from typing import Optional + +from ..adaptors import AdaptorRunner +from .._http import SocketDirectories +from .._utils import secure_open +from .http_server import BackgroundHTTPServer +from .log_buffers import LogBuffer +from .model import ConnectionSettings +from .model import DataclassJSONEncoder + +_logger = logging.getLogger(__name__) + + +class BackendRunner: + """ + Class that runs the backend logic in background mode. + """ + + def __init__( + self, + adaptor_runner: AdaptorRunner, + connection_file_path: str, + *, + log_buffer: LogBuffer | None = None, + ) -> None: + self._adaptor_runner = adaptor_runner + self._connection_file_path = connection_file_path + self._log_buffer = log_buffer + self._http_server: Optional[BackgroundHTTPServer] = None + signal.signal(signal.SIGINT, self._sigint_handler) + signal.signal(signal.SIGTERM, self._sigint_handler) + + def _sigint_handler(self, signum: int, frame: Optional[FrameType]) -> None: + """Signal handler that is invoked when the process receives a SIGINT/SIGTERM""" + _logger.info("Interruption signal recieved.") + # OpenJD dictates that a SIGTERM/SIGINT results in a cancel workflow being + # kicked off. + if self._http_server is not None: + self._http_server.submit(self._adaptor_runner._cancel, force_immediate=True) + + def run(self) -> None: + """ + Runs the backend logic for background mode. + + This function will start an HTTP server that picks an available port to listen on, write + that port to a connection file, and listens for HTTP requests until a shutdown is requested + """ + _logger.info("Running in background daemon mode.") + + queue: Queue = Queue() + + socket_path = SocketDirectories.for_os().get_process_socket_path("runtime", create_dir=True) + + try: + self._http_server = BackgroundHTTPServer( + socket_path, + self._adaptor_runner, + cancel_queue=queue, + log_buffer=self._log_buffer, + ) + except Exception as e: + _logger.error(f"Error starting in background mode: {e}") + raise + + _logger.debug(f"Listening on {socket_path}") + http_thread = Thread( + name="AdaptorRuntimeBackendHttpThread", target=self._http_server.serve_forever + ) + http_thread.start() + + try: + with secure_open(self._connection_file_path, open_mode="w") as conn_file: + json.dump( + ConnectionSettings(socket_path), + conn_file, + cls=DataclassJSONEncoder, + ) + except OSError as e: + _logger.error(f"Error writing to connection file: {e}") + _logger.info("Shutting down server...") + queue.put(True) + raise + finally: + # Block until the cancel queue has been pushed to + queue.get() + + # Shutdown the server + self._http_server.shutdown() + http_thread.join() + + # Cleanup the connection file and socket + for path in [self._connection_file_path, socket_path]: + try: + os.remove(path) + except FileNotFoundError: # pragma: no cover + pass # File is already cleaned up + except OSError as e: # pragma: no cover + _logger.warning(f"Failed to delete {path}: {e}") + + _logger.info("HTTP server has shutdown.") diff --git a/src/openjd/adaptor_runtime/_background/frontend_runner.py b/src/openjd/adaptor_runtime/_background/frontend_runner.py new file mode 100644 index 0000000..820a197 --- /dev/null +++ b/src/openjd/adaptor_runtime/_background/frontend_runner.py @@ -0,0 +1,332 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import http.client as http_client +import json +import logging +import os +import signal +import socket +import subprocess +import sys +import time +import urllib.parse as urllib_parse +from threading import Event +from types import FrameType +from types import ModuleType +from typing import Optional + +from ..process._logging import _ADAPTOR_OUTPUT_LEVEL +from .model import ( + AdaptorState, + AdaptorStatus, + BufferedOutput, + ConnectionSettings, + DataclassJSONEncoder, + DataclassMapper, + HeartbeatResponse, +) + +_logger = logging.getLogger(__name__) + + +class FrontendRunner: + """ + Class that runs the frontend logic in background mode. + """ + + def __init__( + self, + connection_file_path: str, + *, + timeout_s: float = 5.0, + heartbeat_interval: float = 1.0, + ) -> None: + """ + Args: + connection_file_path (str): Absolute path to the connection file. + timeout_s (float, optional): Timeout for HTTP requests, in seconds. Defaults to 5. + heartbeat_interval (float, optional): Interval between heartbeats, in seconds. + Defaults to 1. + """ + self._timeout_s = timeout_s + self._heartbeat_interval = heartbeat_interval + self._connection_file_path = connection_file_path + self._canceled = Event() + signal.signal(signal.SIGINT, self._sigint_handler) + signal.signal(signal.SIGTERM, self._sigint_handler) + + def init(self, adaptor_module: ModuleType, init_data: dict = {}) -> None: + """ + Creates the backend process then sends a heartbeat request to verify that it has started + successfully. + + Args: + adaptor_module (ModuleType): The module of the adaptor running the runtime. + """ + if adaptor_module.__package__ is None: + raise Exception(f"Adaptor module is not a package: {adaptor_module}") + + if os.path.exists(self._connection_file_path): + raise FileExistsError( + "Cannot init a new backend process with an existing connection file at: " + + self._connection_file_path + ) + + _logger.info("Initializing backend process...") + args = [ + sys.executable, + "-m", + adaptor_module.__package__, + "daemon", + "_serve", + "--connection-file", + self._connection_file_path, + "--init-data", + json.dumps(init_data), + ] + try: + process = subprocess.Popen( + args, + shell=False, + close_fds=True, + start_new_session=True, + stdin=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + except Exception as e: + _logger.error(f"Failed to initialize backend process: {e}") + raise + _logger.info(f"Started backend process. PID: {process.pid}") + + # Wait for backend process to create connection file + try: + _wait_for_file(self._connection_file_path, timeout_s=5) + except TimeoutError: + _logger.error( + "Backend process failed to write connection file in time at: " + + self._connection_file_path + ) + raise + + # Heartbeat to ensure backend process is listening for requests + _logger.info("Verifying connection to backend...") + self._heartbeat() + _logger.info("Connected successfully") + + def run(self, run_data: dict) -> None: + """ + Sends a run request to the backend + """ + self._send_request("PUT", "/run", json_body=run_data) + self._heartbeat_until_state_complete(AdaptorState.RUN) + + def start(self) -> None: + """ + Sends a start request to the backend + """ + self._send_request("PUT", "/start") + self._heartbeat_until_state_complete(AdaptorState.START) + + def stop(self) -> None: + """ + Sends an end request to the backend + """ + self._send_request("PUT", "/stop") + # The backend calls end then cleanup on the adaptor, so we wait until cleanup is complete. + self._heartbeat_until_state_complete(AdaptorState.CLEANUP) + + def shutdown(self) -> None: + """ + Sends a shutdown request to the backend + """ + self._send_request("PUT", "/shutdown") + + def cancel(self) -> None: + """ + Sends a cancel request to the backend + """ + self._send_request("PUT", "/cancel") + self._canceled.set() + + def _heartbeat(self, ack_id: str | None = None) -> HeartbeatResponse: + """ + Sends a heartbeat request to the backend. + + Args: + ack_id (str): The heartbeat output ID to ACK. Defaults to None. + """ + params: dict[str, str] | None = {"ack_id": ack_id} if ack_id else None + response = self._send_request("GET", "/heartbeat", params=params) + + return DataclassMapper(HeartbeatResponse).map(json.load(response.fp)) + + def _heartbeat_until_state_complete(self, state: AdaptorState) -> None: + """ + Heartbeats with the backend until it transitions to the specified state and is idle. + + Args: + state (AdaptorState): The final state the adaptor should be in. + + Raises: + AdaptorFailedException: Raised when the adaptor reports a failure. + """ + failure_message = None + ack_id = None + while True: + _logger.debug("Sending heartbeat request...") + heartbeat = self._heartbeat(ack_id) + _logger.debug(f"Heartbeat response: {json.dumps(heartbeat, cls=DataclassJSONEncoder)}") + for line in heartbeat.output.output.splitlines(): + _logger.log(_ADAPTOR_OUTPUT_LEVEL, line) + + if heartbeat.failed: + failure_message = heartbeat.output.output + + ack_id = heartbeat.output.id + if ( + heartbeat.state in [state, AdaptorState.CANCELED] + and heartbeat.status == AdaptorStatus.IDLE + ): + break + else: + if not self._canceled.is_set(): + self._canceled.wait(timeout=self._heartbeat_interval) + else: + # We've been canceled. Do a small sleep to give it time to take effect. + time.sleep(0.25) + + # Send one last heartbeat to ACK the previous heartbeat output if any + if ack_id != BufferedOutput.EMPTY: # pragma: no branch + _logger.debug("ACKing last heartbeat...") + heartbeat = self._heartbeat(ack_id) + + # Raise a failure exception if the adaptor failed + if failure_message: + raise AdaptorFailedException(failure_message) + + def _send_request( + self, + method: str, + path: str, + *, + params: dict | None = None, + json_body: dict | None = None, + ) -> http_client.HTTPResponse: + conn = UnixHTTPConnection(self.connection_settings.socket, timeout=self._timeout_s) + + if params: + query_str = urllib_parse.urlencode(params, doseq=True) + path = f"{path}?{query_str}" + + body = json.dumps(json_body) if json_body else None + + conn.request(method, path, body=body) + try: + response = conn.getresponse() + except http_client.HTTPException as e: + _logger.error(f"Failed to send {path} request: {e}") + raise + finally: + conn.close() + + if response.status >= 400 and response.status < 600: + errmsg = f"Received unexpected HTTP status code {response.status}: {response.reason}" + _logger.error(errmsg) + raise HTTPError(response, errmsg) + + return response + + @property + def connection_settings(self) -> ConnectionSettings: + """ + Gets the lazy-loaded connection settings. + """ + if not hasattr(self, "_connection_settings"): + self._connection_settings = _load_connection_settings(self._connection_file_path) + return self._connection_settings + + def _sigint_handler(self, signum: int, frame: Optional[FrameType]) -> None: + """Signal handler that is invoked when the process receives a SIGINT/SIGTERM""" + _logger.info("Interruption signal recieved.") + # OpenJD dictates that a SIGTERM/SIGINT results in a cancel workflow being + # kicked off. + self.cancel() + + +def _load_connection_settings(path: str) -> ConnectionSettings: + try: + with open(path) as conn_file: + loaded_settings = json.load(conn_file) + except OSError as e: + _logger.error(f"Failed to open connection file: {e}") + raise + except json.JSONDecodeError as e: + _logger.error(f"Failed to decode connection file: {e}") + raise + return DataclassMapper(ConnectionSettings).map(loaded_settings) + + +def _wait_for_file(filepath: str, timeout_s: float, interval_s: float = 1) -> None: + """ + Waits for a file at the specified path to exist and to be openable. + + Args: + filepath (str): The file path to check. + timeout_s (float): The max duration to wait before timing out, in seconds. + interval_s (float, optional): The interval between checks, in seconds. Default is 0.01s. + + Raises: + TimeoutError: Raised when the file does not exist after timeout_s seconds. + """ + + def _wait(): + if time.time() - start < timeout_s: + time.sleep(interval_s) + else: + raise TimeoutError(f"Timed out after {timeout_s}s waiting for file at {filepath}") + + start = time.time() + while not os.path.exists(filepath): + _wait() + + while True: + # Wait before opening to give the backend time to open it first + _wait() + try: + open(filepath, mode="r").close() + break + except IOError: + # File is not available yet + pass + + +class AdaptorFailedException(Exception): + pass + + +class HTTPError(http_client.HTTPException): + response: http_client.HTTPResponse + + def __init__(self, response: http_client.HTTPResponse, *args: object) -> None: + super().__init__(*args) + self.response = response + + +class UnixHTTPConnection(http_client.HTTPConnection): + """ + Specialization of http.client.HTTPConnection class that uses a UNIX domain socket. + """ + + def __init__(self, host, **kwargs): + self.socket_path = host + kwargs.pop("strict", None) # Removed in py3 + super(UnixHTTPConnection, self).__init__("localhost", **kwargs) + + def connect(self): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.settimeout(self.timeout) + sock.connect(self.socket_path) + self.sock = sock diff --git a/src/openjd/adaptor_runtime/_background/http_server.py b/src/openjd/adaptor_runtime/_background/http_server.py new file mode 100644 index 0000000..8fba102 --- /dev/null +++ b/src/openjd/adaptor_runtime/_background/http_server.py @@ -0,0 +1,298 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import json +import logging +import re +import socketserver +import time +from concurrent.futures import Future +from concurrent.futures import ThreadPoolExecutor +from http import HTTPStatus +from queue import Queue +from typing import Callable + +from ..adaptors._adaptor_runner import _OPENJD_FAIL_STDOUT_PREFIX +from ..adaptors import AdaptorRunner +from .._http import HTTPResponse, RequestHandler, ResourceRequestHandler +from .log_buffers import LogBuffer +from .model import ( + AdaptorState, + AdaptorStatus, + BufferedOutput, + DataclassJSONEncoder, + HeartbeatResponse, +) + +_logger = logging.getLogger(__name__) + + +class AsyncFutureRunner: + """ + Class that models an asynchronous worker thread using concurrent.futures. + """ + + _WAIT_FOR_START_INTERVAL = 0.01 + + def __init__(self) -> None: + self._thread_pool = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="AdaptorRuntimeBackendWorkerThread" + ) + self._future: Future | None = None + + def submit(self, fn: Callable, *args, **kwargs) -> None: + if self.is_running: + raise Exception("Cannot submit new task while another task is running") + self._future = self._thread_pool.submit(fn, *args, **kwargs) + + @property + def is_running(self) -> bool: + if self._future is None: + return False + return self._future.running() + + @property + def has_started(self) -> bool: + if self._future is None: + return False # pragma: no cover + return self._future.running() or self._future.done() + + def wait_for_start(self): + """Blocks until the Future has started""" + while not self.has_started: + time.sleep(self._WAIT_FOR_START_INTERVAL) + + +class BackgroundHTTPServer(socketserver.UnixStreamServer): + """ + HTTP server for the background mode of the adaptor runtime communicating via Unix socket. + + This UnixStreamServer subclass stores the stateful information of the adaptor backend. + """ + + def __init__( + self, + socket_path: str, + adaptor_runner: AdaptorRunner, + cancel_queue: Queue, + *, + log_buffer: LogBuffer | None = None, + bind_and_activate: bool = True, + ) -> None: # pragma: no cover + super().__init__(socket_path, BackgroundRequestHandler, bind_and_activate) + self._adaptor_runner = adaptor_runner + self._cancel_queue = cancel_queue + self._future_runner = AsyncFutureRunner() + self._log_buffer = log_buffer + + def submit(self, fn: Callable, *args, force_immediate=False, **kwargs) -> HTTPResponse: + """ + Submits work to the server. + + Args: + force_immediate (bool): Force the server to immediately start the work. This work will + be performed concurrently with any ongoing work. + """ + future_runner = self._future_runner if not force_immediate else AsyncFutureRunner() + try: + future_runner.submit(fn, *args, **kwargs) + except Exception as e: + _logger.error(f"Failed to submit work: {e}") + return HTTPResponse(HTTPStatus.INTERNAL_SERVER_ERROR, body=str(e)) + + # Wait for the worker thread to start working before sending the response + self._future_runner.wait_for_start() + return HTTPResponse(HTTPStatus.OK) + + +class BackgroundRequestHandler(RequestHandler): + """ + Class that handles HTTP requests to a BackgroundHTTPServer. + + Note: The "server" argument passed to this class must be an instance of BackgroundHTTPServer + and the server must listen for requests using UNIX domain sockets. + """ + + def __init__( + self, request: bytes, client_address: str, server: socketserver.BaseServer + ) -> None: + if not isinstance(server, BackgroundHTTPServer): + raise TypeError( + "Received incompatible server class. " + f"Expected {BackgroundHTTPServer.__name__}, but got {type(server)}" + ) + super().__init__( + request, + client_address, + server, + BackgroundResourceRequestHandler, + ) + + +class BackgroundResourceRequestHandler(ResourceRequestHandler): + """ + Base class that handles HTTP requests for a specific resource. + + This class only works with a BackgroundHTTPServer. + """ + + @property + def server(self) -> BackgroundHTTPServer: + """ + Property to "lazily type check" the HTTP server class this handler is used in. + + This is required because the socketserver.BaseRequestHandler.__init__ method actually + handles the request. This means the self.handler.server variable is not set until that + init method is called, so we need to do this type check outside of the init chain. + + Raises: + TypeError: Raised when the HTTP server class is not BackgroundHTTPServer. + """ + + if not isinstance(self.handler.server, BackgroundHTTPServer): + raise TypeError( + f"Incompatible HTTP server class. Expected {BackgroundHTTPServer.__name__}, got: " + + type(self.handler.server).__name__ + ) + + return self.handler.server + + +class HeartbeatHandler(BackgroundResourceRequestHandler): + """ + Handler for the heartbeat resource + """ + + # Failure messages are in the form: ": openjd_fail: " + _FAILURE_REGEX = f"^(?:\\w+: )?{re.escape(_OPENJD_FAIL_STDOUT_PREFIX)}" + _ACK_ID_KEY = "ack_id" + + path: str = "/heartbeat" + + def get(self) -> HTTPResponse: + failed = False + if not self.server._log_buffer: + output = BufferedOutput(BufferedOutput.EMPTY, "") + else: + # Check for chunk ID ACKs + ack_id = self._parse_ack_id() + if ack_id: + if self.server._log_buffer.clear(ack_id): + _logger.debug(f"Received ACK for chunk: {ack_id}") + else: + _logger.warning(f"Received ACK for old or invalid chunk: {ack_id}") + + output = self.server._log_buffer.chunk() + + if re.search(self._FAILURE_REGEX, output.output, re.MULTILINE): + failed = True + + status = ( + AdaptorStatus.WORKING if self.server._future_runner.is_running else AdaptorStatus.IDLE + ) + + heartbeat = HeartbeatResponse( + state=self.server._adaptor_runner.state, status=status, output=output, failed=failed + ) + return HTTPResponse(HTTPStatus.OK, json.dumps(heartbeat, cls=DataclassJSONEncoder)) + + def _parse_ack_id(self) -> str | None: + """ + Parses chunk ID ACK from the query string. Returns None if the chunk ID ACK was not found. + """ + if self._ACK_ID_KEY in self.query_string_params: + ack_ids: list[str] = self.query_string_params[self._ACK_ID_KEY] + if len(ack_ids) > 1: + raise ValueError( + f"Expected one value for {self._ACK_ID_KEY}, but found: {len(ack_ids)}" + ) + return ack_ids[0] + + return None + + +class ShutdownHandler(BackgroundResourceRequestHandler): + """ + Handler for the shutdown resource. + """ + + path: str = "/shutdown" + + def put(self) -> HTTPResponse: + self.server._cancel_queue.put(True) + return HTTPResponse(HTTPStatus.OK) + + +class RunHandler(BackgroundResourceRequestHandler): + """ + Handler for the run resource. + """ + + path: str = "/run" + + def put(self) -> HTTPResponse: + if self.server._future_runner.is_running: + return HTTPResponse(HTTPStatus.BAD_REQUEST) + + run_data: dict = json.loads(self.body.decode(encoding="utf-8")) if self.body else {} + + return self.server.submit( + self.server._adaptor_runner._run, + run_data, + ) + + +class StartHandler(BackgroundResourceRequestHandler): + """ + Handler for the start resource. + """ + + path: str = "/start" + + def put(self) -> HTTPResponse: + if self.server._future_runner.is_running: + return HTTPResponse(HTTPStatus.BAD_REQUEST) + + return self.server.submit(self.server._adaptor_runner._start) + + +class StopHandler(BackgroundResourceRequestHandler): + """ + Handler for the stop resource. + """ + + path: str = "/stop" + + def put(self) -> HTTPResponse: + if self.server._future_runner.is_running: + return HTTPResponse(HTTPStatus.BAD_REQUEST) + + return self.server.submit(self._stop_adaptor) + + def _stop_adaptor(self): # pragma: no cover + try: + self.server._adaptor_runner._stop() + _logger.info("Daemon background process stopped.") + finally: + self.server._adaptor_runner._cleanup() + + +class CancelHandler(BackgroundResourceRequestHandler): + """ + Handler for the cancel resource. + """ + + path: str = "/cancel" + + def put(self) -> HTTPResponse: + if not ( + self.server._future_runner.is_running + and self.server._adaptor_runner.state in [AdaptorState.START, AdaptorState.RUN] + ): + return HTTPResponse(HTTPStatus.OK, body="No action required") + + return self.server.submit( + self.server._adaptor_runner._cancel, + force_immediate=True, + ) diff --git a/src/openjd/adaptor_runtime/_background/log_buffers.py b/src/openjd/adaptor_runtime/_background/log_buffers.py new file mode 100644 index 0000000..1aa3db6 --- /dev/null +++ b/src/openjd/adaptor_runtime/_background/log_buffers.py @@ -0,0 +1,174 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import logging +import os +import threading +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List + +from .._utils import secure_open +from .model import BufferedOutput + + +class LogBuffer(ABC): # pragma: no cover + """ + Base class for a log buffer. + """ + + def __init__(self, *, formatter: logging.Formatter | None = None) -> None: + self._formatter = formatter + + @abstractmethod + def buffer(self, record: logging.LogRecord) -> None: + """ + Store the log record in this buffer. + """ + pass + + @abstractmethod + def chunk(self) -> BufferedOutput: + """ + Returns the currently buffered output as a BufferedOutput. + """ + pass + + @abstractmethod + def clear(self, chunk_id: str) -> bool: + """ + Clears the chunk with the specified ID from this buffer. Returns True if the chunk was + cleared, false otherwise. + + Args: + chunk_id (str): The ID of the chunk to clear. + """ + pass + + def _format(self, record: logging.LogRecord) -> str: + return self._formatter.format(record) if self._formatter else record.msg + + def _create_id(self) -> str: + return str(time.time()) + + +class InMemoryLogBuffer(LogBuffer): + """ + In-memory log buffer implementation. + + This buffer stores a single chunk that grows until it is explicitly cleared. If a new chunk is + created without clearing the previous one, the new chunk stores all data in the previous + chunk, in addition to new buffered data, and replaces it. + """ + + _buffer: List[logging.LogRecord] + _last_chunk: BufferedOutput | None + + def __init__(self, *, formatter: logging.Formatter | None = None) -> None: + super().__init__(formatter=formatter) + self._buffer = [] + self._last_chunk = None + self._buffer_lock = threading.Lock() + self._last_chunk_lock = threading.Lock() + + def buffer(self, record: logging.LogRecord) -> None: # pragma: no cover + with self._buffer_lock: + self._buffer.append(record) + + def chunk(self) -> BufferedOutput: + id = self._create_id() + with self._buffer_lock: + logs = [*self._buffer] + self._buffer.clear() + + output = os.linesep.join([self._format(log) for log in logs]) + + with self._last_chunk_lock: + if self._last_chunk: + output = os.linesep.join([self._last_chunk.output, output]) + chunk = BufferedOutput(id, output) + self._last_chunk = chunk + + return chunk + + def clear(self, chunk_id: str) -> bool: + with self._last_chunk_lock: + if self._last_chunk and self._last_chunk.id == chunk_id: + self._last_chunk = None + return True + + return False + + +@dataclass +class _FileChunk: + id: str | None + start: int + end: int + + +class FileLogBuffer(LogBuffer): + """ + Log buffer that uses a file to buffer the output. + + This buffer keeps track of a section in a file with start/end stream positions. This section + grows until it is explicitly cleared. If a new chunk is created without clearing the previous + one, the new chunk's section includes all data in the previous chunk's section, in addition to + new buffered data, and replaces it. + """ + + _filepath: str + _chunk: _FileChunk + + def __init__(self, filepath: str, *, formatter: logging.Formatter | None = None) -> None: + super().__init__(formatter=formatter) + self._filepath = filepath + self._chunk = _FileChunk(id=None, start=0, end=0) + self._file_lock = threading.Lock() + self._chunk_lock = threading.Lock() + + def buffer(self, record: logging.LogRecord) -> None: + with ( + self._file_lock, + secure_open(self._filepath, open_mode="a") as f, + ): + f.write(self._format(record)) + + def chunk(self) -> BufferedOutput: + id = self._create_id() + + with ( + self._chunk_lock, + self._file_lock, + open(self._filepath, mode="r") as f, + ): + self._chunk.id = id + f.seek(self._chunk.start) + output = f.read() + self._chunk.end = f.tell() + + return BufferedOutput(id, output) + + def clear(self, chunk_id: str) -> bool: + with self._chunk_lock: + if self._chunk.id == chunk_id: + self._chunk.start = self._chunk.end + self._chunk.id = None + return True + + return False + + +class LogBufferHandler(logging.Handler): # pragma: no cover + """ + Class for a handler that buffers logs. + """ + + def __init__(self, buffer: LogBuffer, level: logging._Level = logging.NOTSET) -> None: + super().__init__(level) + self._buffer = buffer + + def emit(self, record: logging.LogRecord) -> None: + self._buffer.buffer(record) diff --git a/src/openjd/adaptor_runtime/_background/model.py b/src/openjd/adaptor_runtime/_background/model.py new file mode 100644 index 0000000..103767d --- /dev/null +++ b/src/openjd/adaptor_runtime/_background/model.py @@ -0,0 +1,113 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import dataclasses as dataclasses +import json as json +from enum import Enum as Enum +from typing import Any, ClassVar, Dict, Generic, Iterable, Type, TypeVar, cast + +from ..adaptors import AdaptorState + +_T = TypeVar("_T") + + +@dataclasses.dataclass +class ConnectionSettings: + socket: str + + +class AdaptorStatus(str, Enum): + IDLE = "idle" + WORKING = "working" + + +@dataclasses.dataclass +class BufferedOutput: + EMPTY: ClassVar[str] = "EMPTY" + + id: str + output: str + + +@dataclasses.dataclass +class HeartbeatResponse: + state: AdaptorState + status: AdaptorStatus + output: BufferedOutput + failed: bool = False + + +class DataclassJSONEncoder(json.JSONEncoder): # pragma: no cover + def default(self, o: Any) -> Dict: + if dataclasses.is_dataclass(o): + return dataclasses.asdict(o) + else: + return super().default(o) + + +class DataclassMapper(Generic[_T]): + """ + Class that maps a dictionary to a dataclass. + + The main reason this exists is to support nested dataclasses. Dataclasses are represented as + dict when serialized, and when they are nested we get a nested dictionary structure. For a + simple dataclass, we can easily go from a dict to a dataclass instance by expanding the + dictionary into keyword arguments for the dataclass' __init__ function. e.g. + + ``` + @dataclass + class FullName: + first: str + last: str + + my_dict = {"first": "John", "last": "Doe"} + name_instance = FullName(**my_dict) + ``` + + However, in a nested structure, this will not work because the parent dataclass' __init__ + function expects instance(s) of the nested dataclass(es), not a dictionary. For example, + building on the previous code snippet: + + ``` + @dataclass + class Person: + age: int + name: FullName + + my_dict = { + "age": 30, + "name": { + "first": "John", + "last": "Doe", + }, + } + person_instance = Person(**my_dict) + ``` + + The above code is not valid because Person.__init__ expects an instance of FullName for the + "name" argument, not a dict with the keyword args. This class handles this case by checking + each field to see if it is a dataclass and instantiating that dataclass for you. + """ + + def __init__(self, cls: Type[_T]) -> None: + self._cls = cls + super().__init__() + + def map(self, o: Dict) -> _T: + args: Dict = {} + for field in dataclasses.fields(self._cls): # type: ignore + if field.name not in o: + raise ValueError(f"Dataclass field {field.name} not found in dict {o}") + + value = o[field.name] + if dataclasses.is_dataclass(field.type): + value = DataclassMapper(field.type).map(value) + elif issubclass(field.type, Enum): + [value] = [ + enum + # Need to cast here for mypy + for enum in cast(Iterable[Enum], list(field.type)) + if enum.value == value + ] + args[field.name] = value + + return self._cls(**args) diff --git a/src/openjd/adaptor_runtime/_entrypoint.py b/src/openjd/adaptor_runtime/_entrypoint.py new file mode 100644 index 0000000..5f88e50 --- /dev/null +++ b/src/openjd/adaptor_runtime/_entrypoint.py @@ -0,0 +1,328 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import logging +import os +import signal +import sys +from argparse import ArgumentParser, Namespace +from types import FrameType as FrameType +from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar + +import jsonschema +import yaml + +from .adaptors import AdaptorRunner, BaseAdaptor +from ._background import BackendRunner, FrontendRunner, InMemoryLogBuffer, LogBufferHandler +from .adaptors.configuration import ( + RuntimeConfiguration, + ConfigurationManager, +) +from ._osname import OSName + +if TYPE_CHECKING: # pragma: no cover + from .adaptors.configuration import AdaptorConfiguration + +__all__ = ["EntryPoint"] + +_U = TypeVar("_U", bound=BaseAdaptor) + +_CLI_HELP_TEXT = { + "init_data": ( + "Data to pass to the adaptor during initialization. " + "This can be a JSON string or the path to a file containing a JSON string in the format " + "file://path/to/file.json" + ), + "run_data": ( + "Data to pass to the adaptor when it is being run. " + "This can be a JSON string or the path to a file containing a JSON string in the format " + "file://path/to/file.json" + ), + "path_mapping_rules": ( + "Path mapping rules to make available to the adaptor while it's running. " + "This can be a JSON string or the path to a file containing a JSON string in the format " + "file://path/to/file.json" + ), + "show_config": ( + "When specified, the adaptor runtime configuration is printed then the program exits." + ), + "connection_file": "The file path to the connection file for use in background mode.", +} + +_DIR = os.path.dirname(os.path.realpath(__file__)) +# Keyword args to init the ConfigurationManager for the runtime. +_ENV_CONFIG_PATH_PREFIX = "RUNTIME_CONFIG_PATH" +_RUNTIME_CONFIG_PATHS: dict[Any, Any] = { + "schema_path": os.path.abspath(os.path.join(_DIR, "configuration.schema.json")), + "default_config_path": os.path.abspath(os.path.join(_DIR, "configuration.json")), + "system_config_path_map": { + "Linux": os.path.abspath( + os.path.join( + os.path.sep, + "etc", + "openjd", + "worker", + "adaptors", + "runtime", + "configuration.json", + ) + ) + }, + "user_config_rel_path": os.path.join( + ".openjd", "worker", "adaptors", "runtime", "configuration.json" + ), +} + +_logger = logging.getLogger(__name__) + + +class EntryPoint: + """ + The main entry point of the adaptor runtime. + """ + + def __init__(self, adaptor_class: Type[_U]) -> None: + self.adaptor_class = adaptor_class + # This will be the current AdaptorRunner when using the 'run' command, rather than + # 'background' command + self._adaptor_runner: Optional[AdaptorRunner] = None + + def start(self) -> None: + """ + Starts the run of the adaptor. + """ + formatter = logging.Formatter("%(levelname)s: %(message)s") + stream_handler = logging.StreamHandler(sys.stdout) + stream_handler.setFormatter(formatter) + + runtime_logger = logging.getLogger(__package__) + runtime_logger.setLevel(logging.INFO) # Start with INFO, will get updated with config + runtime_logger.addHandler(stream_handler) + + adaptor_logger = logging.getLogger(self.adaptor_class.__module__.split(".")[0]) + adaptor_logger.addHandler(stream_handler) + + parsed_args = self._parse_args() + + path_mapping_data = ( + parsed_args.path_mapping_rules + if hasattr(parsed_args, "path_mapping_rules") + # TODO: Eliminate the use of the environment variable once all users of this library have + # been updated to use the command-line option. Default to an empty dictionary. + else _load_data(os.environ.get("PATH_MAPPING_RULES", "{}")) + ) + + additional_config_path = os.environ.get(_ENV_CONFIG_PATH_PREFIX) + self.config_manager = ConfigurationManager( + config_cls=RuntimeConfiguration, + **_RUNTIME_CONFIG_PATHS, + additional_config_paths=[additional_config_path] if additional_config_path else [], + ) + try: + self.config = self.config_manager.build_config() + except jsonschema.ValidationError as e: + _logger.error(f"Nonvalid runtime configuration file: {e}") + raise + except NotImplementedError as e: + _logger.warning( + f"The current system ({OSName()}) is not supported for runtime " + f"configuration. Only the default configuration will be loaded. Full error: {e}" + ) + # The above call to build_config() would have already successfully retrieved the + # default config for this error to be raised, so we can assume the default config + # is valid here. + self.config = self.config_manager.get_default_config() + + if hasattr(parsed_args, "show_config") and parsed_args.show_config: + print(yaml.dump(self.config.config, indent=2)) + return # pragma: no cover + + init_data = parsed_args.init_data if hasattr(parsed_args, "init_data") else {} + run_data = parsed_args.run_data if hasattr(parsed_args, "run_data") else {} + command = ( + parsed_args.command + if hasattr(parsed_args, "command") and parsed_args.command is not None + else "run" + ) + + adaptor: BaseAdaptor[AdaptorConfiguration] = self.adaptor_class( + init_data, path_mapping_data=path_mapping_data + ) + + adaptor_logger.setLevel(adaptor.config.log_level) + runtime_logger.setLevel(self.config.log_level) + + if command == "run": + self._adaptor_runner = AdaptorRunner(adaptor=adaptor) + # To be able to handle cancelation via a SIGTERM/SIGINT + signal.signal(signal.SIGINT, self._sigint_handler) + signal.signal(signal.SIGTERM, self._sigint_handler) + try: + self._adaptor_runner._start() + self._adaptor_runner._run(run_data) + self._adaptor_runner._stop() + self._adaptor_runner._cleanup() + except Exception as e: + _logger.error(f"Error running the adaptor: {e}") + try: + self._adaptor_runner._cleanup() + except Exception as e: + _logger.error(f"Error cleaning up the adaptor: {e}") + raise + raise + elif command == "daemon": # pragma: no branch + connection_file = parsed_args.connection_file + if not os.path.isabs(connection_file): + connection_file = os.path.abspath(connection_file) + subcommand = parsed_args.subcommand if hasattr(parsed_args, "subcommand") else None + + if subcommand == "_serve": + # Replace stream handler with log buffer handler since output will be buffered in + # background mode + log_buffer = InMemoryLogBuffer(formatter=formatter) + buffer_handler = LogBufferHandler(log_buffer) + for logger in [runtime_logger, adaptor_logger]: + logger.removeHandler(stream_handler) + logger.addHandler(buffer_handler) + + # This process is running in background mode. Create the backend server and serve + # forever until a shutdown is requested + backend = BackendRunner( + AdaptorRunner(adaptor=adaptor), + connection_file, + log_buffer=log_buffer, + ) + backend.run() + else: + # This process is running in frontend mode. Create the frontend runner and send + # the appropriate request to the backend. + frontend = FrontendRunner(connection_file) + if subcommand == "start": + adaptor_module = sys.modules.get(self.adaptor_class.__module__) + if adaptor_module is None: + raise ModuleNotFoundError( + f"Adaptor module is not loaded: {self.adaptor_class.__module__}" + ) + + frontend.init(adaptor_module, init_data) + frontend.start() + elif subcommand == "run": + frontend.run(run_data) + elif subcommand == "stop": + frontend.stop() + frontend.shutdown() + + def _parse_args(self) -> Namespace: + parser = self._build_argparser() + try: + return parser.parse_args(sys.argv[1:]) + except Exception as e: + _logger.error(f"Error parsing command line arguments: {e}") + raise + + def _build_argparser(self) -> ArgumentParser: + parser = ArgumentParser(prog="adaptor_runtime", add_help=True) + parser.add_argument( + "--show-config", action="store_true", help=_CLI_HELP_TEXT["show_config"] + ) + + subparser = parser.add_subparsers(dest="command", title="subcommands") + + init_data = ArgumentParser(add_help=False) + init_data.add_argument( + "--init-data", default="", type=_load_data, help=_CLI_HELP_TEXT["init_data"] + ) + run_data = ArgumentParser(add_help=False) + run_data.add_argument( + "--run-data", default="", type=_load_data, help=_CLI_HELP_TEXT["run_data"] + ) + + path_mapping_rules = ArgumentParser(add_help=False) + path_mapping_rules.add_argument( + "--path-mapping-rules", + required=False, + type=_load_data, + help=_CLI_HELP_TEXT["path_mapping_rules"], + ) + + subparser.add_parser("run", parents=[init_data, path_mapping_rules, run_data]) + + connection_file = ArgumentParser(add_help=False) + connection_file.add_argument( + "--connection-file", + default="", + help=_CLI_HELP_TEXT["connection_file"], + required=True, + ) + + bg_parser = subparser.add_parser("daemon") + bg_subparser = bg_parser.add_subparsers( + dest="subcommand", + title="subcommands", + required=True, + # Explicitly set the metavar to "hide" the "_serve" command + metavar="{start,run,stop}", + ) + + # "Hidden" command that actually runs the adaptor runtime in background mode + bg_subparser.add_parser("_serve", parents=[init_data, connection_file]) + + bg_subparser.add_parser("start", parents=[init_data, path_mapping_rules, connection_file]) + bg_subparser.add_parser("run", parents=[run_data, connection_file]) + bg_subparser.add_parser("stop", parents=[connection_file]) + + return parser + + def _sigint_handler(self, signum: int, frame: Optional[FrameType]) -> None: + """Signal handler that is invoked when the process receives a SIGINT/SIGTERM""" + if self._adaptor_runner is not None: + _logger.info("Interruption signal recieved.") + # OpenJD dictates that a SIGTERM/SIGINT results in a cancel workflow being + # kicked off. + self._adaptor_runner._cancel() + + +def _load_data(data: str) -> dict: + """ + Parses an input JSON/YAML (filepath or string-encoded) into a dictionary. + + Args: + data (str): The filepath or string representation of the JSON/YAML to parse. + + Raises: + ValueError: Raised when the JSON/YAML is not parsed to a dictionary. + """ + if not data: + return {} + + try: + loaded_data = _load_yaml_json(data) + except OSError as e: + _logger.error(f"Failed to open data file: {e}") + raise + except yaml.YAMLError as e: + _logger.error(f"Failed to load data as JSON or YAML: {e}") + raise + + if not isinstance(loaded_data, dict): + raise ValueError(f"Expected loaded data to be a dict, but got {type(loaded_data)}") + + return loaded_data + + +def _load_yaml_json(data: str) -> Any: + """ + Loads a YAML/JSON file/string. + + Note that yaml.safe_load() is capable of loading JSON documents. + """ + loaded_yaml = None + if data.startswith("file://"): + filepath = data[len("file://") :] + with open(filepath) as yaml_file: + loaded_yaml = yaml.safe_load(yaml_file) + else: + loaded_yaml = yaml.safe_load(data) + + return loaded_yaml diff --git a/src/openjd/adaptor_runtime/_http/__init__.py b/src/openjd/adaptor_runtime/_http/__init__.py new file mode 100644 index 0000000..834c7e8 --- /dev/null +++ b/src/openjd/adaptor_runtime/_http/__init__.py @@ -0,0 +1,6 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from .request_handler import HTTPResponse, RequestHandler, ResourceRequestHandler +from .sockets import SocketDirectories + +__all__ = ["HTTPResponse", "RequestHandler", "ResourceRequestHandler", "SocketDirectories"] diff --git a/src/openjd/adaptor_runtime/_http/exceptions.py b/src/openjd/adaptor_runtime/_http/exceptions.py new file mode 100644 index 0000000..705504d --- /dev/null +++ b/src/openjd/adaptor_runtime/_http/exceptions.py @@ -0,0 +1,17 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + +class UnsupportedPlatformException(Exception): + pass + + +class NonvalidSocketPathException(Exception): + """Raised when a socket path is not valid""" + + pass + + +class NoSocketPathFoundException(Exception): + """Raised when a valid socket path could not be found""" + + pass diff --git a/src/openjd/adaptor_runtime/_http/request_handler.py b/src/openjd/adaptor_runtime/_http/request_handler.py new file mode 100644 index 0000000..19ee3c3 --- /dev/null +++ b/src/openjd/adaptor_runtime/_http/request_handler.py @@ -0,0 +1,227 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import abc +import ctypes +import logging +import os +import socket +import socketserver +import urllib.parse as urllib_parse +from dataclasses import dataclass +from http import HTTPStatus, server +from typing import Callable, Type + +from .exceptions import UnsupportedPlatformException + +_logger = logging.getLogger(__name__) + + +class RequestHandler(server.BaseHTTPRequestHandler): + """ + Class that handles HTTP requests to a HTTPServer. + + Note: The "server" argument passed to this class must listen for requests using UNIX domain + sockets. + """ + + _DEFAULT_HANDLER: ResourceRequestHandler + _HANDLER_TYPE: Type[ResourceRequestHandler] + + _handlers: dict[str, ResourceRequestHandler] + + # Socket variable set in parent class StreamRequestHandler.setup() + connection: socket.socket + + def __init__( + self, + request: bytes, + client_address: str, + server: socketserver.BaseServer, + handler_type: Type[ResourceRequestHandler], + ) -> None: + self._DEFAULT_HANDLER = _DefaultRequestHandler() + self._HANDLER_TYPE = handler_type + + def _subclasses(cls: type): + for sc in cls.__subclasses__(): + yield from _subclasses(sc) + yield sc + + self._handlers = { + sc.path: sc(self) + for sc in _subclasses(self._HANDLER_TYPE) + if sc is not _DefaultRequestHandler + } + super().__init__(request, client_address, server) # type: ignore + + def address_string(self) -> str: + # Parent class assumes this is a tuple of (address, port) + return self.client_address # type: ignore + + def do_GET(self) -> None: # pragma: no cover + parsed_path = urllib_parse.urlparse(self.path) + handler = self._handlers.get(parsed_path.path, self._DEFAULT_HANDLER) + self._do_request(handler.get) + + def do_PUT(self) -> None: # pragma: no cover + parsed_path = urllib_parse.urlparse(self.path) + handler = self._handlers.get(parsed_path.path, self._DEFAULT_HANDLER) + self._do_request(handler.put) + + def _do_request(self, func: Callable[[], HTTPResponse]) -> None: + # First, authenticate the connecting peer + try: + authenticated = self._authenticate() + except UnsupportedPlatformException as e: + _logger.error(e) + self._respond(HTTPResponse(HTTPStatus.INTERNAL_SERVER_ERROR)) + return + + if not authenticated: + self._respond(HTTPResponse(HTTPStatus.UNAUTHORIZED)) + return + + # Handle the request + try: + response = func() + except Exception as e: + _logger.error(f"Failed to handle request: {e}") + response = HTTPResponse(HTTPStatus.INTERNAL_SERVER_ERROR) + + self._respond(response) + + def _respond(self, response: HTTPResponse) -> None: + if response.status < 400: + self.send_response(response.status) + else: + self.send_error(response.status) + _logger.debug(f"Sending status code {response.status} for request to {self.path}") + + if response.body: + body = response.body.encode("utf-8") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + else: + self.end_headers() + + # NOTE: self.connection is set by the base class socketserver.StreamRequestHandler. + # This class is instantiated in socketserver.BaseServer.finish_request(), where the socket + # returned by socketserver.BaseServer.get_request() is passed as the argument for "request". + def _authenticate(self) -> bool: + # Verify we have a UNIX socket. + if not ( + isinstance(self.connection, socket.socket) + and self.connection.family == socket.AddressFamily.AF_UNIX + ): + raise UnsupportedPlatformException( + "Failed to handle request because it was not made through a UNIX socket" + ) + + # Get the credentials of the peer process + cred_buffer = self.connection.getsockopt( + socket.SOL_SOCKET, + socket.SO_PEERCRED, + socket.CMSG_SPACE(ctypes.sizeof(UCred)), + ) + peer_cred = UCred.from_buffer_copy(cred_buffer) + + # Only allow connections from a process running as the same user + return peer_cred.uid == os.getuid() + + +class UCred(ctypes.Structure): + """ + Represents the ucred struct returned from the SO_PEERCRED socket option. + + For more info, see SO_PASSCRED in the unix(7) man page + """ + + _fields_ = [ + ("pid", ctypes.c_int), + ("uid", ctypes.c_int), + ("gid", ctypes.c_int), + ] + + def __str__(self): # pragma: no cover + return f"pid:{self.pid} uid:{self.uid} gid:{self.gid}" + + +@dataclass +class HTTPResponse: + """ + Dataclass to model an HTTP response. + """ + + status: HTTPStatus + body: str | None = None + + +class ResourceRequestHandler(abc.ABC): + """ + Base class that handles HTTP requests for a specific resource. + """ + + path: str = "/" + + def __init__( + self, + handler: RequestHandler, + ) -> None: + self.handler = handler + + def get(self) -> HTTPResponse: # pragma: no cover + """ + Handles HTTP GET + """ + return HTTPResponse(HTTPStatus.NOT_IMPLEMENTED, None) + + def put(self) -> HTTPResponse: # pragma: no cover + """ + Handles HTTP PUT + """ + return HTTPResponse(HTTPStatus.NOT_IMPLEMENTED, None) + + @property + def server(self) -> socketserver.BaseServer: + """ + Property to "lazily type check" the HTTP server class this handler is used in. + + This is required because the socketserver.BaseRequestHandler.__init__ method actually + handles the request. This means the self.handler.server variable is not set until that + init method is called, so we need to do this type check outside of the init chain. + """ + return self.handler.server + + @property + def query_string_params(self) -> dict[str, list[str]]: + """ + Gets the query string parameters for the request. + + Note: Parameter values are stored in an array to support duplicate keys + """ + if not hasattr(self, "_query_string_params"): + parsed_path = urllib_parse.urlparse(self.handler.path) + self._query_string_params = urllib_parse.parse_qs(parsed_path.query) + return self._query_string_params + + @property + def body(self) -> bytes | None: + """ + Gets the request body or None if there was no body. + """ + if not hasattr(self, "_body"): + body_length = int(self.handler.headers.get("Content-Length", 0)) + self._body = self.handler.rfile.read(body_length) if body_length else None + return self._body + + +class _DefaultRequestHandler(ResourceRequestHandler): # pragma: no cover + """ + Request handler that always returns 501 Not Implemented (see base class implementation) + """ + + def __init__(self) -> None: + pass diff --git a/src/openjd/adaptor_runtime/_http/sockets.py b/src/openjd/adaptor_runtime/_http/sockets.py new file mode 100644 index 0000000..4929a35 --- /dev/null +++ b/src/openjd/adaptor_runtime/_http/sockets.py @@ -0,0 +1,163 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import abc +import os +import stat +import tempfile + +from .._osname import OSName +from .exceptions import ( + UnsupportedPlatformException, + NonvalidSocketPathException, + NoSocketPathFoundException, +) + +# Max PID on 64-bit systems is 4194304 (2^22) +_PID_MAX_LENGTH = 7 +_PID_MAX_LENGTH_PADDED = _PID_MAX_LENGTH + 1 # 1 char for path seperator + + +class SocketDirectories(abc.ABC): + """ + Base class for determining the base directory for sockets used in the Adaptor Runtime. + """ + + @staticmethod + def for_os(osname: OSName = OSName()): # pragma: no cover + """_summary_ + + Args: + osname (OSName, optional): The OS to get socket directories for. + Defaults to the current OS. + + Raises: + UnsupportedPlatformException: Raised when this class is requested for an unsupported + platform. + """ + klass = _get_socket_directories_cls(osname) + if not klass: + raise UnsupportedPlatformException(osname) + return klass() + + def get_process_socket_path(self, namespace: str | None = None, *, create_dir: bool = False): + """ + Gets the path for this process' socket in the given namespace. + + Args: + namespace (Optional[str]): The optional namespace (subdirectory) where the sockets go. + create_dir (bool): Whether to create the socket directory. Default is false. + + Raises: + NonvalidSocketPathException: Raised if the user has configured a socket base directory + that is nonvalid + NoSocketPathFoundException: Raised if no valid socket path could be found. This will + not be raised if the user has configured a socket base directory. + """ + socket_name = str(os.getpid()) + assert ( + len(socket_name) <= _PID_MAX_LENGTH + ), f"PID too long. Only PIDs up to {_PID_MAX_LENGTH} digits are supported." + + return os.path.join(self.get_socket_dir(namespace, create=create_dir), socket_name) + + def get_socket_dir(self, namespace: str | None = None, *, create: bool = False) -> str: + """ + Gets the base directory for sockets used in Adaptor IPC + + Args: + namespace (Optional[str]): The optional namespace (subdirectory) where the sockets go + create (bool): Whether to create the directory or not. Default is false. + + Raises: + NonvalidSocketPathException: Raised if the user has configured a socket base directory + that is nonvalid + NoSocketPathFoundException: Raised if no valid socket path could be found. This will + not be raised if the user has configured a socket base directory. + """ + + def create_dir(path: str) -> str: + if create: + os.makedirs(path, mode=0o700, exist_ok=True) + return path + + rel_path = os.path.join(".openjd", "adaptors", "sockets") + if namespace: + rel_path = os.path.join(rel_path, namespace) + + reasons: list[str] = [] + + # First try home directory + home_dir = os.path.expanduser("~") + socket_dir = os.path.join(home_dir, rel_path) + try: + self.verify_socket_path(socket_dir) + except NonvalidSocketPathException as e: + reasons.append(f"Cannot create sockets directory in the home directory because: {e}") + else: + return create_dir(socket_dir) + + # Last resort is the temp directory + temp_dir = tempfile.gettempdir() + socket_dir = os.path.join(temp_dir, rel_path) + try: + self.verify_socket_path(socket_dir) + except NonvalidSocketPathException as e: + reasons.append(f"Cannot create sockets directory in the temp directory because: {e}") + else: + # Also check that the sticky bit is set on the temp dir + if not os.stat(temp_dir).st_mode & stat.S_ISVTX: + reasons.append( + f"Cannot use temporary directory {temp_dir} because it does not have the " + "sticky bit (restricted deletion flag) set" + ) + else: + return create_dir(socket_dir) + + raise NoSocketPathFoundException( + "Failed to find a suitable base directory to create sockets in for the following " + f"reasons: {os.linesep.join(reasons)}" + ) + + @abc.abstractmethod + def verify_socket_path(self, path: str) -> None: # pragma: no cover + """ + Verifies a socket path is valid. + + Raises: + NonvalidSocketPathException: Subclasses will raise this exception if the socket path + is not valid. + """ + pass + + +class LinuxSocketDirectories(SocketDirectories): + """ + Specialization for socket paths in Linux systems. + """ + + # This is based on the max length of socket names to 108 bytes + # See unix(7) under "Address format" + _socket_path_max_length = 108 + _socket_dir_max_length = _socket_path_max_length - _PID_MAX_LENGTH_PADDED + + def verify_socket_path(self, path: str) -> None: + path_length = len(path.encode("utf-8")) + if path_length > self._socket_dir_max_length: + raise NonvalidSocketPathException( + "Socket base directory path too big. The maximum allowed size is " + f"{self._socket_dir_max_length} bytes, but the directory has a size of " + f"{path_length}: {path}" + ) + + +_os_map: dict[str, type[SocketDirectories]] = { + OSName.LINUX: LinuxSocketDirectories, +} + + +def _get_socket_directories_cls( + osname: OSName, +) -> type[SocketDirectories] | None: # pragma: no cover + return _os_map.get(osname, None) diff --git a/src/openjd/adaptor_runtime/_osname.py b/src/openjd/adaptor_runtime/_osname.py new file mode 100644 index 0000000..f3e63f0 --- /dev/null +++ b/src/openjd/adaptor_runtime/_osname.py @@ -0,0 +1,95 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import platform + + +class OSName(str): + """ + OS Name Utility Class. + + Calling the constructor without any parameters will create an OSName object initialized with the + OS python is running on (one of Linux, macOS, Windows). + + Calling the constructor with a string will result in an OSName object with the string resolved + to one of Linux, macOS, Windows. If the string could not be resolved to an OS, then a ValueError + will be raised. + + This class also has an override __eq__ which can be used to compare against string types for OS + Name equality. For example OSName('Windows') == 'nt' will evaluate to True. + """ + + LINUX = "Linux" + MACOS = "macOS" + WINDOWS = "Windows" + POSIX = "Posix" + + __hash__ = str.__hash__ # needed because we define __eq__ + + def __init__(self, *args, **kw): + super().__init__() + + def __new__(cls, *args, **kw): + if len(args) > 0: + args = (OSName.resolve_os_name(args[0]), *args[1:]) + else: + args = (OSName._get_os_name(),) + return str.__new__(cls, *args, **kw) + + @staticmethod + def is_macos(name: str) -> bool: + return OSName.resolve_os_name(name) == OSName.MACOS + + @staticmethod + def is_windows(name: str) -> bool: + return OSName.resolve_os_name(name) == OSName.WINDOWS + + @staticmethod + def is_linux(name: str) -> bool: + return OSName.resolve_os_name(name) == OSName.LINUX + + @staticmethod + def is_posix(name: str) -> bool: + return ( + OSName.resolve_os_name(name) == OSName.POSIX + or OSName.is_macos(name) + or OSName.is_linux(name) + ) + + @staticmethod + def _get_os_name() -> str: + return OSName.resolve_os_name(platform.system()) + + @staticmethod + def resolve_os_name(name: str) -> str: + """ + Resolves an OS Name from an alias. In general this works as follows: + - macOS will resolve from: {'darwin', 'macos', 'mac', 'mac os', 'os x'} + - Windows will resolve from {'nt', 'windows'} or any string starting with 'win' like 'win32' + - Linux will resolve from any string starting with 'linux', like 'linux' or 'linux2' + """ + name = name.lower().strip() + if os_name := _osname_alias_map.get(name): + return os_name + elif name.startswith("win"): + return OSName.WINDOWS + elif name.startswith("linux"): + return OSName.LINUX + elif name.lower() == "posix": + return OSName.POSIX + else: + raise ValueError(f"The operating system '{name}' is unknown and could not be resolved.") + + def __eq__(self, __x: object) -> bool: + return OSName.resolve_os_name(self) == OSName.resolve_os_name(str(__x)) + + +_osname_alias_map: dict[str, str] = { + "darwin": OSName.MACOS, + "macos": OSName.MACOS, + "mac": OSName.MACOS, + "mac os": OSName.MACOS, + "os x": OSName.MACOS, + "nt": OSName.WINDOWS, + "windows": OSName.WINDOWS, + "posix": OSName.POSIX, +} diff --git a/src/openjd/adaptor_runtime/_utils/__init__.py b/src/openjd/adaptor_runtime/_utils/__init__.py new file mode 100644 index 0000000..6a0079e --- /dev/null +++ b/src/openjd/adaptor_runtime/_utils/__init__.py @@ -0,0 +1,7 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from ._secure_open import secure_open + +__all__ = [ + "secure_open", +] diff --git a/src/openjd/adaptor_runtime/_utils/_secure_open.py b/src/openjd/adaptor_runtime/_utils/_secure_open.py new file mode 100644 index 0000000..e23829d --- /dev/null +++ b/src/openjd/adaptor_runtime/_utils/_secure_open.py @@ -0,0 +1,75 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import os +import stat +from contextlib import contextmanager +from typing import IO, TYPE_CHECKING, Generator + +if TYPE_CHECKING: + from _typeshed import StrOrBytesPath + + +@contextmanager +def secure_open( + path: "StrOrBytesPath", + open_mode: str, + encoding: str | None = None, + newline: str | None = None, + mask: int = 0, +) -> Generator[IO, None, None]: + """ + Opens a file with the following behavior: + The OS-level open flags are inferred from the open mode (A combination of r, w, a, x, +) + If the open_mode involves writing to the file, then the following permissions are set: + OWNER read/write bit-wise OR'd with the mask argument provided + If the open_mode only involves reading the file, the permissions are not changed. + Args: + file (StrOrBytesPath): The path to the file to open + mode (str): The string mode for opening the file. A combination of r, w, a, x, + + encoding (str, optional): The encoding of the file to open. Defaults to None. + newline (str, optional): The newline character to use. Defaults to None. + mask (int, optional): Additional masks to apply to the opened file. Defaults to 0. + + Raises: + ValueError: If the open mode is not valid + + Returns: + Generator: A generator that yields the opened file + """ + flags = _get_flags_from_mode_str(open_mode) + os_open_kwargs = { + "path": path, + "flags": _get_flags_from_mode_str(open_mode), + } + if flags != 0: # not O_RDONLY + os_open_kwargs["mode"] = stat.S_IWUSR | stat.S_IRUSR | mask + + fd = os.open(**os_open_kwargs) # type: ignore + + open_kwargs = {} + if encoding is not None: + open_kwargs["encoding"] = encoding + if newline is not None: + open_kwargs["newline"] = newline + with open(fd, open_mode, **open_kwargs) as f: # type: ignore + yield f + + +def _get_flags_from_mode_str(open_mode: str) -> int: + flags = 0 + for char in open_mode: + if char == "r": + flags |= os.O_RDONLY + elif char == "w": + flags |= os.O_WRONLY | os.O_TRUNC | os.O_CREAT + elif char == "a": + flags |= os.O_WRONLY | os.O_APPEND | os.O_CREAT + elif char == "x": + flags |= os.O_EXCL | os.O_CREAT | os.O_WRONLY + elif char == "+": + flags |= os.O_RDWR | os.O_CREAT + else: + raise ValueError(f"Nonvalid mode: '{open_mode}'") + return flags diff --git a/src/openjd/adaptor_runtime/adaptors/__init__.py b/src/openjd/adaptor_runtime/adaptors/__init__.py new file mode 100644 index 0000000..45548ee --- /dev/null +++ b/src/openjd/adaptor_runtime/adaptors/__init__.py @@ -0,0 +1,21 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from ._adaptor import Adaptor +from ._adaptor_runner import AdaptorRunner +from ._adaptor_states import AdaptorState +from ._base_adaptor import AdaptorConfigurationOptions, BaseAdaptor +from ._command_adaptor import CommandAdaptor +from ._path_mapping import PathMappingRule +from ._validator import AdaptorDataValidator, AdaptorDataValidators + +__all__ = [ + "Adaptor", + "AdaptorConfigurationOptions", + "AdaptorDataValidator", + "AdaptorDataValidators", + "AdaptorRunner", + "AdaptorState", + "BaseAdaptor", + "CommandAdaptor", + "PathMappingRule", +] diff --git a/src/openjd/adaptor_runtime/adaptors/_adaptor.py b/src/openjd/adaptor_runtime/adaptors/_adaptor.py new file mode 100644 index 0000000..289311f --- /dev/null +++ b/src/openjd/adaptor_runtime/adaptors/_adaptor.py @@ -0,0 +1,70 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from abc import abstractmethod +from typing import TypeVar + +from .configuration import AdaptorConfiguration +from ._base_adaptor import BaseAdaptor + +__all__ = ["Adaptor"] + +_T = TypeVar("_T", bound=AdaptorConfiguration) + + +class Adaptor(BaseAdaptor[_T]): + """An Adaptor. + + Derived classes must override the on_run method, and may also optionally + override the on_start, on_end, on_cleanup, and on_cancel methods. + """ + + # =============================================== + # Callbacks / virtual functions. + # =============================================== + + def on_start(self): # pragma: no cover + """ + For job stickiness. Will start everything required for the Task. Will be used for all + SubTasks. + """ + pass + + @abstractmethod + def on_run(self, run_data: dict): # pragma: no cover + """ + This will run for every task and will setup everything needed to render (including calling + any managed processes). This will be overridden and defined in each advanced plugin. + """ + pass + + def on_stop(self): # pragma: no cover + """ + For job stickiness. Will stop everything required for the Task before moving on to a new + Task. + """ + pass + + def on_cleanup(self): # pragma: no cover + """ + This callback will be any additional cleanup required by the adaptor. + """ + pass + + # =============================================== + # =============================================== + + def _start(self): # pragma: no cover + self.on_start() + + def _run(self, run_data: dict): + """ + :param run_data: This is the data that changes between the different SubTasks. Eg. frame + number. + """ + self.on_run(run_data) + + def _stop(self): # pragma: no cover + self.on_stop() + + def _cleanup(self): # pragma: no cover + self.on_cleanup() diff --git a/src/openjd/adaptor_runtime/adaptors/_adaptor_runner.py b/src/openjd/adaptor_runtime/adaptors/_adaptor_runner.py new file mode 100644 index 0000000..6b05893 --- /dev/null +++ b/src/openjd/adaptor_runtime/adaptors/_adaptor_runner.py @@ -0,0 +1,86 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import logging + +from ._adaptor_states import AdaptorState, AdaptorStates +from ._base_adaptor import BaseAdaptor as BaseAdaptor + +__all__ = ["AdaptorRunner"] + +_logger = logging.getLogger(__name__) + +_OPENJD_FAIL_STDOUT_PREFIX: str = "openjd_fail: " + + +class AdaptorRunner(AdaptorStates): + """ + Class that is responsible for running adaptors. + """ + + def __init__(self, *, adaptor: BaseAdaptor): + self.adaptor = adaptor + self.state = AdaptorState.NOT_STARTED + + def _start(self): + _logger.debug("Starting...") + self.state = AdaptorState.START + + try: + self.adaptor._start() + except Exception as e: + _fail(f"Error encountered while starting adaptor: {e}") + raise + + def _run(self, run_data: dict): + _logger.debug("Running task") + self.state = AdaptorState.RUN + + try: + self.adaptor._run(run_data) + except Exception as e: + _fail(f"Error encountered while running adaptor: {e}") + raise + + _logger.debug("Task complete") + + def _stop(self): + _logger.debug("Stopping...") + self.state = AdaptorState.STOP + + try: + self.adaptor._stop() + except Exception as e: + _fail(f"Error encountered while stopping adaptor: {e}") + raise + + def _cleanup(self): + _logger.debug("Cleaning up...") + self.state = AdaptorState.CLEANUP + + try: + self.adaptor._cleanup() + except Exception as e: + _fail(f"Error encountered while cleaning up adaptor: {e}") + raise + + _logger.debug("Cleanup complete") + + def _cancel(self): + _logger.debug("Canceling...") + self.state = AdaptorState.CANCELED + + try: + self.adaptor.cancel() + except Exception as e: + _fail(f"Error encountered while canceling the adaptor: {e}") + raise + + _logger.debug("Cancel complete") + + +def _fail(reason: str): + # TODO: Add a way to output "system" messages that ignore logging configuration. + # We don't ever want this message to get filtered out by Python's logging library. + _logger.error(f"{_OPENJD_FAIL_STDOUT_PREFIX}{reason}") diff --git a/src/openjd/adaptor_runtime/adaptors/_adaptor_states.py b/src/openjd/adaptor_runtime/adaptors/_adaptor_states.py new file mode 100644 index 0000000..e30b8f8 --- /dev/null +++ b/src/openjd/adaptor_runtime/adaptors/_adaptor_states.py @@ -0,0 +1,61 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from enum import Enum + +__all__ = [ + "AdaptorState", + "AdaptorStates", +] + + +class AdaptorState(str, Enum): + """ + Enumeration of the different states an adaptor can be in. + """ + + NOT_STARTED = "not_started" + START = "start" + RUN = "run" + STOP = "stop" + CLEANUP = "cleanup" + CANCELED = "canceled" + + +class AdaptorStates(ABC): + """ + Abstract class containing functions to transition an adaptor between states. + """ + + @abstractmethod + def _start(self): # pragma: no cover + """ + Starts the adaptor. + """ + pass + + @abstractmethod + def _run(self, run_data: dict): # pragma: no cover + """ + Runs the adaptor. + + Args: + run_data (dict): The data required to run the adaptor. + """ + pass + + @abstractmethod + def _stop(self): # pragma: no cover + """ + Stops the adaptor run. + """ + pass + + @abstractmethod + def _cleanup(self): # pragma: no cover + """ + Performs any cleanup the adaptor may need. + """ + pass diff --git a/src/openjd/adaptor_runtime/adaptors/_base_adaptor.py b/src/openjd/adaptor_runtime/adaptors/_base_adaptor.py new file mode 100644 index 0000000..96c8779 --- /dev/null +++ b/src/openjd/adaptor_runtime/adaptors/_base_adaptor.py @@ -0,0 +1,231 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import logging +import math +import os +import sys +from dataclasses import dataclass +from types import ModuleType +from typing import Generic +from typing import Type +from typing import TypeVar + +from .configuration import AdaptorConfiguration, ConfigurationManager +from .configuration._configuration_manager import ( + create_adaptor_configuration_manager as create_adaptor_configuration_manager, +) +from ._adaptor_states import AdaptorStates +from ._path_mapping import PathMappingRule + +__all__ = [ + "AdaptorConfigurationOptions", + "BaseAdaptor", +] + +# "{ADAPTORNAME}_" is put in front of this variables to make the full env variable +# ie. MAYAADAPTOR_CONFIG_PATH +_ENV_CONFIG_PATH_TEMPLATE = "CONFIG_PATH" +# Directory containing adaptor schemas +_ENV_CONFIG_SCHEMA_PATH_PREFIX = "ADAPTOR_CONFIG_SCHEMA_PATH" + +_T = TypeVar("_T", bound=AdaptorConfiguration) + +_logger = logging.getLogger(__name__) + + +@dataclass +class AdaptorConfigurationOptions(Generic[_T]): + """Options for adaptor configuration.""" + + config_cls: Type[_T] | None = None + """The adaptor configuration class to use.""" + config_path: str | None = None + """The path to the adaptor configuration file.""" + schema_path: str | list[str] | None = None + """The path to the JSON Schema file.""" + + +class BaseAdaptor(AdaptorStates, Generic[_T]): + """ + Base class for adaptors. + """ + + _OPENJD_PROGRESS_STDOUT_PREFIX: str = "openjd_progress: " + _OPENJD_STATUS_STDOUT_PREFIX: str = "openjd_status: " + + def __init__( + self, + init_data: dict, + *, + config_opts: AdaptorConfigurationOptions[_T] | None = None, + path_mapping_data: dict[str, list[dict[str, str]]] | None = None, + ): + """ + Args: + init_data (dict): Data required to initialize the adaptor. + config_opts (AdaptorConfigurationOptions[T], optional): Options for adaptor + configuration. + """ + self.init_data = init_data + self._config_opts = config_opts + self._path_mapping_data: dict = path_mapping_data or {} + self._path_mapping_rules: list[PathMappingRule] = [ + PathMappingRule.from_dict(rule=rule) + for rule in self._path_mapping_data.get("path_mapping_rules", []) + ] + + def on_cancel(self): # pragma: no cover + """ + Invoked at the end of the `cancel` method. + """ + pass + + def cancel(self): # pragma: no cover + """ + Cancels the run of this adaptor. + """ + self.on_cancel() + + @property + def config_manager(self) -> ConfigurationManager[_T]: + """ + Gets the lazily-loaded configuration manager for this adaptor. + """ + if not hasattr(self, "_config_manager"): + self._config_manager = self._load_configuration_manager() + + return self._config_manager + + @property + def config(self) -> _T: + """ + Gets the configuration for this adaptor. + """ + if not hasattr(self, "_config"): + self._config = self.config_manager.build_config() + + return self._config + + def _load_configuration_manager(self) -> ConfigurationManager[_T]: + """ + Loads a configuration manager using the module of this instance. + + Raises: + KeyError: Raised when the module is not loaded. + ValueError: Raised when the module is not a package or does not have a file path set. + """ + module = sys.modules.get(self.__module__) + if module is None: + raise KeyError(f"Module not loaded: {self.__module__}") + + module_info = _ModuleInfo(module) + if not module_info.package: + raise ValueError(f"Module {module_info.name} is not a package") + + adaptor_name = type(self).__name__ + config_cls = ( + self._config_opts.config_cls + if self._config_opts and self._config_opts.config_cls + else AdaptorConfiguration + ) + config_path = ( + self._config_opts.config_path + if self._config_opts and self._config_opts.config_path + else None + ) + schema_path = self._config_opts.schema_path if self._config_opts else None + + def module_dir() -> str: + if not module_info.file: + raise ValueError(f"Module {module_info.name} does not have a file path set") + return os.path.dirname(os.path.abspath(module_info.file)) + + if not config_path: + config_path = os.path.join(module_dir(), f"{adaptor_name}.json") + if not schema_path: + schema_dir = os.environ.get(_ENV_CONFIG_SCHEMA_PATH_PREFIX) + if schema_dir: + # Schema dir was provided, so we assume a schema file exists at that location + schema_path = os.path.join(schema_dir, f"{adaptor_name}.schema.json") + else: + # Schema dir was not provided, so we only provide the default schema path if it + # exists + schema_path = os.path.join(module_dir(), f"{adaptor_name}.schema.json") + schema_path = schema_path if os.path.exists(schema_path) else None + + additional_config_paths = [] + adaptor_config_path_env = f"{adaptor_name.upper()}_{_ENV_CONFIG_PATH_TEMPLATE}" + if additional_config_path := os.environ.get(adaptor_config_path_env): + _logger.info(f"Found adaptor config environment variable: {adaptor_config_path_env}") + additional_config_paths.append(additional_config_path) + + return create_adaptor_configuration_manager( + config_cls=config_cls, + adaptor_name=adaptor_name, + default_config_path=config_path, + schema_path=schema_path, + additional_config_paths=additional_config_paths, + ) + + @classmethod + def update_status( + cls, *, progress: float | None = None, status_message: str | None = None + ) -> None: + """Using OpenJD stdout prefixes the adaptor will notify the + Worker Agent about the progress, status message, or both""" + if progress is None and status_message is None: + _logger.warning("Both progress and status message were None. Ignoring status update.") + return + + if progress is not None: + if math.isfinite(progress): + sys.stdout.write(f"{cls._OPENJD_PROGRESS_STDOUT_PREFIX}{progress}{os.linesep}") + sys.stdout.flush() + else: + _logger.warning( + f"Attempted to set progress to something non-finite: {progress}. " + "Ignoring progress update." + ) + if status_message is not None: + sys.stdout.write(f"{cls._OPENJD_STATUS_STDOUT_PREFIX}{status_message}{os.linesep}") + sys.stdout.flush() + + @property + def path_mapping_rules(self) -> list[PathMappingRule]: + """Returns the list of path mapping rules""" + return self._path_mapping_rules.copy() + + def map_path(self, path: str) -> str: + """Applies path mapping rules to the given path. + Returns original path if no rules matched""" + for rule in self._path_mapping_rules: + changed, new_path = rule.apply(path=path) + if changed: + return new_path + + return path + + +class _ModuleInfo: # pragma: no cover + """ + This class wraps the ModuleType class and provides getters for magic attributes (e.g. __name__) + so that they can be mocked in unit tests, since unittest.mock does not allow some magic + attributes to be mocked. + """ + + def __init__(self, module: ModuleType) -> None: + self._module = module + + @property + def package(self) -> str | None: + return self._module.__package__ + + @property + def file(self) -> str | None: + return self._module.__file__ + + @property + def name(self) -> str: + return self._module.__name__ diff --git a/src/openjd/adaptor_runtime/adaptors/_command_adaptor.py b/src/openjd/adaptor_runtime/adaptors/_command_adaptor.py new file mode 100644 index 0000000..0a83db7 --- /dev/null +++ b/src/openjd/adaptor_runtime/adaptors/_command_adaptor.py @@ -0,0 +1,68 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +from abc import abstractmethod +from typing import TypeVar + +from .configuration import AdaptorConfiguration +from ..process import ManagedProcess +from ._base_adaptor import BaseAdaptor + +__all__ = [ + "CommandAdaptor", +] + +_T = TypeVar("_T", bound=AdaptorConfiguration) + + +class CommandAdaptor(BaseAdaptor[_T]): + """ + Base class for command adaptors that utilize a ManagedProcess. + + Derived classes must override the get_managed_process method, and + may optionally override the on_prerun and on_postrun methods. + """ + + def _start(self): # pragma: no cover + pass + + def _run(self, run_data: dict): + process = self.get_managed_process(run_data) + + self.on_prerun() + process.run() + self.on_postrun() + + def _stop(self): # pragma: no cover + pass + + def _cleanup(self): # pragma: no cover + pass + + @abstractmethod + def get_managed_process(self, run_data: dict) -> ManagedProcess: # pragma: no cover + """ + Gets the ManagedProcess for this adaptor to run. + + Args: + run_data (dict): The data required by the ManagedProcess. + + Returns: + ManagedProcess: The ManagedProcess to run. + """ + pass + + def on_prerun(self): # pragma: no cover + """ + Method that is invoked before the ManagedProcess is run. + You can override this method to run code before the ManagedProcess is run. + """ + pass + + def on_postrun(self): # pragma: no cover + """ + Method that is invoked after the ManagedProcess is run. + You can override this method to run code after the ManagedProcess is run. + """ + pass diff --git a/src/openjd/adaptor_runtime/adaptors/_path_mapping.py b/src/openjd/adaptor_runtime/adaptors/_path_mapping.py new file mode 100644 index 0000000..b5b4996 --- /dev/null +++ b/src/openjd/adaptor_runtime/adaptors/_path_mapping.py @@ -0,0 +1,139 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +from pathlib import PurePath, PurePosixPath, PureWindowsPath + +from .._osname import OSName + +__all__ = [ + "PathMappingRule", +] + + +class PathMappingRule: + """A PathMappingRule represents how to transform a valid rooted path to another + valid rooted path within or across different platforms. + + This is useful for consolidating environments that refer to physical storage with + different paths. Consider the following example: + + Given: + - A Storage Device: "SharedStorage" + - A Windows instance: "Env1" with "SharedStorage" mounted at "Z:\\movie1" + - A Linux instance: "Env2" with "SharedStorage" mounted at "/mnt/shared/movie1" + + If "Env2" wanted to perform work generated by "Env1" that references "SharedStorage" + then we need to apply the following transformation to paths that reference "SharedStorage" + on "Env2": + "Z:\\movie1" -> "/mnt/shared/movie1" + + Some nuance to consider is that paths valid on one system, may not be valid on another. ie. + - Windows paths are capitalization/directory separator agnostic, whereas Posix is not + "Z:/MOVIE1" is equivalent to "z:\\movie1" + - Posix paths can have colons (":") and backslashes ("\\") in the filename whereas + Windows cannot + - When transforming a Windows path to Posix, any backslashes (directory separators) + must be changed to forward slashes + - Spurious slashes and single dots are collapsed, but ".." in paths are not, + since this would change the meaning of a path in the face of symbolic links + """ + + _pure_source_path: PurePath + _pure_destination_path: PurePath + + def __init__( + self, + *, + source_os: str, + source_path: str, + destination_path: str, + destination_os: str = OSName(), + ): + for label, value in ( + ("source_os", source_os), + ("source_path", source_path), + ("destination_path", destination_path), + ): + if not value: + raise ValueError(f"{label} cannot be None or empty") + + self.source_path: str = source_path + self.destination_path: str = destination_path + self._source_os: str = OSName(source_os) # Raises ValueError if not valid OS + self._is_windows_source: bool = OSName.is_windows(self._source_os) + + self._destination_os: str = OSName(destination_os) # Raises ValueError if not valid OS + self._is_windows_destination: bool = OSName.is_windows(self._destination_os) + + if self._is_windows_source: + self._pure_source_path = PureWindowsPath(self.source_path) + else: + self._pure_source_path = PurePosixPath(self.source_path) + + if self._is_windows_destination: + self._pure_destination_path = PureWindowsPath(self.destination_path) + else: + self._pure_destination_path = PurePosixPath(self.destination_path) + + def __eq__(self, other): + return ( + self.source_path == other.source_path + and self.destination_path == other.destination_path + and self._is_windows_source == other._is_windows_source + and self._is_windows_destination == other._is_windows_destination + ) + + @staticmethod + def from_dict(*, rule: dict[str, str]) -> PathMappingRule: + """Builds a PathMappingRule given a dict with the fields required by __init__ + raises TypeError, ValueError: if rule is None, an empty dict, or nonvalid""" + if not rule: + raise ValueError("Empty path mapping rule") + + return PathMappingRule(**rule) + + def to_dict(self) -> dict[str, str]: + """Builds a PathMappingRule given a dict with the fields required by __init__ + raises TypeError, ValueError: if rule is None, an empty dict, or nonvalid""" + return { + "source_os": self._source_os, + "source_path": self.source_path, + "destination_os": self._destination_os, + "destination_path": self.destination_path, + } + + def apply(self, *, path: str) -> tuple[bool, str]: + """Applies the path mapping rule on the given path, if it matches the rule. + Does not collapse ".." since symbolic paths could be used. + + Returns: tuple[bool, str] - indicating if the path matched the rule and the resulting + mapped path. If it doesn't match, then it returns the original path unmodified. + """ + pure_path = self._get_pure_path(path) + if not self._is_match(pure_path=pure_path): + return False, path + + return True, str(self._swap_source_for_dest(pure_path)) + + def _is_match(self, *, pure_path: PurePath) -> bool: + """Determines if the supplied path matches the path mapping rule""" + return pure_path.is_relative_to(self._pure_source_path) + + def _get_pure_path(self, path: str) -> PurePath: + """Assumes that the path received matches the source os of the rule""" + if self._is_windows_source: + return PureWindowsPath(path) + else: + return PurePosixPath(path) + + def _swap_source_for_dest(self, pure_path: PurePath) -> PurePath: + """Given that pure_path matches the rule, return a PurePath where the source + parts are swapped for the destination parts""" + new_parts = ( + self._pure_destination_path.parts + pure_path.parts[len(self._pure_source_path.parts) :] + ) + if self._is_windows_destination: + return PureWindowsPath(*new_parts) + else: + return PurePosixPath(*new_parts) diff --git a/src/openjd/adaptor_runtime/adaptors/_validator.py b/src/openjd/adaptor_runtime/adaptors/_validator.py new file mode 100644 index 0000000..cccd746 --- /dev/null +++ b/src/openjd/adaptor_runtime/adaptors/_validator.py @@ -0,0 +1,156 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import json +import jsonschema +import logging +import os +import yaml +from typing import Any + + +_logger = logging.getLogger(__name__) + + +class AdaptorDataValidators: + """ + Class that contains validators for Adaptor input data. + """ + + @classmethod + def for_adaptor(cls, schema_dir: str) -> AdaptorDataValidators: + """ + Gets the validators for the specified adaptor. + + Args: + adaptor_name (str): The name of the adaptor + """ + init_data_schema_path = os.path.join(schema_dir, "init_data.schema.json") + _logger.info("Loading 'init_data' schema from %s", init_data_schema_path) + run_data_schema_path = os.path.join(schema_dir, "run_data.schema.json") + _logger.info("Loading 'run_data' schema from %s", run_data_schema_path) + + init_data_validator = AdaptorDataValidator.from_schema_file(init_data_schema_path) + + run_data_validator = AdaptorDataValidator.from_schema_file(run_data_schema_path) + + return AdaptorDataValidators(init_data_validator, run_data_validator) + + def __init__( + self, + init_data_validator: AdaptorDataValidator, + run_data_validator: AdaptorDataValidator, + ) -> None: + self._init_data_validator = init_data_validator + self._run_data_validator = run_data_validator + + @property + def init_data(self) -> AdaptorDataValidator: + """ + Gets the validator for init_data. + """ + return self._init_data_validator + + @property + def run_data(self) -> AdaptorDataValidator: + """ + Gets the validator for run_data. + """ + return self._run_data_validator + + +class AdaptorDataValidator: + """ + Class that validates the input data for an Adaptor. + """ + + @staticmethod + def from_schema_file(schema_path: str) -> AdaptorDataValidator: + """ + Creates an AdaptorDataValidator with the JSON schema at the specified file path. + + Args: + schema_path (str): The path to the JSON schema file to use. + """ + try: + with open(schema_path) as schema_file: + schema = json.load(schema_file) + except json.JSONDecodeError as e: + _logger.error(f"Failed to decode JSON schema file: {e}") + raise + except OSError as e: + _logger.error(f"Failed to open JSON schema file at {schema_path}: {e}") + raise + + if not isinstance(schema, dict): + raise ValueError(f"Expected JSON schema to be a dict, but got {type(schema)}") + + return AdaptorDataValidator(schema) + + def __init__(self, schema: dict) -> None: + self._schema = schema + + def validate(self, data: str | dict) -> None: + """ + Validates that the data adheres to the schema. + + The data argument can be one of the following: + - A string containing the data file path. Must be prefixed with "file://". + - A string-encoded version of the data. + - A dictionary containing the data. + + Args: + data (dict): The data to validate. + + Raises: + jsonschema.ValidationError: Raised when the data failed validate against the schema. + jsonschema.SchemaError: Raised when the schema itself is nonvalid. + """ + if isinstance(data, str): + data = _load_data(data) + + jsonschema.validate(data, self._schema) + + +def _load_data(data: str) -> dict: + """ + Parses an input JSON/YAML (filepath or string-encoded) into a dictionary. + + Args: + data (str): The filepath or string representation of the JSON/YAML to parse. + If this is a filepath, it must begin with "file://" + + Raises: + ValueError: Raised when the JSON/YAML is not parsed to a dictionary. + """ + try: + loaded_data = _load_yaml_json(data) + except OSError as e: + _logger.error(f"Failed to open data file: {e}") + raise + except yaml.YAMLError as e: + _logger.error(f"Failed to load data as JSON or YAML: {e}") + raise + + if not isinstance(loaded_data, dict): + raise ValueError(f"Expected loaded data to be a dict, but got {type(loaded_data)}") + + return loaded_data + + +def _load_yaml_json(data: str) -> Any: + """ + Loads a YAML/JSON file/string. + + Note that yaml.safe_load() is capable of loading JSON documents. + """ + loaded_yaml = None + if data.startswith("file://"): + filepath = data[len("file://") :] + with open(filepath) as yaml_file: + loaded_yaml = yaml.safe_load(yaml_file) + else: + loaded_yaml = yaml.safe_load(data) + + return loaded_yaml diff --git a/src/openjd/adaptor_runtime/adaptors/configuration/__init__.py b/src/openjd/adaptor_runtime/adaptors/configuration/__init__.py new file mode 100644 index 0000000..f3eff43 --- /dev/null +++ b/src/openjd/adaptor_runtime/adaptors/configuration/__init__.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +""" +This module contains the configuration classes used for the adaptor runtime and the adaptors +themselves. + +The base Configuration class exposes a "config" property that returns a dictionary of the loaded +config values with defaults injected where applicable. To do this, the following pattern is used: + +1. The "config" property obtains a function with a "registry" attribute that maps function names to +functions from the virtual class method "_get_defaults_decorator". +2. For each key in the "registry" that is not in the loaded configuration, the corresponding +function is invoked to obtain the "default" value to inject into the returned configuration. + +By default, "_get_defaults_decorator" returns a no-op decorator that has an empty registry. +Classes that derive the base Configuration class can override this class method to return a +decorator created by the "_make_function_register_decorator" function. This decorator actually +registers functions it is applied to, so it can be used to mark properties that should have default +values injected if none are loaded. For example, the following subclass uses this pattern to mark +the "my_config_key" property as one that uses a default value: + +Note: The property name must match the corresponding key in the configuration dictionary. + +class MyConfiguration(Configuration): + + _defaults = _make_function_register_decorator() + + @classmethod + def _get_defaults_decorator(cls) -> Any: + return cls._defaults + + @property + @_defaults + def my_config_key(self) -> str: + return self._config.get("my_config_key", "default_value") +""" + +from ._configuration import ( + AdaptorConfiguration, + Configuration, + RuntimeConfiguration, +) +from ._configuration_manager import ConfigurationManager + +__all__ = ["AdaptorConfiguration", "Configuration", "ConfigurationManager", "RuntimeConfiguration"] diff --git a/src/openjd/adaptor_runtime/adaptors/configuration/_adaptor_configuration.schema.json b/src/openjd/adaptor_runtime/adaptors/configuration/_adaptor_configuration.schema.json new file mode 100644 index 0000000..2130eaf --- /dev/null +++ b/src/openjd/adaptor_runtime/adaptors/configuration/_adaptor_configuration.schema.json @@ -0,0 +1,10 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "log_level": { + "enum": ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + "default": "INFO" + } + } +} diff --git a/src/openjd/adaptor_runtime/adaptors/configuration/_configuration.py b/src/openjd/adaptor_runtime/adaptors/configuration/_configuration.py new file mode 100644 index 0000000..a2d7a69 --- /dev/null +++ b/src/openjd/adaptor_runtime/adaptors/configuration/_configuration.py @@ -0,0 +1,206 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + +from __future__ import annotations + +import copy +import json +import jsonschema +import logging +from json.decoder import JSONDecodeError +from typing import Any, List, Literal, Type, TypeVar + +__all__ = [ + "Configuration", + "RuntimeConfiguration", +] + +_logger = logging.getLogger(__name__) + + +def _make_function_register_decorator(): + """ + Creates a decorator function that registers functions. + + If used on a function with the @property decorator, the outermost decorator must have + "# type: ignore" to avoid a mypy error. See https://github.com/python/mypy/issues/1362 + + See the comment block at the top of this file for more details. + """ + registry = {} + + def register(fn): + registry[fn.__name__] = fn + return fn + + register.registry = registry # type: ignore + return register + + +_T = TypeVar("_T", bound="Configuration") + + +class Configuration: + """ + General class for a JSON-based configuration. + + + This class should not be instantiated directly. Use one of the following class methods to + instantiate this class: + - `Configuration.from_file` + """ + + @classmethod + def _get_defaults_decorator(cls) -> Any: # pragma: no cover + """ + Virtual class method to get the defaults decorator. Defaults to an empty defaults registry. + + Override this in a subclass and return the value from _make_function_register_decorator() + to have default values automatically applied to the .config property getter return value. + + See the comment block at the top of this file for more details. + """ + + def register(_): + pass + + register.registry = {} # type: ignore + return register + + @classmethod + def from_file( + cls: Type[_T], config_path: str, schema_path: str | List[str] | None = None + ) -> _T: + """ + Loads a Configuration from a JSON file. + + Args: + config_path (str): The path to the JSON file containing the configuration + schema_path (str, List[str], Optional): The path(s) to the JSON Schema file to validate + the configuration JSON with. If multiple are specified, they will be used in the order + they are provided. If left as None, validation will be skipped. + """ + + try: + config = json.load(open(config_path)) + except OSError as e: + _logger.error(f"Failed to open configuration at {config_path}: {e}") + raise + except JSONDecodeError as e: + _logger.error(f"Failed to decode configuration at {config_path}: {e}") + raise + + if schema_path is None: + _logger.warning( + f"JSON Schema file path not provided. " + f"Configuration file {config_path} will not be validated." + ) + return cls(config) + elif not schema_path: + raise ValueError(f"Schema path cannot be an empty {type(schema_path)}") + + schema_paths = schema_path if isinstance(schema_path, list) else [schema_path] + for path in schema_paths: + try: + schema = json.load(open(path)) + except OSError as e: + _logger.error(f"Failed to open configuration schema at {path}: {e}") + raise + except JSONDecodeError as e: + _logger.error(f"Failed to decode configuration schema at {path}: {e}") + raise + + try: + jsonschema.validate(config, schema) + except jsonschema.ValidationError as e: + _logger.error( + f"Configuration file at {config_path} failed to validate against the JSON " + f"schema at {schema_path}: {e}" + ) + raise + + return cls(config) + + def __init__(self, config: dict) -> None: + self._config = config + + def override(self: _T, other: _T) -> _T: + """ + Creates a new Configuration with the configuration values in this object overriden by + another configuration. + + Args: + other (Configuration): The configuration with the override values. + + Returns: + Configuration: A new Configuration with overridden values. + """ + return self.__class__(copy.deepcopy({**self._config, **other._config})) + + @property + def config(self) -> dict: + """ + Gets the configuration dictionary with defaults applied to any missing required fields. + + See the comment block at the top of this file for more details. + """ + config = copy.deepcopy(self._config) + defaults = self.__class__._get_defaults_decorator() + + for fn_name, fn in defaults.registry.items(): + if fn_name not in config: # pragma: no branch + config[fn_name] = fn(self) + + return config + + +class RuntimeConfiguration(Configuration): + """ + Configuration for the Adaptor Runtime. + """ + + _defaults = _make_function_register_decorator() + + @classmethod + def _get_defaults_decorator(cls) -> Any: # pragma: no cover + return cls._defaults + + @property # type: ignore + @_defaults + def log_level( + self, + ) -> Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]: # pragma: no cover # noqa: F821 + """ + The log level that is used in the runtime. + """ + return self._config.get("log_level", "INFO") + + @property # type: ignore + @_defaults + def deactivate_telemetry(self) -> bool: # pragma: no cover + """ + Indicates whether telemetry is deactivated or not. + """ + return self._config.get("deactivate_telemetry", False) + + +class AdaptorConfiguration(Configuration): + """ + Configuration for adaptors. + """ + + _defaults = _make_function_register_decorator() + + @classmethod + def _get_defaults_decorator(cls) -> Any: # pragma: no cover + return cls._defaults + + @property # type: ignore + @_defaults + def log_level( + self, + ) -> Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]: # noqa: F821 # pragma: no cover + """ + The log level that is used in this adaptor. + """ + return self._config.get("log_level", "INFO") diff --git a/src/openjd/adaptor_runtime/adaptors/configuration/_configuration_manager.py b/src/openjd/adaptor_runtime/adaptors/configuration/_configuration_manager.py new file mode 100644 index 0000000..3a95ed7 --- /dev/null +++ b/src/openjd/adaptor_runtime/adaptors/configuration/_configuration_manager.py @@ -0,0 +1,273 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import json +import logging +import os +import posixpath +import stat +from typing import Generic, List, Type, TypeVar + +from ..._utils import secure_open +from ..._osname import OSName +from ._configuration import AdaptorConfiguration, Configuration + +__all__ = [ + "ConfigurationManager", + "create_adaptor_configuration_manager", +] + +_logger = logging.getLogger(__name__) + +_DIR = os.path.dirname(os.path.realpath(__file__)) + +_ConfigType = TypeVar("_ConfigType", bound=Configuration) +_AdaptorConfigType = TypeVar("_AdaptorConfigType", bound=AdaptorConfiguration) + + +def create_adaptor_configuration_manager( + config_cls: Type[_AdaptorConfigType], + adaptor_name: str, + default_config_path: str, + schema_path: str | List[str] | None = None, + additional_config_paths: list[str] = [], +) -> ConfigurationManager[_AdaptorConfigType]: + """ + Creates a ConfigurationManager for an adaptor. + + Args: + config_cls (Type[U], optional): The adaptor configuration class to create a configuration + manager for. + adaptor_name (str): The name of the adaptor. + default_config_path (str): The path to the adaptor's default configuration file. + schema_path (str, List[str], optional): The path(s) to a JSON Schema file(s) to validate + the configuration with. If left as None, only the base adaptor configuration values will be + validated. + """ + schema_paths = [os.path.abspath(os.path.join(_DIR, "_adaptor_configuration.schema.json"))] + if isinstance(schema_path, str): + schema_paths.append(schema_path) + elif isinstance(schema_path, list): + schema_paths.extend(schema_path) + + system_config_path_map = { + "Linux": posixpath.abspath( + posixpath.join( + posixpath.sep, + "etc", + "openjd", + "adaptors", + adaptor_name, + f"{adaptor_name}.json", + ) + ) + } + user_config_rel_path = os.path.join(".openjd", "adaptors", adaptor_name, f"{adaptor_name}.json") + + return ConfigurationManager( + config_cls=config_cls, + default_config_path=default_config_path, + system_config_path_map=system_config_path_map, + user_config_rel_path=user_config_rel_path, + schema_path=schema_paths, + additional_config_paths=additional_config_paths, + ) + + +def _ensure_config_file(filepath: str, *, create: bool = False) -> bool: + """ + Ensures a config file path points to a file. + + Args: + filepath (str): The file path to validate. + create (bool): Whether to create an empty config if the file does not exist. + + Returns: + bool: True if the path points to a file, false otherwise. If create is set to True, this + function will return True if the file was created successfully, false otherwise. + """ + if not os.path.exists(filepath): + _logger.debug(f'Configuration file at "{filepath}" does not exist.') + if not create: + return False + + _logger.info(f"Creating empty configuration at {filepath}") + try: + os.makedirs(os.path.dirname(filepath), mode=stat.S_IRWXU, exist_ok=True) + with secure_open(filepath, open_mode="w") as f: + json.dump({}, f) + except OSError as e: + _logger.warning(f"Could not write empty configuration to {filepath}: {e}") + return False + else: + return True + elif not os.path.isfile(filepath): + _logger.debug(f'Configuration file at "{filepath}" is not a file.') + return False + else: + return True + + +class ConfigurationManager(Generic[_ConfigType]): + """ + Class that manages configuration. + """ + + def __init__( + self, + *, + config_cls: Type[_ConfigType], + default_config_path: str, + system_config_path_map: dict, + user_config_rel_path: str, + schema_path: str | List[str] | None = None, + additional_config_paths: list[str] = [], + ) -> None: + """ + Initializes a ConfigurationManager object. + + Args: + config_cls (Type[T]): The Configuration class that this class manages. + default_config_path (str): The path to the default configuration JSON file. + system_config_path_map (dict): A dictionary containing a mapping of OS names to system + configuration file path. + user_config_rel_path (str): The path to the user configuration file relative to the + user's home directory. + schema_path (str, List[str], Optional): The path(s) to the JSON Schema file to use. + If multiple are given then they will be used in the order they are provided. + If none are given then validation will be skipped for configuration files. + additional_config_paths (list[str]): Paths to additional configuration files. These + will have the highest priority and will be applied in the order they are provided. + """ + self._config_cls = config_cls + self._schema_path = schema_path + self._default_config_path = default_config_path + self._system_config_path_map = system_config_path_map + self._user_config_rel_path = user_config_rel_path + self._additional_config_paths = additional_config_paths + + def get_default_config(self) -> _ConfigType: + """ + Gets the default configuration. + """ + if not _ensure_config_file(self._default_config_path, create=True): + _logger.warning( + f"Default configuration file at {self._default_config_path} is not a valid file. " + "Using empty configuration." + ) + return self._config_cls({}) + + return self._config_cls.from_file(self._default_config_path, self._schema_path) + + def get_system_config_path(self) -> str: + """ + Gets the system-level configuration file path. + + Raises: + NotImplementedError: Raised when no mapping exists for the system config path on + the current system/OS. + """ + try: + system = OSName() + except ValueError: # Can happen on unsupported platforms like Java + raise NotImplementedError() + + path = self._system_config_path_map.get(system, None) + if path is None: + raise NotImplementedError() + + return path + + def get_system_config(self) -> _ConfigType | None: + """ + Gets the system-level configuration. Any values defined here will override the default + configuration. + """ + config_path = self.get_system_config_path() + return ( + self._config_cls.from_file(config_path, self._schema_path) + if _ensure_config_file(config_path) + else None + ) + + def get_user_config_path(self, username: str | None = None) -> str: + """ + Gets the user-level configuration file path. + + Args: + username (str, optional): The username to get the configuration for. Defaults to the + current user. + """ + user = f"~{username}" if username else "~" + + # os.path.expanduser works cross-platform (Windows & UNIX) + return os.path.expanduser(os.path.join(user, self._user_config_rel_path)) + + def get_user_config(self, username: str | None = None) -> _ConfigType | None: + """ + Gets the user-level configuration. Any values defined here will override the default and + system configuration. + + Args: + username (str, optional): The username to get the configuration for. Defaults to the + current user. + """ + config_path = self.get_user_config_path(username) + return ( + self._config_cls.from_file(config_path, self._schema_path) + if _ensure_config_file(config_path, create=True) + else None + ) + + def build_config(self, username: str | None = None) -> _ConfigType: + """ + Builds a Configuration with the default, system, and user level configuration files. + + Args: + username (str, optional): The username to use for the user-level configuration. + Defaults to the current user. + """ + + def log_diffs(a: _ConfigType, b: _ConfigType): + def _config_to_set(config: _ConfigType): + # Convert inner dicts to str because elements in a set must be hashable. + # This aligns with our override logic that only overrides top-level keys. + return set( + [ + (k, v if not isinstance(v, dict) else str(v)) + for k, v in config._config.items() + ] + ) + + diffs = dict(_config_to_set(a) - _config_to_set(b)) + for k, v in diffs.items(): + _logger.info(f"Set {k} to {v}") + + config: _ConfigType = self.get_default_config() + + system_config = self.get_system_config() + if system_config: + _logger.info(f"Applying system-level configuration: {self.get_system_config_path()}") + old_config = config + config = config.override(system_config) + log_diffs(config, old_config) + + user_config = self.get_user_config(username) + if user_config: + _logger.info(f"Applying user-level configuration: {self.get_user_config_path()}") + old_config = config + config = config.override(user_config) + log_diffs(config, old_config) + + for path in self._additional_config_paths: + if not _ensure_config_file(path): + _logger.warning(f"Failed to load additional configuration: {path}. Skipping...") + continue + + _logger.info(f"Applying additional configuration: {path}") + old_config = config + config = config.override(self._config_cls.from_file(path)) + log_diffs(config, old_config) + + return config diff --git a/src/openjd/adaptor_runtime/app_handlers/__init__.py b/src/openjd/adaptor_runtime/app_handlers/__init__.py new file mode 100644 index 0000000..e7083ee --- /dev/null +++ b/src/openjd/adaptor_runtime/app_handlers/__init__.py @@ -0,0 +1,5 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from ._regex_callback_handler import RegexCallback, RegexHandler + +__all__ = ["RegexCallback", "RegexHandler"] diff --git a/src/openjd/adaptor_runtime/app_handlers/_regex_callback_handler.py b/src/openjd/adaptor_runtime/app_handlers/_regex_callback_handler.py new file mode 100644 index 0000000..93ffa6f --- /dev/null +++ b/src/openjd/adaptor_runtime/app_handlers/_regex_callback_handler.py @@ -0,0 +1,111 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import logging +import re +from dataclasses import dataclass +from typing import Callable, List, Sequence + + +@dataclass +class RegexCallback: + """ + Dataclass for regex callbacks + """ + + regex_list: List[re.Pattern[str]] + callback: Callable[[re.Match], None] + exit_if_matched: bool = False + only_run_if_first_matched: bool = False + + def __init__( + self, + regex_list: Sequence[re.Pattern[str]], + callback: Callable[[re.Match], None], + exit_if_matched: bool = False, + only_run_if_first_matched: bool = False, + ) -> None: + """ + Initializes a RegexCallback + + Args: + regex_list (Sequence[re.Pattern[str]]): A sequence of regex patterns which will invoke + the callback if any single regex matches a logged string. This will be stored as a + separate list object than the sequence passed in the constructor. + callback (Callable[[re.Match], None]): A callable which takes a re.Match object as the + only argument. The re.Match object is from the pattern that matched the string + tested against it. + exit_if_matched (bool, optional): Indicates if the handler should exit early if this + RegexCallback is matched. This will prevent future RegexCallbacks from being + invoked if this RegexCallback matched first. Defaults to False. + only_run_if_first_matched (bool, optional): Indicates if the handler should only + call the callback if this RegexCallback was the first to have a regex match a logged + line. + """ + self.regex_list = list(regex_list) + self.callback = callback + self.exit_if_matched = exit_if_matched + self.only_run_if_first_matched = only_run_if_first_matched + + def get_match(self, msg: str) -> re.Match | None: + """ + Provides the first regex in self.regex_list that matches a given msg. + + Args: + msg (str): A message to test against each regex in the regex_list + + Returns: + re.Match | None: The match object from the first regex that matched the message, none + if no regex matched. + """ + for regex in self.regex_list: + if match := regex.search(msg): + return match + return None + + +class RegexHandler(logging.Handler): + """ + A Logging Handler that adds the ability to call Callbacks based on Regex + Matches of logged lines. + """ + + regex_callbacks: List[RegexCallback] + + def __init__( + self, regex_callbacks: Sequence[RegexCallback], level: int = logging.NOTSET + ) -> None: + """ + Initializes a RegexHandler + + Args: + regex_callbacks (Sequence[RegexCallback]): A sequence of RegexCallback objects which + will be iterated through on each logged message. RegexCallbacks are tested and + called in the same order as they are provided in the sequence. + + A new list object will be created from the provided sequence, if the callback list + needs to be modified then you must access the new list through the regex_callbacks + property. + level (int, optional): A minimum level of message that will be handled. + Defaults to logging.NOTSET. + """ + super().__init__(level) + self.regex_callbacks = list(regex_callbacks) + + def emit(self, record: logging.LogRecord) -> None: + """ + Method which is called by the logger when a string is logged to a logger + this handler has been added to. + Args: + record (logging.LogRecord): The log record of the logged string + """ + matched = False + for regex_callback in self.regex_callbacks: + if matched and regex_callback.only_run_if_first_matched: + continue + if match := regex_callback.get_match(record.msg): + regex_callback.callback(match) + if match and regex_callback.exit_if_matched: + break + matched = matched or match is not None diff --git a/src/openjd/adaptor_runtime/application_ipc/__init__.py b/src/openjd/adaptor_runtime/application_ipc/__init__.py new file mode 100644 index 0000000..55ad594 --- /dev/null +++ b/src/openjd/adaptor_runtime/application_ipc/__init__.py @@ -0,0 +1,6 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from ._actions_queue import ActionsQueue +from ._adaptor_server import AdaptorServer + +__all__ = ["ActionsQueue", "AdaptorServer", "ServerAddress"] diff --git a/src/openjd/adaptor_runtime/application_ipc/_actions_queue.py b/src/openjd/adaptor_runtime/application_ipc/_actions_queue.py new file mode 100644 index 0000000..0f93347 --- /dev/null +++ b/src/openjd/adaptor_runtime/application_ipc/_actions_queue.py @@ -0,0 +1,46 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +from collections import deque +from typing import TYPE_CHECKING +from typing import Deque +from typing import Optional + +if TYPE_CHECKING: # pragma: no cover because pytest will think we should test for this. + from openjd.adaptor_runtime_client import Action + + +class ActionsQueue: + """This class will manage the Queue of Actions. This class will be reponsible for + enqueueing, or dequeueing Actions, and converting actions to and from json strings.""" + + _actions_queue: Deque[Action] + + def __init__(self) -> None: + self._actions_queue = deque() + + def enqueue_action(self, a: Action, front: bool = False) -> None: + """This function will enqueue the action to the end of the queue. + + Args: + a (Action): The action to be enqueued. + front (bool, optional): Whether we want to append to the front of the queue. + Defaults to False. + """ + if front: + self._actions_queue.appendleft(a) + else: + self._actions_queue.append(a) + + def dequeue_action(self) -> Optional[Action]: + if len(self) > 0: + return self._actions_queue.popleft() + else: + return None + + def __bool__(self) -> bool: + return bool(self._actions_queue) + + def __len__(self) -> int: + return len(self._actions_queue) diff --git a/src/openjd/adaptor_runtime/application_ipc/_adaptor_server.py b/src/openjd/adaptor_runtime/application_ipc/_adaptor_server.py new file mode 100644 index 0000000..cf91dbe --- /dev/null +++ b/src/openjd/adaptor_runtime/application_ipc/_adaptor_server.py @@ -0,0 +1,43 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import os +from socketserver import UnixStreamServer +from typing import TYPE_CHECKING + +from .._http import SocketDirectories +from ._http_request_handler import AdaptorHTTPRequestHandler + +if TYPE_CHECKING: # pragma: no cover because pytest will think we should test for this. + from ..adaptors import BaseAdaptor + from ._actions_queue import ActionsQueue + + +class AdaptorServer(UnixStreamServer): + """ + This is the Adaptor server which will be passed the populated ActionsQueue from the Adaptor. + """ + + actions_queue: ActionsQueue + adaptor: BaseAdaptor + + def __init__( + self, + actions_queue: ActionsQueue, + adaptor: BaseAdaptor, + ) -> None: # pragma: no cover + socket_path = SocketDirectories.for_os().get_process_socket_path("dcc", create_dir=True) + super().__init__(socket_path, AdaptorHTTPRequestHandler) + + self.actions_queue = actions_queue + self.adaptor = adaptor + self.socket_path = socket_path + + def shutdown(self) -> None: # pragma: no cover + super().shutdown() + + try: + os.remove(self.socket_path) + except FileNotFoundError: + pass diff --git a/src/openjd/adaptor_runtime/application_ipc/_http_request_handler.py b/src/openjd/adaptor_runtime/application_ipc/_http_request_handler.py new file mode 100644 index 0000000..5918e89 --- /dev/null +++ b/src/openjd/adaptor_runtime/application_ipc/_http_request_handler.py @@ -0,0 +1,136 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import json +import sys +from http import HTTPStatus +from time import sleep +from typing import TYPE_CHECKING +from typing import Optional + +from .._http import HTTPResponse, RequestHandler, ResourceRequestHandler + +if TYPE_CHECKING: # pragma: no cover because pytest will think we should test for this. + from openjd.adaptor_runtime_client import Action + + from ._adaptor_server import AdaptorServer + + +class AdaptorHTTPRequestHandler(RequestHandler): + """This is the HTTPRequestHandler to be used by the Adaptor Server. This class is + where we will dequeue the actions from the queue and pass it in a response to a client. + """ + + server: AdaptorServer # This is here for type hinting. + + def __init__( + self, + request: bytes, + client_address: str, + server: AdaptorServer, + ) -> None: + super().__init__(request, client_address, server, AdaptorResourceRequestHandler) + + +class AdaptorResourceRequestHandler(ResourceRequestHandler): + """ + Base class that handles HTTP requests for a specific resource. + + This class only works with an AdaptorServer. + """ + + server: AdaptorServer # This is just for type hinting + + +class PathMappingEndpoint(AdaptorResourceRequestHandler): + path = "/path_mapping" + + def get(self) -> HTTPResponse: + """ + GET Handler for the Path Mapping Endpoint + + Returns: + HTTPResponse: A body and response code to send to the DCC Client + """ + try: + if "path" in self.query_string_params: + return HTTPResponse( + HTTPStatus.OK, + json.dumps( + {"path": self.server.adaptor.map_path(self.query_string_params["path"][0])} + ), + ) + else: + return HTTPResponse(HTTPStatus.BAD_REQUEST, "Missing path in query string.") + except Exception as e: + return HTTPResponse(HTTPStatus.INTERNAL_SERVER_ERROR, body=str(e)) + + +class PathMappingRulesEndpoint(AdaptorResourceRequestHandler): + path = "/path_mapping_rules" + + def get(self) -> HTTPResponse: + """ + GET Handler for the Path Mapping Rules Endpoint + + Returns: + HTTPResponse: A body and response code to send to the DCC Client + """ + return HTTPResponse( + HTTPStatus.OK, + json.dumps( + { + "path_mapping_rules": [ + rule.to_dict() for rule in self.server.adaptor.path_mapping_rules + ] + } + ), + ) + + +class ActionEndpoint(AdaptorResourceRequestHandler): + path = "/action" + + def get(self) -> HTTPResponse: + """ + GET handler for the Action end point of the Adaptor Server that communicates with the client + spawned in the DCC. + + Returns: + HTTPResponse: A body and response code to send to the DCC Client + """ + action = self._dequeue_action() + + # We are going to wait until we have an action in the queue. This + # could happen between tasks. + while action is None: + sleep(0.01) + action = self._dequeue_action() + + return HTTPResponse(HTTPStatus.OK, str(action)) + + def _dequeue_action(self) -> Optional[Action]: + """This function will dequeue the first action in the queue. + + Returns: + Action: A tuple containing the next action structured: + ("action_name", { "args1": "val1", "args2": "val2" }) + + None: If the Actions Queue is empty. + + Raises: + TypeError: If the server isn't an AdaptorServer. + """ + # This condition shouldn't matter, because we have typehinted the server above. + # This is only here for type hinting (as is the return None below). + if hasattr(self, "server") and hasattr(self.server, "actions_queue"): + return self.server.actions_queue.dequeue_action() + + print( + "ERROR: Could not retrieve the next action because the server or actions queue " + "wasn't set.", + file=sys.stderr, + flush=True, + ) + return None diff --git a/src/openjd/adaptor_runtime/configuration.json b/src/openjd/adaptor_runtime/configuration.json new file mode 100644 index 0000000..205ae71 --- /dev/null +++ b/src/openjd/adaptor_runtime/configuration.json @@ -0,0 +1,4 @@ +{ + "log_level": "INFO", + "deactivate_telemetry": false +} \ No newline at end of file diff --git a/src/openjd/adaptor_runtime/configuration.schema.json b/src/openjd/adaptor_runtime/configuration.schema.json new file mode 100644 index 0000000..2444b80 --- /dev/null +++ b/src/openjd/adaptor_runtime/configuration.schema.json @@ -0,0 +1,14 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "log_level": { + "enum": ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + "default": "INFO" + }, + "deactivate_telemetry": { + "type": "boolean", + "default": false + } + } +} diff --git a/src/openjd/adaptor_runtime/process/__init__.py b/src/openjd/adaptor_runtime/process/__init__.py new file mode 100644 index 0000000..2627c4b --- /dev/null +++ b/src/openjd/adaptor_runtime/process/__init__.py @@ -0,0 +1,11 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from ._logging_subprocess import LoggingSubprocess +from ._managed_process import ManagedProcess +from ._stream_logger import StreamLogger + +__all__ = [ + "LoggingSubprocess", + "ManagedProcess", + "StreamLogger", +] diff --git a/src/openjd/adaptor_runtime/process/_logging.py b/src/openjd/adaptor_runtime/process/_logging.py new file mode 100644 index 0000000..2d068b1 --- /dev/null +++ b/src/openjd/adaptor_runtime/process/_logging.py @@ -0,0 +1,12 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import logging + +_STDOUT_LEVEL = logging.ERROR + 1 +logging.addLevelName(_STDOUT_LEVEL, "STDOUT") + +_STDERR_LEVEL = logging.ERROR + 2 +logging.addLevelName(_STDERR_LEVEL, "STDERR") + +_ADAPTOR_OUTPUT_LEVEL = logging.ERROR + 3 +logging.addLevelName(_ADAPTOR_OUTPUT_LEVEL, "ADAPTOR_OUTPUT") diff --git a/src/openjd/adaptor_runtime/process/_logging_subprocess.py b/src/openjd/adaptor_runtime/process/_logging_subprocess.py new file mode 100644 index 0000000..853e7b7 --- /dev/null +++ b/src/openjd/adaptor_runtime/process/_logging_subprocess.py @@ -0,0 +1,220 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +"""Module for the LoggingSubprocess class""" +from __future__ import annotations + +import logging +import subprocess +import sys +import uuid +from types import TracebackType +from typing import Any, Sequence, TypeVar + +from ..app_handlers import RegexHandler +from ._logging import _STDERR_LEVEL, _STDOUT_LEVEL +from ._stream_logger import StreamLogger + +__all__ = ["LoggingSubprocess"] + +_logger = logging.getLogger(__name__) + + +class LoggingSubprocess(object): + """A process whose stdout/stderr lines are sent to a configurable logger""" + + _logger: logging.Logger + _process: subprocess.Popen + _stdout_logger: StreamLogger + _stderr_logger: StreamLogger + _terminate_threads: bool + + def __init__( + self, + *, + # Required keyword-only arguments + args: Sequence[str], + # Optional keyword-only arguments + startup_directory: str | None = None, # This is None, because Popen's default is None + logger: logging.Logger = _logger, + stdout_handler: RegexHandler | None = None, + stderr_handler: RegexHandler | None = None, + encoding: str = "utf-8", + ): + if not logger: + raise ValueError("No logger specified") + if not args or len(args) < 1: + raise ValueError("Insufficient args") + + self._terminate_threads = False + self._logger = logger + + self._logger.info("Running command: %s", subprocess.list2cmdline(args)) + + # Create the subprocess + self._process = subprocess.Popen( + args, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + encoding=encoding, + cwd=startup_directory, + ) + + if not self._process.stdout: # pragma: no cover + raise RuntimeError("process stdout not set") + if not self._process.stderr: # pragma: no cover + raise RuntimeError("process stdout not set") + + # Create the stdout/stderr stream logging threads + stdout_loggers = [self._logger] + stderr_loggers = [self._logger] + proc_uuid = uuid.uuid4() # ensure loggers are unique to this process for regexhandlers + + def _register_handler(logger_name: str, handler: RegexHandler) -> logging.Logger: + """Registers a handler with the logger name provided and returns the logger""" + handler_logger = logging.getLogger(logger_name) + handler_logger.setLevel(1) + handler_logger.addHandler(handler) + return handler_logger + + if stdout_handler is not None: + stdout_loggers.append(_register_handler(f"stdout-{proc_uuid}", stdout_handler)) + if stderr_handler is not None: + stderr_loggers.append(_register_handler(f"stderr-{proc_uuid}", stderr_handler)) + + self._stdout_logger = StreamLogger( + name="AdaptorRuntimeStdoutLogger", + stream=self._process.stdout, + loggers=stdout_loggers, + level=_STDOUT_LEVEL, + ) + self._stderr_logger = StreamLogger( + name="AdaptorRuntimeStderrLogger", + stream=self._process.stderr, + loggers=stderr_loggers, + level=_STDERR_LEVEL, + ) + + self._stdout_logger.start() + self._stderr_logger.start() + + @property + def pid(self) -> int: + """Returns the PID of the sub-process""" + return self._process.pid + + @property + def returncode(self) -> int | None: + """ + Before accessing this property, ensure the process has been terminated (calling wait() or + terminate()). You can check is_running before accessing this value. + + :return: None if the process has not yet exited. Otherwise, it returns the exit code of the + process + """ + # poll() is required to update the returncode + # See https://docs.python.org/3/library/subprocess.html#subprocess.Popen.poll + poll_result = self._process.poll() + return poll_result + + @property + def is_running(self) -> bool: + """ + Determine whether the subprocess is running. + :return: True if it is running; False otherwise + """ + return self._process is not None and self._process.poll() is None + + def __enter__(self) -> LoggingSubprocess: + return self + + def __exit__(self, type: TypeVar, value: Any, traceback: TracebackType) -> None: + self.wait() + + def _cleanup_io_threads(self) -> None: + self._logger.debug( + "Finished terminating/waiting for the process. About to cleanup the IO threads." + ) + + # Wait for the logging threads to exit + self._terminate_threads = True + + self._stdout_logger.join() + if not self._process.stdout: # pragma: no cover + raise RuntimeError("process stdout not piped") + # Must be after the join; before will cause an exception due to file in use. + self._process.stdout.close() + + self._stderr_logger.join() + if not self._process.stderr: # pragma: no cover + raise RuntimeError("process stderr not piped") + # Must be after the join; before will cause an exception due to file in use. + self._process.stderr.close() + + def terminate(self, grace_time_s: float = 60) -> None: + """ + Sends a signal to soft terminate (SIGTERM) the process after the passed grace time (in + seconds). If the grace time is 0 or the process hasn't terminated after the grace period, + sending SIGKILL to interrupt/terminate the process. + """ + if not self._process or self._terminate_threads: + return + self._logger.debug(f"Asked to terminate the subprocess (pid={self._process.pid}).") + + if not self.is_running: + self._logger.info("Cannot terminate the process, because it is not running.") + return + + # If we want to stop the process immediately. + if grace_time_s == 0: + self._logger.info(f"Immediately stopping process (pid={self._process.pid}).") + self._process.kill() + self._process.wait() + else: + # On Windows, process.kill is an alias for process.terminate. We may want to replicate + # the behaviour of this function on Windows. + # Here is a blog post for reference: https://maruel.ca/post/python_windows_signal/ + if sys.platform == "win32": + raise NotImplementedError() + + self._logger.info( + f"Sending the SIGTERM signal to pid={self._process.pid} and waiting {grace_time_s}" + " seconds for it to exit." + ) + self._process.terminate() # SIGTERM + + try: + self._process.wait(timeout=grace_time_s) + self._logger.info(f"Finished terminating the subprocess (pid={self._process.pid}).") + except subprocess.TimeoutExpired: + self._logger.info( + f"Process (pid={self._process.pid}) did not complete in the allotted time " + "after the SIGTERM signal, now sending the SIGKILL signal." + ) + self._process.kill() # SIGKILL, on Windows, this is an alias for terminate + self._process.wait() + self._logger.info(f"Finished killing the subprocess (pid={self._process.pid}).") + + # _process.communicate will close the _process.stdout and _process.stderr. + self._cleanup_io_threads() + + def wait(self) -> None: + """ + Waits for the process to finish. + """ + if not self._process or self._terminate_threads: + return + self._logger.info(f"Asked to wait for the subprocess (pid={self._process.pid}) to finish.") + + # Wait for the running process + if self.is_running: + if not self._process.stdin: # pragma: no cover + raise RuntimeError("process stdin not piped") + self._process.stdin.close() + + self._logger.debug(f"Telling pid {self._process.pid} to wait.") + self._process.wait() + + self._logger.info(f"Finished waiting for the subprocess (pid={self._process.pid}).") + + self._cleanup_io_threads() diff --git a/src/openjd/adaptor_runtime/process/_managed_process.py b/src/openjd/adaptor_runtime/process/_managed_process.py new file mode 100644 index 0000000..b27b79e --- /dev/null +++ b/src/openjd/adaptor_runtime/process/_managed_process.py @@ -0,0 +1,71 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +"""Module for the ManagedProcess class""" +from __future__ import annotations + +from abc import ABC as ABC, abstractmethod +from typing import List + +from ..app_handlers import RegexHandler +from ._logging_subprocess import LoggingSubprocess + +__all__ = ["ManagedProcess"] + + +class ManagedProcess(ABC): + def __init__( + self, + run_data: dict, + *, + stdout_handler: RegexHandler | None = None, + stderr_handler: RegexHandler | None = None, + ): + self.run_data = run_data + self.stdout_handler = stdout_handler + self.stderr_handler = stderr_handler + + # =============================================== + # Callbacks / virtual functions. + # =============================================== + + @abstractmethod + def get_executable(self) -> str: # pragma: no cover + """ + Return the path of the executable to run. + """ + raise NotImplementedError() + + def get_arguments(self) -> List[str]: # pragma: no cover + """ + Return the args (as a list) to be used with the executable. + """ + return [] + + def get_startup_directory(self) -> str | None: # pragma: no cover + """ + Returns The directory that the executable should be run from. + Note: Does not require that spaces be escaped + """ + # This defaults to None because that is the default for Popen. + return None + + # =============================================== + # Render Control + # =============================================== + + def run(self): + """ + Create a LoggingSubprocess to run the command. + """ + exec = self.get_executable() + args = self.get_arguments() + args = [exec] + args + startup_directory = self.get_startup_directory() + + subproc = LoggingSubprocess( + args=args, + startup_directory=startup_directory, + stdout_handler=self.stdout_handler, + stderr_handler=self.stderr_handler, + ) + subproc.wait() diff --git a/src/openjd/adaptor_runtime/process/_stream_logger.py b/src/openjd/adaptor_runtime/process/_stream_logger.py new file mode 100644 index 0000000..8627106 --- /dev/null +++ b/src/openjd/adaptor_runtime/process/_stream_logger.py @@ -0,0 +1,62 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +"""Module for the StreamLogger class""" +from __future__ import annotations + +import logging +import os +from threading import Thread +from typing import IO, Sequence + + +class StreamLogger(Thread): + """A thread that reads a text stream line-by-line and logs each line to a specified logger""" + + def __init__( + self, + *args, + # Required keyword-only arguments + stream: IO[str], + loggers: Sequence[logging.Logger], + # Optional keyword-only arguments + level: int = logging.INFO, + **kwargs, + ): + super(StreamLogger, self).__init__(*args, **kwargs) + self._stream = stream + self._loggers = list(loggers) + self._level = level + + # Without setting daemon to False, we run into an issue in which all output may NOT be + # printed. From the python docs: + # > The entire Python program exits when no alive non-daemon threads are left. + # Reference: https://docs.python.org/3/library/threading.html#threading.Thread.daemon + self.daemon = False + + def _log(self, line: str, level: int | None = None): + """ + Logs a line to each logger at the provided level or self._level is no level is provided. + Args: + line (str): The line to log + level (int): The level to log the line at + """ + if level is None: + level = self._level + + for logger in self._loggers: + logger.log(level, line) + + def run(self): + try: + for line in iter(self._stream.readline, ""): + line = line.rstrip(os.linesep) + self._log(line) + except ValueError as e: + if "I/O operation on closed file" in str(e): + self._log( + "The StreamLogger could not read from the stream. This is most likely because " + "the stream was closed before the stream logger.", + logging.WARNING, + ) + else: + raise diff --git a/src/openjd/adaptor_runtime/py.typed b/src/openjd/adaptor_runtime/py.typed new file mode 100644 index 0000000..7ef2116 --- /dev/null +++ b/src/openjd/adaptor_runtime/py.typed @@ -0,0 +1 @@ +# Marker file that indicates this package supports typing diff --git a/src/openjd/adaptor_runtime_client/__init__.py b/src/openjd/adaptor_runtime_client/__init__.py new file mode 100644 index 0000000..a85384b --- /dev/null +++ b/src/openjd/adaptor_runtime_client/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from .action import Action +from .client_interface import ( + HTTPClientInterface, + PathMappingRule, +) + +__all__ = [ + "Action", + "HTTPClientInterface", + "PathMappingRule", +] diff --git a/src/openjd/adaptor_runtime_client/action.py b/src/openjd/adaptor_runtime_client/action.py new file mode 100644 index 0000000..63f7c52 --- /dev/null +++ b/src/openjd/adaptor_runtime_client/action.py @@ -0,0 +1,50 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import json as _json +import sys as _sys +from dataclasses import asdict as _asdict +from dataclasses import dataclass as _dataclass +from typing import Any as _Any +from typing import Dict as _Dict +from typing import Optional as _Optional + + +@_dataclass(frozen=True) +class Action: + """This is the class representation of the Actions to be performed on the DCC.""" + + name: str + args: _Optional[_Dict[str, _Any]] = None + + def __str__(self) -> str: + return _json.dumps(_asdict(self)) + + @staticmethod + def from_json_string(json_str: str) -> _Optional[Action]: + try: + ad = _json.loads(json_str) + except Exception as e: + print( + f'ERROR: Unable to convert "{json_str}" to json. The following exception was ' + f"raised:\n{e}", + file=_sys.stderr, + flush=True, + ) + return None + + try: + return Action(ad["name"], ad["args"]) + except Exception as e: + print( + f"ERROR: Unable to convert the json dictionary ({ad}) to an action. The following " + f"exception was raised:\n{e}", + file=_sys.stderr, + flush=True, + ) + return None + + @staticmethod + def from_bytes(s: bytes) -> _Optional[Action]: + return Action.from_json_string(s.decode()) diff --git a/src/openjd/adaptor_runtime_client/client_interface.py b/src/openjd/adaptor_runtime_client/client_interface.py new file mode 100644 index 0000000..9c98d27 --- /dev/null +++ b/src/openjd/adaptor_runtime_client/client_interface.py @@ -0,0 +1,224 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import json as _json +import signal as _signal +import sys as _sys +from abc import ABC as _ABC +from abc import abstractmethod as _abstractmethod +from dataclasses import dataclass as _dataclass +from functools import lru_cache as _lru_cache +from http import HTTPStatus as _HTTPStatus +from types import FrameType as _FrameType +from typing import ( + Any as _Any, + Callable as _Callable, + Dict as _Dict, + List as _List, + Tuple as _Tuple, +) +from urllib.parse import urlencode as _urlencode + +from .action import Action as _Action +from .connection import UnixHTTPConnection as _UnixHTTPConnection + +# Set timeout to None so our requests are blocking calls with no timeout. +# See socket.settimeout +_REQUEST_TIMEOUT = None + + +# Based on adaptor runtime's PathMappingRule class +# This is needed because we cannot import from adaptor runtime directly +# due to some applications running an older Python version that can't import newer typing +@_dataclass +class PathMappingRule: + source_os: str + source_path: str + destination_path: str + destination_os: str + + +class HTTPClientInterface(_ABC): + actions: _Dict[str, _Callable[..., None]] + socket_path: str + + def __init__(self, socket_path: str) -> None: + """When the client is created, we need the port number to connect to the server. + + Args: + socket_path (str): The path to the UNIX domain socket to use. + """ + self.socket_path = socket_path + self.actions = { + "close": self.close, + } + + # NOTE: The signals SIGKILL and SIGSTOP cannot be caught, blocked, or ignored. + # Reference: https://man7.org/linux/man-pages/man7/signal.7.html + # SIGTERM graceful shutdown. + _signal.signal(_signal.SIGTERM, self.graceful_shutdown) + + def _request_next_action(self) -> _Tuple[int, str, _Action | None]: + """Sending a get request to the server on the /action endpoint. + This will be used to poll for the next action from the Adaptor server. + + Returns: + _Tuple[int, str, _Action | None]: Returns the status code (int), the status reason + (str), the action if one was received (_Action | None). + """ + headers = { + "Content-type": "application/json", + } + connection = _UnixHTTPConnection(self.socket_path, timeout=_REQUEST_TIMEOUT) + connection.request("GET", "/action", headers=headers) + response = connection.getresponse() + connection.close() + + action = None + if response.length: + action = _Action.from_bytes(response.read()) + return response.status, response.reason, action + + @_lru_cache(maxsize=None) + def map_path(self, path: str) -> str: + """Sending a get request to the server on the /path_mapping endpoint. + This will be used to get the Adaptor to map a given path. + + Returns: + str: The mapped path + + Raises: + RuntimeError: When the client fails to get a mapped path from the server. + """ + headers = { + "Content-type": "application/json", + } + connection = _UnixHTTPConnection(self.socket_path, timeout=_REQUEST_TIMEOUT) + print(f"Requesting Path Mapping for path '{path}'.", flush=True) + connection.request("GET", "/path_mapping?" + _urlencode({"path": path}), headers=headers) + response = connection.getresponse() + connection.close() + + if response.status == _HTTPStatus.OK and response.length: + response_dict = _json.loads(response.read().decode()) + mapped_path = response_dict.get("path") + if mapped_path is not None: # pragma: no branch # HTTP 200 guarantees a mapped path + print(f"Mapped path '{path}' to '{mapped_path}'.", flush=True) + return mapped_path + reason = response.read().decode() if response.length else "" + raise RuntimeError( + f"ERROR: Failed to get a mapped path for path '{path}'. " + f"Server response: Status: {int(response.status)}, Response: '{reason}'", + ) + + @_lru_cache(maxsize=None) + def path_mapping_rules(self) -> _List[PathMappingRule]: + """Sending a get request to the server on the /path_mapping_rules endpoint. + This will be used to get the Adaptor to map a given path. + + Returns: + _List[_PathMappingRule]: The list of path mapping rules + + Raises: + RuntimeError: When the client fails to get a mapped path from the server. + """ + headers = { + "Content-type": "application/json", + } + connection = _UnixHTTPConnection(self.socket_path, timeout=_REQUEST_TIMEOUT) + print("Requesting Path Mapping Rules.", flush=True) + connection.request("GET", "/path_mapping_rules", headers=headers) + response = connection.getresponse() + connection.close() + + if response.status != _HTTPStatus.OK or not response.length: + reason = response.read().decode() if response.length else "" + raise RuntimeError( + f"ERROR: Failed to get a path mapping rules. " + f"Server response: Status: {int(response.status)}, Response: '{reason}'", + ) + + try: + response_dict = _json.loads(response.read().decode()) + except _json.JSONDecodeError as e: + raise RuntimeError( + f"Expected JSON string from /path_mapping_rules endpoint, but got error: {e}", + ) + + rule_list = response_dict.get("path_mapping_rules") + if not isinstance(rule_list, list): + raise RuntimeError( + f"Expected list for path_mapping_rules, but got: {rule_list}", + ) + + rules: _List[PathMappingRule] = [] + for rule in rule_list: + try: + rules.append(PathMappingRule(**rule)) + except TypeError as e: + raise RuntimeError( + f"Expected PathMappingRule object, but got: {rule}\nAll rules: {rule_list}\nError: {e}", + ) + + return rules + + def poll(self) -> None: + """ + This function will poll the server for the next task. If the server is in between Subtasks + (no actions in the queue), a backoff function will be called to add a delay between the + requests. + """ + run = True + while run: + status, reason, action = self._request_next_action() + if status == _HTTPStatus.OK: + if action is not None: + print( + f"Performing action: {action}", + flush=True, + ) + self._perform_action(action) + run = action.name != "close" + else: # Any other status or reason + print( + f"ERROR: An error was raised when trying to connect to the server: {status} " + f"{reason}", + file=_sys.stderr, + flush=True, + ) + + def _perform_action(self, a: _Action) -> None: + try: + action_func = self.actions[a.name] + except KeyError: + print( + f"ERROR: Attempted to perform the following action: {a}. But this action doesn't " + "exist in the actions dictionary.", + file=_sys.stderr, + flush=True, + ) + else: + action_func(a.args) + + @_abstractmethod + def close(self, args: _Dict[str, _Any] | None) -> None: # pragma: no cover + """This is the close function which will be called to cleanup the Application. + + Args: + args (_Dict[str, _Any] | None): The arguments (if any) required to perform the + cleanup. + """ + pass + + @_abstractmethod + def graceful_shutdown(self, signum: int, frame: _FrameType | None) -> None: # pragma: no cover + """This is the function when we cancel. This function is called when a SIGTERM signal is + received. This functions will need to be implemented for each application we want to + support because the clean up will be different for each application. + + Args: + signum (int): The signal number. + frame (_FrameType | None): The current stack frame (None or a frame object). + """ + pass diff --git a/src/openjd/adaptor_runtime_client/connection.py b/src/openjd/adaptor_runtime_client/connection.py new file mode 100644 index 0000000..5f94af3 --- /dev/null +++ b/src/openjd/adaptor_runtime_client/connection.py @@ -0,0 +1,73 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import socket as _socket +import ctypes as _ctypes +import os as _os +from http.client import HTTPConnection as _HTTPConnection + + +class UnrecognizedBackgroundConnectionError(Exception): + pass + + +class UCred(_ctypes.Structure): + """ + Represents the ucred struct returned from the SO_PEERCRED socket option. + + For more info, see SO_PASSCRED in the unix(7) man page + """ + + _fields_ = [ + ("pid", _ctypes.c_int), + ("uid", _ctypes.c_int), + ("gid", _ctypes.c_int), + ] + + def __str__(self): # pragma: no cover + return f"pid:{self.pid} uid:{self.uid} gid:{self.gid}" + + +class UnixHTTPConnection(_HTTPConnection): # pragma: no cover + """ + Specialization of http.client.HTTPConnection class that uses a UNIX domain socket. + """ + + def __init__(self, host, **kwargs): + kwargs.pop("strict", None) # Removed in py3 + super(UnixHTTPConnection, self).__init__(host, **kwargs) + + def connect(self): + sock = _socket.socket(_socket.AF_UNIX, _socket.SOCK_STREAM) + sock.settimeout(self.timeout) + sock.connect(self.host) + self.sock = sock + + # Verify that the socket belongs to the same user + if not self._authenticate(): + sock.detach(self.sock) + raise UnrecognizedBackgroundConnectionError( + "Attempted to make a connection to a background server owned by another user." + ) + + def _authenticate(self) -> bool: + # Verify we have a UNIX socket. + if not ( + isinstance(self.sock, _socket.socket) + and self.sock.family == _socket.AddressFamily.AF_UNIX + ): + raise NotImplementedError( + "Failed to handle request because it was not made through a UNIX socket" + ) + + # Get the credentials of the peer process + cred_buffer = self.sock.getsockopt( + _socket.SOL_SOCKET, + _socket.SO_PEERCRED, + _socket.CMSG_SPACE(_ctypes.sizeof(UCred)), + ) + peer_cred = UCred.from_buffer_copy(cred_buffer) + + # Only allow connections from a process running as the same user + return peer_cred.uid == _os.getuid() diff --git a/src/openjd/adaptor_runtime_client/py.typed b/src/openjd/adaptor_runtime_client/py.typed new file mode 100644 index 0000000..7ef2116 --- /dev/null +++ b/src/openjd/adaptor_runtime_client/py.typed @@ -0,0 +1 @@ +# Marker file that indicates this package supports typing diff --git a/test/openjd/adaptor_runtime/__init__.py b/test/openjd/adaptor_runtime/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime/conftest.py b/test/openjd/adaptor_runtime/conftest.py new file mode 100644 index 0000000..d9aad84 --- /dev/null +++ b/test/openjd/adaptor_runtime/conftest.py @@ -0,0 +1,27 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import platform + +import pytest + +# List of platforms that can be used to mark tests as specific to that platform +# See [tool.pytest.ini_options] -> markers in pyproject.toml +_PLATFORMS = set( + [ + "Linux", + "Windows", + "Darwin", + ] +) + + +def pytest_runtest_setup(item: pytest.Item): + """ + Hook that is run for each test. + """ + + # Skip platform-specific tests that don't apply to current platform + supported_platforms = set(_PLATFORMS).intersection(mark.name for mark in item.iter_markers()) + plat = platform.system() + if supported_platforms and plat not in supported_platforms: + pytest.skip(f"Skipping non-{plat} test: {item.name}") diff --git a/test/openjd/adaptor_runtime/integ/IntegCommandAdaptor/IntegCommandAdaptor.json b/test/openjd/adaptor_runtime/integ/IntegCommandAdaptor/IntegCommandAdaptor.json new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/IntegCommandAdaptor/IntegCommandAdaptor.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/test/openjd/adaptor_runtime/integ/IntegCommandAdaptor/__init__.py b/test/openjd/adaptor_runtime/integ/IntegCommandAdaptor/__init__.py new file mode 100644 index 0000000..bc90b20 --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/IntegCommandAdaptor/__init__.py @@ -0,0 +1,6 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from .__main__ import main +from .adaptor import IntegCommandAdaptor + +__all__ = ["IntegCommandAdaptor", "main"] diff --git a/test/openjd/adaptor_runtime/integ/IntegCommandAdaptor/__main__.py b/test/openjd/adaptor_runtime/integ/IntegCommandAdaptor/__main__.py new file mode 100644 index 0000000..b0c943c --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/IntegCommandAdaptor/__main__.py @@ -0,0 +1,31 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import logging as _logging +import sys as _sys + +from openjd.adaptor_runtime import EntryPoint as _EntryPoint + +from .adaptor import IntegCommandAdaptor + +__all__ = ["main"] +_logger = _logging.getLogger(__name__) + + +def main(): + _logger.info("About to start the IntegCommandAdaptor") + + package_name = vars(_sys.modules[__name__])["__package__"] + if not package_name: + raise RuntimeError(f"Must be run as a module. Do not run {__file__} directly") + + try: + _EntryPoint(IntegCommandAdaptor).start() + except Exception as e: + _logger.error(f"Entrypoint failed: {e}") + _sys.exit(1) + + _logger.info("Done IntegCommandAdaptor main") + + +if __name__ == "__main__": + main() diff --git a/test/openjd/adaptor_runtime/integ/IntegCommandAdaptor/adaptor.py b/test/openjd/adaptor_runtime/integ/IntegCommandAdaptor/adaptor.py new file mode 100644 index 0000000..9710cff --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/IntegCommandAdaptor/adaptor.py @@ -0,0 +1,40 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import os +from typing import List +from logging import getLogger +from openjd.adaptor_runtime.adaptors import CommandAdaptor +from openjd.adaptor_runtime.process import ManagedProcess + +logger = getLogger(__name__) + + +class IntegManagedProcess(ManagedProcess): + def __init__(self, run_data: dict) -> None: + super().__init__(run_data) + + def get_executable(self) -> str: + return os.path.abspath(os.path.join(os.path.sep, "bin", "echo")) + + def get_arguments(self) -> List[str]: + return self.run_data.get("args", "") + + +class IntegCommandAdaptor(CommandAdaptor): + def __init__(self, init_data: dict, path_mapping_data: dict): + super().__init__(init_data, path_mapping_data=path_mapping_data) + + def get_managed_process(self, run_data: dict) -> ManagedProcess: + return IntegManagedProcess(run_data) + + def on_prerun(self): + # Print only goes to stdout and is not captured in daemon mode. + print("prerun-print") + # Logging is captured in daemon mode. + logger.info(self.init_data.get("on_prerun", "")) + + def on_postrun(self): + # Print only goes to stdout and is not captured in daemon mode. + print("postrun-print") + # Logging is captured in daemon mode. + logger.info(self.init_data.get("on_postrun", "")) diff --git a/test/openjd/adaptor_runtime/integ/__init__.py b/test/openjd/adaptor_runtime/integ/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime/integ/adaptors/__init__.py b/test/openjd/adaptor_runtime/integ/adaptors/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/adaptors/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime/integ/adaptors/configuration/__init__.py b/test/openjd/adaptor_runtime/integ/adaptors/configuration/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/adaptors/configuration/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime/integ/adaptors/configuration/test_configuration.py b/test/openjd/adaptor_runtime/integ/adaptors/configuration/test_configuration.py new file mode 100644 index 0000000..a4e6048 --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/adaptors/configuration/test_configuration.py @@ -0,0 +1,68 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import json +import os +import pathlib as _pathlib +import tempfile + +import pytest + +from openjd.adaptor_runtime.adaptors.configuration import Configuration + + +class TestFromFile: + """ + Integration tests for the Configuration.from_file method + """ + + def test_loads_config(self): + # GIVEN + json_schema = { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": {"key": {"enum": ["value"]}}, + } + config = {"key": "value"} + + # fmt: off + with tempfile.NamedTemporaryFile(mode="w+") as schema_file, \ + tempfile.NamedTemporaryFile(mode="w+") as config_file: + json.dump(json_schema, schema_file.file) + json.dump(config, config_file.file) + schema_file.seek(0) + config_file.seek(0) + + # WHEN + result = Configuration.from_file( + config_path=config_file.name, + schema_path=schema_file.name, + ) + # fmt: on + + # THEN + assert result._config == config + + schema_file.close() + config_file.close() + + def test_raises_when_config_file_fails_to_open( + self, tmp_path: _pathlib.Path, caplog: pytest.LogCaptureFixture + ): + # GIVEN + with tempfile.NamedTemporaryFile(mode="w+") as schema_file: + json.dump({}, schema_file.file) + schema_file.seek(0) + non_existent_filepath = os.path.join(tmp_path.absolute(), "non_existent_file") + + # WHEN + with pytest.raises(OSError) as raised_err: + Configuration.from_file( + schema_path=schema_file.name, + config_path=non_existent_filepath, + ) + + # THEN + assert isinstance(raised_err.value, OSError) + assert f"Failed to open configuration at {non_existent_filepath}: " in caplog.text diff --git a/test/openjd/adaptor_runtime/integ/adaptors/configuration/test_configuration_manager.py b/test/openjd/adaptor_runtime/integ/adaptors/configuration/test_configuration_manager.py new file mode 100644 index 0000000..8c3f5eb --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/adaptors/configuration/test_configuration_manager.py @@ -0,0 +1,97 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import json +import os +import tempfile +from typing import IO + +import pytest + +from openjd.adaptor_runtime.adaptors.configuration import Configuration, ConfigurationManager + + +@pytest.mark.Linux +class TestConfigurationManagerLinux: + """ + Linux-specific integration tests for ConfigurationManager + """ + + @pytest.fixture + def json_schema_file(self): + json_schema = { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "key": {"enum": ["value"]}, + "syskey": {"type": "string"}, + "usrkey": {"type": "string"}, + }, + } + with tempfile.NamedTemporaryFile(mode="w+") as schema_file: + json.dump(json_schema, schema_file.file) + schema_file.seek(0) + yield schema_file + + def test_gets_system_config(self, json_schema_file: IO[str]): + # GIVEN + config = {"key": "value"} + + with tempfile.NamedTemporaryFile(mode="w+") as config_file: + json.dump(config, config_file.file) + config_file.seek(0) + manager = ConfigurationManager( + config_cls=Configuration, + schema_path=json_schema_file.name, + system_config_path_map={"Linux": config_file.name}, + # These fields can be empty since they will not be used in this test + default_config_path="", + user_config_rel_path="", + ) + + # WHEN + sys_config = manager.get_system_config() + + # THEN + assert sys_config is not None and sys_config._config == config + + def test_builds_config(self, json_schema_file: IO[str]): + # GIVEN + default_config = { + "key": "value", + "syskey": "value", + "usrkey": "value", + } + system_config = {"syskey": "system"} + user_config = {"usrkey": "user"} + + # fmt: off + homedir = os.path.expanduser("~") + with tempfile.NamedTemporaryFile(mode="w+") as default_config_file, \ + tempfile.NamedTemporaryFile(mode="w+") as system_config_file, \ + tempfile.NamedTemporaryFile(mode="w+", dir=homedir) as user_config_file: + json.dump(default_config, default_config_file) + json.dump(system_config, system_config_file) + json.dump(user_config, user_config_file) + default_config_file.seek(0) + system_config_file.seek(0) + user_config_file.seek(0) + + manager = ConfigurationManager( + config_cls=Configuration, + schema_path=json_schema_file.name, + default_config_path=default_config_file.name, + system_config_path_map={"Linux": system_config_file.name}, + user_config_rel_path=os.path.relpath( + user_config_file.name, + start=os.path.expanduser("~"), + ), + ) + + # WHEN + result = manager.build_config() + # fmt: on + + # THEN + assert result._config == {**default_config, **system_config, **user_config} diff --git a/test/openjd/adaptor_runtime/integ/adaptors/test_integration_adaptor.py b/test/openjd/adaptor_runtime/integ/adaptors/test_integration_adaptor.py new file mode 100644 index 0000000..aae0c4c --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/adaptors/test_integration_adaptor.py @@ -0,0 +1,104 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import os +import shutil +from pathlib import Path + +from openjd.adaptor_runtime.adaptors import Adaptor + + +class TestRun: + """ + Tests for the Adaptor._run method + """ + + _OPENJD_PROGRESS_STDOUT_PREFIX: str = "openjd_progress: " + _OPENJD_STATUS_STDOUT_PREFIX: str = "openjd_status: " + + def test_run(self, capsys) -> None: + first_progress = 0.0 + first_status_message = "Starting the printing of run_data" + second_progress = 100.0 + second_status_message = "Finished printing" + + class PrintAdaptor(Adaptor): + """ + Test implementation of an Adaptor. + """ + + def __init__(self, init_data: dict): + super().__init__(init_data) + + def on_run(self, run_data: dict): + # This run funciton will simply print the run_data. + self.update_status(progress=first_progress, status_message=first_status_message) + print("run_data:") + for key, value in run_data.items(): + print(f"\t{key} = {value}") + self.update_status(progress=second_progress, status_message=second_status_message) + + # GIVEN + init_data: dict = {} + run_data: dict = {"key1": "value1", "key2": "value2", "key3": "value3"} + adaptor = PrintAdaptor(init_data) + + # WHEN + adaptor._run(run_data) + result = capsys.readouterr().out.strip() + + # THEN + assert f"{self._OPENJD_PROGRESS_STDOUT_PREFIX}{first_progress}" in result + assert f"{self._OPENJD_STATUS_STDOUT_PREFIX}{first_status_message}" in result + assert f"{self._OPENJD_PROGRESS_STDOUT_PREFIX}{second_progress}" in result + assert f"{self._OPENJD_STATUS_STDOUT_PREFIX}{second_status_message}" in result + assert "run_data:\n\tkey1 = value1\n\tkey2 = value2\n\tkey3 = value3" in result + + def test_start_end_cleanup(self, tmpdir, capsys) -> None: + """ + We are going to test the start and end methods + """ + + class FileAdaptor(Adaptor): + def __init__(self, init_data: dict): + super().__init__(init_data) + + def on_start(self): + # Open a temp file + self.f = tmpdir.mkdir("test").join("hello.txt") + + def on_run(self, run_data: dict): + # Write hello world to temp file + self.f.write("Hello World from FileAdaptor!") + + def on_stop(self): + # Read from temp file + print(self.f.read()) + + def on_cleanup(self): + # Delete temp file + path = Path(str(self.f)) + parent_dir = path.parent.absolute() + os.remove(str(self.f)) + shutil.rmtree(parent_dir) + + init_dict: dict = {} + fa = FileAdaptor(init_dict) + + # Creates the path for the temp file. + fa._start() + + # Writes to the temp file + fa._run({}) + + # The file exists after writing. + assert os.path.exists(str(fa.f)) + + # Printing the contents of the file. + fa._stop() + assert capsys.readouterr().out.strip() == "Hello World from FileAdaptor!" + + # Deleting the file created before. + fa._cleanup() + assert not os.path.exists(str(fa.f)) diff --git a/test/openjd/adaptor_runtime/integ/adaptors/test_integration_path_mapping.py b/test/openjd/adaptor_runtime/integ/adaptors/test_integration_path_mapping.py new file mode 100644 index 0000000..f7d813a --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/adaptors/test_integration_path_mapping.py @@ -0,0 +1,350 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import pytest +from unittest.mock import MagicMock +from openjd.adaptor_runtime.adaptors import CommandAdaptor, PathMappingRule +from openjd.adaptor_runtime.process import ManagedProcess + + +class FakeCommandAdaptor(CommandAdaptor): + """ + Test implementation of a CommandAdaptor + """ + + def __init__(self, path_mapping_rules: list[dict]): + super().__init__({}, path_mapping_data={"path_mapping_rules": path_mapping_rules}) + + def get_managed_process(self, run_data: dict) -> ManagedProcess: + return MagicMock() + + +class TestGetPathMappingRules: + def test_no_rules(self) -> None: + # GIVEN + path_mapping_rules: list[dict] = [] + adaptor = FakeCommandAdaptor(path_mapping_rules) + + # WHEN + result = adaptor.path_mapping_rules + + # THEN + assert result == [] + + def test_one_rule(self) -> None: + # GIVEN + path_mapping_rules = [ + { + "source_os": "linux", + "source_path": "/mnt/shared/asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\asset_storage1", + } + ] + adaptor = FakeCommandAdaptor(path_mapping_rules) + + # WHEN + result = adaptor.path_mapping_rules + + # THEN + # Ensure we only got 1 rule back + assert len(result) == len(path_mapping_rules) + assert len(result) == 1 + + # Basic validation on the 1 rule we got back (ie. we can access source/destination) + assert isinstance(result[0], PathMappingRule) + assert result[0].source_path == path_mapping_rules[0]["source_path"] + assert result[0].destination_path == path_mapping_rules[0]["destination_path"] + + def test_many_rules(self) -> None: + # GIVEN + path_mapping_rules = [ + { + "source_os": "linux", + "source_path": "/mnt/shared/asset_storage0", + "destination_os": "windows", + "destination_path": "Z:\\asset_storage0", + }, + { + "source_os": "linux", + "source_path": "/mnt/shared/asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\asset_storage1", + }, + ] + adaptor = FakeCommandAdaptor(path_mapping_rules) + + # WHEN + result = adaptor.path_mapping_rules + + # THEN + assert len(result) > 1 + assert len(result) == len(path_mapping_rules) + assert all(isinstance(rule, PathMappingRule) for rule in result) + + def test_get_order_is_preserved(self) -> None: + # GIVEN + rule1 = { + "source_os": "linux", + "source_path": "/mnt/shared/asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\asset_storage1", + } + rule2 = { + "source_os": "windows", + "source_path": "Z:\\asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\should\\not\\reach\\this", + } + path_mapping_rules = [rule1, rule2] + adaptor = FakeCommandAdaptor(path_mapping_rules) + expected_rules = [ + PathMappingRule.from_dict(rule=rule1), + PathMappingRule.from_dict(rule=rule2), + ] + wrong_order_rules = [expected_rules[1], expected_rules[0]] + + # WHEN + result = adaptor.path_mapping_rules + + # THEN + # All lists haves the same length + assert len(result) == len(expected_rules) + assert len(result) == len(wrong_order_rules) + # Compare the lists to ensure they have the correct order + assert result == expected_rules + assert result != wrong_order_rules + + def test_rule_list_is_read_only(self) -> None: + # GIVEN + expected: list[dict] = [] + adaptor = FakeCommandAdaptor(expected) + rules = adaptor.path_mapping_rules + new_rule = PathMappingRule( + source_os="linux", + source_path="/mnt/shared/asset_storage1", + destination_os="windows", + destination_path="Z:\\asset_storage1", + ) + + # WHEN/THEN + with pytest.raises(AttributeError): + adaptor.path_mapping_rules = [new_rule] # type: ignore + + # WHEN/THEN + rules.append(new_rule) + assert adaptor.path_mapping_rules == expected + adaptor.path_mapping_rules.append(new_rule) + assert adaptor.path_mapping_rules == expected + + +class TestApplyPathMapping: + def test_no_change(self) -> None: + # GIVEN + path_mapping_rules: list[dict] = [] + adaptor = FakeCommandAdaptor(path_mapping_rules) + source_path = expected = "/mnt/shared/asset_storage1" + + # WHEN + result = adaptor.map_path(source_path) + + # THEN + assert result == expected + + def test_linux_to_windows(self) -> None: + # GIVEN + path_mapping_rules = [ + { + "source_os": "linux", + "source_path": "/mnt/shared/asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\asset_storage1", + } + ] + adaptor = FakeCommandAdaptor(path_mapping_rules) + source_path = "/mnt/shared/asset_storage1/asset.ext" + expected = "Z:\\asset_storage1\\asset.ext" + + # WHEN + result = adaptor.map_path(source_path) + + # THEN + assert result == expected + + def test_windows_to_linux(self) -> None: + # GIVEN + path_mapping_rules = [ + { + "source_os": "windows", + "source_path": "Z:\\asset_storage1", + "destination_os": "linux", + "destination_path": "/mnt/shared/asset_storage1", + } + ] + adaptor = FakeCommandAdaptor(path_mapping_rules) + source_path = "Z:\\asset_storage1\\asset.ext" + expected = "/mnt/shared/asset_storage1/asset.ext" + + # WHEN + result = adaptor.map_path(source_path) + + # THEN + assert result == expected + + def test_linux_to_linux(self) -> None: + # GIVEN + path_mapping_rules = [ + { + "source_os": "linux", + "source_path": "/mnt/shared/my_custom_path/asset_storage1", + "destination_os": "linux", + "destination_path": "/mnt/shared/asset_storage1", + } + ] + adaptor = FakeCommandAdaptor(path_mapping_rules) + + source_path = "/mnt/shared/my_custom_path/asset_storage1/asset.ext" + expected = "/mnt/shared/asset_storage1/asset.ext" + + # WHEN + result = adaptor.map_path(source_path) + + # THEN + assert result == expected + + def test_windows_to_windows(self) -> None: + # GIVEN + path_mapping_rules = [ + { + "source_os": "windows", + "source_path": "Z:\\my_custom_asset_path\\asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\asset_storage1", + } + ] + adaptor = FakeCommandAdaptor(path_mapping_rules) + source_path = "Z:\\my_custom_asset_path\\asset_storage1\\asset.ext" + expected = "Z:\\asset_storage1\\asset.ext" + + # WHEN + result = adaptor.map_path(source_path) + + # THEN + assert result == expected + + def test_windows_capitalization_agnostic(self) -> None: + # GIVEN + path_mapping_rules = [ + { + "source_os": "windows", + "source_path": "Z:\\my_custom_asset_path\\asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\asset_storage1", + } + ] + adaptor = FakeCommandAdaptor(path_mapping_rules) + source_path = f"{path_mapping_rules[0]['source_path'].upper()}\\asset.ext" + expected = "Z:\\asset_storage1\\asset.ext" + + # WHEN + result = adaptor.map_path(source_path) + + # THEN + assert result == expected + + def test_windows_directory_separator_agnostic(self) -> None: + # GIVEN + path_mapping_rules = [ + { + "source_os": "windows", + "source_path": "Z:\\my_custom_asset_path\\asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\asset_storage1", + } + ] + adaptor = FakeCommandAdaptor(path_mapping_rules) + source_path = "Z:/my_custom_asset_path/asset_storage1/asset.ext" + expected = "Z:\\asset_storage1\\asset.ext" + + # WHEN + result = adaptor.map_path(source_path) + + # THEN + assert result == expected + + def test_multiple_rules(self) -> None: + # GIVEN + path_mapping_rules = [ + { + "source_os": "linux", + "source_path": "/mnt/shared/asset_storage0", + "destination_os": "windows", + "destination_path": "Z:\\asset_storage0", + }, + { + "source_os": "linux", + "source_path": "/mnt/shared/asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\asset_storage1", + }, + ] + adaptor = FakeCommandAdaptor(path_mapping_rules) + source_path = "/mnt/shared/asset_storage1/asset.ext" + expected = "Z:\\asset_storage1\\asset.ext" + + # WHEN + result = adaptor.map_path(source_path) + + # THEN + assert result == expected + + def test_only_first_applied(self) -> None: + # GIVEN + path_mapping_rules = [ + { + "source_os": "linux", + "source_path": "/mnt/shared/asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\asset_storage1", + }, + { + "source_os": "windows", + "source_path": "Z:\\asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\should\\not\\reach\\this", + }, + ] + adaptor = FakeCommandAdaptor(path_mapping_rules) + source_path = "/mnt/shared/asset_storage1/asset.ext" + expected = "Z:\\asset_storage1\\asset.ext" + + # WHEN + result = adaptor.map_path(source_path) + + # THEN + assert result == expected + + def test_apply_order_is_preserved(self) -> None: + # GIVEN + path_mapping_rules = [ + { + "source_os": "linux", + "source_path": "/mnt/shared/asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\asset_storage1", + }, + { + "source_os": "linux", + "source_path": "/mnt/shared/asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\should\\not\\reach\\this", + }, + ] + adaptor = FakeCommandAdaptor(path_mapping_rules) + source_path = "/mnt/shared/asset_storage1/asset.ext" + expected = "Z:\\asset_storage1\\asset.ext" + + # WHEN + result = adaptor.map_path(source_path) + + # THEN + assert result == expected diff --git a/test/openjd/adaptor_runtime/integ/application_ipc/__init__.py b/test/openjd/adaptor_runtime/integ/application_ipc/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/application_ipc/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime/integ/application_ipc/fake_app_client.py b/test/openjd/adaptor_runtime/integ/application_ipc/fake_app_client.py new file mode 100644 index 0000000..3ccf980 --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/application_ipc/fake_app_client.py @@ -0,0 +1,24 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +from typing import Any as _Any +from typing import Dict as _Dict +from typing import Optional as _Optional + +from openjd.adaptor_runtime_client import HTTPClientInterface as _HTTPClientInterface + + +class FakeAppClient(_HTTPClientInterface): + def __init__(self, socket_path: str) -> None: + super().__init__(socket_path) + self.actions.update({"hello_world": self.hello_world}) + + def close(self, args: _Optional[_Dict[str, _Any]]) -> None: + print("closing") + + def hello_world(self, args: _Optional[_Dict[str, _Any]]) -> None: + print(f"args = {args}") + + def graceful_shutdown(self): + print("Gracefully shutting down.") diff --git a/test/openjd/adaptor_runtime/integ/application_ipc/test_integration_adaptor_ipc.py b/test/openjd/adaptor_runtime/integ/application_ipc/test_integration_adaptor_ipc.py new file mode 100644 index 0000000..e8d36f3 --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/application_ipc/test_integration_adaptor_ipc.py @@ -0,0 +1,185 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import threading as _threading +from time import sleep as _sleep +from unittest import mock as _mock + +import pytest +from openjd.adaptor_runtime_client import Action as _Action + +from openjd.adaptor_runtime.adaptors import Adaptor +from openjd.adaptor_runtime.application_ipc import ActionsQueue as _ActionsQueue +from openjd.adaptor_runtime.application_ipc import AdaptorServer as _AdaptorServer + +from .fake_app_client import FakeAppClient as _FakeAppClient + + +@pytest.fixture +def adaptor(): + class FakeAdaptor(Adaptor): + def __init__(self, path_mapping_rules): + super().__init__({}, path_mapping_data={"path_mapping_rules": path_mapping_rules}) + + def on_run(self, run_data: dict): + return + + path_mapping_rules = [ + { + "source_os": "windows", + "source_path": "Z:\\asset_storage1", + "destination_os": "linux", + "destination_path": "/mnt/shared/asset_storage1", + }, + { + "source_os": "windows", + "source_path": "🌚\\πŸŒ’\\πŸŒ“\\πŸŒ”\\🌝\\πŸŒ–\\πŸŒ—\\🌘\\🌚", + "destination_os": "linux", + "destination_path": "🌝/πŸŒ–/πŸŒ—/🌘/🌚/πŸŒ’/πŸŒ“/πŸŒ”/🌝", + }, + ] + + return FakeAdaptor(path_mapping_rules) + + +def start_test_server(test_server: _AdaptorServer): + """This is the function responsible for starting the test server. + + Args: + aq (_ActionsQueue): The queue containing the actions to be performed by application. + """ + test_server.serve_forever() + + +def start_test_client(client: _FakeAppClient): + """Given a client, this app will make the client poll for the next action. + + Args: + client (_FakeAppClient): The client used in our tests. + """ + client.poll() + + +class TestAdaptorIPC: + """Integration tests to for the Adaptor IPC.""" + + @pytest.mark.parametrize( + argnames=("source_path", "dest_path"), + argvalues=[ + ("Z:\\asset_storage1\\somefile.png", "/mnt/shared/asset_storage1/somefile.png"), + ("🌚\\πŸŒ’\\πŸŒ“\\πŸŒ”\\🌝\\πŸŒ–\\πŸŒ—\\🌘\\🌚", "🌝/πŸŒ–/πŸŒ—/🌘/🌚/πŸŒ’/πŸŒ“/πŸŒ”/🌝"), + ], + ) + def test_map_path(self, adaptor: Adaptor, source_path: str, dest_path: str): + # GIVEN + test_server = _AdaptorServer(_ActionsQueue(), adaptor) + server_thread = _threading.Thread(target=start_test_server, args=(test_server,)) + server_thread.start() + + # Create a client passing in the port number from the server. + client = _FakeAppClient(f"{test_server.server_address}") + mapped_path = client.map_path(source_path) + + # Giving time to avoid a race condition in which we close the thread before setup. + _sleep(1) + + # Cleanup + test_server.shutdown() + server_thread.join() + + # THEN + assert mapped_path == dest_path + + @_mock.patch.object(_FakeAppClient, "close") + @_mock.patch.object(_FakeAppClient, "hello_world") + def test_action_performed( + self, mocked_hw: _mock.Mock, mocked_close: _mock.Mock, adaptor: Adaptor + ): + """This test will confirm an action was performed on the client.""" + # The argument for the hello world action. + hw_args = {"foo": "barr"} + + # Create an action queue with actions enqueued + aq = _ActionsQueue() + aq.enqueue_action(_Action("hello_world", hw_args)) + aq.enqueue_action(_Action("close")) + + # Create a server and pass the actions queue. + test_server = _AdaptorServer(aq, adaptor) + + # Create thread for the AdaptorServer. + server_thread = _threading.Thread(target=start_test_server, args=(test_server,)) + server_thread.start() + + # Create a client passing in the port number from the server. + client = _FakeAppClient(test_server.socket_path) + + # Create a thread for the client. + client_thread = _threading.Thread(target=start_test_client, args=(client,)) + client_thread.start() + + # Giving time to avoid a race condition in which we close the thread before setup. + _sleep(1) + + # Cleanup + test_server.shutdown() + server_thread.join() + client_thread.join() + + # Confirming the test ran successfully. + mocked_hw.assert_called_once_with(hw_args) + mocked_close.assert_called_once() + + @_mock.patch.object(_FakeAppClient, "close") + @_mock.patch.object(_FakeAppClient, "hello_world") + def test_long_polling(self, mocked_hw: _mock.Mock, mocked_close: _mock.Mock, adaptor: Adaptor): + """This test will test long polling works as expected.""" + # The argument for the hello world action. + hw_args = {"foo": "barr"} + + # Create an action queue with actions enqueued + aq = _ActionsQueue() + aq.enqueue_action(_Action("hello_world", hw_args)) + + # Create a server and pass the actions queue. + test_server = _AdaptorServer(aq, adaptor) + + # Create thread for the AdaptorServer. + server_thread = _threading.Thread(target=start_test_server, args=(test_server,)) + server_thread.start() + + # Create a client passing in the port number from the server. + client = _FakeAppClient(test_server.socket_path) + + # Create a thread for the client. + client_thread = _threading.Thread(target=start_test_client, args=(client,)) + client_thread.start() + + # Giving time to avoid a race condition in which we close the thread before setup. + _sleep(1) + + # Confirming the test ran successfully. + mocked_hw.assert_called_once_with(hw_args) + + # Sleeping while the client is running to simulate a delay in enqueuing an action. + # We are going to sleep for less than the REQUEST_TIMEOUT. + _sleep(2) + + # Verifying close wasn't called. + assert not mocked_close.called + + def enqueue_close_action(): + """This is the function to enqueue the close action.""" + aq.enqueue_action(_Action("close")) + + # Creating a thread to delay the close action to "force" long polling on the client. + close_thread = _threading.Thread(target=enqueue_close_action) + close_thread.start() + + # Cleanup + test_server.shutdown() + server_thread.join() + client_thread.join() + close_thread.join() + + # Verifying the test was successful. + mocked_close.assert_called_once() diff --git a/test/openjd/adaptor_runtime/integ/background/__init__.py b/test/openjd/adaptor_runtime/integ/background/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/background/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime/integ/background/sample_adaptor/SampleAdaptor.json b/test/openjd/adaptor_runtime/integ/background/sample_adaptor/SampleAdaptor.json new file mode 100644 index 0000000..9e26dfe --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/background/sample_adaptor/SampleAdaptor.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/test/openjd/adaptor_runtime/integ/background/sample_adaptor/__init__.py b/test/openjd/adaptor_runtime/integ/background/sample_adaptor/__init__.py new file mode 100644 index 0000000..80989bf --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/background/sample_adaptor/__init__.py @@ -0,0 +1,7 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from .adaptor import SampleAdaptor + +__all__ = [ + "SampleAdaptor", +] diff --git a/test/openjd/adaptor_runtime/integ/background/sample_adaptor/__main__.py b/test/openjd/adaptor_runtime/integ/background/sample_adaptor/__main__.py new file mode 100644 index 0000000..481a0f7 --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/background/sample_adaptor/__main__.py @@ -0,0 +1,19 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import sys + +from openjd.adaptor_runtime import EntryPoint + +from .adaptor import SampleAdaptor + + +def main(): + package_name = vars(sys.modules[__name__])["__package__"] + if not package_name: + raise RuntimeError(f"Must be run as a module. Do not run {__file__} directly") + + EntryPoint(SampleAdaptor).start() + + +if __name__ == "__main__": + main() diff --git a/test/openjd/adaptor_runtime/integ/background/sample_adaptor/adaptor.py b/test/openjd/adaptor_runtime/integ/background/sample_adaptor/adaptor.py new file mode 100644 index 0000000..74ce780 --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/background/sample_adaptor/adaptor.py @@ -0,0 +1,28 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import logging + +from openjd.adaptor_runtime.adaptors import Adaptor + +_logger = logging.getLogger(__name__) + + +class SampleAdaptor(Adaptor): + """ + Adaptor class that is used for background mode integration tests. + """ + + def __init__(self, init_data: dict, **_): + super().__init__(init_data) + + def on_start(self): + _logger.info("on_start") + + def on_run(self, run_data: dict): + _logger.info(f"on_run: {run_data}") + + def on_stop(self): + _logger.info("on_stop") + + def on_cleanup(self): + _logger.info("on_cleanup") diff --git a/test/openjd/adaptor_runtime/integ/background/sample_adaptor/tests.integration.background.sample_adaptor.json b/test/openjd/adaptor_runtime/integ/background/sample_adaptor/tests.integration.background.sample_adaptor.json new file mode 100644 index 0000000..184f5b2 --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/background/sample_adaptor/tests.integration.background.sample_adaptor.json @@ -0,0 +1,3 @@ +{ + "log_level": "DEBUG" +} \ No newline at end of file diff --git a/test/openjd/adaptor_runtime/integ/background/test_background_mode.py b/test/openjd/adaptor_runtime/integ/background/test_background_mode.py new file mode 100644 index 0000000..874de73 --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/background/test_background_mode.py @@ -0,0 +1,226 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import json +import os +import pathlib +import re +import sys +import time +from http import HTTPStatus +from typing import Generator +from unittest.mock import patch +from pathlib import Path + +import psutil +import pytest + +import openjd.adaptor_runtime._entrypoint as runtime_entrypoint +from openjd.adaptor_runtime._background.frontend_runner import ( + FrontendRunner, + HTTPError, + _load_connection_settings, +) + +mod_path = (Path(__file__).parent).resolve() +sys.path.append(str(mod_path)) +if (_pypath := os.environ.get("PYTHONPATH")) is not None: + os.environ["PYTHONPATH"] = ":".join((_pypath, str(mod_path))) +else: + os.environ["PYTHONPATH"] = str(mod_path) +from sample_adaptor import SampleAdaptor # noqa: E402 + + +class TestDaemonMode: + """ + Tests for background daemon mode. + """ + + @pytest.fixture(autouse=True) + def mock_runtime_logger_level(self, tmpdir: pathlib.Path): + # Setup a config file for the backend process + config = {"log_level": "DEBUG"} + config_path = os.path.join(tmpdir, "configuration.json") + with open(config_path, mode="w") as f: + json.dump(config, f) + + # Override the default config path to the one we just created + with (patch.dict(os.environ, {runtime_entrypoint._ENV_CONFIG_PATH_PREFIX: config_path}),): + yield + + @pytest.fixture + def connection_file_path(self, tmp_path: pathlib.Path) -> str: + return os.path.join(tmp_path.absolute(), "connection.json") + + @pytest.fixture + def initialized_setup( + self, + connection_file_path: str, + caplog: pytest.LogCaptureFixture, + ) -> Generator[tuple[FrontendRunner, psutil.Process], None, None]: + caplog.set_level(0) + frontend = FrontendRunner(connection_file_path) + frontend.init(sys.modules[SampleAdaptor.__module__]) + conn_settings = _load_connection_settings(connection_file_path) + + match = re.search("Started backend process. PID: ([0-9]+)", caplog.text) + assert match is not None + pid = int(match.group(1)) + backend_proc = psutil.Process(pid) + + yield (frontend, backend_proc) + + try: + backend_proc.kill() + except psutil.NoSuchProcess: + pass # Already stopped + try: + os.remove(conn_settings.socket) + except FileNotFoundError: + pass # Already deleted + + def test_init( + self, + initialized_setup: tuple[FrontendRunner, psutil.Process], + connection_file_path: str, + ) -> None: + # GIVEN + _, backend_proc = initialized_setup + + # THEN + assert os.path.exists(connection_file_path) + + connection_settings = _load_connection_settings(connection_file_path) + assert any( + [ + conn.laddr == connection_settings.socket + for conn in backend_proc.connections(kind="unix") + ] + ) + + def test_shutdown( + self, + initialized_setup: tuple[FrontendRunner, psutil.Process], + connection_file_path: str, + ) -> None: + # GIVEN + frontend, backend_proc = initialized_setup + conn_settings = _load_connection_settings(connection_file_path) + + # WHEN + frontend.shutdown() + + # THEN + assert all( + [ + _wait_for_file_deletion(p, timeout_s=1) + for p in [connection_file_path, conn_settings.socket] + ] + ) + + # "Assert" the process exits after requesting shutdown. + # The "assertion" fails if we time out waiting. + backend_proc.wait(timeout=1) + + def test_start( + self, + initialized_setup: tuple[FrontendRunner, psutil.Process], + caplog: pytest.LogCaptureFixture, + ) -> None: + # GIVEN + frontend, _ = initialized_setup + + # WHEN + frontend.start() + + # THEN + assert "on_start" in caplog.text + + @pytest.mark.parametrize( + argnames=["run_data"], + argvalues=[ + [[{"one": 1}]], + [[{"one": 1}, {"two": 2}]], + ], + ids=["runs once", "runs consecutively"], + ) + def test_run( + self, + run_data: list[dict], + initialized_setup: tuple[FrontendRunner, psutil.Process], + caplog: pytest.LogCaptureFixture, + ) -> None: + # GIVEN + frontend, _ = initialized_setup + + for data in run_data: + # WHEN + frontend.run(data) + + # THEN + assert f"on_run: {data}" in caplog.text + + def test_stop( + self, + initialized_setup: tuple[FrontendRunner, psutil.Process], + caplog: pytest.LogCaptureFixture, + ) -> None: + # GIVEN + frontend, _ = initialized_setup + + # WHEN + frontend.stop() + + # THEN + assert "on_stop" in caplog.text + + def test_heartbeat_acks( + self, + initialized_setup: tuple[FrontendRunner, psutil.Process], + ) -> None: + # GIVEN + frontend, _ = initialized_setup + response = frontend._heartbeat() + + # WHEN + new_response = frontend._heartbeat(response.output.id) + + # THEN + assert f"Received ACK for chunk: {response.output.id}" in new_response.output.output + + class TestAuthentication: + """ + Tests for background mode authentication. + + Tests that require another OS user are in the Adaptor Runtime pipeline. + """ + + def test_accepts_same_uid_process( + self, initialized_setup: tuple[FrontendRunner, psutil.Process] + ) -> None: + # GIVEN + frontend, _ = initialized_setup + + # WHEN + try: + frontend._heartbeat() + except HTTPError as e: + if e.response.status == HTTPStatus.UNAUTHORIZED: + pytest.fail("Request failed authentication when it should have succeeded") + else: + pytest.fail(f"Request failed with an unexpected status code: {e}") + else: + # THEN + # Heartbeat request went through, so auth succeeded + pass + + +def _wait_for_file_deletion(path: str, timeout_s: float) -> bool: + start = time.time() + while os.path.exists(path): + if time.time() - start < timeout_s: + time.sleep(0.01) + else: + return False + return True diff --git a/test/openjd/adaptor_runtime/integ/process/__init__.py b/test/openjd/adaptor_runtime/integ/process/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/process/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime/integ/process/scripts/echo_sleep_n_times.sh b/test/openjd/adaptor_runtime/integ/process/scripts/echo_sleep_n_times.sh new file mode 100755 index 0000000..3f43a51 --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/process/scripts/echo_sleep_n_times.sh @@ -0,0 +1,10 @@ +#!/bin/bash +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +num2=$(($2)) +for ((i=1;i<=num2;i++)) +do + echo $1 + >&2 echo $1 + sleep 0.01 +done \ No newline at end of file diff --git a/test/openjd/adaptor_runtime/integ/process/scripts/no_sigterm.sh b/test/openjd/adaptor_runtime/integ/process/scripts/no_sigterm.sh new file mode 100755 index 0000000..90409c4 --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/process/scripts/no_sigterm.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# This script's purpose is to ignore SIGTERM so we can test whether or not the process is exited properly. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +trap_with_arg() { + func="$1" ; shift + for sig ; do + trap "$func $sig" "$sig" + done +} + +func_trap() { + echo "Trapped: $1" +} + +trap_with_arg func_trap INT TERM EXIT + +echo 'Starting no_sigterm.sh Script' + +while true; do + date +%F_%T + sleep 1 +done diff --git a/test/openjd/adaptor_runtime/integ/process/scripts/print_signals.sh b/test/openjd/adaptor_runtime/integ/process/scripts/print_signals.sh new file mode 100755 index 0000000..50db9b1 --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/process/scripts/print_signals.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# The purpose of this script is to print when it receives a SIGTERM then exit. + +_handler() { + echo "Trapped: $1" + exit +} + +trap '_handler TERM' SIGTERM + +echo 'Starting print_signals.sh Script' + +while true; do + date +%F_%T + sleep 1 +done diff --git a/test/openjd/adaptor_runtime/integ/process/test_integration_logging_subprocess.py b/test/openjd/adaptor_runtime/integ/process/test_integration_logging_subprocess.py new file mode 100644 index 0000000..1633f4b --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/process/test_integration_logging_subprocess.py @@ -0,0 +1,221 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import logging +import os +import re +import time +from logging import DEBUG +from unittest import mock + +import pytest + +from openjd.adaptor_runtime.app_handlers import RegexCallback, RegexHandler +from openjd.adaptor_runtime.process import LoggingSubprocess +from openjd.adaptor_runtime.process._logging_subprocess import _STDERR_LEVEL, _STDOUT_LEVEL + + +class TestIntegrationLoggingSubprocess(object): + """Integration tests for LoggingSubprocess""" + + expected_stop_params = [ + pytest.param(0, ["Immediately stopping process (pid="], id="StopProcessImmediately"), + pytest.param( + 2, + [ + "Sending the SIGTERM signal to pid=", + "now sending the SIGKILL signal.", + ], + id="StopProcessWhenSIGTERMFails", + ), + ] + + @pytest.mark.timeout(5) + @pytest.mark.parametrize("grace_period, expected_output", expected_stop_params) + def test_stop_process(self, grace_period, expected_output, caplog: pytest.LogCaptureFixture): + """ + Testing that we stop the process immediately and after SIGTERM fails. + """ + test_file = os.path.join( + os.path.abspath(os.path.dirname(__file__)), "scripts", "no_sigterm.sh" + ) + caplog.set_level(DEBUG) + p = LoggingSubprocess(args=[test_file]) + + # This is because we are giving the subprocess time to load and + # register the sigterm signal handler with the OS. + while "Starting no_sigterm.sh Script" not in caplog.text: + True + + p.terminate(grace_period) + + for output in expected_output: + assert output in caplog.text + + @pytest.mark.timeout(5) + def test_terminate_process(self, caplog): + """ + Testing that the process was terminated successfully. This means that the process ended + when SIGTERM was sent and SIGKILL was not needed. + """ + test_file = os.path.join( + os.path.abspath(os.path.dirname(__file__)), "scripts", "print_signals.sh" + ) + caplog.set_level(DEBUG) + p = LoggingSubprocess(args=[test_file]) + + # This is because we are giving the subprocess time to load and ignore the sigterm signal. + while "Starting print_signals.sh Script" not in caplog.text: + True + + p.terminate(5) # Sometimes, when this is 1 second the process doesn't terminate in time. + + assert ( + "Sending the SIGTERM signal to pid=" in caplog.text + ) # Asserting the SIGTERM signal was sent to the subprocess + assert ( + "Trapped: TERM" in caplog.text + ) # Asserting the SIGTERM was received by the subprocess. + assert ( + "now sending the SIGKILL signal." not in caplog.text + ) # Asserting the SIGKILL signal was not sent to the subprocess + + startup_dir_params = [ + pytest.param(None, id="DefaultBehaviour"), + pytest.param(os.path.dirname(os.path.realpath(__file__)), id="CurrentDir"), + ] + + @pytest.mark.parametrize("startup_dir", startup_dir_params) + def test_startup_directory(self, startup_dir: str | None, caplog): + caplog.set_level(logging.INFO) + + args = ["pwd"] + ls = LoggingSubprocess(args=args, startup_directory=startup_dir) + + # Sometimes we assert too quickly, so we are waiting for the pwd command to finish + # running. + ls.wait() + + # Explicitly cleanup the IO threads to ensure all output is logged + ls._cleanup_io_threads() + + assert "Running command: pwd" in caplog.text + + if startup_dir is not None: + assert startup_dir in caplog.text + + def test_startup_directory_empty(self): + """When calling LoggingSubprocess with an empty cwd, FileNotFoundError will be raised.""" + args = ["pwd"] + with pytest.raises(FileNotFoundError) as excinfo: + LoggingSubprocess(args=args, startup_directory="") + + assert "[Errno 2] No such file or directory: ''" in str(excinfo.value) + + @pytest.mark.parametrize("log_level", [_STDOUT_LEVEL, _STDERR_LEVEL]) + def test_log_levels(self, log_level: int, caplog): + # GIVEN + caplog.set_level(log_level) + message = "Hello World" + + test_file = os.path.join( + os.path.abspath(os.path.dirname(__file__)), "scripts", "echo_sleep_n_times.sh" + ) + # WHEN + p = LoggingSubprocess( + args=[test_file, message, "1"], + ) + p.wait() + + # THEN + records = caplog.get_records("call") + if log_level == _STDOUT_LEVEL: + assert any(r.message == message and r.levelno == _STDOUT_LEVEL for r in records) + else: + assert not any(r.message == message and r.levelno == _STDOUT_LEVEL for r in records) + + assert any(r.message == message and r.levelno == _STDERR_LEVEL for r in records) + + +class TestIntegrationRegexHandler(object): + """Integration tests for LoggingSubprocess""" + + invoked_regex_list = [ + pytest.param( + re.compile(".*"), + "Test output", + 5, + ), + ] + + @pytest.mark.parametrize("stdout, stderr", [(1, 0), (0, 1), (1, 1)]) + @pytest.mark.parametrize("regex, output, echo_count", invoked_regex_list) + def test_stdouthandler_invoked(self, regex, output, echo_count, stdout, stderr): + # GIVEN + callback = mock.Mock() + regex_callbacks = [RegexCallback([regex], callback)] + regex_handler = RegexHandler(regex_callbacks) + test_file = os.path.join( + os.path.abspath(os.path.dirname(__file__)), "scripts", "echo_sleep_n_times.sh" + ) + # WHEN + p = LoggingSubprocess( + args=[test_file, output, str(echo_count)], + stdout_handler=regex_handler if stdout else None, + stderr_handler=regex_handler if stderr else None, + ) + p.wait() + time.sleep(0.01) # magic sleep - logging handler has a delay and test can exit too fast + + # THEN + assert callback.call_count == echo_count * (stdout + stderr) + assert all(c[0][0].re == regex for c in callback.call_args_list) + + multiple_procs_regex_list = [ + pytest.param( + re.compile(".*"), + "Test output", + 5, + ), + ] + + @pytest.mark.parametrize("num_procs", [2]) + @pytest.mark.parametrize("regex, output, echo_count", multiple_procs_regex_list) + def test_multiple_processes_invoked_independently(self, regex, output, echo_count, num_procs): + """ + Creates a number of processes and validates that the stdout/stderr from each process does + not invoke a callback in a different process' logging handler. + """ + # GIVEN + + # Set up regex handler with a single callback for each stdout/stderr of each proc + callbacks = (mock.Mock() for _ in range(2 * num_procs)) + regex_callbacks = [RegexCallback([regex], callback) for callback in callbacks] + regex_handlers = [RegexHandler([regex_callback]) for regex_callback in regex_callbacks] + + stdout_handlers = regex_handlers[:num_procs] + stderr_handlers = regex_handlers[num_procs:] + + test_file = os.path.join( + os.path.abspath(os.path.dirname(__file__)), "scripts", "echo_sleep_n_times.sh" + ) + + # WHEN + procs = [] + for i in range(num_procs): + procs.append( + LoggingSubprocess( + args=[test_file, output, str(echo_count)], + stdout_handler=stdout_handlers[i], + stderr_handler=stderr_handlers[i], + ) + ) + + for proc in procs: + proc.wait() + + # THEN + for callback in callbacks: + assert callback.call_count == echo_count + assert all(c[0][0].re == regex for c in callback.call_args_list) diff --git a/test/openjd/adaptor_runtime/integ/process/test_integration_managed_process.py b/test/openjd/adaptor_runtime/integ/process/test_integration_managed_process.py new file mode 100644 index 0000000..a756b36 --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/process/test_integration_managed_process.py @@ -0,0 +1,88 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import os +import re +import time +from logging import INFO +from typing import List +from unittest import mock + +import pytest + +from openjd.adaptor_runtime.app_handlers import RegexCallback, RegexHandler +from openjd.adaptor_runtime.process import ManagedProcess + + +class TestManagedProcess(object): + """Integration tests for ManagedProcess""" + + def test_run(self, caplog): + """Testing a success case for the managed process.""" + + class FakeManagedProcess(ManagedProcess): + def __init__(self, run_data: dict): + super(FakeManagedProcess, self).__init__(run_data) + + def get_executable(self) -> str: + return "echo" + + def get_arguments(self) -> List[str]: + return ["Hello World!"] + + def get_startup_directory(self) -> str | None: + return None + + caplog.set_level(INFO) + + mp = FakeManagedProcess({}) + mp.run() + + assert "Hello World!" in caplog.text + + +class TestIntegrationRegexHandlerManagedProcess(object): + """Integration tests for LoggingSubprocess""" + + invoked_regex_list = [ + pytest.param( + re.compile(".*"), + "Test output", + 5, + ), + ] + + @pytest.mark.parametrize("stdout, stderr", [(1, 0), (0, 1), (1, 1)]) + @pytest.mark.parametrize("regex, output, echo_count", invoked_regex_list) + def test_regexhandler_invoked(self, regex, output, echo_count, stdout, stderr): + # GIVEN + class FakeManagedProcess(ManagedProcess): + def get_executable(self) -> str: + return os.path.join( + os.path.abspath(os.path.dirname(__file__)), "scripts", "echo_sleep_n_times.sh" + ) + + def get_arguments(self) -> List[str]: + return [output, str(echo_count)] + + def get_startup_directory(self) -> str | None: + return None + + callback = mock.Mock() + regex_callbacks = [RegexCallback([regex], callback)] + regex_handler = RegexHandler(regex_callbacks) + + # WHEN + + mp = FakeManagedProcess( + {}, + stdout_handler=regex_handler if stdout else None, + stderr_handler=regex_handler if stderr else None, + ) + mp.run() + time.sleep(0.01) # magic sleep - logging handler has a delay and test can exit too fast + + # THEN + assert callback.call_count == echo_count * (stdout + stderr) + assert all(c[0][0].re == regex for c in callback.call_args_list) diff --git a/test/openjd/adaptor_runtime/integ/test_integration_entrypoint.py b/test/openjd/adaptor_runtime/integ/test_integration_entrypoint.py new file mode 100644 index 0000000..2939946 --- /dev/null +++ b/test/openjd/adaptor_runtime/integ/test_integration_entrypoint.py @@ -0,0 +1,178 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import json +import os +import re +import sys +from logging import INFO +from typing import cast +from unittest.mock import patch +from pathlib import Path + +import pytest + +import openjd.adaptor_runtime._entrypoint as runtime_entrypoint +from openjd.adaptor_runtime import EntryPoint + +mod_path = (Path(__file__).parent).resolve() +sys.path.append(str(mod_path)) +if (_pypath := os.environ.get("PYTHONPATH")) is not None: + os.environ["PYTHONPATH"] = ":".join((_pypath, str(mod_path))) +else: + os.environ["PYTHONPATH"] = str(mod_path) +from IntegCommandAdaptor import IntegCommandAdaptor # noqa: E402 + + +class TestCommandAdaptorRun: + """ + Tests for the CommandAdaptor running using the `run` command-line. + """ + + def test_runs_command_adaptor( + self, capfd: pytest.CaptureFixture, caplog: pytest.LogCaptureFixture + ): + # GIVEN + caplog.set_level(INFO) + test_sys_argv = [ + "program_filename.py", + "run", + "--init-data", + json.dumps( + { + "on_prerun": "on_prerun", + "on_postrun": "on_postrun", + } + ), + "--run-data", + json.dumps({"args": ["hello world"]}), + ] + entrypoint = EntryPoint(IntegCommandAdaptor) + + # WHEN + with ( + patch.object(runtime_entrypoint.sys, "argv", test_sys_argv), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + + # THEN + assert "on_prerun" in caplog.text + assert "hello world" in caplog.text + assert "on_postrun" in caplog.text + + # THEN + result = cast(str, capfd.readouterr().out) + assert re.match(".*prerun-print.*postrun-print.*", result, flags=re.RegexFlag.DOTALL) + + +class TestCommandAdaptorDaemon: + """ + Tests for the CommandAdaptor running using the `daemon` command-line. + """ + + def test_start_stop(self, caplog: pytest.LogCaptureFixture, tmp_path: Path): + # GIVEN + caplog.set_level(INFO) + connection_file = tmp_path / "connection.json" + test_start_argv = [ + "program_filename.py", + "daemon", + "start", + "--connection-file", + str(connection_file), + "--init-data", + json.dumps( + { + "on_prerun": "on_prerun", + "on_postrun": "on_postrun", + } + ), + ] + test_stop_argv = [ + "program_filename.py", + "daemon", + "stop", + "--connection-file", + str(connection_file), + ] + entrypoint = EntryPoint(IntegCommandAdaptor) + + # WHEN + with ( + patch.object(runtime_entrypoint.sys, "argv", test_start_argv), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + with ( + patch.object(runtime_entrypoint.sys, "argv", test_stop_argv), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + + # THEN + assert "Initializing backend process" in caplog.text + assert "Connected successfully" in caplog.text + assert "Running in background daemon mode." in caplog.text + assert "Daemon background process stopped." in caplog.text + assert "on_prerun" not in caplog.text + assert "on_postrun" not in caplog.text + + def test_run(self, caplog: pytest.LogCaptureFixture, tmp_path: Path): + # GIVEN + caplog.set_level(INFO) + connection_file = tmp_path / "connection.json" + test_start_argv = [ + "program_filename.py", + "daemon", + "start", + "--connection-file", + str(connection_file), + "--init-data", + json.dumps( + { + "on_prerun": "on_prerun", + "on_postrun": "on_postrun", + } + ), + ] + test_run_argv = [ + "program_filename.py", + "daemon", + "run", + "--connection-file", + str(connection_file), + "--run-data", + json.dumps({"args": ["hello world"]}), + ] + test_stop_argv = [ + "program_filename.py", + "daemon", + "stop", + "--connection-file", + str(connection_file), + ] + entrypoint = EntryPoint(IntegCommandAdaptor) + + # WHEN + with ( + patch.object(runtime_entrypoint.sys, "argv", test_start_argv), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + with ( + patch.object(runtime_entrypoint.sys, "argv", test_run_argv), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + with ( + patch.object(runtime_entrypoint.sys, "argv", test_stop_argv), + patch.object(runtime_entrypoint.logging.Logger, "setLevel"), + ): + entrypoint.start() + + # THEN + assert "on_prerun" in caplog.text + assert "hello world" in caplog.text + assert "on_postrun" in caplog.text diff --git a/test/openjd/adaptor_runtime/test_importable.py b/test/openjd/adaptor_runtime/test_importable.py new file mode 100644 index 0000000..461ed71 --- /dev/null +++ b/test/openjd/adaptor_runtime/test_importable.py @@ -0,0 +1,9 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + +def test_openjd_importable(): + import openjd # noqa: F401 + + +def test_importable(): + import openjd.adaptor_runtime # noqa: F401 diff --git a/test/openjd/adaptor_runtime/unit/__init__.py b/test/openjd/adaptor_runtime/unit/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime/unit/adaptors/__init__.py b/test/openjd/adaptor_runtime/unit/adaptors/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/adaptors/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime/unit/adaptors/configuration/__init__.py b/test/openjd/adaptor_runtime/unit/adaptors/configuration/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/adaptors/configuration/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime/unit/adaptors/configuration/stubs.py b/test/openjd/adaptor_runtime/unit/adaptors/configuration/stubs.py new file mode 100644 index 0000000..c597244 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/adaptors/configuration/stubs.py @@ -0,0 +1,72 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +from typing_extensions import Literal + +from openjd.adaptor_runtime.adaptors.configuration import ( + AdaptorConfiguration, + Configuration, + ConfigurationManager, + RuntimeConfiguration, +) + + +class RuntimeConfigurationStub(RuntimeConfiguration): + """ + Stub implementation of RuntimeConfiguration + """ + + def __init__(self) -> None: + super().__init__({}) + + @property + def log_level(self) -> Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]: + return "DEBUG" + + @property + def deactivate_telemetry(self) -> bool: + return True + + @property + def plugin_configuration(self) -> dict | None: + return None + + +class AdaptorConfigurationStub(AdaptorConfiguration): + """ + Stub implementation of AdaptorConfiguration + """ + + def __init__(self) -> None: + super().__init__({}) + + @property + def log_level( + self, + ) -> Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] | None: + return "DEBUG" + + +class ConfigurationManagerMock(ConfigurationManager): + """ + Mock implementation of ConfigurationManager with empty defaults. + """ + + def __init__( + self, + *, + schema_path="", + default_config_path="", + system_config_path_map={}, + user_config_rel_path="", + additional_config_paths=[], + ) -> None: + super().__init__( + config_cls=Configuration, + schema_path=schema_path, + default_config_path=default_config_path, + system_config_path_map=system_config_path_map, + user_config_rel_path=user_config_rel_path, + additional_config_paths=additional_config_paths, + ) diff --git a/test/openjd/adaptor_runtime/unit/adaptors/configuration/test_configuration.py b/test/openjd/adaptor_runtime/unit/adaptors/configuration/test_configuration.py new file mode 100644 index 0000000..3deeabd --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/adaptors/configuration/test_configuration.py @@ -0,0 +1,284 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +from json.decoder import JSONDecodeError +from typing import Any +from typing import List as _List +from unittest.mock import MagicMock, call, patch + +import jsonschema +import pytest + +import openjd.adaptor_runtime.adaptors.configuration._configuration as configuration +from openjd.adaptor_runtime.adaptors.configuration import ( + Configuration, +) +from openjd.adaptor_runtime.adaptors.configuration._configuration import ( + _make_function_register_decorator, +) + + +def test_make_register_decorator(): + # GIVEN + decorator = _make_function_register_decorator() + + # WHEN + @decorator + def my_func(): + pass + + # THEN + key, value = my_func.__name__, my_func + assert key in decorator.registry and value == decorator.registry[key] + + +class TestFromFile: + """ + Tests for the Configuration.from_file method + """ + + @patch.object(configuration.jsonschema, "validate") + @patch.object(configuration.json, "load") + @patch.object(configuration, "open") + def test_loads_schema( + self, mock_open: MagicMock, mock_load: MagicMock, mock_validate: MagicMock + ): + # GIVEN + schema_path = "/path/to/schema" + config_path = "/path/to/config" + schema = {"json": "schema"} + config = {"my": "config"} + mock_load.side_effect = [config, schema] + + # WHEN + result = Configuration.from_file(config_path, schema_path) + + # THEN + mock_open.assert_has_calls([call(config_path), call(schema_path)]) + assert mock_load.call_count == 2 + mock_validate.assert_called_once_with(config, schema) + assert result._config is config + + @patch.object(configuration.json, "load") + @patch.object(configuration, "open") + def test_skips_validation_when_no_schema( + self, + mock_open: MagicMock, + mock_load: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + config_path = "/path/to/config" + config = {"my": "config"} + mock_load.return_value = config + + # WHEN + result = Configuration.from_file(config_path) + + # THEN + mock_open.assert_called_once_with(config_path) + mock_load.assert_called_once() + assert ( + f"JSON Schema file path not provided. Configuration file {config_path} will not be " + "validated." + ) in caplog.text + assert result._config is config + + @patch.object(configuration.jsonschema, "validate") + @patch.object(configuration.json, "load") + @patch.object(configuration, "open") + def test_validates_against_multiple_schemas( + self, + mock_open: MagicMock, + mock_load: MagicMock, + mock_validate: MagicMock, + ): + # GIVEN + schema_path = "/path/to/schema" + schema = {"json": "schema"} + schema2_path = "path/2/schema" + schema2 = {"json2": "schema2"} + config_path = "/path/to/config" + config = {"my": "config"} + mock_load.side_effect = [config, schema, schema2] + + # WHEN + result = Configuration.from_file(config_path, [schema_path, schema2_path]) + + # THEN + mock_open.assert_has_calls([call(config_path), call(schema_path), call(schema2_path)]) + assert mock_load.call_count == 3 + mock_validate.assert_has_calls([call(config, schema), call(config, schema2)]) + assert result._config is config + + @pytest.mark.parametrize("schema_path", [[], ""]) + @patch.object(configuration.json, "load") + @patch.object(configuration, "open") + def test_raises_when_nonvalid_schema_path_value( + self, mock_open: MagicMock, mock_load: MagicMock, schema_path: _List | str + ): + # GIVEN + config_path = "/path/to/config" + mock_open.return_value = MagicMock() + + # WHEN + with pytest.raises(ValueError) as raised_err: + Configuration.from_file(config_path, schema_path) + + # THEN + mock_load.assert_called_once() + mock_open.assert_called_once_with(config_path) + assert raised_err.match(f"Schema path cannot be an empty {type(schema_path)}") + + @patch.object(configuration.json, "load") + @patch.object(configuration, "open") + def test_raises_when_schema_open_fails( + self, mock_open: MagicMock, mock_load: MagicMock, caplog: pytest.LogCaptureFixture + ): + # GIVEN + config_path = "/path/to/config" + schema_path = "/path/to/schema" + err = OSError() + mock_open.side_effect = [MagicMock(), err] + + # WHEN + with pytest.raises(OSError) as raised_err: + Configuration.from_file(config_path, schema_path) + + # THEN + mock_load.assert_called_once() + mock_open.assert_has_calls([call(config_path), call(schema_path)]) + assert raised_err.value is err + assert f"Failed to open configuration schema at {schema_path}: " in caplog.text + + @patch.object(configuration.json, "load") + @patch.object(configuration, "open") + def test_raises_when_schema_json_decode_fails( + self, mock_open: MagicMock, mock_load: MagicMock, caplog: pytest.LogCaptureFixture + ): + # GIVEN + config_path = "/path/to/config" + schema_path = "/path/to/schema" + err = JSONDecodeError("", "", 0) + mock_load.side_effect = [{}, err] + + # WHEN + with pytest.raises(JSONDecodeError) as raised_err: + Configuration.from_file(config_path, schema_path) + + # THEN + assert mock_load.call_count == 2 + mock_open.assert_has_calls([call(config_path), call(schema_path)]) + assert raised_err.value is err + assert f"Failed to decode configuration schema at {schema_path}: " in caplog.text + + @patch.object(configuration, "open") + def test_raises_when_config_open_fails( + self, mock_open: MagicMock, caplog: pytest.LogCaptureFixture + ): + # GIVEN + config_path = "/path/to/config" + err = OSError() + mock_open.side_effect = err + + # WHEN + with pytest.raises(OSError) as raised_err: + Configuration.from_file(config_path, "") + + # THEN + mock_open.assert_called_once_with(config_path) + assert raised_err.value is err + assert f"Failed to open configuration at {config_path}: " in caplog.text + + @patch.object(configuration.json, "load") + @patch.object(configuration, "open") + def test_raises_when_config_json_decode_fails( + self, mock_open: MagicMock, mock_load: MagicMock, caplog: pytest.LogCaptureFixture + ): + # GIVEN + config_path = "/path/to/config" + err = JSONDecodeError("", "", 0) + mock_load.side_effect = err + + # WHEN + with pytest.raises(JSONDecodeError) as raised_err: + Configuration.from_file(config_path, "") + + # THEN + mock_open.assert_called_once_with(config_path) + mock_load.assert_called_once() + assert raised_err.value is err + assert f"Failed to decode configuration at {config_path}: " in caplog.text + + @patch.object(configuration.jsonschema, "validate") + @patch.object(configuration.json, "load") + @patch.object(configuration, "open") + def test_raises_when_config_fails_jsonschema_validation( + self, + mock_open: MagicMock, + mock_load: MagicMock, + mock_validate: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + schema_path = "/path/to/schema" + config_path = "/path/to/config" + schema = {"json": "schema"} + config = {"my": "config"} + mock_load.side_effect = [config, schema] + mock_validate.side_effect = jsonschema.ValidationError("") + + # WHEN + with pytest.raises(jsonschema.ValidationError) as raised_err: + Configuration.from_file(config_path, schema_path) + + # THEN + mock_open.assert_has_calls([call(config_path), call(schema_path)]) + assert mock_load.call_count == 2 + mock_validate.assert_called_once_with(config, schema) + assert raised_err.value is mock_validate.side_effect + assert ( + f"Configuration file at {config_path} failed to validate " + f"against the JSON schema at {schema_path}: " in caplog.text + ) + + +class TestConfiguration: + """ + Tests for the base Configuration instance methods + """ + + def test_override(self): + # GIVEN + config1 = Configuration({"a": 1, "b": 2}) + config2 = Configuration({"b": 3, "c": 4}) + + # WHEN + result = config1.override(config2) + + # THEN + assert {"a": 1, "b": 3, "c": 4} == result._config + + def test_config_populates_defaults(self): + # GIVEN + class TestConfiguration(Configuration): + _defaults = _make_function_register_decorator() + + @classmethod + def _get_defaults_decorator(cls) -> Any: + return cls._defaults + + @property # type: ignore + @_defaults + def default_property(self) -> str: + return "default" + + initial_config = {"existing_property": "existing"} + expected = {"existing_property": "existing", "default_property": "default"} + + # WHEN + config = TestConfiguration(initial_config) + + # THEN + assert expected == config.config diff --git a/test/openjd/adaptor_runtime/unit/adaptors/configuration/test_configuration_manager.py b/test/openjd/adaptor_runtime/unit/adaptors/configuration/test_configuration_manager.py new file mode 100644 index 0000000..135260e --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/adaptors/configuration/test_configuration_manager.py @@ -0,0 +1,657 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import os +import re +from typing import Generator as _Generator +from unittest.mock import MagicMock, call, mock_open, patch + +import pytest + +import openjd.adaptor_runtime._osname as osname +from openjd.adaptor_runtime.adaptors.configuration import ( + AdaptorConfiguration as _AdaptorConfiguration, + Configuration as _Configuration, +) +import openjd.adaptor_runtime.adaptors.configuration._configuration_manager as configuration_manager +from openjd.adaptor_runtime.adaptors.configuration._configuration_manager import ( + _DIR as _configuration_manager_dir, + ConfigurationManager, + _ensure_config_file, + create_adaptor_configuration_manager as _create_adaptor_configuration_manager, +) + +from .stubs import ConfigurationManagerMock + + +class TestEnsureConfigFile: + """ + Tests for the ConfigurationManager._ensure_config_file method + """ + + @pytest.fixture(autouse=True) + def mock_makedirs(self) -> _Generator[MagicMock, None, None]: + with patch.object(configuration_manager.os, "makedirs") as mock: + yield mock + + @patch.object(configuration_manager.os.path, "isfile") + @patch.object(configuration_manager.os.path, "exists") + def test_returns_true_if_valid(self, mock_exists: MagicMock, mock_isfile: MagicMock): + # GIVEN + path = "my/path" + mock_exists.return_value = True + mock_isfile.return_value = True + + # WHEN + result = _ensure_config_file(path) + + # THEN + assert result + mock_exists.assert_called_once_with(path) + mock_isfile.assert_called_once_with(path) + + @patch.object(configuration_manager.os.path, "exists") + def test_returns_false_when_file_does_not_exist( + self, mock_exists: MagicMock, caplog: pytest.LogCaptureFixture + ): + # GIVEN + caplog.set_level(0) + path = "my/path" + mock_exists.return_value = False + + # WHEN + result = _ensure_config_file(path) + + # THEN + assert not result + mock_exists.assert_called_once_with(path) + assert f'Configuration file at "{path}" does not exist.' in caplog.text + + @patch.object(configuration_manager.os.path, "isfile") + @patch.object(configuration_manager.os.path, "exists") + def test_returns_false_when_path_points_to_nonfile( + self, mock_exists: MagicMock, mock_isfile: MagicMock, caplog: pytest.LogCaptureFixture + ): + # GIVEN + caplog.set_level(0) + path = "my/path" + mock_exists.return_value = True + mock_isfile.return_value = False + + # WHEN + result = _ensure_config_file(path) + + # THEN + assert not result + mock_exists.assert_called_once_with(path) + mock_isfile.assert_called_once_with(path) + assert f'Configuration file at "{path}" is not a file.' in caplog.text + + @pytest.mark.parametrize( + argnames=["created"], + argvalues=[[True], [False]], + ids=["created", "not created"], + ) + @patch.object(configuration_manager.json, "dump") + @patch.object(configuration_manager.os.path, "exists") + def test_create( + self, + mock_exists: MagicMock, + mock_dump: MagicMock, + created: bool, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + caplog.set_level(0) + path = "my/path" + mock_exists.return_value = False + + open_mock: MagicMock + with patch.object(configuration_manager, "secure_open", mock_open()) as open_mock: + if not created: + open_mock.side_effect = OSError() + + # WHEN + result = _ensure_config_file(path, create=True) + + # THEN + assert result == created + mock_exists.assert_called_once_with(path) + open_mock.assert_called_once_with(path, open_mode="w") + assert f'Configuration file at "{path}" does not exist.' in caplog.text + assert f"Creating empty configuration at {path}" in caplog.text + if created: + mock_dump.assert_called_once_with({}, open_mock.return_value) + else: + assert f"Could not write empty configuration to {path}: " in caplog.text + + +class TestCreateAdaptorConfigurationManager: + """ + Tests for the create_adaptor_configuration_manager function. + """ + + def test_creates_config_manager(self): + """ + This test is fragile as it relies on the hardcoded path formats to adaptor config files. + """ + # GIVEN + adaptor_name = "adaptor" + default_config_path = "/path/to/config" + + # WHEN + result = _create_adaptor_configuration_manager( + _AdaptorConfiguration, + adaptor_name, + default_config_path, + ) + + # THEN + assert result._config_cls == _AdaptorConfiguration + assert result._default_config_path == default_config_path + assert result._system_config_path_map["Linux"] == ( + f"/etc/openjd/adaptors/{adaptor_name}/{adaptor_name}.json" + ) + assert result._user_config_rel_path == os.path.join( + ".openjd", "adaptors", adaptor_name, f"{adaptor_name}.json" + ) + assert isinstance(result._schema_path, list) + assert len(result._schema_path) == 1 + assert result._schema_path[0] == os.path.abspath( + os.path.join(_configuration_manager_dir, "_adaptor_configuration.schema.json") + ) + + def test_accepts_single_schema(self): + # GIVEN + adaptor_name = "adaptor" + default_config_path = "/path/to/config" + schema_path = "/path/to/schema" + + # WHEN + result = _create_adaptor_configuration_manager( + _AdaptorConfiguration, + adaptor_name, + default_config_path, + schema_path, + ) + + # THEN + assert isinstance(result._schema_path, list) + assert len(result._schema_path) == 2 + assert result._schema_path[1] == schema_path + + def test_accepts_multiple_schemas(self): + # GIVEN + adaptor_name = "adaptor" + default_config_path = "/path/to/config" + schema_paths = [ + "/path/to/schema1", + "/path/to/schema2", + ] + + # WHEN + result = _create_adaptor_configuration_manager( + _AdaptorConfiguration, + adaptor_name, + default_config_path, + schema_paths, + ) + + # THEN + assert isinstance(result._schema_path, list) + assert len(result._schema_path) == 1 + len(schema_paths) + assert result._schema_path[1:] == schema_paths + + +class TestConfigurationManager: + """ + Tests for the base ConfigurationManager class + """ + + class TestBuildConfig: + """ + Tests for the ConfigurationManager.build_config method + + These tests mock the "Configuration.override" method to return an empty + Configuration and do not assert its correctness. They will only assert that the + system-level and user-level overrides are applied correctly, not the actual override logic. + The override logic is covered in the Configuration tests. + """ + + @patch.object(_Configuration, "override") + @patch.object(ConfigurationManager, "get_user_config") + @patch.object(ConfigurationManager, "get_user_config_path") + @patch.object(ConfigurationManager, "get_system_config") + @patch.object(ConfigurationManager, "get_system_config_path") + @patch.object(ConfigurationManager, "get_default_config") + def test_builds_config( + self, + mock_get_default_config: MagicMock, + mock_get_system_config_path: MagicMock, + mock_get_system_config: MagicMock, + mock_get_user_config_path: MagicMock, + mock_get_user_config: MagicMock, + mock_override: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + caplog.set_level(0) + mock_override.side_effect = [_Configuration({}), _Configuration({})] + mock_get_default_config.return_value = _Configuration({}) + mock_get_system_config_path.return_value = "fake_system_config_path" + mock_get_system_config.return_value = _Configuration({}) + mock_get_user_config_path.return_value = "fake_user_config_path" + mock_get_user_config.return_value = _Configuration({}) + + # WHEN + ConfigurationManagerMock().build_config() + + # THEN + mock_get_default_config.assert_called_once() + mock_get_system_config.assert_called_once() + assert "Applying system-level configuration" in caplog.text + mock_get_user_config.assert_called_once_with(None) + assert "Applying user-level configuration" in caplog.text + mock_override.assert_has_calls( + [ + call(mock_get_system_config.return_value), + call(mock_get_user_config.return_value), + ] + ) + + @patch.object(_Configuration, "override") + @patch.object(ConfigurationManager, "get_user_config") + @patch.object(ConfigurationManager, "get_user_config_path") + @patch.object(ConfigurationManager, "get_system_config") + @patch.object(ConfigurationManager, "get_default_config") + def test_skips_when_system_config_missing( + self, + mock_get_default_config: MagicMock, + mock_get_system_config: MagicMock, + mock_get_user_config_path: MagicMock, + mock_get_user_config: MagicMock, + mock_override: MagicMock, + ): + # GIVEN + mock_override.side_effect = [_Configuration({}), _Configuration({})] + mock_get_default_config.return_value = _Configuration({}) + mock_get_system_config.return_value = None + mock_get_user_config.return_value = _Configuration({}) + mock_get_user_config_path.return_value = "fake_user_config_path" + + # WHEN + ConfigurationManagerMock().build_config() + + # THEN + mock_get_default_config.assert_called_once() + mock_get_system_config.assert_called_once() + mock_get_user_config.assert_called_once_with(None) + mock_override.assert_called_once_with(mock_get_user_config.return_value) + + @patch.object(_Configuration, "override") + @patch.object(ConfigurationManager, "get_user_config") + @patch.object(ConfigurationManager, "get_system_config") + @patch.object(ConfigurationManager, "get_system_config_path") + @patch.object(ConfigurationManager, "get_default_config") + def test_skips_user_config_when_missing( + self, + mock_get_default_config: MagicMock, + mock_get_system_config_path: MagicMock, + mock_get_system_config: MagicMock, + mock_get_user_config: MagicMock, + mock_override: MagicMock, + ): + # GIVEN + mock_override.side_effect = [_Configuration({}), _Configuration({})] + mock_get_default_config.return_value = _Configuration({}) + mock_get_system_config.return_value = _Configuration({}) + mock_get_system_config_path.return_value = "fake_system_config_path" + mock_get_user_config.return_value = None + + # WHEN + ConfigurationManagerMock().build_config() + + # THEN + mock_get_default_config.assert_called_once() + mock_get_system_config.assert_called_once() + mock_get_user_config.assert_called_once_with(None) + mock_override.assert_called_once_with(mock_get_system_config.return_value) + + @patch.object(configuration_manager, "_ensure_config_file", return_value=True) + @patch.object(_Configuration, "from_file") + @patch.object(_Configuration, "override") + @patch.object(ConfigurationManager, "get_user_config") + @patch.object(ConfigurationManager, "get_system_config") + @patch.object(ConfigurationManager, "get_default_config") + def test_applies_additional_config_paths( + self, + mock_get_default_config: MagicMock, + mock_get_system_config: MagicMock, + mock_get_user_config: MagicMock, + mock_override: MagicMock, + mock_from_file: MagicMock, + mock_ensure_config_file: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + caplog.set_level(0) + mock_get_default_config.return_value = _Configuration({}) + mock_get_system_config.return_value = None + mock_get_user_config.return_value = None + + additional_config_paths = ["/config/a.json", "/config/b.json"] + additional_configs = [ + _Configuration({"log_level": "WARNING", "a": "a", "unchanged": "unchanged"}), + _Configuration({"log_level": "DEBUG", "b": "b", "unchanged": "unchanged"}), + ] + + override_retvals = [*additional_configs] + mock_override.side_effect = override_retvals + mock_from_file.side_effect = additional_configs + + manager = ConfigurationManagerMock(additional_config_paths=additional_config_paths) + + # WHEN + manager.build_config() + + # THEN + mock_get_default_config.assert_called_once() + mock_get_system_config.assert_called_once() + mock_get_user_config.assert_called_once_with(None) + + # Verify calls + path_calls = [call(path) for path in additional_config_paths] + mock_ensure_config_file.assert_has_calls(path_calls) + mock_from_file.assert_has_calls(path_calls) + mock_override.assert_has_calls([call(retval) for retval in override_retvals]) + + # Verify diffs are logged + expected_diffs = [ + # First diff is applying the entire first config since prior configs are empty + {"log_level": "WARNING", "a": "a", "unchanged": "unchanged"}, + # Next diff is having log_level updated and a new "b" prop added + { + "log_level": "DEBUG", + "b": "b", + }, + ] + for expected in expected_diffs: + for k, v in expected.items(): + assert f"Set {k} to {v}" in caplog.text + + @patch.object(configuration_manager, "_ensure_config_file") + @patch.object(_Configuration, "from_file") + @patch.object(_Configuration, "override") + @patch.object(ConfigurationManager, "get_user_config") + @patch.object(ConfigurationManager, "get_system_config") + @patch.object(ConfigurationManager, "get_default_config") + def test_skips_additional_config_path( + self, + mock_get_default_config: MagicMock, + mock_get_system_config: MagicMock, + mock_get_user_config: MagicMock, + mock_override: MagicMock, + mock_from_file: MagicMock, + mock_ensure_config_file: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + caplog.set_level(0) + empty_config = _Configuration({}) + mock_get_default_config.return_value = empty_config + mock_get_system_config.return_value = None + mock_get_user_config.return_value = None + + skipped_config = "/config/skipped.json" + additional_config_paths = ["/config/used.json", skipped_config] + mock_ensure_config_file.side_effect = [True, False] + + mock_override.side_effect = [empty_config] * (2 + len(additional_config_paths)) + mock_from_file.side_effect = [empty_config] * len(additional_config_paths) + + manager = ConfigurationManagerMock(additional_config_paths=additional_config_paths) + + # WHEN + manager.build_config() + + # THEN + mock_ensure_config_file.assert_has_calls( + [call(path) for path in additional_config_paths] + ) + assert ( + f"Failed to load additional configuration: {skipped_config}. Skipping..." + in caplog.text + ) + + class TestDefaultConfig: + """ + Tests for ConfigurationManager methods that get the default configuration + """ + + @patch.object(configuration_manager, "_ensure_config_file", return_value=True) + @patch.object(_Configuration, "from_file") + def test_gets_default_config( + self, + mock_from_file: MagicMock, + mock_ensure_config_file: MagicMock, + ): + # GIVEN + schema_path = "schema/path" + config_path = "config/path" + manager = ConfigurationManagerMock( + schema_path=schema_path, default_config_path=config_path + ) + + # WHEN + manager.get_default_config() + + # THEN + mock_ensure_config_file.assert_called_once_with(config_path, create=True) + mock_from_file.assert_called_once_with(config_path, schema_path) + + @patch.object(configuration_manager, "_ensure_config_file", return_value=False) + def test_warns_when_file_is_nonvalid( + self, + mock_ensure_config_file: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + schema_path = "schema/path" + config_path = "config/path" + manager = ConfigurationManagerMock( + schema_path=schema_path, default_config_path=config_path + ) + cls_mock = MagicMock() + manager._config_cls = cls_mock + + # WHEN + manager.get_default_config() + + # THEN + mock_ensure_config_file.assert_called_once_with(config_path, create=True) + assert ( + f"Default configuration file at {config_path} is not a valid file. " + "Using empty configuration." + ) in caplog.text + cls_mock.assert_called_once_with({}) + + class TestSystemConfig: + """ + Tests for methods that get the system-level configuration + """ + + @patch.object(osname.platform, "system") + def test_gets_linux_path(self, mock_system: MagicMock): + # GIVEN + mock_system.return_value = "Linux" + expected = "path/to/linux/system/config" + manager = ConfigurationManagerMock( + system_config_path_map={ + "Linux": expected, + } + ) + + # WHEN + result = manager.get_system_config_path() + + # THEN + mock_system.assert_called_once() + assert result == expected + + @patch.object(osname.platform, "system") + def test_raises_on_nonvalid_os(self, mock_system: MagicMock): + """ + Validate a NotImplementedError is Raised if the OSName does not resolve. + """ + + # GIVEN + mock_system.return_value = "unsupported_os" + + # WHEN + with pytest.raises(NotImplementedError): + ConfigurationManagerMock().get_system_config_path() + + # THEN + mock_system.assert_called_once() + + @patch.object(osname, "OSName", return_value="unsupported_os") + def test_raises_on_valid_os_not_implemented(self, mock_system: MagicMock): + """ + Validate a NotImplementedError is Raised if the OSName resolves but is not supported + """ + with pytest.raises(NotImplementedError): + ConfigurationManagerMock().get_system_config_path() + + @patch.object(configuration_manager, "_ensure_config_file") + @patch.object(ConfigurationManagerMock, "get_system_config_path") + @patch.object(_Configuration, "from_file") + def test_loads_config_file( + self, + mock_from_file: MagicMock, + mock_get_system_config_path: MagicMock, + mock_ensure_config_file: MagicMock, + ): + # GIVEN + system_config_path = "system/config/path" + mock_get_system_config_path.return_value = system_config_path + mock_ensure_config_file.return_value = True + schema_path = "schema/path" + manager = ConfigurationManagerMock(schema_path=schema_path) + + # WHEN + result = manager.get_system_config() + + # THEN + mock_get_system_config_path.assert_called_once() + mock_ensure_config_file.assert_called_once_with(system_config_path) + mock_from_file.assert_called_once_with(system_config_path, schema_path) + assert result is not None + + @patch.object(configuration_manager, "_ensure_config_file") + @patch.object(ConfigurationManagerMock, "get_system_config_path") + def test_returns_none_when_path_is_not_file( + self, mock_get_system_config_path: MagicMock, mock_ensure_config_file: MagicMock + ): + # GIVEN + mock_ensure_config_file.return_value = False + + # WHEN + result = ConfigurationManagerMock().get_system_config() + + # THEN + mock_get_system_config_path.assert_called_once() + mock_ensure_config_file.assert_called_once() + assert result is None + + class TestUserConfig: + """ + Tests for methods that get the user-level configuration + """ + + @patch.object(configuration_manager.os.path, "expanduser") + def test_gets_path_current_user(self, mock_expanduser: MagicMock): + # GIVEN + def fake_expanduser(path: str): + return path.replace("~", "/home/currentuser") + + mock_expanduser.side_effect = fake_expanduser + expected_rel_path = "path/to/user/config" + manager = ConfigurationManagerMock(user_config_rel_path=expected_rel_path) + + # WHEN + result = manager.get_user_config_path() + + # THEN + mock_expanduser.assert_called_once_with(StringStartsWith("~")) + assert result == os.path.join("/home/currentuser", expected_rel_path) + + @patch.object(configuration_manager.os.path, "expanduser") + def test_gets_path_specific_user(self, mock_expanduser: MagicMock): + # GIVEN + def fake_expanduser(path: str): + return re.sub(r"~(.*)/(.*)", r"/home/\1/\2", path) + + mock_expanduser.side_effect = fake_expanduser + username = "username" + expected_rel_path = "path/to/user/config" + manager = ConfigurationManagerMock(user_config_rel_path=expected_rel_path) + + # WHEN + result = manager.get_user_config_path(username) + + # THEN + mock_expanduser.assert_called_once_with(StringStartsWith(f"~{username}")) + assert result == os.path.join(f"/home/{username}", expected_rel_path) + + @patch.object(configuration_manager, "_ensure_config_file") + @patch.object(ConfigurationManagerMock, "get_user_config_path") + @patch.object(_Configuration, "from_file") + def test_loads_config_file( + self, + mock_from_file: MagicMock, + mock_get_user_config_path: MagicMock, + mock_ensure_config_file: MagicMock, + ): + # GIVEN + user_config_path = "user/config/path" + mock_get_user_config_path.return_value = user_config_path + mock_ensure_config_file.return_value = True + schema_path = "schema/path" + manager = ConfigurationManagerMock(schema_path=schema_path) + + # WHEN + result = manager.get_user_config() + + # THEN + mock_get_user_config_path.assert_called_once_with(None) + mock_ensure_config_file.assert_called_once_with(user_config_path, create=True) + mock_from_file.assert_called_once_with(user_config_path, schema_path) + assert result is not None + + @patch.object(configuration_manager, "_ensure_config_file") + @patch.object(ConfigurationManagerMock, "get_user_config_path") + def test_returns_none_when_path_is_not_file( + self, mock_get_user_config_path: MagicMock, mock_ensure_config_file: MagicMock + ): + # GIVEN + mock_ensure_config_file.return_value = False + + # WHEN + result = ConfigurationManagerMock().get_user_config() + + # THEN + mock_get_user_config_path.assert_called_once() + mock_ensure_config_file.assert_called_once() + assert result is None + + +class StringStartsWith(str): + """ + String subclass that overrides the equality method to check if a string starts with this one. + + This is used as a "matcher" object in test assertions. + """ + + def __eq__(self, other: object): + return isinstance(other, str) and other.startswith(self) diff --git a/test/openjd/adaptor_runtime/unit/adaptors/fake_adaptor.py b/test/openjd/adaptor_runtime/unit/adaptors/fake_adaptor.py new file mode 100644 index 0000000..4b51346 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/adaptors/fake_adaptor.py @@ -0,0 +1,24 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +from openjd.adaptor_runtime.adaptors import BaseAdaptor + +__all__ = ["FakeAdaptor"] + + +class FakeAdaptor(BaseAdaptor): + def __init__(self, init_data: dict, **kwargs): + super().__init__(init_data, **kwargs) + + def _start(self): + pass + + def _run(self, run_data: dict): + pass + + def _cleanup(self): + pass + + def _stop(self): + pass diff --git a/test/openjd/adaptor_runtime/unit/adaptors/test_adaptor.py b/test/openjd/adaptor_runtime/unit/adaptors/test_adaptor.py new file mode 100644 index 0000000..de35b2d --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/adaptors/test_adaptor.py @@ -0,0 +1,76 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +from unittest.mock import Mock, patch + +from openjd.adaptor_runtime.adaptors import Adaptor + + +class FakeAdaptor(Adaptor): + """ + Test implementation of a Adaptor + """ + + def __init__(self, init_data: dict): + super().__init__(init_data) + + def on_run(self, run_data: dict): + pass + + +class TestRun: + """ + Tests for the Adaptor._run method + """ + + @patch.object(FakeAdaptor, "on_run", autospec=True) + @patch.object(FakeAdaptor, "__init__", return_value=None, autospec=True) + def test_run(self, mocked_init: Mock, mocked_on_run: Mock) -> None: + # GIVEN + init_data: dict = {} + run_data: dict = {} + adaptor = FakeAdaptor(init_data) + + # WHEN + adaptor._run(run_data) + + # THEN + mocked_init.assert_called_once_with(adaptor, init_data) + mocked_on_run.assert_called_once_with(adaptor, run_data) + + @patch.object(FakeAdaptor, "on_start", autospec=True) + def test_start(self, mocked_on_start: Mock) -> None: + # GIVEN + init_data: dict = {} + adaptor = FakeAdaptor(init_data) + + # WHEN + adaptor._start() + + # THEN + mocked_on_start.assert_called_once_with(adaptor) + + @patch.object(FakeAdaptor, "on_stop", autospec=True) + def test_stop(self, mocked_on_stop: Mock) -> None: + # GIVEN + init_data: dict = {} + adaptor = FakeAdaptor(init_data) + + # WHEN + adaptor._stop() + + # THEN + mocked_on_stop.assert_called_once_with(adaptor) + + @patch.object(FakeAdaptor, "on_cleanup", autospec=True) + def test_cleanup(self, mocked_on_cleanup: Mock) -> None: + # GIVEN + init_data: dict = {} + adaptor = FakeAdaptor(init_data) + + # WHEN + adaptor._cleanup() + + # THEN + mocked_on_cleanup.assert_called_once_with(adaptor) diff --git a/test/openjd/adaptor_runtime/unit/adaptors/test_adaptor_runner.py b/test/openjd/adaptor_runtime/unit/adaptors/test_adaptor_runner.py new file mode 100644 index 0000000..8cb010f --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/adaptors/test_adaptor_runner.py @@ -0,0 +1,181 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from unittest.mock import MagicMock, patch + +import pytest + +from openjd.adaptor_runtime.adaptors import AdaptorRunner + +from .fake_adaptor import FakeAdaptor + + +class TestRun: + """ + Tests for the AdaptorRunner._run method + """ + + @patch.object(FakeAdaptor, "_run") + def test_runs_adaptor(self, adaptor_run_mock: MagicMock): + # GIVEN + run_data = {"run": "data"} + runner = FakeAdaptorRunner() + + # WHEN + runner._run(run_data) + + # THEN + adaptor_run_mock.assert_called_once_with(run_data) + + @patch.object(FakeAdaptor, "_run") + def test_run_throws(self, adaptor_run_mock: MagicMock, caplog: pytest.LogCaptureFixture): + # GIVEN + exc = Exception() + adaptor_run_mock.side_effect = exc + runner = FakeAdaptorRunner() + + # WHEN + with pytest.raises(Exception) as raised_exc: + runner._run({}) + + # THEN + assert raised_exc.value is exc + assert "Error encountered while running adaptor: " in caplog.text + + +class TestStart: + """ + Tests for the AdaptorRunner._start method + """ + + @patch.object(FakeAdaptor, "_start") + def test_starts_adaptor(self, adaptor_start_mock: MagicMock): + # GIVEN + runner = FakeAdaptorRunner() + + # WHEN + runner._start() + + # THEN + adaptor_start_mock.assert_called_once() + + @patch.object(FakeAdaptor, "_start") + def test_start_throws(self, adaptor_start_mock: MagicMock, caplog: pytest.LogCaptureFixture): + # GIVEN + exc = Exception() + adaptor_start_mock.side_effect = exc + runner = FakeAdaptorRunner() + + # WHEN + with pytest.raises(Exception) as raised_exc: + runner._start() + + # THEN + assert raised_exc.value is exc + assert "Error encountered while starting adaptor: " in caplog.text + adaptor_start_mock.assert_called_once() + + +class TestStop: + """ + Tests for the AdaptorRunner._stop method + """ + + @patch.object(FakeAdaptor, "_stop") + def test_stops_adaptor(self, adaptor_end_mock: MagicMock): + # GIVEN + runner = FakeAdaptorRunner() + + # WHEN + runner._stop() + + # THEN + adaptor_end_mock.assert_called_once() + + @patch.object(FakeAdaptor, "_stop") + def test_stop_throws(self, adaptor_end_mock: MagicMock, caplog: pytest.LogCaptureFixture): + # GIVEN + exc = Exception() + adaptor_end_mock.side_effect = exc + runner = FakeAdaptorRunner() + + # WHEN + with pytest.raises(Exception) as raised_exc: + runner._stop() + + # THEN + assert raised_exc.value is exc + assert "Error encountered while stopping adaptor: " in caplog.text + adaptor_end_mock.assert_called_once() + + +class TestCleanup: + """ + Tests for the AdaptorRunner._cleanup method + """ + + @patch.object(FakeAdaptor, "_cleanup") + def test_cleanup_adaptor(self, adaptor_cleanup_mock: MagicMock): + # GIVEN + runner = FakeAdaptorRunner() + + # WHEN + runner._cleanup() + + # THEN + adaptor_cleanup_mock.assert_called_once() + + @patch.object(FakeAdaptor, "_cleanup") + def test_cleanup_throws( + self, adaptor_cleanup_mock: MagicMock, caplog: pytest.LogCaptureFixture + ): + # GIVEN + exc = Exception() + adaptor_cleanup_mock.side_effect = exc + runner = FakeAdaptorRunner() + + # WHEN + with pytest.raises(Exception) as raised_exc: + runner._cleanup() + + # THEN + assert raised_exc.value is exc + assert "Error encountered while cleaning up adaptor: " in caplog.text + adaptor_cleanup_mock.assert_called_once() + + +class TestCancel: + """ + Tests for the AdaptorRunner._cancel method + """ + + @patch.object(FakeAdaptor, "cancel") + def test_cancel_adaptor(self, adaptor_cancel_mock: MagicMock): + # GIVEN + runner = FakeAdaptorRunner() + + # WHEN + runner._cancel() + + # THEN + adaptor_cancel_mock.assert_called_once() + + @patch.object(FakeAdaptor, "cancel") + def test_cancel_throws(self, adaptor_cancel_mock: MagicMock, caplog: pytest.LogCaptureFixture): + # GIVEN + exc = Exception() + adaptor_cancel_mock.side_effect = exc + runner = FakeAdaptorRunner() + + # WHEN + with pytest.raises(Exception) as raised_exc: + runner._cancel() + + # THEN + assert raised_exc.value is exc + assert "Error encountered while canceling the adaptor: " in caplog.text + adaptor_cancel_mock.assert_called_once() + + +class FakeAdaptorRunner(AdaptorRunner): + def __init__(self): + super().__init__(adaptor=FakeAdaptor({})) diff --git a/test/openjd/adaptor_runtime/unit/adaptors/test_base_adaptor.py b/test/openjd/adaptor_runtime/unit/adaptors/test_base_adaptor.py new file mode 100644 index 0000000..2505b45 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/adaptors/test_base_adaptor.py @@ -0,0 +1,383 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import os +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from pytest import param + +import openjd.adaptor_runtime.adaptors._base_adaptor as base_adaptor +from openjd.adaptor_runtime.adaptors._base_adaptor import ( + _ENV_CONFIG_PATH_TEMPLATE, + _ENV_CONFIG_SCHEMA_PATH_PREFIX, + AdaptorConfigurationOptions, + BaseAdaptor, + _ModuleInfo, +) +from openjd.adaptor_runtime.adaptors.configuration import AdaptorConfiguration + +from .fake_adaptor import FakeAdaptor + + +class TestConfigProperty: + """ + Tests for the configuration property in BaseAdaptor. + """ + + @patch.object(BaseAdaptor, "_load_configuration_manager") + def test_lazily_loads_config(self, mock_load_config_manager: MagicMock): + # GIVEN + mock_config = MagicMock() + mock_config_manager = MagicMock() + mock_config_manager.build_config.return_value = mock_config + mock_load_config_manager.return_value = mock_config_manager + adaptor = FakeAdaptor({}) + + mock_load_config_manager.assert_not_called() + mock_config_manager.build_config.assert_not_called() + + # WHEN + config = adaptor.config + + # THEN + mock_load_config_manager.assert_called_once() + mock_config_manager.build_config.assert_called_once() + assert config is mock_config + + @patch.object(BaseAdaptor, "_load_configuration_manager") + def test_uses_loaded_config(self, mock_load_config_manager: MagicMock): + # GIVEN + mock_config = MagicMock() + mock_config_manager = MagicMock() + mock_config_manager.build_config.return_value = mock_config + mock_load_config_manager.return_value = mock_config_manager + adaptor = FakeAdaptor({}) + + # Get the property to lazily load the config + adaptor.config + + # WHEN + config = adaptor.config + + # THEN + mock_load_config_manager.assert_called_once() + mock_config_manager.build_config.assert_called_once() + assert config is mock_config + + @patch.object(BaseAdaptor, "_load_configuration_manager") + def test_lazily_loads_config_manager(self, mock_load_config_manager: MagicMock): + # GIVEN + mock_config_manager = MagicMock() + mock_load_config_manager.return_value = mock_config_manager + adaptor = FakeAdaptor({}) + + mock_load_config_manager.assert_not_called() + + # WHEN + config_manager = adaptor.config_manager + + # THEN + mock_load_config_manager.assert_called_once() + assert config_manager is mock_config_manager + + @patch.object(BaseAdaptor, "_load_configuration_manager") + def test_uses_loaded_config_manager(self, mock_load_config_manager: MagicMock): + # GIVEN + mock_config_manager = MagicMock() + mock_load_config_manager.return_value = mock_config_manager + adaptor = FakeAdaptor({}) + + # Get the property to lazily load the config manager + adaptor.config_manager + + # WHEN + config_manager = adaptor.config_manager + + # THEN + mock_load_config_manager.assert_called_once() + assert config_manager is mock_config_manager + + +class TestConfigLoading: + """ + Tests for the adaptor configuration loading logic. + """ + + @pytest.mark.parametrize( + argnames="schema_exists", + argvalues=[True, False], + ids=["Default schema exists", "Default schema does not exist"], + ) + @patch.object(base_adaptor, "create_adaptor_configuration_manager") + @patch.object(_ModuleInfo, "package", new_callable=PropertyMock) + @patch.object(_ModuleInfo, "file", new_callable=PropertyMock) + @patch.object(base_adaptor.os.path, "abspath") + @patch.object(base_adaptor.os.path, "exists") + def test_loads_config_from_module_info( + self, + mock_exists: MagicMock, + mock_abspath: MagicMock, + mock_file: MagicMock, + mock_package: MagicMock, + mock_create_adaptor_config_manager: MagicMock, + schema_exists: bool, + ): + # GIVEN + package = "openjd_fake_adaptor" + mock_package.return_value = package + mock_file.return_value = "path/to/file" + + module_path = "/root/dir/module.py" + mock_abspath.return_value = module_path + mock_exists.return_value = schema_exists + adaptor = FakeAdaptor({}) + adaptor_name = "FakeAdaptor" + + # WHEN + with patch.dict(base_adaptor.sys.modules, {adaptor.__module__: MagicMock()}): + adaptor._load_configuration_manager() + + # THEN + config_path = os.path.join(os.path.dirname(module_path), f"{adaptor_name}.json") + schema_path = os.path.join(os.path.dirname(module_path), f"{adaptor_name}.schema.json") + mock_create_adaptor_config_manager.assert_called_once_with( + config_cls=AdaptorConfiguration, + adaptor_name=adaptor_name, + default_config_path=config_path, + schema_path=schema_path if schema_exists else None, + additional_config_paths=[], + ) + assert mock_file.call_count == 4 + assert mock_package.call_count == 1 + assert mock_abspath.call_count == 2 + mock_exists.assert_called_once_with(schema_path) + + @patch.object(base_adaptor, "create_adaptor_configuration_manager") + @patch.object(_ModuleInfo, "file", new_callable=PropertyMock) + @patch.object(_ModuleInfo, "package", new_callable=PropertyMock) + def test_loads_config_from_environment_variables( + self, + mock_package: MagicMock, + mock_file: MagicMock, + mock_create_adaptor_config_manager: MagicMock, + ): + # GIVEN + package = "openjd_fake_adaptor" + mock_package.return_value = package + adaptor = FakeAdaptor({}) + adaptor_name = "FakeAdaptor" + additional_config_path = f"/path/to/additional/config/{adaptor_name}.json" + config_path = f"/path/to/config/{adaptor_name}.json" + schema_path = f"/path/to/schema/{adaptor_name}.schema.json" + + mock_file.return_value = config_path + + # WHEN + with patch.dict(base_adaptor.sys.modules, {adaptor.__module__: MagicMock()}): + with patch.dict( + base_adaptor.os.environ, + { + f"FAKEADAPTOR_{_ENV_CONFIG_PATH_TEMPLATE}": additional_config_path, + _ENV_CONFIG_SCHEMA_PATH_PREFIX: os.path.dirname(schema_path), + }, + ): + adaptor._load_configuration_manager() + + # THEN + mock_create_adaptor_config_manager.assert_called_once_with( + config_cls=AdaptorConfiguration, + adaptor_name=adaptor_name, + default_config_path=config_path, + schema_path=schema_path, + additional_config_paths=[additional_config_path], + ) + assert mock_package.call_count == 1 + + @patch.object(base_adaptor, "create_adaptor_configuration_manager") + @patch.object(_ModuleInfo, "package", new_callable=PropertyMock) + def test_loads_config_from_options( + self, mock_package: MagicMock, mock_create_adaptor_config_manager: MagicMock + ): + # GIVEN + config_path = "/path/to/config" + schema_path = "/path/to/schema" + package = "openjd_fake_adaptor" + mock_package.return_value = package + adaptor = FakeAdaptor( + {}, + config_opts=AdaptorConfigurationOptions( + config_cls=None, + config_path=config_path, + schema_path=schema_path, + ), + ) + adaptor_name = "FakeAdaptor" + + # WHEN + with patch.dict(base_adaptor.sys.modules, {adaptor.__module__: MagicMock()}): + adaptor._load_configuration_manager() + + # THEN + mock_create_adaptor_config_manager.assert_called_once_with( + config_cls=AdaptorConfiguration, + adaptor_name=adaptor_name, + default_config_path=config_path, + schema_path=schema_path, + additional_config_paths=[], + ) + assert mock_package.call_count == 1 + + def test_raises_when_module_not_loaded(self): + # GIVEN + adaptor = FakeAdaptor({}) + module_name = adaptor.__module__ + + # WHEN + with patch.dict(base_adaptor.sys.modules, {module_name: None}): + with pytest.raises(KeyError) as raised_err: + adaptor._load_configuration_manager() + + # THEN + assert raised_err.match(f"Module not loaded: {module_name}") + + @patch.object(_ModuleInfo, "name", new_callable=PropertyMock) + @patch.object(_ModuleInfo, "package", new_callable=PropertyMock) + def test_raises_when_module_not_package( + self, + mock_package: MagicMock, + mock_name: MagicMock, + ): + # GIVEN + adaptor = FakeAdaptor({}) + module_name = adaptor.__module__ + mock_name.return_value = module_name + mock_package.return_value = None + + # WHEN + with patch.dict(base_adaptor.sys.modules, {module_name: MagicMock()}): + with pytest.raises(ValueError) as raised_err: + adaptor._load_configuration_manager() + + # THEN + assert raised_err.match(f"Module {module_name} is not a package") + + @patch.object(_ModuleInfo, "name", new_callable=PropertyMock) + @patch.object(_ModuleInfo, "package", new_callable=PropertyMock) + @patch.object(_ModuleInfo, "file", new_callable=PropertyMock) + def test_raises_when_module_no_filepath( + self, + mock_file: MagicMock, + mock_package: MagicMock, + mock_name: MagicMock, + ): + # GIVEN + adaptor = FakeAdaptor({}) + module_name = adaptor.__module__ + mock_name.return_value = module_name + mock_package.return_value = "package" + mock_file.return_value = None + + # WHEN + with patch.dict(base_adaptor.sys.modules, {module_name: MagicMock()}): + with pytest.raises(ValueError) as raised_err: + adaptor._load_configuration_manager() + + # THEN + assert mock_package.call_count == 1 + mock_file.assert_called_once() + assert raised_err.match(f"Module {module_name} does not have a file path set") + + +class TestStatusUpdate: + """Tests for sending status updates""" + + _OPENJD_PROGRESS_STDOUT_PREFIX: str = "openjd_progress: " + _OPENJD_STATUS_STDOUT_PREFIX: str = "openjd_status: " + + @pytest.mark.parametrize( + "progress", + [ + param(-10000.0), + param(0), + param(0.0), + param(33.3333333333333), + param(50.0), + param(100.0), + param(100000.0), + param(1e5), + param(1e-5), + ], + ) + def test_progress_update(self, capsys, progress: float): + """Tests just updating the progress""" + # GIVEN + expected = f"{self._OPENJD_PROGRESS_STDOUT_PREFIX}{progress}" + + # WHEN + BaseAdaptor.update_status(progress=progress) + + # THEN + assert expected in capsys.readouterr().out + + @pytest.mark.parametrize( + "status_message", + [ + param("my epic new status message"), + param("33.33333"), + param(""), + ], + ) + def test_status_message_update(self, capsys, status_message: str): + """Tests just updating the status message""" + # GIVEN + expected = f"{self._OPENJD_STATUS_STDOUT_PREFIX}{status_message}" + + # WHEN + BaseAdaptor.update_status(status_message=status_message) + + # THEN + assert expected in capsys.readouterr().out + + @pytest.mark.parametrize( + "progress,status_message", + [ + param(-350.0, "...negative progress?"), + param(0.0, ""), + param(10.0, "making some progress"), + param(33.33333, "33.33333"), + param(100.0, "just finished!"), + param( + 100000.0, + "this farm accepts a lot of progress and this is a really long status message", + ), + ], + ) + def test_status_update(self, capsys, progress: float, status_message: str): + """Tests updating both progress and status messages""" + # GIVEN + expected_progress = f"{self._OPENJD_PROGRESS_STDOUT_PREFIX}{progress}" + expected_status_message = f"{self._OPENJD_STATUS_STDOUT_PREFIX}{status_message}" + + # WHEN + BaseAdaptor.update_status(progress=progress, status_message=status_message) + + # THEN + result = capsys.readouterr().out + assert expected_progress in result + assert expected_status_message in result + + def test_ignore_status_update(self, capsys): + """Tests we don't send any message if there's nothing to report""" + # GIVEN + expected = "" # nothing was captured in stdout + + # WHEN + BaseAdaptor.update_status(progress=None, status_message=None) + BaseAdaptor.update_status(progress=float("NaN")) + BaseAdaptor.update_status(progress=float("inf")) + + # THEN + result = capsys.readouterr().out + assert expected == result diff --git a/test/openjd/adaptor_runtime/unit/adaptors/test_basic_adaptor.py b/test/openjd/adaptor_runtime/unit/adaptors/test_basic_adaptor.py new file mode 100644 index 0000000..8988b80 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/adaptors/test_basic_adaptor.py @@ -0,0 +1,39 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from openjd.adaptor_runtime.adaptors import CommandAdaptor +from openjd.adaptor_runtime.process import ManagedProcess + + +class FakeCommandAdaptor(CommandAdaptor): + """ + Test implementation of a CommandAdaptor + """ + + def __init__(self, init_data: dict): + super().__init__(init_data) + + def get_managed_process(self, run_data: dict) -> ManagedProcess: + return MagicMock() + + +class TestRun: + """ + Tests for the CommandAdaptor.run method + """ + + @patch.object(FakeCommandAdaptor, "get_managed_process") + def test_runs_managed_process(self, get_managed_process_mock: MagicMock): + # GIVEN + run_data = {"run": "data"} + adaptor = FakeCommandAdaptor({}) + + # WHEN + adaptor._run(run_data) + + # THEN + get_managed_process_mock.assert_called_once_with(run_data) + get_managed_process_mock.return_value.run.assert_called_once() diff --git a/test/openjd/adaptor_runtime/unit/adaptors/test_path_mapping.py b/test/openjd/adaptor_runtime/unit/adaptors/test_path_mapping.py new file mode 100644 index 0000000..4a7b1d7 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/adaptors/test_path_mapping.py @@ -0,0 +1,372 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from pathlib import PurePosixPath, PureWindowsPath + +import pytest + +from openjd.adaptor_runtime.adaptors import PathMappingRule + + +@pytest.mark.parametrize( + "rule", + [ + pytest.param({"source_os": "", "source_path": "", "destination_path": ""}), + pytest.param({"source_os": None, "source_path": None, "destination_path": None}), + pytest.param({"source_os": "", "source_path": None, "destination_path": ""}), + pytest.param({"source_os": "", "source_path": "C:/", "destination_path": "/mnt/"}), + pytest.param({"source_os": "windows", "source_path": "", "destination_path": "/mnt/"}), + pytest.param({"source_os": "windows", "source_path": "C:/", "destination_path": ""}), + pytest.param({"source_os": "nonvalid", "source_path": "C:/", "destination_path": "/mnt/"}), + pytest.param( + { + "source_os": "windows", + "destination_os": "nonvalid", + "source_path": "C:/", + "destination_path": "/mnt/", + } + ), + ], +) +def test_bad_args(rule): + # WHEN/THEN + with pytest.raises(ValueError): + PathMappingRule(**rule) + + with pytest.raises(ValueError): + PathMappingRule.from_dict(rule=rule) + + +@pytest.mark.parametrize( + "rule", + [ + pytest.param({}), + pytest.param(None), + ], +) +def test_no_args(rule): + # WHEN/THEN + with pytest.raises(TypeError): + PathMappingRule(**rule) + + with pytest.raises(ValueError): + PathMappingRule.from_dict(rule=rule) + + +def test_good_args(): + # GIVEN + rule = { + "source_os": "windows", + "destination_os": "windows", + "source_path": "Y:/movie1", + "destination_path": "Z:/movie2", + } + + # THEN + PathMappingRule(**rule) + PathMappingRule.from_dict(rule=rule) + + +@pytest.mark.parametrize( + "path", + [ + pytest.param("/usr"), + pytest.param("/usr/assets"), + pytest.param("/usr/scene/maya.mb"), + pytest.param("/usr/scene/../../../.."), + pytest.param("/usr/scene/../../../../who/knows/where/we/are"), + pytest.param("/usr/scene/symbolic_path/../who/knows/where/we/are"), + ], +) +def test_path_mapping_linux_is_match(path): + # GIVEN + rule = PathMappingRule(source_os="linux", source_path="/usr", destination_path="/mnt/shared") + pure_path = PurePosixPath(path) + + # WHEN + result = rule._is_match(pure_path=pure_path) + + # THEN + assert result + + +@pytest.mark.parametrize( + "path", + [ + pytest.param(""), + pytest.param("/"), + pytest.param("/Usr/Movie1"), + pytest.param("/usr/movie1"), + pytest.param("/usr\\Movie1"), + pytest.param("\\usr\\Movie1"), + pytest.param("/usr/Movie1a"), + ], +) +def test_path_mapping_linux_is_not_match(path): + # GIVEN + rule = PathMappingRule( + source_os="linux", source_path="/usr/Movie1", destination_path="/mnt/shared/Movie1" + ) + pure_path = PurePosixPath(path) + + # WHEN + result = rule._is_match(pure_path=pure_path) + + # THEN + assert not result + + +@pytest.mark.parametrize( + "path", + [ + pytest.param("Z:\\Movie1"), + pytest.param("z:\\movie1"), + pytest.param("z:/movie1"), + pytest.param("z:/movie1/assets"), + pytest.param("z:/movie1/assets/texture.png"), + pytest.param("Z://////Movie1"), + pytest.param("Z:\\\\\\Movie1"), + ], +) +def test_path_mapping_windows_is_match(path): + # GIVEN + rule = PathMappingRule( + source_os="windows", source_path="Z:\\Movie1", destination_path="/mnt/shared" + ) + pure_path = PureWindowsPath(path) + + # WHEN + result = rule._is_match(pure_path=pure_path) + + # THEN + assert result + + +@pytest.mark.parametrize( + "path", + [ + pytest.param("C:\\Movie1"), + pytest.param("Z:\\"), + pytest.param("Z:\\Movie1a"), + ], +) +def test_path_mapping_windows_is_not_match(path): + # GIVEN + rule = PathMappingRule( + source_os="windows", source_path="Z:\\Movie1", destination_path="/mnt/shared" + ) + pure_path = PureWindowsPath(path) + + # WHEN + result = rule._is_match(pure_path=pure_path) + + # THEN + assert not result + + +class TestApplyPathMapping: + def test_no_change(self): + # GIVEN + rule = PathMappingRule.from_dict( + rule={ + "source_os": "linux", + "source_path": "/mnt/shared/asset_storage2", + "destination_os": "linux", + "destination_path": "/mnt/shared/movie2", + } + ) + path = "/usr/assets/no_mapping.png" + expected = False, path + + # WHEN + result = rule.apply(path=path) + + # THEN + assert result == expected + + def test_linux_to_windows(self): + # GIVEN + rule = PathMappingRule.from_dict( + rule={ + "source_os": "linux", + "source_path": "/mnt/shared/asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\asset_storage1", + } + ) + path = "/mnt/shared/asset_storage1/asset.ext" + expected = True, "Z:\\asset_storage1\\asset.ext" + + # WHEN + result = rule.apply(path=path) + + # THEN + assert result == expected + + def test_windows_to_linux(self): + # GIVEN + rule = PathMappingRule.from_dict( + rule={ + "source_os": "windows", + "source_path": "Z:\\asset_storage1", + "destination_os": "linux", + "destination_path": "/mnt/shared/asset_storage1", + } + ) + path = "Z:\\asset_storage1\\asset.ext" + expected = True, "/mnt/shared/asset_storage1/asset.ext" + + # WHEN + result = rule.apply(path=path) + + # THEN + assert result == expected + + def test_linux_to_linux(self): + # GIVEN + rule = PathMappingRule.from_dict( + rule={ + "source_os": "linux", + "source_path": "/mnt/shared/my_custom_path/asset_storage1", + "destination_os": "linux", + "destination_path": "/mnt/shared/asset_storage1", + } + ) + + path = "/mnt/shared/my_custom_path/asset_storage1/asset.ext" + expected = True, "/mnt/shared/asset_storage1/asset.ext" + + # WHEN + result = rule.apply(path=path) + + # THEN + assert result == expected + + def test_windows_to_windows(self): + # GIVEN + rule = rule = PathMappingRule.from_dict( + rule={ + "source_os": "windows", + "source_path": "Z:\\my_custom_asset_path\\asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\asset_storage1", + } + ) + path = "Z:\\my_custom_asset_path\\asset_storage1\\asset.ext" + expected = True, "Z:\\asset_storage1\\asset.ext" + + # WHEN + result = rule.apply(path=path) + + # THEN + assert result == expected + + def test_windows_capitalization_agnostic(self): + # GIVEN + rule = PathMappingRule.from_dict( + rule={ + "source_os": "windows", + "source_path": "Z:\\my_custom_asset_path\\asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\asset_storage1", + } + ) + path = f"{rule.source_path.upper()}\\asset.ext" + expected = True, "Z:\\asset_storage1\\asset.ext" + + # WHEN + result = rule.apply(path=path) + + # THEN + assert result == expected + + def test_windows_directory_separator_agnostic(self): + # GIVEN + rule = PathMappingRule.from_dict( + rule={ + "source_os": "windows", + "source_path": "Z:\\my_custom_asset_path\\asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\asset_storage1", + } + ) + path = "Z:/my_custom_asset_path/asset_storage1/asset.ext" + expected = True, "Z:\\asset_storage1\\asset.ext" + + # WHEN + result = rule.apply(path=path) + + # THEN + assert result == expected + + def test_windows_directory_separator_agnostic_inverted(self): + # GIVEN + rule = PathMappingRule.from_dict( + rule={ + "source_os": "windows", + "source_path": "Z:/my_custom_asset_path/asset_storage1", + "destination_os": "windows", + "destination_path": "Z:\\asset_storage1", + } + ) + path = "Z:\\my_custom_asset_path\\asset_storage1\\asset.ext" + expected = True, "Z:\\asset_storage1\\asset.ext" + + # WHEN + result = rule.apply(path=path) + + # THEN + assert result == expected + + def test_starts_with_partial_match(self): + # GIVEN + rule = PathMappingRule.from_dict( + rule={ + "source_os": "linux", + "source_path": "a/b", + "destination_os": "linux", + "destination_path": "/c", + } + ) + path = "/a/bc/asset.ext" + expected = False, path + + # WHEN + result = rule.apply(path=path) + + # THEN + assert result == expected + + def test_partial_match(self): + # GIVEN + rule = PathMappingRule.from_dict( + rule={ + "source_os": "linux", + "source_path": "/bar/baz", + "destination_os": "linux", + "destination_path": "/bla", + } + ) + path = "/foo/bar/baz" + expected = False, path + + # WHEN + result = rule.apply(path=path) + + # THEN + assert result == expected + + def test_to_dict(self): + # GIVEN + rule_dict = { + "source_os": "linux", + "source_path": "/bar/baz", + "destination_os": "linux", + "destination_path": "/bla", + } + rule = PathMappingRule.from_dict(rule=rule_dict) + + # WHEN + result = rule.to_dict() + + # THEN + assert result == rule_dict diff --git a/test/openjd/adaptor_runtime/unit/application_ipc/__init__.py b/test/openjd/adaptor_runtime/unit/application_ipc/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/application_ipc/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime/unit/application_ipc/test_actions_queue.py b/test/openjd/adaptor_runtime/unit/application_ipc/test_actions_queue.py new file mode 100644 index 0000000..36b8312 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/application_ipc/test_actions_queue.py @@ -0,0 +1,87 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from collections import deque as _deque + +from openjd.adaptor_runtime_client import Action as _Action + +from openjd.adaptor_runtime.application_ipc import ActionsQueue as _ActionsQueue + + +class TestActionsQueue: + def test_actions_queue(self) -> None: + """Testing that we can enqueue correctly.""" + aq = _ActionsQueue() + + # Confirming the actions queue has been initialized. + assert aq._actions_queue == _deque() + + # Testing enqueue_action works as expected. + aq.enqueue_action(_Action("a1")) + aq.enqueue_action(_Action("a2")) + aq.enqueue_action(_Action("a3")) + + # Asserting actions were enqueued in order. + assert len(aq) == 3 + assert aq.dequeue_action() == _Action("a1") + assert aq.dequeue_action() == _Action("a2") + assert aq.dequeue_action() == _Action("a3") + assert aq.dequeue_action() is None + + def test_actions_queue_append_start(self) -> None: + aq = _ActionsQueue() + + # Testing enqueue_action works as expected. + aq.enqueue_action(_Action("a1")) + aq.enqueue_action(_Action("a4"), front=True) + + # Asserting actions were enqueued in order. + assert len(aq) == 2 + assert aq.dequeue_action() == _Action("a4") + assert aq.dequeue_action() == _Action("a1") + assert aq.dequeue_action() is None + + def test_len(self) -> None: + """Testing that our overriden __len__ works as expected.""" + aq = _ActionsQueue() + + # Starting off with an empty queue. + assert len(aq) == 0 + + # Adding 1 item to the queue. + aq.enqueue_action(_Action("a1")) + assert len(aq) == 1 + + # Adding a second item to the queue. + aq.enqueue_action(_Action("a2")) + assert len(aq) == 2 + + # Removing the first items from the queue. + aq.dequeue_action() + assert len(aq) == 1 + + # Removing the last from the queue. + aq.dequeue_action() + assert len(aq) == 0 + + def test_bool(self) -> None: + """Testing that our overriden __bool__ works as expected.""" + aq = _ActionsQueue() + + # Starting off with an empty queue. + assert not bool(aq) + + # Adding 1 item to the queue. + aq.enqueue_action(_Action("a1")) + assert bool(aq) + + # Adding a second item to the queue. + aq.enqueue_action(_Action("a2")) + assert bool(aq) + + # Removing the first items from the queue. + aq.dequeue_action() + assert bool(aq) + + # Removing the last from the queue. + aq.dequeue_action() + assert not bool(aq) diff --git a/test/openjd/adaptor_runtime/unit/application_ipc/test_adaptor_http_request_handler.py b/test/openjd/adaptor_runtime/unit/application_ipc/test_adaptor_http_request_handler.py new file mode 100644 index 0000000..55ab875 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/application_ipc/test_adaptor_http_request_handler.py @@ -0,0 +1,216 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import json +from http import HTTPStatus as _HTTPStatus +from unittest.mock import MagicMock, Mock, PropertyMock, patch +from urllib.parse import urlencode + +from _pytest.capture import CaptureFixture as _CaptureFixture +from openjd.adaptor_runtime_client import Action as _Action + +from openjd.adaptor_runtime.application_ipc import ActionsQueue as _ActionsQueue +from openjd.adaptor_runtime.application_ipc import AdaptorServer as _AdaptorServer +from openjd.adaptor_runtime.application_ipc._http_request_handler import ( + ActionEndpoint, +) +from openjd.adaptor_runtime.application_ipc._http_request_handler import ( + AdaptorHTTPRequestHandler as _AdaptorHTTPRequestHandler, +) +from openjd.adaptor_runtime.application_ipc._http_request_handler import ( + PathMappingEndpoint, +) +from openjd.adaptor_runtime.application_ipc._http_request_handler import ( + PathMappingRulesEndpoint, +) +from openjd.adaptor_runtime._http.request_handler import HTTPResponse + +from ..adaptors.fake_adaptor import FakeAdaptor + + +class TestPathMappingEndpoint: + @patch.object(PathMappingEndpoint, "query_string_params", new_callable=PropertyMock) + def test_get_internal_error(self, mock_qsp): + # GIVEN + mock_request_handler = MagicMock() + mock_server = MagicMock(spec=_AdaptorServer) + mock_request_handler.server = mock_server + mock_qsp.side_effect = Exception("Something bad happened") + handler = PathMappingEndpoint(mock_request_handler) + + # WHEN + response = handler.get() + + # THEN + assert response == HTTPResponse( + _HTTPStatus.INTERNAL_SERVER_ERROR, str(mock_qsp.side_effect) + ) + + def test_get_no_params_returns_bad_request(self): + # GIVEN + adaptor = FakeAdaptor({}) + mock_request_handler = MagicMock() + mock_server = MagicMock(spec=_AdaptorServer) + mock_server.adaptor = adaptor + mock_request_handler.server = mock_server + mock_request_handler.path = "localhost:8080/path_mapping" + + handler = PathMappingEndpoint(mock_request_handler) + + # WHEN + response = handler.get() + + # THEN + assert response == HTTPResponse(_HTTPStatus.BAD_REQUEST, "Missing path in query string.") + + def test_get_returns_mapped_path(self): + # GIVEN + SOURCE_PATH = "Z:\\asset_storage1" + DEST_PATH = "/mnt/shared/asset_storage1" + adaptor = FakeAdaptor( + {}, + path_mapping_data={ + "path_mapping_rules": [ + { + "source_os": "windows", + "source_path": SOURCE_PATH, + "destination_os": "linux", + "destination_path": DEST_PATH, + } + ] + }, + ) + mock_request_handler = MagicMock() + mock_server = MagicMock(spec=_AdaptorServer) + mock_server.adaptor = adaptor + mock_request_handler.server = mock_server + mock_request_handler.path = "localhost:8080/path_mapping?" + urlencode( + {"path": SOURCE_PATH + "\\somefile.png"} + ) + + handler = PathMappingEndpoint(mock_request_handler) + + # WHEN + response = handler.get() + + # THEN + assert response == HTTPResponse( + _HTTPStatus.OK, json.dumps({"path": DEST_PATH + "/somefile.png"}) + ) + + +class TestPathMappingRulesEndpoint: + def test_get_returns_rules(self): + # GIVEN + SOURCE_PATH = "Z:\\asset_storage1" + DEST_PATH = "/mnt/shared/asset_storage1" + rules = { + "source_os": "Windows", + "source_path": SOURCE_PATH, + "destination_os": "Linux", + "destination_path": DEST_PATH, + } + adaptor = FakeAdaptor( + {}, + path_mapping_data={"path_mapping_rules": [rules]}, + ) + mock_request_handler = MagicMock() + mock_server = MagicMock(spec=_AdaptorServer) + mock_server.adaptor = adaptor + mock_request_handler.server = mock_server + mock_request_handler.path = "localhost:8080/path_mapping_rules" + + handler = PathMappingRulesEndpoint(mock_request_handler) + + # WHEN + response = handler.get() + + # THEN + assert response == HTTPResponse(_HTTPStatus.OK, json.dumps({"path_mapping_rules": [rules]})) + + +class TestActionEndpoint: + def test_get_returns_action(self): + # GIVEN + mock_request_handler = MagicMock() + mock_server = MagicMock(spec=_AdaptorServer) + mock_server.actions_queue = _ActionsQueue() + mock_request_handler.server = mock_server + + a1 = _Action("a1", {"arg1": "val1"}) + mock_server.actions_queue.enqueue_action(a1) + + handler = ActionEndpoint(mock_request_handler) + + # WHEN + response = handler.get() + + # THEN + assert response == HTTPResponse(_HTTPStatus.OK, str(a1)) + + def test_dequeue_no_action(self) -> None: + # GIVEN + mock_request_handler = MagicMock() + mock_server = MagicMock(spec=_AdaptorServer) + mock_server.actions_queue = _ActionsQueue() + mock_request_handler.server = mock_server + + handler = ActionEndpoint(mock_request_handler) + + # WHEN + action = handler._dequeue_action() + + # THEN + assert action is None + + @patch.object(_AdaptorHTTPRequestHandler, "__init__", return_value=None) + def test_dequeue_action(self, mocked_init: Mock) -> None: + # GIVEN + mock_request_handler = MagicMock() + mock_server = MagicMock(spec=_AdaptorServer) + mock_server.actions_queue = _ActionsQueue() + mock_request_handler.server = mock_server + + handler = ActionEndpoint(mock_request_handler) + + a1 = _Action("a1", {"arg1": "val1"}) + mock_server.actions_queue.enqueue_action(a1) + + # WHEN + action = handler._dequeue_action() + + # THEN + assert action == a1 + + def test_dequeue_action_no_server(self, capsys: _CaptureFixture) -> None: + # GIVEN + mock_request_handler = MagicMock() + mock_request_handler.server = None + handler = ActionEndpoint(mock_request_handler) + + # WHEN + action = handler._dequeue_action() + + # THEN + assert action is None + assert ( + "Could not retrieve the next action because the server or actions queue" + " wasn't set." in capsys.readouterr().err + ) + + def test_dequeue_action_no_queue(self, capsys: _CaptureFixture) -> None: + # GIVEN + mock_request_handler = MagicMock() + mock_server = MagicMock(spec=_AdaptorServer) + mock_request_handler.server = mock_server + + handler = ActionEndpoint(mock_request_handler) + + # WHEN + action = handler._dequeue_action() + + # THEN + assert action is None + assert ( + "Could not retrieve the next action because the server or actions queue" + " wasn't set." in capsys.readouterr().err + ) diff --git a/test/openjd/adaptor_runtime/unit/background/__init__.py b/test/openjd/adaptor_runtime/unit/background/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/background/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime/unit/background/test_backend_runner.py b/test/openjd/adaptor_runtime/unit/background/test_backend_runner.py new file mode 100644 index 0000000..ed8c380 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/background/test_backend_runner.py @@ -0,0 +1,174 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import json +import os +import pathlib +import signal +from typing import Generator +from unittest.mock import MagicMock, Mock, call, mock_open, patch + +import pytest + +import openjd.adaptor_runtime._background.backend_runner as backend_runner +from openjd.adaptor_runtime._background.backend_runner import BackendRunner +from openjd.adaptor_runtime._background.model import ConnectionSettings, DataclassJSONEncoder + + +class TestBackendRunner: + """ + Tests for the BackendRunner class + """ + + @pytest.fixture(autouse=True) + def socket_path(self, tmp_path: pathlib.Path) -> Generator[str, None, None]: + with patch.object(backend_runner.SocketDirectories, "get_process_socket_path") as mock: + path = os.path.join(tmp_path, "socket", "1234") + mock.return_value = path + + yield path + + try: + os.remove(path) + except FileNotFoundError: + pass + + @pytest.fixture(autouse=True) + def mock_server_cls(self) -> Generator[MagicMock, None, None]: + with patch.object(backend_runner, "BackgroundHTTPServer", autospec=True) as mock: + yield mock + + @patch.object(backend_runner.json, "dump") + @patch.object(backend_runner.os, "remove") + @patch.object(backend_runner, "Queue") + @patch.object(backend_runner, "Thread") + def test_run( + self, + mock_thread: MagicMock, + mock_queue: MagicMock, + mock_os_remove: MagicMock, + mock_json_dump: MagicMock, + mock_server_cls: MagicMock, + socket_path: str, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + caplog.set_level("DEBUG") + conn_file_path = "/path/to/conn_file" + connection_settings = {"socket": socket_path} + adaptor_runner = Mock() + runner = BackendRunner(adaptor_runner, conn_file_path) + + # WHEN + open_mock: MagicMock + with patch.object( + backend_runner, "secure_open", mock_open(read_data=json.dumps(connection_settings)) + ) as open_mock: + runner.run() + + # THEN + assert caplog.messages == [ + "Running in background daemon mode.", + f"Listening on {socket_path}", + "HTTP server has shutdown.", + ] + mock_server_cls.assert_called_once_with( + socket_path, + adaptor_runner, + mock_queue.return_value, + log_buffer=None, + ) + mock_thread.assert_called_once() + mock_thread.return_value.start.assert_called_once() + open_mock.assert_called_once_with(conn_file_path, open_mode="w") + mock_json_dump.assert_called_once_with( + ConnectionSettings(socket_path), + open_mock.return_value, + cls=DataclassJSONEncoder, + ) + mock_thread.return_value.join.assert_called_once() + mock_os_remove.assert_has_calls([call(conn_file_path), call(socket_path)]) + + def test_run_raises_when_http_server_fails_to_start( + self, + mock_server_cls: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + caplog.set_level("DEBUG") + exc = Exception() + mock_server_cls.side_effect = exc + runner = BackendRunner(Mock(), "") + + # WHEN + with pytest.raises(Exception) as raised_exc: + runner.run() + + # THEN + assert raised_exc.value is exc + assert caplog.messages == [ + "Running in background daemon mode.", + "Error starting in background mode: ", + ] + + @patch.object(backend_runner, "secure_open") + @patch.object(backend_runner.os, "remove") + @patch.object(backend_runner, "Queue") + @patch.object(backend_runner, "Thread") + def test_run_raises_when_writing_connection_file_fails( + self, + mock_thread: MagicMock, + mock_queue: MagicMock, + mock_os_remove: MagicMock, + open_mock: MagicMock, + socket_path: str, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + caplog.set_level("DEBUG") + err = OSError() + open_mock.side_effect = err + conn_file_path = "/path/to/conn_file" + adaptor_runner = Mock() + runner = BackendRunner(adaptor_runner, conn_file_path) + + # WHEN + with pytest.raises(OSError) as raised_err: + runner.run() + + # THEN + assert raised_err.value is err + mock_queue.return_value.put.assert_called_once_with(True) + assert caplog.messages == [ + "Running in background daemon mode.", + f"Listening on {socket_path}", + "Error writing to connection file: ", + "Shutting down server...", + "HTTP server has shutdown.", + ] + mock_thread.assert_called_once() + mock_thread.return_value.start.assert_called_once() + open_mock.assert_called_once_with(conn_file_path, open_mode="w") + mock_thread.return_value.join.assert_called_once() + mock_os_remove.assert_has_calls([call(conn_file_path), call(socket_path)]) + + @patch.object(backend_runner.signal, "signal") + def test_signal_hook(self, signal_mock: MagicMock) -> None: + # Test that we create the signal hook, and that it initiates a cancelation + # as expected. + + # GIVEN + conn_file_path = "/path/to/conn_file" + adaptor_runner = Mock() + runner = BackendRunner(adaptor_runner, conn_file_path) + server_mock = MagicMock() + submit_mock = MagicMock() + server_mock.submit = submit_mock + runner._http_server = server_mock + + # WHEN + runner._sigint_handler(MagicMock(), MagicMock()) + + # THEN + signal_mock.assert_any_call(signal.SIGINT, runner._sigint_handler) + signal_mock.assert_any_call(signal.SIGTERM, runner._sigint_handler) + submit_mock.assert_called_with(adaptor_runner._cancel, force_immediate=True) diff --git a/test/openjd/adaptor_runtime/unit/background/test_frontend_runner.py b/test/openjd/adaptor_runtime/unit/background/test_frontend_runner.py new file mode 100644 index 0000000..114e586 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/background/test_frontend_runner.py @@ -0,0 +1,743 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import http.client as http_client +import json +import re +import signal +import subprocess +import sys +from types import ModuleType +from typing import Generator +from unittest.mock import MagicMock, PropertyMock, call, mock_open, patch + +import pytest + +import openjd.adaptor_runtime._background.frontend_runner as frontend_runner +from openjd.adaptor_runtime.adaptors import AdaptorState +from openjd.adaptor_runtime._background.frontend_runner import ( + AdaptorFailedException, + FrontendRunner, + HTTPError, + _load_connection_settings, + _wait_for_file, +) +from openjd.adaptor_runtime._background.model import ( + AdaptorStatus, + BufferedOutput, + ConnectionSettings, + DataclassMapper, + HeartbeatResponse, +) + + +class TestFrontendRunner: + """ + Tests for the FrontendRunner class + """ + + @pytest.fixture + def socket_path(self) -> str: + return "/path/to/socket" + + @pytest.fixture(autouse=True) + def mock_connection_settings(self, socket_path: str) -> Generator[MagicMock, None, None]: + with patch.object(FrontendRunner, "connection_settings", new_callable=PropertyMock) as mock: + mock.return_value = ConnectionSettings(socket_path) + yield mock + + class TestInit: + """ + Tests for the FrontendRunner.init method + """ + + @patch.object(frontend_runner.sys, "argv") + @patch.object(frontend_runner.sys, "executable") + @patch.object(frontend_runner.json, "dumps") + @patch.object(FrontendRunner, "_heartbeat") + @patch.object(frontend_runner, "_wait_for_file") + @patch.object(frontend_runner.subprocess, "Popen") + @patch.object(frontend_runner.os.path, "exists") + def test_initializes_backend_process( + self, + mock_exists: MagicMock, + mock_Popen: MagicMock, + mock_wait_for_file: MagicMock, + mock_heartbeat: MagicMock, + mock_json_dumps: MagicMock, + mock_sys_executable: MagicMock, + mock_sys_argv: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + caplog.set_level("DEBUG") + mock_json_dumps.return_value = "test" + mock_exists.return_value = False + pid = 123 + mock_Popen.return_value.pid = pid + mock_sys_executable.return_value = "executable" + mock_sys_argv.return_value = [] + adaptor_module = ModuleType("") + adaptor_module.__package__ = "package" + conn_file_path = "/path" + init_data = {"init": "data"} + runner = FrontendRunner(conn_file_path) + + # WHEN + runner.init(adaptor_module, init_data) + + # THEN + assert caplog.messages == [ + "Initializing backend process...", + f"Started backend process. PID: {pid}", + "Verifying connection to backend...", + "Connected successfully", + ] + mock_exists.assert_called_once_with(conn_file_path) + mock_Popen.assert_called_once_with( + [ + sys.executable, + "-m", + adaptor_module.__package__, + "daemon", + "_serve", + "--connection-file", + conn_file_path, + "--init-data", + json.dumps(init_data), + ], + shell=False, + close_fds=True, + start_new_session=True, + stdin=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + mock_wait_for_file.assert_called_once_with(conn_file_path, timeout_s=5) + mock_heartbeat.assert_called_once() + + def test_raises_when_adaptor_module_not_package(self): + # GIVEN + adaptor_module = ModuleType("") + adaptor_module.__package__ = None + runner = FrontendRunner("") + + # WHEN + with pytest.raises(Exception) as raised_exc: + runner.init(adaptor_module) + + # THEN + assert raised_exc.match(f"Adaptor module is not a package: {adaptor_module}") + + @patch.object(frontend_runner.os.path, "exists") + def test_raises_when_connection_file_exists( + self, + mock_exists: MagicMock, + ): + # GIVEN + mock_exists.return_value = True + adaptor_module = ModuleType("") + adaptor_module.__package__ = "package" + conn_file_path = "/path" + runner = FrontendRunner(conn_file_path) + + # WHEN + with pytest.raises(FileExistsError) as raised_err: + runner.init(adaptor_module) + + # THEN + assert raised_err.match( + "Cannot init a new backend process with an existing connection file at: " + + conn_file_path + ) + mock_exists.assert_called_once_with(conn_file_path) + + @patch.object(frontend_runner.subprocess, "Popen") + @patch.object(frontend_runner.os.path, "exists") + def test_raises_when_failed_to_create_backend_process( + self, + mock_exists: MagicMock, + mock_Popen: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + caplog.set_level("DEBUG") + exc = Exception() + mock_Popen.side_effect = exc + mock_exists.return_value = False + adaptor_module = ModuleType("") + adaptor_module.__package__ = "package" + conn_file_path = "/path" + runner = FrontendRunner(conn_file_path) + + # WHEN + with pytest.raises(Exception) as raised_exc: + runner.init(adaptor_module) + + # THEN + assert raised_exc.value is exc + assert caplog.messages == [ + "Initializing backend process...", + "Failed to initialize backend process: ", + ] + mock_exists.assert_called_once_with(conn_file_path) + mock_Popen.assert_called_once() + + @patch.object(frontend_runner, "_wait_for_file") + @patch.object(frontend_runner.subprocess, "Popen") + @patch.object(frontend_runner.os.path, "exists") + def test_raises_when_connection_file_wait_times_out( + self, + mock_exists: MagicMock, + mock_Popen: MagicMock, + mock_wait_for_file: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + caplog.set_level("DEBUG") + err = TimeoutError() + mock_wait_for_file.side_effect = err + mock_exists.return_value = False + pid = 123 + mock_Popen.return_value.pid = pid + adaptor_module = ModuleType("") + adaptor_module.__package__ = "package" + conn_file_path = "/path" + runner = FrontendRunner(conn_file_path) + + # WHEN + with pytest.raises(TimeoutError) as raised_err: + runner.init(adaptor_module) + + # THEN + assert raised_err.value is err + print(caplog.messages) + assert caplog.messages == [ + "Initializing backend process...", + f"Started backend process. PID: {pid}", + f"Backend process failed to write connection file in time at: {conn_file_path}", + ] + mock_exists.assert_called_once_with(conn_file_path) + mock_Popen.assert_called_once() + mock_wait_for_file.assert_called_once_with(conn_file_path, timeout_s=5) + + class TestHeartbeat: + """ + Tests for the FrontendRunner._heartbeat method + """ + + @patch.object(frontend_runner.json, "load") + @patch.object(DataclassMapper, "map") + @patch.object(FrontendRunner, "_send_request") + def test_sends_heartbeat( + self, + mock_send_request: MagicMock, + mock_map: MagicMock, + mock_json_load: MagicMock, + ): + # GIVEN + mock_response = mock_send_request.return_value + runner = FrontendRunner("") + + # WHEN + response = runner._heartbeat() + + # THEN + assert response is mock_map.return_value + mock_json_load.assert_called_once_with(mock_response.fp) + mock_map.assert_called_once_with(mock_json_load.return_value) + mock_send_request.assert_called_once_with("GET", "/heartbeat", params=None) + + @patch.object(frontend_runner.json, "load") + @patch.object(DataclassMapper, "map") + @patch.object(FrontendRunner, "_send_request") + def test_sends_heartbeat_with_ack_id( + self, + mock_send_request: MagicMock, + mock_map: MagicMock, + mock_json_load: MagicMock, + ): + # GIVEN + ack_id = "ack_id" + mock_response = mock_send_request.return_value + runner = FrontendRunner("") + + # WHEN + response = runner._heartbeat(ack_id) + + # THEN + assert response is mock_map.return_value + mock_json_load.assert_called_once_with(mock_response.fp) + mock_map.assert_called_once_with(mock_json_load.return_value) + mock_send_request.assert_called_once_with( + "GET", "/heartbeat", params={"ack_id": ack_id} + ) + + class TestHeartbeatUntilComplete: + """ + Tests for FrontendRunner._heartbeat_until_state_complete + """ + + @patch.object(FrontendRunner, "_heartbeat") + @patch("openjd.adaptor_runtime._background.frontend_runner.Event") + def test_heartbeats_until_complete( + self, mock_event_class: MagicMock, mock_heartbeat: MagicMock + ): + # GIVEN + state = AdaptorState.RUN + ack_id = "id" + mock_heartbeat.side_effect = [ + HeartbeatResponse( + state=state, + status=status, + output=BufferedOutput(id=ack_id, output="output"), + ) + # Working -> Idle -> Idle (for final ACK heartbeat) + for status in [AdaptorStatus.WORKING, AdaptorStatus.IDLE, AdaptorStatus.IDLE] + ] + mock_event = MagicMock() + mock_event_class.return_value = mock_event + mock_event.wait = MagicMock() + mock_event.is_set = MagicMock(return_value=False) + heartbeat_interval = 1 + runner = FrontendRunner("", heartbeat_interval=heartbeat_interval) + + # WHEN + runner._heartbeat_until_state_complete(state) + + # THEN + mock_heartbeat.assert_has_calls([call(None), call(ack_id)]) + mock_event.wait.assert_called_once_with(timeout=heartbeat_interval) + + @patch.object(FrontendRunner, "_heartbeat") + def test_raises_when_adaptor_fails(self, mock_heartbeat: MagicMock) -> None: + # GIVEN + state = AdaptorState.RUN + ack_id = "id" + failure_message = "failed" + mock_heartbeat.side_effect = [ + HeartbeatResponse( + state=state, + status=AdaptorStatus.IDLE, + output=BufferedOutput(id=ack_id, output=failure_message), + failed=True, + ), + HeartbeatResponse( + state=state, + status=AdaptorStatus.IDLE, + output=BufferedOutput(id="id2", output="output2"), + failed=False, + ), + ] + runner = FrontendRunner("") + + # WHEN + with pytest.raises(AdaptorFailedException) as raised_exc: + runner._heartbeat_until_state_complete(state) + + # THEN + mock_heartbeat.assert_has_calls([call(None), call(ack_id)]) + assert raised_exc.match(failure_message) + + class TestShutdown: + """ + Tests for the FrontendRunner.shutdown method + """ + + @patch.object(FrontendRunner, "_send_request") + def test_sends_shutdown(self, mock_send_request: MagicMock): + # GIVEN + runner = FrontendRunner("") + + # WHEN + runner.shutdown() + + # THEN + mock_send_request.assert_called_once_with("PUT", "/shutdown") + + class TestRun: + """ + Tests for the FrontendRunner.run method + """ + + @patch.object(FrontendRunner, "_heartbeat_until_state_complete") + @patch.object(FrontendRunner, "_send_request") + def test_sends_run( + self, + mock_send_request: MagicMock, + mock_heartbeat_until_state_complete: MagicMock, + ): + # GIVEN + run_data = {"run": "data"} + runner = FrontendRunner("") + + # WHEN + runner.run(run_data) + + # THEN + mock_send_request.assert_called_once_with("PUT", "/run", json_body=run_data) + mock_heartbeat_until_state_complete.assert_called_once_with(AdaptorState.RUN) + + class TestStart: + """ + Tests for the FrontendRunner.start method + """ + + @patch.object(FrontendRunner, "_heartbeat_until_state_complete") + @patch.object(FrontendRunner, "_send_request") + def test_sends_start( + self, + mock_send_request: MagicMock, + mock_heartbeat_until_state_complete: MagicMock, + ): + # GIVEN + runner = FrontendRunner("") + + # WHEN + runner.start() + + # THEN + mock_send_request.assert_called_once_with("PUT", "/start") + mock_heartbeat_until_state_complete.assert_called_once_with(AdaptorState.START) + + class TestEnd: + """ + Tests for the FrontendRunner.end method + """ + + @patch.object(FrontendRunner, "_heartbeat_until_state_complete") + @patch.object(FrontendRunner, "_send_request") + def test_sends_end( + self, + mock_send_request: MagicMock, + mock_heartbeat_until_state_complete: MagicMock, + ): + # GIVEN + runner = FrontendRunner("") + + # WHEN + runner.stop() + + # THEN + mock_send_request.assert_called_once_with("PUT", "/stop") + mock_heartbeat_until_state_complete.assert_called_once_with(AdaptorState.CLEANUP) + + class TestCancel: + """ + Tests for the FrontendRunner.cancel method + """ + + @patch.object(FrontendRunner, "_send_request") + def test_sends_cancel( + self, + mock_send_request: MagicMock, + ): + # GIVEN + runner = FrontendRunner("") + + # WHEN + runner.cancel() + + # THEN + mock_send_request.assert_called_once_with("PUT", "/cancel") + + class TestSendRequest: + """ + Tests for the FrontendRunner._send_request method + """ + + @pytest.fixture + def mock_response(self) -> MagicMock: + return MagicMock() + + @pytest.fixture + def mock_getresponse(self, mock_response: MagicMock) -> Generator[MagicMock, None, None]: + with patch.object(frontend_runner.UnixHTTPConnection, "getresponse") as mock: + mock.return_value = mock_response + mock_response.status = 200 + yield mock + + @patch.object(frontend_runner.UnixHTTPConnection, "request") + def test_sends_request(self, mock_request: MagicMock, mock_getresponse: MagicMock): + # GIVEN + method = "GET" + path = "/path" + conn_file_path = "/conn/file/path" + runner = FrontendRunner(conn_file_path) + + # WHEN + response = runner._send_request(method, path) + + # THEN + mock_request.assert_called_once_with( + method, + path, + body=None, + ) + mock_getresponse.assert_called_once() + assert response is mock_getresponse.return_value + + @patch.object(frontend_runner.UnixHTTPConnection, "request") + def test_raises_when_request_fails( + self, + mock_request: MagicMock, + mock_getresponse: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + exc = http_client.HTTPException() + mock_getresponse.side_effect = exc + method = "GET" + path = "/path" + conn_file_path = "/conn/file/path" + runner = FrontendRunner(conn_file_path) + + # WHEN + with pytest.raises(http_client.HTTPException) as raised_exc: + runner._send_request(method, path) + + # THEN + assert raised_exc.value is exc + assert f"Failed to send {path} request: " in caplog.text + mock_request.assert_called_once_with( + method, + path, + body=None, + ) + mock_getresponse.assert_called_once() + + @patch.object(frontend_runner.UnixHTTPConnection, "request") + def test_raises_when_error_response_received( + self, + mock_request: MagicMock, + mock_getresponse: MagicMock, + mock_response: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + mock_response.status = 500 + mock_response.reason = "Something went wrong" + method = "GET" + path = "/path" + conn_file_path = "/conn/file/path" + runner = FrontendRunner(conn_file_path) + + # WHEN + with pytest.raises(HTTPError) as raised_err: + runner._send_request(method, path) + + # THEN + errmsg = f"Received unexpected HTTP status code {mock_response.status}: " + str( + mock_response.reason + ) + assert errmsg in caplog.text + assert raised_err.match(re.escape(errmsg)) + mock_request.assert_called_once_with( + method, + path, + body=None, + ) + mock_getresponse.assert_called_once() + + @patch.object(frontend_runner.UnixHTTPConnection, "request") + def test_formats_query_string(self, mock_request: MagicMock, mock_getresponse: MagicMock): + # GIVEN + method = "GET" + path = "/path" + conn_file_path = "/conn/file/path" + params = {"first param": 1, "second_param": ["one", "two three"]} + runner = FrontendRunner(conn_file_path) + + # WHEN + response = runner._send_request(method, path, params=params) + + # THEN + mock_request.assert_called_once_with( + method, + f"{path}?first+param=1&second_param=one&second_param=two+three", + body=None, + ) + mock_getresponse.assert_called_once() + assert response is mock_getresponse.return_value + + @patch.object(frontend_runner.UnixHTTPConnection, "request") + def test_sends_body(self, mock_request: MagicMock, mock_getresponse: MagicMock): + # GIVEN + method = "GET" + path = "/path" + conn_file_path = "/conn/file/path" + json = {"the": "body"} + runner = FrontendRunner(conn_file_path) + + # WHEN + response = runner._send_request(method, path, json_body=json) + + # THEN + mock_request.assert_called_once_with( + method, + path, + body='{"the": "body"}', + ) + mock_getresponse.assert_called_once() + assert response is mock_getresponse.return_value + + class TestSignalHandling: + @patch.object(FrontendRunner, "cancel") + @patch.object(frontend_runner.signal, "signal") + def test_hook(self, signal_mock: MagicMock, cancel_mock: MagicMock) -> None: + # Test that we create the signal hook, and that it initiates a cancelation + # as expected. + + # GIVEN + conn_file_path = "/path/to/conn_file" + runner = FrontendRunner(conn_file_path) + + # WHEN + runner._sigint_handler(MagicMock(), MagicMock()) + + # THEN + signal_mock.assert_any_call(signal.SIGINT, runner._sigint_handler) + signal_mock.assert_any_call(signal.SIGTERM, runner._sigint_handler) + cancel_mock.assert_called_once() + + +class TestLoadConnectionSettings: + """ + Tests for the _load_connection_settings method + """ + + @patch.object(DataclassMapper, "map") + def test_loads_settings( + self, + mock_map: MagicMock, + ): + # GIVEN + filepath = "/path" + connection_settings = {"port": 123} + + # WHEN + with patch.object( + frontend_runner, "open", mock_open(read_data=json.dumps(connection_settings)) + ): + _load_connection_settings(filepath) + + # THEN + mock_map.assert_called_once_with(connection_settings) + + @patch.object(frontend_runner, "open") + def test_raises_when_file_open_fails( + self, + open_mock: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + filepath = "/path" + err = OSError() + open_mock.side_effect = err + + # WHEN + with pytest.raises(OSError) as raised_err: + _load_connection_settings(filepath) + + # THEN + assert raised_err.value is err + assert "Failed to open connection file: " in caplog.text + + @patch.object(frontend_runner.json, "load") + def test_raises_when_json_decode_fails( + self, + mock_json_load: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + filepath = "/path" + err = json.JSONDecodeError("", "", 0) + mock_json_load.side_effect = err + + # WHEN + with pytest.raises(json.JSONDecodeError) as raised_err: + with patch.object(frontend_runner, "open", mock_open()): + _load_connection_settings(filepath) + + # THEN + assert raised_err.value is err + assert "Failed to decode connection file: " in caplog.text + + +class TestWaitForFile: + """ + Tests for the _wait_for_file method + """ + + @patch.object(frontend_runner, "open") + @patch.object(frontend_runner.time, "time") + @patch.object(frontend_runner.time, "sleep") + @patch.object(frontend_runner.os.path, "exists") + def test_waits_for_file( + self, + mock_exists: MagicMock, + mock_sleep: MagicMock, + mock_time: MagicMock, + open_mock: MagicMock, + ): + # GIVEN + filepath = "/path" + timeout = sys.float_info.max + interval = 0.01 + mock_time.side_effect = [1, 2, 3, 4] + mock_exists.side_effect = [False, True] + err = IOError() + open_mock.side_effect = [err, MagicMock()] + + # WHEN + _wait_for_file(filepath, timeout, interval) + + # THEN + assert mock_time.call_count == 4 + mock_exists.assert_has_calls([call(filepath)] * 2) + mock_sleep.assert_has_calls([call(interval)] * 3) + open_mock.assert_has_calls([call(filepath, mode="r")] * 2) + + @patch.object(frontend_runner.time, "time") + @patch.object(frontend_runner.time, "sleep") + @patch.object(frontend_runner.os.path, "exists") + def test_raises_when_timeout_reached( + self, + mock_exists: MagicMock, + mock_sleep: MagicMock, + mock_time: MagicMock, + ): + # GIVEN + filepath = "/path" + timeout = 0 + interval = 0.01 + mock_time.side_effect = [1, 2] + mock_exists.side_effect = [False] + + # WHEN + with pytest.raises(TimeoutError) as raised_err: + _wait_for_file(filepath, timeout, interval) + + # THEN + assert raised_err.match(f"Timed out after {timeout}s waiting for file at {filepath}") + assert mock_time.call_count == 2 + mock_exists.assert_called_once_with(filepath) + mock_sleep.assert_not_called() + + +@patch.object(frontend_runner, "_load_connection_settings") +def test_connection_settings_lazy_loads(mock_load_connection_settings: MagicMock): + # GIVEN + filepath = "/path" + expected = ConnectionSettings("/socket") + mock_load_connection_settings.return_value = expected + runner = FrontendRunner(filepath) + + # Assert the internal connection settings var is not set yet + assert not hasattr(runner, "_connection_settings") + + # WHEN + actual = runner.connection_settings + + # THEN + assert actual is expected + assert runner._connection_settings is expected diff --git a/test/openjd/adaptor_runtime/unit/background/test_http_server.py b/test/openjd/adaptor_runtime/unit/background/test_http_server.py new file mode 100644 index 0000000..a26a006 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/background/test_http_server.py @@ -0,0 +1,778 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import json +import os +import socketserver +from http import HTTPStatus +from queue import Queue +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest + +import openjd.adaptor_runtime._background.http_server as http_server +from openjd.adaptor_runtime.adaptors import AdaptorRunner +from openjd.adaptor_runtime.adaptors._adaptor_runner import _OPENJD_FAIL_STDOUT_PREFIX +from openjd.adaptor_runtime._background.http_server import ( + AsyncFutureRunner, + BackgroundHTTPServer, + BackgroundRequestHandler, + BackgroundResourceRequestHandler, + CancelHandler, + StopHandler, + HeartbeatHandler, + RunHandler, + ShutdownHandler, + StartHandler, + ThreadPoolExecutor, +) +from openjd.adaptor_runtime._background.log_buffers import InMemoryLogBuffer +from openjd.adaptor_runtime._background.model import AdaptorState, BufferedOutput + + +@pytest.fixture +def fake_server() -> socketserver.BaseServer: + class FakeServer(socketserver.BaseServer): + def __init__(self) -> None: + pass + + return FakeServer() + + +@pytest.fixture +def fake_request_handler() -> BackgroundRequestHandler: + class FakeBackgroundRequestHandler(BackgroundRequestHandler): + path: str = "/fake" + + def __init__(self) -> None: + pass + + return FakeBackgroundRequestHandler() + + +class TestAsyncFutureRunner: + """ + Tests for the AsyncFutureRunner class + """ + + @patch.object(ThreadPoolExecutor, "submit") + def test_submit(self, mock_submit: MagicMock): + # GIVEN + mock_fn = MagicMock() + args = ("hello", "world") + kwargs = {"hello": "world"} + runner = AsyncFutureRunner() + + # WHEN + runner.submit(mock_fn, *args, **kwargs) + + # THEN + mock_submit.assert_called_once_with(mock_fn, *args, **kwargs) + + @patch.object(AsyncFutureRunner, "is_running", new_callable=PropertyMock) + def test_submit_raises_if_running(self, mock_is_running: MagicMock): + # GIVEN + mock_is_running.return_value = True + runner = AsyncFutureRunner() + + # WHEN + with pytest.raises(Exception) as raised_exc: + runner.submit(print) + + # THEN + mock_is_running.assert_called_once() + assert raised_exc.match("Cannot submit new task while another task is running") + + @pytest.mark.parametrize( + argnames=["running"], + argvalues=[[True], [False]], + ids=["Running", "Not running"], + ) + def test_is_running_reflects_future(self, running: bool): + # GIVEN + mock_future = MagicMock() + mock_future.running.return_value = running + runner = AsyncFutureRunner() + runner._future = mock_future + + # WHEN + is_running = runner.is_running + + # THEN + assert is_running == running + mock_future.running.assert_called_once() + + @pytest.mark.parametrize( + argnames=["running", "done", "expected"], + argvalues=[ + [True, True, True], + [True, False, True], + [False, True, True], + [False, False, False], + ], + ids=[ + "running and done", + "running and not done", + "not running and done", + "not running and not done", + ], + ) + def test_has_started_reflects_future(self, running: bool, done: bool, expected: bool): + # GIVEN + mock_future = MagicMock() + mock_future.running.return_value = running + mock_future.done.return_value = done + runner = AsyncFutureRunner() + runner._future = mock_future + + # WHEN + has_started = runner.has_started + + # THEN + assert has_started == expected + mock_future.running.assert_called_once() + # Only assert done called if the OR expression was not short-circuited + if not running: + mock_future.done.assert_called_once() + + @patch.object(http_server.time, "sleep") + @patch.object(AsyncFutureRunner, "has_started", new_callable=PropertyMock) + def test_wait_for_start(self, mock_has_started, mock_sleep): + # GIVEN + mock_has_started.side_effect = [False, True] + runner = AsyncFutureRunner() + + # WHEN + runner.wait_for_start() + + # THEN + assert mock_sleep.called_once_with(AsyncFutureRunner._WAIT_FOR_START_INTERVAL) + + +class TestBackgroundHTTPServer: + """ + Tests for the BackgroundHTTPServer class + """ + + class TestSubmit: + """ + Tests for the BackgroundHTTPServer.submit method + """ + + def test_submits_work(self): + # GIVEN + def my_fn(): + pass + + args = ("one", "two") + kwargs = {"three": 3, "four": 4} + + mock_server = MagicMock(spec=BackgroundHTTPServer) + mock_future_runner = MagicMock() + mock_server._future_runner = mock_future_runner + + # WHEN + result = BackgroundHTTPServer.submit(mock_server, my_fn, *args, **kwargs) + + # THEN + mock_future_runner.submit.assert_called_once_with(my_fn, *args, **kwargs) + mock_future_runner.wait_for_start.assert_called_once() + assert result.status == HTTPStatus.OK + + def test_returns_500_if_fails_to_submit_work(self, caplog: pytest.LogCaptureFixture): + # GIVEN + def my_fn(): + pass + + args = ("one", "two") + kwargs = {"three": 3, "four": 4} + + mock_server = MagicMock(spec=BackgroundHTTPServer) + mock_future_runner = MagicMock() + exc = Exception() + mock_future_runner.submit.side_effect = exc + mock_server._future_runner = mock_future_runner + + # WHEN + result = BackgroundHTTPServer.submit(mock_server, my_fn, *args, **kwargs) + + # THEN + mock_future_runner.submit.assert_called_once_with(my_fn, *args, **kwargs) + assert result.status == HTTPStatus.INTERNAL_SERVER_ERROR + assert result.body == str(exc) + assert "Failed to submit work: " in caplog.text + + +class TestBackgroundRequestHandler: + """ + Tests for the BackgroundRequestHandler class + """ + + def test_init_raises_when_server_is_incompatible(self, fake_server: socketserver.BaseServer): + # WHEN + with pytest.raises(TypeError) as raised_err: + BackgroundRequestHandler("".encode("utf-8"), "", fake_server) + + assert raised_err.match( + f"Received incompatible server class. Expected {BackgroundHTTPServer.__name__}, " + f"but got {type(fake_server)}" + ) + + +class TestBackgroundResourceRequestHandler: + """ + Tests for the RequestHandler class + """ + + def test_server_property_raises( + self, + fake_server: socketserver.BaseServer, + fake_request_handler: BackgroundRequestHandler, + ): + # GIVEN + class FakeRequestHandler(BackgroundResourceRequestHandler): + def __init__(self, handler: BackgroundRequestHandler) -> None: + self.handler = handler + + fake_request_handler.server = fake_server + handler = FakeRequestHandler(fake_request_handler) + + # WHEN + with pytest.raises(TypeError) as raised_err: + handler.server + + # THEN + assert raised_err.match( + f"Incompatible HTTP server class. Expected {BackgroundHTTPServer.__name__}, got: " + + type(fake_server).__name__ + ) + + +class TestHeartbeatHandler: + """ + Tests for the HeartbeatHandler class + """ + + @pytest.mark.parametrize( + argnames=[ + "is_running", + ], + argvalues=[ + [True], + [False], + ], + ids=["working", "idle"], + ) + def test_returns_adaptor_status( + self, + fake_request_handler: BackgroundRequestHandler, + is_running: bool, + ): + # GIVEN + mock_server = MagicMock(spec=BackgroundHTTPServer) + mock_server._log_buffer = None + + mock_server._future_runner = MagicMock() + mock_server._future_runner.is_running = is_running + + mock_server._adaptor_runner = MagicMock() + mock_server._adaptor_runner.state = AdaptorState.NOT_STARTED + + fake_request_handler.server = mock_server + handler = HeartbeatHandler(fake_request_handler) + + # WHEN + response = handler.get() + + # THEN + expected_status = "working" if is_running else "idle" + assert response.status == HTTPStatus.OK + assert response.body == json.dumps( + { + "state": "not_started", + "status": expected_status, + "output": { + "id": BufferedOutput.EMPTY, + "output": "", + }, + "failed": False, + } + ) + + @patch.object(HeartbeatHandler, "_parse_ack_id") + @patch.object(InMemoryLogBuffer, "chunk") + def test_gets_log_buffer_chunk( + self, + mock_chunk: MagicMock, + mock_parse_ack_id: MagicMock, + fake_request_handler: BackgroundRequestHandler, + ): + # GIVEN + mock_parse_ack_id.return_value = None + expected_output = BufferedOutput("id", "output") + mock_chunk.return_value = expected_output + + mock_server = MagicMock(spec=BackgroundHTTPServer) + mock_server._log_buffer = InMemoryLogBuffer() + + mock_server._future_runner = MagicMock() + mock_server._future_runner.is_running = True + + mock_server._adaptor_runner = MagicMock() + mock_server._adaptor_runner.state = AdaptorState.RUN + + fake_request_handler.server = mock_server + handler = HeartbeatHandler(fake_request_handler) + + # WHEN + response = handler.get() + + # THEN + mock_parse_ack_id.assert_called_once() + mock_chunk.assert_called_once() + assert response.status == HTTPStatus.OK + assert response.body == json.dumps( + { + "state": "run", + "status": "working", + "output": { + "id": expected_output.id, + "output": expected_output.output, + }, + "failed": False, + } + ) + + @pytest.mark.parametrize( + argnames=["valid_ack_id"], + argvalues=[[True], [False]], + ids=["Valid ACK ID", "Nonvalid ACK ID"], + ) + @patch.object(HeartbeatHandler, "_parse_ack_id") + @patch.object(InMemoryLogBuffer, "clear") + @patch.object(InMemoryLogBuffer, "chunk") + def test_processes_ack_id( + self, + mock_chunk: MagicMock, + mock_clear: MagicMock, + mock_parse_ack_id: MagicMock, + valid_ack_id: bool, + fake_request_handler: BackgroundRequestHandler, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + caplog.set_level(0) + expected_ack_id = "ack_id" + mock_parse_ack_id.return_value = expected_ack_id + expected_output = BufferedOutput("id", "output") + mock_chunk.return_value = expected_output + mock_clear.return_value = valid_ack_id + + mock_server = MagicMock(spec=BackgroundHTTPServer) + mock_server._log_buffer = InMemoryLogBuffer() + + mock_server._future_runner = MagicMock() + mock_server._future_runner.is_running = True + + mock_server._adaptor_runner = MagicMock() + mock_server._adaptor_runner.state = AdaptorState.RUN + + fake_request_handler.server = mock_server + handler = HeartbeatHandler(fake_request_handler) + + # WHEN + response = handler.get() + + # THEN + mock_parse_ack_id.assert_called_once() + mock_chunk.assert_called_once() + mock_clear.assert_called_once_with(expected_ack_id) + if valid_ack_id: + assert f"Received ACK for chunk: {expected_ack_id}" in caplog.text + else: + assert f"Received ACK for old or invalid chunk: {expected_ack_id}" in caplog.text + assert response.status == HTTPStatus.OK + assert response.body == json.dumps( + { + "state": "run", + "status": "working", + "output": { + "id": expected_output.id, + "output": expected_output.output, + }, + "failed": False, + } + ) + + @patch.object(HeartbeatHandler, "_parse_ack_id") + @patch.object(InMemoryLogBuffer, "chunk") + def test_sets_failed_if_adaptor_fails( + self, + mock_chunk: MagicMock, + mock_parse_ack_id: MagicMock, + fake_request_handler: BackgroundRequestHandler, + ) -> None: + # GIVEN + mock_parse_ack_id.return_value = None + expected_output = BufferedOutput( + "id", + os.linesep.join( + ["INFO: regular message", f"ERROR: {_OPENJD_FAIL_STDOUT_PREFIX}failure message"] + ), + ) + mock_chunk.return_value = expected_output + + mock_server = MagicMock(spec=BackgroundHTTPServer) + mock_server._log_buffer = InMemoryLogBuffer() + + mock_server._future_runner = MagicMock() + mock_server._future_runner.is_running = True + + mock_server._adaptor_runner = MagicMock() + mock_server._adaptor_runner.state = AdaptorState.RUN + + fake_request_handler.server = mock_server + handler = HeartbeatHandler(fake_request_handler) + + # WHEN + response = handler.get() + + # THEN + mock_parse_ack_id.assert_called_once() + mock_chunk.assert_called_once() + assert response.status == HTTPStatus.OK + assert response.body == json.dumps( + { + "state": "run", + "status": "working", + "output": { + "id": expected_output.id, + "output": expected_output.output, + }, + "failed": True, + } + ) + + class TestParseAckId: + """ + Tests for the HeartbeatHandler._parse_ack_id method + """ + + @patch("urllib.parse.urlparse") + @patch("urllib.parse.parse_qs") + def test_parses_ack_id( + self, + mock_parse_qs: MagicMock, + mock_urlparse: MagicMock, + ): + # GIVEN + ack_id = "123" + parsed_qs = {HeartbeatHandler._ACK_ID_KEY: [ack_id]} + mock_url = MagicMock() + mock_urlparse.return_value = mock_url + mock_parse_qs.return_value = parsed_qs + + mock_handler = MagicMock() + handler = HeartbeatHandler(mock_handler) + + # WHEN + result = handler._parse_ack_id() + + # THEN + mock_urlparse.assert_called_once_with(mock_handler.path) + mock_parse_qs.assert_called_once_with(mock_url.query) + assert ack_id == result + + @patch("urllib.parse.urlparse") + @patch("urllib.parse.parse_qs") + def test_returns_none_if_ack_id_not_found( + self, + mock_parse_qs: MagicMock, + mock_urlparse: MagicMock, + ): + # GIVEN + mock_url = MagicMock() + mock_urlparse.return_value = mock_url + mock_parse_qs.return_value = {} + + mock_handler = MagicMock() + handler = HeartbeatHandler(mock_handler) + + # WHEN + result = handler._parse_ack_id() + + # THEN + mock_urlparse.assert_called_once_with(mock_handler.path) + mock_parse_qs.assert_called_once_with(mock_url.query) + assert result is None + + @patch("urllib.parse.urlparse") + @patch("urllib.parse.parse_qs") + def test_raises_if_more_than_one_ack_id( + self, + mock_parse_qs: MagicMock, + mock_urlparse: MagicMock, + ): + # GIVEN + ack_id = "123" + parsed_qs = {HeartbeatHandler._ACK_ID_KEY: [ack_id, ack_id]} + mock_url = MagicMock() + mock_urlparse.return_value = mock_url + mock_parse_qs.return_value = parsed_qs + + mock_handler = MagicMock() + handler = HeartbeatHandler(mock_handler) + + # WHEN + with pytest.raises(ValueError) as raised_err: + handler._parse_ack_id() + + # THEN + mock_urlparse.assert_called_once_with(mock_handler.path) + mock_parse_qs.assert_called_once_with(mock_url.query) + assert raised_err.match( + f"Expected one value for {HeartbeatHandler._ACK_ID_KEY}, but found: 2" + ) + + +class TestShutdownHandler: + """ + Tests for the ShutdownHandler class + """ + + def test_signals_to_the_server_thread(self): + # GIVEN + mock_request_handler = MagicMock() + mock_server = MagicMock(spec=BackgroundHTTPServer) + mock_cancel_queue = MagicMock(spec=Queue) + mock_server._cancel_queue = mock_cancel_queue + mock_request_handler.server = mock_server + handler = ShutdownHandler(mock_request_handler) + + # WHEN + response = handler.put() + + # THEN + mock_cancel_queue.put.assert_called_once_with(True) + assert response.status == HTTPStatus.OK + assert response.body is None + + +class TestRunHandler: + """ + Tests for the RunHandler. + """ + + @patch("json.loads") + def test_submits_adaptor_run_to_worker(self, mock_loads: MagicMock): + # GIVEN + content_length = 123 + run_data = {"run": "data"} + str_run_data = json.dumps(run_data) + mock_loads.return_value = run_data + + mock_server = MagicMock(spec=BackgroundHTTPServer) + mock_future_runner = MagicMock() + mock_future_runner.is_running = False + mock_server._future_runner = mock_future_runner + mock_server._adaptor_runner = MagicMock() + + mock_handler = MagicMock() + mock_handler.headers = {"Content-Length": str(content_length)} + mock_handler.rfile.read.return_value = str_run_data.encode("utf-8") + mock_handler.server = mock_server + handler = RunHandler(mock_handler) + + # WHEN + result = handler.put() + + # THEN + mock_handler.rfile.read.assert_called_once_with(content_length) + mock_loads.assert_called_once_with(str_run_data) + mock_server.submit.assert_called_once_with( + mock_server._adaptor_runner._run, + run_data, + ) + assert result is mock_server.submit.return_value + + def test_returns_400_if_busy(self): + # GIVEN + mock_server = MagicMock(spec=BackgroundHTTPServer) + mock_future_runner = MagicMock() + mock_future_runner.is_running = True + mock_server._future_runner = mock_future_runner + + mock_handler = MagicMock() + mock_handler.server = mock_server + handler = RunHandler(mock_handler) + + # WHEN + result = handler.put() + + # THEN + assert result.status == HTTPStatus.BAD_REQUEST + + +class TestStartHandler: + """ + Tests for the StartHandler class + """ + + def test_put_starts_adaptor_runner(self): + # GIVEN + mock_request_handler = MagicMock() + mock_server = MagicMock(spec=BackgroundHTTPServer) + mock_server._adaptor_runner = MagicMock(spec=AdaptorRunner) + mock_request_handler.server = mock_server + + mock_future_runner = MagicMock() + mock_future_runner.is_running = False + mock_server._future_runner = mock_future_runner + + handler = StartHandler(mock_request_handler) + + # WHEN + response = handler.put() + + # THEN + mock_server.submit.assert_called_once_with(mock_server._adaptor_runner._start) + assert response is mock_server.submit.return_value + + def test_returns_400_if_busy(self): + # GIVEN + mock_server = MagicMock(spec=BackgroundHTTPServer) + mock_future_runner = MagicMock() + mock_future_runner.is_running = True + mock_server._future_runner = mock_future_runner + + mock_handler = MagicMock() + mock_handler.server = mock_server + handler = StartHandler(mock_handler) + + # WHEN + result = handler.put() + + # THEN + assert result.status == HTTPStatus.BAD_REQUEST + + +class TestStopHandlerr: + """ + Tests for the StopHandler class + """ + + def test_put_ends_adaptor_runner(self): + # GIVEN + mock_request_handler = MagicMock() + mock_server = MagicMock(spec=BackgroundHTTPServer) + mock_server._adaptor_runner = MagicMock(spec=AdaptorRunner) + mock_request_handler.server = mock_server + + mock_future_runner = MagicMock() + mock_future_runner.is_running = False + mock_server._future_runner = mock_future_runner + + handler = StopHandler(mock_request_handler) + + # WHEN + response = handler.put() + + # THEN + mock_server.submit.assert_called_once_with(handler._stop_adaptor) + assert response is mock_server.submit.return_value + + def test_returns_400_if_busy(self): + # GIVEN + mock_server = MagicMock(spec=BackgroundHTTPServer) + mock_future_runner = MagicMock() + mock_future_runner.is_running = True + mock_server._future_runner = mock_future_runner + + mock_handler = MagicMock() + mock_handler.server = mock_server + handler = StopHandler(mock_handler) + + # WHEN + result = handler.put() + + # THEN + assert result.status == HTTPStatus.BAD_REQUEST + + +class TestCancelHandler: + """ + Tests for the CancelHandler class + """ + + def test_put_cancels_adaptor_runner(self): + # GIVEN + mock_request_handler = MagicMock() + mock_server = MagicMock(spec=BackgroundHTTPServer) + mock_server._adaptor_runner = MagicMock(spec=AdaptorRunner) + mock_server._adaptor_runner.state = AdaptorState.RUN + mock_request_handler.server = mock_server + + mock_future_runner = MagicMock() + mock_future_runner.is_running = True + mock_server._future_runner = mock_future_runner + + handler = CancelHandler(mock_request_handler) + + # WHEN + response = handler.put() + + # THEN + mock_server.submit.assert_called_once_with( + mock_server._adaptor_runner._cancel, + force_immediate=True, + ) + assert response is mock_server.submit.return_value + + def test_returns_immediately_if_future_not_running(self): + # GIVEN + mock_server = MagicMock(spec=BackgroundHTTPServer) + mock_future_runner = MagicMock() + mock_future_runner.is_running = False + mock_server._future_runner = mock_future_runner + + mock_handler = MagicMock() + mock_handler.server = mock_server + handler = CancelHandler(mock_handler) + + # WHEN + result = handler.put() + + # THEN + assert result.status == HTTPStatus.OK + assert result.body == "No action required" + + @pytest.mark.parametrize( + argnames=["state"], + argvalues=[ + [AdaptorState.NOT_STARTED], + [AdaptorState.STOP], + [AdaptorState.CLEANUP], + [AdaptorState.CANCELED], + ], + ids=["NOT_STARTED", "END", "CLEANUP", "CANCELED"], + ) + def test_returns_immediately_if_adaptor_not_cancelable(self, state: AdaptorState): + # GIVEN + mock_server = MagicMock(spec=BackgroundHTTPServer) + mock_future_runner = MagicMock() + mock_future_runner.is_running = True + mock_server._future_runner = mock_future_runner + + mock_adaptor_runner = MagicMock() + mock_adaptor_runner.state = state + mock_server._adaptor_runner = mock_adaptor_runner + + mock_handler = MagicMock() + mock_handler.server = mock_server + handler = CancelHandler(mock_handler) + + # WHEN + result = handler.put() + + # THEN + assert result.status == HTTPStatus.OK + assert result.body == "No action required" diff --git a/test/openjd/adaptor_runtime/unit/background/test_log_buffers.py b/test/openjd/adaptor_runtime/unit/background/test_log_buffers.py new file mode 100644 index 0000000..a14e6a8 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/background/test_log_buffers.py @@ -0,0 +1,181 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import logging +import os +from typing import Tuple +from unittest.mock import MagicMock, mock_open, patch + +import pytest + +import openjd.adaptor_runtime._background.log_buffers as log_buffers +from openjd.adaptor_runtime._background.log_buffers import ( + FileLogBuffer, + InMemoryLogBuffer, + LogBuffer, +) +from openjd.adaptor_runtime._background.model import BufferedOutput + + +@pytest.fixture(autouse=True) +def mocked_chunk_id(): + with patch.object(LogBuffer, "_create_id") as mock_create_id: + chunk_id = "id" + mock_create_id.return_value = chunk_id + yield (chunk_id, mock_create_id) + + +class TestInMemoryLogBuffer: + """ + Tests for InMemoryLogBuffer. + """ + + @patch.object(LogBuffer, "_format") + def test_chunk_creates_new_chunk( + self, + mock_format: MagicMock, + mocked_chunk_id: Tuple[str, MagicMock], + ): + # GIVEN + chunk_id, mock_create_id = mocked_chunk_id + mock_format.return_value = "output" + buffer = InMemoryLogBuffer() + buffer._buffer = [MagicMock()] + + # WHEN + output = buffer.chunk() + + # THEN + assert output.id == chunk_id + assert output.output == mock_format.return_value + mock_create_id.assert_called_once() + assert len(buffer._buffer) == 0 + assert buffer._last_chunk == output + + @patch.object(LogBuffer, "_format") + def test_chunk_uses_last_chunk( + self, + mock_format: MagicMock, + mocked_chunk_id: Tuple[str, MagicMock], + ): + # GIVEN + chunk_id, mock_create_id = mocked_chunk_id + mock_format.return_value = "output" + buffer = InMemoryLogBuffer() + buffer._buffer = [MagicMock()] + last_chunk = BufferedOutput("id", "last_chunk") + buffer._last_chunk = last_chunk + + # WHEN + output = buffer.chunk() + + # THEN + assert output.id == chunk_id + assert output.output == os.linesep.join([last_chunk.output, mock_format.return_value]) + mock_create_id.assert_called_once() + assert len(buffer._buffer) == 0 + assert buffer._last_chunk == output + + def test_clear_clears_chunk(self): + # GIVEN + last_chunk = BufferedOutput("id", "last_chunk") + buffer = InMemoryLogBuffer() + buffer._last_chunk = last_chunk + + # WHEN + cleared = buffer.clear(last_chunk.id) + + # THEN + assert cleared + assert buffer._last_chunk is None + + def test_clear_no_op_if_wrong_id(self): + # GIVEN + last_chunk = BufferedOutput("id", "last_chunk") + buffer = InMemoryLogBuffer() + buffer._last_chunk = last_chunk + + # WHEN + cleared = buffer.clear("wrong_id") + + # THEN + assert not cleared + assert buffer._last_chunk == last_chunk + + +class TestFileLogBuffer: + """ + Tests for the FileLogBuffer class + """ + + def test_buffer(self) -> None: + # GIVEN + filepath = "/filepath" + mock_record = MagicMock(spec=logging.LogRecord) + mock_record.msg = "hello world" + buffer = FileLogBuffer(filepath) + + # WHEN + open_mock: MagicMock + with patch.object(log_buffers, "secure_open", mock_open()) as open_mock: + buffer.buffer(mock_record) + + # THEN + open_mock.assert_called_once_with(filepath, open_mode="a") + handle = open_mock.return_value + handle.write.assert_called_once_with(mock_record.msg) + + def test_chunk(self, mocked_chunk_id: Tuple[str, MagicMock]) -> None: + # GIVEN + chunk_id, mock_create_id = mocked_chunk_id + filepath = "/filepath" + data = "hello world" + end_pos = len(data) + buffer = FileLogBuffer(filepath) + + # WHEN + open_mock: MagicMock + with patch("builtins.open", mock_open(read_data=data)) as open_mock: + open_mock.return_value.tell.return_value = end_pos + output = buffer.chunk() + + # THEN + mock_create_id.assert_called_once() + open_mock.assert_called_once_with(filepath, mode="r") + handle = open_mock.return_value + handle.seek.assert_called_once_with(buffer._chunk.start) + handle.read.assert_called_once() + handle.tell.assert_called_once() + assert buffer._chunk.end == end_pos + assert buffer._chunk.id == chunk_id + assert output.id == chunk_id + assert output.output == data + + def test_clear(self) -> None: + # GIVEN + chunk_id = "id" + end_pos = 123 + buffer = FileLogBuffer("") + buffer._chunk.id = chunk_id + buffer._chunk.end = end_pos + + # WHEN + cleared = buffer.clear(chunk_id) + + # THEN + assert cleared + assert buffer._chunk.start == end_pos + assert buffer._chunk.id is None + + def test_clear_no_op_if_wrong_id(self) -> None: + # GIVEN + buffer = FileLogBuffer("") + buffer._chunk.id = "id" + buffer._chunk.end = 1 + + # WHEN + cleared = buffer.clear("wrong_id") + + # THEN + assert not cleared + assert buffer._chunk.id == "id" + assert buffer._chunk.start == 0 diff --git a/test/openjd/adaptor_runtime/unit/background/test_model.py b/test/openjd/adaptor_runtime/unit/background/test_model.py new file mode 100644 index 0000000..3a74b44 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/background/test_model.py @@ -0,0 +1,51 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import dataclasses + +import pytest + +from openjd.adaptor_runtime._background.model import DataclassMapper + + +# Define two dataclasses to use for tests +@dataclasses.dataclass +class Inner: + key: str + + +@dataclasses.dataclass +class Outer: + outer_key: str + inner: Inner + + +class TestDataclassMapper: + """ + Tests for the DataclassMapper class + """ + + def test_maps_nested_dataclass(self): + # GIVEN + input = {"outer_key": "outer_value", "inner": {"key": "value"}} + mapper = DataclassMapper(Outer) + + # WHEN + result = mapper.map(input) + + # THEN + assert isinstance(result, Outer) + assert isinstance(result.inner, Inner) + assert result.outer_key == "outer_value" + assert result.inner.key == "value" + + def test_raises_when_field_is_missing(self): + # GIVEN + input = {"outer_key": "outer_value"} + mapper = DataclassMapper(Outer) + + # WHEN + with pytest.raises(ValueError) as raised_err: + mapper.map(input) + + # THEN + assert raised_err.match("Dataclass field inner not found in dict " + str(input)) diff --git a/test/openjd/adaptor_runtime/unit/handlers/__init__.py b/test/openjd/adaptor_runtime/unit/handlers/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/handlers/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime/unit/handlers/test_regex_callback_handler.py b/test/openjd/adaptor_runtime/unit/handlers/test_regex_callback_handler.py new file mode 100644 index 0000000..21d4735 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/handlers/test_regex_callback_handler.py @@ -0,0 +1,329 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import logging +import re +from typing import Dict, List, Tuple +from unittest.mock import Mock + +import pytest + +from openjd.adaptor_runtime.app_handlers import RegexCallback, RegexHandler + + +class TestLoggingRegexHandler: + """ + Tests for the RegexHandler when using the logging library + """ + + invoked_regex_list = [ + pytest.param( + [re.compile(".*")], + 0, + "Test input", + ["Test input", ""], + id="Match everything regex call once", + ), + pytest.param( + [re.compile(r"input")], 0, "Test input", ["input"], id="Match a word regex call once" + ), + pytest.param( + [re.compile(r"\w+")], + 0, + "Test input", + ["Test", "input"], + id="Match multiple words regex call once", + ), + pytest.param( + [re.compile("b"), re.compile("s")], + 1, + "Test input", + ["s"], + id="Multiple regexes single match call once", + ), + pytest.param( + [re.compile("t"), re.compile("s")], + 0, + "Test input", + ["t", "t"], + id="Multiple regexes multiple match call once", + ), + pytest.param( + [re.compile("test", flags=re.IGNORECASE)], + 0, + "Test input", + ["Test"], + id="Ignore case regex", + ), + ] + + @pytest.mark.parametrize( + "regex_list, match_regex_index, input, find_all_results", invoked_regex_list + ) + def test_regex_handler_invoked( + self, + regex_list: List[re.Pattern], + match_regex_index: int, + input: str, + find_all_results: List[str], + ): + # GIVEN + callback_mock = Mock().callback + regex_callback = RegexCallback(regex_list, callback_mock) + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + handler = RegexHandler([regex_callback]) + stdout_logger.addHandler(handler) + + # WHEN + stdout_logger.info(input) + + # THEN + callback_mock.assert_called_once() + assert callback_mock.call_args[0][0].re == regex_list[match_regex_index] + assert regex_list[match_regex_index].findall(input) == find_all_results + + noninvoked_regex_list = [ + pytest.param([re.compile("(?!)")], "Test input", id="Match nothing regex"), + pytest.param([re.compile(r"a")], "Test input", id="Single letter match nothing regex"), + ] + + @pytest.mark.parametrize("regex_list, input", noninvoked_regex_list) + def test_regex_handler_not_invoked( + self, + regex_list: List[re.Pattern], + input: str, + ): + # GIVEN + callback_mock = Mock().callback + regex_callback = RegexCallback(regex_list, callback_mock) + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + handler = RegexHandler([regex_callback]) + stdout_logger.addHandler(handler) + + # WHEN + stdout_logger.info(input) + + # THEN + callback_mock.assert_not_called() + assert regex_list[0].findall(input) == [] + + multiple_regex_list = [ + pytest.param( + [[re.compile("Test")], [re.compile("Input", flags=re.IGNORECASE)]], + "Test input", + id="Match twice regexes", + ), + pytest.param( + [[re.compile("T")], [re.compile("e")], [re.compile("s")], [re.compile("t")]], + "Test input", + id="Single letter match four times", + ), + pytest.param( + [[re.compile("T"), re.compile("e")], [re.compile("s"), re.compile("t")]], + "Test input", + id="Multiple callbacks with multiple matching regexes match twice", + ), + ] + + @pytest.mark.parametrize("regex_lists, input", multiple_regex_list) + def test_multiple_callbacks( + self, + regex_lists: List[List[re.Pattern]], + input: str, + ): + # GIVEN + callback_mock = Mock().callback + regex_callbacks = [RegexCallback(regex_list, callback_mock) for regex_list in regex_lists] + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + handler = RegexHandler(regex_callbacks) + stdout_logger.addHandler(handler) + + # WHEN + stdout_logger.info(input) + + # THEN + assert callback_mock.call_count == len(regex_lists) + assert all( + c[0][0].re == patterns[0] + for c, patterns in zip(callback_mock.call_args_list, regex_lists) + ) + for regex_list in regex_lists: + assert regex_list[0].search(input) + + multiple_loggers = [ + pytest.param( + { + logging.getLogger("info_logger"): [re.compile("INFO: "), re.compile("STDOUT: ")], + logging.getLogger("error_logger"): [re.compile("ERROR: ")], + }, + "Test input", + ), + ] + + @pytest.mark.parametrize("loggers, input", multiple_loggers) + def test_multiple_loggers( + self, + loggers: Dict[logging.Logger, List[re.Pattern]], + input: str, + ): + # GIVEN + regex_callbacks = {} + for logger, regex_list in loggers.items(): + logger.setLevel(logging.INFO) + callback_mock = Mock() + regex_callback = RegexCallback(regex_list, callback_mock) + handler = RegexHandler([regex_callback]) + logger.addHandler(handler) + regex_callbacks[callback_mock] = [pattern for pattern in regex_list] + + # WHEN + for patterns in regex_callbacks.values(): + for logger in loggers.keys(): + for pattern in patterns: + logger.info(f"{pattern.pattern}: {input}") + + # THEN + for callback_mock, patterns in regex_callbacks.items(): + print(callback_mock.call_args_list) + assert callback_mock.call_count == len(patterns) + assert all( + c[0][0].re == pattern for c, pattern in zip(callback_mock.call_args_list, patterns) + ) + + exit_if_matched_regexes = [ + pytest.param( + [(re.compile("Test"), True, True), (re.compile("input"), False, True)], + "Test input", + id="Multiple matches, exit_if_matched first", + ), + pytest.param( + [ + (re.compile("T"), False, True), + (re.compile("e"), True, True), + (re.compile("s"), False, True), + (re.compile("t"), False, True), + ], + "Test input", + id="Single letter match four times, exit_if_matched second", + ), + pytest.param( + [ + (re.compile("a"), True, False), + (re.compile("b"), False, False), + (re.compile("c"), True, True), + (re.compile("d"), False, True), + ], + "cd", + id="Multiple matched, multiple exit_if_matched", + ), + ] + + @pytest.mark.parametrize("regex_list, input", exit_if_matched_regexes) + @pytest.mark.parametrize("exit_if_matched", [False, True]) + def test_exit_if_matched( + self, + exit_if_matched: bool, + regex_list: List[Tuple[re.Pattern, bool, bool]], + input: str, + ): + # GIVEN + callback_mock = Mock().callback + regex_callbacks = [ + RegexCallback( + [pattern], callback_mock, exit_if_matched=exit_if_matched and exit_this_pattern + ) + for (pattern, exit_this_pattern, _) in regex_list + ] + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + handler = RegexHandler(regex_callbacks) + stdout_logger.addHandler(handler) + + # WHEN + stdout_logger.info(input) + + # THEN + if not exit_if_matched: + patterns = [pattern for pattern, _, matches in regex_list if matches] + assert callback_mock.call_count == len(patterns) + assert all( + c[0][0].re == pattern for c, pattern in zip(callback_mock.call_args_list, patterns) + ) + else: + patterns = [] + for pattern, exit_this_pattern, matches_input in regex_list: + if matches_input: + patterns.append(pattern) + if exit_this_pattern: + break + assert callback_mock.call_count == len(patterns) + assert all( + c[0][0].re == pattern for c, pattern in zip(callback_mock.call_args_list, patterns) + ) + + only_run_if_first_regexes = [ + pytest.param( + [re.compile("Test"), re.compile("input")], 0, "Test input", id="Match twice regexes" + ), + pytest.param( + [re.compile("Test"), re.compile("input")], + 1, + "Test input", + id="Match twice regexes only run first", + ), + pytest.param( + [re.compile("T"), re.compile("e"), re.compile("s"), re.compile("t")], + 0, + "Test input", + id="Single letter match four times", + ), + pytest.param( + [re.compile("T"), re.compile("e"), re.compile("s"), re.compile("t")], + 3, + "Test input", + id="Single letter match four times don't run last", + ), + ] + + @pytest.mark.parametrize("regex_list, first_match_index, input", only_run_if_first_regexes) + @pytest.mark.parametrize("only_run_if_first", [False, True]) + def test_only_run_if_first_matched( + self, + only_run_if_first: bool, + regex_list: List[re.Pattern], + first_match_index: int, + input: str, + ): + # GIVEN + callback_mock = Mock().callback + regex_callbacks = [ + RegexCallback( + [pattern], + callback_mock, + only_run_if_first_matched=only_run_if_first and i == first_match_index, + ) + for i, pattern in enumerate(regex_list) + ] + stdout_logger = logging.getLogger("stdout") + stdout_logger.setLevel(logging.INFO) + handler = RegexHandler(regex_callbacks) + stdout_logger.addHandler(handler) + + # WHEN + stdout_logger.info(input) + + # THEN + if not only_run_if_first or first_match_index == 0: + assert callback_mock.call_count == len(regex_list) + assert all( + c[0][0].re == pattern + for c, pattern in zip(callback_mock.call_args_list, regex_list) + ) + else: + patterns = [pattern for i, pattern in enumerate(regex_list) if i != first_match_index] + assert callback_mock.call_count == len(regex_list) - 1 + assert all( + c[0][0].re == pattern for c, pattern in zip(callback_mock.call_args_list, patterns) + ) diff --git a/test/openjd/adaptor_runtime/unit/http/__init__.py b/test/openjd/adaptor_runtime/unit/http/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/http/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime/unit/http/test_request_handler.py b/test/openjd/adaptor_runtime/unit/http/test_request_handler.py new file mode 100644 index 0000000..725bbbe --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/http/test_request_handler.py @@ -0,0 +1,242 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import socket +from http import HTTPStatus +from unittest.mock import MagicMock, Mock, patch + +import pytest + +import openjd.adaptor_runtime._http.request_handler as request_handler +from openjd.adaptor_runtime._background.http_server import BackgroundRequestHandler +from openjd.adaptor_runtime._http.request_handler import ( + HTTPResponse, + RequestHandler, + UnsupportedPlatformException, +) + + +@pytest.fixture +def fake_request_handler() -> BackgroundRequestHandler: + class FakeBackgroundRequestHandler(BackgroundRequestHandler): + path: str = "/fake" + + def __init__(self) -> None: + pass + + def _authenticate(self) -> bool: + return True + + return FakeBackgroundRequestHandler() + + +class TestRequestHandler: + """ + Tests for the RequestHandler class. + """ + + @patch.object(BackgroundRequestHandler, "_respond") + def test_do_request( + self, + mock_respond: MagicMock, + fake_request_handler: BackgroundRequestHandler, + ): + # GIVEN + func = Mock() + + # WHEN + fake_request_handler._do_request(func) + + # THEN + func.assert_called_once() + mock_respond.assert_called_once_with(func.return_value) + + @patch.object(BackgroundRequestHandler, "_respond") + def test_do_request_responds_with_error_when_request_handler_raises( + self, + mock_respond: MagicMock, + fake_request_handler: BackgroundRequestHandler, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + func = Mock() + exc = Exception() + func.side_effect = exc + + # WHEN + fake_request_handler._do_request(func) + + # THEN + func.assert_called_once() + assert "Failed to handle request: " in caplog.text + mock_respond.assert_called_once_with(HTTPResponse(HTTPStatus.INTERNAL_SERVER_ERROR)) + + @patch.object(BackgroundRequestHandler, "send_response") + @patch.object(BackgroundRequestHandler, "end_headers") + def test_respond_with_success( + self, + mock_end_headers: MagicMock, + mock_send_response: MagicMock, + fake_request_handler: BackgroundRequestHandler, + ): + # GIVEN + response = HTTPResponse(HTTPStatus.OK) + + # WHEN + fake_request_handler._respond(response) + + # THEN + mock_send_response.assert_called_once_with(response.status) + mock_end_headers.assert_called_once() + + @patch.object(BackgroundRequestHandler, "send_error") + @patch.object(BackgroundRequestHandler, "end_headers") + def test_respond_with_error( + self, + mock_end_headers: MagicMock, + mock_send_error: MagicMock, + fake_request_handler: BackgroundRequestHandler, + ): + # GIVEN + response = HTTPResponse(HTTPStatus.INTERNAL_SERVER_ERROR) + + # WHEN + fake_request_handler._respond(response) + + # THEN + mock_send_error.assert_called_once_with(response.status) + mock_end_headers.assert_called_once() + + @patch.object(BackgroundRequestHandler, "send_header") + @patch.object(BackgroundRequestHandler, "send_response") + @patch.object(BackgroundRequestHandler, "end_headers") + def test_respond_with_body( + self, + mock_end_headers: MagicMock, + mock_send_response: MagicMock, + mock_send_header: MagicMock, + fake_request_handler: BackgroundRequestHandler, + ): + # GIVEN + mock_wfile = MagicMock() + fake_request_handler.wfile = mock_wfile + body = "hello world" + response = HTTPResponse(HTTPStatus.OK, body) + + # WHEN + fake_request_handler._respond(response) + + # THEN + mock_send_response.assert_called_once_with(response.status) + mock_end_headers.assert_called_once() + mock_send_header.assert_called_once_with("Content-Length", str(len(body.encode("utf-8")))) + mock_wfile.write.assert_called_once_with(body.encode("utf-8")) + + +class TestAuthentication: + """ + Tests for the RequestHandler authentication + """ + + class TestAuthenticate: + """ + Tests for the RequestHandler._authenticate() method + """ + + @pytest.fixture + def mock_handler(self) -> MagicMock: + mock_socket = MagicMock(spec=socket.socket) + mock_socket.family = socket.AddressFamily.AF_UNIX + + mock_handler = MagicMock(spec=RequestHandler) + mock_handler.connection = mock_socket + + return mock_handler + + @patch.object(request_handler.os, "getuid") + @patch.object(request_handler.UCred, "from_buffer_copy") + def test_accepts_same_uid( + self, mock_from_buffer_copy: MagicMock, mock_getuid: MagicMock, mock_handler: MagicMock + ) -> None: + # GIVEN + # Set the UID of the mocked calling process == our mocked UID + mock_from_buffer_copy.return_value.uid = mock_getuid.return_value + + # WHEN + result = RequestHandler._authenticate(mock_handler) + + # THEN + assert result + + @patch.object(request_handler.os, "getuid") + @patch.object(request_handler.UCred, "from_buffer_copy") + def test_rejects_different_uid( + self, mock_from_buffer_copy: MagicMock, mock_getuid: MagicMock, mock_handler: MagicMock + ) -> None: + # GIVEN + mock_getuid.return_value = 1 + mock_from_buffer_copy.return_value.uid = 2 + + # WHEN + result = RequestHandler._authenticate(mock_handler) + + # THEN + assert not result + + def test_raises_if_not_on_unix_socket(self, mock_handler: MagicMock) -> None: + # GIVEN + mock_handler.connection.family = socket.AddressFamily.AF_INET + + # WHEN + with pytest.raises(UnsupportedPlatformException) as raised_exc: + RequestHandler._authenticate(mock_handler) + + # THEN + assert raised_exc.match( + "Failed to handle request because it was not made through a UNIX socket" + ) + + class TestDoRequest: + """ + Tests for the RequestHandler._do_request() method + """ + + def test_does_request_after_auth_succeeds(self) -> None: + # GIVEN + mock_handler = MagicMock(spec=RequestHandler) + mock_handler._authenticate.return_value = True + mock_func = Mock() + + # WHEN + RequestHandler._do_request(mock_handler, mock_func) + + # THEN + mock_handler._authenticate.assert_called_once() + mock_func.assert_called_once() + + def test_responds_with_unauthorized_after_auth_fails(self): + # GIVEN + mock_handler = MagicMock(spec=RequestHandler) + mock_handler._authenticate.return_value = False + + # WHEN + RequestHandler._do_request(mock_handler, Mock()) + + # THEN + mock_handler._authenticate.assert_called_once() + mock_handler._respond.assert_called_once_with(HTTPResponse(HTTPStatus.UNAUTHORIZED)) + + def test_responds_with_500_for_unsupported_platform(self, caplog: pytest.LogCaptureFixture): + # GIVEN + mock_handler = MagicMock(spec=RequestHandler) + exc = UnsupportedPlatformException("not UNIX") + mock_handler._authenticate.side_effect = exc + + # WHEN + RequestHandler._do_request(mock_handler, Mock()) + + # THEN + mock_handler._authenticate.assert_called_once() + assert str(exc) in caplog.text + mock_handler._respond.assert_called_once_with( + HTTPResponse(HTTPStatus.INTERNAL_SERVER_ERROR) + ) diff --git a/test/openjd/adaptor_runtime/unit/http/test_sockets.py b/test/openjd/adaptor_runtime/unit/http/test_sockets.py new file mode 100644 index 0000000..68a218e --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/http/test_sockets.py @@ -0,0 +1,265 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import os +import re +import stat +from typing import Generator +from unittest.mock import ANY, MagicMock, call, patch + +import pytest + +import openjd.adaptor_runtime._http.sockets as sockets +from openjd.adaptor_runtime._http.sockets import ( + LinuxSocketDirectories, + NonvalidSocketPathException, + NoSocketPathFoundException, + SocketDirectories, +) + + +class SocketDirectoriesStub(SocketDirectories): + def verify_socket_path(self, path: str) -> None: + pass + + +class TestSocketDirectories: + class TestGetProcessSocketPath: + """ + Tests for SocketDirectories.get_process_socket_path() + """ + + @pytest.fixture + def socket_dir(self) -> str: + return "/path/to/socket/dir" + + @pytest.fixture(autouse=True) + def mock_socket_dir(self, socket_dir: str) -> Generator[MagicMock, None, None]: + with patch.object(SocketDirectories, "get_socket_dir") as mock: + mock.return_value = socket_dir + yield mock + + @pytest.mark.parametrize( + argnames=["create_dir"], + argvalues=[[True], [False]], + ids=["creates dir", "does not create dir"], + ) + @patch.object(sockets.os, "getpid", return_value=1234) + def test_gets_path( + self, + mock_getpid: MagicMock, + socket_dir: str, + mock_socket_dir: MagicMock, + create_dir: bool, + ) -> None: + # GIVEN + namespace = "my-namespace" + subject = SocketDirectoriesStub() + + # WHEN + result = subject.get_process_socket_path(namespace, create_dir=create_dir) + + # THEN + assert result == os.path.join(socket_dir, str(mock_getpid.return_value)) + mock_getpid.assert_called_once() + mock_socket_dir.assert_called_once_with(namespace, create=create_dir) + + @patch.object(sockets.os, "getpid", return_value="a" * (sockets._PID_MAX_LENGTH + 1)) + def test_asserts_max_pid_length(self, mock_getpid: MagicMock): + # GIVEN + subject = SocketDirectoriesStub() + + # WHEN + with pytest.raises(AssertionError) as raised_err: + subject.get_process_socket_path() + + # THEN + assert raised_err.match( + f"PID too long. Only PIDs up to {sockets._PID_MAX_LENGTH} digits are supported." + ) + mock_getpid.assert_called_once() + + class TestGetSocketDir: + """ + Tests for SocketDirectories.get_socket_dir() + """ + + @pytest.fixture(autouse=True) + def mock_makedirs(self) -> Generator[MagicMock, None, None]: + with patch.object(sockets.os, "makedirs") as mock: + yield mock + + @pytest.fixture + def home_dir(self) -> str: + return os.path.join("home", "user") + + @pytest.fixture(autouse=True) + def mock_expanduser(self, home_dir: str) -> Generator[MagicMock, None, None]: + with patch.object(sockets.os.path, "expanduser", return_value=home_dir) as mock: + yield mock + + @pytest.fixture + def temp_dir(self) -> str: + return "tmp" + + @pytest.fixture(autouse=True) + def mock_gettempdir(self, temp_dir: str) -> Generator[MagicMock, None, None]: + with patch.object(sockets.tempfile, "gettempdir", return_value=temp_dir) as mock: + yield mock + + def test_gets_home_dir( + self, + mock_expanduser: MagicMock, + home_dir: str, + ) -> None: + # GIVEN + subject = SocketDirectoriesStub() + + # WHEN + result = subject.get_socket_dir() + + # THEN + mock_expanduser.assert_called_once_with("~") + assert result.startswith(home_dir) + + @patch.object(sockets.os, "stat") + @patch.object(SocketDirectoriesStub, "verify_socket_path") + def test_gets_temp_dir( + self, + mock_verify_socket_path: MagicMock, + mock_stat: MagicMock, + mock_gettempdir: MagicMock, + temp_dir: str, + ) -> None: + # GIVEN + exc = NonvalidSocketPathException() + mock_verify_socket_path.side_effect = [exc, None] # Raise exc only once + mock_stat.return_value.st_mode = stat.S_ISVTX + subject = SocketDirectoriesStub() + + # WHEN + result = subject.get_socket_dir() + + # THEN + mock_gettempdir.assert_called_once() + mock_verify_socket_path.assert_has_calls( + [ + call(ANY), # home dir + call(result), # temp dir + ] + ) + mock_stat.assert_called_once_with(temp_dir) + + @pytest.mark.parametrize( + argnames=["create"], + argvalues=[[True], [False]], + ids=["created", "not created"], + ) + def test_create_dir(self, mock_makedirs: MagicMock, create: bool) -> None: + # GIVEN + subject = SocketDirectoriesStub() + + # WHEN + result = subject.get_socket_dir(create=create) + + # THEN + if create: + mock_makedirs.assert_called_once_with(result, mode=0o700, exist_ok=True) + else: + mock_makedirs.assert_not_called() + + def test_uses_namespace(self) -> None: + # GIVEN + namespace = "my-namespace" + subject = SocketDirectoriesStub() + + # WHEN + result = subject.get_socket_dir(namespace) + + # THEN + assert result.endswith(namespace) + + @patch.object(SocketDirectoriesStub, "verify_socket_path") + def test_raises_when_no_valid_dir_found(self, mock_verify_socket_path: MagicMock) -> None: + # GIVEN + mock_verify_socket_path.side_effect = NonvalidSocketPathException() + subject = SocketDirectoriesStub() + + # WHEN + with pytest.raises(NoSocketPathFoundException) as raised_exc: + subject.get_socket_dir() + + # THEN + assert raised_exc.match( + "Failed to find a suitable base directory to create sockets in for the following " + "reasons: " + ) + assert mock_verify_socket_path.call_count == 2 + + @patch.object(SocketDirectoriesStub, "verify_socket_path") + @patch.object(sockets.os, "stat") + def test_raises_when_no_tmpdir_sticky_bit( + self, + mock_stat: MagicMock, + mock_verify_socket_path: MagicMock, + temp_dir: str, + ) -> None: + # GIVEN + mock_verify_socket_path.side_effect = [NonvalidSocketPathException(), None] + mock_stat.return_value.st_mode = 0 + subject = SocketDirectoriesStub() + + # WHEN + with pytest.raises(NoSocketPathFoundException) as raised_exc: + subject.get_socket_dir() + + # THEN + assert raised_exc.match( + re.escape( + f"Cannot use temporary directory {temp_dir} because it does not have the " + "sticky bit (restricted deletion flag) set" + ) + ) + + +class TestLinuxSocketDirectories: + @pytest.mark.parametrize( + argnames=["path"], + argvalues=[ + ["a"], + ["a" * 100], + ], + ids=["one byte", "100 bytes"], + ) + def test_accepts_paths_within_100_bytes(self, path: str): + """ + Verifies the function accepts paths up to 100 bytes (108 byte max - 8 byte padding + for socket name portion (path sep + PID)) + """ + # GIVEN + subject = LinuxSocketDirectories() + + try: + # WHEN + subject.verify_socket_path(path) + except NonvalidSocketPathException as e: + pytest.fail(f"verify_socket_path raised an error when it should not have: {e}") + else: + # THEN + pass # success + + def test_rejects_paths_over_100_bytes(self): + # GIVEN + length = 101 + path = "a" * length + subject = LinuxSocketDirectories() + + # WHEN + with pytest.raises(NonvalidSocketPathException) as raised_exc: + subject.verify_socket_path(path) + + # THEN + assert raised_exc.match( + "Socket base directory path too big. The maximum allowed size is " + f"{subject._socket_dir_max_length} bytes, but the directory has a size of " + f"{length}: {path}" + ) diff --git a/test/openjd/adaptor_runtime/unit/process/__init__.py b/test/openjd/adaptor_runtime/unit/process/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/process/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime/unit/process/test_logging_subprocess.py b/test/openjd/adaptor_runtime/unit/process/test_logging_subprocess.py new file mode 100644 index 0000000..90ed0a1 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/process/test_logging_subprocess.py @@ -0,0 +1,496 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +"""Tests for StreamLogger""" +from __future__ import annotations + +import subprocess +from logging import INFO +from typing import List +from unittest import mock + +import pytest + +import openjd.adaptor_runtime.process._logging_subprocess as logging_subprocess +from openjd.adaptor_runtime.process import LoggingSubprocess + + +class TestLoggingSubprocess(object): + """Tests for LoggingSubprocess""" + + @pytest.fixture() + def mock_popen(self): + with mock.patch.object(logging_subprocess.subprocess, "Popen") as popen_mock: + yield popen_mock + + @pytest.fixture(autouse=True) + def mock_stream_logger(self): + with mock.patch.object(logging_subprocess, "StreamLogger") as stream_logger: + yield stream_logger + + def test_args_validation(self, mock_popen: mock.Mock): + """Tests that passing no args raises an Exception""" + # GIVEN + args: List[str] = [] + logger = mock.Mock() + + # THEN + with pytest.raises(ValueError, match="Insufficient args"): + LoggingSubprocess(args=args, logger=logger) + + mock_popen.assert_not_called() + + def test_logging_validation(self, mock_popen: mock.Mock): + """Tests that passing no logger raises an Exception""" + # GIVEN + args = ["cat", "foo.txt"] + logger = None + + # THEN + with pytest.raises(ValueError, match="No logger specified"): + LoggingSubprocess(args=args, logger=logger) # type: ignore[arg-type] + + mock_popen.assert_not_called() + + def test_process_creation(self, mock_popen: mock.Mock, mock_stream_logger: mock.Mock): + # GIVEN + args = ["cat", "foo.txt"] + logger = mock.Mock() + stdout_logger_mock = mock.Mock() + stderr_logger_mock = mock.Mock() + mock_stream_logger.side_effect = [stdout_logger_mock, stderr_logger_mock] + + # WHEN + LoggingSubprocess(args=args, logger=logger) + + # EXPECT + mock_popen.assert_called_with( + args, + encoding="utf-8", + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=None, + ) + + def test_is_running(self, mock_popen: mock.Mock): + # GIVEN + proc_obj = mock.Mock() + proc_obj.poll.return_value = None + mock_popen.return_value = proc_obj + args = ["cat", "foo.txt"] + logger = mock.Mock() + + # WHEN + subject = LoggingSubprocess(args=args, logger=logger) + + # THEN + assert subject.is_running + proc_obj.poll.return_value = True + assert not subject.is_running + + def test_wait(self, mock_popen: mock.Mock, mock_stream_logger: mock.Mock): + # GIVEN + # mock stdout and stderr StreamLogger instances + stdout_logger = mock.Mock() + stderr_logger = mock.Mock() + mock_stream_logger.side_effect = [stdout_logger, stderr_logger] + # mock subprocess.Popen return value + proc = mock.Mock() + proc.poll.return_value = None + mock_popen.return_value = proc + args = ["cat", "foo.txt"] + logger = mock.Mock() + subject = LoggingSubprocess(args=args, logger=logger) + + # WHEN + subject.wait() + + # THEN + proc.wait.assert_called_once() + stdout_logger.join.assert_called_once() + stderr_logger.join.assert_called_once() + proc.stdout.close.assert_called_once() + proc.stderr.close.assert_called_once() + + @mock.patch.object(logging_subprocess.sys, "platform", "win32") + def test_terminate_fails_on_windows(self, mock_popen: mock.Mock): + args = ["cat", "foo.txt"] + proc = mock.Mock() + proc.poll.return_value = None + proc.pid = 1 + mock_popen.return_value = proc + logger = mock.Mock() + subject = LoggingSubprocess(args=args, logger=logger) + + with pytest.raises(NotImplementedError): + subject.terminate() + + def test_terminate_no_process(self, mock_popen: mock.Mock, mock_stream_logger: mock.Mock): + # GIVEN + # mock stdout and stderr StreamLogger instances + stdout_logger = mock.Mock() + stderr_logger = mock.Mock() + mock_stream_logger.side_effect = [stdout_logger, stderr_logger] + # mock subprocess.Popen return value + proc = mock.Mock() + proc.poll.return_value = False + proc.pid = 1 + mock_popen.return_value = proc + args = ["cat", "foo.txt"] + logger = mock.Mock() + subject = LoggingSubprocess(args=args, logger=logger) + + # WHEN + subject.terminate() + + # THEN + proc.terminate.assert_not_called() + proc.kill.assert_not_called() + proc.stdout.close.assert_not_called() + proc.stderr.close.assert_not_called() + + def test_terminate_no_grace(self, mock_popen: mock.Mock, mock_stream_logger: mock.Mock): + # GIVEN + # mock stdout and stderr StreamLogger instances + stdout_logger = mock.Mock() + stderr_logger = mock.Mock() + mock_stream_logger.side_effect = [stdout_logger, stderr_logger] + # mock subprocess.Popen return value + proc = mock.Mock() + proc.poll.return_value = None + proc.pid = 1 + mock_popen.return_value = proc + args = ["cat", "foo.txt"] + logger = mock.Mock() + subject = LoggingSubprocess(args=args, logger=logger) + + # WHEN + subject.terminate(0) + + # THEN + proc.terminate.assert_not_called() + proc.kill.assert_called_once() + proc.wait.assert_called_once() + stdout_logger.join.assert_called_once() + stderr_logger.join.assert_called_once() + proc.stdout.close.assert_called_once() + proc.stderr.close.assert_called_once() + + def test_terminate(self, mock_popen: mock.Mock, mock_stream_logger: mock.Mock): + # GIVEN + # mock stdout and stderr StreamLogger instances + stdout_logger = mock.Mock() + stderr_logger = mock.Mock() + mock_stream_logger.side_effect = [stdout_logger, stderr_logger] + # mock subprocess.Popen return value + proc = mock.Mock() + proc.poll.return_value = None + proc.pid = 1 + mock_popen.return_value = proc + args = ["cat", "foo.txt"] + logger = mock.Mock() + subject = LoggingSubprocess(args=args, logger=logger) + + # WHEN + subject.terminate() + + # THEN + proc.terminate.assert_called_once() + proc.kill.assert_not_called() + proc.wait.assert_called_once() + stdout_logger.join.assert_called_once() + stderr_logger.join.assert_called_once() + proc.stdout.close.assert_called_once() + proc.stderr.close.assert_called_once() + + def test_stop_after_terminate_timeout( + self, mock_popen: mock.Mock, mock_stream_logger: mock.Mock + ): + # GIVEN + args = ["cat", "foo.txt"] + timeout = 2 + # mock stdout and stderr StreamLogger instances + stdout_logger = mock.Mock() + stderr_logger = mock.Mock() + mock_stream_logger.side_effect = [stdout_logger, stderr_logger] + # mock subprocess.Popen return value + proc = mock.Mock() + proc.poll.return_value = None + proc.pid = 1 + + # When a subprocess doesn't terminate in the alloted time, it throws a TimeoutExpired + # exception. When this exception is thrown we send the SIGKILL signal, so we are + # simulating that here. + proc.wait.side_effect = [subprocess.TimeoutExpired(args, timeout), mock.DEFAULT] + + mock_popen.return_value = proc + logger = mock.Mock() + subject = LoggingSubprocess(args=args, logger=logger) + + # WHEN + subject.terminate(timeout) + + # THEN + proc.terminate.assert_called_once() + proc.kill.assert_called_once() + assert proc.wait.call_count == 2 + stdout_logger.join.assert_called_once() + stderr_logger.join.assert_called_once() + proc.stdout.close.assert_called_once() + proc.stderr.close.assert_called_once() + + def test_wait_multiple(self, mock_popen: mock.Mock, mock_stream_logger: mock.Mock): + # GIVEN + # mock stdout and stderr StreamLogger instances + stdout_logger = mock.Mock() + stderr_logger = mock.Mock() + mock_stream_logger.side_effect = [stdout_logger, stderr_logger] + # mock subprocess.Popen return value + proc = mock.Mock() + proc.poll.return_value = None + mock_popen.return_value = proc + args = ["cat", "foo.txt"] + logger = mock.Mock() + subject = LoggingSubprocess(args=args, logger=logger) + + subject.wait() + + proc.wait.assert_called_once() + stdout_logger.join.assert_called_once() + stderr_logger.join.assert_called_once() + proc.stdout.close.assert_called_once() + proc.stderr.close.assert_called_once() + + # clear tracked mock calls + mocks_to_reset = ( + proc.wait, + stdout_logger.join, + stderr_logger.join, + proc.stdout.close, + proc.stderr.close, + ) + for mock_to_reset in mocks_to_reset: + mock_to_reset.reset_mock() + + # WHEN + subject.wait() + + # THEN + proc.wait.assert_not_called() + stdout_logger.join.assert_not_called() + stderr_logger.join.assert_not_called() + proc.stdout.close.assert_not_called() + proc.stderr.close.assert_not_called() + + def test_terminate_multiple(self, mock_popen: mock.Mock, mock_stream_logger: mock.Mock): + # GIVEN + # mock stdout and stderr StreamLogger instances + stdout_logger = mock.Mock() + stderr_logger = mock.Mock() + mock_stream_logger.side_effect = [stdout_logger, stderr_logger] + # mock subprocess.Popen return value + proc = mock.Mock() + proc.poll.return_value = None + mock_popen.return_value = proc + args = ["cat", "foo.txt"] + logger = mock.Mock() + subject = LoggingSubprocess(args=args, logger=logger) + + subject.terminate() + + proc.terminate.assert_called_once() + proc.kill.assert_not_called() + proc.wait.assert_called_once() + stdout_logger.join.assert_called_once() + stderr_logger.join.assert_called_once() + proc.stdout.close.assert_called_once() + proc.stderr.close.assert_called_once() + + # clear tracked mock calls + mocks_to_reset = ( + proc.terminate, + proc.wait, + stdout_logger.join, + stderr_logger.join, + proc.stdout.close, + proc.stderr.close, + ) + for mock_to_reset in mocks_to_reset: + mock_to_reset.reset_mock() + + # WHEN + subject.terminate() + + # THEN + proc.terminate.assert_not_called() + proc.kill.assert_not_called() + proc.wait.assert_not_called() + stdout_logger.join.assert_not_called() + stderr_logger.join.assert_not_called() + proc.stdout.close.assert_not_called() + proc.stderr.close.assert_not_called() + + def test_stop_multiple(self, mock_popen: mock.Mock, mock_stream_logger: mock.Mock): + # GIVEN + args = ["cat", "foo.txt"] + timeout = 2 + # mock stdout and stderr StreamLogger instances + stdout_logger = mock.Mock() + stderr_logger = mock.Mock() + mock_stream_logger.side_effect = [stdout_logger, stderr_logger] + # mock subprocess.Popen return value + proc = mock.Mock() + proc.poll.return_value = None + proc.wait.side_effect = [subprocess.TimeoutExpired(args, timeout), mock.DEFAULT] + mock_popen.return_value = proc + logger = mock.Mock() + subject = LoggingSubprocess(args=args, logger=logger) + + subject.terminate(timeout) + + proc.terminate.assert_called_once() + proc.kill.assert_called_once() + assert proc.wait.call_count == 2 + stdout_logger.join.assert_called_once() + stderr_logger.join.assert_called_once() + proc.stdout.close.assert_called_once() + proc.stderr.close.assert_called_once() + + # clear tracked mock calls + mocks_to_reset = ( + proc.terminate, + proc.kill, + proc.wait, + stdout_logger.join, + stderr_logger.join, + proc.stdout.close, + proc.stderr.close, + ) + for mock_to_reset in mocks_to_reset: + mock_to_reset.reset_mock() + + # WHEN + subject.terminate(timeout) + + # THEN + proc.terminate.assert_not_called() + proc.kill.assert_not_called() + proc.wait.assert_not_called() + stdout_logger.join.assert_not_called() + stderr_logger.join.assert_not_called() + proc.stdout.close.assert_not_called() + proc.stderr.close.assert_not_called() + + def test_context_manager(self): + # GIVEN + args = ["cat", "foo.txt"] + logger = mock.Mock() + + # WHEN + subject = LoggingSubprocess(args=args, logger=logger) + with mock.patch.object(subject, "wait", wraps=subject.wait) as wait_spy: + with subject as mgr_yield_value: + wait_spy.assert_not_called() + + wait_spy.assert_called_once() + assert mgr_yield_value is subject + + def test_pid(self, mock_popen: mock.Mock): + # GIVEN + # mock subprocess.Popen return value + pid = 123 + proc = mock.Mock() + proc.pid = pid + mock_popen.return_value = proc + args = ["cat", "foo.txt"] + logger = mock.Mock() + + # WHEN + subject = LoggingSubprocess(args=args, logger=logger) + + # THEN + assert subject.pid == pid + + def test_returncode_success(self, mock_popen: mock.Mock): + # GIVEN + # mock subprocess.Popen return value + returncode = 1 + proc = mock.Mock() + proc.poll.return_value = returncode + mock_popen.return_value = proc + args = ["cat", "foo.txt"] + logger = mock.Mock() + + # WHEN + subject = LoggingSubprocess(args=args, logger=logger) + + # THEN + assert subject.returncode == returncode + proc.poll.assert_called_once() + + def test_returncode_subproc_running(self, mock_popen: mock.Mock): + # GIVEN + + # mock subprocess.Popen return value + proc = mock.Mock() + mock_popen.return_value = proc + # Popen.poll() returns None when the subprocess is still running + proc.poll.return_value = None + + args = ["cat", "foo.txt"] + logger = mock.Mock() + subject = LoggingSubprocess(args=args, logger=logger) + + assert subject.returncode is None + + def test_command_printed(self, mock_popen: mock.Mock, caplog): + caplog.set_level(INFO) + + # mock subprocess.Popen return value + proc = mock.Mock() + mock_popen.return_value = proc + # Popen.poll() returns None when the subprocess is still running + proc.poll.return_value = None + + args = ["cat", "foo.txt"] + LoggingSubprocess(args=args) + + assert "Running command: cat foo.txt" in caplog.text + + @mock.patch.object(logging_subprocess.subprocess, "Popen", autospec=True) + def test_startup_directory_default(self, mock_popen_autospec: mock.Mock): + # mock subprocess.Popen return value + proc = mock.Mock() + mock_popen_autospec.return_value = proc + + args = ["cat", "foo.txt"] + LoggingSubprocess(args=args) + + # cwd will equal the startup direcotry, since that is None by default, + # we expect cwd to be None. + mock_popen_autospec.assert_called_once_with( + args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + cwd=None, + ) + + @mock.patch.object(logging_subprocess.subprocess, "Popen", autospec=True) + def test_start_directory(self, mock_popen_autospec: mock.Mock): + # mock subprocess.Popen return value + proc = mock.Mock() + mock_popen_autospec.return_value = proc + + args = ["cat", "foo.txt"] + LoggingSubprocess(args=args, startup_directory="startup_dir") + + mock_popen_autospec.assert_called_once_with( + args, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8", + cwd="startup_dir", + ) diff --git a/test/openjd/adaptor_runtime/unit/process/test_managed_process.py b/test/openjd/adaptor_runtime/unit/process/test_managed_process.py new file mode 100644 index 0000000..d653d55 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/process/test_managed_process.py @@ -0,0 +1,73 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from typing import List, Optional +from unittest import mock + +import pytest + +import openjd.adaptor_runtime.process._logging_subprocess as logging_subprocess +import openjd.adaptor_runtime.process._managed_process as managed_process +from openjd.adaptor_runtime.process import ManagedProcess + + +class TestManagedProcess(object): + """Unit tests for ManagedProcess""" + + @pytest.fixture(autouse=True) + def mock_popen(self): + with mock.patch.object(logging_subprocess.subprocess, "Popen") as popen_mock: + yield popen_mock + + @pytest.fixture(autouse=True) + def mock_stream_logger(self): + with mock.patch.object(logging_subprocess, "StreamLogger") as stream_logger: + stdout_logger_mock = mock.Mock() + stderr_logger_mock = mock.Mock() + stream_logger.side_effect = [stdout_logger_mock, stderr_logger_mock] + yield stream_logger + + startup_dirs = [ + pytest.param("", ["Hello World!"], "/path/for/startup", id="EmptyExecutable"), + pytest.param("echo", ["Hello World!"], "/path/for/startup", id="EchoExecutable"), + pytest.param("echo", [""], "/path/for/startup", id="EmptyArguments"), + pytest.param("echo", ["Hello World!"], "", id="EmptyStartupDir"), + pytest.param("echo", ["Hello World!"], None, id="NoStartupDir"), + pytest.param( + "echo", + ["Hello World!"], + "/path/for/startup", + id="RandomStartupDir", + ), + ] + + @pytest.mark.parametrize("executable, arguments, startup_dir", startup_dirs) + @mock.patch.object(managed_process, "LoggingSubprocess", autospec=True) + def test_run( + self, + mock_LoggingSubprocess: mock.Mock, + executable: str, + arguments: List[str], + startup_dir: str, + ): + class FakeManagedProcess(ManagedProcess): + def __init__(self, run_data: dict): + super(FakeManagedProcess, self).__init__(run_data) + + def get_executable(self) -> str: + return executable + + def get_arguments(self) -> List[str]: + return arguments + + def get_startup_directory(self) -> Optional[str]: + return startup_dir + + mp = FakeManagedProcess({}) + mp.run() + + mock_LoggingSubprocess.assert_called_once_with( + args=[executable] + arguments, + startup_directory=startup_dir, + stdout_handler=None, + stderr_handler=None, + ) diff --git a/test/openjd/adaptor_runtime/unit/process/test_stream_logger.py b/test/openjd/adaptor_runtime/unit/process/test_stream_logger.py new file mode 100644 index 0000000..23ee124 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/process/test_stream_logger.py @@ -0,0 +1,132 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +"""Tests for LoggingSubprocess""" +from __future__ import annotations + +import logging +import os +from typing import List +from unittest import mock + +import pytest + +import openjd.adaptor_runtime.process._stream_logger as stream_logger +from openjd.adaptor_runtime.process._stream_logger import StreamLogger + + +class TestStreamLogger(object): + """Tests for StreamLogger""" + + @pytest.fixture(autouse=True) + def mock_thread(self): + with mock.patch.object(stream_logger, "Thread") as mock_thread: + yield mock_thread + + def test_not_daemon_default(self): + # GIVEN + stream = mock.Mock() + logger = mock.Mock() + + # WHEN + subject = StreamLogger(stream=stream, loggers=[logger]) + + # THEN + assert not subject.daemon + + @pytest.mark.parametrize( + ("lines",), + ( + (["foo", "bar"],), + (["foo"],), + ([],), + ), + ) + def test_level_info_default(self, lines: List[str]): + # GIVEN + # stream.readline() includes newline characters + readline_returns = [f"{line}{os.linesep}" for line in lines] + stream = mock.Mock() + stream.closed = False + # stream.readline() returns an empty string on EOF + stream.readline.side_effect = readline_returns + [""] + logger = mock.Mock() + subject = StreamLogger(stream=stream, loggers=[logger]) + + # WHEN + subject.run() + + # THEN + logger.log.assert_has_calls([mock.call(logging.INFO, line) for line in lines]) + + def test_supplied_logging_level(self): + # GIVEN + level = logging.CRITICAL + log_line = "foo" + # stream.readline() includes newline characters + readline_returns = [f"{log_line}{os.linesep}"] + stream = mock.Mock() + stream.closed = False + # stream.readline() returns an empty string on EOF + stream.readline.side_effect = readline_returns + [""] + logger = mock.Mock() + subject = StreamLogger(stream=stream, loggers=[logger], level=level) + + # WHEN + subject.run() + + # THEN + logger.log.assert_has_calls([mock.call(level, log_line)]) + + def test_multiple_loggers(self): + # GIVEN + level = logging.INFO + log_line = "foo" + loggers = [mock.Mock() for _ in range(5)] + # stream.readline() includes newline characters + readline_returns = [f"{log_line}{os.linesep}"] + stream = mock.Mock() + stream.closed = False + # stream.readline() returns an empty string on EOF + stream.readline.side_effect = readline_returns + [""] + subject = StreamLogger(stream=stream, loggers=loggers, level=level) + + # WHEN + subject.run() + + # THEN + for logger in loggers: + logger.log.assert_has_calls([mock.call(level, log_line)]) + + def test_readline_failure_raises(self): + # GIVEN + err = ValueError() + stream = mock.Mock() + stream.readline.side_effect = err + subject = StreamLogger(stream=stream, loggers=[mock.Mock()]) + + # WHEN + with pytest.raises(ValueError) as raised_err: + subject.run() + + # THEN + assert raised_err.value is err + stream.readline.assert_called_once() + + def test_io_failure_logs_error(self): + # GIVEN + err = ValueError("I/O operation on closed file") + stream = mock.Mock() + stream.readline.side_effect = err + logger = mock.Mock() + subject = StreamLogger(stream=stream, loggers=[logger]) + + # WHEN + subject.run() + + # THEN + stream.readline.assert_called_once() + logger.log.assert_called_once_with( + stream_logger.logging.WARNING, + "The StreamLogger could not read from the stream. This is most likely because " + "the stream was closed before the stream logger.", + ) diff --git a/test/openjd/adaptor_runtime/unit/test_entrypoint.py b/test/openjd/adaptor_runtime/unit/test_entrypoint.py new file mode 100644 index 0000000..3747676 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/test_entrypoint.py @@ -0,0 +1,620 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import argparse +import json +import signal +from unittest.mock import ANY, MagicMock, Mock, PropertyMock, mock_open, patch + +import jsonschema +import pytest +import yaml + +import openjd.adaptor_runtime._entrypoint as runtime_entrypoint +from openjd.adaptor_runtime import EntryPoint +from openjd.adaptor_runtime.adaptors.configuration import ( + ConfigurationManager, + RuntimeConfiguration, +) +from openjd.adaptor_runtime.adaptors import BaseAdaptor +from openjd.adaptor_runtime._background import BackendRunner, FrontendRunner +from openjd.adaptor_runtime._osname import OSName +from openjd.adaptor_runtime._entrypoint import _load_data + +from .adaptors.fake_adaptor import FakeAdaptor +from .adaptors.configuration.stubs import AdaptorConfigurationStub, RuntimeConfigurationStub + + +@pytest.fixture(autouse=True) +def mock_configuration(): + with patch.object( + ConfigurationManager, "build_config", return_value=RuntimeConfigurationStub() + ): + yield + + +@pytest.fixture(autouse=True) +def mock_logging(): + with ( + patch.object( + BaseAdaptor, + "config", + new_callable=PropertyMock(return_value=AdaptorConfigurationStub()), + ), + ): + yield + + +@pytest.fixture(autouse=True) +def mock_getLogger(): + with patch.object(runtime_entrypoint.logging, "getLogger"): + yield + + +@pytest.fixture +def mock_adaptor_cls(): + mock_adaptor_cls = MagicMock() + mock_adaptor_cls.return_value.config = AdaptorConfigurationStub() + return mock_adaptor_cls + + +class TestStart: + """ + Tests for the EntryPoint.start method + """ + + @patch.object(EntryPoint, "_parse_args") + def test_creates_adaptor_with_init_data( + self, _parse_args_mock: MagicMock, mock_adaptor_cls: MagicMock + ): + # GIVEN + init_data = {"init": "data"} + _parse_args_mock.return_value = argparse.Namespace(init_data=init_data) + entrypoint = EntryPoint(mock_adaptor_cls) + + # WHEN + entrypoint.start() + + # THEN + _parse_args_mock.assert_called_once() + mock_adaptor_cls.assert_called_once_with(init_data, path_mapping_data={}) + + @patch.object(EntryPoint, "_parse_args") + def test_creates_adaptor_with_path_mapping( + self, _parse_args_mock: MagicMock, mock_adaptor_cls: MagicMock + ): + # GIVEN + init_data = {"init": "data"} + path_mapping_rules = {"path_mapping_rules": "data"} + _parse_args_mock.return_value = argparse.Namespace( + init_data=init_data, path_mapping_rules=path_mapping_rules + ) + entrypoint = EntryPoint(mock_adaptor_cls) + + # WHEN + entrypoint.start() + + # THEN + _parse_args_mock.assert_called_once() + mock_adaptor_cls.assert_called_once_with(init_data, path_mapping_data=path_mapping_rules) + + @patch.object(EntryPoint, "_parse_args") + @patch.object(FakeAdaptor, "_cleanup") + @patch.object(FakeAdaptor, "_start") + def test_raises_adaptor_exception( + self, + mock_start: MagicMock, + mock_cleanup: MagicMock, + mock_parse_args: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + mock_start.side_effect = Exception() + mock_parse_args.return_value = argparse.Namespace(command="run") + entrypoint = EntryPoint(FakeAdaptor) + + # WHEN + with pytest.raises(Exception) as raised_exc: + entrypoint.start() + + # THEN + assert raised_exc.value is mock_start.side_effect + assert "Error running the adaptor: " in caplog.text + mock_start.assert_called_once() + mock_cleanup.assert_called_once() + + @patch.object(EntryPoint, "_parse_args") + @patch.object(FakeAdaptor, "_cleanup") + @patch.object(FakeAdaptor, "_start") + def test_raises_adaptor_cleanup_exception( + self, + mock_start: MagicMock, + mock_cleanup: MagicMock, + mock_parse_args: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + mock_start.side_effect = Exception() + mock_cleanup.side_effect = Exception() + mock_parse_args.return_value = argparse.Namespace(command="run") + entrypoint = EntryPoint(FakeAdaptor) + + # WHEN + with pytest.raises(Exception) as raised_exc: + entrypoint.start() + + # THEN + assert raised_exc.value is mock_cleanup.side_effect + assert "Error running the adaptor: " in caplog.text + assert "Error cleaning up the adaptor: " in caplog.text + mock_start.assert_called_once() + mock_cleanup.assert_called_once() + + @patch.object(argparse.ArgumentParser, "parse_args") + def test_raises_argparse_exception( + self, mock_parse_args: MagicMock, caplog: pytest.LogCaptureFixture + ): + # GIVEN + mock_parse_args.side_effect = Exception() + entrypoint = EntryPoint(FakeAdaptor) + + # WHEN + with pytest.raises(Exception) as raised_exc: + entrypoint.start() + + # THEN + assert raised_exc.value is mock_parse_args.side_effect + assert "Error parsing command line arguments: " in caplog.text + + @patch.object(ConfigurationManager, "build_config") + def test_raises_jsonschema_validation_err( + self, + mock_build_config: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + mock_build_config.side_effect = jsonschema.ValidationError("") + entrypoint = EntryPoint(FakeAdaptor) + + # WHEN + with pytest.raises(jsonschema.ValidationError) as raised_err: + entrypoint.start() + + # THEN + mock_build_config.assert_called_once() + assert raised_err.value is mock_build_config.side_effect + assert "Nonvalid runtime configuration file: " in caplog.text + + @patch.object(ConfigurationManager, "get_default_config") + @patch.object(ConfigurationManager, "build_config") + def test_uses_default_config_on_unsupported_system( + self, + mock_build_config: MagicMock, + mock_get_default_config: MagicMock, + caplog: pytest.LogCaptureFixture, + ): + # GIVEN + mock_build_config.side_effect = NotImplementedError() + mock_get_default_config.return_value = RuntimeConfigurationStub() + entrypoint = EntryPoint(FakeAdaptor) + + # WHEN + entrypoint.start() + + # THEN + mock_build_config.assert_called_once() + mock_get_default_config.assert_called_once() + assert entrypoint.config is mock_get_default_config.return_value + assert f"The current system ({OSName()}) is not supported for runtime " + "configuration. Only the default configuration will be loaded. Full error: " in caplog.text + + @patch.object(EntryPoint, "_parse_args") + @patch.object(ConfigurationManager, "build_config") + @patch.object(RuntimeConfiguration, "config", new_callable=PropertyMock) + @patch.object(runtime_entrypoint, "print") + def test_shows_config( + self, + print_spy: MagicMock, + mock_config: MagicMock, + mock_build_config: MagicMock, + mock_parse_args: MagicMock, + ): + # GIVEN + config = {"key": "value"} + mock_parse_args.return_value = argparse.Namespace(show_config=True) + mock_config.return_value = config + mock_build_config.return_value = RuntimeConfiguration({}) + entrypoint = EntryPoint(FakeAdaptor) + + # WHEN + entrypoint.start() + + # THEN + mock_parse_args.assert_called_once() + mock_build_config.assert_called_once() + mock_config.assert_called_once() + print_spy.assert_called_once_with(yaml.dump(config, indent=2)) + + @patch.object(EntryPoint, "_parse_args") + def test_runs_in_run_mode(self, _parse_args_mock: MagicMock, mock_adaptor_cls: MagicMock): + # GIVEN + init_data = {"init": "data"} + run_data = {"run": "data"} + _parse_args_mock.return_value = argparse.Namespace( + command="run", + init_data=init_data, + run_data=run_data, + ) + entrypoint = EntryPoint(mock_adaptor_cls) + + # WHEN + entrypoint.start() + + # THEN + _parse_args_mock.assert_called_once() + mock_adaptor_cls.assert_called_once_with(init_data, path_mapping_data=ANY) + + mock_adaptor_cls.return_value._start.assert_called_once() + mock_adaptor_cls.return_value._run.assert_called_once_with(run_data) + mock_adaptor_cls.return_value._stop.assert_called_once() + mock_adaptor_cls.return_value._cleanup.assert_called_once() + + @patch.object(runtime_entrypoint, "AdaptorRunner") + @patch.object(EntryPoint, "_parse_args") + @patch.object(runtime_entrypoint.signal, "signal") + def test_runmode_signal_hook( + self, + signal_mock: MagicMock, + _parse_args_mock: MagicMock, + mock_adaptor_runner: MagicMock, + mock_adaptor_cls: MagicMock, + ): + # GIVEN + init_data = {"init": "data"} + run_data = {"run": "data"} + _parse_args_mock.return_value = argparse.Namespace( + command="run", + init_data=init_data, + run_data=run_data, + ) + entrypoint = EntryPoint(mock_adaptor_cls) + + # WHEN + entrypoint.start() + entrypoint._sigint_handler(MagicMock(), MagicMock()) + + # THEN + signal_mock.assert_any_call(signal.SIGINT, entrypoint._sigint_handler) + signal_mock.assert_any_call(signal.SIGTERM, entrypoint._sigint_handler) + mock_adaptor_runner.return_value._cancel.assert_called_once() + + @patch.object(runtime_entrypoint, "InMemoryLogBuffer") + @patch.object(runtime_entrypoint, "AdaptorRunner") + @patch.object(EntryPoint, "_parse_args") + @patch.object(BackendRunner, "run") + @patch.object(BackendRunner, "__init__", return_value=None) + def test_runs_background_serve( + self, + mock_init: MagicMock, + mock_run: MagicMock, + _parse_args_mock: MagicMock, + mock_adaptor_runner: MagicMock, + mock_log_buffer: MagicMock, + mock_adaptor_cls: MagicMock, + ): + # GIVEN + init_data = {"init": "data"} + conn_file = "/path/to/conn_file" + _parse_args_mock.return_value = argparse.Namespace( + command="daemon", + subcommand="_serve", + init_data=init_data, + connection_file=conn_file, + ) + entrypoint = EntryPoint(mock_adaptor_cls) + + # WHEN + entrypoint.start() + + # THEN + _parse_args_mock.assert_called_once() + mock_adaptor_cls.assert_called_once_with(init_data, path_mapping_data=ANY) + mock_adaptor_runner.assert_called_once_with( + adaptor=mock_adaptor_cls.return_value, + ) + mock_init.assert_called_once_with( + mock_adaptor_runner.return_value, + conn_file, + log_buffer=mock_log_buffer.return_value, + ) + mock_run.assert_called_once() + + @patch.object(runtime_entrypoint, "AdaptorRunner") + @patch.object(EntryPoint, "_parse_args") + @patch.object(BackendRunner, "run") + @patch.object(BackendRunner, "__init__", return_value=None) + @patch.object(runtime_entrypoint.signal, "signal") + def test_background_serve_no_signal_hook( + self, + signal_mock: MagicMock, + mock_init: MagicMock, + mock_run: MagicMock, + _parse_args_mock: MagicMock, + mock_adaptor_cls: MagicMock, + ): + # GIVEN + init_data = {"init": "data"} + conn_file = "/path/to/conn_file" + _parse_args_mock.return_value = argparse.Namespace( + command="daemon", + subcommand="_serve", + init_data=init_data, + connection_file=conn_file, + ) + entrypoint = EntryPoint(mock_adaptor_cls) + + # WHEN + entrypoint.start() + + # THEN + signal_mock.assert_not_called() + + @patch.object(EntryPoint, "_parse_args") + @patch.object(FrontendRunner, "__init__", return_value=None) + def test_background_start_raises_when_adaptor_module_not_loaded( + self, + mock_magic_init: MagicMock, + _parse_args_mock: MagicMock, + ): + # GIVEN + conn_file = "/path/to/conn_file" + _parse_args_mock.return_value = argparse.Namespace( + command="daemon", + subcommand="start", + connection_file=conn_file, + ) + entrypoint = EntryPoint(FakeAdaptor) + + # WHEN + with patch.dict(runtime_entrypoint.sys.modules, {FakeAdaptor.__module__: None}): + with pytest.raises(ModuleNotFoundError) as raised_err: + entrypoint.start() + + # THEN + assert raised_err.match(f"Adaptor module is not loaded: {FakeAdaptor.__module__}") + _parse_args_mock.assert_called_once() + mock_magic_init.assert_called_once_with(conn_file) + + @patch.object(EntryPoint, "_parse_args") + @patch.object(FrontendRunner, "__init__", return_value=None) + @patch.object(FrontendRunner, "init") + @patch.object(FrontendRunner, "start") + def test_runs_background_start( + self, + mock_start: MagicMock, + mock_magic_init: MagicMock, + mock_magic_start: MagicMock, + _parse_args_mock: MagicMock, + ): + # GIVEN + conn_file = "/path/to/conn_file" + _parse_args_mock.return_value = argparse.Namespace( + command="daemon", + subcommand="start", + connection_file=conn_file, + ) + mock_adaptor_module = Mock() + entrypoint = EntryPoint(FakeAdaptor) + + # WHEN + with patch.dict( + runtime_entrypoint.sys.modules, {FakeAdaptor.__module__: mock_adaptor_module} + ): + entrypoint.start() + + # THEN + _parse_args_mock.assert_called_once() + mock_magic_init.assert_called_once_with(mock_adaptor_module, {}) + mock_magic_start.assert_called_once_with(conn_file) + mock_start.assert_called_once_with() + + @patch.object(EntryPoint, "_parse_args") + @patch.object(FrontendRunner, "__init__", return_value=None) + @patch.object(FrontendRunner, "shutdown") + @patch.object(FrontendRunner, "stop") + def test_runs_background_stop( + self, + mock_end: MagicMock, + mock_shutdown: MagicMock, + mock_magic_init: MagicMock, + _parse_args_mock: MagicMock, + ): + # GIVEN + conn_file = "/path/to/conn_file" + _parse_args_mock.return_value = argparse.Namespace( + command="daemon", + subcommand="stop", + connection_file=conn_file, + ) + entrypoint = EntryPoint(FakeAdaptor) + + # WHEN + entrypoint.start() + + # THEN + _parse_args_mock.assert_called_once() + mock_magic_init.assert_called_once_with(conn_file) + mock_end.assert_called_once() + mock_shutdown.assert_called_once_with() + + @patch.object(EntryPoint, "_parse_args") + @patch.object(FrontendRunner, "__init__", return_value=None) + @patch.object(FrontendRunner, "run") + def test_runs_background_run( + self, + mock_run: MagicMock, + mock_magic_init: MagicMock, + _parse_args_mock: MagicMock, + ): + # GIVEN + conn_file = "/path/to/conn_file" + run_data = {"run": "data"} + _parse_args_mock.return_value = argparse.Namespace( + command="daemon", + subcommand="run", + connection_file=conn_file, + run_data=run_data, + ) + entrypoint = EntryPoint(FakeAdaptor) + + # WHEN + entrypoint.start() + + # THEN + _parse_args_mock.assert_called_once() + mock_magic_init.assert_called_once_with(conn_file) + mock_run.assert_called_once_with(run_data) + + @patch.object(EntryPoint, "_parse_args") + @patch.object(FrontendRunner, "__init__", return_value=None) + @patch.object(FrontendRunner, "run") + @patch.object(runtime_entrypoint.signal, "signal") + def test_background_no_signal_hook( + self, + signal_mock: MagicMock, + mock_run: MagicMock, + mock_magic_init: MagicMock, + _parse_args_mock: MagicMock, + ): + # GIVEN + conn_file = "/path/to/conn_file" + run_data = {"run": "data"} + _parse_args_mock.return_value = argparse.Namespace( + command="daemon", + subcommand="run", + connection_file=conn_file, + run_data=run_data, + ) + entrypoint = EntryPoint(FakeAdaptor) + + # WHEN + entrypoint.start() + + # THEN + signal_mock.assert_not_called() + + @patch.object(EntryPoint, "_parse_args") + @patch.object(FrontendRunner, "__init__", return_value=None) + def test_makes_connection_file_path_absolute( + self, + mock_init: MagicMock, + _parse_args_mock: MagicMock, + ): + # GIVEN + conn_file = "relpath" + _parse_args_mock.return_value = argparse.Namespace( + command="daemon", + subcommand="", + connection_file=conn_file, + ) + + entrypoint = EntryPoint(FakeAdaptor) + + # WHEN + mock_isabs: MagicMock + with ( + patch.object(runtime_entrypoint.os.path, "isabs", return_value=False) as mock_isabs, + patch.object(runtime_entrypoint.os.path, "abspath") as mock_abspath, + ): + entrypoint.start() + + # THEN + _parse_args_mock.assert_called_once() + mock_isabs.assert_called_once_with(conn_file) + mock_abspath.assert_called_once_with(conn_file) + mock_init.assert_called_once_with(mock_abspath.return_value) + + +class TestLoadData: + """ + Tests for the _load_data method + """ + + def test_defaults_to_dict(self): + assert _load_data("") == {} + + @pytest.mark.parametrize( + argnames=["input", "expected"], + argvalues=[ + [json.dumps({"hello": "world"}), {"hello": "world"}], + [yaml.dump({"hello": "world"}), {"hello": "world"}], + ], + ids=["JSON", "YAML"], + ) + def test_accepts_string(self, input: str, expected: dict, caplog: pytest.LogCaptureFixture): + # WHEN + output = _load_data(input) + + # THEN + assert output == expected + + @pytest.mark.parametrize( + argnames=["input", "expected"], + argvalues=[ + [json.dumps({"hello": "world"}), {"hello": "world"}], + [yaml.dump({"hello": "world"}), {"hello": "world"}], + ], + ids=["JSON", "YAML"], + ) + def test_accepts_file(self, input: str, expected: dict): + # GIVEN + filepath = "/my/file" + file_uri = f"file://{filepath}" + + # WHEN + open_mock: MagicMock + with patch.object(runtime_entrypoint, "open", mock_open(read_data=input)) as open_mock: + output = _load_data(file_uri) + + # THEN + assert output == expected + open_mock.assert_called_once_with(filepath) + + @patch.object(runtime_entrypoint, "open") + def test_raises_on_os_error(self, mock_open: MagicMock, caplog: pytest.LogCaptureFixture): + # GIVEN + filepath = "/my/file.txt" + file_uri = f"file://{filepath}" + mock_open.side_effect = OSError() + + # WHEN + with pytest.raises(OSError) as raised_err: + _load_data(file_uri) + + # THEN + assert raised_err.value is mock_open.side_effect + mock_open.assert_called_once_with(filepath) + assert "Failed to open data file: " in caplog.text + + def test_raises_when_parsing_fails(self, caplog: pytest.LogCaptureFixture): + # GIVEN + input = "@" + + # WHEN + with pytest.raises(yaml.YAMLError): + _load_data(input) + + # THEN + assert "Failed to load data as JSON or YAML: " in caplog.text + + def test_raises_on_nonvalid_parsed_data_type(self): + # GIVEN + input = "input" + + # WHEN + with pytest.raises(ValueError) as raised_err: + _load_data(input) + + # THEN + assert raised_err.match(f"Expected loaded data to be a dict, but got {type(input)}") diff --git a/test/openjd/adaptor_runtime/unit/test_osname.py b/test/openjd/adaptor_runtime/unit/test_osname.py new file mode 100644 index 0000000..7b04bba --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/test_osname.py @@ -0,0 +1,71 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from typing import Callable +from unittest.mock import Mock, patch + +import pytest + +import openjd.adaptor_runtime._osname as osname +from openjd.adaptor_runtime._osname import OSName + + +class TestOSName: + @pytest.mark.parametrize("platform", ["Windows", "Darwin", "Linux"]) + @patch.object(osname, "platform") + def test_empty_init_returns_osname(self, mock_platform: Mock, platform: str): + # GIVEN + mock_platform.system.return_value = platform + + # WHEN + osname = OSName() + + # THEN + assert isinstance(osname, OSName) + if platform == "Darwin": + assert str(osname) == OSName.MACOS + else: + assert str(osname) == platform + + alias_params = [ + pytest.param( + ( + "Darwin", + "darwin", + "MacOS", + "macos", + "mac", + "Mac", + "mac os", + "MAC OS", + "os x", + "OS X", + ), + OSName.MACOS, + OSName.is_macos, + id="macOS", + ), + pytest.param( + ("Windows", "win", "win32", "nt", "windows"), + OSName.WINDOWS, + OSName.is_windows, + id="windows", + ), + pytest.param(("linux", "linux2"), OSName.LINUX, OSName.is_linux, id="linux"), + pytest.param(("posix", "Posix", "POSIX"), OSName.POSIX, OSName.is_posix, id="posix"), + ] + + @pytest.mark.parametrize("aliases, expected, is_os_func", alias_params) + def test_aliases(self, aliases: list[str], expected: str, is_os_func: Callable): + for alias in aliases: + # WHEN + osname = OSName(alias) + + # THEN + assert isinstance( + osname, OSName + ), f"OSName('{alias}') did not return object of type OSName" + assert str(osname) == expected, f"OSName('{alias}') did not resolve to '{expected}'" + assert ( + osname == alias + ), f"OSName.__eq__ failed comparison with OSName('{alias}') and '{alias}'" + assert is_os_func(alias), f"OSName.is_{expected.lower()}() failed for '{alias}'" diff --git a/test/openjd/adaptor_runtime/unit/utils/__init__.py b/test/openjd/adaptor_runtime/unit/utils/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/utils/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime/unit/utils/test_secure_open.py b/test/openjd/adaptor_runtime/unit/utils/test_secure_open.py new file mode 100644 index 0000000..e4adea4 --- /dev/null +++ b/test/openjd/adaptor_runtime/unit/utils/test_secure_open.py @@ -0,0 +1,97 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import os +import stat +from unittest.mock import mock_open, patch + +import pytest + +from openjd.adaptor_runtime._utils import secure_open + +READ_FLAGS = os.O_RDONLY +WRITE_FLAGS = os.O_WRONLY | os.O_TRUNC | os.O_CREAT +APPEND_FLAGS = os.O_WRONLY | os.O_APPEND | os.O_CREAT +EXCL_FLAGS = os.O_EXCL | os.O_CREAT | os.O_WRONLY +UPDATE_FLAGS = os.O_RDWR | os.O_CREAT + +FLAG_DICT = { + "r": READ_FLAGS, + "w": WRITE_FLAGS, + "a": APPEND_FLAGS, + "x": EXCL_FLAGS, + "+": UPDATE_FLAGS, + "": 0, +} + + +@pytest.mark.parametrize( + argnames=["path", "open_mode", "mask", "expected_os_open_kwargs"], + argvalues=[ + ( + "/path/to/file", + "".join((mode, update_flag)), + mask, + { + "path": "/path/to/file", + "flags": FLAG_DICT[mode] | FLAG_DICT[update_flag], + "mode": stat.S_IWUSR | stat.S_IRUSR | mask, + }, + ) + for mode in ("r", "w", "a", "x") + for update_flag in ("", "+") + for mask in (stat.S_IRGRP | stat.S_IWGRP, 0) + ], +) +@patch.object(os, "open") +def test_secure_open(mock_os_open, path, open_mode, mask, expected_os_open_kwargs): + # WHEN + with patch("builtins.open", mock_open()) as mocked_open: + secure_open_kwargs = {"mask": mask} if mask else {} + with secure_open(path, open_mode, **secure_open_kwargs): + pass + + # THEN + if open_mode == "r": + del expected_os_open_kwargs["mode"] + mock_os_open.assert_called_once_with(**expected_os_open_kwargs) + mocked_open.assert_called_once_with(mock_os_open.return_value, open_mode) + + +@pytest.mark.parametrize( + argnames=["path", "open_mode", "encoding", "newline"], + argvalues=[ + ( + "/path/to/file", + "w", + encoding, + newline, + ) + for encoding in ("utf-8", "utf-16", None) + for newline in ("\n", "\r\n", None) + ], +) +@patch.object(os, "open") +def test_secure_open_passes_open_kwargs(mock_os_open, path, open_mode, encoding, newline): + # WHEN + open_kwargs = {} + if encoding: + open_kwargs["encoding"] = encoding + if newline: + open_kwargs["newline"] = newline + + with patch("builtins.open", mock_open()) as mocked_open: + with secure_open(path, open_mode, **open_kwargs): + pass + + # THEN + mocked_open.assert_called_once_with(mock_os_open.return_value, open_mode, **open_kwargs) + + +def test_raises_when_nonvalid_mode(): + # WHEN + with pytest.raises(ValueError) as exc_info: + with secure_open("/path/to/file", "something"): + pass + + # THEN + assert str(exc_info.value) == "Nonvalid mode: 'something'" diff --git a/test/openjd/adaptor_runtime_client/__init__.py b/test/openjd/adaptor_runtime_client/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime_client/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime_client/integ/__init__.py b/test/openjd/adaptor_runtime_client/integ/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime_client/integ/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime_client/integ/fake_client.py b/test/openjd/adaptor_runtime_client/integ/fake_client.py new file mode 100644 index 0000000..6f49a6e --- /dev/null +++ b/test/openjd/adaptor_runtime_client/integ/fake_client.py @@ -0,0 +1,36 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +from time import sleep as _sleep +from types import FrameType as _FrameType +from typing import Any as _Any +from typing import Dict as _Dict +from typing import Optional as _Optional + +from openjd.adaptor_runtime_client import HTTPClientInterface as _HTTPClientInterface + + +class FakeClient(_HTTPClientInterface): + shutdown: bool + + def __init__(self, port: str) -> None: + super().__init__(port) + self.shutdown = False + + def close(self, args: _Optional[_Dict[str, _Any]]) -> None: + print("closing") + + def graceful_shutdown(self, signum: int, frame: _Optional[_FrameType]) -> None: + print("Received SIGTERM signal.") + self.shutdown = True + + def run(self): + count = 0 + while not self.shutdown: + _sleep(0.25) + count += 1 + + +test_client = FakeClient("1234") +test_client.run() diff --git a/test/openjd/adaptor_runtime_client/integ/test_integration_client_interface.py b/test/openjd/adaptor_runtime_client/integ/test_integration_client_interface.py new file mode 100644 index 0000000..0908816 --- /dev/null +++ b/test/openjd/adaptor_runtime_client/integ/test_integration_client_interface.py @@ -0,0 +1,35 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from __future__ import annotations + +import os as _os +import subprocess as _subprocess +from time import sleep as _sleep + + +class TestIntegrationClientInterface: + """ "These are the integration tests for the client interface.""" + + def test_graceful_shutdown(self) -> None: + client_subprocess = _subprocess.Popen( + [ + "python", + _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "fake_client.py"), + ], + stdin=_subprocess.PIPE, + stderr=_subprocess.PIPE, + stdout=_subprocess.PIPE, + encoding="utf-8", + ) + + # To avoid a race condition, giving some extra time for the logging subprocess to start. + _sleep(0.5) + client_subprocess.terminate() + + # To avoid a race condition, giving some extra time for the log to be updated after + # receiving the signal. + _sleep(0.5) + + out, _ = client_subprocess.communicate() + + assert "Received SIGTERM signal." in out diff --git a/test/openjd/adaptor_runtime_client/test_importable.py b/test/openjd/adaptor_runtime_client/test_importable.py new file mode 100644 index 0000000..461ed71 --- /dev/null +++ b/test/openjd/adaptor_runtime_client/test_importable.py @@ -0,0 +1,9 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + +def test_openjd_importable(): + import openjd # noqa: F401 + + +def test_importable(): + import openjd.adaptor_runtime # noqa: F401 diff --git a/test/openjd/adaptor_runtime_client/unit/__init__.py b/test/openjd/adaptor_runtime_client/unit/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/adaptor_runtime_client/unit/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/adaptor_runtime_client/unit/test_action.py b/test/openjd/adaptor_runtime_client/unit/test_action.py new file mode 100644 index 0000000..b3519aa --- /dev/null +++ b/test/openjd/adaptor_runtime_client/unit/test_action.py @@ -0,0 +1,67 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import json as _json +from dataclasses import asdict as _asdict + +import pytest +from _pytest.capture import CaptureFixture as _CaptureFixture + +from openjd.adaptor_runtime_client import Action as _Action + + +class TestAction: + def test_action_dict_cast(self) -> None: + """Tests the action can be converted to a dictionary we expect.""" + name = "test" + args = None + expected_dict = {"name": name, "args": args} + + a = _Action(name) + + assert _asdict(a) == expected_dict + + def test_action_to_from_string(self) -> None: + """Test that the action can be turned into a string and a string can be converted to an + action.""" + name = "test" + args = None + expected_dict_str = _json.dumps({"name": name, "args": args}) + + a = _Action(name) + + # Testing the action can be converted to a string as expected + assert str(a) == expected_dict_str + + # Testing that we can convert bytes to an Action + # This also tests Action.from_json_string. + a2 = _Action.from_bytes(expected_dict_str.encode()) + + assert a2 is not None + if a2 is not None: # This is just for mypy + assert a.name == a2.name + assert a.args == a2.args + + json_errors = [ + pytest.param( + "action_1", + 'Unable to convert "action_1" to json. The following exception was raised:', + id="NonvalidJSON", + ), + pytest.param( + '{"foo": "bar"}', + "Unable to convert the json dictionary ({'foo': 'bar'}) to an action. The following " + "exception was raised:", + id="NonvalidKeys", + ), + ] + + @pytest.mark.parametrize("json_str, expected_error", json_errors) + def test_action_from_nonvalid_string( + self, json_str: str, expected_error: str, capsys: _CaptureFixture + ) -> None: + """Testing that exceptions were raised properly when attempting to convert a string to an + action.""" + a = _Action.from_json_string(json_str) + + assert a is None + assert expected_error in capsys.readouterr().err diff --git a/test/openjd/adaptor_runtime_client/unit/test_client_interface.py b/test/openjd/adaptor_runtime_client/unit/test_client_interface.py new file mode 100644 index 0000000..49982ad --- /dev/null +++ b/test/openjd/adaptor_runtime_client/unit/test_client_interface.py @@ -0,0 +1,380 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from http import HTTPStatus +import json +from types import FrameType as _FrameType +from typing import ( + Any as _Any, + Dict as _Dict, + List as _List, + Optional as _Optional, +) +from unittest import mock +from urllib.parse import urlencode + +import pytest +from _pytest.capture import CaptureFixture as _CaptureFixture + +from openjd.adaptor_runtime_client import ( + Action as _Action, + HTTPClientInterface as _HTTPClientInterface, + PathMappingRule as _PathMappingRule, +) + + +class FakeClient(_HTTPClientInterface): + """Since we need to override the DCC Client (because it's an interface). + We are going to use this FakeClient for our testing. + """ + + def __init__(self, socket_path: str) -> None: + super().__init__(socket_path) + self.actions.update({"hello_world": self.hello_world}) + + def hello_world(self, args: _Optional[_Dict[str, _Any]]) -> None: + print(f"args = {args}") + + def graceful_shutdown(self, signum: int, frame: _Optional[_FrameType]) -> None: + print("Received SIGTERM signal.") + + # This function needs to be overridden. + def close(self, args: _Optional[_Dict[str, _Any]]) -> None: + pass + + +class TestClientInterface: + @pytest.mark.parametrize( + argnames=("original_path", "new_path"), + argvalues=[ + ("original/path", "new/path"), + ("🌚\\πŸŒ’\\πŸŒ“\\πŸŒ”\\🌝\\πŸŒ–\\πŸŒ—\\🌘\\🌚", "🌝/πŸŒ–/πŸŒ—/🌘/🌚/πŸŒ’/πŸŒ“/πŸŒ”/🌝"), + ], + ) + @mock.patch("http.client.HTTPConnection.close") + @mock.patch("http.client.HTTPConnection.request") + @mock.patch("http.client.HTTPConnection.getresponse") + def test_map_path( + self, + mocked_HTTPConnection_getresponse: mock.Mock, + mocked_HTTPConnection_request: mock.Mock, + mocked_HTTPConnection_close: mock.Mock, + original_path: str, + new_path: str, + ) -> None: + # GIVEN + mocked_response = mock.Mock() + mocked_response.status = 200 + mocked_response.read.return_value = json.dumps({"path": new_path}).encode("utf-8") + mocked_response.length = len(mocked_response.read.return_value) + mocked_HTTPConnection_getresponse.return_value = mocked_response + + dcc_client = FakeClient(socket_path="socket_path") + + # WHEN + mapped = dcc_client.map_path(original_path) + + # THEN + assert mapped == new_path + mocked_HTTPConnection_request.assert_has_calls( + [ + mock.call( + "GET", + "/path_mapping?" + urlencode({"path": original_path}), + headers={"Content-type": "application/json"}, + ), + ] + ) + mocked_HTTPConnection_close.assert_has_calls( + [ + mock.call(), + ] + ) + + @pytest.mark.parametrize( + argnames=("rules"), + argvalues=[ + ( + [ + { + "source_os": "one", + "source_path": "here", + "destination_os": "two", + "destination_path": "there", + } + ] + ), + ], + ) + @mock.patch("http.client.HTTPConnection.close") + @mock.patch("http.client.HTTPConnection.request") + @mock.patch("http.client.HTTPConnection.getresponse") + def test_path_mapping_rules( + self, + mocked_HTTPConnection_getresponse: mock.Mock, + mocked_HTTPConnection_request: mock.Mock, + mocked_HTTPConnection_close: mock.Mock, + rules: _List[_Any], + ) -> None: + # GIVEN + mocked_response = mock.Mock() + mocked_response.status = 200 + mocked_response.read.return_value = json.dumps({"path_mapping_rules": rules}).encode( + "utf-8" + ) + mocked_response.length = len(mocked_response.read.return_value) + mocked_HTTPConnection_getresponse.return_value = mocked_response + + dcc_client = FakeClient(socket_path="socket_path") + + # WHEN + expected = dcc_client.path_mapping_rules() + + # THEN + assert len(expected) == len(rules) + for i in range(0, len(expected)): + assert _PathMappingRule(**rules[i]) == expected[i] + + mocked_HTTPConnection_request.assert_has_calls( + [ + mock.call( + "GET", + "/path_mapping_rules", + headers={"Content-type": "application/json"}, + ), + ] + ) + mocked_HTTPConnection_close.assert_has_calls( + [ + mock.call(), + ] + ) + + @mock.patch("http.client.HTTPConnection.close") + @mock.patch("http.client.HTTPConnection.request") + @mock.patch("http.client.HTTPConnection.getresponse") + def test_path_mapping_rules_throws_nonvalid_json( + self, + mock_getresponse: mock.Mock, + mock_request: mock.Mock, + mock_close: mock.Mock, + ): + # GIVEN + mock_response = mock.Mock() + mock_response.status = HTTPStatus.OK + mock_response.read.return_value = "bad json".encode("utf-8") + mock_getresponse.return_value = mock_response + client = FakeClient(socket_path="socket_path") + + # WHEN + with pytest.raises(RuntimeError) as raised_err: + client.path_mapping_rules() + + # THEN + assert "Expected JSON string from /path_mapping_rules endpoint, but got error: " in str( + raised_err.value + ) + mock_request.assert_called_once_with( + "GET", "/path_mapping_rules", headers={"Content-type": "application/json"} + ) + mock_getresponse.assert_called_once() + mock_close.assert_called_once() + + @mock.patch("http.client.HTTPConnection.close") + @mock.patch("http.client.HTTPConnection.request") + @mock.patch("http.client.HTTPConnection.getresponse") + def test_path_mapping_rules_throws_not_list( + self, + mock_getresponse: mock.Mock, + mock_request: mock.Mock, + mock_close: mock.Mock, + ): + # GIVEN + response_val = {"path_mapping_rules": "this-is-not-a-list"} + mock_response = mock.Mock() + mock_response.status = HTTPStatus.OK + mock_response.read.return_value = json.dumps(response_val).encode("utf-8") + mock_getresponse.return_value = mock_response + client = FakeClient(socket_path="socket_path") + + # WHEN + with pytest.raises(RuntimeError) as raised_err: + client.path_mapping_rules() + + # THEN + assert ( + f"Expected list for path_mapping_rules, but got: {response_val['path_mapping_rules']}" + in str(raised_err.value) + ) + mock_request.assert_called_once_with( + "GET", "/path_mapping_rules", headers={"Content-type": "application/json"} + ) + mock_getresponse.assert_called_once() + mock_close.assert_called_once() + + @mock.patch("http.client.HTTPConnection.close") + @mock.patch("http.client.HTTPConnection.request") + @mock.patch("http.client.HTTPConnection.getresponse") + def test_path_mapping_rules_throws_not_path_mapping_rule( + self, + mock_getresponse: mock.Mock, + mock_request: mock.Mock, + mock_close: mock.Mock, + ): + # GIVEN + response_val = {"path_mapping_rules": ["not-a-rule-dict"]} + mock_response = mock.Mock() + mock_response.status = HTTPStatus.OK + mock_response.read.return_value = json.dumps(response_val).encode("utf-8") + mock_getresponse.return_value = mock_response + client = FakeClient(socket_path="socket_path") + + # WHEN + with pytest.raises(RuntimeError) as raised_err: + client.path_mapping_rules() + + # THEN + assert ( + f"Expected PathMappingRule object, but got: not-a-rule-dict\nAll rules: {response_val['path_mapping_rules']}" + in str(raised_err.value) + ) + mock_request.assert_called_once_with( + "GET", "/path_mapping_rules", headers={"Content-type": "application/json"} + ) + mock_getresponse.assert_called_once() + mock_close.assert_called_once() + + @mock.patch("http.client.HTTPConnection.close") + @mock.patch("http.client.HTTPConnection.request") + @mock.patch("http.client.HTTPConnection.getresponse") + def test_map_path_error( + self, + mocked_HTTPConnection_getresponse: mock.Mock, + mocked_HTTPConnection_request: mock.Mock, + mocked_HTTPConnection_close: mock.Mock, + ) -> None: + # GIVEN + ORIGINAL_PATH = "some/path" + REASON = "Could not process request." + mocked_response = mock.Mock() + mocked_response.status = 500 + mocked_response.read.return_value = REASON.encode("utf-8") + mocked_response.length = len(mocked_response.read.return_value) + mocked_HTTPConnection_getresponse.return_value = mocked_response + + dcc_client = FakeClient(socket_path="socket_path") + + # WHEN + with pytest.raises(RuntimeError) as exc_info: + dcc_client.map_path(ORIGINAL_PATH) + + # THEN + mocked_HTTPConnection_request.assert_has_calls( + [ + mock.call( + "GET", + "/path_mapping?" + urlencode({"path": ORIGINAL_PATH}), + headers={"Content-type": "application/json"}, + ), + ] + ) + mocked_HTTPConnection_close.assert_has_calls( + [ + mock.call(), + ] + ) + assert str(exc_info.value) == ( + f"ERROR: Failed to get a mapped path for path '{ORIGINAL_PATH}'. " + f"Server response: Status: {mocked_response.status}, Response: '{REASON}'" + ) + + @mock.patch("http.client.HTTPConnection.close") + @mock.patch("http.client.HTTPConnection.request") + @mock.patch("http.client.HTTPConnection.getresponse") + def test_request_next_action( + self, + mocked_HTTPConnection_getresponse: mock.Mock, + mocked_HTTPConnection_request: mock.Mock, + mocked_HTTPConnection_close: mock.Mock, + ) -> None: + mocked_response = mock.Mock() + mocked_response.status = "mocked_status" + mocked_response.reason = "mocked_reason" + mocked_response.length = None + + mocked_HTTPConnection_getresponse.return_value = mocked_response + + socket_path = "socket_path" + dcc_client = FakeClient(socket_path) + assert dcc_client.socket_path == socket_path + status, reason, action = dcc_client._request_next_action() + + assert action is None + + a1 = _Action("a1") + bytes_a1 = bytes(str(a1), "utf-8") + + mocked_response.read.return_value = bytes_a1 + mocked_response.length = len(bytes_a1) + + status, reason, action = dcc_client._request_next_action() + mocked_HTTPConnection_request.assert_has_calls( + [ + mock.call("GET", "/action", headers={"Content-type": "application/json"}), + mock.call("GET", "/action", headers={"Content-type": "application/json"}), + ] + ) + mocked_HTTPConnection_close.assert_has_calls( + [ + mock.call(), + mock.call(), + ] + ) + + assert status == "mocked_status" + assert reason == "mocked_reason" + assert str(action) == str(a1) + + @mock.patch.object(_HTTPClientInterface, "_perform_action") + def test_poll(self, mocked_perform_action: mock.Mock, capsys: _CaptureFixture) -> None: + a1 = _Action("render", {"arg1": "val1"}) + a2 = _Action("close") + + with mock.patch.object( + _HTTPClientInterface, + "_request_next_action", + side_effect=[ + (404, "Not found", a1), + (200, "OK", None), + (200, "OK", a1), + (200, "OK", a2), + ], + ): + dcc_client = FakeClient(socket_path="socket_path") + dcc_client.poll() + + mocked_perform_action.assert_has_calls([mock.call(a1), mock.call(a2)]) + + assert ( + "An error was raised when trying to connect to the server: 404 Not found\n" + in capsys.readouterr().err + ) + + def test_perform_action(self) -> None: + a1 = _Action("hello_world", {"arg1": "Hello!", "arg2": "How are you?"}) + + with mock.patch.object(FakeClient, "hello_world") as mocked_hello_world: + dcc_client = FakeClient(socket_path="socket_path") + dcc_client._perform_action(a1) + + mocked_hello_world.assert_called_once_with(a1.args) + + def test_perform_nonvalid_action(self, capsys: _CaptureFixture) -> None: + a2 = _Action("nonvalid") + dcc_client = FakeClient(socket_path="socket_path") + dcc_client._perform_action(a2) + + assert ( + capsys.readouterr().err + == f"ERROR: Attempted to perform the following action: {a2}. But this action doesn't " + "exist in the actions dictionary.\n" + ) diff --git a/test/openjd/test_copyright_header.py b/test/openjd/test_copyright_header.py new file mode 100644 index 0000000..faa445b --- /dev/null +++ b/test/openjd/test_copyright_header.py @@ -0,0 +1,79 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import re +from pathlib import Path + +# For distributed open source and proprietary code, we must include a copyright header in source every file: +_copyright_header_re = re.compile( + r"Copyright Amazon\.com, Inc\. or its affiliates\. All Rights Reserved\.", re.IGNORECASE +) +_generated_by_scm = re.compile(r"# file generated by setuptools_scm", re.IGNORECASE) + + +def _check_file(filename: Path) -> None: + with open(filename) as infile: + lines_read = 0 + for line in infile: + if _copyright_header_re.search(line): + return # success + lines_read += 1 + if lines_read > 10: + raise Exception( + f"Could not find a valid Amazon.com copyright header in the top of {filename}." + " Please add one." + ) + else: + # __init__.py files are usually empty, this is to catch that. + raise Exception( + f"Could not find a valid Amazon.com copyright header in the top of {filename}." + " Please add one." + ) + + +def _is_version_file(filename: Path) -> bool: + if filename.name != "_version.py": + return False + with open(filename) as infile: + lines_read = 0 + for line in infile: + if _generated_by_scm.search(line): + return True + lines_read += 1 + if lines_read > 10: + break + return False + + +def test_copyright_headers(): + """Verifies every .py file has an Amazon copyright header.""" + root_project_dir = Path(__file__) + # The root of the project is the directory that contains the test directory. + while not (root_project_dir / "test").exists(): + root_project_dir = root_project_dir.parent + # Choose only a few top level directories to test. + # That way we don't snag any virtual envs a developer might create, at the risk of missing + # some top level .py files. + # Additionally, ignore any files in the `node_modules` directory that we use in the VS Code + # extension. + top_level_dirs = [ + "src", + "test", + "scripts", + "testing_containers", + "openjdvscode!(/node_modules)", + ] + file_count = 0 + for top_level_dir in top_level_dirs: + for glob_pattern in ("**/*.py", "**/*.sh", "**/Dockerfile", "**/*.ts"): + for path in Path(root_project_dir / top_level_dir).glob(glob_pattern): + print(path) + if not _is_version_file(path): + _check_file(path) + file_count += 1 + + print(f"test_copyright_headers checked {file_count} files successfully.") + assert file_count > 0, "Test misconfiguration" + + +if __name__ == "__main__": + test_copyright_headers() diff --git a/test/openjd/test_importable.py b/test/openjd/test_importable.py new file mode 100644 index 0000000..f712b3d --- /dev/null +++ b/test/openjd/test_importable.py @@ -0,0 +1,5 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + +def test_openjd_importable(): + import openjd # noqa: F401