Skip to content

Commit

Permalink
[BUG_FIX] fix check for agg rules in detector trigger condition to cr…
Browse files Browse the repository at this point in the history
…eate chained findings monitor (#992)

* remove chekc for agg rules in detector trigger condition to create bucket level monitor

Signed-off-by: Surya Sashank Nistala <[email protected]>

* add agg rules tags in chained monitor query to match trigger condition of detector

Signed-off-by: Surya Sashank Nistala <[email protected]>

* fix check to evaluate agg rules present when creating chained findings monitor

Signed-off-by: Surya Sashank Nistala <[email protected]>

* fix tests where check on group by trigger existed earlier

Signed-off-by: Surya Sashank Nistala <[email protected]>

* fix race condition while creating first monitor

Signed-off-by: Surya Sashank Nistala <[email protected]>

* add test to verify detector trigger function for aggregation rules

Signed-off-by: Surya Sashank Nistala <[email protected]>

* revert step listener change

Signed-off-by: Surya Sashank Nistala <[email protected]>

---------

Signed-off-by: Surya Sashank Nistala <[email protected]>
  • Loading branch information
eirsep committed Apr 28, 2024
1 parent 7c6b79d commit 0feb950
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.RangeQueryBuilder;
Expand Down Expand Up @@ -97,7 +97,6 @@
import org.opensearch.securityanalytics.rules.exceptions.SigmaError;
import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings;
import org.opensearch.securityanalytics.util.DetectorIndices;
import org.opensearch.securityanalytics.util.DetectorUtils;
import org.opensearch.securityanalytics.util.ExceptionChecker;
import org.opensearch.securityanalytics.util.IndexUtils;
import org.opensearch.securityanalytics.util.MonitorService;
Expand All @@ -117,6 +116,8 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -705,19 +706,29 @@ private void handleUpsertWorkflowFailure(final Exception e, final ActionListener
*/
private IndexMonitorRequest createDocLevelMonitorMatchAllRequest(
Detector detector,
WriteRequest.RefreshPolicy refreshPolicy,
RefreshPolicy refreshPolicy,
String monitorId,
RestRequest.Method restMethod
) {
Method restMethod,
List<Pair<String, Rule>> queries) {
List<DocLevelMonitorInput> docLevelMonitorInputs = new ArrayList<>();
List<DocLevelQuery> docLevelQueries = new ArrayList<>();
String monitorName = detector.getName() + "_chained_findings";
String actualQuery = "_id:*";
Set<String> tags = new HashSet<>();
for (Pair<String, Rule> query: queries) {
if(query.getRight().isAggregationRule()) {
Rule rule = query.getRight();
tags.add(rule.getLevel());
tags.add(rule.getCategory());
tags.addAll(rule.getTags().stream().map(Value::getValue).collect(Collectors.toList()));
}
}
tags.removeIf(Objects::isNull);
DocLevelQuery docLevelQuery = new DocLevelQuery(
monitorName,
monitorName + "doc",
actualQuery,
Collections.emptyList()
new ArrayList<>(tags)
);
docLevelQueries.add(docLevelQuery);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,12 @@ public void onFailure(Exception e) {
});
}

public static List<String> getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger(
Detector detector,
List<Pair<String, Rule>> rulesById,
public static List<String> getBucketLevelMonitorIds(
List<IndexMonitorResponse> monitorResponses
) {
List<String> aggRuleIdsConfiguredToTrigger = getAggRuleIdsConfiguredToTrigger(detector, rulesById);
return monitorResponses.stream().filter(
// In the case of bucket level monitors rule id is trigger id
it -> Monitor.MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType()
&& !it.getMonitor().getTriggers().isEmpty()
&& aggRuleIdsConfiguredToTrigger.contains(it.getMonitor().getTriggers().get(0).getId())
).map(IndexMonitorResponse::getId).collect(Collectors.toList());
}
public static List<String> getAggRuleIdsConfiguredToTrigger(Detector detector, List<Pair<String, Rule>> rulesById) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.opensearch.commons.alerting.model.ChainedMonitorFindings;
import org.opensearch.commons.alerting.model.CompositeInput;
import org.opensearch.commons.alerting.model.Delegate;
import org.opensearch.commons.alerting.model.Monitor.MonitorType;
import org.opensearch.commons.alerting.model.Sequence;
import org.opensearch.commons.alerting.model.Workflow;
import org.opensearch.commons.alerting.model.Workflow.WorkflowType;
Expand All @@ -34,12 +33,11 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import static org.opensearch.securityanalytics.util.DetectorUtils.getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger;
import static org.opensearch.securityanalytics.util.DetectorUtils.getBucketLevelMonitorIds;

/**
* Alerting common clas used for workflow manipulation
Expand Down Expand Up @@ -101,7 +99,7 @@ public void upsertWorkflow(
monitorResponses.addAll(updatedMonitorResponses);
}
cmfMonitorId = addedMonitorResponses.stream().filter(res -> (detector.getName() + "_chained_findings").equals(res.getMonitor().getName())).findFirst().get().getId();
chainedMonitorFindings = new ChainedMonitorFindings(null, getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger(detector, rulesById, monitorResponses));
chainedMonitorFindings = new ChainedMonitorFindings(null, getBucketLevelMonitorIds(monitorResponses));
}

IndexWorkflowRequest indexWorkflowRequest = createWorkflowRequest(monitorIds,
Expand Down Expand Up @@ -149,16 +147,21 @@ public void deleteWorkflow(String workflowId, ActionListener<DeleteWorkflowRespo
private IndexWorkflowRequest createWorkflowRequest(List<String> monitorIds, Detector detector, RefreshPolicy refreshPolicy, String workflowId, Method method,
ChainedMonitorFindings chainedMonitorFindings, String cmfMonitorId) {
AtomicInteger index = new AtomicInteger();
List<Delegate> delegates = monitorIds.stream().map(
monitorId -> {
ChainedMonitorFindings cmf = null;
if (cmfMonitorId != null && chainedMonitorFindings != null && Objects.equals(monitorId, cmfMonitorId)) {
cmf = Objects.equals(monitorId, cmfMonitorId) ? chainedMonitorFindings : null;
}
Delegate delegate = new Delegate(index.incrementAndGet(), monitorId, cmf);
return delegate;
}
).collect(Collectors.toList());
List<Delegate> delegates = new ArrayList<>();
ChainedMonitorFindings cmf = null;
for (String monitorId : monitorIds) {
if (cmfMonitorId != null && chainedMonitorFindings != null && Objects.equals(monitorId, cmfMonitorId)) {
cmf = Objects.equals(monitorId, cmfMonitorId) ? chainedMonitorFindings : null;
} else {
Delegate delegate = new Delegate(index.incrementAndGet(), monitorId, null);
delegates.add(delegate);
}
}
if (cmf != null) {
// Add cmf with maximum value on "index"
Delegate cmfDelegate = new Delegate(index.incrementAndGet(), cmfMonitorId, cmf);
delegates.add(cmfDelegate);
}

Sequence sequence = new Sequence(delegates);
CompositeInput compositeInput = new CompositeInput(sequence);
Expand Down
148 changes: 139 additions & 9 deletions src/test/java/org/opensearch/securityanalytics/alerts/AlertsIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,14 @@

package org.opensearch.securityanalytics.alerts;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.http.HttpStatus;
import org.apache.http.entity.StringEntity;
import org.apache.http.message.BasicHeader;
import org.junit.Assert;
import org.junit.Ignore;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.commons.alerting.model.Monitor;
import org.opensearch.commons.alerting.model.action.Action;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.search.SearchHit;
Expand All @@ -33,8 +25,20 @@
import org.opensearch.securityanalytics.model.DetectorRule;
import org.opensearch.securityanalytics.model.DetectorTrigger;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static org.opensearch.securityanalytics.TestHelpers.netFlowMappings;
import static org.opensearch.securityanalytics.TestHelpers.randomAction;
import static org.opensearch.securityanalytics.TestHelpers.randomAggregationRule;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndTriggers;
import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithTriggers;
Expand Down Expand Up @@ -657,6 +661,132 @@ public void testAlertHistoryRollover_maxAge() throws IOException, InterruptedExc

restoreAlertsFindingsIMSettings();
}
/**
* 1. Creates detector with aggregation and prepackaged rules
* (sum rule - should match docIds: 1, 2, 3; maxRule - 4, 5, 6, 7; minRule - 7)
* 2. Verifies monitor execution
* 3. Verifies alerts
*
* @throws IOException
*/
public void testMultipleAggregationAndDocRules_alertSuccess() throws IOException {
String index = createTestIndex(randomIndex(), windowsIndexMapping());

Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI);
createMappingRequest.setJsonEntity(
"{ \"index_name\":\"" + index + "\"," +
" \"rule_topic\":\"" + randomDetectorType() + "\", " +
" \"partial\":true" +
"}"
);

Response createMappingResponse = client().performRequest(createMappingRequest);

assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode());

String infoOpCode = "Info";

String sumRuleId = createRule(randomAggregationRule("sum", " > 1", infoOpCode));


List<DetectorRule> detectorRules = List.of(new DetectorRule(sumRuleId));

DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules,
Collections.emptyList());
Detector detector = randomDetectorWithInputsAndTriggers(List.of(input),
List.of(new DetectorTrigger("randomtrigegr", "test-trigger", "1", List.of(randomDetectorType()), List.of(), List.of(), List.of(), List.of(), List.of()))
);

Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector));


String request = "{\n" +
" \"query\" : {\n" +
" \"match_all\":{\n" +
" }\n" +
" }\n" +
"}";
SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true);

assertEquals(1, response.getHits().getTotalHits().value); // 5 for rules, 1 for match_all query in chained findings monitor

assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse));
Map<String, Object> responseBody = asMap(createResponse);
String detectorId = responseBody.get("_id").toString();
request = "{\n" +
" \"query\" : {\n" +
" \"match\":{\n" +
" \"_id\": \"" + detectorId + "\"\n" +
" }\n" +
" }\n" +
"}";
List<SearchHit> hits = executeSearch(Detector.DETECTORS_INDEX, request);
SearchHit hit = hits.get(0);
Map<String, List> updatedDetectorMap = (HashMap<String, List>) (hit.getSourceAsMap().get("detector"));

List<String> monitorIds = ((List<String>) (updatedDetectorMap).get("monitor_id"));

indexDoc(index, "1", randomDoc(2, 4, infoOpCode));
indexDoc(index, "2", randomDoc(3, 4, infoOpCode));

Map<String, Integer> numberOfMonitorTypes = new HashMap<>();

for (String monitorId : monitorIds) {
Map<String, String> monitor = (Map<String, String>) (entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor");
numberOfMonitorTypes.merge(monitor.get("monitor_type"), 1, Integer::sum);
Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap());

// Assert monitor executions
Map<String, Object> executeResults = entityAsMap(executeResponse);
if (Monitor.MonitorType.DOC_LEVEL_MONITOR.getValue().equals(monitor.get("monitor_type")) && false == monitor.get("name").equals(detector.getName() + "_chained_findings")) {
int noOfSigmaRuleMatches = ((List<Map<String, Object>>) ((Map<String, Object>) executeResults.get("input_results")).get("results")).get(0).size();
assertEquals(5, noOfSigmaRuleMatches);
}
}

assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.BUCKET_LEVEL_MONITOR.getValue()).intValue());
assertEquals(1, numberOfMonitorTypes.get(Monitor.MonitorType.DOC_LEVEL_MONITOR.getValue()).intValue());

Map<String, String> params = new HashMap<>();
params.put("detector_id", detectorId);
Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null);
Map<String, Object> getFindingsBody = entityAsMap(getFindingsResponse);

assertNotNull(getFindingsBody);
assertEquals(1, getFindingsBody.get("total_findings"));

String findingDetectorId = ((Map<String, Object>) ((List) getFindingsBody.get("findings")).get(0)).get("detectorId").toString();
assertEquals(detectorId, findingDetectorId);

String findingIndex = ((Map<String, Object>) ((List) getFindingsBody.get("findings")).get(0)).get("index").toString();
assertEquals(index, findingIndex);

List<String> docLevelFinding = new ArrayList<>();
List<Map<String, Object>> findings = (List) getFindingsBody.get("findings");


for (Map<String, Object> finding : findings) {
List<Map<String, Object>> queries = (List<Map<String, Object>>) finding.get("queries");
Set<String> findingRuleIds = queries.stream().map(it -> it.get("id").toString()).collect(Collectors.toSet());

// In the case of bucket level monitors, queries will always contain one value
String aggRuleId = findingRuleIds.iterator().next();
List<String> findingDocs = (List<String>) finding.get("related_doc_ids");

if (aggRuleId.equals(sumRuleId)) {
assertTrue(List.of("1", "2", "3", "4", "5", "6", "7").containsAll(findingDocs));
}
}

assertTrue(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8").containsAll(docLevelFinding));

Map<String, String> params1 = new HashMap<>();
params1.put("detector_id", detectorId);
Response getAlertsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.ALERTS_BASE_URI, params1, null);
Map<String, Object> getAlertsBody = asMap(getAlertsResponse);
// TODO enable asserts here when able
Assert.assertEquals(3, getAlertsBody.get("total_alerts")); // 2 doc level alerts for each doc, 1 bucket level alert
}

public void testAlertHistoryRollover_maxAge_low_retention() throws IOException, InterruptedException {
updateClusterSetting(ALERT_HISTORY_ROLLOVER_PERIOD.getKey(), "1s");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ public void testMultipleAggregationAndDocRules_findingSuccess() throws IOExcepti
"}";
SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true);

assertEquals(6, response.getHits().getTotalHits().value);
assertEquals(7, response.getHits().getTotalHits().value); // 6 for rules, 1 for match_all query in chained findings monitor

assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse));
Map<String, Object> responseBody = asMap(createResponse);
Expand All @@ -994,8 +994,7 @@ public void testMultipleAggregationAndDocRules_findingSuccess() throws IOExcepti
assertEquals(6, ((Map<String, Map<String, List>>) inputArr.get(0)).get("detector_input").get("custom_rules").size());

List<String> monitorIds = ((List<String>) (updatedDetectorMap).get("monitor_id"));

assertEquals(6, monitorIds.size());
assertEquals(7, monitorIds.size());

indexDoc(index, "1", randomDoc(2, 4, infoOpCode));
indexDoc(index, "2", randomDoc(3, 4, infoOpCode));
Expand Down Expand Up @@ -1037,7 +1036,7 @@ else if (ruleId == minRuleId) {
}

assertEquals(5, numberOfMonitorTypes.get(MonitorType.BUCKET_LEVEL_MONITOR.getValue()).intValue());
assertEquals(1, numberOfMonitorTypes.get(MonitorType.DOC_LEVEL_MONITOR.getValue()).intValue());
assertEquals(2, numberOfMonitorTypes.get(MonitorType.DOC_LEVEL_MONITOR.getValue()).intValue());

Map<String, String> params = new HashMap<>();
params.put("detector_id", detectorId);
Expand Down Expand Up @@ -1122,7 +1121,7 @@ public void testCreateDetector_verifyWorkflowCreation_success_WithoutGroupByRule
"}";
SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true);

assertEquals(1, response.getHits().getTotalHits().value);
assertEquals(2, response.getHits().getTotalHits().value);

assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse));
Map<String, Object> responseBody = asMap(createResponse);
Expand All @@ -1143,13 +1142,13 @@ public void testCreateDetector_verifyWorkflowCreation_success_WithoutGroupByRule
assertEquals(2, ((Map<String, Map<String, List>>) inputArr.get(0)).get("detector_input").get("custom_rules").size());

List<String> monitorIds = ((List<String>) (detectorMap).get("monitor_id"));
assertEquals(2, monitorIds.size());
assertEquals(3, monitorIds.size());

assertNotNull("Workflow not created", detectorMap.get("workflow_ids"));
assertEquals("Number of workflows not correct", 1, ((List<String>) detectorMap.get("workflow_ids")).size());

// Verify workflow
verifyWorkflow(detectorMap, monitorIds, 2);
verifyWorkflow(detectorMap, monitorIds, 3);
}


Expand Down Expand Up @@ -1699,6 +1698,7 @@ public void testCreateDetector_verifyWorkflowExecutionMultipleBucketLevelDocLeve
indexDoc(index, "7", randomDoc(6, 2, testOpCode));
indexDoc(index, "8", randomDoc(1, 1, testOpCode));
// Verify workflow

verifyWorkflow(detectorMap, monitorIds, 7);

String workflowId = ((List<String>) detectorMap.get("workflow_ids")).get(0);
Expand Down

0 comments on commit 0feb950

Please sign in to comment.