Skip to content

Commit

Permalink
Sync cleanups and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rmaucher committed Dec 11, 2023
1 parent 0170417 commit 8f26744
Show file tree
Hide file tree
Showing 11 changed files with 591 additions and 246 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,8 @@
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.lang.foreign.Arena;
import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.Linker;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.ref.Cleaner;
import java.lang.ref.Cleaner.Cleanable;
import java.net.HttpURLConnection;
Expand All @@ -47,6 +42,7 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;

import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
Expand All @@ -66,6 +62,8 @@
import org.apache.tomcat.util.net.Constants;
import org.apache.tomcat.util.net.SSLUtil;
import org.apache.tomcat.util.net.openssl.ciphers.OpenSSLCipherConfigurationParser;
import org.apache.tomcat.util.openssl.SSL_set_info_callback$cb;
import org.apache.tomcat.util.openssl.SSL_set_verify$callback;
import org.apache.tomcat.util.res.StringManager;

/**
Expand All @@ -84,29 +82,10 @@ public final class OpenSSLEngine extends SSLEngine implements SSLUtil.ProtocolIn

public static final Set<String> IMPLEMENTED_PROTOCOLS_SET;

private static final MethodHandle openSSLCallbackInfoHandle;
private static final MethodHandle openSSLCallbackVerifyHandle;

private static final FunctionDescriptor openSSLCallbackInfoFunctionDescriptor =
FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.JAVA_INT, ValueLayout.JAVA_INT);
private static final FunctionDescriptor openSSLCallbackVerifyFunctionDescriptor =
FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.JAVA_INT, ValueLayout.ADDRESS);

static {
MethodHandles.Lookup lookup = MethodHandles.lookup();
try {
openSSLCallbackInfoHandle = lookup.findStatic(OpenSSLEngine.class, "openSSLCallbackInfo",
MethodType.methodType(void.class, MemorySegment.class, int.class, int.class));
openSSLCallbackVerifyHandle = lookup.findStatic(OpenSSLEngine.class, "openSSLCallbackVerify",
MethodType.methodType(int.class, int.class, MemorySegment.class));
} catch (Exception e) {
throw new IllegalStateException(e);
}

final Set<String> availableCipherSuites = new LinkedHashSet<>(128);
availableCipherSuites.addAll(OpenSSLLibrary.findCiphers("ALL"));
AVAILABLE_CIPHER_SUITES = Collections.unmodifiableSet(availableCipherSuites);

HashSet<String> protocols = new HashSet<>();
protocols.add(Constants.SSL_PROTO_SSLv2Hello);
protocols.add(Constants.SSL_PROTO_SSLv2);
Expand Down Expand Up @@ -211,9 +190,7 @@ private enum PHAState { NONE, START, COMPLETE }
session = new OpenSSLSession();
var ssl = SSL_new(sslCtx);
// Set ssl_info_callback
var openSSLCallbackInfo = Linker.nativeLinker().upcallStub(openSSLCallbackInfoHandle,
openSSLCallbackInfoFunctionDescriptor, engineArena);
SSL_set_info_callback(ssl, openSSLCallbackInfo);
SSL_set_info_callback(ssl, SSL_set_info_callback$cb.allocate(new InfoCallback(), engineArena));
if (clientMode) {
SSL_set_connect_state(ssl);
} else {
Expand Down Expand Up @@ -1158,27 +1135,34 @@ private void setClientAuth(ClientAuthMode mode) {
};
// SSL.setVerify(state.ssl, value, certificateVerificationDepth);
// Set int verify_callback(int preverify_ok, X509_STORE_CTX *x509_ctx) callback
var openSSLCallbackVerify =
Linker.nativeLinker().upcallStub(openSSLCallbackVerifyHandle,
openSSLCallbackVerifyFunctionDescriptor, engineArena);
int value = switch (mode) {
case NONE -> SSL_VERIFY_NONE();
case REQUIRE -> SSL_VERIFY_PEER() | SSL_VERIFY_FAIL_IF_NO_PEER_CERT();
case OPTIONAL -> SSL_VERIFY_PEER();
};
SSL_set_verify(state.ssl, value, openSSLCallbackVerify);
SSL_set_verify(state.ssl, value, SSL_set_verify$callback.allocate(new VerifyCallback(), engineArena));
clientAuth = mode;
}
}

public static void openSSLCallbackInfo(MemorySegment ssl, int where, int ret) {
EngineState state = getState(ssl);
if (state == null) {
log.warn(sm.getString("engine.noSSL", Long.valueOf(ssl.address())));
return;
private static class InfoCallback implements SSL_set_info_callback$cb {
@Override
public void apply(MemorySegment ssl, int where, @SuppressWarnings("unused") int ret) {
EngineState state = getState(ssl);
if (state == null) {
log.warn(sm.getString("engine.noSSL", Long.valueOf(ssl.address())));
return;
}
if (0 != (where & SSL_CB_HANDSHAKE_DONE())) {
state.handshakeCount++;
}
}
if (0 != (where & SSL_CB_HANDSHAKE_DONE())) {
state.handshakeCount++;
}

private static class VerifyCallback implements SSL_set_verify$callback {
@Override
public int apply(int preverify_ok, MemorySegment /*X509_STORE_CTX*/ x509ctx) {
return openSSLCallbackVerify(preverify_ok, x509ctx);
}
}

Expand Down Expand Up @@ -1740,19 +1724,24 @@ private EngineState(MemorySegment ssl, MemorySegment networkBIO,
this.noOcspCheck = noOcspCheck;
// Use another arena to avoid keeping a reference through segments
// This also allows making further accesses to the main pointers safer
this.ssl = ssl.reinterpret(ValueLayout.ADDRESS.byteSize(), stateArena, null);
this.networkBIO = networkBIO.reinterpret(ValueLayout.ADDRESS.byteSize(), stateArena, null);
this.ssl = ssl.reinterpret(ValueLayout.ADDRESS.byteSize(), stateArena,
new Consumer<MemorySegment>() {
@Override
public void accept(MemorySegment t) {
SSL_free(t);
}});
this.networkBIO = networkBIO.reinterpret(ValueLayout.ADDRESS.byteSize(), stateArena,
new Consumer<MemorySegment>() {
@Override
public void accept(MemorySegment t) {
BIO_free(t);
}});
}

@Override
public void run() {
try {
states.remove(Long.valueOf(ssl.address()));
BIO_free(networkBIO);
SSL_free(ssl);
} finally {
stateArena.close();
}
states.remove(Long.valueOf(ssl.address()));
stateArena.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import static org.apache.tomcat.util.openssl.openssl_h_Compatibility.*;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import org.apache.tomcat.util.net.openssl.OpenSSLStatus;
import org.apache.tomcat.util.net.openssl.ciphers.OpenSSLCipherConfigurationParser;
import org.apache.tomcat.util.res.StringManager;

Expand Down Expand Up @@ -73,10 +74,6 @@ public class OpenSSLLibrary {

protected static final Object lock = new Object();

public OpenSSLLibrary() {
OpenSSLStatus.setInstanceCreated(true);
}

static MemorySegment enginePointer = MemorySegment.NULL;

static void initLibrary() {
Expand All @@ -98,7 +95,6 @@ static void initLibrary() {
{ BN_get_rfc3526_prime_2048, NULL, 1025 },
{ BN_get_rfc2409_prime_1024, NULL, 0 }
*/
@Deprecated
static final class DHParam {
final MemorySegment dh;
final int min;
Expand All @@ -109,7 +105,6 @@ private DHParam(MemorySegment dh, int min) {
}
static final DHParam[] dhParameters = new DHParam[6];

@Deprecated
private static void initDHParameters() {
var dh = DH_new();
var p = BN_get_rfc3526_prime_8192(MemorySegment.NULL);
Expand Down Expand Up @@ -149,7 +144,6 @@ private static void initDHParameters() {
dhParameters[5] = new DHParam(dh, 0);
}

@Deprecated
private static void freeDHParameters() {
for (int i = 0; i < dhParameters.length; i++) {
if (dhParameters[i] != null) {
Expand All @@ -162,7 +156,7 @@ private static void freeDHParameters() {
}
}

static void init() {
public static void init() {
synchronized (lock) {

if (OpenSSLStatus.isInitialized()) {
Expand Down Expand Up @@ -330,7 +324,8 @@ static void init() {
}
}

static void destroy() {

public static void destroy() {
synchronized (lock) {
if (!OpenSSLStatus.isInitialized()) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ public void setTicketKeys(byte[] keys) {
throw new IllegalArgumentException(sm.getString("sessionContext.nullTicketKeys"));
}
if (keys.length != TICKET_KEYS_SIZE) {
throw new IllegalArgumentException(sm.getString("sessionContext.invalidTicketKeysLength", keys.length));
throw new IllegalArgumentException(sm.getString("sessionContext.invalidTicketKeysLength",
Integer.valueOf(keys.length)));
}
try (var memorySession = Arena.ofConfined()) {
var array = memorySession.allocateFrom(ValueLayout.JAVA_BYTE, keys);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Generated by jextract

package org.apache.tomcat.util.openssl;

import java.lang.invoke.MethodHandle;
import java.lang.foreign.*;
import static java.lang.foreign.ValueLayout.*;

/**
* {@snippet lang = c
* : * int (*SSL_CTX_set_alpn_select_cb$cb)(struct ssl_st*,unsigned char**,unsigned char*,unsigned char*,unsigned int,void*);
* }
*/
public interface SSL_CTX_set_alpn_select_cb$cb {

FunctionDescriptor $DESC = FunctionDescriptor.of(JAVA_INT, openssl_h.C_POINTER, openssl_h.C_POINTER,
openssl_h.C_POINTER, openssl_h.C_POINTER, JAVA_INT, openssl_h.C_POINTER);

int apply(MemorySegment _x0, MemorySegment _x1, MemorySegment _x2, MemorySegment _x3, int _x4, MemorySegment _x5);

MethodHandle UP$MH = openssl_h.upcallHandle(SSL_CTX_set_alpn_select_cb$cb.class, "apply", $DESC);

static MemorySegment allocate(SSL_CTX_set_alpn_select_cb$cb fi, Arena scope) {
return Linker.nativeLinker().upcallStub(UP$MH.bindTo(fi), $DESC, scope);
}

MethodHandle DOWN$MH = Linker.nativeLinker().downcallHandle($DESC);

static SSL_CTX_set_alpn_select_cb$cb ofAddress(MemorySegment addr, Arena arena) {
MemorySegment symbol = addr.reinterpret(arena, null);
return (MemorySegment __x0, MemorySegment __x1, MemorySegment __x2, MemorySegment __x3, int __x4,
MemorySegment __x5) -> {
try {
return (int) DOWN$MH.invokeExact(symbol, __x0, __x1, __x2, __x3, __x4, __x5);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Generated by jextract

package org.apache.tomcat.util.openssl;

import java.lang.invoke.MethodHandle;
import java.lang.foreign.*;
import static java.lang.foreign.ValueLayout.*;

/**
* {@snippet lang = c : * int (*SSL_CTX_set_cert_verify_callback$cb)(struct x509_store_ctx_st*,void*);
* }
*/
public interface SSL_CTX_set_cert_verify_callback$cb {

FunctionDescriptor $DESC = FunctionDescriptor.of(JAVA_INT, openssl_h.C_POINTER, openssl_h.C_POINTER);

int apply(MemorySegment _x0, MemorySegment _x1);

MethodHandle UP$MH = openssl_h.upcallHandle(SSL_CTX_set_cert_verify_callback$cb.class, "apply", $DESC);

static MemorySegment allocate(SSL_CTX_set_cert_verify_callback$cb fi, Arena scope) {
return Linker.nativeLinker().upcallStub(UP$MH.bindTo(fi), $DESC, scope);
}

MethodHandle DOWN$MH = Linker.nativeLinker().downcallHandle($DESC);

static SSL_CTX_set_cert_verify_callback$cb ofAddress(MemorySegment addr, Arena arena) {
MemorySegment symbol = addr.reinterpret(arena, null);
return (MemorySegment __x0, MemorySegment __x1) -> {
try {
return (int) DOWN$MH.invokeExact(symbol, __x0, __x1);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Generated by jextract

package org.apache.tomcat.util.openssl;

import java.lang.invoke.MethodHandle;
import java.lang.foreign.*;
import static java.lang.foreign.ValueLayout.*;

/**
* {@snippet lang = c : * int (*SSL_CTX_set_default_passwd_cb$cb)(char*,int,int,void*);
* }
*/
public interface SSL_CTX_set_default_passwd_cb$cb {

FunctionDescriptor $DESC = FunctionDescriptor.of(JAVA_INT, openssl_h.C_POINTER, JAVA_INT, JAVA_INT,
openssl_h.C_POINTER);

int apply(MemorySegment _x0, int _x1, int _x2, MemorySegment _x3);

MethodHandle UP$MH = openssl_h.upcallHandle(SSL_CTX_set_default_passwd_cb$cb.class, "apply", $DESC);

static MemorySegment allocate(SSL_CTX_set_default_passwd_cb$cb fi, Arena scope) {
return Linker.nativeLinker().upcallStub(UP$MH.bindTo(fi), $DESC, scope);
}

MethodHandle DOWN$MH = Linker.nativeLinker().downcallHandle($DESC);

static SSL_CTX_set_default_passwd_cb$cb ofAddress(MemorySegment addr, Arena arena) {
MemorySegment symbol = addr.reinterpret(arena, null);
return (MemorySegment __x0, int __x1, int __x2, MemorySegment __x3) -> {
try {
return (int) DOWN$MH.invokeExact(symbol, __x0, __x1, __x2, __x3);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
};
}
}
Loading

0 comments on commit 8f26744

Please sign in to comment.