From a065d5824cf30d0a07cac34726c878b22633f928 Mon Sep 17 00:00:00 2001 From: Matthias Veit Date: Mon, 15 Jul 2024 18:09:54 +0200 Subject: [PATCH 1/3] [feat] Add post-collect --- Dockerfile | 4 +- collect_single/__main__.py | 131 ---------------- ...{collect_and_sync.py => collect_single.py} | 148 +++++++++++++----- collect_single/core_client.py | 18 ++- collect_single/job.py | 53 +++++++ collect_single/post_collect.py | 131 ++++++++++++++++ collect_single_shim | 4 - dispatch_executable_shim.sh | 20 +++ pyproject.toml | 3 +- tests/collect_and_sync_test.py | 8 +- tests/conftest.py | 6 +- tests/core_client_test.py | 6 +- 12 files changed, 345 insertions(+), 187 deletions(-) delete mode 100644 collect_single/__main__.py rename collect_single/{collect_and_sync.py => collect_single.py} (67%) create mode 100644 collect_single/job.py create mode 100644 collect_single/post_collect.py delete mode 100755 collect_single_shim create mode 100755 dispatch_executable_shim.sh diff --git a/Dockerfile b/Dockerfile index 29a3b11..51230d7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,7 +6,7 @@ ADD . /single_coordinator RUN . /usr/local/fix-venv-python3/bin/activate && pip install /single_coordinator && rm -rf /single_coordinator # Add shim and create symlink -COPY collect_single_shim /usr/local/bin/collect_single_shim +COPY dispatch_executable_shim.sh /usr/local/bin/dispatch_executable_shim RUN chmod 755 /usr/local/bin/collect_single_shim && ln -s /usr/local/bin/collect_single_shim /usr/bin/collect_single -ENTRYPOINT ["/bin/dumb-init", "--", "/usr/local/sbin/bootstrap", "/usr/bin/collect_single"] +ENTRYPOINT ["/bin/dumb-init", "--", "/usr/local/sbin/bootstrap", "/usr/bin/dispatch_executable_shim"] diff --git a/collect_single/__main__.py b/collect_single/__main__.py deleted file mode 100644 index 5fd94fc..0000000 --- a/collect_single/__main__.py +++ /dev/null @@ -1,131 +0,0 @@ -# fix-collect-single -# Copyright (C) 2023 Some Engineering -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . - -import asyncio -import logging -import os -import sys -from argparse import Namespace, ArgumentParser -from datetime import timedelta -from itertools import takewhile -from pathlib import Path -from typing import List, Tuple, Dict - -from fixcloudutils.logging import setup_logger -from fixcloudutils.util import utc -from redis.asyncio import Redis - -from collect_single.collect_and_sync import CollectAndSync - -log = logging.getLogger("fix.coordinator") - - -def kv_pairs(s: str) -> Tuple[str, str]: - return tuple(s.split("=", maxsplit=1)) # type: ignore - - -async def startup( - args: Namespace, core_args: List[str], worker_args: List[str], logging_context: Dict[str, str] -) -> None: - redis_args = {} - if args.redis_password: - redis_args["password"] = args.redis_password - if args.redis_url.startswith("rediss://") and args.ca_cert: - redis_args["ssl_ca_certs"] = args.ca_cert - async with Redis.from_url(args.redis_url, decode_responses=True, **redis_args) as redis: - - async def collect_and_sync(send_on_failed: bool) -> bool: - async with CollectAndSync( - redis=redis, - tenant_id=args.tenant_id, - cloud=args.cloud, - account_id=args.account_id, - job_id=args.job_id, - core_args=core_args, - worker_args=worker_args, - push_gateway_url=args.push_gateway_url, - logging_context=logging_context, - ) as cas: - return await cas.sync(send_on_failed) - - if retry := args.retry_failed_for: - log.info(f"Collect job with retry enabled for {retry}.") - has_result = False - deadline = utc() + retry - while not has_result and utc() < deadline: - # collect and do not send a message in the failing case - has_result = await collect_and_sync(False) - if not has_result: - log.info("Failed collect with retry enabled. Retrying in 30s.") - await asyncio.sleep(30) - # if we come here without a result, collect and also send a message in the failing case - if not has_result: - log.info("Last attempt to collect with retry enabled.") - await collect_and_sync(True) - else: - await collect_and_sync(True) - - -def main() -> None: - # 3 argument sets delimited by "---": --- --- - # coordinator --main-arg1 --main-arg2 --- --core-arg1 --core-arg2 --- --worker-arg1 --worker-arg2 - args = iter(sys.argv[1:]) - coordinator_args = list(takewhile(lambda x: x != "---", args)) - core_args = list(takewhile(lambda x: x != "---", args)) - worker_args = list(args) - # handle coordinator args - parser = ArgumentParser() - parser.add_argument( - "--write", - type=kv_pairs, - help="Write config files in home dir from env vars. Format: --write path/in/home/dir=env-var-name", - default=[], - action="append", - ) - parser.add_argument("--job-id", required=True, help="Job Id of the coordinator") - parser.add_argument("--tenant-id", required=True, help="Id of the tenant") - parser.add_argument("--account-id", help="Id of the account") - parser.add_argument("--cloud", help="Cloud provider.") - parser.add_argument("--redis-url", default="redis://localhost:6379/0", help="Redis host.") - parser.add_argument("--redis-password", default=os.environ.get("REDIS_PASSWORD"), help="Redis password") - parser.add_argument("--push-gateway-url", help="Prometheus push gateway url") - parser.add_argument("--ca-cert", help="Path to CA cert file") - parser.add_argument( - "--retry-failed-for", type=lambda x: timedelta(seconds=float(x)), help="Seconds to retry failed jobs." - ) - parsed = parser.parse_args(coordinator_args) - - # setup logging - logging_context = dict(job_id=parsed.job_id, workspace_id=parsed.tenant_id, cloud_account_id=parsed.account_id) - setup_logger("collect-single", get_logging_context=lambda: {"process": "coordinator", **logging_context}) - - # write config files from env vars - env_vars = {k.lower(): v for k, v in os.environ.items()} - for home_path, env_var_name in parsed.write: - path = (Path.home() / Path(home_path)).absolute() - content = env_vars.get(env_var_name.lower()) - assert content is not None, f"Env var {env_var_name} not found" - log.info(f"Writing file: {path} from env var: {env_var_name}") - path.parent.mkdir(parents=True, exist_ok=True) - with open(path, "w+") as f: - f.write(content) - - log.info(f"Coordinator args:({coordinator_args}) Core args:({core_args}) Worker args:({worker_args})") - asyncio.run(startup(parsed, core_args, worker_args, logging_context)) - - -if __name__ == "__main__": - main() diff --git a/collect_single/collect_and_sync.py b/collect_single/collect_single.py similarity index 67% rename from collect_single/collect_and_sync.py rename to collect_single/collect_single.py index c8f787c..7d9d526 100644 --- a/collect_single/collect_and_sync.py +++ b/collect_single/collect_single.py @@ -16,33 +16,37 @@ import asyncio import logging +import os +import sys +from argparse import Namespace, ArgumentParser from datetime import timedelta +from itertools import takewhile from pathlib import Path -from typing import List, Optional, Any, Tuple, Dict, cast +from typing import List, Tuple, Dict +from typing import Optional, Any, cast import yaml from arango.cursor import Cursor -from fixcloudutils.redis.event_stream import RedisStreamPublisher, Json +from fixcloudutils.logging import setup_logger +from fixcloudutils.redis.event_stream import Json from fixcloudutils.redis.lock import Lock -from fixcloudutils.redis.pub_sub import RedisPubSubPublisher -from fixcloudutils.service import Service -from fixcloudutils.util import utc, utc_str +from fixcloudutils.util import utc +from fixcloudutils.util import utc_str from fixcore.core_config import parse_config from fixcore.db.async_arangodb import AsyncArangoDB from fixcore.db.db_access import DbAccess from fixcore.db.timeseriesdb import TimeSeriesDB from fixcore.system_start import parse_args as core_parse_args from redis.asyncio import Redis -import prometheus_client -from collect_single.core_client import CoreClient +from collect_single.job import Job from collect_single.model import MetricQuery from collect_single.process import ProcessWrapper log = logging.getLogger("fix.coordinator") -class CollectAndSync(Service): +class CollectSingle(Job): def __init__( self, *, @@ -55,36 +59,20 @@ def __init__( worker_args: List[str], logging_context: Dict[str, str], push_gateway_url: Optional[str] = None, - core_url: str = "http://localhost:8980", ) -> None: - self.redis = redis - self.tenant_id = tenant_id + super().__init__(redis=redis, job_id=job_id, tenant_id=tenant_id, push_gateway_url=push_gateway_url) self.cloud = cloud self.account_id = account_id - self.job_id = job_id self.core_args = ["--no-scheduling", "--ignore-interrupted-tasks"] + core_args self.worker_args = worker_args self.logging_context = logging_context - self.core_client = CoreClient(core_url) self.task_id: Optional[str] = None - self.push_gateway_url = push_gateway_url - publisher = "collect-and-sync" - self.progress_update_publisher = RedisPubSubPublisher(redis, f"tenant-events::{tenant_id}", publisher) - self.collect_done_publisher = RedisStreamPublisher(redis, "collect-events", publisher) - self.started_at = utc() self.worker_connected = asyncio.Event() self.metrics: List[MetricQuery] = [] async def start(self) -> Any: - await self.progress_update_publisher.start() - await self.collect_done_publisher.start() + await super().start() self.metrics = self.load_metrics() - # note: the client is not started (core is not running and no certificate required) - - async def stop(self) -> None: - await self.core_client.stop() - await self.progress_update_publisher.stop() - await self.collect_done_publisher.stop() @staticmethod def load_metrics() -> List[MetricQuery]: @@ -140,7 +128,7 @@ async def post_process(self) -> Tuple[Json, List[str]]: # synchronize the security section benchmarks = await self.core_client.list_benchmarks(providers=[self.cloud] if self.cloud else None) if benchmarks: - await self.core_client.create_benchmark_reports(account_id, benchmarks, self.task_id) + await self.core_client.create_benchmark_reports([account_id], benchmarks, self.task_id) # create metrics for metric in self.metrics: res = await self.core_client.timeseries_snapshot(metric, account_id) @@ -170,14 +158,6 @@ async def send_result_events(self, read_from_process: bool, error_messages: Opti }, ) - async def push_metrics(self) -> None: - if gateway := self.push_gateway_url: - # Possible future option: retrieve metrics from core and worker and push them to prometheus - prometheus_client.push_to_gateway( - gateway=gateway, job="collect_single", registry=prometheus_client.REGISTRY - ) - log.info("Metrics pushed to gateway") - async def migrate_ts_data(self) -> None: ts_with_account = "for doc in ts filter doc.group.account!=null" update = ( @@ -243,3 +223,101 @@ async def sync(self, send_on_failed: bool) -> bool: await asyncio.wait_for(self.send_result_events(False, [str(ex)]), 600) # wait up to 10 minutes result_send = True return result_send + + +async def startup( + args: Namespace, core_args: List[str], worker_args: List[str], logging_context: Dict[str, str] +) -> None: + redis_args = {} + if args.redis_password: + redis_args["password"] = args.redis_password + if args.redis_url.startswith("rediss://") and args.ca_cert: + redis_args["ssl_ca_certs"] = args.ca_cert + async with Redis.from_url(args.redis_url, decode_responses=True, **redis_args) as redis: + + async def collect_and_sync(send_on_failed: bool) -> bool: + async with CollectSingle( + redis=redis, + tenant_id=args.tenant_id, + cloud=args.cloud, + account_id=args.account_id, + job_id=args.job_id, + core_args=core_args, + worker_args=worker_args, + push_gateway_url=args.push_gateway_url, + logging_context=logging_context, + ) as cas: + return await cas.sync(send_on_failed) + + if retry := args.retry_failed_for: + log.info(f"Collect job with retry enabled for {retry}.") + has_result = False + deadline = utc() + retry + while not has_result and utc() < deadline: + # collect and do not send a message in the failing case + has_result = await collect_and_sync(False) + if not has_result: + log.info("Failed collect with retry enabled. Retrying in 30s.") + await asyncio.sleep(30) + # if we come here without a result, collect and also send a message in the failing case + if not has_result: + log.info("Last attempt to collect with retry enabled.") + await collect_and_sync(True) + else: + await collect_and_sync(True) + + +def kv_pairs(s: str) -> Tuple[str, str]: + return tuple(s.split("=", maxsplit=1)) # type: ignore + + +def main() -> None: + # 3 argument sets delimited by "---": --- --- + # coordinator --main-arg1 --main-arg2 --- --core-arg1 --core-arg2 --- --worker-arg1 --worker-arg2 + args = iter(sys.argv[1:]) + coordinator_args = list(takewhile(lambda x: x != "---", args)) + core_args = list(takewhile(lambda x: x != "---", args)) + worker_args = list(args) + # handle coordinator args + parser = ArgumentParser() + parser.add_argument( + "--write", + type=kv_pairs, + help="Write config files in home dir from env vars. Format: --write path/in/home/dir=env-var-name", + default=[], + action="append", + ) + parser.add_argument("--job-id", required=True, help="Job Id of the coordinator") + parser.add_argument("--tenant-id", required=True, help="Id of the tenant") + parser.add_argument("--account-id", help="Id of the account") + parser.add_argument("--cloud", help="Cloud provider.") + parser.add_argument("--redis-url", default="redis://localhost:6379/0", help="Redis host.") + parser.add_argument("--redis-password", default=os.environ.get("REDIS_PASSWORD"), help="Redis password") + parser.add_argument("--push-gateway-url", help="Prometheus push gateway url") + parser.add_argument("--ca-cert", help="Path to CA cert file") + parser.add_argument( + "--retry-failed-for", type=lambda x: timedelta(seconds=float(x)), help="Seconds to retry failed jobs." + ) + parsed = parser.parse_args(coordinator_args) + + # setup logging + logging_context = dict(job_id=parsed.job_id, workspace_id=parsed.tenant_id, cloud_account_id=parsed.account_id) + setup_logger("collect-single", get_logging_context=lambda: {"process": "coordinator", **logging_context}) + + # write config files from env vars + env_vars = {k.lower(): v for k, v in os.environ.items()} + for home_path, env_var_name in parsed.write: + path = (Path.home() / Path(home_path)).absolute() + content = env_vars.get(env_var_name.lower()) + assert content is not None, f"Env var {env_var_name} not found" + log.info(f"Writing file: {path} from env var: {env_var_name}") + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w+") as f: + f.write(content) + + log.info(f"Coordinator args:({coordinator_args}) Core args:({core_args}) Worker args:({worker_args})") + asyncio.run(startup(parsed, core_args, worker_args, logging_context)) + + +if __name__ == "__main__": + main() diff --git a/collect_single/core_client.py b/collect_single/core_client.py index 62c433e..3caa939 100644 --- a/collect_single/core_client.py +++ b/collect_single/core_client.py @@ -71,11 +71,14 @@ async def list_benchmarks(self, *, providers: Optional[List[str]] = None) -> Lis else: raise AttributeError(await response.text()) - async def create_benchmark_reports(self, account_id: str, benchmarks: List[str], task_id: Optional[str]) -> None: + async def create_benchmark_reports( + self, account_ids: List[str], benchmarks: List[str], run_id: Optional[str] + ) -> None: bn = " ".join(benchmarks) - run_id = task_id or str(uuid.uuid4()) - command = f"report benchmark run {bn} --accounts {account_id} --sync-security-section --run-id {run_id} | count" - log.info(f"Create reports for following benchmarks: {bn} for accounts: {account_id}. Command: {command}") + an = " ".join(account_ids) + rid = run_id or str(uuid.uuid4()) + command = f"report benchmark run {bn} --accounts {an} --sync-security-section --run-id {rid} | count" + log.info(f"Create reports for following benchmarks: {bn} for accounts: {an}. Command: {command}") async for _ in self.client.cli_execute(command, headers={"Accept": "application/json"}): pass # ignore the result @@ -138,3 +141,10 @@ async def account_id_by_name(self) -> Dict[Optional[str], Optional[str]]: value_in_path(r, "reported.name"): value_in_path(r, "reported.id") async for r in self.client.cli_execute("search is(account) | dump") } + + async def merge_deferred_edges(self, task_ids: List[str], *, graph: str = "fix") -> None: + response = await self.client._post(f"/graph/{graph}/merge/deferred_edges", json=task_ids) + if response.status_code == 200: + return await response.json() # type: ignore + else: + raise AttributeError(await response.text()) diff --git a/collect_single/job.py b/collect_single/job.py new file mode 100644 index 0000000..5d8e647 --- /dev/null +++ b/collect_single/job.py @@ -0,0 +1,53 @@ +import logging +from abc import ABC +from typing import Optional, Any + +import prometheus_client +from fixcloudutils.redis.event_stream import RedisStreamPublisher +from fixcloudutils.redis.pub_sub import RedisPubSubPublisher +from fixcloudutils.service import Service +from fixcloudutils.util import utc +from redis.asyncio.client import Redis + +from collect_single.core_client import CoreClient + +log = logging.getLogger("fix.coordinator") + + +class Job(Service, ABC): + def __init__( + self, + *, + redis: Redis, # type: ignore + job_id: str, + tenant_id: str, + push_gateway_url: Optional[str] = None, + core_url: str = "http://localhost:8980", + ) -> None: + self.redis = redis + self.job_id = job_id + self.tenant_id = tenant_id + self.push_gateway_url = push_gateway_url + self.core_client = CoreClient(core_url) + publisher = "collect-and-sync" + self.progress_update_publisher = RedisPubSubPublisher(redis, f"tenant-events::{tenant_id}", publisher) + self.collect_done_publisher = RedisStreamPublisher(redis, "collect-events", publisher) + self.started_at = utc() + + async def start(self) -> Any: + await self.progress_update_publisher.start() + await self.collect_done_publisher.start() + # note: the client is not started (core is not running and no certificate required) + + async def stop(self) -> None: + await self.progress_update_publisher.stop() + await self.collect_done_publisher.stop() + await self.core_client.stop() + + async def push_metrics(self) -> None: + if gateway := self.push_gateway_url: + # Possible future option: retrieve metrics from core and worker and push them to prometheus + prometheus_client.push_to_gateway( + gateway=gateway, job="collect_single", registry=prometheus_client.REGISTRY + ) + log.info("Metrics pushed to gateway") diff --git a/collect_single/post_collect.py b/collect_single/post_collect.py new file mode 100644 index 0000000..95d1aec --- /dev/null +++ b/collect_single/post_collect.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import asyncio +import logging +import os +import sys +from argparse import ArgumentParser, Namespace +from itertools import takewhile +from typing import List, Dict +from typing import Optional + +from attr import define +from fixcloudutils.logging import setup_logger +from fixcloudutils.util import utc, utc_str +from redis.asyncio.client import Redis + +from collect_single.job import Job +from collect_single.process import ProcessWrapper + +log = logging.getLogger("fix.coordinator") + + +class PostCollect(Job): + def __init__( + self, + *, + redis: Redis, # type: ignore + tenant_id: str, + job_id: str, + accounts_collected: List[AccountCollected], + core_args: List[str], + logging_context: Dict[str, str], + push_gateway_url: Optional[str] = None, + ) -> None: + super().__init__(redis=redis, job_id=job_id, tenant_id=tenant_id, push_gateway_url=push_gateway_url) + self.accounts_collected = accounts_collected + self.core_args = core_args + self.logging_context = logging_context + + async def send_result_events(self, exception: Optional[Exception] = None) -> None: + # send a collect done event for the tenant + await self.collect_done_publisher.publish( + "post-collect-done", + { + "job_id": self.job_id, + "tenant_id": self.tenant_id, + "started_at": utc_str(self.started_at), + "duration": int((utc() - self.started_at).total_seconds()), + "success": exception is None, + "exception": str(exception) if exception else None, + }, + ) + + async def merge_deferred_edges(self) -> None: + await self.core_client.merge_deferred_edges([ac.task_id for ac in self.accounts_collected]) + + async def security_report(self) -> None: + for acc in self.accounts_collected: + benchmarks = await self.core_client.list_benchmarks(providers=[acc.cloud]) + if benchmarks: + await self.core_client.create_benchmark_reports([acc.account_id], benchmarks, acc.task_id) + + async def sync(self) -> None: + try: + if self.accounts_collected: # Don't do anything if no accounts were collected + async with ProcessWrapper(["fixcore", *self.core_args], self.logging_context): + log.info("Core started.") + await asyncio.wait_for(self.core_client.wait_connected(), timeout=60) + log.info("Core Client connected.") + await self.merge_deferred_edges() + log.info("All deferred edges have been updated.") + await self.security_report() + log.info("Security reports have been synchronized.") + await asyncio.wait_for(self.send_result_events(), 600) # wait up to 10 minutes + except Exception as ex: + log.info(f"Got Exception during sync: {ex}") + await asyncio.wait_for(self.send_result_events(ex), 600) # wait up to 10 minutes + + +async def startup(args: Namespace, core_args: List[str], logging_context: Dict[str, str]) -> None: + redis_args = {} + if args.redis_password: + redis_args["password"] = args.redis_password + if args.redis_url.startswith("rediss://") and args.ca_cert: + redis_args["ssl_ca_certs"] = args.ca_cert + async with Redis.from_url(args.redis_url, decode_responses=True, **redis_args) as redis: + async with PostCollect( + redis=redis, + tenant_id=args.tenant_id, + job_id=args.job_id, + core_args=core_args, + accounts_collected=args.accounts_collected, + logging_context=logging_context, + push_gateway_url=args.push_gateway_url, + ) as post_collect: + await post_collect.sync() + + +@define +class AccountCollected: + cloud: str + account_id: str + task_id: str + + @staticmethod + def from_string(s: str) -> "AccountCollected": + return AccountCollected(*s.split(":")) + + +def main() -> None: + args = iter(sys.argv[1:]) + post_process_args = list(takewhile(lambda x: x != "---", args)) + core_args = list(takewhile(lambda x: x != "---", args)) + parser = ArgumentParser() + parser.add_argument("--job-id", required=True, help="Job Id of the coordinator") + parser.add_argument("--tenant-id", required=True, help="Id of the tenant") + parser.add_argument("--accounts-collected", required=True, nargs="+", type=AccountCollected.from_string) + parser.add_argument("--redis-url", default="redis://localhost:6379/0", help="Redis host.") + parser.add_argument("--redis-password", default=os.environ.get("REDIS_PASSWORD"), help="Redis password") + parser.add_argument("--push-gateway-url", help="Prometheus push gateway url") + parser.add_argument("--ca-cert", help="Path to CA cert file") + parsed = parser.parse_args(post_process_args) + + # setup logging + logging_context = dict(job_id=parsed.job_id, workspace_id=parsed.tenant_id) + setup_logger("post-collect", get_logging_context=lambda: {"process": "post-collect", **logging_context}) + asyncio.run(startup(parsed, core_args, logging_context)) + + +if __name__ == "__main__": + main() diff --git a/collect_single_shim b/collect_single_shim deleted file mode 100755 index 314a7a4..0000000 --- a/collect_single_shim +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash -. /usr/local/etc/fixinventory/defaults -. /usr/local/fix-venv-python3/bin/activate -exec collect_single "$@" diff --git a/dispatch_executable_shim.sh b/dispatch_executable_shim.sh new file mode 100755 index 0000000..53e42bd --- /dev/null +++ b/dispatch_executable_shim.sh @@ -0,0 +1,20 @@ +#!/bin/bash +. /usr/local/etc/fixinventory/defaults +. /usr/local/fix-venv-python3/bin/activate + +# dispatch to the correct console script +case $1 in + collect) + shift + exec collect_single "$@" + ;; + post-collect) + shift + exec post_collect "$@" + ;; + *) + # backward compatibility: delegate to collect_single + exec collect_single "$@" + ;; +esac + diff --git a/pyproject.toml b/pyproject.toml index 03b0ce7..509cbce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,8 @@ test = [ [project.scripts] -collect_single = "collect_single.__main__:main" +collect_single = "collect_single.collect_single:main" +post_collect = "collect_single.post_collect:main" [build-system] requires = ["setuptools>=67.8.0", "wheel>=0.40.0", "build>=0.10.0"] diff --git a/tests/collect_and_sync_test.py b/tests/collect_and_sync_test.py index 9b1ba21..dc2b3da 100644 --- a/tests/collect_and_sync_test.py +++ b/tests/collect_and_sync_test.py @@ -19,7 +19,7 @@ from fixclient.async_client import FixInventoryClient from fixcore.query import query_parser -from collect_single.collect_and_sync import CollectAndSync +from collect_single.collect_single import CollectSingle @pytest.mark.asyncio @@ -40,18 +40,18 @@ async def test_client(core_client: FixInventoryClient) -> None: @pytest.mark.asyncio @pytest.mark.skip(reason="Only for manual testing") -async def test_collect_and_sync(collect_and_sync: CollectAndSync) -> None: +async def test_collect_and_sync(collect_and_sync: CollectSingle) -> None: await collect_and_sync.send_result_events(True) @pytest.mark.asyncio @pytest.mark.skip(reason="Only for manual testing") -async def test_migrate_ts(collect_and_sync: CollectAndSync) -> None: +async def test_migrate_ts(collect_and_sync: CollectSingle) -> None: await collect_and_sync.migrate_ts_data() def test_load_metrics() -> None: - metrics = CollectAndSync.load_metrics() + metrics = CollectSingle.load_metrics() assert len(metrics) >= 14 for query in metrics: # make sure the query parser does not explode diff --git a/tests/conftest.py b/tests/conftest.py index ce07103..2038a6c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,7 +22,7 @@ from redis.backoff import ExponentialBackoff from fixclient.async_client import FixInventoryClient -from collect_single.collect_and_sync import CollectAndSync +from collect_single.collect_single import CollectSingle @fixture @@ -46,8 +46,8 @@ async def core_client() -> AsyncIterator[FixInventoryClient]: @fixture -async def collect_and_sync(redis: Redis) -> AsyncIterator[CollectAndSync]: # type: ignore - async with CollectAndSync( +async def collect_and_sync(redis: Redis) -> AsyncIterator[CollectSingle]: # type: ignore + async with CollectSingle( redis=redis, tenant_id="tenant_id", cloud="aws", diff --git a/tests/core_client_test.py b/tests/core_client_test.py index 5fce3be..a7d3ff7 100644 --- a/tests/core_client_test.py +++ b/tests/core_client_test.py @@ -5,7 +5,7 @@ import pytest -from collect_single.collect_and_sync import CollectAndSync +from collect_single.collect_single import CollectSingle from collect_single.core_client import CoreClient @@ -53,7 +53,7 @@ async def test_create_benchmark_report(core_client: CoreClient) -> None: accounts = [a async for a in core_client.client.search_list("is(aws_account) limit 1")] single = accounts[0]["reported"]["id"] task_id = str(uuid.uuid4()) - await core_client.create_benchmark_reports(single, ["aws_cis_1_5"], task_id) + await core_client.create_benchmark_reports([single], ["aws_cis_1_5"], task_id) res = [ a async for a in core_client.client.cli_execute( @@ -74,7 +74,7 @@ async def test_wait_for_collect_task_to_finish(core_client: CoreClient) -> None: async def test_timeseries_snapshot(core_client: CoreClient) -> None: accounts = [a async for a in core_client.client.search_list("is(aws_account) limit 1")] single = accounts[0]["reported"]["id"] - for query in CollectAndSync.load_metrics(): + for query in CollectSingle.load_metrics(): res = await core_client.timeseries_snapshot(query, single) assert res >= 0 From 7d3e136b98449b282a71acba3d775a9189d86247 Mon Sep 17 00:00:00 2001 From: Matthias Veit Date: Tue, 16 Jul 2024 09:33:43 +0200 Subject: [PATCH 2/3] fix dockerfile --- Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 51230d7..2715df7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,6 +7,6 @@ RUN . /usr/local/fix-venv-python3/bin/activate && pip install /single_coordinato # Add shim and create symlink COPY dispatch_executable_shim.sh /usr/local/bin/dispatch_executable_shim -RUN chmod 755 /usr/local/bin/collect_single_shim && ln -s /usr/local/bin/collect_single_shim /usr/bin/collect_single +RUN chmod 755 /usr/local/bin/dispatch_executable_shim && ln -s /usr/local/bin/dispatch_executable_shim /usr/bin/dispatch_executable -ENTRYPOINT ["/bin/dumb-init", "--", "/usr/local/sbin/bootstrap", "/usr/bin/dispatch_executable_shim"] +ENTRYPOINT ["/bin/dumb-init", "--", "/usr/local/sbin/bootstrap", "/usr/bin/dispatch_executable"] From 8f5f3a85b5def53565af8afebac3c2901416db46 Mon Sep 17 00:00:00 2001 From: Matthias Veit Date: Tue, 16 Jul 2024 11:04:21 +0200 Subject: [PATCH 3/3] use json as format --- collect_single/post_collect.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/collect_single/post_collect.py b/collect_single/post_collect.py index 95d1aec..bc7cafa 100644 --- a/collect_single/post_collect.py +++ b/collect_single/post_collect.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import json import logging import os import sys @@ -9,6 +10,7 @@ from typing import List, Dict from typing import Optional +import cattrs from attr import define from fixcloudutils.logging import setup_logger from fixcloudutils.util import utc, utc_str @@ -63,6 +65,8 @@ async def security_report(self) -> None: async def sync(self) -> None: try: if self.accounts_collected: # Don't do anything if no accounts were collected + aids = ", ".join([ac.account_id for ac in self.accounts_collected]) + log.info(f"Sync tenant {self.tenant_id}: with {len(self.accounts_collected)} accounts: {aids}.") async with ProcessWrapper(["fixcore", *self.core_args], self.logging_context): log.info("Core started.") await asyncio.wait_for(self.core_client.wait_connected(), timeout=60) @@ -103,8 +107,8 @@ class AccountCollected: task_id: str @staticmethod - def from_string(s: str) -> "AccountCollected": - return AccountCollected(*s.split(":")) + def from_string(json_str: str) -> List[AccountCollected]: + return cattrs.structure(json.loads(json_str), List[AccountCollected]) def main() -> None: @@ -114,7 +118,9 @@ def main() -> None: parser = ArgumentParser() parser.add_argument("--job-id", required=True, help="Job Id of the coordinator") parser.add_argument("--tenant-id", required=True, help="Id of the tenant") - parser.add_argument("--accounts-collected", required=True, nargs="+", type=AccountCollected.from_string) + parser.add_argument( + "--accounts-collected", default=os.environ.get("ACCOUNTS_COLLECTED"), type=AccountCollected.from_string + ) parser.add_argument("--redis-url", default="redis://localhost:6379/0", help="Redis host.") parser.add_argument("--redis-password", default=os.environ.get("REDIS_PASSWORD"), help="Redis password") parser.add_argument("--push-gateway-url", help="Prometheus push gateway url")