diff --git a/src/conf.h b/src/conf.h index 3c8d41d19..5a755a3af 100644 --- a/src/conf.h +++ b/src/conf.h @@ -204,7 +204,8 @@ typedef struct _popular_server_t { */ typedef struct _ip_trusted_t { - char ip[HTTP_IP_ADDR_LEN]; + char ip[INET_ADDRSTRLEN]; + unsigned int uip; struct _ip_trusted_t *next; } t_ip_trusted; diff --git a/src/dns_forward.c b/src/dns_forward.c index d29a820a4..634673aaa 100644 --- a/src/dns_forward.c +++ b/src/dns_forward.c @@ -42,110 +42,108 @@ strrstr(const char *haystack, const char *needle) return result; }; +static void +parse_name(unsigned char *response, unsigned char **ptr, char *name) { + int jumped = 0, offset; + unsigned char *pos = *ptr; + char *name_pos = name; + + while (*pos) { + if (*pos >= 192) { + offset = ((*pos) << 8 | *(pos + 1)) & 0x3FFF; + pos = response + offset; + jumped = 1; + } else { + int segment_len = *pos; + pos++; + memcpy(name_pos, pos, segment_len); + name_pos += segment_len; + pos += segment_len; + *name_pos++ = '.'; + } + if (!jumped) (*ptr)++; + } + *name_pos = '\0'; + if (!jumped) (*ptr)++; +} + static int -parse_dns_response_ex(unsigned char *response, int response_len) { +process_dns_response(unsigned char *response, int response_len) { s_config *config = config_get_config(); - unsigned char *ptr = response + 12; // Skip the DNS header - int qdcount = ntohs(*(unsigned short *)(response + 4)); - int ancount = ntohs(*(unsigned short *)(response + 6)); + + if (response_len <= sizeof(struct dns_header)) { + debug(LOG_WARNING, "Invalid DNS response"); + return -1; + } + struct dns_header *header = (struct dns_header *)response; + if (header->qr != 1 || header->opcode != 0 || header->rcode != 0) { + debug(LOG_WARNING, "Invalid DNS response"); + return -1; + } + + char query_name[MAX_DNS_NAME] = {0}; + int qdcount = ntohs(header->qdcount); + int ancount = ntohs(header->ancount); + unsigned char *ptr = response + sizeof(struct dns_header); // Skip the DNS header + debug(LOG_DEBUG, "DNS response: qdcount=%d, ancount=%d", qdcount, ancount); // Skip the question section + assert(qdcount == 1); // Only handle one question (A record) for (int i = 0; i < qdcount; i++) { - while (*ptr != 0) ptr++; - ptr += 5; // Skip null byte, type, and class + parse_name(response, &ptr, query_name); + ptr += 4; // Skip QTYPE and QCLASS + } + debug(LOG_DEBUG, "DNS query: %s", query_name); + // find the trusted domain in the query_name + t_domain_trusted *p = config->pan_domains_trusted; + while (p) { + if (strrstr(query_name, p->domain)) { + debug(LOG_DEBUG, "Find trusted wildcard domain: %s", p->domain); + break; + } + p = p->next; + } + + if (!p) { + debug(LOG_DEBUG, "No trusted wildcard domain found in the query"); + return 0; } // Parse the answer section for (int i = 0; i < ancount; i++) { - char domain[256]; - int len = 0; - - while (*ptr != 0) { - if (*ptr >= 192) { // Handle pointers - ptr += 2; - break; - } - memcpy(domain + len, ptr + 1, *ptr); - len += *ptr; - domain[len++] = '.'; - ptr += *ptr + 1; - } - domain[--len] = '\0'; // Null-terminate the domain string - ptr++; + char answer_name[MAX_DNS_NAME] = {0}; + parse_name(response, &ptr, answer_name); unsigned short type = ntohs(*(unsigned short *)ptr); - ptr += 2; ptr += 8; // Skip class, TTL, and data length unsigned short data_len = ntohs(*(unsigned short *)ptr); ptr += 2; - if (type == 1 && data_len == 4) { // Type A - char ip[INET_ADDRSTRLEN]; - inet_ntop(AF_INET, ptr, ip, sizeof(ip)); - t_domain_trusted *p = config->pan_domains_trusted; - while (p) { - if (strstr(domain, p->domain)) { - t_ip_trusted *ip_entry = malloc(sizeof(t_ip_trusted)); - strcpy(ip_entry->ip, ip); - ip_entry->next = p->ips_trusted; - p->ips_trusted = ip_entry; - printf("Trusted domain: %s -> %s\n", domain, ip); + if (type == 1 && data_len == 4) { // Type A record + t_ip_trusted *ip_trusted = p->ips_trusted; + while(ip_trusted) { + if (ip_trusted->uip == *(unsigned int *)ptr) { break; } - p = p->next; + ip_trusted = ip_trusted->next; + } + if (!ip_trusted) { + char ip[INET_ADDRSTRLEN] = {0}; + inet_ntop(AF_INET, ptr, ip, sizeof(ip)); + debug(LOG_DEBUG, "Trusted domain: %s -> %s", query_name, ip); + t_ip_trusted *new_ip_trusted = (t_ip_trusted *)malloc(sizeof(t_ip_trusted)); + new_ip_trusted->uip = *(unsigned int *)ptr; + new_ip_trusted->next = p->ips_trusted; + p->ips_trusted = new_ip_trusted; + char cmd[128] = {0}; + snprintf(cmd, sizeof(cmd), "nft add element set wifidogx_inner_trust_domains %s", ip); + system(cmd); } } ptr += data_len; } - return 0; -} - -static void -process_dns_response(char *response, int response_len) -{ - ns_msg handle; - s_config *config = config_get_config(); - - - if (ns_initparse((const uint8_t *)response, response_len, &handle) < 0) { - debug(LOG_WARNING, "ns_initparse: %s", strerror(errno)); - return; - } - int msg_count = ns_msg_count(handle, ns_s_an); - debug(LOG_DEBUG, "DNS response contains %d answers, response_len %d", msg_count, response_len); - for (int i = 0; i < msg_count; i++) { - ns_rr rr; - int nret; - if (ns_parserr(&handle, ns_s_an, i, &rr) < 0) { - debug(LOG_WARNING, "ns_parserr: %s", strerror(errno)); - continue; - } - debug(LOG_DEBUG, "DNS response type: %d", ns_rr_type(rr)); - - if (ns_rr_type(rr) == ns_t_a) { - t_domain_trusted *p = NULL; - char domain[4096] = {0}; - nret = ns_name_uncompress(ns_msg_base(handle), ns_msg_end(handle), ns_rr_name(rr), domain, sizeof(domain)); - if (nret < 0) { - debug(LOG_WARNING, "ns_name_uncompress: %s", strerror(errno)); - continue; - } - debug(LOG_DEBUG, "Get dns response domain: %s", domain); - for (p = config->pan_domains_trusted; p; p = p->next) { - // reverse match the domain, check if the domain include p->domain - if (strrstr(domain, p->domain) != NULL) { - struct in_addr addr; - char cmd[128] = {0}; - memcpy(&addr, ns_rr_rdata(rr), sizeof(addr)); - debug(LOG_DEBUG, "Trusted domain: %s -> %s", domain, inet_ntoa(addr)); - snprintf(cmd, sizeof(cmd), "nft add element set wifidogx_inner_trust_domains %s", inet_ntoa(addr)); - system(cmd); - break; - } - } - } - } + return 0; } static void diff --git a/src/dns_forward.h b/src/dns_forward.h index c7b9bec00..c7812cfaa 100644 --- a/src/dns_forward.h +++ b/src/dns_forward.h @@ -3,6 +3,24 @@ #define DNS_FORWARD_PORT 15353 #define LOCAL_DNS_PORT 53 +#define MAX_DNS_NAME 256 + +// define a structure to hold the DNS header +struct dns_header { + unsigned short id; + unsigned char rd :1; + unsigned char tc :1; + unsigned char aa :1; + unsigned char opcode :4; + unsigned char qr :1; + unsigned char rcode :4; + unsigned char z :3; + unsigned char ra :1; + unsigned short qdcount; + unsigned short ancount; + unsigned short nscount; + unsigned short arcount; +}; void *dns_forward_thread(void *);