123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396 |
- --[[
- lua dns resolver library
- See https://github.com/xjdrew/levent/blob/master/levent/dns.lua for more detail
- -- resource record type:
- -- TYPE value and meaning
- -- A 1 a host address
- -- NS 2 an authoritative name server
- -- MD 3 a mail destination (Obsolete - use MX)
- -- MF 4 a mail forwarder (Obsolete - use MX)
- -- CNAME 5 the canonical name for an alias
- -- SOA 6 marks the start of a zone of authority
- -- MB 7 a mailbox domain name (EXPERIMENTAL)
- -- MG 8 a mail group member (EXPERIMENTAL)
- -- MR 9 a mail rename domain name (EXPERIMENTAL)
- -- NULL 10 a null RR (EXPERIMENTAL)
- -- WKS 11 a well known service description
- -- PTR 12 a domain name pointer
- -- HINFO 13 host information
- -- MINFO 14 mailbox or mail list information
- -- MX 15 mail exchange
- -- TXT 16 text strings
- -- AAAA 28 a ipv6 host address
- -- only appear in the question section:
- -- AXFR 252 A request for a transfer of an entire zone
- -- MAILB 253 A request for mailbox-related records (MB, MG or MR)
- -- MAILA 254 A request for mail agent RRs (Obsolete - see MX)
- -- * 255 A request for all records
- --
- -- resource recode class:
- -- IN 1 the Internet
- -- CS 2 the CSNET class (Obsolete - used only for examples in some obsolete RFCs)
- -- CH 3 the CHAOS class
- -- HS 4 Hesiod [Dyer 87]
- -- only appear in the question section:
- -- * 255 any class
- -- ]]
- --[[
- -- struct header {
- -- uint16_t tid # identifier assigned by the program that generates any kind of query.
- -- uint16_t flags # flags
- -- uint16_t qdcount # the number of entries in the question section.
- -- uint16_t ancount # the number of resource records in the answer section.
- -- uint16_t nscount # the number of name server resource records in the authority records section.
- -- uint16_t arcount # the number of resource records in the additional records section.
- -- }
- --
- -- request body:
- -- struct request {
- -- string name
- -- uint16_t atype
- -- uint16_t class
- -- }
- --
- -- response body:
- -- struct response {
- -- string name
- -- uint16_t atype
- -- uint16_t class
- -- uint16_t ttl
- -- uint16_t rdlength
- -- string rdata
- -- }
- --]]
- local skynet = require "skynet"
- local socket = require "skynet.socket"
- local MAX_DOMAIN_LEN = 1024
- local MAX_LABEL_LEN = 63
- local MAX_PACKET_LEN = 2048
- local DNS_HEADER_LEN = 12
- local TIMEOUT = 30 * 100 -- 30 seconds
- local QTYPE = {
- A = 1,
- CNAME = 5,
- AAAA = 28,
- }
- local QCLASS = {
- IN = 1,
- }
- local weak = {__mode = "kv"}
- local CACHE = {}
- local dns = {}
- local request_pool = {}
- function dns.flush()
- CACHE[QTYPE.A] = setmetatable({},weak)
- CACHE[QTYPE.AAAA] = setmetatable({},weak)
- end
- dns.flush()
- local function verify_domain_name(name)
- if #name > MAX_DOMAIN_LEN then
- return false
- end
- if not name:match("^[_%l%d%-%.]+$") then
- return false
- end
- for w in name:gmatch("([_%w%-]+)%.?") do
- if #w > MAX_LABEL_LEN then
- return false
- end
- end
- return true
- end
- local next_tid = 1
- local function gen_tid()
- local tid = next_tid
- if request_pool[tid] then
- tid = nil
- for i = 1, 65535 do
- -- find available tid
- if not request_pool[i] then
- tid = i
- break
- end
- end
- assert(tid)
- end
- next_tid = tid + 1
- if next_tid > 65535 then
- next_tid = 1
- end
- return tid
- end
- local function pack_header(t)
- return string.pack(">HHHHHH",
- t.tid, t.flags, t.qdcount, t.ancount or 0, t.nscount or 0, t.arcount or 0)
- end
- local function pack_question(name, qtype, qclass)
- local labels = {}
- for w in name:gmatch("([_%w%-]+)%.?") do
- table.insert(labels, string.pack("s1",w))
- end
- table.insert(labels, '\0')
- table.insert(labels, string.pack(">HH", qtype, qclass))
- return table.concat(labels)
- end
- local function unpack_header(chunk)
- local tid, flags, qdcount, ancount, nscount, arcount, left = string.unpack(">HHHHHH", chunk)
- return {
- tid = tid,
- flags = flags,
- qdcount = qdcount,
- ancount = ancount,
- nscount = nscount,
- arcount = arcount
- }, left
- end
- -- unpack a resource name
- local function unpack_name(chunk, left)
- local t = {}
- local jump_pointer
- local tag, offset, label
- while true do
- tag, left = string.unpack("B", chunk, left)
- if tag & 0xc0 == 0xc0 then
- -- pointer
- offset,left = string.unpack(">H", chunk, left - 1)
- offset = offset & 0x3fff
- if not jump_pointer then
- jump_pointer = left
- end
- -- offset is base 0, need to plus 1
- left = offset + 1
- elseif tag == 0 then
- break
- else
- label, left = string.unpack("s1", chunk, left - 1)
- t[#t+1] = label
- end
- end
- return table.concat(t, "."), jump_pointer or left
- end
- local function unpack_question(chunk, left)
- local name, left = unpack_name(chunk, left)
- local atype, class, left = string.unpack(">HH", chunk, left)
- return {
- name = name,
- atype = atype,
- class = class
- }, left
- end
- local function unpack_answer(chunk, left)
- local name, left = unpack_name(chunk, left)
- local atype, class, ttl, rdata, left = string.unpack(">HHI4s2", chunk, left)
- return {
- name = name,
- atype = atype,
- class = class,
- ttl = ttl,
- rdata = rdata
- },left
- end
- local function unpack_rdata(qtype, chunk)
- if qtype == QTYPE.A then
- local a,b,c,d = string.unpack("BBBB", chunk)
- return string.format("%d.%d.%d.%d", a,b,c,d)
- elseif qtype == QTYPE.AAAA then
- local a,b,c,d,e,f,g,h = string.unpack(">HHHHHHHH", chunk)
- return string.format("%x:%x:%x:%x:%x:%x:%x:%x", a, b, c, d, e, f, g, h)
- else
- error("Error qtype " .. qtype)
- end
- end
- local dns_server = {
- fd = nil,
- address = nil,
- port = nil,
- retire = nil,
- }
- local function resolve(content)
- if #content < DNS_HEADER_LEN then
- -- drop
- skynet.error("Recv an invalid package when dns query")
- return
- end
- local answer_header,left = unpack_header(content)
- -- verify answer
- assert(answer_header.qdcount == 1, "malformed packet")
- local question,left = unpack_question(content, left)
- local ttl
- local answer
- local answers_ipv4
- local answers_ipv6
- for i=1, answer_header.ancount do
- answer, left = unpack_answer(content, left)
- local answers
- if answer.atype == QTYPE.A then
- answers_ipv4 = answers_ipv4 or {}
- answers = answers_ipv4
- elseif answer.atype == QTYPE.AAAA then
- answers_ipv6 = answers_ipv6 or {}
- answers = answers_ipv6
- end
- if answers then
- local ip = unpack_rdata(answer.atype, answer.rdata)
- ttl = ttl and math.min(ttl, answer.ttl) or answer.ttl
- answers[#answers+1] = ip
- end
- end
- if answers_ipv4 then
- CACHE[QTYPE.A][question.name] = { answers = answers_ipv4, ttl = skynet.now() + ttl * 100 }
- end
- if answers_ipv6 then
- CACHE[QTYPE.AAAA][question.name] = { answers = answers_ipv6, ttl = skynet.now() + ttl * 100 }
- end
- local resp = request_pool[answer_header.tid]
- if not resp then
- -- the resp may be timeout
- return
- end
- if question.name ~= resp.name then
- skynet.error("Recv an invalid name when dns query")
- end
- local r = CACHE[resp.qtype][resp.name]
- if r then
- resp.answers = r.answers
- end
- skynet.wakeup(resp.co)
- end
- local DNS_SERVER_RETIRE = 60 * 100
- local function touch_server()
- dns_server.retire = skynet.now()
- if dns_server.fd then
- return
- end
- dns_server.fd = socket.udp(function(str, from)
- resolve(str)
- end)
- skynet.error(string.format("Udp server open %s:%s (%d)", dns_server.address, dns_server.port, dns_server.fd))
- socket.udp_connect(dns_server.fd, dns_server.address, dns_server.port)
- local function check_alive()
- if skynet.now() > dns_server.retire + DNS_SERVER_RETIRE then
- local fd = dns_server.fd
- if fd then
- dns_server.fd = nil
- socket.close(fd)
- skynet.error(string.format("Udp server close %s:%s (%d)", dns_server.address, dns_server.port, fd))
- end
- else
- skynet.timeout( 2 * DNS_SERVER_RETIRE, check_alive)
- end
- end
- skynet.timeout( 2 * DNS_SERVER_RETIRE, check_alive)
- end
- function dns.server(server, port)
- if not server then
- local f = assert(io.open "/etc/resolv.conf")
- for line in f:lines() do
- server = line:match("%s*nameserver%s+([^%s]+)")
- if server then
- break
- end
- end
- f:close()
- assert(server, "Can't get nameserver")
- end
- assert(dns_server.fd == nil) -- only set dns.server once
- dns_server.address = server
- dns_server.port = port or 53
- touch_server()
- return dns_server.address
- end
- local function lookup_cache(name, qtype, ignorettl)
- local result = CACHE[qtype][name]
- if result then
- if ignorettl or (result.ttl > skynet.now()) then
- return result.answers
- end
- end
- end
- local function suspend(tid, name, qtype)
- local req = {
- name = name,
- tid = tid,
- qtype = qtype,
- co = coroutine.running(),
- }
- request_pool[tid] = req
- skynet.fork(function()
- skynet.sleep(TIMEOUT)
- local req = request_pool[tid]
- if req then
- -- cancel tid
- skynet.error(string.format("DNS query %s timeout", name))
- request_pool[tid] = nil
- skynet.wakeup(req.co)
- end
- end)
- skynet.wait(req.co)
- local answers = req.answers
- request_pool[tid] = nil
- if not req.answers then
- local answers = lookup_cache(name, qtype, true)
- if answers then
- return answers[1], answers
- end
- error "timeout or no answer"
- end
- return req.answers[1], req.answers
- end
- function dns.resolve(name, ipv6)
- local qtype = ipv6 and QTYPE.AAAA or QTYPE.A
- local name = name:lower()
- assert(verify_domain_name(name) , "illegal name")
- local answers = lookup_cache(name, qtype)
- if answers then
- return answers[1], answers
- end
- local question_header = {
- tid = gen_tid(),
- flags = 0x100, -- flags: 00000001 00000000, set RD
- qdcount = 1,
- }
- local req = pack_header(question_header) .. pack_question(name, qtype, QCLASS.IN)
- assert(dns_server.address, "Call dns.server first")
- touch_server()
- socket.write(dns_server.fd, req)
- return suspend(question_header.tid, name, qtype)
- end
- return dns
|