Skip to content

Commit

Permalink
implement tlsuv_stream flow control [fixes #171]
Browse files Browse the repository at this point in the history
  • Loading branch information
ekoby committed Sep 11, 2023
1 parent da949f4 commit 8b534c7
Show file tree
Hide file tree
Showing 3 changed files with 231 additions and 14 deletions.
11 changes: 6 additions & 5 deletions src/tls_link.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ static int tls_write(uv_link_t *link, uv_link_t *source, const uv_buf_t bufs[],
static void tls_close(uv_link_t *link, uv_link_t *source, uv_link_close_cb cb);

static const uv_link_methods_t tls_methods = {
.close = tls_close,
.read_start = tls_read_start,
.write = tls_write,
.alloc_cb_override = tls_alloc,
.read_cb_override = tls_read_cb
.close = tls_close,
.read_start = tls_read_start,
.read_stop = uv_link_default_read_stop,
.write = tls_write,
.alloc_cb_override = tls_alloc,
.read_cb_override = tls_read_cb
};

typedef struct tls_link_write_s {
Expand Down
41 changes: 32 additions & 9 deletions src/tlsuv.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@
static void tcp_connect_cb(uv_connect_t* req, int status);

static const uv_link_methods_t mbed_methods = {
.close = uv_link_default_close,
.read_start = uv_link_default_read_start,
.write = uv_link_default_write,
.alloc_cb_override = uv_link_default_alloc_cb_override,
.read_cb_override = uv_link_default_read_cb_override,
.close = uv_link_default_close,
.read_start = uv_link_default_read_start,
.read_stop = uv_link_default_read_stop,
.write = uv_link_default_write,
.alloc_cb_override = uv_link_default_alloc_cb_override,
.read_cb_override = uv_link_default_read_cb_override,
};

static tls_context *DEFAULT_TLS = NULL;
Expand Down Expand Up @@ -72,6 +73,8 @@ int tlsuv_stream_init(uv_loop_t *l, tlsuv_stream_t *clt, tls_context *tls) {

uv_link_init((uv_link_t *) clt, &mbed_methods);
clt->tls = tls != NULL ? tls : get_default_tls();
clt->read_cb = NULL;
clt->alloc_cb = NULL;

return 0;
}
Expand Down Expand Up @@ -129,6 +132,7 @@ static void on_tls_hs(tls_link_t *tls_link, int status) {
}

if (status == TLS_HS_COMPLETE) {
tlsuv_stream_read_stop(stream);
req->cb(req, 0);
} else if (status == TLS_HS_ERROR) {
UM_LOG(WARN, "handshake failed: %s", tls_link->engine->strerror(tls_link->engine));
Expand Down Expand Up @@ -197,13 +201,32 @@ int tlsuv_stream_connect(uv_connect_t *req, tlsuv_stream_t *clt, const char *hos
}

int tlsuv_stream_read_start(tlsuv_stream_t *clt, uv_alloc_cb alloc_cb, uv_read_cb read_cb) {
clt->alloc_cb = (uv_link_alloc_cb) alloc_cb;
clt->read_cb = (uv_link_read_cb) read_cb;
uv_link_read_start((uv_link_t *)clt);
return 0;
if (clt == NULL || alloc_cb == NULL || read_cb == NULL) {
return UV_EINVAL;
}

if (clt->read_cb) {
return UV_EALREADY;
}

int rc = uv_link_read_start((uv_link_t *)clt);
if (rc == 0) {
clt->alloc_cb = (uv_link_alloc_cb) alloc_cb;
clt->read_cb = (uv_link_read_cb) read_cb;
}
return rc;
}

int tlsuv_stream_read_stop(tlsuv_stream_t *clt) {
if (clt == NULL) {
return UV_EINVAL;
}

if (clt->read_cb == NULL) {
return 0;
}
clt->read_cb = NULL;
clt->alloc_cb = NULL;
return uv_link_read_stop((uv_link_t *) clt);
}

Expand Down
193 changes: 193 additions & 0 deletions tests/stream_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ limitations under the License.
#include "fixtures.h"
#include "catch.hpp"

extern tls_context *testServerTLS();

TEST_CASE("stream connect fail", "[stream]") {
UvLoopTest test;

Expand Down Expand Up @@ -178,4 +180,195 @@ accept: application/dns-json
tlsuv_stream_free(&s);

tls->free_ctx(tls);
}

struct connect_args_s {
tlsuv_stream_t *s;
const char *hostname;
int port;
struct test_result *result;
};

struct sleep_args_s {
int timeout;
};

struct write_args_s {
tlsuv_stream_t *s;
const char *data;
};

struct expected_result_s {
struct test_result *result;
const char *data;
int count;
};

typedef struct step_s step_t;
typedef void (*step_fn)(uv_loop_t *, struct step_s *);
struct step_s {
step_fn fn;

union {
connect_args_s connect_args;
write_args_s write_args;
sleep_args_s sleep_args;
expected_result_s expected;
};
};

static inline void start(uv_loop_t *l, step_t *s) {
if (s && s->fn) {
s->fn(l, s);
}
}
static inline step_t *next_step(step_t *s) { return ++s; }

static inline void run_next(uv_loop_t *l, step_t *step) {
start(l, next_step(step));
}

static void sleep_cb(uv_timer_t *t) {
step_t *step = static_cast<step_t *>(t->data);
uv_close(reinterpret_cast<uv_handle_t *>(t),
reinterpret_cast<uv_close_cb>(free));
printf("sleep step is done\n");
run_next(t->loop, step);
}

static void sleep_step(uv_loop_t *l, step_t *step) {
printf("running sleep step\n");
uv_timer_t *t = (uv_timer_t *)calloc(1, sizeof(*t));
uv_timer_init(l, t);
t->data = step;
uv_timer_start(t, sleep_cb, step->sleep_args.timeout, 0);
}

static void connect_cb(uv_connect_t *r, int status) {
printf("connected: %d\n", status);
step_t *s = (step_t *)r->data;
auto stream = (tlsuv_stream_t *)r->handle;
auto l = stream->loop;
free(r);
REQUIRE(status == 0);
run_next(l, s);
}

static void connect_step(uv_loop_t *l, step_t *step) {
tlsuv_stream_t *clt = step->connect_args.s;
REQUIRE(tlsuv_stream_init(l, clt, testServerTLS()) == 0);
clt->data = step->connect_args.result;
uv_connect_t *r = (uv_connect_t *)calloc(1, sizeof(*r));
r->data = step;
REQUIRE(tlsuv_stream_connect(r, clt, step->connect_args.hostname, step->connect_args.port, connect_cb) == 0);
}

static void disconnect_cb(uv_handle_t *h) {
auto s = (tlsuv_stream_t *)h;
auto step = (step_t *)s->data;
tlsuv_stream_free(s);
run_next(s->loop, step);
}

static void disconnect_step(uv_loop_t *l, step_t *step) {
auto s = step->connect_args.s;
s->data = step;
tlsuv_stream_close(step->connect_args.s, disconnect_cb);
}

static void write_cb(uv_write_t *r, int status) {
auto stream = (tlsuv_stream_t *)r->handle;
auto step = (step_t *)r->data;
REQUIRE(status == 0);
delete r;
run_next(stream->loop, step);
}

static void write_step(uv_loop_t *l, step_t *step) {
uv_write_t *r = new uv_write_t;
auto buf = uv_buf_init((char *)step->write_args.data,
strlen(step->write_args.data));
r->data = step;
REQUIRE(tlsuv_stream_write(r, step->write_args.s, &buf, write_cb) == 0);
}

struct test_result {
int read_count;
std::string read_data;
tlsuv_stream_t *stream;

public:
explicit test_result(tlsuv_stream_t *s)
: stream(s), read_count(0), read_data("") {}
};

static void check_result(uv_loop_t *l, step_t *step) {
printf("read: %s\n", step->expected.result->read_data.c_str());
REQUIRE(step->expected.result->read_data == step->expected.data);
CHECK(step->expected.result->read_count <= step->expected.count);
run_next(l, step);
}

static void read_alloc(uv_handle_t *handle, size_t size, uv_buf_t *buf) {
buf->base = (char *)malloc(size);
buf->len = size;
}

static void read_cb(uv_stream_t *stream, ssize_t nread, const uv_buf_t *buf) {
tlsuv_stream_t *clt = reinterpret_cast<tlsuv_stream_t *>(stream);
test_result *result = static_cast<test_result *>(clt->data);

REQUIRE(nread > 0);
result->read_count++;
result->read_data.append(buf->base, nread);

free(buf->base);
}

static void start_read_step(uv_loop_t *l, step_t *step) {
CHECK(tlsuv_stream_read_start(step->write_args.s, nullptr, nullptr) == UV_EINVAL);
CHECK(tlsuv_stream_read_start(step->write_args.s, read_alloc, nullptr) == UV_EINVAL);
CHECK(tlsuv_stream_read_start(step->write_args.s, read_alloc, read_cb) == 0);
CHECK(tlsuv_stream_read_start(step->write_args.s, read_alloc, read_cb) == UV_EALREADY);
run_next(l, step);
}

static void stop_read_step(uv_loop_t *l, step_t *step) {
REQUIRE(tlsuv_stream_read_stop(step->write_args.s) == 0);
run_next(l, step);
}

TEST_CASE("read start/stop", "[stream]") {
UvLoopTest loopTest;
tlsuv_stream_t s;
test_result r(&s);

step_t steps[] = {
{
.fn = connect_step,

Check failure on line 348 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, mbedtls)

use of designated initializers requires at least '/std:c++20'

Check failure on line 348 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, openssl)

use of designated initializers requires at least '/std:c++20'
.connect_args = { .s = &s, .hostname = "localhost", .port = 7443, .result = &r },

Check failure on line 349 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, mbedtls)

use of designated initializers requires at least '/std:c++20'

Check failure on line 349 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, openssl)

use of designated initializers requires at least '/std:c++20'
},
{ .fn = write_step, .write_args = { .s = &s, .data = "1",}},

Check failure on line 351 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, mbedtls)

use of designated initializers requires at least '/std:c++20'

Check failure on line 351 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, mbedtls)

use of designated initializers requires at least '/std:c++20'

Check failure on line 351 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, openssl)

use of designated initializers requires at least '/std:c++20'

Check failure on line 351 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, openssl)

use of designated initializers requires at least '/std:c++20'
{ .fn = write_step, .write_args = { .s = &s, .data = "2",}},

Check failure on line 352 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, mbedtls)

use of designated initializers requires at least '/std:c++20'

Check failure on line 352 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, mbedtls)

use of designated initializers requires at least '/std:c++20'

Check failure on line 352 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, openssl)

use of designated initializers requires at least '/std:c++20'

Check failure on line 352 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, openssl)

use of designated initializers requires at least '/std:c++20'
{ .fn = sleep_step, .sleep_args = { .timeout = 100, } },

Check failure on line 353 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, mbedtls)

use of designated initializers requires at least '/std:c++20'

Check failure on line 353 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, mbedtls)

use of designated initializers requires at least '/std:c++20'

Check failure on line 353 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, openssl)

use of designated initializers requires at least '/std:c++20'

Check failure on line 353 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, openssl)

use of designated initializers requires at least '/std:c++20'
{ .fn = check_result, .expected = { .result = &r, .data = "", .count = 0, }}, // not reading yet

Check failure on line 354 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, mbedtls)

use of designated initializers requires at least '/std:c++20'

Check failure on line 354 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, mbedtls)

use of designated initializers requires at least '/std:c++20'

Check failure on line 354 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, openssl)

use of designated initializers requires at least '/std:c++20'

Check failure on line 354 in tests/stream_tests.cpp

View workflow job for this annotation

GitHub Actions / build (windows, openssl)

use of designated initializers requires at least '/std:c++20'
{ .fn = start_read_step, .write_args = { .s = &s }},
{ .fn = sleep_step, .sleep_args = { .timeout = 100, } },
{ .fn = check_result, .expected = { .result = &r, .data = "12", .count = 1,}}, // should read echo from two writes
{ .fn = stop_read_step, .write_args = {.s = &s }},
{ .fn = write_step, .write_args = { .s = &s, .data = "3",}},
{ .fn = write_step, .write_args = { .s = &s, .data = "4",}},
{ .fn = sleep_step, .sleep_args = { .timeout = 100, } },
{ .fn = check_result, .expected = { .result = &r, .data = "12", .count = 1, }}, // not reading
{ .fn = start_read_step, .write_args = { .s = &s }},
{ .fn = write_step, .write_args = { .s = &s, .data = "5",}},
{ .fn = write_step, .write_args = { .s = &s, .data = "6",}},
{ .fn = sleep_step, .sleep_args = { .timeout = 100, } },
{ .fn = check_result, .expected = { .result = &r, .data = "123456", .count = 4,}}, // should read echo from writes 3,4,5,6
{ .fn = disconnect_step, .connect_args = { .s = &s } },
{ .fn = nullptr }
};

start(loopTest.loop, steps);
loopTest.run();
}

0 comments on commit 8b534c7

Please sign in to comment.