Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: revisit tls_socket code #296

Merged
merged 1 commit into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 20 additions & 28 deletions util/tls/tls_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,36 +41,28 @@ static Engine::OpResult ToOpResult(const SSL* ssl, int result, const char* locat
return nonstd::make_unexpected(error);
}

int want = SSL_want(ssl);

if (want == SSL_NOTHING) {
int ssl_error = SSL_get_error(ssl, result);
int io_err = errno;

switch (ssl_error) {
case SSL_ERROR_ZERO_RETURN:
break;
case SSL_ERROR_SYSCALL:
LOG(WARNING) << "SSL syscall error " << io_err << ":" << result << " " << location;
break;
case SSL_ERROR_SSL:
LOG(WARNING) << "SSL protocol error " << io_err << ":" << result << " " << location;
break;
default:
LOG(WARNING) << "Unexpected SSL error " << io_err << ":" << result << " " << location;
break;
}

return Engine::EOF_STREAM;
int ssl_error = SSL_get_error(ssl, result);
int io_err = errno;

switch (ssl_error) {
case SSL_ERROR_ZERO_RETURN:
break;
case SSL_ERROR_WANT_READ:
return Engine::NEED_READ_AND_MAYBE_WRITE;
case SSL_ERROR_WANT_WRITE:
VLOG(1) << "SSL_ERROR_WANT_WRITE " << location;
return Engine::NEED_WRITE;
case SSL_ERROR_SYSCALL:
LOG(WARNING) << "SSL syscall error " << io_err << ":" << result << " " << location;
break;
case SSL_ERROR_SSL:
LOG(WARNING) << "SSL protocol error " << io_err << ":" << result << " " << location;
break;
default:
LOG(WARNING) << "Unexpected SSL error " << io_err << ":" << result << " " << location;
break;
}

if (SSL_WRITING == want)
return Engine::NEED_WRITE;
if (SSL_READING == want)
return Engine::NEED_READ_AND_MAYBE_WRITE;

LOG(ERROR) << "Unsupported want value " << want << ", ssl_error: " << SSL_get_error(ssl, result);

return Engine::EOF_STREAM;
}

Expand Down
9 changes: 8 additions & 1 deletion util/tls/tls_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ class Engine {
enum HandshakeType { CLIENT = 1, SERVER = 2 };
enum OpCode {
EOF_STREAM = -1,

// We use BIO buffers, therefore any SSL operation can end up writing to the internal BIO
// and result in success, even though the data has not been flushed to the underlying socket.
// See https://www.openssl.org/docs/man1.0.2/man3/BIO_new_bio_pair.html
// As a result, we must flush output buffer (if OutputPending() > 0)if before we do any
// Socket reads. We could flush after each SSL operation but that would result in fragmented
// Socket writes which we want to avoid.
NEED_READ_AND_MAYBE_WRITE = -2,
NEED_WRITE = -3,
};
Expand Down Expand Up @@ -89,7 +96,7 @@ class Engine {
void CommitInput(unsigned sz);

// Returns size of pending data that needs to be flushed out from SSL to I/O.
// See https://www.openssl.org/docs/man1.1.0/man3/BIO_new_bio_pair.html
// See https://www.openssl.org/docs/man1.0.2/man3/BIO_new_bio_pair.html
// Specifically, warning that says: "An application must not rely on the error value of
// SSL_operation() but must assure that the write buffer is always flushed first".
size_t OutputPending() const {
Expand Down
139 changes: 65 additions & 74 deletions util/tls/tls_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ auto TlsSocket::Shutdown(int how) -> error_code {
Engine::OpResult op_result = engine_->Shutdown();
if (op_result) {
// engine_ could send notification messages to the peer.
MaybeSendOutput();
std::ignore = MaybeSendOutput();
}

// In any case we should also shutdown the underlying TCP socket without relying on the
Expand Down Expand Up @@ -132,14 +132,10 @@ auto TlsSocket::Accept() -> AcceptResult {
if (op_val >= 0) { // Shutdown or empty read/write may return 0.
break;
}
if (op_val == Engine::EOF_STREAM) {
return make_unexpected(make_error_code(errc::connection_reset));
}
if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
ec = HandleSocketRead();
if (ec)
return make_unexpected(ec);
}

ec = HandleOp(op_val);
if (ec)
return make_unexpected(ec);
}

return nullptr;
Expand All @@ -162,36 +158,26 @@ error_code TlsSocket::Connect(const endpoint_type& endpoint,

// Flush the ssl data to the socket and run the loop that ensures handshaking converges.
int op_val = *op_result;
error_code ec;

// it should guide us to write and then read.
DCHECK_EQ(op_val, Engine::NEED_READ_AND_MAYBE_WRITE);
while (op_val < 0) {
if (op_val == Engine::EOF_STREAM) {
return make_error_code(errc::connection_reset);
}
error_code ec = HandleOp(op_val);
if (ec)
return ec;

if (op_val == Engine::NEED_WRITE) {
ec = HandleSocketWrite();
if (ec)
return ec;
} else if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
ec = HandleSocketWrite();
if (ec)
return ec;

ec = HandleSocketRead();
if (ec)
return ec;
}
op_result = engine_->Handshake(Engine::HandshakeType::CLIENT);
if (!op_result) {
return std::error_code(op_result.error(), std::system_category());
}
op_val = *op_result;
}

return ec;
const auto* cipher = SSL_get_current_cipher(engine_->native_handle());
VLOG(1) << "SSL handshake success, chosen " << SSL_CIPHER_get_name(cipher) << "/"
<< SSL_CIPHER_get_version(cipher);

return {};
}

auto TlsSocket::Close() -> error_code {
Expand Down Expand Up @@ -245,11 +231,6 @@ io::Result<size_t> TlsSocket::RecvMsg(const msghdr& msg, int flags) {
return make_unexpected(SSL2Error(op_result.error()));
}

error_code ec = MaybeSendOutput();
if (ec) {
return make_unexpected(ec);
}

int op_val = *op_result;
if (spin_count.Check(op_val <= 0)) {
// Once every 30 seconds.
Expand All @@ -267,26 +248,18 @@ io::Result<size_t> TlsSocket::RecvMsg(const msghdr& msg, int flags) {
++io;
--io_len;
if (io_len == 0)
break;
break; // Finished reading everything.
dest = Engine::MutableBuffer{reinterpret_cast<uint8_t*>(io->iov_base), io->iov_len};
}
continue; // We read everything we asked for - lets retry.
// We read everything we asked for but there are still buffers left to fill.
continue;
}
break;
}

if (read_total) // if we read something lets return it before we handle other states.
break;

if (op_val == Engine::EOF_STREAM) {
return make_unexpected(make_error_code(errc::connection_reset));
}

if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
ec = HandleSocketRead();
if (ec)
return make_unexpected(ec);
}
error_code ec = HandleOp(op_val);
if (ec)
return make_unexpected(ec);
}
return read_total;
}
Expand All @@ -307,12 +280,12 @@ io::Result<size_t> TlsSocket::WriteSome(const iovec* ptr, uint32_t len) {
// Chosen to be sufficiently smaller than the usual MTU (1500) and a multiple of 16.
// IP - max 24 bytes. TCP - max 60 bytes. TLS - max 21 bytes.
constexpr size_t kBufferSize = 1392;
io::Result<size_t> ec;
io::Result<size_t> res;
size_t total_sent = 0;

while (len) {
if (ptr->iov_len > kBufferSize || len == 1) {
ec = SendBuffer(Engine::Buffer{reinterpret_cast<uint8_t*>(ptr->iov_base), ptr->iov_len});
res = SendBuffer(Engine::Buffer{reinterpret_cast<uint8_t*>(ptr->iov_base), ptr->iov_len});
ptr++;
len--;
} else {
Expand All @@ -324,18 +297,18 @@ io::Result<size_t> TlsSocket::WriteSome(const iovec* ptr, uint32_t len) {
ptr++;
len--;
}
ec = SendBuffer({scratch, buffered_size});
res = SendBuffer({scratch, buffered_size});
}
if (!ec.has_value()) {
return ec;
} else {
total_sent += ec.value();
if (!res) {
return res;
}
total_sent += *res;
}
return total_sent;
}

io::Result<size_t> TlsSocket::SendBuffer(Engine::Buffer buf) {
// Sending buffer into ssl.
DCHECK(engine_);
DCHECK_GT(buf.size(), 0u);

Expand All @@ -348,17 +321,7 @@ io::Result<size_t> TlsSocket::SendBuffer(Engine::Buffer buf) {
return make_unexpected(SSL2Error(op_result.error()));
}

error_code ec = MaybeSendOutput();
if (ec) {
return make_unexpected(ec);
}

int op_val = *op_result;
if (spin_count.Check(op_val <= 0)) {
// Once every 30 seconds.
LOG_EVERY_T(WARNING, 30) << "IO loop spin limit reached. Limit: " << spin_count.Limit()
<< " Spins: " << spin_count.Spins();
}

if (op_val > 0) {
send_total += op_val;
Expand All @@ -370,15 +333,15 @@ io::Result<size_t> TlsSocket::SendBuffer(Engine::Buffer buf) {
}
}

if (op_val == Engine::EOF_STREAM) {
return make_unexpected(make_error_code(errc::connection_reset));
if (spin_count.Check(op_val <= 0)) {
// Once every 30 seconds.
LOG_EVERY_T(WARNING, 30) << "IO loop spin limit reached. Limit: " << spin_count.Limit()
<< " Spins: " << spin_count.Spins();
}

if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) {
ec = HandleSocketRead();
if (ec)
return make_unexpected(ec);
}
error_code ec = HandleOp(op_val);
if (ec)
return make_unexpected(ec);
}

return send_total;
Expand All @@ -395,6 +358,9 @@ SSL* TlsSocket::ssl_handle() {
}

auto TlsSocket::MaybeSendOutput() -> error_code {
if (engine_->OutputPending() == 0)
return {};

// This function is present in both read and write paths.
// meaning that both of them can be called concurrently from differrent fibers and then
// race over flushing the output buffer. We use state_ to prevent that.
Expand All @@ -419,6 +385,10 @@ auto TlsSocket::MaybeSendOutput() -> error_code {
}

auto TlsSocket::HandleSocketRead() -> error_code {
error_code ec = MaybeSendOutput();
if (ec)
return ec;

if (state_ & READ_IN_PROGRESS) {
// We need to Yield because otherwise we might end up in an infinite loop.
// See also comments in MaybeSendOutput.
Expand All @@ -434,33 +404,54 @@ auto TlsSocket::HandleSocketRead() -> error_code {
return esz.error();
}

DVLOG(1) << "TlsSocket:Read " << *esz << " bytes";

engine_->CommitInput(*esz);

return error_code{};
}

error_code TlsSocket::HandleSocketWrite() {
Engine::Buffer buffer = engine_->PeekOutputBuf();
DCHECK(!buffer.empty());

if (buffer.empty())
return {};

// we do not allow concurrent writes from multiple fibers.
state_ |= WRITE_IN_PROGRESS;
while (!buffer.empty()) {
// we do not allow concurrent writes from multiple fibers.
state_ |= WRITE_IN_PROGRESS;
io::Result<size_t> write_result = next_sock_->WriteSome(buffer);

// Safe to clear here since the code below is atomic fiber-wise.
state_ &= ~WRITE_IN_PROGRESS;
DCHECK(engine_);
if (!write_result) {
state_ &= ~WRITE_IN_PROGRESS;

return write_result.error();
}
CHECK_GT(*write_result, 0u);
engine_->ConsumeOutputBuf(*write_result);
buffer.remove_prefix(*write_result);
}
DCHECK_EQ(engine_->OutputPending(), 0u);

state_ &= ~WRITE_IN_PROGRESS;

return error_code{};
}

error_code TlsSocket::HandleOp(int op_val) {
switch (op_val) {
case Engine::EOF_STREAM:
return make_error_code(errc::connection_reset);
case Engine::NEED_READ_AND_MAYBE_WRITE:
return HandleSocketRead();
default:
LOG(DFATAL) << "Unsupported " << op_val;
}
return {};
}

TlsSocket::endpoint_type TlsSocket::LocalEndpoint() const {
return next_sock_->LocalEndpoint();
}
Expand Down
1 change: 1 addition & 0 deletions util/tls/tls_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class TlsSocket final : public FiberSocketBase {
error_code HandleSocketRead();

error_code HandleSocketWrite();
error_code HandleOp(int op);

std::unique_ptr<FiberSocketBase> next_sock_;
std::unique_ptr<Engine> engine_;
Expand Down
Loading