From 6a8ae4333e27cd37b003977d5845ceab73300e23 Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Fri, 12 Jul 2024 14:36:46 +0300 Subject: [PATCH] chore: fixes in tls and client sockets Signed-off-by: Roman Gershman --- util/fiber_socket_base.h | 7 +-- util/fibers/epoll_socket.cc | 6 ++- util/fibers/epoll_socket.h | 3 +- util/fibers/uring_socket.cc | 11 +++-- util/fibers/uring_socket.h | 5 +- util/http/http_client.cc | 15 ++++-- util/tls/tls_engine.cc | 13 +++-- util/tls/tls_engine.h | 5 +- util/tls/tls_engine_test.cc | 13 +++-- util/tls/tls_socket.cc | 98 +++++++++++++++++++++++++------------ util/tls/tls_socket.h | 6 ++- 11 files changed, 117 insertions(+), 65 deletions(-) diff --git a/util/fiber_socket_base.h b/util/fiber_socket_base.h index bc8ec082..170ff718 100644 --- a/util/fiber_socket_base.h +++ b/util/fiber_socket_base.h @@ -39,7 +39,8 @@ class FiberSocketBase : public io::Sink, public io::AsyncSink, public io::Source ABSL_MUST_USE_RESULT virtual AcceptResult Accept() = 0; - ABSL_MUST_USE_RESULT virtual error_code Connect(const endpoint_type& ep) = 0; + ABSL_MUST_USE_RESULT virtual error_code Connect(const endpoint_type& ep, + std::function on_pre_connect = {}) = 0; ABSL_MUST_USE_RESULT virtual error_code Close() = 0; @@ -200,8 +201,8 @@ class LinuxSocketBase : public FiberSocketBase { // gives me 256M descriptors. int32_t fd_; - private: - uint32_t timeout_ = UINT32_MAX; + private: + uint32_t timeout_ = UINT32_MAX; }; void SetNonBlocking(int fd); diff --git a/util/fibers/epoll_socket.cc b/util/fibers/epoll_socket.cc index 390c72bc..051a865e 100644 --- a/util/fibers/epoll_socket.cc +++ b/util/fibers/epoll_socket.cc @@ -189,7 +189,7 @@ auto EpollSocket::Accept() -> AcceptResult { return nonstd::make_unexpected(ec); } -auto EpollSocket::Connect(const endpoint_type& ep) -> error_code { +error_code EpollSocket::Connect(const endpoint_type& ep, std::function on_pre_connect) { CHECK_EQ(fd_, -1); CHECK(proactor() && proactor()->InMyThread()); @@ -208,7 +208,9 @@ auto EpollSocket::Connect(const endpoint_type& ep) -> error_code { write_context_ = detail::FiberActive(); absl::Cleanup clean = [this]() { write_context_ = nullptr; }; - // RegisterEvents(GetProactor()->ev_loop_fd(), fd, arm_index_ + 1024); + if (on_pre_connect) { + on_pre_connect(fd); + } DVSOCK(2) << "Connecting"; diff --git a/util/fibers/epoll_socket.h b/util/fibers/epoll_socket.h index 7143bc28..395a56a2 100644 --- a/util/fibers/epoll_socket.h +++ b/util/fibers/epoll_socket.h @@ -20,7 +20,8 @@ class EpollSocket : public LinuxSocketBase { ABSL_MUST_USE_RESULT AcceptResult Accept() final; - ABSL_MUST_USE_RESULT error_code Connect(const endpoint_type& ep) final; + ABSL_MUST_USE_RESULT error_code Connect(const endpoint_type& ep, + std::function on_pre_connect) final; ABSL_MUST_USE_RESULT error_code Close() final; // Really need here expected. diff --git a/util/fibers/uring_socket.cc b/util/fibers/uring_socket.cc index 49540c4d..62592e15 100644 --- a/util/fibers/uring_socket.cc +++ b/util/fibers/uring_socket.cc @@ -146,7 +146,7 @@ auto UringSocket::Accept() -> AcceptResult { return fs; } -auto UringSocket::Connect(const endpoint_type& ep) -> error_code { +auto UringSocket::Connect(const endpoint_type& ep, std::function on_pre_connect) -> error_code { CHECK_EQ(fd_, -1); CHECK(proactor() && proactor()->InMyThread()); @@ -163,12 +163,13 @@ auto UringSocket::Connect(const endpoint_type& ep) -> error_code { // TODO: support direct descriptors. For now client sockets always use regular linux fds. fd_ = fd << kFdShift; - IoResult io_res; - ep.data(); + if (on_pre_connect) { + on_pre_connect(fd); + } FiberCall fc(proactor, timeout()); fc->PrepConnect(fd, (const sockaddr*)ep.data(), ep.size()); - io_res = fc.Get(); + IoResult io_res = fc.Get(); if (io_res < 0) { // In that case connect returns -errno. ec = error_code(-io_res, system_category()); @@ -333,7 +334,7 @@ io::Result UringSocket::Recv(const io::MutableBytes& mb, int flags) { Proactor* p = GetProactor(); DCHECK(ProactorBase::me() == p); - VSOCK(2) << "Recv [" << fd << "] " << flags; + VSOCK(2) << "Recv [" << fd << "], flags: " << flags; ssize_t res; while (true) { FiberCall fc(p, timeout()); diff --git a/util/fibers/uring_socket.h b/util/fibers/uring_socket.h index 154ac9ea..4603d6a6 100644 --- a/util/fibers/uring_socket.h +++ b/util/fibers/uring_socket.h @@ -32,7 +32,8 @@ class UringSocket : public LinuxSocketBase { ABSL_MUST_USE_RESULT AcceptResult Accept() final; - ABSL_MUST_USE_RESULT error_code Connect(const endpoint_type& ep) final; + ABSL_MUST_USE_RESULT error_code Connect(const endpoint_type& ep, + std::function on_pre_connect) final; ABSL_MUST_USE_RESULT error_code Close() final; io::Result WriteSome(const iovec* v, uint32_t len) override; @@ -75,7 +76,7 @@ class UringSocket : public LinuxSocketBase { struct ErrorCbRefWrapper { uint32_t error_cb_id = 0; - uint32_t ref_count = 2; // one for the socket reference, one for the completion lambda. + uint32_t ref_count = 2; // one for the socket reference, one for the completion lambda. std::function cb; static ErrorCbRefWrapper* New(std::function cb) { diff --git a/util/http/http_client.cc b/util/http/http_client.cc index 42337004..4ef3bd00 100644 --- a/util/http/http_client.cc +++ b/util/http/http_client.cc @@ -82,12 +82,15 @@ std::error_code Client::Reconnect() { return berr; FiberSocketBase* sock = proactor_->CreateSocket(); - if (on_connect_cb_) { - on_connect_cb_(sock->native_handle()); - } + socket_.reset(sock); FiberSocketBase::endpoint_type ep{address, port_}; - return socket_->Connect(ep); + auto on_connect = [this](int fd) { + if (on_connect_cb_) { + on_connect_cb_(fd); + } + }; + return socket_->Connect(ep, std::move(on_connect)); } #if 0 @@ -181,7 +184,9 @@ std::error_code TlsClient::Connect(string_view host, string_view service, SSL_CT // verify server cert using server hostname SSL_dane_enable(ssl_handle, host); ec = tls_socket->Connect(FiberSocketBase::endpoint_type{}); - if (!ec) { + if (ec) { + std::ignore = tls_socket->Close(); + } else { socket_.reset(tls_socket.release()); } } diff --git a/util/tls/tls_engine.cc b/util/tls/tls_engine.cc index 8d4c7130..83983e89 100644 --- a/util/tls/tls_engine.cc +++ b/util/tls/tls_engine.cc @@ -97,8 +97,11 @@ Engine::Engine(SSL_CTX* context) : ssl_(::SSL_new(context)) { // SSL_set0_[rw]bio take ownership of the passed reference, // so if we call both with the same BIO, we need the refcount to be 2. BIO_up_ref(int_bio); + SSL_set0_rbio(ssl_, int_bio); SSL_set0_wbio(ssl_, int_bio); + SSL_set_msg_callback(ssl_, SSL_trace); + SSL_set_msg_callback_arg(ssl_, BIO_new_fp(stdout,0)); } Engine::~Engine() { @@ -111,21 +114,21 @@ Engine::~Engine() { } -auto Engine::FetchOutputBuf() -> BufResult { +auto Engine::FetchOutputBuf() -> Buffer { char* buf = nullptr; int res = BIO_nread(external_bio_, &buf, INT_MAX); if (res < 0) { unsigned long error = ::ERR_get_error(); - return nonstd::make_unexpected(error); + LOG(DFATAL) << "Unexpected result " << res << " " << error; + + return Buffer{}; } return Buffer(reinterpret_cast(buf), res); } -// TODO: to consider replacing BufResult with Buffer since -// it seems BIO_C_NREAD0 should not return negative values when used properly. -auto Engine::PeekOutputBuf() -> BufResult { +auto Engine::PeekOutputBuf() -> Buffer { char* buf = nullptr; long res = BIO_ctrl(external_bio_, BIO_C_NREAD0, 0, &buf); diff --git a/util/tls/tls_engine.h b/util/tls/tls_engine.h index e8c91bfb..c45a85f9 100644 --- a/util/tls/tls_engine.h +++ b/util/tls/tls_engine.h @@ -33,7 +33,6 @@ class Engine { // write. In any case for non-error OpResult a caller must check OutputPending and write the // output buffer to the appropriate channel. using OpResult = io::Result; - using BufResult = io::Result; // Construct a new engine for the specified context. explicit Engine(SSL_CTX* context); @@ -67,11 +66,11 @@ class Engine { //! Returns output (read) buffer. This operation is destructive, i.e. after calling //! this function the buffer is being consumed. //! See OutputPending() for checking if there is a output buffer to consume. - BufResult FetchOutputBuf(); + Buffer FetchOutputBuf(); //! Returns output buffer which is the read buffer of tls engine. //! This operation is not destructive. - BufResult PeekOutputBuf(); + Buffer PeekOutputBuf(); //! Tells the engine that sz bytes were consumed from the output buffer. //! sz should be not greater than the buffer size from the last PeekOutputBuf() call. diff --git a/util/tls/tls_engine_test.cc b/util/tls/tls_engine_test.cc index 3145e5a9..7436d872 100644 --- a/util/tls/tls_engine_test.cc +++ b/util/tls/tls_engine_test.cc @@ -143,18 +143,17 @@ static unsigned long RunPeer(SslStreamTest::Options opts, SslStreamTest::OpCb cb if (opts.drain_output) src->FetchOutputBuf(); else { - auto buf_result = src->PeekOutputBuf(); - CHECK(buf_result); - VLOG(1) << opts.name << " wrote " << buf_result->size() << " bytes"; - CHECK(!buf_result->empty()); + auto buffer = src->PeekOutputBuf(); + VLOG(1) << opts.name << " wrote " << buffer.size() << " bytes"; + CHECK(!buffer.empty()); if (opts.mutate_indx) { - uint8_t* mem = const_cast(buf_result->data()); - mem[opts.mutate_indx % buf_result->size()] = opts.mutate_val; + uint8_t* mem = const_cast(buffer.data()); + mem[opts.mutate_indx % buffer.size()] = opts.mutate_val; opts.mutate_indx = 0; } - auto write_result = dest->WriteBuf(*buf_result); + auto write_result = dest->WriteBuf(buffer); if (!write_result) { return write_result.error(); } diff --git a/util/tls/tls_socket.cc b/util/tls/tls_socket.cc index 969c9dc8..df5cb39b 100644 --- a/util/tls/tls_socket.cc +++ b/util/tls/tls_socket.cc @@ -136,7 +136,7 @@ auto TlsSocket::Accept() -> AcceptResult { return make_unexpected(make_error_code(errc::connection_reset)); } if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) { - ec = HandleRead(); + ec = HandleSocketRead(); if (ec) return make_unexpected(ec); } @@ -145,19 +145,53 @@ auto TlsSocket::Accept() -> AcceptResult { return nullptr; } -auto TlsSocket::Connect(const endpoint_type& endpoint) -> error_code { +error_code TlsSocket::Connect(const endpoint_type& endpoint, + std::function on_pre_connect) { DCHECK(engine_); - auto io_result = engine_->Handshake(Engine::HandshakeType::CLIENT); - if (!io_result.has_value()) { - return std::error_code(io_result.error(), std::system_category()); + Engine::OpResult op_result = engine_->Handshake(Engine::HandshakeType::CLIENT); + if (!op_result) { + return std::error_code(op_result.error(), std::system_category()); } // If the socket is already open, we should not call connect on it - if (IsOpen()) { - return {}; + if (!IsOpen()) { + error_code ec = next_sock_->Connect(endpoint, std::move(on_pre_connect)); + if (ec) + return ec; + } + + // Flush the ssl data to the socket and run the loop that ensures handshaking converges. + int op_val = *op_result; + error_code ec; + + // it should guide us to write and then read. + DCHECK_EQ(op_val, Engine::NEED_READ_AND_MAYBE_WRITE); + while (op_val < 0) { + if (op_val == Engine::EOF_STREAM) { + return make_error_code(errc::connection_reset); + } + + if (op_val == Engine::NEED_WRITE) { + ec = HandleSocketWrite(); + if (ec) + return ec; + } else if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) { + ec = HandleSocketWrite(); + if (ec) + return ec; + + ec = HandleSocketRead(); + if (ec) + return ec; + } + op_result = engine_->Handshake(Engine::HandshakeType::CLIENT); + if (!op_result) { + return std::error_code(op_result.error(), std::system_category()); + } + op_val = *op_result; } - return next_sock_->Connect(endpoint); + return ec; } auto TlsSocket::Close() -> error_code { @@ -249,7 +283,7 @@ io::Result TlsSocket::RecvMsg(const msghdr& msg, int flags) { } if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) { - ec = HandleRead(); + ec = HandleSocketRead(); if (ec) return make_unexpected(ec); } @@ -341,7 +375,7 @@ io::Result TlsSocket::SendBuffer(Engine::Buffer buf) { } if (op_val == Engine::NEED_READ_AND_MAYBE_WRITE) { - ec = HandleRead(); + ec = HandleSocketRead(); if (ec) return make_unexpected(ec); } @@ -381,28 +415,10 @@ auto TlsSocket::MaybeSendOutput() -> error_code { return error_code{}; } - auto buf_result = engine_->PeekOutputBuf(); - CHECK(buf_result); - - if (!buf_result->empty()) { - // we do not allow concurrent writes from multiple fibers. - state_ |= WRITE_IN_PROGRESS; - io::Result write_result = next_sock_->WriteSome(*buf_result); - - // Safe to clear here since the code below is atomic fiber-wise. - state_ &= ~WRITE_IN_PROGRESS; - DCHECK(engine_); - if (!write_result) { - return write_result.error(); - } - CHECK_GT(*write_result, 0u); - engine_->ConsumeOutputBuf(*write_result); - } - - return error_code{}; + return HandleSocketWrite(); } -auto TlsSocket::HandleRead() -> error_code { +auto TlsSocket::HandleSocketRead() -> error_code { if (state_ & READ_IN_PROGRESS) { // We need to Yield because otherwise we might end up in an infinite loop. // See also comments in MaybeSendOutput. @@ -423,6 +439,28 @@ auto TlsSocket::HandleRead() -> error_code { return error_code{}; } +error_code TlsSocket::HandleSocketWrite() { + Engine::Buffer buffer = engine_->PeekOutputBuf(); + + while (!buffer.empty()) { + // we do not allow concurrent writes from multiple fibers. + state_ |= WRITE_IN_PROGRESS; + io::Result write_result = next_sock_->WriteSome(buffer); + + // Safe to clear here since the code below is atomic fiber-wise. + state_ &= ~WRITE_IN_PROGRESS; + DCHECK(engine_); + if (!write_result) { + return write_result.error(); + } + CHECK_GT(*write_result, 0u); + engine_->ConsumeOutputBuf(*write_result); + buffer.remove_prefix(*write_result); + } + + return error_code{}; +} + TlsSocket::endpoint_type TlsSocket::LocalEndpoint() const { return next_sock_->LocalEndpoint(); } diff --git a/util/tls/tls_socket.h b/util/tls/tls_socket.h index c0ed9985..57d7b087 100644 --- a/util/tls/tls_socket.h +++ b/util/tls/tls_socket.h @@ -37,7 +37,7 @@ class TlsSocket final : public FiberSocketBase { // The endpoint should not really pass here, it is to keep // the interface with FiberSocketBase. - error_code Connect(const endpoint_type&) final; + error_code Connect(const endpoint_type& ep, std::function on_pre_connect = {}) final; error_code Close() final; @@ -92,7 +92,9 @@ class TlsSocket final : public FiberSocketBase { error_code MaybeSendOutput(); /// Read encrypted data from the network socket and feed it into the TLS engine. - error_code HandleRead(); + error_code HandleSocketRead(); + + error_code HandleSocketWrite(); std::unique_ptr next_sock_; std::unique_ptr engine_;