Skip to content

Commit

Permalink
refactor: dns_resolver.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
a1383n committed Jun 9, 2024
1 parent 8c4c59e commit 279145f
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 22 deletions.
12 changes: 7 additions & 5 deletions src/inet/dns/dns.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
#include <cstring>
#include "dns.hpp"

DNS::Packet::Packet(const uint8_t *data, size_t len) {
DNS::Packet::Packet(SocketClient *socketClient, const uint8_t *data, size_t len) : socketClient(socketClient) {
ldns_enum_status s = ldns_wire2pkt(&_pkt, data, len);

if (s != LDNS_STATUS_OK) {
throw std::runtime_error("Invalid status");
}
}

std::vector<std::string> DNS::resolveQuestions(ldns_pkt *dnsPacket) {
std::vector<std::string> DNS::resolveQuestions(DNS::Packet *packet) {
ldns_pkt *dnsPacket = packet->_pkt;
std::vector<std::string> resolvedIPs;
HttpDNSResolver resolver = HttpDNSResolver();

Expand All @@ -24,7 +25,7 @@ std::vector<std::string> DNS::resolveQuestions(ldns_pkt *dnsPacket) {
if (qname) {
std::string qnameStr(qname);
std::string ip;
int resolveStatus = resolver.resolve(qnameStr, &ip);
int resolveStatus = resolver.resolve(inet_ntoa(packet->socketClient->addr.sin_addr), qnameStr, &ip);
if (resolveStatus == 0) {
resolvedIPs.push_back(ip);
}
Expand All @@ -37,7 +38,8 @@ std::vector<std::string> DNS::resolveQuestions(ldns_pkt *dnsPacket) {
return resolvedIPs;
}

ssize_t DNS::createResponse(uint8_t *buffer, ldns_pkt *dnsPacket) {
ssize_t DNS::createResponse(DNS::Packet *packet, uint8_t *buffer) {
ldns_pkt *dnsPacket = packet->_pkt;
ldns_pkt *responsePacket = ldns_pkt_new();
if (!responsePacket) {
return -1;
Expand All @@ -59,7 +61,7 @@ ssize_t DNS::createResponse(uint8_t *buffer, ldns_pkt *dnsPacket) {
}
}

auto resolvedIPs = DNS::resolveQuestions(dnsPacket);
auto resolvedIPs = DNS::resolveQuestions(packet);
for (const std::string &ip: resolvedIPs) {
ldns_rr *answerRR = ldns_rr_new();
if (questions && ldns_rr_list_rr_count(questions) > 0) {
Expand Down
8 changes: 5 additions & 3 deletions src/inet/dns/dns.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,22 @@

#include <ldns/ldns.h>

#include "../socket/socket.hpp"
#include "dns_resolver.hpp"

class DNS {
public:
class Packet {
public:
Packet(const uint8_t *data, size_t len);
Packet(SocketClient *socketClient, const uint8_t *data, size_t len);

ldns_pkt *_pkt;
SocketClient *socketClient;
};

static ssize_t createResponse(uint8_t *buffer, ldns_pkt *dnsPacket);
static ssize_t createResponse(DNS::Packet *packet, uint8_t *buffer);

static std::vector<std::string> resolveQuestions(ldns_pkt *dnsPacket);
static std::vector<std::string> resolveQuestions(DNS::Packet *packet);
};

#endif //DNS_REVERSE_PROXY_DNS_HPP
45 changes: 35 additions & 10 deletions src/inet/dns/dns_resolver.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "dns_resolver.hpp"
#include "../../config/config.h"
#include <curl/curl.h>
#include <jsoncpp/json/json.h>
#include <string>

size_t writeCallback(void *contents, size_t size, size_t nmemb, std::string *response) {
Expand All @@ -9,23 +10,30 @@ size_t writeCallback(void *contents, size_t size, size_t nmemb, std::string *res
return totalSize;
}

int HttpDNSResolver::resolve(std::string qname, std::string *ip) {
int HttpDNSResolver::resolve(std::string clientIp, std::string qname, std::string *ip) {
CURL *curl = curl_easy_init();
if (!curl) {
return 0;
return -1;
}

// Construct the URL with qname parameter
std::string url = Config::getHttpResolverUrl();
url += "?qname=" + qname;

curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeCallback);
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1);
Json::Value postData;
postData["q"] = qname;
postData["ip"] = clientIp;
Json::StreamWriterBuilder writer;
std::string postFields = Json::writeString(writer, postData);

std::string response;

curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeCallback);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &response);
curl_easy_setopt(curl, CURLOPT_POST, 1);
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, postFields.c_str());
curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, postFields.size());
curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1);
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, curl_slist_append(nullptr, "Content-Type: application/json"));

CURLcode res = curl_easy_perform(curl);
long httpStatusCode = 0;
Expand All @@ -34,9 +42,26 @@ int HttpDNSResolver::resolve(std::string qname, std::string *ip) {
curl_easy_cleanup(curl);

if (res == CURLE_OK && httpStatusCode < 400 && !response.empty()) {
*ip = response;
return 0;
try {
Json::CharReaderBuilder readerBuilder;
Json::Value jsonResponse;
std::string errs;

std::istringstream responseStream(response);
if (Json::parseFromStream(readerBuilder, responseStream, &jsonResponse, &errs)) {
if (jsonResponse["ok"].asBool()) {
*ip = jsonResponse["result"]["address"].asString();
return 0;
} else {
return -2;
}
} else {
return -1;
}
} catch (const std::exception &e) {
return -1;
}
} else {
return -1;
}
}
}
4 changes: 2 additions & 2 deletions src/inet/dns/dns_resolver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

class DNSResolver {
public:
virtual int resolve(std::string qname, std::string *ip) = 0;
virtual int resolve(std::string clientIp, std::string qname, std::string *ip) = 0;
};

class HttpDNSResolver : public DNSResolver {
public:
virtual int resolve(std::string qname, std::string *ip);
int resolve(std::string clientIp, std::string qname, std::string *ip) override;
};


Expand Down
4 changes: 2 additions & 2 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ int main(int argc, char *argv[]) {
auto handler = [](SocketClient *client, void *buff, size_t len) {
uint8_t b[4096];
try {
auto *packet = new DNS::Packet((const uint8_t *) buff, len);
auto l = DNS::createResponse(b, packet->_pkt);
auto *packet = new DNS::Packet(client, (const uint8_t *) buff, len);
auto l = DNS::createResponse(packet, b);
client->send(b, l);
delete packet;
} catch (std::runtime_error &e) {
Expand Down

0 comments on commit 279145f

Please sign in to comment.