From 8b534c79dc9901e7d8af29ec05152e4aaeee63e3 Mon Sep 17 00:00:00 2001 From: Eugene K Date: Mon, 11 Sep 2023 18:32:08 -0400 Subject: [PATCH] implement tlsuv_stream flow control [fixes #171] --- src/tls_link.c | 11 +-- src/tlsuv.c | 41 +++++++-- tests/stream_tests.cpp | 193 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 231 insertions(+), 14 deletions(-) diff --git a/src/tls_link.c b/src/tls_link.c index 1b930ef0..fd43465d 100644 --- a/src/tls_link.c +++ b/src/tls_link.c @@ -26,11 +26,12 @@ static int tls_write(uv_link_t *link, uv_link_t *source, const uv_buf_t bufs[], static void tls_close(uv_link_t *link, uv_link_t *source, uv_link_close_cb cb); static const uv_link_methods_t tls_methods = { - .close = tls_close, - .read_start = tls_read_start, - .write = tls_write, - .alloc_cb_override = tls_alloc, - .read_cb_override = tls_read_cb + .close = tls_close, + .read_start = tls_read_start, + .read_stop = uv_link_default_read_stop, + .write = tls_write, + .alloc_cb_override = tls_alloc, + .read_cb_override = tls_read_cb }; typedef struct tls_link_write_s { diff --git a/src/tlsuv.c b/src/tlsuv.c index b28a149d..85614274 100644 --- a/src/tlsuv.c +++ b/src/tlsuv.c @@ -34,11 +34,12 @@ static void tcp_connect_cb(uv_connect_t* req, int status); static const uv_link_methods_t mbed_methods = { - .close = uv_link_default_close, - .read_start = uv_link_default_read_start, - .write = uv_link_default_write, - .alloc_cb_override = uv_link_default_alloc_cb_override, - .read_cb_override = uv_link_default_read_cb_override, + .close = uv_link_default_close, + .read_start = uv_link_default_read_start, + .read_stop = uv_link_default_read_stop, + .write = uv_link_default_write, + .alloc_cb_override = uv_link_default_alloc_cb_override, + .read_cb_override = uv_link_default_read_cb_override, }; static tls_context *DEFAULT_TLS = NULL; @@ -72,6 +73,8 @@ int tlsuv_stream_init(uv_loop_t *l, tlsuv_stream_t *clt, tls_context *tls) { uv_link_init((uv_link_t *) clt, &mbed_methods); clt->tls = tls != NULL ? tls : get_default_tls(); + clt->read_cb = NULL; + clt->alloc_cb = NULL; return 0; } @@ -129,6 +132,7 @@ static void on_tls_hs(tls_link_t *tls_link, int status) { } if (status == TLS_HS_COMPLETE) { + tlsuv_stream_read_stop(stream); req->cb(req, 0); } else if (status == TLS_HS_ERROR) { UM_LOG(WARN, "handshake failed: %s", tls_link->engine->strerror(tls_link->engine)); @@ -197,13 +201,32 @@ int tlsuv_stream_connect(uv_connect_t *req, tlsuv_stream_t *clt, const char *hos } int tlsuv_stream_read_start(tlsuv_stream_t *clt, uv_alloc_cb alloc_cb, uv_read_cb read_cb) { - clt->alloc_cb = (uv_link_alloc_cb) alloc_cb; - clt->read_cb = (uv_link_read_cb) read_cb; - uv_link_read_start((uv_link_t *)clt); - return 0; + if (clt == NULL || alloc_cb == NULL || read_cb == NULL) { + return UV_EINVAL; + } + + if (clt->read_cb) { + return UV_EALREADY; + } + + int rc = uv_link_read_start((uv_link_t *)clt); + if (rc == 0) { + clt->alloc_cb = (uv_link_alloc_cb) alloc_cb; + clt->read_cb = (uv_link_read_cb) read_cb; + } + return rc; } int tlsuv_stream_read_stop(tlsuv_stream_t *clt) { + if (clt == NULL) { + return UV_EINVAL; + } + + if (clt->read_cb == NULL) { + return 0; + } + clt->read_cb = NULL; + clt->alloc_cb = NULL; return uv_link_read_stop((uv_link_t *) clt); } diff --git a/tests/stream_tests.cpp b/tests/stream_tests.cpp index 2e9aeeb3..5e934d94 100644 --- a/tests/stream_tests.cpp +++ b/tests/stream_tests.cpp @@ -21,6 +21,8 @@ limitations under the License. #include "fixtures.h" #include "catch.hpp" +extern tls_context *testServerTLS(); + TEST_CASE("stream connect fail", "[stream]") { UvLoopTest test; @@ -178,4 +180,195 @@ accept: application/dns-json tlsuv_stream_free(&s); tls->free_ctx(tls); +} + +struct connect_args_s { + tlsuv_stream_t *s; + const char *hostname; + int port; + struct test_result *result; +}; + +struct sleep_args_s { + int timeout; +}; + +struct write_args_s { + tlsuv_stream_t *s; + const char *data; +}; + +struct expected_result_s { + struct test_result *result; + const char *data; + int count; +}; + +typedef struct step_s step_t; +typedef void (*step_fn)(uv_loop_t *, struct step_s *); +struct step_s { + step_fn fn; + + union { + connect_args_s connect_args; + write_args_s write_args; + sleep_args_s sleep_args; + expected_result_s expected; + }; +}; + +static inline void start(uv_loop_t *l, step_t *s) { + if (s && s->fn) { + s->fn(l, s); + } +} +static inline step_t *next_step(step_t *s) { return ++s; } + +static inline void run_next(uv_loop_t *l, step_t *step) { + start(l, next_step(step)); +} + +static void sleep_cb(uv_timer_t *t) { + step_t *step = static_cast(t->data); + uv_close(reinterpret_cast(t), + reinterpret_cast(free)); + printf("sleep step is done\n"); + run_next(t->loop, step); +} + +static void sleep_step(uv_loop_t *l, step_t *step) { + printf("running sleep step\n"); + uv_timer_t *t = (uv_timer_t *)calloc(1, sizeof(*t)); + uv_timer_init(l, t); + t->data = step; + uv_timer_start(t, sleep_cb, step->sleep_args.timeout, 0); +} + +static void connect_cb(uv_connect_t *r, int status) { + printf("connected: %d\n", status); + step_t *s = (step_t *)r->data; + auto stream = (tlsuv_stream_t *)r->handle; + auto l = stream->loop; + free(r); + REQUIRE(status == 0); + run_next(l, s); +} + +static void connect_step(uv_loop_t *l, step_t *step) { + tlsuv_stream_t *clt = step->connect_args.s; + REQUIRE(tlsuv_stream_init(l, clt, testServerTLS()) == 0); + clt->data = step->connect_args.result; + uv_connect_t *r = (uv_connect_t *)calloc(1, sizeof(*r)); + r->data = step; + REQUIRE(tlsuv_stream_connect(r, clt, step->connect_args.hostname, step->connect_args.port, connect_cb) == 0); +} + +static void disconnect_cb(uv_handle_t *h) { + auto s = (tlsuv_stream_t *)h; + auto step = (step_t *)s->data; + tlsuv_stream_free(s); + run_next(s->loop, step); +} + +static void disconnect_step(uv_loop_t *l, step_t *step) { + auto s = step->connect_args.s; + s->data = step; + tlsuv_stream_close(step->connect_args.s, disconnect_cb); +} + +static void write_cb(uv_write_t *r, int status) { + auto stream = (tlsuv_stream_t *)r->handle; + auto step = (step_t *)r->data; + REQUIRE(status == 0); + delete r; + run_next(stream->loop, step); +} + +static void write_step(uv_loop_t *l, step_t *step) { + uv_write_t *r = new uv_write_t; + auto buf = uv_buf_init((char *)step->write_args.data, + strlen(step->write_args.data)); + r->data = step; + REQUIRE(tlsuv_stream_write(r, step->write_args.s, &buf, write_cb) == 0); +} + +struct test_result { + int read_count; + std::string read_data; + tlsuv_stream_t *stream; + + public: + explicit test_result(tlsuv_stream_t *s) + : stream(s), read_count(0), read_data("") {} +}; + +static void check_result(uv_loop_t *l, step_t *step) { + printf("read: %s\n", step->expected.result->read_data.c_str()); + REQUIRE(step->expected.result->read_data == step->expected.data); + CHECK(step->expected.result->read_count <= step->expected.count); + run_next(l, step); +} + +static void read_alloc(uv_handle_t *handle, size_t size, uv_buf_t *buf) { + buf->base = (char *)malloc(size); + buf->len = size; +} + +static void read_cb(uv_stream_t *stream, ssize_t nread, const uv_buf_t *buf) { + tlsuv_stream_t *clt = reinterpret_cast(stream); + test_result *result = static_cast(clt->data); + + REQUIRE(nread > 0); + result->read_count++; + result->read_data.append(buf->base, nread); + + free(buf->base); +} + +static void start_read_step(uv_loop_t *l, step_t *step) { + CHECK(tlsuv_stream_read_start(step->write_args.s, nullptr, nullptr) == UV_EINVAL); + CHECK(tlsuv_stream_read_start(step->write_args.s, read_alloc, nullptr) == UV_EINVAL); + CHECK(tlsuv_stream_read_start(step->write_args.s, read_alloc, read_cb) == 0); + CHECK(tlsuv_stream_read_start(step->write_args.s, read_alloc, read_cb) == UV_EALREADY); + run_next(l, step); +} + +static void stop_read_step(uv_loop_t *l, step_t *step) { + REQUIRE(tlsuv_stream_read_stop(step->write_args.s) == 0); + run_next(l, step); +} + +TEST_CASE("read start/stop", "[stream]") { + UvLoopTest loopTest; + tlsuv_stream_t s; + test_result r(&s); + + step_t steps[] = { + { + .fn = connect_step, + .connect_args = { .s = &s, .hostname = "localhost", .port = 7443, .result = &r }, + }, + { .fn = write_step, .write_args = { .s = &s, .data = "1",}}, + { .fn = write_step, .write_args = { .s = &s, .data = "2",}}, + { .fn = sleep_step, .sleep_args = { .timeout = 100, } }, + { .fn = check_result, .expected = { .result = &r, .data = "", .count = 0, }}, // not reading yet + { .fn = start_read_step, .write_args = { .s = &s }}, + { .fn = sleep_step, .sleep_args = { .timeout = 100, } }, + { .fn = check_result, .expected = { .result = &r, .data = "12", .count = 1,}}, // should read echo from two writes + { .fn = stop_read_step, .write_args = {.s = &s }}, + { .fn = write_step, .write_args = { .s = &s, .data = "3",}}, + { .fn = write_step, .write_args = { .s = &s, .data = "4",}}, + { .fn = sleep_step, .sleep_args = { .timeout = 100, } }, + { .fn = check_result, .expected = { .result = &r, .data = "12", .count = 1, }}, // not reading + { .fn = start_read_step, .write_args = { .s = &s }}, + { .fn = write_step, .write_args = { .s = &s, .data = "5",}}, + { .fn = write_step, .write_args = { .s = &s, .data = "6",}}, + { .fn = sleep_step, .sleep_args = { .timeout = 100, } }, + { .fn = check_result, .expected = { .result = &r, .data = "123456", .count = 4,}}, // should read echo from writes 3,4,5,6 + { .fn = disconnect_step, .connect_args = { .s = &s } }, + { .fn = nullptr } + }; + + start(loopTest.loop, steps); + loopTest.run(); } \ No newline at end of file