diff --git a/pycti/api/opencti_api_client.py b/pycti/api/opencti_api_client.py index d3ba32f1..7140d496 100644 --- a/pycti/api/opencti_api_client.py +++ b/pycti/api/opencti_api_client.py @@ -700,6 +700,36 @@ def upload_pending_file(self, **kwargs): self.app_logger.error("[upload] Missing parameter: file_name") return None + def send_bundle_to_api(self, **kwargs): + """Push a bundle to a queue through OpenCTI API + + :param `**kwargs`: arguments for bundle push (required: `connectorId` and `bundle`) + :return: returns the query response for the bundle push + :rtype: dict + """ + + connector_id = kwargs.get("connector_id", None) + bundle = kwargs.get("bundle", None) + + if connector_id is not None and bundle is not None: + self.app_logger.info( + "Pushing a bundle to queue through API", {connector_id} + ) + mutation = """ + mutation StixBundlePush($connectorId: String!, $bundle: String!) { + stixBundlePush(connectorId: $connectorId, bundle: $bundle) + } + """ + return self.query( + mutation, + {"connectorId": connector_id, "bundle": bundle}, + ) + else: + self.app_logger.error( + "[bundle push] Missing parameter: connector_id or bundle" + ) + return None + def get_stix_content(self, id): """get the STIX content of any entity diff --git a/pycti/connector/opencti_connector_helper.py b/pycti/connector/opencti_connector_helper.py index c299c2d9..e106af77 100644 --- a/pycti/connector/opencti_connector_helper.py +++ b/pycti/connector/opencti_connector_helper.py @@ -776,6 +776,9 @@ def __init__(self, config: Dict, playbook_compatible=False) -> None: self.connect_id = get_config_variable( "CONNECTOR_ID", ["connector", "id"], config ) + self.queue_protocol = get_config_variable( + "QUEUE_PROTOCOL", ["connector", "queue_protocol"], config, default="amqp" + ) self.connect_type = get_config_variable( "CONNECTOR_TYPE", ["connector", "type"], config ) @@ -994,7 +997,6 @@ def __init__(self, config: Dict, playbook_compatible=False) -> None: # Start ping thread if not self.connect_run_and_terminate: - is_run_and_terminate = False if self.connect_duration_period == 0: is_run_and_terminate = True @@ -1689,10 +1691,11 @@ def send_stix2_bundle(self, bundle: str, **kwargs) -> list: expectations_number = len(json.loads(bundle)["objects"]) else: stix2_splitter = OpenCTIStix2Splitter() - expectations_number, bundles = ( - stix2_splitter.split_bundle_with_expectations( - bundle, True, event_version - ) + ( + expectations_number, + bundles, + ) = stix2_splitter.split_bundle_with_expectations( + bundle, True, event_version ) if len(bundles) == 0: @@ -1704,42 +1707,53 @@ def send_stix2_bundle(self, bundle: str, **kwargs) -> list: self.api.work.add_expectations(work_id, expectations_number) if entities_types is None: entities_types = [] - pika_credentials = pika.PlainCredentials( - self.connector_config["connection"]["user"], - self.connector_config["connection"]["pass"], - ) - pika_parameters = pika.ConnectionParameters( - host=self.connector_config["connection"]["host"], - port=self.connector_config["connection"]["port"], - virtual_host=self.connector_config["connection"]["vhost"], - credentials=pika_credentials, - ssl_options=( - pika.SSLOptions( - create_mq_ssl_context(self.config), - self.connector_config["connection"]["host"], + if self.queue_protocol == "amqp": + pika_credentials = pika.PlainCredentials( + self.connector_config["connection"]["user"], + self.connector_config["connection"]["pass"], + ) + pika_parameters = pika.ConnectionParameters( + host=self.connector_config["connection"]["host"], + port=self.connector_config["connection"]["port"], + virtual_host=self.connector_config["connection"]["vhost"], + credentials=pika_credentials, + ssl_options=( + pika.SSLOptions( + create_mq_ssl_context(self.config), + self.connector_config["connection"]["host"], + ) + if self.connector_config["connection"]["use_ssl"] + else None + ), + ) + pika_connection = pika.BlockingConnection(pika_parameters) + channel = pika_connection.channel() + try: + channel.confirm_delivery() + except Exception as err: # pylint: disable=broad-except + self.connector_logger.warning(str(err)) + self.connector_logger.info( + self.connect_name + " sending bundle to queue" + ) + for sequence, bundle in enumerate(bundles, start=1): + self._send_bundle( + channel, + bundle, + work_id=work_id, + entities_types=entities_types, + sequence=sequence, + update=update, ) - if self.connector_config["connection"]["use_ssl"] - else None - ), - ) - pika_connection = pika.BlockingConnection(pika_parameters) - channel = pika_connection.channel() - try: - channel.confirm_delivery() - except Exception as err: # pylint: disable=broad-except - self.connector_logger.warning(str(err)) - self.connector_logger.info(self.connect_name + " sending bundle to queue") - for sequence, bundle in enumerate(bundles, start=1): - self._send_bundle( - channel, - bundle, - work_id=work_id, - entities_types=entities_types, - sequence=sequence, - update=update, + channel.close() + pika_connection.close() + elif self.queue_protocol == "api": + self.api.send_bundle_to_api( + connector_id=self.connector_id, bundle=bundle + ) + else: + raise ValueError( + f"{self.queue_protocol}: this queue protocol is not supported" ) - channel.close() - pika_connection.close() return bundles