Skip to content

Commit

Permalink
Add TLS support to mailboxes used in the multi-stage engine (#14476)
Browse files Browse the repository at this point in the history
  • Loading branch information
yashmayya authored Nov 19, 2024
1 parent a6ac892 commit f08c159
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.apache.pinot.broker.querylog.QueryLogger;
import org.apache.pinot.broker.queryquota.QueryQuotaManager;
import org.apache.pinot.broker.routing.BrokerRoutingManager;
import org.apache.pinot.common.config.TlsConfig;
import org.apache.pinot.common.config.provider.TableCache;
import org.apache.pinot.common.exception.QueryException;
import org.apache.pinot.common.metrics.BrokerMeter;
Expand Down Expand Up @@ -93,10 +94,11 @@ public MultiStageBrokerRequestHandler(PinotConfiguration config, String brokerId
String hostname = config.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_HOSTNAME);
int port = Integer.parseInt(config.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_PORT));
_workerManager = new WorkerManager(hostname, port, _routingManager);
_queryDispatcher = new QueryDispatcher(new MailboxService(hostname, port, config), config.getProperty(
TlsConfig tlsConfig = config.getProperty(
CommonConstants.Helix.CONFIG_OF_MULTI_STAGE_ENGINE_TLS_ENABLED,
CommonConstants.Helix.DEFAULT_MULTI_STAGE_ENGINE_TLS_ENABLED) ? TlsUtils.extractTlsConfig(config,
CommonConstants.Broker.BROKER_TLS_PREFIX) : null);
CommonConstants.Broker.BROKER_TLS_PREFIX) : null;
_queryDispatcher = new QueryDispatcher(new MailboxService(hostname, port, config, tlsConfig), tlsConfig);
LOGGER.info("Initialized MultiStageBrokerRequestHandler on host: {}, port: {} with broker id: {}, timeout: {}ms, "
+ "query log max length: {}, query log max rate: {}", hostname, port, _brokerId, _brokerTimeoutMs,
_queryLogger.getMaxQueryLengthToLog(), _queryLogger.getLogRateLimit());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,15 @@ public GrpcQueryClient(String host, int port) {
public GrpcQueryClient(String host, int port, GrpcConfig config) {
ManagedChannelBuilder<?> channelBuilder;
if (config.isUsePlainText()) {
channelBuilder =
ManagedChannelBuilder.forAddress(host, port).maxInboundMessageSize(config.getMaxInboundMessageSizeBytes())
.usePlaintext();
channelBuilder = ManagedChannelBuilder
.forAddress(host, port)
.maxInboundMessageSize(config.getMaxInboundMessageSizeBytes())
.usePlaintext();
} else {
channelBuilder =
NettyChannelBuilder.forAddress(host, port).maxInboundMessageSize(config.getMaxInboundMessageSizeBytes())
.sslContext(buildSslContext(config.getTlsConfig()));
channelBuilder = NettyChannelBuilder
.forAddress(host, port)
.maxInboundMessageSize(config.getMaxInboundMessageSizeBytes())
.sslContext(buildSslContext(config.getTlsConfig()));
}

// Set keep alive configs, if enabled
Expand All @@ -85,8 +87,8 @@ public GrpcQueryClient(String host, int port, GrpcConfig config) {
}

public static SslContext buildSslContext(TlsConfig tlsConfig) {
LOGGER.info("Building gRPC SSL context");
SslContext sslContext = CLIENT_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> {
LOGGER.info("Building gRPC client SSL context");
return CLIENT_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> {
try {
SSLFactory sslFactory = RenewableTlsUtils.createSSLFactoryAndEnableAutoRenewalWhenUsingFileStores(tlsConfig,
PinotInsecureMode::isPinotInInsecureMode);
Expand All @@ -101,10 +103,9 @@ public static SslContext buildSslContext(TlsConfig tlsConfig) {
}
return sslContextBuilder.build();
} catch (SSLException e) {
throw new RuntimeException("Failed to build gRPC SSL context", e);
throw new RuntimeException("Failed to build gRPC client SSL context", e);
}
});
return sslContext;
}

public Iterator<Server.ServerResponse> submit(Server.ServerRequest request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public GrpcQueryServer(int port, GrpcConfig config, TlsConfig tlsConfig, QueryEx
_serverMetrics = serverMetrics;
if (tlsConfig != null) {
try {
_server = NettyServerBuilder.forPort(port).sslContext(buildGRpcSslContext(tlsConfig))
_server = NettyServerBuilder.forPort(port).sslContext(buildGrpcSslContext(tlsConfig))
.maxInboundMessageSize(config.getMaxInboundMessageSizeBytes()).addService(this)
.addTransportFilter(new GrpcQueryTransportFilter()).build();
} catch (Exception e) {
Expand All @@ -119,13 +119,13 @@ public GrpcQueryServer(int port, GrpcConfig config, TlsConfig tlsConfig, QueryEx
ResourceManager.DEFAULT_QUERY_WORKER_THREADS);
}

public static SslContext buildGRpcSslContext(TlsConfig tlsConfig)
public static SslContext buildGrpcSslContext(TlsConfig tlsConfig)
throws IllegalArgumentException {
LOGGER.info("Building gRPC SSL context");
LOGGER.info("Building gRPC server SSL context");
if (tlsConfig.getKeyStorePath() == null) {
throw new IllegalArgumentException("Must provide key store path for secured gRpc server");
throw new IllegalArgumentException("Must provide key store path for secured gRPC server");
}
SslContext sslContext = SERVER_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> {
return SERVER_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> {
try {
SSLFactory sslFactory =
RenewableTlsUtils.createSSLFactoryAndEnableAutoRenewalWhenUsingFileStores(
Expand All @@ -138,10 +138,9 @@ public static SslContext buildGRpcSslContext(TlsConfig tlsConfig)
}
return GrpcSslContexts.configure(sslContextBuilder).build();
} catch (Exception e) {
throw new RuntimeException("Failed to build gRPC SSL context", e);
throw new RuntimeException("Failed to build gRPC server SSL context", e);
}
});
return sslContext;
}

public void start() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import com.google.common.cache.RemovalListener;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import org.apache.pinot.common.config.TlsConfig;
import org.apache.pinot.common.datatable.StatMap;
import org.apache.pinot.query.mailbox.channel.ChannelManager;
import org.apache.pinot.query.mailbox.channel.GrpcMailboxServer;
Expand Down Expand Up @@ -60,14 +62,21 @@ public class MailboxService {
private final String _hostname;
private final int _port;
private final PinotConfiguration _config;
private final ChannelManager _channelManager = new ChannelManager();
private final ChannelManager _channelManager;
@Nullable private final TlsConfig _tlsConfig;

private GrpcMailboxServer _grpcMailboxServer;

public MailboxService(String hostname, int port, PinotConfiguration config) {
this(hostname, port, config, null);
}

public MailboxService(String hostname, int port, PinotConfiguration config, @Nullable TlsConfig tlsConfig) {
_hostname = hostname;
_port = port;
_config = config;
_tlsConfig = tlsConfig;
_channelManager = new ChannelManager(tlsConfig);
LOGGER.info("Initialized MailboxService with hostname: {}, port: {}", hostname, port);
}

Expand All @@ -76,7 +85,7 @@ public MailboxService(String hostname, int port, PinotConfiguration config) {
*/
public void start() {
LOGGER.info("Starting GrpcMailboxServer");
_grpcMailboxServer = new GrpcMailboxServer(this, _config);
_grpcMailboxServer = new GrpcMailboxServer(this, _config, _tlsConfig);
_grpcMailboxServer.start();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@

import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import java.util.concurrent.ConcurrentHashMap;
import javax.annotation.Nullable;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.pinot.common.config.TlsConfig;
import org.apache.pinot.common.utils.grpc.GrpcQueryClient;
import org.apache.pinot.spi.utils.CommonConstants;


Expand All @@ -33,14 +37,31 @@
*/
public class ChannelManager {
private final ConcurrentHashMap<Pair<String, Integer>, ManagedChannel> _channelMap = new ConcurrentHashMap<>();
private final TlsConfig _tlsConfig;

public ChannelManager(@Nullable TlsConfig tlsConfig) {
_tlsConfig = tlsConfig;
}

public ManagedChannel getChannel(String hostname, int port) {
// TODO: Revisit parameters
// TODO: Support TLS
return _channelMap.computeIfAbsent(Pair.of(hostname, port),
(k) -> ManagedChannelBuilder.forAddress(k.getLeft(), k.getRight())
.maxInboundMessageSize(
CommonConstants.MultiStageQueryRunner.DEFAULT_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES)
.usePlaintext().build());
if (_tlsConfig != null) {
return _channelMap.computeIfAbsent(Pair.of(hostname, port),
(k) -> NettyChannelBuilder
.forAddress(k.getLeft(), k.getRight())
.maxInboundMessageSize(
CommonConstants.MultiStageQueryRunner.DEFAULT_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES)
.sslContext(GrpcQueryClient.buildSslContext(_tlsConfig))
.build()
);
} else {
return _channelMap.computeIfAbsent(Pair.of(hostname, port),
(k) -> ManagedChannelBuilder
.forAddress(k.getLeft(), k.getRight())
.maxInboundMessageSize(
CommonConstants.MultiStageQueryRunner.DEFAULT_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES)
.usePlaintext()
.build());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@

import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import org.apache.pinot.common.config.TlsConfig;
import org.apache.pinot.common.proto.Mailbox;
import org.apache.pinot.common.proto.PinotMailboxGrpc;
import org.apache.pinot.core.transport.grpc.GrpcQueryServer;
import org.apache.pinot.query.mailbox.MailboxService;
import org.apache.pinot.spi.env.PinotConfiguration;
import org.apache.pinot.spi.utils.CommonConstants;
Expand All @@ -42,13 +46,27 @@ public class GrpcMailboxServer extends PinotMailboxGrpc.PinotMailboxImplBase {
private final MailboxService _mailboxService;
private final Server _server;

public GrpcMailboxServer(MailboxService mailboxService, PinotConfiguration config) {
public GrpcMailboxServer(MailboxService mailboxService, PinotConfiguration config, @Nullable TlsConfig tlsConfig) {
_mailboxService = mailboxService;
int port = mailboxService.getPort();
// TODO: Support TLS
_server = ServerBuilder.forPort(port).addService(this).maxInboundMessageSize(
config.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES,
CommonConstants.MultiStageQueryRunner.DEFAULT_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES)).build();
if (tlsConfig != null) {
_server = NettyServerBuilder
.forPort(port)
.addService(this)
.sslContext(GrpcQueryServer.buildGrpcSslContext(tlsConfig))
.maxInboundMessageSize(config.getProperty(
CommonConstants.MultiStageQueryRunner.KEY_OF_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES,
CommonConstants.MultiStageQueryRunner.DEFAULT_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES))
.build();
} else {
_server = ServerBuilder
.forPort(port)
.addService(this)
.maxInboundMessageSize(config.getProperty(
CommonConstants.MultiStageQueryRunner.KEY_OF_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES,
CommonConstants.MultiStageQueryRunner.DEFAULT_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES))
.build();
}
}

public void start() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.apache.helix.HelixManager;
import org.apache.pinot.common.config.TlsConfig;
import org.apache.pinot.common.datatable.StatMap;
import org.apache.pinot.common.metrics.ServerMetrics;
import org.apache.pinot.common.proto.Worker;
Expand Down Expand Up @@ -120,7 +121,7 @@ public class QueryRunner {
* <p>Should be called only once and before calling any other method.
*/
public void init(PinotConfiguration config, InstanceDataManager instanceDataManager, HelixManager helixManager,
ServerMetrics serverMetrics) {
ServerMetrics serverMetrics, @Nullable TlsConfig tlsConfig) {
String hostname = config.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_HOSTNAME);
if (hostname.startsWith(CommonConstants.Helix.PREFIX_OF_SERVER_INSTANCE)) {
hostname = hostname.substring(CommonConstants.Helix.SERVER_INSTANCE_PREFIX_LENGTH);
Expand Down Expand Up @@ -148,7 +149,7 @@ public void init(PinotConfiguration config, InstanceDataManager instanceDataMana
config, CommonConstants.Server.CONFIG_OF_QUERY_EXECUTOR_OPCHAIN_EXECUTOR, "query-runner-on-" + port,
CommonConstants.Server.DEFAULT_QUERY_EXECUTOR_OPCHAIN_EXECUTOR);
_opChainScheduler = new OpChainSchedulerService(_executorService);
_mailboxService = new MailboxService(hostname, port, config);
_mailboxService = new MailboxService(hostname, port, config, tlsConfig);
try {
_leafQueryExecutor = new ServerQueryExecutorV1Impl();
_leafQueryExecutor.init(config.subset(CommonConstants.Server.QUERY_EXECUTOR_CONFIG_PREFIX), instanceDataManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public void start() {
_server = NettyServerBuilder
.forPort(_port)
.addService(this)
.sslContext(GrpcQueryServer.buildGRpcSslContext(_tlsConfig))
.sslContext(GrpcQueryServer.buildGrpcSslContext(_tlsConfig))
.maxInboundMessageSize(MAX_INBOUND_MESSAGE_SIZE)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ public QueryServerEnclosure(MockInstanceDataManagerFactory factory, Map<String,
InstanceDataManager instanceDataManager = factory.buildInstanceDataManager();
HelixManager helixManager = mockHelixManager(factory.buildSchemaMap());
_queryRunner = new QueryRunner();
_queryRunner.init(new PinotConfiguration(runnerConfig), instanceDataManager, helixManager, mockServiceMetrics());
_queryRunner.init(new PinotConfiguration(runnerConfig), instanceDataManager, helixManager, mockServiceMetrics(),
null);
}

private HelixManager mockHelixManager(Map<String, Schema> schemaMap) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public WorkerQueryServer(PinotConfiguration configuration, InstanceDataManager i
_queryServicePort = _configuration.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_SERVER_PORT,
CommonConstants.MultiStageQueryRunner.DEFAULT_QUERY_SERVER_PORT);
QueryRunner queryRunner = new QueryRunner();
queryRunner.init(_configuration, instanceDataManager, helixManager, serverMetrics);
queryRunner.init(_configuration, instanceDataManager, helixManager, serverMetrics, tlsConfig);
_queryWorkerService = new QueryServer(_queryServicePort, queryRunner, tlsConfig);
}

Expand Down

0 comments on commit f08c159

Please sign in to comment.