mysql.lua 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684
  1. -- Copyright (C) 2012 Yichun Zhang (agentzh)
  2. -- Copyright (C) 2014 Chang Feng
  3. -- This file is modified version from https://github.com/openresty/lua-resty-mysql
  4. -- The license is under the BSD license.
  5. -- Modified by Cloud Wu (remove bit32 for lua 5.3)
  6. local socketchannel = require "skynet.socketchannel"
  7. local mysqlaux = require "skynet.mysqlaux.c"
  8. local crypt = require "skynet.crypt"
  9. local sub = string.sub
  10. local strgsub = string.gsub
  11. local strformat = string.format
  12. local strbyte = string.byte
  13. local strchar = string.char
  14. local strrep = string.rep
  15. local strunpack = string.unpack
  16. local strpack = string.pack
  17. local sha1= crypt.sha1
  18. local setmetatable = setmetatable
  19. local error = error
  20. local tonumber = tonumber
  21. local new_tab = function (narr, nrec) return {} end
  22. local _M = { _VERSION = '0.13' }
  23. -- constants
  24. local STATE_CONNECTED = 1
  25. local STATE_COMMAND_SENT = 2
  26. local COM_QUERY = 0x03
  27. local SERVER_MORE_RESULTS_EXISTS = 8
  28. -- 16MB - 1, the default max allowed packet size used by libmysqlclient
  29. local FULL_PACKET_SIZE = 16777215
  30. local mt = { __index = _M }
  31. -- mysql field value type converters
  32. local converters = new_tab(0, 8)
  33. for i = 0x01, 0x05 do
  34. -- tiny, short, long, float, double
  35. converters[i] = tonumber
  36. end
  37. converters[0x08] = tonumber -- long long
  38. converters[0x09] = tonumber -- int24
  39. converters[0x0d] = tonumber -- year
  40. converters[0xf6] = tonumber -- newdecimal
  41. local function _get_byte2(data, i)
  42. return strunpack("<I2",data,i)
  43. end
  44. local function _get_byte3(data, i)
  45. return strunpack("<I3",data,i)
  46. end
  47. local function _get_byte4(data, i)
  48. return strunpack("<I4",data,i)
  49. end
  50. local function _get_byte8(data, i)
  51. return strunpack("<I8",data,i)
  52. end
  53. local function _set_byte2(n)
  54. return strpack("<I2", n)
  55. end
  56. local function _set_byte3(n)
  57. return strpack("<I3", n)
  58. end
  59. local function _set_byte4(n)
  60. return strpack("<I4", n)
  61. end
  62. local function _from_cstring(data, i)
  63. return strunpack("z", data, i)
  64. end
  65. local function _dumphex(bytes)
  66. return strgsub(bytes, ".", function(x) return strformat("%02x ", strbyte(x)) end)
  67. end
  68. local function _compute_token(password, scramble)
  69. if password == "" then
  70. return ""
  71. end
  72. --_dumphex(scramble)
  73. local stage1 = sha1(password)
  74. --print("stage1:", _dumphex(stage1) )
  75. local stage2 = sha1(stage1)
  76. local stage3 = sha1(scramble .. stage2)
  77. local i = 0
  78. return strgsub(stage3,".",
  79. function(x)
  80. i = i + 1
  81. -- ~ is xor in lua 5.3
  82. return strchar(strbyte(x) ~ strbyte(stage1, i))
  83. end)
  84. end
  85. local function _compose_packet(self, req, size)
  86. self.packet_no = self.packet_no + 1
  87. local packet = _set_byte3(size) .. strchar(self.packet_no) .. req
  88. return packet
  89. end
  90. local function _recv_packet(self,sock)
  91. local data = sock:read( 4)
  92. if not data then
  93. return nil, nil, "failed to receive packet header: "
  94. end
  95. local len, pos = _get_byte3(data, 1)
  96. if len == 0 then
  97. return nil, nil, "empty packet"
  98. end
  99. if len > self._max_packet_size then
  100. return nil, nil, "packet size too big: " .. len
  101. end
  102. local num = strbyte(data, pos)
  103. self.packet_no = num
  104. data = sock:read(len)
  105. if not data then
  106. return nil, nil, "failed to read packet content: "
  107. end
  108. local field_count = strbyte(data, 1)
  109. local typ
  110. if field_count == 0x00 then
  111. typ = "OK"
  112. elseif field_count == 0xff then
  113. typ = "ERR"
  114. elseif field_count == 0xfe then
  115. typ = "EOF"
  116. else
  117. typ = "DATA"
  118. end
  119. return data, typ
  120. end
  121. local function _from_length_coded_bin(data, pos)
  122. local first = strbyte(data, pos)
  123. if not first then
  124. return nil, pos
  125. end
  126. if first >= 0 and first <= 250 then
  127. return first, pos + 1
  128. end
  129. if first == 251 then
  130. return nil, pos + 1
  131. end
  132. if first == 252 then
  133. pos = pos + 1
  134. return _get_byte2(data, pos)
  135. end
  136. if first == 253 then
  137. pos = pos + 1
  138. return _get_byte3(data, pos)
  139. end
  140. if first == 254 then
  141. pos = pos + 1
  142. return _get_byte8(data, pos)
  143. end
  144. return false, pos + 1
  145. end
  146. local function _from_length_coded_str(data, pos)
  147. local len
  148. len, pos = _from_length_coded_bin(data, pos)
  149. if len == nil then
  150. return nil, pos
  151. end
  152. return sub(data, pos, pos + len - 1), pos + len
  153. end
  154. local function _parse_ok_packet(packet)
  155. local res = new_tab(0, 5)
  156. local pos
  157. res.affected_rows, pos = _from_length_coded_bin(packet, 2)
  158. res.insert_id, pos = _from_length_coded_bin(packet, pos)
  159. res.server_status, pos = _get_byte2(packet, pos)
  160. res.warning_count, pos = _get_byte2(packet, pos)
  161. local message = sub(packet, pos)
  162. if message and message ~= "" then
  163. res.message = message
  164. end
  165. return res
  166. end
  167. local function _parse_eof_packet(packet)
  168. local pos = 2
  169. local warning_count, pos = _get_byte2(packet, pos)
  170. local status_flags = _get_byte2(packet, pos)
  171. return warning_count, status_flags
  172. end
  173. local function _parse_err_packet(packet)
  174. local errno, pos = _get_byte2(packet, 2)
  175. local marker = sub(packet, pos, pos)
  176. local sqlstate
  177. if marker == '#' then
  178. -- with sqlstate
  179. pos = pos + 1
  180. sqlstate = sub(packet, pos, pos + 5 - 1)
  181. pos = pos + 5
  182. end
  183. local message = sub(packet, pos)
  184. return errno, message, sqlstate
  185. end
  186. local function _parse_result_set_header_packet(packet)
  187. local field_count, pos = _from_length_coded_bin(packet, 1)
  188. local extra
  189. extra = _from_length_coded_bin(packet, pos)
  190. return field_count, extra
  191. end
  192. local function _parse_field_packet(data)
  193. local col = new_tab(0, 2)
  194. local catalog, db, table, orig_table, orig_name, charsetnr, length
  195. local pos
  196. catalog, pos = _from_length_coded_str(data, 1)
  197. db, pos = _from_length_coded_str(data, pos)
  198. table, pos = _from_length_coded_str(data, pos)
  199. orig_table, pos = _from_length_coded_str(data, pos)
  200. col.name, pos = _from_length_coded_str(data, pos)
  201. orig_name, pos = _from_length_coded_str(data, pos)
  202. pos = pos + 1 -- ignore the filler
  203. charsetnr, pos = _get_byte2(data, pos)
  204. length, pos = _get_byte4(data, pos)
  205. col.type = strbyte(data, pos)
  206. --[[
  207. pos = pos + 1
  208. col.flags, pos = _get_byte2(data, pos)
  209. col.decimals = strbyte(data, pos)
  210. pos = pos + 1
  211. local default = sub(data, pos + 2)
  212. if default and default ~= "" then
  213. col.default = default
  214. end
  215. --]]
  216. return col
  217. end
  218. local function _parse_row_data_packet(data, cols, compact)
  219. local pos = 1
  220. local ncols = #cols
  221. local row
  222. if compact then
  223. row = new_tab(ncols, 0)
  224. else
  225. row = new_tab(0, ncols)
  226. end
  227. for i = 1, ncols do
  228. local value
  229. value, pos = _from_length_coded_str(data, pos)
  230. local col = cols[i]
  231. local typ = col.type
  232. local name = col.name
  233. if value ~= nil then
  234. local conv = converters[typ]
  235. if conv then
  236. value = conv(value)
  237. end
  238. end
  239. if compact then
  240. row[i] = value
  241. else
  242. row[name] = value
  243. end
  244. end
  245. return row
  246. end
  247. local function _recv_field_packet(self, sock)
  248. local packet, typ, err = _recv_packet(self, sock)
  249. if not packet then
  250. return nil, err
  251. end
  252. if typ == "ERR" then
  253. local errno, msg, sqlstate = _parse_err_packet(packet)
  254. return nil, msg, errno, sqlstate
  255. end
  256. if typ ~= 'DATA' then
  257. return nil, "bad field packet type: " .. typ
  258. end
  259. -- typ == 'DATA'
  260. return _parse_field_packet(packet)
  261. end
  262. local function _recv_decode_packet_resp(self)
  263. return function(sock)
  264. -- don't return more than 2 results
  265. return true, (_recv_packet(self,sock))
  266. end
  267. end
  268. local function _recv_auth_resp(self)
  269. return function(sock)
  270. local packet, typ, err = _recv_packet(self,sock)
  271. if not packet then
  272. --print("recv auth resp : failed to receive the result packet")
  273. error ("failed to receive the result packet"..err)
  274. --return nil,err
  275. end
  276. if typ == 'ERR' then
  277. local errno, msg, sqlstate = _parse_err_packet(packet)
  278. error( strformat("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate))
  279. --return nil, errno,msg, sqlstate
  280. end
  281. if typ == 'EOF' then
  282. error "old pre-4.1 authentication protocol not supported"
  283. end
  284. if typ ~= 'OK' then
  285. error "bad packet type: "
  286. end
  287. return true, true
  288. end
  289. end
  290. local function _mysql_login(self,user,password,database,on_connect)
  291. return function(sockchannel)
  292. local packet, typ, err = sockchannel:response( _recv_decode_packet_resp(self) )
  293. --local aat={}
  294. if not packet then
  295. error( err )
  296. end
  297. if typ == "ERR" then
  298. local errno, msg, sqlstate = _parse_err_packet(packet)
  299. error( strformat("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate))
  300. end
  301. self.protocol_ver = strbyte(packet)
  302. local server_ver, pos = _from_cstring(packet, 2)
  303. if not server_ver then
  304. error "bad handshake initialization packet: bad server version"
  305. end
  306. self._server_ver = server_ver
  307. local thread_id, pos = _get_byte4(packet, pos)
  308. local scramble1 = sub(packet, pos, pos + 8 - 1)
  309. if not scramble1 then
  310. error "1st part of scramble not found"
  311. end
  312. pos = pos + 9 -- skip filler
  313. -- two lower bytes
  314. self._server_capabilities, pos = _get_byte2(packet, pos)
  315. self._server_lang = strbyte(packet, pos)
  316. pos = pos + 1
  317. self._server_status, pos = _get_byte2(packet, pos)
  318. local more_capabilities
  319. more_capabilities, pos = _get_byte2(packet, pos)
  320. self._server_capabilities = self._server_capabilities|more_capabilities<<16
  321. local len = 21 - 8 - 1
  322. pos = pos + 1 + 10
  323. local scramble_part2 = sub(packet, pos, pos + len - 1)
  324. if not scramble_part2 then
  325. error "2nd part of scramble not found"
  326. end
  327. local scramble = scramble1..scramble_part2
  328. local token = _compute_token(password, scramble)
  329. local client_flags = 260047;
  330. local req = strpack("<I4I4c24zs1z",
  331. client_flags,
  332. self._max_packet_size,
  333. strrep("\0", 24), -- TODO: add support for charset encoding
  334. user,
  335. token,
  336. database)
  337. local packet_len = #req
  338. local authpacket=_compose_packet(self,req,packet_len)
  339. sockchannel:request(authpacket,_recv_auth_resp(self))
  340. if on_connect then
  341. on_connect(self)
  342. end
  343. end
  344. end
  345. local function _compose_query(self, query)
  346. self.packet_no = -1
  347. local cmd_packet = strchar(COM_QUERY) .. query
  348. local packet_len = 1 + #query
  349. local querypacket = _compose_packet(self, cmd_packet, packet_len)
  350. return querypacket
  351. end
  352. local function read_result(self, sock)
  353. local packet, typ, err = _recv_packet(self, sock)
  354. if not packet then
  355. return nil, err
  356. --error( err )
  357. end
  358. if typ == "ERR" then
  359. local errno, msg, sqlstate = _parse_err_packet(packet)
  360. return nil, msg, errno, sqlstate
  361. --error( strformat("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate))
  362. end
  363. if typ == 'OK' then
  364. local res = _parse_ok_packet(packet)
  365. if res and res.server_status&SERVER_MORE_RESULTS_EXISTS ~= 0 then
  366. return res, "again"
  367. end
  368. return res
  369. end
  370. if typ ~= 'DATA' then
  371. return nil, "packet type " .. typ .. " not supported"
  372. --error( "packet type " .. typ .. " not supported" )
  373. end
  374. -- typ == 'DATA'
  375. local field_count, extra = _parse_result_set_header_packet(packet)
  376. local cols = new_tab(field_count, 0)
  377. for i = 1, field_count do
  378. local col, err, errno, sqlstate = _recv_field_packet(self, sock)
  379. if not col then
  380. return nil, err, errno, sqlstate
  381. --error( strformat("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate))
  382. end
  383. cols[i] = col
  384. end
  385. local packet, typ, err = _recv_packet(self, sock)
  386. if not packet then
  387. --error( err)
  388. return nil, err
  389. end
  390. if typ ~= 'EOF' then
  391. --error ( "unexpected packet type " .. typ .. " while eof packet is ".. "expected" )
  392. return nil, "unexpected packet type " .. typ .. " while eof packet is ".. "expected"
  393. end
  394. -- typ == 'EOF'
  395. local compact = self.compact
  396. local rows = new_tab( 4, 0)
  397. local i = 0
  398. while true do
  399. packet, typ, err = _recv_packet(self, sock)
  400. if not packet then
  401. --error (err)
  402. return nil, err
  403. end
  404. if typ == 'EOF' then
  405. local warning_count, status_flags = _parse_eof_packet(packet)
  406. if status_flags&SERVER_MORE_RESULTS_EXISTS ~= 0 then
  407. return rows, "again"
  408. end
  409. break
  410. end
  411. -- if typ ~= 'DATA' then
  412. -- return nil, 'bad row packet type: ' .. typ
  413. -- end
  414. -- typ == 'DATA'
  415. local row = _parse_row_data_packet(packet, cols, compact)
  416. i = i + 1
  417. rows[i] = row
  418. end
  419. return rows
  420. end
  421. local function _query_resp(self)
  422. return function(sock)
  423. local res, err, errno, sqlstate = read_result(self,sock)
  424. if not res then
  425. local badresult ={}
  426. badresult.badresult = true
  427. badresult.err = err
  428. badresult.errno = errno
  429. badresult.sqlstate = sqlstate
  430. return true , badresult
  431. end
  432. if err ~= "again" then
  433. return true, res
  434. end
  435. local mulitresultset = {res}
  436. mulitresultset.mulitresultset = true
  437. local i =2
  438. while err =="again" do
  439. res, err, errno, sqlstate = read_result(self,sock)
  440. if not res then
  441. mulitresultset.badresult = true
  442. mulitresultset.err = err
  443. mulitresultset.errno = errno
  444. mulitresultset.sqlstate = sqlstate
  445. return true, mulitresultset
  446. end
  447. mulitresultset[i]=res
  448. i=i+1
  449. end
  450. return true, mulitresultset
  451. end
  452. end
  453. function _M.connect(opts)
  454. local self = setmetatable( {}, mt)
  455. local max_packet_size = opts.max_packet_size
  456. if not max_packet_size then
  457. max_packet_size = 1024 * 1024 -- default 1 MB
  458. end
  459. self._max_packet_size = max_packet_size
  460. self.compact = opts.compact_arrays
  461. local database = opts.database or ""
  462. local user = opts.user or ""
  463. local password = opts.password or ""
  464. local channel = socketchannel.channel {
  465. host = opts.host,
  466. port = opts.port or 3306,
  467. auth = _mysql_login(self,user,password,database,opts.on_connect),
  468. }
  469. self.sockchannel = channel
  470. -- try connect first only once
  471. channel:connect(true)
  472. return self
  473. end
  474. function _M.disconnect(self)
  475. self.sockchannel:close()
  476. setmetatable(self, nil)
  477. end
  478. function _M.query(self, query)
  479. local querypacket = _compose_query(self, query)
  480. local sockchannel = self.sockchannel
  481. if not self.query_resp then
  482. self.query_resp = _query_resp(self)
  483. end
  484. return sockchannel:request( querypacket, self.query_resp )
  485. end
  486. function _M.server_ver(self)
  487. return self._server_ver
  488. end
  489. function _M.quote_sql_str( str)
  490. return mysqlaux.quote_sql_str(str)
  491. end
  492. function _M.set_compact_arrays(self, value)
  493. self.compact = value
  494. end
  495. return _M