Skip to content

Commit

Permalink
ENG-50887: Mask value using a masking config (#227)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddhant2001 authored Oct 10, 2024
1 parent e092fbf commit cd60ebb
Show file tree
Hide file tree
Showing 8 changed files with 650 additions and 22 deletions.
4 changes: 2 additions & 2 deletions query-service-impl/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ dependencies {
implementation("org.apache.calcite:calcite-babel:1.34.0") {
because("CVE-2022-39135")
}
implementation("org.apache.avro:avro:1.11.3") {
because("CVE-2023-39410")
implementation("org.apache.avro:avro:1.11.4") {
because("CVE-2024-47561")
}
implementation("org.apache.commons:commons-compress:1.26.0") {
because("CVE-2024-25710")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package org.hypertrace.core.query.service;

import com.typesafe.config.Config;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Value;
import lombok.experimental.NonFinal;
import lombok.extern.slf4j.Slf4j;

@Slf4j
public class HandlerScopedMaskingConfig {
private static final String TENANT_SCOPED_MASKS_CONFIG_KEY = "tenantScopedMaskingConfig";
private Map<String, List<TimeRangeToMaskedAttributes>> tenantToTimeRangeMaskedAttributes =
Collections.emptyMap();

public HandlerScopedMaskingConfig(Config config) {
if (config.hasPath(TENANT_SCOPED_MASKS_CONFIG_KEY)) {
this.tenantToTimeRangeMaskedAttributes =
config.getConfigList(TENANT_SCOPED_MASKS_CONFIG_KEY).stream()
.map(TenantMaskingConfig::new)
.collect(
Collectors.toMap(
TenantMaskingConfig::getTenantId,
TenantMaskingConfig::getTimeRangeToMaskedAttributes));
}
}

public Set<String> getMaskedAttributes(ExecutionContext executionContext) {
String tenantId = executionContext.getTenantId();
HashSet<String> maskedAttributes = new HashSet<>();
if (!tenantToTimeRangeMaskedAttributes.containsKey(tenantId)) {
return maskedAttributes;
}

Optional<QueryTimeRange> queryTimeRange = executionContext.getQueryTimeRange();
Instant queryStartTime = Instant.MIN, queryEndTime = Instant.MAX;
if (queryTimeRange.isPresent()) {
queryStartTime = queryTimeRange.get().getStartTime();
queryEndTime = queryTimeRange.get().getEndTime();
}
for (TimeRangeToMaskedAttributes timeRangeAndMasks :
tenantToTimeRangeMaskedAttributes.get(tenantId)) {
if (isTimeRangeOverlap(timeRangeAndMasks, queryStartTime, queryEndTime)) {
maskedAttributes.addAll(timeRangeAndMasks.maskedAttributes);
}
}
return maskedAttributes;
}

private static boolean isTimeRangeOverlap(
TimeRangeToMaskedAttributes timeRangeAndMasks, Instant queryStartTime, Instant queryEndTime) {
return !(timeRangeAndMasks.startTimeMillis.isAfter(queryEndTime)
|| timeRangeAndMasks.endTimeMillis.isBefore(queryStartTime));
}

@Value
@NonFinal
static class TenantMaskingConfig {
private static final String TENANT_ID_CONFIG_KEY = "tenantId";
private static final String TIME_RANGE_AND_MASK_VALUES_CONFIG_KEY =
"timeRangeToMaskedAttributes";
String tenantId;
List<TimeRangeToMaskedAttributes> timeRangeToMaskedAttributes;

private TenantMaskingConfig(Config config) {
this.tenantId = config.getString(TENANT_ID_CONFIG_KEY);
this.timeRangeToMaskedAttributes =
config.getConfigList(TIME_RANGE_AND_MASK_VALUES_CONFIG_KEY).stream()
.map(TimeRangeToMaskedAttributes::new)
.filter(
timeRangeToMaskedAttributes -> {
if (!timeRangeToMaskedAttributes.isValid()) {
log.warn(
"Invalid masking configuration for tenant: {}. Either the time range is missing or the mask list is empty.",
this.tenantId);
return false;
}
return true;
})
.collect(Collectors.toList());
}
}

@NonFinal
static class TimeRangeToMaskedAttributes {
private static final String START_TIME_CONFIG_PATH = "startTimeMillis";
private static final String END_TIME_CONFIG_PATH = "endTimeMillis";
private static final String MASK_ATTRIBUTES_CONFIG_PATH = "maskedAttributes";
Instant startTimeMillis = null;
Instant endTimeMillis = null;
ArrayList<String> maskedAttributes = new ArrayList<>();

private TimeRangeToMaskedAttributes(Config config) {
if (config.hasPath(START_TIME_CONFIG_PATH) && config.hasPath(END_TIME_CONFIG_PATH)) {
Instant startTimeMillis = Instant.ofEpochMilli(config.getLong(START_TIME_CONFIG_PATH));
Instant endTimeMillis = Instant.ofEpochMilli(config.getLong(END_TIME_CONFIG_PATH));

if (startTimeMillis.isBefore(endTimeMillis)) {
this.startTimeMillis = startTimeMillis;
this.endTimeMillis = endTimeMillis;
if (config.hasPath(MASK_ATTRIBUTES_CONFIG_PATH)) {
maskedAttributes = new ArrayList<>(config.getStringList(MASK_ATTRIBUTES_CONFIG_PATH));
}
}
}
}

boolean isValid() {
return startTimeMillis != null && endTimeMillis != null && !maskedAttributes.isEmpty();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.pinot.client.ResultSetGroup;
import org.hypertrace.core.query.service.ExecutionContext;
import org.hypertrace.core.query.service.HandlerScopedFiltersConfig;
import org.hypertrace.core.query.service.HandlerScopedMaskingConfig;
import org.hypertrace.core.query.service.QueryCost;
import org.hypertrace.core.query.service.RequestHandler;
import org.hypertrace.core.query.service.api.Expression;
Expand Down Expand Up @@ -58,6 +59,10 @@ public class PinotBasedRequestHandler implements RequestHandler {
private static final String START_TIME_ATTRIBUTE_NAME_CONFIG_KEY = "startTimeAttributeName";
private static final String SLOW_QUERY_THRESHOLD_MS_CONFIG = "slowQueryThresholdMs";

private static final String DEFAULT_MASKED_VALUE = "*";
// This is how empty list is represented in Pinot
private static final String ARRAY_TYPE_MASKED_VALUE = "[\"\"]";

private static final int DEFAULT_SLOW_QUERY_THRESHOLD_MS = 3000;
private static final Set<Operator> GTE_OPERATORS = Set.of(Operator.GE, Operator.GT, Operator.EQ);

Expand All @@ -67,6 +72,7 @@ public class PinotBasedRequestHandler implements RequestHandler {
private QueryRequestToPinotSQLConverter request2PinotSqlConverter;
private final PinotMapConverter pinotMapConverter;
private HandlerScopedFiltersConfig handlerScopedFiltersConfig;
private HandlerScopedMaskingConfig handlerScopedMaskingConfig;
// The implementations of ResultSet are package private and hence there's no way to determine the
// shape of the results
// other than to do string comparison on the simple class names. In order to be able to unit test
Expand Down Expand Up @@ -143,6 +149,7 @@ private void processConfig(Config config) {

this.handlerScopedFiltersConfig =
new HandlerScopedFiltersConfig(config, this.startTimeAttributeName);
this.handlerScopedMaskingConfig = new HandlerScopedMaskingConfig(config);
LOG.info(
"Using {}ms as the threshold for logging slow queries of handler: {}",
slowQueryThreshold,
Expand Down Expand Up @@ -424,7 +431,7 @@ public Observable<Row> handleRequest(
LOG.debug("Query results: [ {} ]", resultSetGroup.toString());
}
// need to merge data especially for Pinot. That's why we need to track the map columns
return this.convert(resultSetGroup, executionContext.getSelectedColumns())
return this.convert(resultSetGroup, executionContext)
.doOnComplete(
() -> {
long requestTimeMs = stopwatch.stop().elapsed(TimeUnit.MILLISECONDS);
Expand Down Expand Up @@ -493,17 +500,21 @@ private Filter rewriteLeafFilter(
return queryFilter;
}

Observable<Row> convert(ResultSetGroup resultSetGroup, LinkedHashSet<String> selectedAttributes) {
Observable<Row> convert(ResultSetGroup resultSetGroup, ExecutionContext executionContext) {
List<Row.Builder> rowBuilderList = new ArrayList<>();
if (resultSetGroup.getResultSetCount() > 0) {
LinkedHashSet<String> selectedAttributes = executionContext.getSelectedColumns();
Set<String> maskedAttributes =
handlerScopedMaskingConfig.getMaskedAttributes(executionContext);
ResultSet resultSet = resultSetGroup.getResultSet(0);
// Pinot has different Response format for selection and aggregation/group by query.
if (resultSetTypePredicateProvider.isSelectionResultSetType(resultSet)) {
// map merging is only supported in the selection. Filtering and Group by has its own
// syntax in Pinot
handleSelection(resultSetGroup, rowBuilderList, selectedAttributes);
handleSelection(resultSetGroup, rowBuilderList, selectedAttributes, maskedAttributes);
} else if (resultSetTypePredicateProvider.isResultTableResultSetType(resultSet)) {
handleTableFormatResultSet(resultSetGroup, rowBuilderList);
handleTableFormatResultSet(
resultSetGroup, rowBuilderList, selectedAttributes, maskedAttributes);
} else {
handleAggregationAndGroupBy(resultSetGroup, rowBuilderList);
}
Expand All @@ -516,7 +527,8 @@ Observable<Row> convert(ResultSetGroup resultSetGroup, LinkedHashSet<String> sel
private void handleSelection(
ResultSetGroup resultSetGroup,
List<Builder> rowBuilderList,
LinkedHashSet<String> selectedAttributes) {
LinkedHashSet<String> selectedAttributes,
Set<String> maskedAttributes) {
int resultSetGroupCount = resultSetGroup.getResultSetCount();
for (int i = 0; i < resultSetGroupCount; i++) {
ResultSet resultSet = resultSetGroup.getResultSet(i);
Expand All @@ -536,7 +548,11 @@ private void handleSelection(
for (String logicalName : selectedAttributes) {
// colVal will never be null. But getDataRow can throw a runtime exception if it failed
// to retrieve data
String colVal = resultAnalyzer.getDataFromRow(rowId, logicalName);
String colVal =
maskedAttributes.contains(logicalName)
? DEFAULT_MASKED_VALUE
: resultAnalyzer.getDataFromRow(rowId, logicalName);

builder.addColumn(Value.newBuilder().setString(colVal).build());
}
}
Expand Down Expand Up @@ -588,10 +604,15 @@ private void handleAggregationAndGroupBy(
}

private void handleTableFormatResultSet(
ResultSetGroup resultSetGroup, List<Builder> rowBuilderList) {
ResultSetGroup resultSetGroup,
List<Builder> rowBuilderList,
LinkedHashSet<String> selectedAttributes,
Set<String> maskedAttributes) {
int resultSetGroupCount = resultSetGroup.getResultSetCount();
for (int i = 0; i < resultSetGroupCount; i++) {
ResultSet resultSet = resultSetGroup.getResultSet(i);
PinotResultAnalyzer resultAnalyzer =
PinotResultAnalyzer.create(resultSet, selectedAttributes, viewDefinition);
for (int rowIdx = 0; rowIdx < resultSet.getRowCount(); rowIdx++) {
Builder builder;
builder = Row.newBuilder();
Expand All @@ -602,8 +623,13 @@ private void handleTableFormatResultSet(
// Read the key and value column values. The columns should be side by side. That's how
// the Pinot query
// is structured
String logicalName = resultAnalyzer.getLogicalNameFromColIdx(colIdx);
String mapKeys = resultSet.getString(rowIdx, colIdx);
String mapVals = resultSet.getString(rowIdx, colIdx + 1);
String mapVals =
maskedAttributes.contains(logicalName)
? ARRAY_TYPE_MASKED_VALUE
: resultSet.getString(rowIdx, colIdx + 1);

try {
builder.addColumn(
Value.newBuilder().setString(pinotMapConverter.merge(mapKeys, mapVals)).build());
Expand All @@ -615,7 +641,11 @@ private void handleTableFormatResultSet(
// advance colIdx by 1 since we have read 2 columns
colIdx++;
} else {
String val = resultSet.getString(rowIdx, colIdx);
String logicalName = resultAnalyzer.getLogicalNameFromColIdx(colIdx);
String val =
maskedAttributes.contains(logicalName)
? DEFAULT_MASKED_VALUE
: resultSet.getString(rowIdx, colIdx);
builder.addColumn(Value.newBuilder().setString(val).build());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,20 @@ class PinotResultAnalyzer {
private final ViewDefinition viewDefinition;
private final Map<String, RateLimiter> attributeLogRateLimitter;
private final PinotMapConverter pinotMapConverter;
private final Map<Integer, String> indexToLogicalName;

PinotResultAnalyzer(
ResultSet resultSet,
LinkedHashSet<String> selectedAttributes,
ViewDefinition viewDefinition,
Map<String, Integer> mapLogicalNameToKeyIndex,
Map<String, Integer> mapLogicalNameToValueIndex,
Map<String, Integer> logicalNameToPhysicalNameIndex) {
Map<String, Integer> logicalNameToPhysicalNameIndex,
Map<Integer, String> indexToLogicalName) {
this.mapLogicalNameToKeyIndex = mapLogicalNameToKeyIndex;
this.mapLogicalNameToValueIndex = mapLogicalNameToValueIndex;
this.logicalNameToPhysicalNameIndex = logicalNameToPhysicalNameIndex;
this.indexToLogicalName = indexToLogicalName;
this.resultSet = resultSet;
this.viewDefinition = viewDefinition;
this.attributeLogRateLimitter = new HashMap<>();
Expand All @@ -53,6 +56,7 @@ static PinotResultAnalyzer create(
Map<String, Integer> mapLogicalNameToKeyIndex = new HashMap<>();
Map<String, Integer> mapLogicalNameToValueIndex = new HashMap<>();
Map<String, Integer> logicalNameToPhysicalNameIndex = new HashMap<>();
Map<Integer, String> indexToLogicalName = new HashMap<>();

for (String logicalName : selectedAttributes) {
if (viewDefinition.isMap(logicalName)) {
Expand All @@ -62,8 +66,10 @@ static PinotResultAnalyzer create(
String physName = resultSet.getColumnName(colIndex);
if (physName.equalsIgnoreCase(keyPhysicalName)) {
mapLogicalNameToKeyIndex.put(logicalName, colIndex);
indexToLogicalName.put(colIndex, logicalName);
} else if (physName.equalsIgnoreCase(valuePhysicalName)) {
mapLogicalNameToValueIndex.put(logicalName, colIndex);
indexToLogicalName.put(colIndex, logicalName);
}
}
} else {
Expand All @@ -73,21 +79,24 @@ static PinotResultAnalyzer create(
String physName = resultSet.getColumnName(colIndex);
if (physName.equalsIgnoreCase(names.get(0))) {
logicalNameToPhysicalNameIndex.put(logicalName, colIndex);
indexToLogicalName.put(colIndex, logicalName);
break;
}
}
}
}
LOG.info("Map LogicalName to Key Index: {} ", mapLogicalNameToKeyIndex);
LOG.info("Map LogicalName to Value Index: {}", mapLogicalNameToValueIndex);
LOG.info("Attributes to Index: {}", logicalNameToPhysicalNameIndex);
LOG.debug("Map LogicalName to Key Index: {} ", mapLogicalNameToKeyIndex);
LOG.debug("Map LogicalName to Value Index: {}", mapLogicalNameToValueIndex);
LOG.debug("Attributes to Index: {}", logicalNameToPhysicalNameIndex);
LOG.debug("Index to LogicalName: {}", indexToLogicalName);
return new PinotResultAnalyzer(
resultSet,
selectedAttributes,
viewDefinition,
mapLogicalNameToKeyIndex,
mapLogicalNameToValueIndex,
logicalNameToPhysicalNameIndex);
logicalNameToPhysicalNameIndex,
indexToLogicalName);
}

@VisibleForTesting
Expand Down Expand Up @@ -149,4 +158,8 @@ String getDataFromRow(int rowIndex, String logicalName) {
}
return result;
}

String getLogicalNameFromColIdx(Integer colIdx) {
return indexToLogicalName.get(colIdx);
}
}
Loading

0 comments on commit cd60ebb

Please sign in to comment.