Skip to content

Commit

Permalink
[Bug] Fixed ListIOCs number of findings cap. (opensearch-project#1373)
Browse files Browse the repository at this point in the history
* Fixed finding number returned by ListIOCs API capping at 10,000.

Signed-off-by: AWSHurneyt <[email protected]>

* Added integ test for fix.

Signed-off-by: AWSHurneyt <[email protected]>

* Removed extraneous query params.

Signed-off-by: AWSHurneyt <[email protected]>

* Added additional test case.

Signed-off-by: AWSHurneyt <[email protected]>

---------

Signed-off-by: AWSHurneyt <[email protected]>
  • Loading branch information
AWSHurneyt authored Oct 23, 2024
1 parent 6f543b5 commit d5c8f7a
Show file tree
Hide file tree
Showing 2 changed files with 305 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.alerting.model.Table;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
Expand All @@ -33,6 +32,9 @@
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.SearchHit;
import org.opensearch.search.aggregations.AggregationBuilders;
import org.opensearch.search.aggregations.Aggregations;
import org.opensearch.search.aggregations.bucket.terms.Terms;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.sort.FieldSortBuilder;
import org.opensearch.search.sort.SortBuilder;
Expand All @@ -44,12 +46,8 @@
import org.opensearch.securityanalytics.model.DetailedSTIX2IOCDto;
import org.opensearch.securityanalytics.model.STIX2IOC;
import org.opensearch.securityanalytics.model.STIX2IOCDto;
import org.opensearch.securityanalytics.model.threatintel.IocFinding;
import org.opensearch.securityanalytics.model.threatintel.IocWithFeeds;
import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings;
import org.opensearch.securityanalytics.threatIntel.action.GetIocFindingsAction;
import org.opensearch.securityanalytics.threatIntel.action.GetIocFindingsRequest;
import org.opensearch.securityanalytics.threatIntel.action.GetIocFindingsResponse;
import org.opensearch.securityanalytics.threatIntel.iocscan.dao.IocFindingService;
import org.opensearch.securityanalytics.threatIntel.model.DefaultIocStoreConfig;
import org.opensearch.securityanalytics.threatIntel.model.SATIFSourceConfig;
import org.opensearch.securityanalytics.threatIntel.service.DefaultTifSourceConfigLoaderService;
Expand Down Expand Up @@ -87,6 +85,10 @@ public class TransportListIOCsAction extends HandledTransportAction<ListIOCsActi
private final SATIFSourceConfigService saTifSourceConfigService;
private final Settings settings;
private volatile Boolean filterByEnabled;
private final IocFindingService iocFindingService;

public static String IOC_COUNT_AGG_NAME = "ioc_id_count";
public static String IOC_ID_KEYWORD_FIELD = "ioc_feed_ids.ioc_id.keyword";

@Inject
public TransportListIOCsAction(
Expand All @@ -110,6 +112,7 @@ public TransportListIOCsAction(
this.threadPool = this.client.threadPool();
this.settings = settings;
this.filterByEnabled = SecurityAnalyticsSettings.FILTER_BY_BACKEND_ROLES.get(this.settings);
this.iocFindingService = new IocFindingService(client, clusterService, xContentRegistry);
}

@Override
Expand Down Expand Up @@ -216,6 +219,7 @@ private void listIocs(List<String> iocIndices) {
.version(true)
.seqNoAndPrimaryTerm(true)
.fetchSource(true)
.trackTotalHits(true)
.query(boolQueryBuilder)
.sort(sortBuilder)
.size(request.getTable().getSize())
Expand All @@ -233,80 +237,7 @@ public void onResponse(SearchResponse searchResponse) {
onFailures(new OpenSearchStatusException("Search request timed out", RestStatus.REQUEST_TIMEOUT));
}

// Concurrently compiling a separate list of IOC IDs to create the subsequent GetIocFindingsRequest
Set<String> iocIds = new HashSet<>();
List<STIX2IOCDto> iocs = new ArrayList<>();
Arrays.stream(searchResponse.getHits().getHits())
.forEach(hit -> {
try {
XContentParser xcp = XContentType.JSON.xContent().createParser(
xContentRegistry,
LoggingDeprecationHandler.INSTANCE,
hit.getSourceAsString());
xcp.nextToken();

STIX2IOCDto ioc = STIX2IOCDto.parse(xcp, hit.getId(), hit.getVersion());

iocIds.add(ioc.getId());
iocs.add(ioc);
} catch (Exception e) {
log.error(
() -> new ParameterizedMessage("Failed to parse IOC doc from hit {}", hit.getId()), e
);
}
});

GetIocFindingsRequest getFindingsRequest = new GetIocFindingsRequest(
Collections.emptyList(),
new ArrayList<>(iocIds),
null,
null,
new Table(
"asc",
"timestamp",
request.getTable().getMissing(),
10000,
0,
""
)
);

// Calling GetIocFindings API to get number of findings for each returned IOC
client.execute(GetIocFindingsAction.INSTANCE, getFindingsRequest, new ActionListener<>() {
@Override
public void onResponse(GetIocFindingsResponse getFindingsResponse) {
// Iterate through the GetIocFindingsResponse to count occurrences of each IOC
Map<String, Integer> iocIdToNumFindings = new HashMap<>();
for (IocFinding iocFinding : getFindingsResponse.getIocFindings()) {
for (IocWithFeeds iocWithFeeds : iocFinding.getFeedIds()) {
// Set the count to 0 if it's not already
iocIdToNumFindings.putIfAbsent(iocWithFeeds.getIocId(), 0);
// Increment the count for the IOC
iocIdToNumFindings.merge(iocWithFeeds.getIocId(), 1, Integer::sum);
}
}

// Iterate through each IOC returned by the SearchRequest to create the detailed model for response
List<DetailedSTIX2IOCDto> iocDetails = new ArrayList<>();
iocs.forEach((ioc) -> {
Integer numFindings = iocIdToNumFindings.get(ioc.getId());
if (numFindings == null) {
// Logging instances of 'null' separately from 0 instances for investigation purposes
log.debug("Null number of findings found for IOC {}", ioc.getId());
numFindings = 0;
}
iocDetails.add(new DetailedSTIX2IOCDto(ioc, numFindings));
});

onOperation(new ListIOCsActionResponse(searchResponse.getHits().getTotalHits().value, iocDetails));
}

@Override
public void onFailure(Exception e) {
log.error("Failed to get IOC findings count:", e);
listener.onFailure(SecurityAnalyticsException.wrap(e));
}
});
getFindingsCount(searchResponse);
}

@Override
Expand All @@ -322,6 +253,79 @@ public void onFailure(Exception e) {
});
}

private void getFindingsCount(SearchResponse iocSearchResponse) {
// Concurrently compiling a separate list of IOC IDs to create the subsequent findings count searchRequest
Set<String> iocIds = new HashSet<>();
List<STIX2IOCDto> iocs = new ArrayList<>();
Arrays.stream(iocSearchResponse.getHits().getHits())
.forEach(hit -> {
try {
XContentParser xcp = XContentType.JSON.xContent().createParser(
xContentRegistry,
LoggingDeprecationHandler.INSTANCE,
hit.getSourceAsString());
xcp.nextToken();

STIX2IOCDto ioc = STIX2IOCDto.parse(xcp, hit.getId(), hit.getVersion());

iocIds.add(ioc.getId());
iocs.add(ioc);
} catch (Exception e) {
log.error(
() -> new ParameterizedMessage("Failed to parse IOC doc from hit {}", hit.getId()), e
);
}
});

// Create an aggregation query that will group by the IOC IDs in the findings
SearchSourceBuilder findingsCountSourceBuilder = new SearchSourceBuilder()
.fetchSource(false)
.trackTotalHits(true)
.query(QueryBuilders.termsQuery(IOC_ID_KEYWORD_FIELD, iocIds))
.size(0)
.aggregation(
AggregationBuilders
.terms(IOC_COUNT_AGG_NAME)
.field(IOC_ID_KEYWORD_FIELD)
.size(iocIds.size())
);

iocFindingService.search(findingsCountSourceBuilder, new ActionListener<>() {
@Override
public void onResponse(SearchResponse findingsSearchResponse) {
Map<String, Integer> iocIdToNumFindings = new HashMap<>();

// Retrieve and store the counts from the aggregation response
Aggregations aggregations = findingsSearchResponse.getAggregations();
if (aggregations != null) {
Terms iocIdCount = aggregations.get(IOC_COUNT_AGG_NAME);
if (iocIdCount != null) {
for (Terms.Bucket bucket : iocIdCount.getBuckets()) {
String iocId = bucket.getKeyAsString();
long findingCount = bucket.getDocCount();
iocIdToNumFindings.put(iocId, (int) findingCount);
}
}
}

// Iterate through each IOC returned by the SearchRequest to create the detailed model for response
List<DetailedSTIX2IOCDto> iocDetails = new ArrayList<>();
iocs.forEach((ioc) -> {
Integer numFindings = iocIdToNumFindings.getOrDefault(ioc.getId(), 0);
iocDetails.add(new DetailedSTIX2IOCDto(ioc, numFindings));
});

// Return API response
onOperation(new ListIOCsActionResponse(iocSearchResponse.getHits().getTotalHits().value, iocDetails));
}

@Override
public void onFailure(Exception e) {
log.error("Failed to get IOC findings count:", e);
listener.onFailure(SecurityAnalyticsException.wrap(e));
}
});
}
private void onOperation(ListIOCsActionResponse response) {
this.response.set(response);
if (counter.compareAndSet(false, true)) {
Expand Down
Loading

0 comments on commit d5c8f7a

Please sign in to comment.