diff --git a/include/tlsuv/tlsuv.h b/include/tlsuv/tlsuv.h index d3b98ed3..9a32b46e 100644 --- a/include/tlsuv/tlsuv.h +++ b/include/tlsuv/tlsuv.h @@ -43,7 +43,8 @@ int tlsuv_stream_connect(uv_connect_t *req, tlsuv_stream_t *clt, const char *hos int tlsuv_stream_connect_addr(uv_connect_t *req, tlsuv_stream_t *clt, const struct addrinfo *addr, uv_connect_cb cb); -int tlsuv_stream_read(tlsuv_stream_t *clt, uv_alloc_cb, uv_read_cb); +int tlsuv_stream_read_start(tlsuv_stream_t *clt, uv_alloc_cb alloc_cb, uv_read_cb read_cb); +int tlsuv_stream_read_stop(tlsuv_stream_t *clt); int tlsuv_stream_write(uv_write_t *req, tlsuv_stream_t *clt, uv_buf_t *buf, uv_write_cb cb); diff --git a/sample/sample-cf.c b/sample/sample-cf.c index 23c82d2d..b6446cba 100755 --- a/sample/sample-cf.c +++ b/sample/sample-cf.c @@ -67,7 +67,7 @@ void on_connect(uv_connect_t *cr, int status) { } tlsuv_stream_t *mbed = (tlsuv_stream_t *) cr->handle; - tlsuv_stream_read(mbed, alloc, on_data); + tlsuv_stream_read_start(mbed, alloc, on_data); uv_write_t *wr = malloc(sizeof(uv_write_t)); char req[] = "GET " PATH " HTTP/1.1\r\n" diff --git a/sample/sample.c b/sample/sample.c index bfa186e1..849e3811 100755 --- a/sample/sample.c +++ b/sample/sample.c @@ -61,7 +61,7 @@ void on_connect(uv_connect_t *cr, int status) { } tlsuv_stream_t *mbed = (tlsuv_stream_t *) cr->handle; - tlsuv_stream_read(mbed, alloc, on_data); + tlsuv_stream_read_start(mbed, alloc, on_data); uv_write_t *wr = malloc(sizeof(uv_write_t)); char req[] = "GET / HTTP/1.1\r\n" diff --git a/src/tls_link.c b/src/tls_link.c index 1b930ef0..fd43465d 100644 --- a/src/tls_link.c +++ b/src/tls_link.c @@ -26,11 +26,12 @@ static int tls_write(uv_link_t *link, uv_link_t *source, const uv_buf_t bufs[], static void tls_close(uv_link_t *link, uv_link_t *source, uv_link_close_cb cb); static const uv_link_methods_t tls_methods = { - .close = tls_close, - .read_start = tls_read_start, - .write = tls_write, - .alloc_cb_override = tls_alloc, - .read_cb_override = tls_read_cb + .close = tls_close, + .read_start = tls_read_start, + .read_stop = uv_link_default_read_stop, + .write = tls_write, + .alloc_cb_override = tls_alloc, + .read_cb_override = tls_read_cb }; typedef struct tls_link_write_s { diff --git a/src/tlsuv.c b/src/tlsuv.c index c33a561b..85614274 100644 --- a/src/tlsuv.c +++ b/src/tlsuv.c @@ -31,16 +31,15 @@ #define TLSUV_VERS "" #endif -static void tls_debug_f(void *ctx, int level, const char *file, int line, const char *str); static void tcp_connect_cb(uv_connect_t* req, int status); -static int mbed_ssl_send(void* ctx, const uint8_t *buf, size_t len); static const uv_link_methods_t mbed_methods = { - .close = uv_link_default_close, - .read_start = uv_link_default_read_start, - .write = uv_link_default_write, - .alloc_cb_override = uv_link_default_alloc_cb_override, - .read_cb_override = uv_link_default_read_cb_override, + .close = uv_link_default_close, + .read_start = uv_link_default_read_start, + .read_stop = uv_link_default_read_stop, + .write = uv_link_default_write, + .alloc_cb_override = uv_link_default_alloc_cb_override, + .read_cb_override = uv_link_default_read_cb_override, }; static tls_context *DEFAULT_TLS = NULL; @@ -74,6 +73,8 @@ int tlsuv_stream_init(uv_loop_t *l, tlsuv_stream_t *clt, tls_context *tls) { uv_link_init((uv_link_t *) clt, &mbed_methods); clt->tls = tls != NULL ? tls : get_default_tls(); + clt->read_cb = NULL; + clt->alloc_cb = NULL; return 0; } @@ -131,6 +132,7 @@ static void on_tls_hs(tls_link_t *tls_link, int status) { } if (status == TLS_HS_COMPLETE) { + tlsuv_stream_read_stop(stream); req->cb(req, 0); } else if (status == TLS_HS_ERROR) { UM_LOG(WARN, "handshake failed: %s", tls_link->engine->strerror(tls_link->engine)); @@ -198,10 +200,34 @@ int tlsuv_stream_connect(uv_connect_t *req, tlsuv_stream_t *clt, const char *hos return clt->socket->connect((tlsuv_src_t *) clt->socket, host, portstr, on_src_connect, clt); } -int tlsuv_stream_read(tlsuv_stream_t *clt, uv_alloc_cb alloc_cb, uv_read_cb read_cb) { - clt->alloc_cb = (uv_link_alloc_cb) alloc_cb; - clt->read_cb = (uv_link_read_cb) read_cb; - return 0; +int tlsuv_stream_read_start(tlsuv_stream_t *clt, uv_alloc_cb alloc_cb, uv_read_cb read_cb) { + if (clt == NULL || alloc_cb == NULL || read_cb == NULL) { + return UV_EINVAL; + } + + if (clt->read_cb) { + return UV_EALREADY; + } + + int rc = uv_link_read_start((uv_link_t *)clt); + if (rc == 0) { + clt->alloc_cb = (uv_link_alloc_cb) alloc_cb; + clt->read_cb = (uv_link_read_cb) read_cb; + } + return rc; +} + +int tlsuv_stream_read_stop(tlsuv_stream_t *clt) { + if (clt == NULL) { + return UV_EINVAL; + } + + if (clt->read_cb == NULL) { + return 0; + } + clt->read_cb = NULL; + clt->alloc_cb = NULL; + return uv_link_read_stop((uv_link_t *) clt); } static void on_mbed_link_write(uv_link_t* l, int status, void *ctx) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b342147a..6673d225 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -40,7 +40,7 @@ set(test_srcs http_tests.cpp ws_tests.cpp engine_tests.cpp - uv_mbed_tests.cpp + stream_tests.cpp compression_tests.cpp key_tests.cpp ) @@ -48,7 +48,7 @@ set(test_srcs add_executable(all_tests all_tests.cpp ${test_srcs}) -set_property(TARGET all_tests PROPERTY CXX_STANDARD 11) +set_property(TARGET all_tests PROPERTY CXX_STANDARD 20) target_compile_options(all_tests PRIVATE ${PKCS11_OPTS}) target_compile_definitions(all_tests PRIVATE TEST_DATA_DIR=${CMAKE_CURRENT_BINARY_DIR}/testdata) @@ -80,4 +80,4 @@ add_test(key_tests all_tests [key]) add_test(engine_tests all_tests [engine]) add_test(http_tests all_tests [http]) add_test(ws_tests all_tests [websocket]) -add_test(uv_mbed all_tests [uv-mbed]) +add_test(uv_mbed all_tests [stream]) diff --git a/tests/Readme.md b/tests/Readme.md index ac9f9e0e..5b4b273d 100644 --- a/tests/Readme.md +++ b/tests/Readme.md @@ -9,6 +9,7 @@ The test server opens the following endpoints: | 8080 | HTTP test endpoint (httpbin API) | | 8443 | HTTPS test endpoint (httpbin API) | | 9443 | client auth endpoint: checks supplied client certificate | +| 7443 | TLS echo server | Start test server: diff --git a/tests/http_tests.cpp b/tests/http_tests.cpp index b82c9dce..72c2e5d8 100644 --- a/tests/http_tests.cpp +++ b/tests/http_tests.cpp @@ -67,22 +67,22 @@ tls_context* testServerTLS() { return srv.TLS(); } -struct ci_less : std::binary_function -{ - // case-independent (ci) compare_less binary function - struct nocase_compare : public std::binary_function - { - bool operator() (const unsigned char& c1, const unsigned char& c2) const { - return tolower (c1) < tolower (c2); - } - }; - bool operator() (const std::string & s1, const std::string & s2) const { - return std::lexicographical_compare - (s1.begin (), s1.end (), // source range - s2.begin (), s2.end (), // dest range - nocase_compare ()); // comparison - } -}; +//struct ci_less : std::binary_function +//{ +// // case-independent (ci) compare_less binary function +// struct nocase_compare : public std::binary_function +// { +// bool operator() (const unsigned char& c1, const unsigned char& c2) const { +// return tolower (c1) < tolower (c2); +// } +// }; +// bool operator() (const std::string & s1, const std::string & s2) const { +// return std::lexicographical_compare +// (s1.begin (), s1.end (), // source range +// s2.begin (), s2.end (), // dest range +// nocase_compare ()); // comparison +// } +//}; class resp_capture { public: @@ -95,7 +95,7 @@ class resp_capture { string http_version; ssize_t code; string status; - map headers; + map headers; string body; string req_body; diff --git a/tests/stream_tests.cpp b/tests/stream_tests.cpp new file mode 100644 index 00000000..5e934d94 --- /dev/null +++ b/tests/stream_tests.cpp @@ -0,0 +1,374 @@ +/* +Copyright 2019-2021 NetFoundry, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include +#include +#include + +#include "fixtures.h" +#include "catch.hpp" + +extern tls_context *testServerTLS(); + +TEST_CASE("stream connect fail", "[stream]") { + UvLoopTest test; + + tlsuv_stream_t s; + tls_context *tls = default_tls_context(nullptr, 0); + tlsuv_stream_init(test.loop, &s, tls); + + uv_connect_t cr; + int conn_cb_called = 0; + cr.data = &conn_cb_called; + + auto cb = [](uv_connect_t *r, int status) { + int *countp = (int*)r->data; + *countp = *countp + 1; + printf("conn cb called status = %d(%s)\n", status, status != 0 ? uv_strerror(status) : ""); + + }; + int rc = 0; + + WHEN("connect fail") { + rc = tlsuv_stream_connect(&cr, &s, "127.0.0.1", 62443, cb); + test.run(); + CHECK(((rc == 0 && conn_cb_called == 1) || (rc != 0 && conn_cb_called == 0))); + } + WHEN("resolve fail") { + rc = tlsuv_stream_connect(&cr, &s, "foo.bar.baz", 443, cb); + test.run(); + CHECK(((rc == 0 && conn_cb_called == 1) || (rc != 0 && conn_cb_called == 0))); + } + + tlsuv_stream_free(&s); + tls->free_ctx(tls); +} + +TEST_CASE("cancel connect", "[stream]") { + UvLoopTest test; + + tlsuv_stream_t s; + tls_context *tls = default_tls_context(nullptr, 0); + tlsuv_stream_init(test.loop, &s, tls); + + struct test_ctx { + int connect_result; + bool close_called; + } test_ctx; + + test_ctx.connect_result = 0; + test_ctx.close_called = false; + + s.data = &test_ctx; + + uv_connect_t cr; + cr.data = &test_ctx; + int rc = tlsuv_stream_connect(&cr, &s, "1.1.1.1", 5555, [](uv_connect_t *r, int status) { + auto ctx = (struct test_ctx *) r->data; + ctx->connect_result = status; + }); + + uv_timer_t t; + uv_timer_init(test.loop, &t); + t.data = &s; + auto timer_cb = [](uv_timer_t* t){ + auto *c = static_cast(t->data); + uv_close_cb closeCb = [](uv_handle_t *h) { + auto s = (tlsuv_stream_t *) h; + auto ctx = (struct test_ctx*)s->data; + ctx->close_called = true; + }; + tlsuv_stream_close(c, closeCb); + uv_close(reinterpret_cast(t), nullptr); + }; + uv_timer_start(&t, timer_cb, 1000, 0); + + test.run(); + + CHECK(rc == 0); + CHECK(test_ctx.close_called); + CHECK(test_ctx.connect_result == UV_ECANCELED); + + tlsuv_stream_free(&s); + tls->free_ctx(tls); +} + +static void test_alloc(uv_handle_t *s, size_t req, uv_buf_t* b) { + b->base = static_cast(calloc(1, req)); + b->len = req; +} + +TEST_CASE("read/write","[stream]") { + UvLoopTest test; + + const char* proto[] = { + "foo", + "bar", + "http/1.1" + }; + tlsuv_stream_t s; + tls_context *tls = default_tls_context(nullptr, 0); + tlsuv_stream_init(test.loop, &s, tls); + tlsuv_stream_set_protocols(&s, 3, proto); + + struct test_ctx { + int connect_result; + bool close_called; + } test_ctx; + + test_ctx.connect_result = 0; + test_ctx.close_called = false; + + s.data = &test_ctx; + + uv_connect_t cr; + cr.data = &test_ctx; + int rc = tlsuv_stream_connect(&cr, &s, "1.1.1.1", 443, [](uv_connect_t *r, int status) { + REQUIRE(status == 0); + auto c = (tlsuv_stream_t *) r->handle; + + auto proto = tlsuv_stream_get_protocol(c); + REQUIRE(proto != nullptr); + CHECK_THAT(proto, Catch::Equals("http/1.1")); + + tlsuv_stream_read_start(c, test_alloc, [](uv_stream_t *s, ssize_t status, const uv_buf_t *b) { + auto c = (tlsuv_stream_t *) s; + auto ctx = (struct test_ctx *) c->data; + if (status == UV_EOF) { + tlsuv_stream_close(c, nullptr); + } else { + REQUIRE(status > 0); + REQUIRE_THAT(b->base, Catch::StartsWith("HTTP/1.1 200 OK")); + fprintf(stderr, "%.*s\n", (int) status, b->base); + } + free(b->base); + }); + + auto *wr = static_cast(calloc(1, sizeof(uv_write_t))); + const char *msg = R"(GET /dns-query?name=openziti.org&type=A HTTP/1.1 +Accept-Encoding: gzip, deflate +Connection: close +Host: 1.1.1.1 +User-Agent: HTTPie/1.0.2 +accept: application/dns-json + +)"; + uv_buf_t buf = uv_buf_init((char *) msg, strlen(msg)); + tlsuv_stream_write(wr, c, &buf, [](uv_write_t *wr, int rc) { + REQUIRE(rc == 0); + free(wr); + }); + }); + + test.run(); + + CHECK(rc == 0); + + tlsuv_stream_free(&s); + + tls->free_ctx(tls); +} + +struct connect_args_s { + tlsuv_stream_t *s; + const char *hostname; + int port; + struct test_result *result; +}; + +struct sleep_args_s { + int timeout; +}; + +struct write_args_s { + tlsuv_stream_t *s; + const char *data; +}; + +struct expected_result_s { + struct test_result *result; + const char *data; + int count; +}; + +typedef struct step_s step_t; +typedef void (*step_fn)(uv_loop_t *, struct step_s *); +struct step_s { + step_fn fn; + + union { + connect_args_s connect_args; + write_args_s write_args; + sleep_args_s sleep_args; + expected_result_s expected; + }; +}; + +static inline void start(uv_loop_t *l, step_t *s) { + if (s && s->fn) { + s->fn(l, s); + } +} +static inline step_t *next_step(step_t *s) { return ++s; } + +static inline void run_next(uv_loop_t *l, step_t *step) { + start(l, next_step(step)); +} + +static void sleep_cb(uv_timer_t *t) { + step_t *step = static_cast(t->data); + uv_close(reinterpret_cast(t), + reinterpret_cast(free)); + printf("sleep step is done\n"); + run_next(t->loop, step); +} + +static void sleep_step(uv_loop_t *l, step_t *step) { + printf("running sleep step\n"); + uv_timer_t *t = (uv_timer_t *)calloc(1, sizeof(*t)); + uv_timer_init(l, t); + t->data = step; + uv_timer_start(t, sleep_cb, step->sleep_args.timeout, 0); +} + +static void connect_cb(uv_connect_t *r, int status) { + printf("connected: %d\n", status); + step_t *s = (step_t *)r->data; + auto stream = (tlsuv_stream_t *)r->handle; + auto l = stream->loop; + free(r); + REQUIRE(status == 0); + run_next(l, s); +} + +static void connect_step(uv_loop_t *l, step_t *step) { + tlsuv_stream_t *clt = step->connect_args.s; + REQUIRE(tlsuv_stream_init(l, clt, testServerTLS()) == 0); + clt->data = step->connect_args.result; + uv_connect_t *r = (uv_connect_t *)calloc(1, sizeof(*r)); + r->data = step; + REQUIRE(tlsuv_stream_connect(r, clt, step->connect_args.hostname, step->connect_args.port, connect_cb) == 0); +} + +static void disconnect_cb(uv_handle_t *h) { + auto s = (tlsuv_stream_t *)h; + auto step = (step_t *)s->data; + tlsuv_stream_free(s); + run_next(s->loop, step); +} + +static void disconnect_step(uv_loop_t *l, step_t *step) { + auto s = step->connect_args.s; + s->data = step; + tlsuv_stream_close(step->connect_args.s, disconnect_cb); +} + +static void write_cb(uv_write_t *r, int status) { + auto stream = (tlsuv_stream_t *)r->handle; + auto step = (step_t *)r->data; + REQUIRE(status == 0); + delete r; + run_next(stream->loop, step); +} + +static void write_step(uv_loop_t *l, step_t *step) { + uv_write_t *r = new uv_write_t; + auto buf = uv_buf_init((char *)step->write_args.data, + strlen(step->write_args.data)); + r->data = step; + REQUIRE(tlsuv_stream_write(r, step->write_args.s, &buf, write_cb) == 0); +} + +struct test_result { + int read_count; + std::string read_data; + tlsuv_stream_t *stream; + + public: + explicit test_result(tlsuv_stream_t *s) + : stream(s), read_count(0), read_data("") {} +}; + +static void check_result(uv_loop_t *l, step_t *step) { + printf("read: %s\n", step->expected.result->read_data.c_str()); + REQUIRE(step->expected.result->read_data == step->expected.data); + CHECK(step->expected.result->read_count <= step->expected.count); + run_next(l, step); +} + +static void read_alloc(uv_handle_t *handle, size_t size, uv_buf_t *buf) { + buf->base = (char *)malloc(size); + buf->len = size; +} + +static void read_cb(uv_stream_t *stream, ssize_t nread, const uv_buf_t *buf) { + tlsuv_stream_t *clt = reinterpret_cast(stream); + test_result *result = static_cast(clt->data); + + REQUIRE(nread > 0); + result->read_count++; + result->read_data.append(buf->base, nread); + + free(buf->base); +} + +static void start_read_step(uv_loop_t *l, step_t *step) { + CHECK(tlsuv_stream_read_start(step->write_args.s, nullptr, nullptr) == UV_EINVAL); + CHECK(tlsuv_stream_read_start(step->write_args.s, read_alloc, nullptr) == UV_EINVAL); + CHECK(tlsuv_stream_read_start(step->write_args.s, read_alloc, read_cb) == 0); + CHECK(tlsuv_stream_read_start(step->write_args.s, read_alloc, read_cb) == UV_EALREADY); + run_next(l, step); +} + +static void stop_read_step(uv_loop_t *l, step_t *step) { + REQUIRE(tlsuv_stream_read_stop(step->write_args.s) == 0); + run_next(l, step); +} + +TEST_CASE("read start/stop", "[stream]") { + UvLoopTest loopTest; + tlsuv_stream_t s; + test_result r(&s); + + step_t steps[] = { + { + .fn = connect_step, + .connect_args = { .s = &s, .hostname = "localhost", .port = 7443, .result = &r }, + }, + { .fn = write_step, .write_args = { .s = &s, .data = "1",}}, + { .fn = write_step, .write_args = { .s = &s, .data = "2",}}, + { .fn = sleep_step, .sleep_args = { .timeout = 100, } }, + { .fn = check_result, .expected = { .result = &r, .data = "", .count = 0, }}, // not reading yet + { .fn = start_read_step, .write_args = { .s = &s }}, + { .fn = sleep_step, .sleep_args = { .timeout = 100, } }, + { .fn = check_result, .expected = { .result = &r, .data = "12", .count = 1,}}, // should read echo from two writes + { .fn = stop_read_step, .write_args = {.s = &s }}, + { .fn = write_step, .write_args = { .s = &s, .data = "3",}}, + { .fn = write_step, .write_args = { .s = &s, .data = "4",}}, + { .fn = sleep_step, .sleep_args = { .timeout = 100, } }, + { .fn = check_result, .expected = { .result = &r, .data = "12", .count = 1, }}, // not reading + { .fn = start_read_step, .write_args = { .s = &s }}, + { .fn = write_step, .write_args = { .s = &s, .data = "5",}}, + { .fn = write_step, .write_args = { .s = &s, .data = "6",}}, + { .fn = sleep_step, .sleep_args = { .timeout = 100, } }, + { .fn = check_result, .expected = { .result = &r, .data = "123456", .count = 4,}}, // should read echo from writes 3,4,5,6 + { .fn = disconnect_step, .connect_args = { .s = &s } }, + { .fn = nullptr } + }; + + start(loopTest.loop, steps); + loopTest.run(); +} \ No newline at end of file diff --git a/tests/test_server/test-server.go b/tests/test_server/test-server.go index 5bb3e4a9..ba400a96 100644 --- a/tests/test_server/test-server.go +++ b/tests/test_server/test-server.go @@ -53,6 +53,43 @@ func runClientAuth(port int, keyFile, certFile string) chan error { return done } +func runEchoServer(port int, keyFile, certFile string) chan error { + done := make(chan error) + cfg := &tls.Config{} + cert, _ := tls.LoadX509KeyPair(certFile, keyFile) + cfg.Certificates = append(cfg.Certificates, cert) + + go func() { + server, err := tls.Listen("tcp", fmt.Sprintf(":%d", port), cfg) + if err != nil { + done <- err + return + } + + for { + clt, err := server.Accept() + if err != nil { + done <- err + return + } + + go func() { + for { + buf := make([]byte, 1024) + n, err := clt.Read(buf) + if err != nil { + return + } + if wc, err := clt.Write(buf[:n]); err != nil || wc != n { + return + } + } + }() + } + }() + return done +} + var keyFile string var certFile string @@ -70,6 +107,7 @@ func main() { case err = <-runHTTP(8080, httpb): case err = <-runHTTPS(8443, httpb, keyFile, certFile): case err = <-runClientAuth(9443, keyFile, certFile): + case err = <-runEchoServer(7443, keyFile, certFile): } fmt.Println(err) diff --git a/tests/uv_mbed_tests.cpp b/tests/uv_mbed_tests.cpp deleted file mode 100644 index 075eb078..00000000 --- a/tests/uv_mbed_tests.cpp +++ /dev/null @@ -1,181 +0,0 @@ -/* -Copyright 2019-2021 NetFoundry, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -https://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#include -#include -#include - -#include "fixtures.h" -#include "catch.hpp" - -TEST_CASE("uv-mbed connect fail", "[uv-mbed]") { - UvLoopTest test; - - tlsuv_stream_t mbed; - tls_context *tls = default_tls_context(nullptr, 0); - tlsuv_stream_init(test.loop, &mbed, tls); - - uv_connect_t cr; - int conn_cb_called = 0; - cr.data = &conn_cb_called; - - auto cb = [](uv_connect_t *r, int status) { - int *countp = (int*)r->data; - *countp = *countp + 1; - printf("conn cb called status = %d(%s)\n", status, status != 0 ? uv_strerror(status) : ""); - - }; - int rc = 0; - - WHEN("connect fail") { - rc = tlsuv_stream_connect(&cr, &mbed, "127.0.0.1", 62443, cb); - test.run(); - CHECK(((rc == 0 && conn_cb_called == 1) || (rc != 0 && conn_cb_called == 0))); - } - WHEN("resolve fail") { - rc = tlsuv_stream_connect(&cr, &mbed, "foo.bar.baz", 443, cb); - test.run(); - CHECK(((rc == 0 && conn_cb_called == 1) || (rc != 0 && conn_cb_called == 0))); - } - - tlsuv_stream_free(&mbed); - tls->free_ctx(tls); -} - -TEST_CASE("cancel connect", "[uv-mbed]") { - UvLoopTest test; - - tlsuv_stream_t mbed; - tls_context *tls = default_tls_context(nullptr, 0); - tlsuv_stream_init(test.loop, &mbed, tls); - - struct test_ctx { - int connect_result; - bool close_called; - } test_ctx; - - test_ctx.connect_result = 0; - test_ctx.close_called = false; - - mbed.data = &test_ctx; - - uv_connect_t cr; - cr.data = &test_ctx; - int rc = tlsuv_stream_connect(&cr, &mbed, "1.1.1.1", 5555, [](uv_connect_t *r, int status) { - auto ctx = (struct test_ctx *) r->data; - ctx->connect_result = status; - }); - - uv_timer_t t; - uv_timer_init(test.loop, &t); - t.data = &mbed; - auto timer_cb = [](uv_timer_t* t){ - auto *c = static_cast(t->data); - uv_close_cb closeCb = [](uv_handle_t *h) { - auto mbed = (tlsuv_stream_t *) h; - auto ctx = (struct test_ctx*)mbed->data; - ctx->close_called = true; - }; - tlsuv_stream_close(c, closeCb); - uv_close(reinterpret_cast(t), nullptr); - }; - uv_timer_start(&t, timer_cb, 1000, 0); - - test.run(); - - CHECK(rc == 0); - CHECK(test_ctx.close_called); - CHECK(test_ctx.connect_result == UV_ECANCELED); - - tlsuv_stream_free(&mbed); - tls->free_ctx(tls); -} - -static void test_alloc(uv_handle_t *s, size_t req, uv_buf_t* b) { - b->base = static_cast(calloc(1, req)); - b->len = req; -} - -TEST_CASE("read/write","[uv-mbed]") { - UvLoopTest test; - - const char* proto[] = { - "foo", - "bar", - "http/1.1" - }; - tlsuv_stream_t mbed; - tls_context *tls = default_tls_context(nullptr, 0); - tlsuv_stream_init(test.loop, &mbed, tls); - tlsuv_stream_set_protocols(&mbed, 3, proto); - - struct test_ctx { - int connect_result; - bool close_called; - } test_ctx; - - test_ctx.connect_result = 0; - test_ctx.close_called = false; - - mbed.data = &test_ctx; - - uv_connect_t cr; - cr.data = &test_ctx; - int rc = tlsuv_stream_connect(&cr, &mbed, "1.1.1.1", 443, [](uv_connect_t *r, int status) { - REQUIRE(status == 0); - auto c = (tlsuv_stream_t *) r->handle; - - auto proto = tlsuv_stream_get_protocol(c); - REQUIRE(proto != nullptr); - CHECK_THAT(proto, Catch::Equals("http/1.1")); - - tlsuv_stream_read(c, test_alloc, [](uv_stream_t *s, ssize_t status, const uv_buf_t *b) { - auto c = (tlsuv_stream_t *) s; - auto ctx = (struct test_ctx *) c->data; - if (status == UV_EOF) { - tlsuv_stream_close(c, nullptr); - } else { - REQUIRE(status > 0); - REQUIRE_THAT(b->base, Catch::StartsWith("HTTP/1.1 200 OK")); - fprintf(stderr, "%.*s\n", (int) status, b->base); - } - free(b->base); - }); - - auto *wr = static_cast(calloc(1, sizeof(uv_write_t))); - const char *msg = R"(GET /dns-query?name=openziti.org&type=A HTTP/1.1 -Accept-Encoding: gzip, deflate -Connection: close -Host: 1.1.1.1 -User-Agent: HTTPie/1.0.2 -accept: application/dns-json - -)"; - uv_buf_t buf = uv_buf_init((char *) msg, strlen(msg)); - tlsuv_stream_write(wr, c, &buf, [](uv_write_t *wr, int rc) { - REQUIRE(rc == 0); - free(wr); - }); - }); - - test.run(); - - CHECK(rc == 0); - - tlsuv_stream_free(&mbed); - - tls->free_ctx(tls); -} \ No newline at end of file