From 2d2fb9552afb367d526f593e6e6868102c041d59 Mon Sep 17 00:00:00 2001 From: Matthias Veit Date: Thu, 1 Aug 2024 17:09:28 +0200 Subject: [PATCH] Move benchmark creation and time series to post-process --- collect_single/collect_single.py | 36 +++----------------------------- collect_single/core_client.py | 10 +++++---- collect_single/post_collect.py | 23 ++++++++++++++++++++ tests/collect_and_sync_test.py | 3 ++- tests/core_client_test.py | 4 ++-- 5 files changed, 36 insertions(+), 40 deletions(-) diff --git a/collect_single/collect_single.py b/collect_single/collect_single.py index 7d9d526..5e256e8 100644 --- a/collect_single/collect_single.py +++ b/collect_single/collect_single.py @@ -23,9 +23,8 @@ from itertools import takewhile from pathlib import Path from typing import List, Tuple, Dict -from typing import Optional, Any, cast +from typing import Optional, cast -import yaml from arango.cursor import Cursor from fixcloudutils.logging import setup_logger from fixcloudutils.redis.event_stream import Json @@ -40,7 +39,6 @@ from redis.asyncio import Redis from collect_single.job import Job -from collect_single.model import MetricQuery from collect_single.process import ProcessWrapper log = logging.getLogger("fix.coordinator") @@ -68,17 +66,6 @@ def __init__( self.logging_context = logging_context self.task_id: Optional[str] = None self.worker_connected = asyncio.Event() - self.metrics: List[MetricQuery] = [] - - async def start(self) -> Any: - await super().start() - self.metrics = self.load_metrics() - - @staticmethod - def load_metrics() -> List[MetricQuery]: - with open(Path(__file__).parent / "metrics.yaml", "r") as f: - yml = yaml.safe_load(f) - return [MetricQuery.from_json(k, v) for k, v in yml.items() if "search" in v] async def listen_to_events_until_collect_done(self) -> bool: async for event in self.core_client.client.events(): @@ -118,32 +105,15 @@ async def start_collect(self) -> None: else: raise Exception("Could not start collect workflow") - async def post_process(self) -> Tuple[Json, List[str]]: + async def read_results(self) -> Tuple[Json, List[str]]: # get information about all accounts that have been collected/updated account_info = await self.core_client.account_info(self.account_id) # check if there were errors messages = await self.core_client.workflow_log(self.task_id) if self.task_id else [] - # post process the data, if something has been collected - if account_info and (account_id := self.account_id): - # synchronize the security section - benchmarks = await self.core_client.list_benchmarks(providers=[self.cloud] if self.cloud else None) - if benchmarks: - await self.core_client.create_benchmark_reports([account_id], benchmarks, self.task_id) - # create metrics - for metric in self.metrics: - res = await self.core_client.timeseries_snapshot(metric, account_id) - if res: - log.info(f"Created timeseries snapshot: {metric.name} created {res} entries") - # downsample all timeseries - ds = await self.core_client.timeseries_downsample() - log.info(f"Sampled down all timeseries. Result: {ds}") - else: - raise ValueError("No account info found. Give up!") - return account_info, messages async def send_result_events(self, read_from_process: bool, error_messages: Optional[List[str]] = None) -> None: - account_info, messages = await self.post_process() if read_from_process else ({}, error_messages or []) + account_info, messages = await self.read_results() if read_from_process else ({}, error_messages or []) # send a collect done event for the tenant await self.collect_done_publisher.publish( "collect-done", diff --git a/collect_single/core_client.py b/collect_single/core_client.py index 3caa939..dded71e 100644 --- a/collect_single/core_client.py +++ b/collect_single/core_client.py @@ -104,10 +104,12 @@ async def wait_for_worker_subscribed(self) -> List[Subscriber]: log.info("Wait for worker to connect.") await asyncio.sleep(1) - async def timeseries_snapshot(self, metric: MetricQuery, account_id: str) -> int: - query = query_parser.parse_query(metric.search).combine( - Query.by(P("/ancestors.account.reported.id").eq(account_id)) - ) + async def timeseries_snapshot(self, metric: MetricQuery, account_id: Optional[str] = None) -> int: + # create query + query = query_parser.parse_query(metric.search) + if account_id: + query = query.combine(Query.by(P("/ancestors.account.reported.id").eq(account_id))) + # create command command = f"timeseries snapshot --name {metric.name} " if metric.factor: command += f"--avg-factor {metric.factor} " diff --git a/collect_single/post_collect.py b/collect_single/post_collect.py index bc7cafa..009a3ed 100644 --- a/collect_single/post_collect.py +++ b/collect_single/post_collect.py @@ -7,16 +7,20 @@ import sys from argparse import ArgumentParser, Namespace from itertools import takewhile +from pathlib import Path +from time import sleep from typing import List, Dict from typing import Optional import cattrs +import yaml from attr import define from fixcloudutils.logging import setup_logger from fixcloudutils.util import utc, utc_str from redis.asyncio.client import Redis from collect_single.job import Job +from collect_single.model import MetricQuery from collect_single.process import ProcessWrapper log = logging.getLogger("fix.coordinator") @@ -39,6 +43,12 @@ def __init__( self.core_args = core_args self.logging_context = logging_context + @staticmethod + def load_metrics() -> List[MetricQuery]: + with open(Path(__file__).parent / "metrics.yaml", "r") as f: + yml = yaml.safe_load(f) + return [MetricQuery.from_json(k, v) for k, v in yml.items() if "search" in v] + async def send_result_events(self, exception: Optional[Exception] = None) -> None: # send a collect done event for the tenant await self.collect_done_publisher.publish( @@ -53,6 +63,16 @@ async def send_result_events(self, exception: Optional[Exception] = None) -> Non }, ) + async def create_timeseries(self) -> None: + # create metrics + for metric in self.load_metrics(): + res = await self.core_client.timeseries_snapshot(metric) + if res: + log.info(f"Created timeseries snapshot: {metric.name} created {res} entries") + # downsample all timeseries + ds = await self.core_client.timeseries_downsample() + log.info(f"Sampled down all timeseries. Result: {ds}") + async def merge_deferred_edges(self) -> None: await self.core_client.merge_deferred_edges([ac.task_id for ac in self.accounts_collected]) @@ -75,6 +95,9 @@ async def sync(self) -> None: log.info("All deferred edges have been updated.") await self.security_report() log.info("Security reports have been synchronized.") + sleep(10) # wait for the view to become ready + await self.create_timeseries() + log.info("Time series have been updated.") await asyncio.wait_for(self.send_result_events(), 600) # wait up to 10 minutes except Exception as ex: log.info(f"Got Exception during sync: {ex}") diff --git a/tests/collect_and_sync_test.py b/tests/collect_and_sync_test.py index dc2b3da..e20ecbb 100644 --- a/tests/collect_and_sync_test.py +++ b/tests/collect_and_sync_test.py @@ -20,6 +20,7 @@ from fixcore.query import query_parser from collect_single.collect_single import CollectSingle +from collect_single.post_collect import PostCollect @pytest.mark.asyncio @@ -51,7 +52,7 @@ async def test_migrate_ts(collect_and_sync: CollectSingle) -> None: def test_load_metrics() -> None: - metrics = CollectSingle.load_metrics() + metrics = PostCollect.load_metrics() assert len(metrics) >= 14 for query in metrics: # make sure the query parser does not explode diff --git a/tests/core_client_test.py b/tests/core_client_test.py index a7d3ff7..ea511c2 100644 --- a/tests/core_client_test.py +++ b/tests/core_client_test.py @@ -5,8 +5,8 @@ import pytest -from collect_single.collect_single import CollectSingle from collect_single.core_client import CoreClient +from collect_single.post_collect import PostCollect @pytest.fixture @@ -74,7 +74,7 @@ async def test_wait_for_collect_task_to_finish(core_client: CoreClient) -> None: async def test_timeseries_snapshot(core_client: CoreClient) -> None: accounts = [a async for a in core_client.client.search_list("is(aws_account) limit 1")] single = accounts[0]["reported"]["id"] - for query in CollectSingle.load_metrics(): + for query in PostCollect.load_metrics(): res = await core_client.timeseries_snapshot(query, single) assert res >= 0