From d5c8f7a374027565e77e59c5d7a3aaac8a85ac4e Mon Sep 17 00:00:00 2001 From: AWSHurneyt Date: Wed, 23 Oct 2024 14:39:02 -0700 Subject: [PATCH] [Bug] Fixed ListIOCs number of findings cap. (#1373) * Fixed finding number returned by ListIOCs API capping at 10,000. Signed-off-by: AWSHurneyt * Added integ test for fix. Signed-off-by: AWSHurneyt * Removed extraneous query params. Signed-off-by: AWSHurneyt * Added additional test case. Signed-off-by: AWSHurneyt --------- Signed-off-by: AWSHurneyt --- .../transport/TransportListIOCsAction.java | 164 ++++++------- .../resthandler/ListIOCsRestApiIT.java | 221 ++++++++++++++++++ 2 files changed, 305 insertions(+), 80 deletions(-) diff --git a/src/main/java/org/opensearch/securityanalytics/threatIntel/transport/TransportListIOCsAction.java b/src/main/java/org/opensearch/securityanalytics/threatIntel/transport/TransportListIOCsAction.java index 4131c00ca..80a6b538c 100644 --- a/src/main/java/org/opensearch/securityanalytics/threatIntel/transport/TransportListIOCsAction.java +++ b/src/main/java/org/opensearch/securityanalytics/threatIntel/transport/TransportListIOCsAction.java @@ -22,7 +22,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentType; -import org.opensearch.commons.alerting.model.Table; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; @@ -33,6 +32,9 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.search.SearchHit; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.terms.Terms; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.sort.FieldSortBuilder; import org.opensearch.search.sort.SortBuilder; @@ -44,12 +46,8 @@ import org.opensearch.securityanalytics.model.DetailedSTIX2IOCDto; import org.opensearch.securityanalytics.model.STIX2IOC; import org.opensearch.securityanalytics.model.STIX2IOCDto; -import org.opensearch.securityanalytics.model.threatintel.IocFinding; -import org.opensearch.securityanalytics.model.threatintel.IocWithFeeds; import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; -import org.opensearch.securityanalytics.threatIntel.action.GetIocFindingsAction; -import org.opensearch.securityanalytics.threatIntel.action.GetIocFindingsRequest; -import org.opensearch.securityanalytics.threatIntel.action.GetIocFindingsResponse; +import org.opensearch.securityanalytics.threatIntel.iocscan.dao.IocFindingService; import org.opensearch.securityanalytics.threatIntel.model.DefaultIocStoreConfig; import org.opensearch.securityanalytics.threatIntel.model.SATIFSourceConfig; import org.opensearch.securityanalytics.threatIntel.service.DefaultTifSourceConfigLoaderService; @@ -87,6 +85,10 @@ public class TransportListIOCsAction extends HandledTransportAction iocIndices) { .version(true) .seqNoAndPrimaryTerm(true) .fetchSource(true) + .trackTotalHits(true) .query(boolQueryBuilder) .sort(sortBuilder) .size(request.getTable().getSize()) @@ -233,80 +237,7 @@ public void onResponse(SearchResponse searchResponse) { onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT)); } - // Concurrently compiling a separate list of IOC IDs to create the subsequent GetIocFindingsRequest - Set iocIds = new HashSet<>(); - List iocs = new ArrayList<>(); - Arrays.stream(searchResponse.getHits().getHits()) - .forEach(hit -> { - try { - XContentParser xcp = XContentType.JSON.xContent().createParser( - xContentRegistry, - LoggingDeprecationHandler.INSTANCE, - hit.getSourceAsString()); - xcp.nextToken(); - - STIX2IOCDto ioc = STIX2IOCDto.parse(xcp, hit.getId(), hit.getVersion()); - - iocIds.add(ioc.getId()); - iocs.add(ioc); - } catch (Exception e) { - log.error( - () -> new ParameterizedMessage("Failed to parse IOC doc from hit {}", hit.getId()), e - ); - } - }); - - GetIocFindingsRequest getFindingsRequest = new GetIocFindingsRequest( - Collections.emptyList(), - new ArrayList<>(iocIds), - null, - null, - new Table( - "asc", - "timestamp", - request.getTable().getMissing(), - 10000, - 0, - "" - ) - ); - - // Calling GetIocFindings API to get number of findings for each returned IOC - client.execute(GetIocFindingsAction.INSTANCE, getFindingsRequest, new ActionListener<>() { - @Override - public void onResponse(GetIocFindingsResponse getFindingsResponse) { - // Iterate through the GetIocFindingsResponse to count occurrences of each IOC - Map iocIdToNumFindings = new HashMap<>(); - for (IocFinding iocFinding : getFindingsResponse.getIocFindings()) { - for (IocWithFeeds iocWithFeeds : iocFinding.getFeedIds()) { - // Set the count to 0 if it's not already - iocIdToNumFindings.putIfAbsent(iocWithFeeds.getIocId(), 0); - // Increment the count for the IOC - iocIdToNumFindings.merge(iocWithFeeds.getIocId(), 1, Integer::sum); - } - } - - // Iterate through each IOC returned by the SearchRequest to create the detailed model for response - List iocDetails = new ArrayList<>(); - iocs.forEach((ioc) -> { - Integer numFindings = iocIdToNumFindings.get(ioc.getId()); - if (numFindings == null) { - // Logging instances of 'null' separately from 0 instances for investigation purposes - log.debug("Null number of findings found for IOC {}", ioc.getId()); - numFindings = 0; - } - iocDetails.add(new DetailedSTIX2IOCDto(ioc, numFindings)); - }); - - onOperation(new ListIOCsActionResponse(searchResponse.getHits().getTotalHits().value, iocDetails)); - } - - @Override - public void onFailure(Exception e) { - log.error("Failed to get IOC findings count:", e); - listener.onFailure(SecurityAnalyticsException.wrap(e)); - } - }); + getFindingsCount(searchResponse); } @Override @@ -322,6 +253,79 @@ public void onFailure(Exception e) { }); } + private void getFindingsCount(SearchResponse iocSearchResponse) { + // Concurrently compiling a separate list of IOC IDs to create the subsequent findings count searchRequest + Set iocIds = new HashSet<>(); + List iocs = new ArrayList<>(); + Arrays.stream(iocSearchResponse.getHits().getHits()) + .forEach(hit -> { + try { + XContentParser xcp = XContentType.JSON.xContent().createParser( + xContentRegistry, + LoggingDeprecationHandler.INSTANCE, + hit.getSourceAsString()); + xcp.nextToken(); + + STIX2IOCDto ioc = STIX2IOCDto.parse(xcp, hit.getId(), hit.getVersion()); + + iocIds.add(ioc.getId()); + iocs.add(ioc); + } catch (Exception e) { + log.error( + () -> new ParameterizedMessage("Failed to parse IOC doc from hit {}", hit.getId()), e + ); + } + }); + + // Create an aggregation query that will group by the IOC IDs in the findings + SearchSourceBuilder findingsCountSourceBuilder = new SearchSourceBuilder() + .fetchSource(false) + .trackTotalHits(true) + .query(QueryBuilders.termsQuery(IOC_ID_KEYWORD_FIELD, iocIds)) + .size(0) + .aggregation( + AggregationBuilders + .terms(IOC_COUNT_AGG_NAME) + .field(IOC_ID_KEYWORD_FIELD) + .size(iocIds.size()) + ); + + iocFindingService.search(findingsCountSourceBuilder, new ActionListener<>() { + @Override + public void onResponse(SearchResponse findingsSearchResponse) { + Map iocIdToNumFindings = new HashMap<>(); + + // Retrieve and store the counts from the aggregation response + Aggregations aggregations = findingsSearchResponse.getAggregations(); + if (aggregations != null) { + Terms iocIdCount = aggregations.get(IOC_COUNT_AGG_NAME); + if (iocIdCount != null) { + for (Terms.Bucket bucket : iocIdCount.getBuckets()) { + String iocId = bucket.getKeyAsString(); + long findingCount = bucket.getDocCount(); + iocIdToNumFindings.put(iocId, (int) findingCount); + } + } + } + + // Iterate through each IOC returned by the SearchRequest to create the detailed model for response + List iocDetails = new ArrayList<>(); + iocs.forEach((ioc) -> { + Integer numFindings = iocIdToNumFindings.getOrDefault(ioc.getId(), 0); + iocDetails.add(new DetailedSTIX2IOCDto(ioc, numFindings)); + }); + + // Return API response + onOperation(new ListIOCsActionResponse(iocSearchResponse.getHits().getTotalHits().value, iocDetails)); + } + + @Override + public void onFailure(Exception e) { + log.error("Failed to get IOC findings count:", e); + listener.onFailure(SecurityAnalyticsException.wrap(e)); + } + }); + } private void onOperation(ListIOCsActionResponse response) { this.response.set(response); if (counter.compareAndSet(false, true)) { diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/ListIOCsRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/ListIOCsRestApiIT.java index 240fe962b..bc86e11a1 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/ListIOCsRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/ListIOCsRestApiIT.java @@ -8,26 +8,130 @@ import org.junit.Assert; import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; import org.opensearch.securityanalytics.SecurityAnalyticsRestTestCase; import org.opensearch.securityanalytics.TestHelpers; +import org.opensearch.securityanalytics.model.DetailedSTIX2IOCDto; +import org.opensearch.securityanalytics.model.threatintel.IocFinding; +import org.opensearch.securityanalytics.model.threatintel.IocWithFeeds; import org.opensearch.securityanalytics.threatIntel.action.ListIOCsActionResponse; import org.opensearch.securityanalytics.commons.model.IOCType; import org.opensearch.securityanalytics.model.STIX2IOC; import org.opensearch.securityanalytics.model.STIX2IOCDto; import org.opensearch.securityanalytics.threatIntel.common.SourceConfigType; +import org.opensearch.securityanalytics.threatIntel.iocscan.dao.IocFindingService; import org.opensearch.securityanalytics.threatIntel.model.IocUploadSource; import org.opensearch.securityanalytics.threatIntel.model.SATIFSourceConfigDto; import org.opensearch.securityanalytics.util.STIX2IOCGenerator; import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; public class ListIOCsRestApiIT extends SecurityAnalyticsRestTestCase { + public void testListIOCsWithNoFindingsIndex() throws IOException { + // Delete findings system indexes if they exist + try { + makeRequest(client(), "DELETE", IocFindingService.IOC_FINDING_INDEX_PATTERN_REGEXP, Collections.emptyMap(), null); + } catch (IndexNotFoundException indexNotFoundException) { + logger.info("No threat intel findings indexes to delete."); + } catch (Exception e) { + logger.error(e.getMessage()); + } + + // Create IOCs + String searchString = "test-list-iocs-no-findings-index"; + Map iocs = new HashMap<>(); + for (int i = 0; i < 100; i++) { + String iocId = searchString + "-" + i; + iocs.put( + iocId, + new STIX2IOCDto( + iocId, + iocId + "-name", + new IOCType(IOCType.IPV4_TYPE), + "ipv4value" + i, + "severity", + null, + null, + "description", + List.of("labels"), + "specversion", + "feedId", + "feedName", + 1L + ) + ); + } + + // Creating source config + SATIFSourceConfigDto saTifSourceConfigDto = new SATIFSourceConfigDto( + null, + null, + "test_list_ioc_" + searchString, + "STIX", + SourceConfigType.IOC_UPLOAD, + null, + null, + null, + new IocUploadSource(null, new ArrayList<>(iocs.values())), + null, + null, + null, + null, + null, + null, + null, + false, + List.of(IOCType.IPV4_TYPE), + true + ); + + // Create the IOC system indexes using IOC_UPLOAD config + Response response = makeRequest(client(), "POST", SecurityAnalyticsPlugin.THREAT_INTEL_SOURCE_URI, Collections.emptyMap(), toHttpEntity(saTifSourceConfigDto)); + Assert.assertEquals(RestStatus.CREATED, restStatus(response)); + + // Call ListIOCs API + Map params = Map.of( + "searchString", searchString, + "size", "10000" + ); + Response iocResponse = makeRequest(client(), "GET", STIX2IOCGenerator.getListIOCsURI(), params, null); + Assert.assertEquals(RestStatus.OK, restStatus(iocResponse)); + Map respMap = asMap(iocResponse); + + // Evaluate response + int totalHits = (int) respMap.get(ListIOCsActionResponse.TOTAL_HITS_FIELD); + assertEquals(iocs.size(), totalHits); + + List> iocHits = (List>) respMap.get(ListIOCsActionResponse.HITS_FIELD); + assertEquals(iocs.size(), iocHits.size()); + + iocHits.forEach((hit) -> { + String iocId = (String) hit.get(STIX2IOC.ID_FIELD); + String iocName = (String) hit.get(STIX2IOC.NAME_FIELD); + String iocValue = (String) hit.get(STIX2IOC.VALUE_FIELD); + + STIX2IOCDto iocDto = iocs.get(iocId); + assertNotNull(iocDto); + + assertEquals(iocDto.getId(), iocId); + assertEquals(iocDto.getName(), iocName); + assertEquals(iocDto.getValue(), iocValue); + + int findingsNum = (int) hit.get(DetailedSTIX2IOCDto.NUM_FINDINGS_FIELD); + int expectedNumFindings = 0; + assertEquals(expectedNumFindings, findingsNum); + }); + } + public void testListIOCsBySearchString() throws IOException { String searchString = "test-search-string"; List iocs = List.of( @@ -126,4 +230,121 @@ public void testListIOCsBySearchString() throws IOException { } // TODO: Implement additional tests using various query param combinations + + public void testListIOCsNumFindings() throws Exception { + // Create IOCs + String searchString = "test-list-iocs-num-findings"; + List iocs = new ArrayList<>(); + Map> iocIdFindingsNum = new HashMap<>(); + for (int i = 0; i < 5; i++) { + String iocId = searchString + "-" + i; + iocs.add( + new STIX2IOCDto( + iocId, + iocId + "-name", + new IOCType(IOCType.IPV4_TYPE), + "ipv4value", + "severity", + null, + null, + "description", + List.of("labels"), + "specversion", + "feedId", + "feedName", + 1L + ) + ); + + // Confirming the ListIOCs API can return a findings count greater than 10,000 by giving the first IOC 10,005 findings + int numFindings = i == 0 ? 10005 : randomInt(10); + List iocFindings = generateIOCMatches(numFindings, iocId); + + // Tracking the number of findings expected for each IOC + iocIdFindingsNum.put(iocId, iocFindings); + } + + // Creating source config + SATIFSourceConfigDto saTifSourceConfigDto = new SATIFSourceConfigDto( + null, + null, + "test_list_ioc_" + searchString, + "STIX", + SourceConfigType.IOC_UPLOAD, + null, + null, + null, + new IocUploadSource(null, iocs), + null, + null, + null, + null, + null, + null, + null, + false, + List.of(IOCType.IPV4_TYPE), + true + ); + + // Create the IOC system indexes using IOC_UPLOAD config + Response response = makeRequest(client(), "POST", SecurityAnalyticsPlugin.THREAT_INTEL_SOURCE_URI, Collections.emptyMap(), toHttpEntity(saTifSourceConfigDto)); + Assert.assertEquals(RestStatus.CREATED, restStatus(response)); + + // Generate IOC matches + for (Map.Entry> entry : iocIdFindingsNum.entrySet()) { + ingestIOCMatches(entry.getValue()); + } + + // Call ListIOCs API + Response iocResponse = makeRequest(client(), "GET", STIX2IOCGenerator.getListIOCsURI(), Map.of("searchString", searchString), null); + Assert.assertEquals(RestStatus.OK, restStatus(iocResponse)); + Map respMap = asMap(iocResponse); + + // Evaluate response + int totalHits = (int) respMap.get(ListIOCsActionResponse.TOTAL_HITS_FIELD); + assertEquals(iocs.size(), totalHits); + + List> iocHits = (List>) respMap.get(ListIOCsActionResponse.HITS_FIELD); + assertEquals(iocs.size(), iocHits.size()); + + iocHits.forEach((hit) -> { + String iocId = (String) hit.get(STIX2IOC.ID_FIELD); + int findingsNum = (int) hit.get(DetailedSTIX2IOCDto.NUM_FINDINGS_FIELD); + int expectedNumFindings = iocIdFindingsNum.get(iocId).size(); + assertEquals(expectedNumFindings, findingsNum); + }); + } + + private List generateIOCMatches(int numMatches, String iocId) { + List iocFindings = new ArrayList<>(); + String monitorId = randomAlphaOfLength(10); + String monitorName = randomAlphaOfLength(10); + for (int i = 0; i < numMatches; i++) { + iocFindings.add(new IocFinding( + randomAlphaOfLength(10), + randomList(1, 10, () -> randomAlphaOfLength(10)),//docIds + randomList(1, 10, () -> new IocWithFeeds( + iocId, + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomAlphaOfLength(10)) + ), //feedIds + monitorId, + monitorName, + randomAlphaOfLength(10), + IOCType.IPV4_TYPE, + Instant.now(), + randomAlphaOfLength(10) + )); + } + return iocFindings; + } + + private void ingestIOCMatches(List iocFindings) throws IOException { + for (IocFinding iocFinding: iocFindings) { + makeRequest(client(), "POST", IocFindingService.IOC_FINDING_ALIAS_NAME + "/_doc?refresh", Map.of(), + toHttpEntity(iocFinding)); + } + } }