diff --git a/Dockerfile b/Dockerfile
index 29a3b11..2715df7 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -6,7 +6,7 @@ ADD . /single_coordinator
RUN . /usr/local/fix-venv-python3/bin/activate && pip install /single_coordinator && rm -rf /single_coordinator
# Add shim and create symlink
-COPY collect_single_shim /usr/local/bin/collect_single_shim
-RUN chmod 755 /usr/local/bin/collect_single_shim && ln -s /usr/local/bin/collect_single_shim /usr/bin/collect_single
+COPY dispatch_executable_shim.sh /usr/local/bin/dispatch_executable_shim
+RUN chmod 755 /usr/local/bin/dispatch_executable_shim && ln -s /usr/local/bin/dispatch_executable_shim /usr/bin/dispatch_executable
-ENTRYPOINT ["/bin/dumb-init", "--", "/usr/local/sbin/bootstrap", "/usr/bin/collect_single"]
+ENTRYPOINT ["/bin/dumb-init", "--", "/usr/local/sbin/bootstrap", "/usr/bin/dispatch_executable"]
diff --git a/collect_single/__main__.py b/collect_single/__main__.py
deleted file mode 100644
index 5fd94fc..0000000
--- a/collect_single/__main__.py
+++ /dev/null
@@ -1,131 +0,0 @@
-# fix-collect-single
-# Copyright (C) 2023 Some Engineering
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Affero General Public License for more details.
-#
-# You should have received a copy of the GNU Affero General Public License
-# along with this program. If not, see .
-
-import asyncio
-import logging
-import os
-import sys
-from argparse import Namespace, ArgumentParser
-from datetime import timedelta
-from itertools import takewhile
-from pathlib import Path
-from typing import List, Tuple, Dict
-
-from fixcloudutils.logging import setup_logger
-from fixcloudutils.util import utc
-from redis.asyncio import Redis
-
-from collect_single.collect_and_sync import CollectAndSync
-
-log = logging.getLogger("fix.coordinator")
-
-
-def kv_pairs(s: str) -> Tuple[str, str]:
- return tuple(s.split("=", maxsplit=1)) # type: ignore
-
-
-async def startup(
- args: Namespace, core_args: List[str], worker_args: List[str], logging_context: Dict[str, str]
-) -> None:
- redis_args = {}
- if args.redis_password:
- redis_args["password"] = args.redis_password
- if args.redis_url.startswith("rediss://") and args.ca_cert:
- redis_args["ssl_ca_certs"] = args.ca_cert
- async with Redis.from_url(args.redis_url, decode_responses=True, **redis_args) as redis:
-
- async def collect_and_sync(send_on_failed: bool) -> bool:
- async with CollectAndSync(
- redis=redis,
- tenant_id=args.tenant_id,
- cloud=args.cloud,
- account_id=args.account_id,
- job_id=args.job_id,
- core_args=core_args,
- worker_args=worker_args,
- push_gateway_url=args.push_gateway_url,
- logging_context=logging_context,
- ) as cas:
- return await cas.sync(send_on_failed)
-
- if retry := args.retry_failed_for:
- log.info(f"Collect job with retry enabled for {retry}.")
- has_result = False
- deadline = utc() + retry
- while not has_result and utc() < deadline:
- # collect and do not send a message in the failing case
- has_result = await collect_and_sync(False)
- if not has_result:
- log.info("Failed collect with retry enabled. Retrying in 30s.")
- await asyncio.sleep(30)
- # if we come here without a result, collect and also send a message in the failing case
- if not has_result:
- log.info("Last attempt to collect with retry enabled.")
- await collect_and_sync(True)
- else:
- await collect_and_sync(True)
-
-
-def main() -> None:
- # 3 argument sets delimited by "---": --- ---
- # coordinator --main-arg1 --main-arg2 --- --core-arg1 --core-arg2 --- --worker-arg1 --worker-arg2
- args = iter(sys.argv[1:])
- coordinator_args = list(takewhile(lambda x: x != "---", args))
- core_args = list(takewhile(lambda x: x != "---", args))
- worker_args = list(args)
- # handle coordinator args
- parser = ArgumentParser()
- parser.add_argument(
- "--write",
- type=kv_pairs,
- help="Write config files in home dir from env vars. Format: --write path/in/home/dir=env-var-name",
- default=[],
- action="append",
- )
- parser.add_argument("--job-id", required=True, help="Job Id of the coordinator")
- parser.add_argument("--tenant-id", required=True, help="Id of the tenant")
- parser.add_argument("--account-id", help="Id of the account")
- parser.add_argument("--cloud", help="Cloud provider.")
- parser.add_argument("--redis-url", default="redis://localhost:6379/0", help="Redis host.")
- parser.add_argument("--redis-password", default=os.environ.get("REDIS_PASSWORD"), help="Redis password")
- parser.add_argument("--push-gateway-url", help="Prometheus push gateway url")
- parser.add_argument("--ca-cert", help="Path to CA cert file")
- parser.add_argument(
- "--retry-failed-for", type=lambda x: timedelta(seconds=float(x)), help="Seconds to retry failed jobs."
- )
- parsed = parser.parse_args(coordinator_args)
-
- # setup logging
- logging_context = dict(job_id=parsed.job_id, workspace_id=parsed.tenant_id, cloud_account_id=parsed.account_id)
- setup_logger("collect-single", get_logging_context=lambda: {"process": "coordinator", **logging_context})
-
- # write config files from env vars
- env_vars = {k.lower(): v for k, v in os.environ.items()}
- for home_path, env_var_name in parsed.write:
- path = (Path.home() / Path(home_path)).absolute()
- content = env_vars.get(env_var_name.lower())
- assert content is not None, f"Env var {env_var_name} not found"
- log.info(f"Writing file: {path} from env var: {env_var_name}")
- path.parent.mkdir(parents=True, exist_ok=True)
- with open(path, "w+") as f:
- f.write(content)
-
- log.info(f"Coordinator args:({coordinator_args}) Core args:({core_args}) Worker args:({worker_args})")
- asyncio.run(startup(parsed, core_args, worker_args, logging_context))
-
-
-if __name__ == "__main__":
- main()
diff --git a/collect_single/collect_and_sync.py b/collect_single/collect_single.py
similarity index 67%
rename from collect_single/collect_and_sync.py
rename to collect_single/collect_single.py
index c8f787c..7d9d526 100644
--- a/collect_single/collect_and_sync.py
+++ b/collect_single/collect_single.py
@@ -16,33 +16,37 @@
import asyncio
import logging
+import os
+import sys
+from argparse import Namespace, ArgumentParser
from datetime import timedelta
+from itertools import takewhile
from pathlib import Path
-from typing import List, Optional, Any, Tuple, Dict, cast
+from typing import List, Tuple, Dict
+from typing import Optional, Any, cast
import yaml
from arango.cursor import Cursor
-from fixcloudutils.redis.event_stream import RedisStreamPublisher, Json
+from fixcloudutils.logging import setup_logger
+from fixcloudutils.redis.event_stream import Json
from fixcloudutils.redis.lock import Lock
-from fixcloudutils.redis.pub_sub import RedisPubSubPublisher
-from fixcloudutils.service import Service
-from fixcloudutils.util import utc, utc_str
+from fixcloudutils.util import utc
+from fixcloudutils.util import utc_str
from fixcore.core_config import parse_config
from fixcore.db.async_arangodb import AsyncArangoDB
from fixcore.db.db_access import DbAccess
from fixcore.db.timeseriesdb import TimeSeriesDB
from fixcore.system_start import parse_args as core_parse_args
from redis.asyncio import Redis
-import prometheus_client
-from collect_single.core_client import CoreClient
+from collect_single.job import Job
from collect_single.model import MetricQuery
from collect_single.process import ProcessWrapper
log = logging.getLogger("fix.coordinator")
-class CollectAndSync(Service):
+class CollectSingle(Job):
def __init__(
self,
*,
@@ -55,36 +59,20 @@ def __init__(
worker_args: List[str],
logging_context: Dict[str, str],
push_gateway_url: Optional[str] = None,
- core_url: str = "http://localhost:8980",
) -> None:
- self.redis = redis
- self.tenant_id = tenant_id
+ super().__init__(redis=redis, job_id=job_id, tenant_id=tenant_id, push_gateway_url=push_gateway_url)
self.cloud = cloud
self.account_id = account_id
- self.job_id = job_id
self.core_args = ["--no-scheduling", "--ignore-interrupted-tasks"] + core_args
self.worker_args = worker_args
self.logging_context = logging_context
- self.core_client = CoreClient(core_url)
self.task_id: Optional[str] = None
- self.push_gateway_url = push_gateway_url
- publisher = "collect-and-sync"
- self.progress_update_publisher = RedisPubSubPublisher(redis, f"tenant-events::{tenant_id}", publisher)
- self.collect_done_publisher = RedisStreamPublisher(redis, "collect-events", publisher)
- self.started_at = utc()
self.worker_connected = asyncio.Event()
self.metrics: List[MetricQuery] = []
async def start(self) -> Any:
- await self.progress_update_publisher.start()
- await self.collect_done_publisher.start()
+ await super().start()
self.metrics = self.load_metrics()
- # note: the client is not started (core is not running and no certificate required)
-
- async def stop(self) -> None:
- await self.core_client.stop()
- await self.progress_update_publisher.stop()
- await self.collect_done_publisher.stop()
@staticmethod
def load_metrics() -> List[MetricQuery]:
@@ -140,7 +128,7 @@ async def post_process(self) -> Tuple[Json, List[str]]:
# synchronize the security section
benchmarks = await self.core_client.list_benchmarks(providers=[self.cloud] if self.cloud else None)
if benchmarks:
- await self.core_client.create_benchmark_reports(account_id, benchmarks, self.task_id)
+ await self.core_client.create_benchmark_reports([account_id], benchmarks, self.task_id)
# create metrics
for metric in self.metrics:
res = await self.core_client.timeseries_snapshot(metric, account_id)
@@ -170,14 +158,6 @@ async def send_result_events(self, read_from_process: bool, error_messages: Opti
},
)
- async def push_metrics(self) -> None:
- if gateway := self.push_gateway_url:
- # Possible future option: retrieve metrics from core and worker and push them to prometheus
- prometheus_client.push_to_gateway(
- gateway=gateway, job="collect_single", registry=prometheus_client.REGISTRY
- )
- log.info("Metrics pushed to gateway")
-
async def migrate_ts_data(self) -> None:
ts_with_account = "for doc in ts filter doc.group.account!=null"
update = (
@@ -243,3 +223,101 @@ async def sync(self, send_on_failed: bool) -> bool:
await asyncio.wait_for(self.send_result_events(False, [str(ex)]), 600) # wait up to 10 minutes
result_send = True
return result_send
+
+
+async def startup(
+ args: Namespace, core_args: List[str], worker_args: List[str], logging_context: Dict[str, str]
+) -> None:
+ redis_args = {}
+ if args.redis_password:
+ redis_args["password"] = args.redis_password
+ if args.redis_url.startswith("rediss://") and args.ca_cert:
+ redis_args["ssl_ca_certs"] = args.ca_cert
+ async with Redis.from_url(args.redis_url, decode_responses=True, **redis_args) as redis:
+
+ async def collect_and_sync(send_on_failed: bool) -> bool:
+ async with CollectSingle(
+ redis=redis,
+ tenant_id=args.tenant_id,
+ cloud=args.cloud,
+ account_id=args.account_id,
+ job_id=args.job_id,
+ core_args=core_args,
+ worker_args=worker_args,
+ push_gateway_url=args.push_gateway_url,
+ logging_context=logging_context,
+ ) as cas:
+ return await cas.sync(send_on_failed)
+
+ if retry := args.retry_failed_for:
+ log.info(f"Collect job with retry enabled for {retry}.")
+ has_result = False
+ deadline = utc() + retry
+ while not has_result and utc() < deadline:
+ # collect and do not send a message in the failing case
+ has_result = await collect_and_sync(False)
+ if not has_result:
+ log.info("Failed collect with retry enabled. Retrying in 30s.")
+ await asyncio.sleep(30)
+ # if we come here without a result, collect and also send a message in the failing case
+ if not has_result:
+ log.info("Last attempt to collect with retry enabled.")
+ await collect_and_sync(True)
+ else:
+ await collect_and_sync(True)
+
+
+def kv_pairs(s: str) -> Tuple[str, str]:
+ return tuple(s.split("=", maxsplit=1)) # type: ignore
+
+
+def main() -> None:
+ # 3 argument sets delimited by "---": --- ---
+ # coordinator --main-arg1 --main-arg2 --- --core-arg1 --core-arg2 --- --worker-arg1 --worker-arg2
+ args = iter(sys.argv[1:])
+ coordinator_args = list(takewhile(lambda x: x != "---", args))
+ core_args = list(takewhile(lambda x: x != "---", args))
+ worker_args = list(args)
+ # handle coordinator args
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--write",
+ type=kv_pairs,
+ help="Write config files in home dir from env vars. Format: --write path/in/home/dir=env-var-name",
+ default=[],
+ action="append",
+ )
+ parser.add_argument("--job-id", required=True, help="Job Id of the coordinator")
+ parser.add_argument("--tenant-id", required=True, help="Id of the tenant")
+ parser.add_argument("--account-id", help="Id of the account")
+ parser.add_argument("--cloud", help="Cloud provider.")
+ parser.add_argument("--redis-url", default="redis://localhost:6379/0", help="Redis host.")
+ parser.add_argument("--redis-password", default=os.environ.get("REDIS_PASSWORD"), help="Redis password")
+ parser.add_argument("--push-gateway-url", help="Prometheus push gateway url")
+ parser.add_argument("--ca-cert", help="Path to CA cert file")
+ parser.add_argument(
+ "--retry-failed-for", type=lambda x: timedelta(seconds=float(x)), help="Seconds to retry failed jobs."
+ )
+ parsed = parser.parse_args(coordinator_args)
+
+ # setup logging
+ logging_context = dict(job_id=parsed.job_id, workspace_id=parsed.tenant_id, cloud_account_id=parsed.account_id)
+ setup_logger("collect-single", get_logging_context=lambda: {"process": "coordinator", **logging_context})
+
+ # write config files from env vars
+ env_vars = {k.lower(): v for k, v in os.environ.items()}
+ for home_path, env_var_name in parsed.write:
+ path = (Path.home() / Path(home_path)).absolute()
+ content = env_vars.get(env_var_name.lower())
+ assert content is not None, f"Env var {env_var_name} not found"
+ log.info(f"Writing file: {path} from env var: {env_var_name}")
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with open(path, "w+") as f:
+ f.write(content)
+
+ log.info(f"Coordinator args:({coordinator_args}) Core args:({core_args}) Worker args:({worker_args})")
+ asyncio.run(startup(parsed, core_args, worker_args, logging_context))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/collect_single/core_client.py b/collect_single/core_client.py
index 62c433e..3caa939 100644
--- a/collect_single/core_client.py
+++ b/collect_single/core_client.py
@@ -71,11 +71,14 @@ async def list_benchmarks(self, *, providers: Optional[List[str]] = None) -> Lis
else:
raise AttributeError(await response.text())
- async def create_benchmark_reports(self, account_id: str, benchmarks: List[str], task_id: Optional[str]) -> None:
+ async def create_benchmark_reports(
+ self, account_ids: List[str], benchmarks: List[str], run_id: Optional[str]
+ ) -> None:
bn = " ".join(benchmarks)
- run_id = task_id or str(uuid.uuid4())
- command = f"report benchmark run {bn} --accounts {account_id} --sync-security-section --run-id {run_id} | count"
- log.info(f"Create reports for following benchmarks: {bn} for accounts: {account_id}. Command: {command}")
+ an = " ".join(account_ids)
+ rid = run_id or str(uuid.uuid4())
+ command = f"report benchmark run {bn} --accounts {an} --sync-security-section --run-id {rid} | count"
+ log.info(f"Create reports for following benchmarks: {bn} for accounts: {an}. Command: {command}")
async for _ in self.client.cli_execute(command, headers={"Accept": "application/json"}):
pass # ignore the result
@@ -138,3 +141,10 @@ async def account_id_by_name(self) -> Dict[Optional[str], Optional[str]]:
value_in_path(r, "reported.name"): value_in_path(r, "reported.id")
async for r in self.client.cli_execute("search is(account) | dump")
}
+
+ async def merge_deferred_edges(self, task_ids: List[str], *, graph: str = "fix") -> None:
+ response = await self.client._post(f"/graph/{graph}/merge/deferred_edges", json=task_ids)
+ if response.status_code == 200:
+ return await response.json() # type: ignore
+ else:
+ raise AttributeError(await response.text())
diff --git a/collect_single/job.py b/collect_single/job.py
new file mode 100644
index 0000000..5d8e647
--- /dev/null
+++ b/collect_single/job.py
@@ -0,0 +1,53 @@
+import logging
+from abc import ABC
+from typing import Optional, Any
+
+import prometheus_client
+from fixcloudutils.redis.event_stream import RedisStreamPublisher
+from fixcloudutils.redis.pub_sub import RedisPubSubPublisher
+from fixcloudutils.service import Service
+from fixcloudutils.util import utc
+from redis.asyncio.client import Redis
+
+from collect_single.core_client import CoreClient
+
+log = logging.getLogger("fix.coordinator")
+
+
+class Job(Service, ABC):
+ def __init__(
+ self,
+ *,
+ redis: Redis, # type: ignore
+ job_id: str,
+ tenant_id: str,
+ push_gateway_url: Optional[str] = None,
+ core_url: str = "http://localhost:8980",
+ ) -> None:
+ self.redis = redis
+ self.job_id = job_id
+ self.tenant_id = tenant_id
+ self.push_gateway_url = push_gateway_url
+ self.core_client = CoreClient(core_url)
+ publisher = "collect-and-sync"
+ self.progress_update_publisher = RedisPubSubPublisher(redis, f"tenant-events::{tenant_id}", publisher)
+ self.collect_done_publisher = RedisStreamPublisher(redis, "collect-events", publisher)
+ self.started_at = utc()
+
+ async def start(self) -> Any:
+ await self.progress_update_publisher.start()
+ await self.collect_done_publisher.start()
+ # note: the client is not started (core is not running and no certificate required)
+
+ async def stop(self) -> None:
+ await self.progress_update_publisher.stop()
+ await self.collect_done_publisher.stop()
+ await self.core_client.stop()
+
+ async def push_metrics(self) -> None:
+ if gateway := self.push_gateway_url:
+ # Possible future option: retrieve metrics from core and worker and push them to prometheus
+ prometheus_client.push_to_gateway(
+ gateway=gateway, job="collect_single", registry=prometheus_client.REGISTRY
+ )
+ log.info("Metrics pushed to gateway")
diff --git a/collect_single/post_collect.py b/collect_single/post_collect.py
new file mode 100644
index 0000000..bc7cafa
--- /dev/null
+++ b/collect_single/post_collect.py
@@ -0,0 +1,137 @@
+from __future__ import annotations
+
+import asyncio
+import json
+import logging
+import os
+import sys
+from argparse import ArgumentParser, Namespace
+from itertools import takewhile
+from typing import List, Dict
+from typing import Optional
+
+import cattrs
+from attr import define
+from fixcloudutils.logging import setup_logger
+from fixcloudutils.util import utc, utc_str
+from redis.asyncio.client import Redis
+
+from collect_single.job import Job
+from collect_single.process import ProcessWrapper
+
+log = logging.getLogger("fix.coordinator")
+
+
+class PostCollect(Job):
+ def __init__(
+ self,
+ *,
+ redis: Redis, # type: ignore
+ tenant_id: str,
+ job_id: str,
+ accounts_collected: List[AccountCollected],
+ core_args: List[str],
+ logging_context: Dict[str, str],
+ push_gateway_url: Optional[str] = None,
+ ) -> None:
+ super().__init__(redis=redis, job_id=job_id, tenant_id=tenant_id, push_gateway_url=push_gateway_url)
+ self.accounts_collected = accounts_collected
+ self.core_args = core_args
+ self.logging_context = logging_context
+
+ async def send_result_events(self, exception: Optional[Exception] = None) -> None:
+ # send a collect done event for the tenant
+ await self.collect_done_publisher.publish(
+ "post-collect-done",
+ {
+ "job_id": self.job_id,
+ "tenant_id": self.tenant_id,
+ "started_at": utc_str(self.started_at),
+ "duration": int((utc() - self.started_at).total_seconds()),
+ "success": exception is None,
+ "exception": str(exception) if exception else None,
+ },
+ )
+
+ async def merge_deferred_edges(self) -> None:
+ await self.core_client.merge_deferred_edges([ac.task_id for ac in self.accounts_collected])
+
+ async def security_report(self) -> None:
+ for acc in self.accounts_collected:
+ benchmarks = await self.core_client.list_benchmarks(providers=[acc.cloud])
+ if benchmarks:
+ await self.core_client.create_benchmark_reports([acc.account_id], benchmarks, acc.task_id)
+
+ async def sync(self) -> None:
+ try:
+ if self.accounts_collected: # Don't do anything if no accounts were collected
+ aids = ", ".join([ac.account_id for ac in self.accounts_collected])
+ log.info(f"Sync tenant {self.tenant_id}: with {len(self.accounts_collected)} accounts: {aids}.")
+ async with ProcessWrapper(["fixcore", *self.core_args], self.logging_context):
+ log.info("Core started.")
+ await asyncio.wait_for(self.core_client.wait_connected(), timeout=60)
+ log.info("Core Client connected.")
+ await self.merge_deferred_edges()
+ log.info("All deferred edges have been updated.")
+ await self.security_report()
+ log.info("Security reports have been synchronized.")
+ await asyncio.wait_for(self.send_result_events(), 600) # wait up to 10 minutes
+ except Exception as ex:
+ log.info(f"Got Exception during sync: {ex}")
+ await asyncio.wait_for(self.send_result_events(ex), 600) # wait up to 10 minutes
+
+
+async def startup(args: Namespace, core_args: List[str], logging_context: Dict[str, str]) -> None:
+ redis_args = {}
+ if args.redis_password:
+ redis_args["password"] = args.redis_password
+ if args.redis_url.startswith("rediss://") and args.ca_cert:
+ redis_args["ssl_ca_certs"] = args.ca_cert
+ async with Redis.from_url(args.redis_url, decode_responses=True, **redis_args) as redis:
+ async with PostCollect(
+ redis=redis,
+ tenant_id=args.tenant_id,
+ job_id=args.job_id,
+ core_args=core_args,
+ accounts_collected=args.accounts_collected,
+ logging_context=logging_context,
+ push_gateway_url=args.push_gateway_url,
+ ) as post_collect:
+ await post_collect.sync()
+
+
+@define
+class AccountCollected:
+ cloud: str
+ account_id: str
+ task_id: str
+
+ @staticmethod
+ def from_string(json_str: str) -> List[AccountCollected]:
+ return cattrs.structure(json.loads(json_str), List[AccountCollected])
+
+
+def main() -> None:
+ args = iter(sys.argv[1:])
+ post_process_args = list(takewhile(lambda x: x != "---", args))
+ core_args = list(takewhile(lambda x: x != "---", args))
+ parser = ArgumentParser()
+ parser.add_argument("--job-id", required=True, help="Job Id of the coordinator")
+ parser.add_argument("--tenant-id", required=True, help="Id of the tenant")
+ parser.add_argument(
+ "--accounts-collected", default=os.environ.get("ACCOUNTS_COLLECTED"), type=AccountCollected.from_string
+ )
+ parser.add_argument("--redis-url", default="redis://localhost:6379/0", help="Redis host.")
+ parser.add_argument("--redis-password", default=os.environ.get("REDIS_PASSWORD"), help="Redis password")
+ parser.add_argument("--push-gateway-url", help="Prometheus push gateway url")
+ parser.add_argument("--ca-cert", help="Path to CA cert file")
+ parsed = parser.parse_args(post_process_args)
+
+ # setup logging
+ logging_context = dict(job_id=parsed.job_id, workspace_id=parsed.tenant_id)
+ setup_logger("post-collect", get_logging_context=lambda: {"process": "post-collect", **logging_context})
+ asyncio.run(startup(parsed, core_args, logging_context))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/collect_single_shim b/collect_single_shim
deleted file mode 100755
index 314a7a4..0000000
--- a/collect_single_shim
+++ /dev/null
@@ -1,4 +0,0 @@
-#!/bin/bash
-. /usr/local/etc/fixinventory/defaults
-. /usr/local/fix-venv-python3/bin/activate
-exec collect_single "$@"
diff --git a/dispatch_executable_shim.sh b/dispatch_executable_shim.sh
new file mode 100755
index 0000000..53e42bd
--- /dev/null
+++ b/dispatch_executable_shim.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+. /usr/local/etc/fixinventory/defaults
+. /usr/local/fix-venv-python3/bin/activate
+
+# dispatch to the correct console script
+case $1 in
+ collect)
+ shift
+ exec collect_single "$@"
+ ;;
+ post-collect)
+ shift
+ exec post_collect "$@"
+ ;;
+ *)
+ # backward compatibility: delegate to collect_single
+ exec collect_single "$@"
+ ;;
+esac
+
diff --git a/pyproject.toml b/pyproject.toml
index 03b0ce7..509cbce 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,7 +40,8 @@ test = [
[project.scripts]
-collect_single = "collect_single.__main__:main"
+collect_single = "collect_single.collect_single:main"
+post_collect = "collect_single.post_collect:main"
[build-system]
requires = ["setuptools>=67.8.0", "wheel>=0.40.0", "build>=0.10.0"]
diff --git a/tests/collect_and_sync_test.py b/tests/collect_and_sync_test.py
index 9b1ba21..dc2b3da 100644
--- a/tests/collect_and_sync_test.py
+++ b/tests/collect_and_sync_test.py
@@ -19,7 +19,7 @@
from fixclient.async_client import FixInventoryClient
from fixcore.query import query_parser
-from collect_single.collect_and_sync import CollectAndSync
+from collect_single.collect_single import CollectSingle
@pytest.mark.asyncio
@@ -40,18 +40,18 @@ async def test_client(core_client: FixInventoryClient) -> None:
@pytest.mark.asyncio
@pytest.mark.skip(reason="Only for manual testing")
-async def test_collect_and_sync(collect_and_sync: CollectAndSync) -> None:
+async def test_collect_and_sync(collect_and_sync: CollectSingle) -> None:
await collect_and_sync.send_result_events(True)
@pytest.mark.asyncio
@pytest.mark.skip(reason="Only for manual testing")
-async def test_migrate_ts(collect_and_sync: CollectAndSync) -> None:
+async def test_migrate_ts(collect_and_sync: CollectSingle) -> None:
await collect_and_sync.migrate_ts_data()
def test_load_metrics() -> None:
- metrics = CollectAndSync.load_metrics()
+ metrics = CollectSingle.load_metrics()
assert len(metrics) >= 14
for query in metrics:
# make sure the query parser does not explode
diff --git a/tests/conftest.py b/tests/conftest.py
index ce07103..2038a6c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -22,7 +22,7 @@
from redis.backoff import ExponentialBackoff
from fixclient.async_client import FixInventoryClient
-from collect_single.collect_and_sync import CollectAndSync
+from collect_single.collect_single import CollectSingle
@fixture
@@ -46,8 +46,8 @@ async def core_client() -> AsyncIterator[FixInventoryClient]:
@fixture
-async def collect_and_sync(redis: Redis) -> AsyncIterator[CollectAndSync]: # type: ignore
- async with CollectAndSync(
+async def collect_and_sync(redis: Redis) -> AsyncIterator[CollectSingle]: # type: ignore
+ async with CollectSingle(
redis=redis,
tenant_id="tenant_id",
cloud="aws",
diff --git a/tests/core_client_test.py b/tests/core_client_test.py
index 5fce3be..a7d3ff7 100644
--- a/tests/core_client_test.py
+++ b/tests/core_client_test.py
@@ -5,7 +5,7 @@
import pytest
-from collect_single.collect_and_sync import CollectAndSync
+from collect_single.collect_single import CollectSingle
from collect_single.core_client import CoreClient
@@ -53,7 +53,7 @@ async def test_create_benchmark_report(core_client: CoreClient) -> None:
accounts = [a async for a in core_client.client.search_list("is(aws_account) limit 1")]
single = accounts[0]["reported"]["id"]
task_id = str(uuid.uuid4())
- await core_client.create_benchmark_reports(single, ["aws_cis_1_5"], task_id)
+ await core_client.create_benchmark_reports([single], ["aws_cis_1_5"], task_id)
res = [
a
async for a in core_client.client.cli_execute(
@@ -74,7 +74,7 @@ async def test_wait_for_collect_task_to_finish(core_client: CoreClient) -> None:
async def test_timeseries_snapshot(core_client: CoreClient) -> None:
accounts = [a async for a in core_client.client.search_list("is(aws_account) limit 1")]
single = accounts[0]["reported"]["id"]
- for query in CollectAndSync.load_metrics():
+ for query in CollectSingle.load_metrics():
res = await core_client.timeseries_snapshot(query, single)
assert res >= 0