diff --git a/Makefile.am b/Makefile.am index eb47ea9c3..30072dbd1 100644 --- a/Makefile.am +++ b/Makefile.am @@ -37,6 +37,7 @@ libdqlite_la_SOURCES = \ src/registry.c \ src/request.c \ src/response.c \ + src/revamp.c \ src/roles.c \ src/server.c \ src/stmt.c \ diff --git a/src/bind.c b/src/bind.c index 1bd622986..d32e20cdc 100644 --- a/src/bind.c +++ b/src/bind.c @@ -2,7 +2,7 @@ #include "tuple.h" /* Bind a single parameter. */ -static int bind_one(sqlite3_stmt *stmt, int n, struct value *value) +static int bind_one(sqlite3_stmt *stmt, int n, const struct value *value) { int rc; @@ -48,37 +48,79 @@ static int bind_one(sqlite3_stmt *stmt, int n, struct value *value) return rc; } -int bind__params(sqlite3_stmt *stmt, struct cursor *cursor, int format) +int parseParams(struct cursor *cursor, int format, struct value **out) { struct tuple_decoder decoder; + struct value *head; + struct value *prev; unsigned long i; - int rc; + int rv; assert(format == TUPLE__PARAMS || format == TUPLE__PARAMS32); - sqlite3_reset(stmt); - /* If the payload has been fully consumed, it means there are no * parameters to bind. */ if (cursor->cap == 0) { return 0; } - rc = tuple_decoder__init(&decoder, 0, format, cursor); - if (rc != 0) { - return rc; + rv = tuple_decoder__init(&decoder, 0, format, cursor); + if (rv != 0) { + goto err; } + + head = sqlite3_malloc(sizeof(*head)); + if (head == NULL) { + rv = DQLITE_NOMEM; + goto err; + } + prev = head; for (i = 0; i < tuple_decoder__n(&decoder); i++) { - struct value value; - rc = tuple_decoder__next(&decoder, &value); - if (rc != 0) { - return rc; + prev->next = sqlite3_malloc(sizeof(*prev->next)); + if (prev->next == NULL) { + goto err_after_alloc_head; } - rc = bind_one(stmt, (int)(i + 1), &value); - if (rc != 0) { - return rc; + rv = tuple_decoder__next(&decoder, prev->next); + if (rv != 0) { + goto err_after_alloc_head; } + prev = prev->next; } + *out = head; return 0; + +err_after_alloc_head: + freeParams(head); +err: + return rv; +} + +int bindParams(sqlite3_stmt *stmt, const struct value *params) +{ + int i; + int rv; + + i = 1; + for (const struct value *cur = params; cur != NULL; cur = cur->next) { + rv = bind_one(stmt, i, cur); + if (rv != 0) { + return rv; + } + i += 1; + } + return 0; +} + +void freeParams(struct value *params) +{ + struct value *cur; + struct value *old; + + cur = params; + while (cur != NULL) { + old = cur; + cur = old->next; + sqlite3_free(old); + } } diff --git a/src/bind.h b/src/bind.h index 1b2bc5a60..ebc019f95 100644 --- a/src/bind.h +++ b/src/bind.h @@ -9,9 +9,12 @@ #include "lib/serialize.h" -/** - * Bind the parameters of the given statement by decoding the given payload. - */ -int bind__params(sqlite3_stmt *stmt, struct cursor *cursor, int format); +struct value; + +int parseParams(struct cursor *cursor, int format, struct value **out); + +int bindParams(sqlite3_stmt *stmt, const struct value *params); + +void freeParams(struct value *params); #endif /* BIND_H_*/ diff --git a/src/conn.c b/src/conn.c index 1b031bc5b..33a59f2b4 100644 --- a/src/conn.c +++ b/src/conn.c @@ -2,6 +2,7 @@ #include "message.h" #include "protocol.h" #include "request.h" +#include "revamp.h" #include "tracing.h" #include "transport.h" @@ -295,7 +296,8 @@ int conn__start(struct conn *c, struct uv_stream_s *stream, struct raft_uv_transport *uv_transport, struct id_state seed, - conn_close_cb close_cb) + conn_close_cb close_cb, + struct db_context *db_ctx) { int rv; (void)loop; @@ -309,7 +311,7 @@ int conn__start(struct conn *c, c->transport.data = c; c->uv_transport = uv_transport; c->close_cb = close_cb; - gateway__init(&c->gateway, config, registry, raft, seed); + gateway__init(&c->gateway, config, registry, raft, seed, db_ctx); rv = buffer__init(&c->read); if (rv != 0) { goto err_after_transport_init; diff --git a/src/conn.h b/src/conn.h index fb93c8f27..5da7d0be0 100644 --- a/src/conn.h +++ b/src/conn.h @@ -14,6 +14,7 @@ #include "gateway.h" #include "id.h" #include "message.h" +#include "revamp.h" /** * Callbacks. @@ -52,7 +53,8 @@ int conn__start(struct conn *c, struct uv_stream_s *stream, struct raft_uv_transport *uv_transport, struct id_state seed, - conn_close_cb close_cb); + conn_close_cb close_cb, + struct db_context *db_ctx); /** * Force closing the connection. The close callback will be invoked when it's diff --git a/src/gateway.c b/src/gateway.c index 61288a1c1..dfe2108e6 100644 --- a/src/gateway.c +++ b/src/gateway.c @@ -15,7 +15,8 @@ void gateway__init(struct gateway *g, struct config *config, struct registry *registry, struct raft *raft, - struct id_state seed) + struct id_state seed, + struct db_context *db_ctx) { tracef("gateway init"); g->config = config; @@ -31,6 +32,7 @@ void gateway__init(struct gateway *g, g->protocol = DQLITE_PROTOCOL_VERSION; g->client_id = 0; g->random_state = seed; + g->db_ctx = db_ctx; } void gateway__leader_close(struct gateway *g, int reason) @@ -465,6 +467,7 @@ static int handle_exec(struct gateway *g, struct handle *req) struct request_exec request = {0}; int tuple_format; uint64_t req_id; + struct value *params; int rv; switch (req->schema) { @@ -491,12 +494,20 @@ static int handle_exec(struct gateway *g, struct handle *req) LOOKUP_DB(request.db_id); LOOKUP_STMT(request.stmt_id); FAIL_IF_CHECKPOINTING; - rv = bind__params(stmt->stmt, cursor, tuple_format); + rv = parseParams(cursor, tuple_format, ¶ms); + if (rv != 0) { + tracef("handle exec parse params failed %d", rv); + failure(req, rv, "parse parameters"); + return 0; + } + rv = bindParams(stmt->stmt, params); if (rv != 0) { tracef("handle exec bind failed %d", rv); + freeParams(params); failure(req, rv, "bind parameters"); return 0; } + freeParams(params); req->stmt_id = stmt->id; g->req = req; req_id = idNext(&g->random_state); @@ -595,6 +606,7 @@ static int handle_query(struct gateway *g, struct handle *req) int tuple_format; bool is_readonly; uint64_t req_id; + struct value *params; int rv; switch (req->schema) { @@ -621,12 +633,20 @@ static int handle_query(struct gateway *g, struct handle *req) LOOKUP_DB(request.db_id); LOOKUP_STMT(request.stmt_id); FAIL_IF_CHECKPOINTING; - rv = bind__params(stmt->stmt, cursor, tuple_format); + rv = parseParams(cursor, tuple_format, ¶ms); + if (rv != 0) { + tracef("handle query parse params failed %d", rv); + failure(req, rv, "bind parameters"); + return 0; + } + rv = bindParams(stmt->stmt, params); if (rv != 0) { tracef("handle query bind failed %d", rv); + freeParams(params); failure(req, rv, "bind parameters"); return 0; } + freeParams(params); req->stmt_id = stmt->id; g->req = req; @@ -697,6 +717,7 @@ static void handle_exec_sql_next(struct gateway *g, const char *tail; int tuple_format; uint64_t req_id; + struct value *params; int rv; if (req->sql == NULL || strcmp(req->sql, "") == 0) { @@ -728,8 +749,14 @@ static void handle_exec_sql_next(struct gateway *g, /* Should have been caught by handle_exec_sql */ assert(0); } - rv = bind__params(stmt, cursor, tuple_format); - if (rv != SQLITE_OK) { + rv = parseParams(cursor, tuple_format, ¶ms); + if (rv != 0) { + failure(req, rv, "parse parameters"); + goto done_after_prepare; + } + rv = bindParams(stmt, params); + if (rv != 0) { + freeParams(params); failure(req, rv, "bind parameters"); goto done_after_prepare; } @@ -847,6 +874,7 @@ static void querySqlBarrierCb(struct barrier *barrier, int status) int tuple_format; bool is_readonly; uint64_t req_id; + struct value *params; int rv; if (status != 0) { @@ -886,9 +914,17 @@ static void querySqlBarrierCb(struct barrier *barrier, int status) /* Should have been caught by handle_query_sql */ assert(0); } - rv = bind__params(stmt, cursor, tuple_format); + rv = parseParams(cursor, tuple_format, ¶ms); + if (rv != 0) { + tracef("handle query sql parse params failed %d", rv); + sqlite3_finalize(stmt); + failure(req, rv, "parse parameters"); + return; + } + rv = bindParams(stmt, params); if (rv != 0) { tracef("handle query sql bind failed %d", rv); + freeParams(params); sqlite3_finalize(stmt); failure(req, rv, "bind parameters"); return; diff --git a/src/gateway.h b/src/gateway.h index dd07fdd59..c6fcacc49 100644 --- a/src/gateway.h +++ b/src/gateway.h @@ -37,13 +37,15 @@ struct gateway uint64_t protocol; /* Protocol format version */ uint64_t client_id; struct id_state random_state; /* For generating IDs */ + struct db_context *db_ctx; }; void gateway__init(struct gateway *g, struct config *config, struct registry *registry, struct raft *raft, - struct id_state seed); + struct id_state seed, + struct db_context *db_ctx); void gateway__close(struct gateway *g); diff --git a/src/revamp.c b/src/revamp.c new file mode 100644 index 000000000..18745b686 --- /dev/null +++ b/src/revamp.c @@ -0,0 +1 @@ +#include "revamp.h" diff --git a/src/revamp.h b/src/revamp.h new file mode 100644 index 000000000..fd5eed561 --- /dev/null +++ b/src/revamp.h @@ -0,0 +1,11 @@ +#ifndef DQLITE_REVAMP_H +#define DQLITE_REVAMP_H + +#include + +struct db_context +{ + sem_t sem; +}; + +#endif diff --git a/src/server.c b/src/server.c index 75b9e8d07..760367d9e 100644 --- a/src/server.c +++ b/src/server.c @@ -167,6 +167,8 @@ void dqlite__close(struct dqlite_node *d) if (!d->initialized) { return; } + sem_destroy(&d->db_ctx->sem); + free(d->db_ctx); raft_free(d->listener); rv = sem_destroy(&d->stopped); assert(rv == 0); /* Fails only if sem object is not valid */ @@ -526,6 +528,9 @@ static void stopCb(uv_async_t *stop) conn__stop(conn); } raft_close(&d->raft, raftCloseCb); + + sem_post(&d->db_ctx->sem); + pthread_join(d->db_thread, NULL); } /* Callback invoked as soon as the loop as started. @@ -624,7 +629,7 @@ static void listenCb(uv_stream_t *listener, int status) goto err; } rv = conn__start(conn, &t->config, &t->loop, &t->registry, &t->raft, - stream, &t->raft_transport, seed, destroy_conn); + stream, &t->raft_transport, seed, destroy_conn, t->db_ctx); if (rv != 0) { goto err_after_conn_alloc; } @@ -671,10 +676,27 @@ static void roleManagementTimerCb(uv_timer_t *handle) RolesAdjust(d); } +static void *dbTask(void *arg) +{ + struct db_context *ctx = arg; + sem_wait(&ctx->sem); + return NULL; +} + static int taskRun(struct dqlite_node *d) { int rv; + d->db_ctx = malloc(sizeof *d->db_ctx); + if (d->db_ctx == NULL) { + return DQLITE_NOMEM; + } + rv = sem_init(&d->db_ctx->sem, 0, 0); + assert(rv == 0); + + rv = pthread_create(&d->db_thread, NULL, dbTask, d->db_ctx); + assert(rv == 0); + /* TODO: implement proper cleanup upon error by spinning the loop a few * times. */ assert(d->listener != NULL); diff --git a/src/server.h b/src/server.h index 300e74086..4803127fd 100644 --- a/src/server.h +++ b/src/server.h @@ -13,6 +13,7 @@ #include "lib/assert.h" #include "logger.h" #include "registry.h" +#include "revamp.h" #define DQLITE_ERRMSG_BUF_SIZE 300 @@ -57,6 +58,8 @@ struct dqlite_node void *connect_func_arg; /* User data for connection function */ char errmsg[DQLITE_ERRMSG_BUF_SIZE]; /* Last error occurred */ struct id_state random_state; /* For seeding ID generation */ + pthread_t db_thread; + struct db_context *db_ctx; }; /* Dynamic array of node info objects. This is the in-memory representation of diff --git a/src/tuple.h b/src/tuple.h index 4f4ba9bab..732e87e6a 100644 --- a/src/tuple.h +++ b/src/tuple.h @@ -78,6 +78,7 @@ struct value int64_t unixtime; /* Unix time in seconds since epoch */ uint64_t boolean; }; + struct value *next; }; /**