From 105e590ae242ec58af47ae2dfc8a97121a85fb7e Mon Sep 17 00:00:00 2001 From: AWSHurneyt Date: Tue, 16 Jul 2024 17:36:51 -0700 Subject: [PATCH] Refactored ListIOCs API to return the correct number of findings for each IOC. (#1163) Signed-off-by: AWSHurneyt --- .../transport/TransportListIOCsAction.java | 82 ++++++++++++++++--- .../ThreatIntelMonitorRestApiIT.java | 69 +++++++++++++++- 2 files changed, 138 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportListIOCsAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportListIOCsAction.java index 2e1954ce7..1e91ce1f3 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportListIOCsAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportListIOCsAction.java @@ -21,6 +21,7 @@ import org.opensearch.common.inject.Inject; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.commons.alerting.model.Table; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; @@ -42,6 +43,11 @@ import org.opensearch.securityanalytics.model.DetailedSTIX2IOCDto; import org.opensearch.securityanalytics.model.STIX2IOC; import org.opensearch.securityanalytics.model.STIX2IOCDto; +import org.opensearch.securityanalytics.model.threatintel.IocFinding; +import org.opensearch.securityanalytics.model.threatintel.IocWithFeeds; +import org.opensearch.securityanalytics.threatIntel.action.GetIocFindingsAction; +import org.opensearch.securityanalytics.threatIntel.action.GetIocFindingsRequest; +import org.opensearch.securityanalytics.threatIntel.action.GetIocFindingsResponse; import org.opensearch.securityanalytics.threatIntel.model.DefaultIocStoreConfig; import org.opensearch.securityanalytics.threatIntel.model.SATIFSourceConfig; import org.opensearch.securityanalytics.threatIntel.service.DefaultTifSourceConfigLoaderService; @@ -55,7 +61,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -178,12 +188,8 @@ private void listIocs(List iocIndices) { } boolQueryBuilder.must(typeQueryBuilder); } -// todo remove filter. not needed because feed ids are fetch before listIocs() -// if (request.getFeedIds() != null && !request.getFeedIds().isEmpty()) { -// boolQueryBuilder.filter(QueryBuilders.termQuery(STIX2_IOC_NESTED_PATH + STIX2IOC.FEED_ID_FIELD, request.getFeedIds())); -// } - if (!request.getTable().getSearchString().isEmpty()) { + if (request.getTable().getSearchString() != null && !request.getTable().getSearchString().isEmpty()) { boolQueryBuilder.must( QueryBuilders.queryStringQuery(request.getTable().getSearchString()) .defaultOperator(Operator.OR) @@ -202,7 +208,7 @@ private void listIocs(List iocIndices) { SortBuilder sortBuilder = SortBuilders .fieldSort(STIX2_IOC_NESTED_PATH + request.getTable().getSortString()) - .order(SortOrder.fromString(request.getTable().getSortOrder().toString())); + .order(SortOrder.fromString(request.getTable().getSortOrder())); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder() .version(true) @@ -224,7 +230,10 @@ public void onResponse(SearchResponse searchResponse) { if (searchResponse.isTimedOut()) { onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); } - List iocs = new ArrayList<>(); + + // Concurrently compiling a separate list of IOC IDs to create the subsequent GetIocFindingsRequest + Set iocIds = new HashSet<>(); + List iocs = new ArrayList<>(); Arrays.stream(searchResponse.getHits().getHits()) .forEach(hit -> { try { @@ -236,17 +245,66 @@ public void onResponse(SearchResponse searchResponse) { STIX2IOCDto ioc = STIX2IOCDto.parse(xcp, hit.getId(), hit.getVersion()); - // TODO integrate with findings API that returns IOCMatches - long numFindings = 0L; - - iocs.add(new DetailedSTIX2IOCDto(ioc, numFindings)); + iocIds.add(ioc.getId()); + iocs.add(ioc); } catch (Exception e) { log.error( () -> new ParameterizedMessage("Failed to parse IOC doc from hit {}", hit.getId()), e ); } }); - onOperation(new ListIOCsActionResponse(searchResponse.getHits().getTotalHits().value, iocs)); + + GetIocFindingsRequest getFindingsRequest = new GetIocFindingsRequest( + Collections.emptyList(), + new ArrayList<>(iocIds), + null, + null, + new Table( + "asc", + "timestamp", + request.getTable().getMissing(), + 10000, + 0, + "" + ) + ); + + // Calling GetIocFindings API to get number of findings for each returned IOC + client.execute(GetIocFindingsAction.INSTANCE, getFindingsRequest, new ActionListener<>() { + @Override + public void onResponse(GetIocFindingsResponse getFindingsResponse) { + // Iterate through the GetIocFindingsResponse to count occurrences of each IOC + Map iocIdToNumFindings = new HashMap<>(); + for (IocFinding iocFinding : getFindingsResponse.getIocFindings()) { + for (IocWithFeeds iocWithFeeds : iocFinding.getFeedIds()) { + // Set the count to 0 if it's not already + iocIdToNumFindings.putIfAbsent(iocWithFeeds.getIocId(), 0); + // Increment the count for the IOC + iocIdToNumFindings.merge(iocWithFeeds.getIocId(), 1, Integer::sum); + } + } + + // Iterate through each IOC returned by the SearchRequest to create the detailed model for response + List iocDetails = new ArrayList<>(); + iocs.forEach((ioc) -> { + Integer numFindings = iocIdToNumFindings.get(ioc.getId()); + if (numFindings == null) { + // Logging instances of 'null' separately from 0 instances for investigation purposes + log.debug("Null number of findings found for IOC {}", ioc.getId()); + numFindings = 0; + } + iocDetails.add(new DetailedSTIX2IOCDto(ioc, numFindings)); + }); + + onOperation(new ListIOCsActionResponse(searchResponse.getHits().getTotalHits().value, iocDetails)); + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to get IOC findings count:", e); + listener.onFailure(SecurityAnalyticsException.wrap(e)); + } + }); } @Override diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/ThreatIntelMonitorRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/ThreatIntelMonitorRestApiIT.java index e6a61df0d..9aab45814 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/ThreatIntelMonitorRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/ThreatIntelMonitorRestApiIT.java @@ -6,6 +6,7 @@ import org.apache.logging.log4j.Logger; import org.junit.Assert; import org.opensearch.client.Response; +import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.commons.alerting.model.IntervalSchedule; import org.opensearch.commons.alerting.model.Monitor; @@ -13,6 +14,7 @@ import org.opensearch.search.SearchHit; import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; import org.opensearch.securityanalytics.SecurityAnalyticsRestTestCase; +import org.opensearch.securityanalytics.action.ListIOCsActionRequest; import org.opensearch.securityanalytics.commons.model.IOCType; import org.opensearch.securityanalytics.model.STIX2IOC; import org.opensearch.securityanalytics.threatIntel.common.RefreshType; @@ -29,6 +31,7 @@ import java.io.IOException; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -40,7 +43,47 @@ import static org.opensearch.securityanalytics.threatIntel.resthandler.monitor.RestSearchThreatIntelMonitorAction.SEARCH_THREAT_INTEL_MONITOR_PATH; public class ThreatIntelMonitorRestApiIT extends SecurityAnalyticsRestTestCase { - private static final Logger log = LogManager.getLogger(ThreatIntelMonitorRestApiIT.class); + private final String iocIndexMappings = "\"properties\": {\n" + + " \"stix2_ioc\": {\n" + + " \"properties\": {\n" + + " \"name\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"type\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"value\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"severity\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"spec_version\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"created\": {\n" + + " \"type\": \"date\"\n" + + " },\n" + + " \"modified\": {\n" + + " \"type\": \"date\"\n" + + " },\n" + + " \"description\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"labels\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"feed_id\": {\n" + + " \"type\": \"keyword\"\n" + + " },\n" + + " \"feed_name\": {\n" + + " \"type\": \"keyword\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }"; + + private List testIocs = new ArrayList<>(); public void indexSourceConfigsAndIocs(int num, List iocVals) throws IOException { for (int i = 0; i < num; i++) { @@ -48,6 +91,12 @@ public void indexSourceConfigsAndIocs(int num, List iocVals) throws IOEx String iocActiveIndex = ".opensearch-sap-ioc-" + configId + Instant.now().toEpochMilli(); String indexPattern = ".opensearch-sap-ioc-" + configId; indexTifSourceConfig(num, configId, indexPattern, iocActiveIndex, i); + + // Create the index before ingesting docs to ensure the mappings are correct + createIndex(iocActiveIndex, Settings.EMPTY, iocIndexMappings); + + // Refresh testIocs list between tests + testIocs = new ArrayList<>(); for (int i1 = 0; i1 < iocVals.size(); i1++) { indexIocs(iocVals, iocActiveIndex, i1, configId); } @@ -71,6 +120,10 @@ private void indexIocs(List iocVals, String iocIndexName, int i1, String "", STIX2IOC.NO_VERSION ); + + // Add IOC to testIocs List for future validation + testIocs.add(stix2IOC); + indexDoc(iocIndexName, iocId, stix2IOC.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS).toString()); List searchHits = executeSearch(iocIndexName, getMatchAllSearchRequestString(iocVals.size())); assertEquals(searchHits.size(), i1 + 1); @@ -179,6 +232,20 @@ public void testCreateThreatIntelMonitor() throws IOException { Map.of(), null); responseAsMap = responseAsMap(iocFindingsResponse); Assert.assertEquals(4, ((List>) responseAsMap.get("ioc_findings")).size()); + + // Use ListIOCs API to confirm expected number of findings are returned + String listIocsUri = String.format("?%s=%s", ListIOCsActionRequest.FEED_IDS_FIELD, "id0"); + Response listIocsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.LIST_IOCS_URI + listIocsUri, Collections.emptyMap(), null); + Map listIocsResponseMap = responseAsMap(listIocsResponse); + List> iocsMap = (List>) listIocsResponseMap.get("iocs"); + assertEquals(2, iocsMap.size()); + iocsMap.forEach((iocDetails) -> { + String iocId = (String) iocDetails.get("id"); + int numFindings = (Integer) iocDetails.get("num_findings"); + assertTrue(testIocs.stream().anyMatch(ioc -> iocId.equals(ioc.getId()))); + assertEquals(2, numFindings); + }); + //alerts via system index search searchHits = executeSearch(ThreatIntelAlertService.THREAT_INTEL_ALERT_ALIAS_NAME, matchAllRequest); Assert.assertEquals(4, searchHits.size());