Skip to content

Commit

Permalink
Merge pull request #43882 from mkouba/ws-next-security-cleanup
Browse files Browse the repository at this point in the history
WebSockets Next: Security cleanup
  • Loading branch information
sberyozkin authored Oct 16, 2024
2 parents 29b0b6b + 8d678da commit 2c9a5ac
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@
import io.quarkus.security.spi.runtime.SecurityCheck;
import io.quarkus.vertx.http.deployment.RouteBuildItem;
import io.quarkus.vertx.http.runtime.HandlerType;
import io.quarkus.vertx.http.runtime.HttpBuildTimeConfig;
import io.quarkus.websockets.next.HttpUpgradeCheck;
import io.quarkus.websockets.next.InboundProcessingMode;
import io.quarkus.websockets.next.WebSocketClientConnection;
Expand Down Expand Up @@ -445,18 +444,11 @@ public String apply(String name) {
@Record(RUNTIME_INIT)
@BuildStep
public void registerRoutes(WebSocketServerRecorder recorder, List<GeneratedEndpointBuildItem> generatedEndpoints,
HttpBuildTimeConfig httpConfig, Capabilities capabilities,
BuildProducer<RouteBuildItem> routes) {
for (GeneratedEndpointBuildItem endpoint : generatedEndpoints.stream().filter(GeneratedEndpointBuildItem::isServer)
.toList()) {
RouteBuildItem.Builder builder = RouteBuildItem.builder();
if (capabilities.isPresent(Capability.SECURITY) && !httpConfig.auth.proactive) {
// Add a special handler so that it's possible to capture the SecurityIdentity before the HTTP upgrade
builder.routeFunction(endpoint.path, recorder.initializeSecurityHandler());
} else {
builder.route(endpoint.path);
}
builder
RouteBuildItem.Builder builder = RouteBuildItem.builder()
.route(endpoint.path)
.displayOnNotFoundPage("WebSocket Endpoint")
.handlerType(HandlerType.NORMAL)
.handler(recorder.createEndpointHandler(endpoint.generatedClassName, endpoint.endpointId));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,12 @@ public static abstract class ChainHttpUpgradeCheckBase implements HttpUpgradeChe

@Override
public Uni<CheckResult> perform(HttpUpgradeContext request) {
if (identityPropagated(request) && testCheckChain(request)) {
return CheckResult.permitUpgrade(getResponseHeaders());
}
return CheckResult.permitUpgrade();
return request.securityIdentity().chain(identity -> {
if (identity != null && identity.isAnonymous() && testCheckChain(request)) {
return CheckResult.permitUpgrade(getResponseHeaders());
}
return CheckResult.permitUpgrade();
});
}

protected Map<String, List<String>> getResponseHeaders() {
Expand All @@ -134,11 +136,6 @@ protected static boolean testCheckChain(HttpUpgradeContext context) {
return context.httpRequest().headers().contains(TEST_CHECK_CHAIN);
}

private static boolean identityPropagated(HttpUpgradeContext context) {
// point of this method is to check that identity is present in the context
return context.securityIdentity() != null && context.securityIdentity().isAnonymous();
}

}

@Dependent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ default boolean appliesTo(String endpointId) {
* @param securityIdentity {@link SecurityIdentity}; the identity is null if the Quarkus Security extension is absent
* @param endpointId {@link WebSocket#endpointId()}
*/
record HttpUpgradeContext(HttpServerRequest httpRequest, SecurityIdentity securityIdentity, String endpointId) {
record HttpUpgradeContext(HttpServerRequest httpRequest, Uni<SecurityIdentity> securityIdentity, String endpointId) {
}

final class CheckResult {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@ public class SecurityHttpUpgradeCheck implements HttpUpgradeCheck {

@Override
public Uni<CheckResult> perform(HttpUpgradeContext context) {
return endpointToCheck
.get(context.endpointId())
.nonBlockingApply(context.securityIdentity(), (MethodDescription) null, null)
return context.securityIdentity().chain(identity -> endpointToCheck.get(context.endpointId())
.nonBlockingApply(identity, (MethodDescription) null, null)
.replaceWith(CheckResult::permitUpgradeSync)
.onFailure(SecurityException.class).recoverWithItem(this::rejectUpgrade);
.onFailure(SecurityException.class).recoverWithItem(this::rejectUpgrade));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Supplier;

import jakarta.enterprise.inject.Instance;
Expand Down Expand Up @@ -31,7 +30,6 @@
import io.vertx.core.Handler;
import io.vertx.core.Vertx;
import io.vertx.core.http.ServerWebSocket;
import io.vertx.ext.web.Route;
import io.vertx.ext.web.RoutingContext;

@Recorder
Expand Down Expand Up @@ -62,34 +60,6 @@ public Object get() {
};
}

public Consumer<Route> initializeSecurityHandler() {
return new Consumer<Route>() {

@Override
public void accept(Route route) {
// Force authentication so that it's possible to capture the SecurityIdentity before the HTTP upgrade
route.handler(new Handler<RoutingContext>() {

@Override
public void handle(RoutingContext ctx) {
if (ctx.user() == null) {
Uni<SecurityIdentity> deferredIdentity = ctx
.<Uni<SecurityIdentity>> get(QuarkusHttpUser.DEFERRED_IDENTITY_KEY);
deferredIdentity.subscribe().with(i -> {
if (ctx.response().ended()) {
return;
}
ctx.next();
}, ctx::fail);
} else {
ctx.next();
}
}
});
}
};
}

public Handler<RoutingContext> createEndpointHandler(String generatedEndpointClass, String endpointId) {
ArcContainer container = Arc.container();
ConnectionManager connectionManager = container.instance(ConnectionManager.class).get();
Expand Down Expand Up @@ -142,7 +112,13 @@ private void httpUpgrade(RoutingContext ctx) {
}

private Uni<CheckResult> checkHttpUpgrade(RoutingContext ctx, String endpointId) {
SecurityIdentity identity = ctx.user() instanceof QuarkusHttpUser user ? user.getSecurityIdentity() : null;
QuarkusHttpUser user = (QuarkusHttpUser) ctx.user();
Uni<SecurityIdentity> identity;
if (user == null) {
identity = ctx.<Uni<SecurityIdentity>> get(QuarkusHttpUser.DEFERRED_IDENTITY_KEY);
} else {
identity = Uni.createFrom().item(user.getSecurityIdentity());
}
return checkHttpUpgrade(new HttpUpgradeContext(ctx.request(), identity, endpointId), httpUpgradeChecks, 0);
}

Expand Down

0 comments on commit 2c9a5ac

Please sign in to comment.