diff --git a/include/PgSQL_Protocol.h b/include/PgSQL_Protocol.h index 18aebda45..25f485bd7 100644 --- a/include/PgSQL_Protocol.h +++ b/include/PgSQL_Protocol.h @@ -694,7 +694,35 @@ class PgSQL_Protocol : public MySQL_Protocol { * updates the output buffer with the generated packet. If `ready` is * true, it also generates and sends a ready-for-query packet. */ - bool generate_ok_packet(bool send, bool ready, const char* msg, int rows, const char* query, PtrSize_t* _ptr = NULL); + bool generate_ok_packet(bool send, bool ready, const char* msg, int rows, const char* query, char trx_state = 'I', PtrSize_t* _ptr = NULL); + + // temporary overriding generate_pkt_OK to avoid crash. FIXME remove this + bool generate_pkt_OK(bool send, void** ptr, unsigned int* len, uint8_t sequence_id, unsigned int affected_rows, + uint64_t last_insert_id, uint16_t status, uint16_t warnings, char* msg, bool eof_identifier = false) { + char txn_state = 'I'; + if (status & SERVER_STATUS_IN_TRANS) { + txn_state = 'T'; + } + return generate_ok_packet(send, true, msg, affected_rows, "OK 1", txn_state); + } + + // temporary overriding generate_pkt_EOF to avoid crash. FIXME remove this + bool generate_pkt_EOF(bool send, void** ptr, unsigned int* len, uint8_t sequence_id, uint16_t warnings, + uint16_t status, MySQL_ResultSet* myrs = NULL) { + char txn_state = 'I'; + if (status & SERVER_STATUS_IN_TRANS) { + txn_state = 'T'; + } + return generate_ok_packet(send, true, NULL, 0, "OK 1", txn_state); + } + + // temporary overriding generate_pkt_ERR to avoid crash. FIXME remove this + bool generate_pkt_ERR(bool send, void** ptr, unsigned int* len, uint8_t sequence_id, uint16_t error_code, + char* sql_state, const char* sql_message, bool track = false) { + + generate_error_packet(send, true, sql_message, PGSQL_ERROR_CODES::ERRCODE_RAISE_EXCEPTION, false, track); + return true; + } //bool generate_row_description(bool send, PgSQL_Query_Result* rs, const PG_Fields& fields, unsigned int size); diff --git a/include/query_processor.h b/include/query_processor.h index 4b95e115d..c5e0a2471 100644 --- a/include/query_processor.h +++ b/include/query_processor.h @@ -1,9 +1,9 @@ #ifndef __CLASS_QUERY_PROCESSOR_H #define __CLASS_QUERY_PROCESSOR_H -#include "proxysql.h" -#include "cpp.h" #include #include +#include "proxysql.h" +#include "cpp.h" // Optimization introduced in 2.0.6 // to avoid a lot of unnecessary copy diff --git a/lib/MySQL_Protocol.cpp b/lib/MySQL_Protocol.cpp index 1b23d5fd1..5eedf0247 100644 --- a/lib/MySQL_Protocol.cpp +++ b/lib/MySQL_Protocol.cpp @@ -1173,6 +1173,7 @@ bool MySQL_Protocol::generate_pkt_initial_handshake(bool send, void **ptr, unsig return true; } +#ifdef PROXYSQLCLICKHOUSE void ch_account_to_my(account_details_t& account, ch_account_details_t& ch_account) { account.username = ch_account.username; account.password = ch_account.password; @@ -1196,6 +1197,7 @@ void ch_account_to_my(account_details_t& account, ch_account_details_t& ch_accou account.attributes = nullptr; // No attributes by default account.comment = nullptr; // No comment by default } +#endif /* PROXYSQLCLICKHOUSE */ bool MySQL_Protocol::process_pkt_auth_swich_response(unsigned char *pkt, unsigned int len) { bool ret=false; diff --git a/lib/PgSQL_Protocol.cpp b/lib/PgSQL_Protocol.cpp index 513f0b1f9..1a87ff8a4 100644 --- a/lib/PgSQL_Protocol.cpp +++ b/lib/PgSQL_Protocol.cpp @@ -692,6 +692,27 @@ bool PgSQL_Protocol::process_startup_packet(unsigned char* pkt, unsigned int len return true; } +char* extract_password(const pgsql_hdr* hdr, uint32_t* len) { + char* pass = NULL; + uint32_t pass_len = hdr->data.size; + + if (pass_len == 0) + return NULL; + + pass = (char*)malloc(pass_len + 1); + memcpy(pass, hdr->data.ptr, pass_len); + pass[pass_len] = 0; + + if (pass_len) { + if (pass[pass_len - 1] == 0) { + pass_len--; // remove the extra 0 if present + } + } + + if (len) *len = pass_len; + return pass; +} + EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* pkt, unsigned int len) { #ifdef DEBUG //if (dump_pkt) { __dump_pkt(__func__, pkt, len); } @@ -764,31 +785,24 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* ((*myds)->sess->session_type == PROXYSQL_SESSION_SQLITE) ) { if (strcmp((const char*)user, mysql_thread___monitor_username) == 0) { - if (strcmp(password, mysql_thread___monitor_password) == 0) { - (*myds)->sess->default_hostgroup = STATS_HOSTGROUP; - (*myds)->sess->default_schema = strdup((char*)"main"); // just the pointer is passed - (*myds)->sess->schema_locked = false; - (*myds)->sess->transaction_persistent = false; - (*myds)->sess->session_fast_forward = false; - (*myds)->sess->user_max_connections = 0; - password = l_strdup(mysql_thread___monitor_password); - ret = EXECUTION_STATE::SUCCESSFUL; - } + (*myds)->sess->default_hostgroup = STATS_HOSTGROUP; + (*myds)->sess->default_schema = strdup((char*)"main"); // just the pointer is passed + (*myds)->sess->schema_locked = false; + (*myds)->sess->transaction_persistent = false; + (*myds)->sess->session_fast_forward = false; + (*myds)->sess->user_max_connections = 0; + password = l_strdup(mysql_thread___monitor_password); } } } - if (password) { proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "Session=%p , DS=%p , user='%s' , auth_method=%s\n", (*myds), (*myds)->sess, user, AUTHENTICATION_METHOD_STR[(int)(*myds)->auth_method]); switch ((*myds)->auth_method) { case AUTHENTICATION_METHOD::MD5_PASSWORD: { - uint32_t pass_len = hdr.data.size; - pass = (char*)malloc(pass_len + 1); - memcpy(pass, hdr.data.ptr, pass_len); - pass[pass_len] = 0; - + uint32_t pass_len = 0; + pass = extract_password(&hdr, &pass_len); using_password = (pass_len > 0); if (pass_len) { @@ -831,19 +845,10 @@ EXECUTION_STATE PgSQL_Protocol::process_handshake_response_packet(unsigned char* break; case AUTHENTICATION_METHOD::CLEAR_TEXT_PASSWORD: { - uint32_t pass_len = hdr.data.size; - pass = (char*)malloc(pass_len + 1); - memcpy(pass, hdr.data.ptr, pass_len); - pass[pass_len] = 0; - + uint32_t pass_len = 0; + pass = extract_password(&hdr, &pass_len); using_password = (pass_len > 0); - if (pass_len) { - if (pass[pass_len - 1] == 0) { - pass_len--; // remove the extra 0 if present - } - } - if (!pass || *pass == '\0') { proxy_debug(PROXY_DEBUG_MYSQL_AUTH, 5, "Session=%p , DS=%p , user='%s'. Empty password returned by client.\n", (*myds), (*myds)->sess, user); generate_error_packet(true, false, "empty password returned by client", PGSQL_ERROR_CODES::ERRCODE_PROTOCOL_VIOLATION, true); @@ -1261,7 +1266,7 @@ char* extract_tag_from_query(const char* query) { } -bool PgSQL_Protocol::generate_ok_packet(bool send, bool ready, const char* msg, int rows, const char* query, PtrSize_t* _ptr) { +bool PgSQL_Protocol::generate_ok_packet(bool send, bool ready, const char* msg, int rows, const char* query, char trx_state, PtrSize_t* _ptr) { // to avoid memory leak assert(send == true || _ptr); @@ -1293,7 +1298,7 @@ bool PgSQL_Protocol::generate_ok_packet(bool send, bool ready, const char* msg, } if (ready == true) { - pgpkt.write_ReadyForQuery(); + pgpkt.write_ReadyForQuery(trx_state); pgpkt.set_multi_pkt_mode(false); } diff --git a/lib/PgSQL_Query_Processor.cpp b/lib/PgSQL_Query_Processor.cpp index e3427e712..e4fa0038e 100644 --- a/lib/PgSQL_Query_Processor.cpp +++ b/lib/PgSQL_Query_Processor.cpp @@ -642,7 +642,7 @@ SQLite3_result* PgSQL_Query_Processor::get_stats_commands_counters() { result->add_column_definition(SQLITE_TEXT, "cnt_5s"); result->add_column_definition(SQLITE_TEXT, "cnt_10s"); result->add_column_definition(SQLITE_TEXT, "cnt_INFs"); - for (int i = 0; i < MYSQL_COM_QUERY__UNINITIALIZED; i++) { + for (int i = 0; i < PGSQL_QUERY__UNINITIALIZED; i++) { char** pta = commands_counters[i]->get_row(); result->add_row(pta); commands_counters[i]->free_row(pta); diff --git a/lib/PgSQL_Session.cpp b/lib/PgSQL_Session.cpp index c6b2f42d9..c454982d1 100644 --- a/lib/PgSQL_Session.cpp +++ b/lib/PgSQL_Session.cpp @@ -4785,15 +4785,15 @@ void PgSQL_Session::handler_WCD_SS_MCQ_qpo_QueryRewrite(PtrSize_t* pkt) { if (thread->variables.stats_time_query_processor) { clock_gettime(CLOCK_THREAD_CPUTIME_ID, &begint); } - pkt->size = sizeof(mysql_hdr) + 1 + qpo->new_query->length(); - pkt->ptr = l_alloc(pkt->size); - mysql_hdr hdr; - hdr.pkt_id = 0; - hdr.pkt_length = pkt->size - sizeof(mysql_hdr); - memcpy((unsigned char*)pkt->ptr, &hdr, sizeof(mysql_hdr)); // copy header - unsigned char* c = (unsigned char*)pkt->ptr + sizeof(mysql_hdr); - *c = (unsigned char)_MYSQL_COM_QUERY; // set command type - memcpy((unsigned char*)pkt->ptr + sizeof(mysql_hdr) + 1, qpo->new_query->data(), qpo->new_query->length()); // copy query + + PG_pkt pgpkt(1 + 4 + qpo->new_query->length() + 1); + pgpkt.put_char('Q'); + pgpkt.put_uint32(4 + qpo->new_query->length() + 1); + pgpkt.put_bytes(qpo->new_query->data(), qpo->new_query->length()); + pgpkt.put_char('\0'); + auto buff = pgpkt.detach(); + pkt->ptr = buff.first; + pkt->size = buff.second; CurrentQuery.query_parser_free(); CurrentQuery.begin((unsigned char*)pkt->ptr, pkt->size, true); delete qpo->new_query; @@ -4811,9 +4811,8 @@ void PgSQL_Session::handler_WCD_SS_MCQ_qpo_OK_msg(PtrSize_t* pkt) { client_myds->DSS = STATE_QUERY_SENT_NET; unsigned int nTrx = NumActiveTransactions(); - uint16_t setStatus = (nTrx ? SERVER_STATUS_IN_TRANS : 0); - if (autocommit) setStatus |= SERVER_STATUS_AUTOCOMMIT; - client_myds->myprot.generate_pkt_OK(true, NULL, NULL, client_myds->pkt_sid + 1, 0, 0, setStatus, 0, qpo->OK_msg); + const char trx_state = (nTrx ? 'T' : 'I'); + client_myds->myprot.generate_ok_packet(true, true, qpo->OK_msg, 0, (const char*)pkt->ptr + 5, trx_state); RequestEnd(NULL); l_free(pkt->size, pkt->ptr); } @@ -4821,7 +4820,8 @@ void PgSQL_Session::handler_WCD_SS_MCQ_qpo_OK_msg(PtrSize_t* pkt) { // this function as inline in handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_COM_QUERY_qpo void PgSQL_Session::handler_WCD_SS_MCQ_qpo_error_msg(PtrSize_t* pkt) { client_myds->DSS = STATE_QUERY_SENT_NET; - client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, client_myds->pkt_sid + 1, 1148, (char*)"42000", qpo->error_msg); + client_myds->myprot.generate_error_packet(true, true, qpo->error_msg, + PGSQL_ERROR_CODES::ERRCODE_INSUFFICIENT_PRIVILEGE, false); RequestEnd(NULL); l_free(pkt->size, pkt->ptr); } @@ -4830,7 +4830,8 @@ void PgSQL_Session::handler_WCD_SS_MCQ_qpo_error_msg(PtrSize_t* pkt) { void PgSQL_Session::handler_WCD_SS_MCQ_qpo_LargePacket(PtrSize_t* pkt) { // ER_NET_PACKET_TOO_LARGE client_myds->DSS = STATE_QUERY_SENT_NET; - client_myds->myprot.generate_pkt_ERR(true, NULL, NULL, client_myds->pkt_sid + 1, 1153, (char*)"08S01", (char*)"Got a packet bigger than 'max_allowed_packet' bytes", true); + client_myds->myprot.generate_error_packet(true, true, "Got a packet bigger than 'max_allowed_packet' bytes", + PGSQL_ERROR_CODES::ERRCODE_PROGRAM_LIMIT_EXCEEDED, false); RequestEnd(NULL); l_free(pkt->size, pkt->ptr); } @@ -5830,6 +5831,7 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C return true; } //} + /* Query Cache is not supported for PgSQL if (qpo->cache_ttl > 0 && ((prepare_stmt_type & PgSQL_ps_type_prepare_stmt) == 0)) { bool deprecate_eof_active = client_myds->myconn->options.client_flag & CLIENT_DEPRECATE_EOF; uint32_t resbuf = 0; @@ -5855,7 +5857,7 @@ bool PgSQL_Session::handler___status_WAITING_CLIENT_DATA___STATE_SLEEP___MYSQL_C l_free(pkt->size, pkt->ptr); return true; } - } + }*/ __exit_set_destination_hostgroup: diff --git a/lib/Query_Processor.cpp b/lib/Query_Processor.cpp index 7c736ed2c..87077fd13 100644 --- a/lib/Query_Processor.cpp +++ b/lib/Query_Processor.cpp @@ -30,12 +30,6 @@ using json = nlohmann::json; #define GET_THREAD_VARIABLE(VARIABLE_NAME) \ ({((std::is_same_v) ? mysql_thread___##VARIABLE_NAME : pgsql_thread___##VARIABLE_NAME) ;}) -template -class Query_Processor; - -template -class Query_Processor; - extern MySQL_Threads_Handler *GloMTH; extern PgSQL_Threads_Handler* GloPTH; extern ProxySQL_Admin *GloAdmin; @@ -68,7 +62,7 @@ static bool rules_sort_comp_function (QP_rule_t * a, QP_rule_t * b) { } static unsigned long long mem_used_rule(QP_rule_t *qr) { - unsigned long long s = sizeof(QP_rule_t); + unsigned long long s = 0; if (qr->username) s+=strlen(qr->username); if (qr->schemaname) @@ -384,6 +378,7 @@ bool Query_Processor::insert(QP_rule_t *qr, bool lock) { if (lock) wrlock(); rules.push_back(qr); + rules_mem_used += sizeof(TypeQueryRule); rules_mem_used += mem_used_rule(qr); if (lock) wrunlock(); @@ -2435,3 +2430,8 @@ void Query_Processor_Output::get_info_json(json& j) { j["retries"] = retries; j["max_lag_ms"] = max_lag_ms; } +template +class Query_Processor; + +template +class Query_Processor;