dns.lua 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. --[[
  2. lua dns resolver library
  3. See https://github.com/xjdrew/levent/blob/master/levent/dns.lua for more detail
  4. -- resource record type:
  5. -- TYPE value and meaning
  6. -- A 1 a host address
  7. -- NS 2 an authoritative name server
  8. -- MD 3 a mail destination (Obsolete - use MX)
  9. -- MF 4 a mail forwarder (Obsolete - use MX)
  10. -- CNAME 5 the canonical name for an alias
  11. -- SOA 6 marks the start of a zone of authority
  12. -- MB 7 a mailbox domain name (EXPERIMENTAL)
  13. -- MG 8 a mail group member (EXPERIMENTAL)
  14. -- MR 9 a mail rename domain name (EXPERIMENTAL)
  15. -- NULL 10 a null RR (EXPERIMENTAL)
  16. -- WKS 11 a well known service description
  17. -- PTR 12 a domain name pointer
  18. -- HINFO 13 host information
  19. -- MINFO 14 mailbox or mail list information
  20. -- MX 15 mail exchange
  21. -- TXT 16 text strings
  22. -- AAAA 28 a ipv6 host address
  23. -- only appear in the question section:
  24. -- AXFR 252 A request for a transfer of an entire zone
  25. -- MAILB 253 A request for mailbox-related records (MB, MG or MR)
  26. -- MAILA 254 A request for mail agent RRs (Obsolete - see MX)
  27. -- * 255 A request for all records
  28. --
  29. -- resource recode class:
  30. -- IN 1 the Internet
  31. -- CS 2 the CSNET class (Obsolete - used only for examples in some obsolete RFCs)
  32. -- CH 3 the CHAOS class
  33. -- HS 4 Hesiod [Dyer 87]
  34. -- only appear in the question section:
  35. -- * 255 any class
  36. -- ]]
  37. --[[
  38. -- struct header {
  39. -- uint16_t tid # identifier assigned by the program that generates any kind of query.
  40. -- uint16_t flags # flags
  41. -- uint16_t qdcount # the number of entries in the question section.
  42. -- uint16_t ancount # the number of resource records in the answer section.
  43. -- uint16_t nscount # the number of name server resource records in the authority records section.
  44. -- uint16_t arcount # the number of resource records in the additional records section.
  45. -- }
  46. --
  47. -- request body:
  48. -- struct request {
  49. -- string name
  50. -- uint16_t atype
  51. -- uint16_t class
  52. -- }
  53. --
  54. -- response body:
  55. -- struct response {
  56. -- string name
  57. -- uint16_t atype
  58. -- uint16_t class
  59. -- uint16_t ttl
  60. -- uint16_t rdlength
  61. -- string rdata
  62. -- }
  63. --]]
  64. local skynet = require "skynet"
  65. local socket = require "skynet.socket"
  66. local MAX_DOMAIN_LEN = 1024
  67. local MAX_LABEL_LEN = 63
  68. local MAX_PACKET_LEN = 2048
  69. local DNS_HEADER_LEN = 12
  70. local TIMEOUT = 30 * 100 -- 30 seconds
  71. local QTYPE = {
  72. A = 1,
  73. CNAME = 5,
  74. AAAA = 28,
  75. }
  76. local QCLASS = {
  77. IN = 1,
  78. }
  79. local weak = {__mode = "kv"}
  80. local CACHE = {}
  81. local dns = {}
  82. local request_pool = {}
  83. function dns.flush()
  84. CACHE[QTYPE.A] = setmetatable({},weak)
  85. CACHE[QTYPE.AAAA] = setmetatable({},weak)
  86. end
  87. dns.flush()
  88. local function verify_domain_name(name)
  89. if #name > MAX_DOMAIN_LEN then
  90. return false
  91. end
  92. if not name:match("^[_%l%d%-%.]+$") then
  93. return false
  94. end
  95. for w in name:gmatch("([_%w%-]+)%.?") do
  96. if #w > MAX_LABEL_LEN then
  97. return false
  98. end
  99. end
  100. return true
  101. end
  102. local next_tid = 1
  103. local function gen_tid()
  104. local tid = next_tid
  105. if request_pool[tid] then
  106. tid = nil
  107. for i = 1, 65535 do
  108. -- find available tid
  109. if not request_pool[i] then
  110. tid = i
  111. break
  112. end
  113. end
  114. assert(tid)
  115. end
  116. next_tid = tid + 1
  117. if next_tid > 65535 then
  118. next_tid = 1
  119. end
  120. return tid
  121. end
  122. local function pack_header(t)
  123. return string.pack(">HHHHHH",
  124. t.tid, t.flags, t.qdcount, t.ancount or 0, t.nscount or 0, t.arcount or 0)
  125. end
  126. local function pack_question(name, qtype, qclass)
  127. local labels = {}
  128. for w in name:gmatch("([_%w%-]+)%.?") do
  129. table.insert(labels, string.pack("s1",w))
  130. end
  131. table.insert(labels, '\0')
  132. table.insert(labels, string.pack(">HH", qtype, qclass))
  133. return table.concat(labels)
  134. end
  135. local function unpack_header(chunk)
  136. local tid, flags, qdcount, ancount, nscount, arcount, left = string.unpack(">HHHHHH", chunk)
  137. return {
  138. tid = tid,
  139. flags = flags,
  140. qdcount = qdcount,
  141. ancount = ancount,
  142. nscount = nscount,
  143. arcount = arcount
  144. }, left
  145. end
  146. -- unpack a resource name
  147. local function unpack_name(chunk, left)
  148. local t = {}
  149. local jump_pointer
  150. local tag, offset, label
  151. while true do
  152. tag, left = string.unpack("B", chunk, left)
  153. if tag & 0xc0 == 0xc0 then
  154. -- pointer
  155. offset,left = string.unpack(">H", chunk, left - 1)
  156. offset = offset & 0x3fff
  157. if not jump_pointer then
  158. jump_pointer = left
  159. end
  160. -- offset is base 0, need to plus 1
  161. left = offset + 1
  162. elseif tag == 0 then
  163. break
  164. else
  165. label, left = string.unpack("s1", chunk, left - 1)
  166. t[#t+1] = label
  167. end
  168. end
  169. return table.concat(t, "."), jump_pointer or left
  170. end
  171. local function unpack_question(chunk, left)
  172. local name, left = unpack_name(chunk, left)
  173. local atype, class, left = string.unpack(">HH", chunk, left)
  174. return {
  175. name = name,
  176. atype = atype,
  177. class = class
  178. }, left
  179. end
  180. local function unpack_answer(chunk, left)
  181. local name, left = unpack_name(chunk, left)
  182. local atype, class, ttl, rdata, left = string.unpack(">HHI4s2", chunk, left)
  183. return {
  184. name = name,
  185. atype = atype,
  186. class = class,
  187. ttl = ttl,
  188. rdata = rdata
  189. },left
  190. end
  191. local function unpack_rdata(qtype, chunk)
  192. if qtype == QTYPE.A then
  193. local a,b,c,d = string.unpack("BBBB", chunk)
  194. return string.format("%d.%d.%d.%d", a,b,c,d)
  195. elseif qtype == QTYPE.AAAA then
  196. local a,b,c,d,e,f,g,h = string.unpack(">HHHHHHHH", chunk)
  197. return string.format("%x:%x:%x:%x:%x:%x:%x:%x", a, b, c, d, e, f, g, h)
  198. else
  199. error("Error qtype " .. qtype)
  200. end
  201. end
  202. local dns_server = {
  203. fd = nil,
  204. address = nil,
  205. port = nil,
  206. retire = nil,
  207. }
  208. local function resolve(content)
  209. if #content < DNS_HEADER_LEN then
  210. -- drop
  211. skynet.error("Recv an invalid package when dns query")
  212. return
  213. end
  214. local answer_header,left = unpack_header(content)
  215. -- verify answer
  216. assert(answer_header.qdcount == 1, "malformed packet")
  217. local question,left = unpack_question(content, left)
  218. local ttl
  219. local answer
  220. local answers_ipv4
  221. local answers_ipv6
  222. for i=1, answer_header.ancount do
  223. answer, left = unpack_answer(content, left)
  224. local answers
  225. if answer.atype == QTYPE.A then
  226. answers_ipv4 = answers_ipv4 or {}
  227. answers = answers_ipv4
  228. elseif answer.atype == QTYPE.AAAA then
  229. answers_ipv6 = answers_ipv6 or {}
  230. answers = answers_ipv6
  231. end
  232. if answers then
  233. local ip = unpack_rdata(answer.atype, answer.rdata)
  234. ttl = ttl and math.min(ttl, answer.ttl) or answer.ttl
  235. answers[#answers+1] = ip
  236. end
  237. end
  238. if answers_ipv4 then
  239. CACHE[QTYPE.A][question.name] = { answers = answers_ipv4, ttl = skynet.now() + ttl * 100 }
  240. end
  241. if answers_ipv6 then
  242. CACHE[QTYPE.AAAA][question.name] = { answers = answers_ipv6, ttl = skynet.now() + ttl * 100 }
  243. end
  244. local resp = request_pool[answer_header.tid]
  245. if not resp then
  246. -- the resp may be timeout
  247. return
  248. end
  249. if question.name ~= resp.name then
  250. skynet.error("Recv an invalid name when dns query")
  251. end
  252. local r = CACHE[resp.qtype][resp.name]
  253. if r then
  254. resp.answers = r.answers
  255. end
  256. skynet.wakeup(resp.co)
  257. end
  258. local DNS_SERVER_RETIRE = 60 * 100
  259. local function touch_server()
  260. dns_server.retire = skynet.now()
  261. if dns_server.fd then
  262. return
  263. end
  264. dns_server.fd = socket.udp(function(str, from)
  265. resolve(str)
  266. end)
  267. skynet.error(string.format("Udp server open %s:%s (%d)", dns_server.address, dns_server.port, dns_server.fd))
  268. socket.udp_connect(dns_server.fd, dns_server.address, dns_server.port)
  269. local function check_alive()
  270. if skynet.now() > dns_server.retire + DNS_SERVER_RETIRE then
  271. local fd = dns_server.fd
  272. if fd then
  273. dns_server.fd = nil
  274. socket.close(fd)
  275. skynet.error(string.format("Udp server close %s:%s (%d)", dns_server.address, dns_server.port, fd))
  276. end
  277. else
  278. skynet.timeout( 2 * DNS_SERVER_RETIRE, check_alive)
  279. end
  280. end
  281. skynet.timeout( 2 * DNS_SERVER_RETIRE, check_alive)
  282. end
  283. function dns.server(server, port)
  284. if not server then
  285. local f = assert(io.open "/etc/resolv.conf")
  286. for line in f:lines() do
  287. server = line:match("%s*nameserver%s+([^%s]+)")
  288. if server then
  289. break
  290. end
  291. end
  292. f:close()
  293. assert(server, "Can't get nameserver")
  294. end
  295. assert(dns_server.fd == nil) -- only set dns.server once
  296. dns_server.address = server
  297. dns_server.port = port or 53
  298. touch_server()
  299. return dns_server.address
  300. end
  301. local function lookup_cache(name, qtype, ignorettl)
  302. local result = CACHE[qtype][name]
  303. if result then
  304. if ignorettl or (result.ttl > skynet.now()) then
  305. return result.answers
  306. end
  307. end
  308. end
  309. local function suspend(tid, name, qtype)
  310. local req = {
  311. name = name,
  312. tid = tid,
  313. qtype = qtype,
  314. co = coroutine.running(),
  315. }
  316. request_pool[tid] = req
  317. skynet.fork(function()
  318. skynet.sleep(TIMEOUT)
  319. local req = request_pool[tid]
  320. if req then
  321. -- cancel tid
  322. skynet.error(string.format("DNS query %s timeout", name))
  323. request_pool[tid] = nil
  324. skynet.wakeup(req.co)
  325. end
  326. end)
  327. skynet.wait(req.co)
  328. local answers = req.answers
  329. request_pool[tid] = nil
  330. if not req.answers then
  331. local answers = lookup_cache(name, qtype, true)
  332. if answers then
  333. return answers[1], answers
  334. end
  335. error "timeout or no answer"
  336. end
  337. return req.answers[1], req.answers
  338. end
  339. function dns.resolve(name, ipv6)
  340. local qtype = ipv6 and QTYPE.AAAA or QTYPE.A
  341. local name = name:lower()
  342. assert(verify_domain_name(name) , "illegal name")
  343. local answers = lookup_cache(name, qtype)
  344. if answers then
  345. return answers[1], answers
  346. end
  347. local question_header = {
  348. tid = gen_tid(),
  349. flags = 0x100, -- flags: 00000001 00000000, set RD
  350. qdcount = 1,
  351. }
  352. local req = pack_header(question_header) .. pack_question(name, qtype, QCLASS.IN)
  353. assert(dns_server.address, "Call dns.server first")
  354. touch_server()
  355. socket.write(dns_server.fd, req)
  356. return suspend(question_header.tid, name, qtype)
  357. end
  358. return dns