-- Copyright (C) 2012 Yichun Zhang (agentzh) -- Copyright (C) 2014 Chang Feng -- This file is modified version from https://github.com/openresty/lua-resty-mysql -- The license is under the BSD license. -- Modified by Cloud Wu (remove bit32 for lua 5.3) local socketchannel = require "skynet.socketchannel" local mysqlaux = require "skynet.mysqlaux.c" local crypt = require "skynet.crypt" local sub = string.sub local strgsub = string.gsub local strformat = string.format local strbyte = string.byte local strchar = string.char local strrep = string.rep local strunpack = string.unpack local strpack = string.pack local sha1= crypt.sha1 local setmetatable = setmetatable local error = error local tonumber = tonumber local new_tab = function (narr, nrec) return {} end local _M = { _VERSION = '0.13' } -- constants local STATE_CONNECTED = 1 local STATE_COMMAND_SENT = 2 local COM_QUERY = 0x03 local SERVER_MORE_RESULTS_EXISTS = 8 -- 16MB - 1, the default max allowed packet size used by libmysqlclient local FULL_PACKET_SIZE = 16777215 local mt = { __index = _M } -- mysql field value type converters local converters = new_tab(0, 8) for i = 0x01, 0x05 do -- tiny, short, long, float, double converters[i] = tonumber end converters[0x08] = tonumber -- long long converters[0x09] = tonumber -- int24 converters[0x0d] = tonumber -- year converters[0xf6] = tonumber -- newdecimal local function _get_byte2(data, i) return strunpack(" self._max_packet_size then return nil, nil, "packet size too big: " .. len end local num = strbyte(data, pos) self.packet_no = num data = sock:read(len) if not data then return nil, nil, "failed to read packet content: " end local field_count = strbyte(data, 1) local typ if field_count == 0x00 then typ = "OK" elseif field_count == 0xff then typ = "ERR" elseif field_count == 0xfe then typ = "EOF" else typ = "DATA" end return data, typ end local function _from_length_coded_bin(data, pos) local first = strbyte(data, pos) if not first then return nil, pos end if first >= 0 and first <= 250 then return first, pos + 1 end if first == 251 then return nil, pos + 1 end if first == 252 then pos = pos + 1 return _get_byte2(data, pos) end if first == 253 then pos = pos + 1 return _get_byte3(data, pos) end if first == 254 then pos = pos + 1 return _get_byte8(data, pos) end return false, pos + 1 end local function _from_length_coded_str(data, pos) local len len, pos = _from_length_coded_bin(data, pos) if len == nil then return nil, pos end return sub(data, pos, pos + len - 1), pos + len end local function _parse_ok_packet(packet) local res = new_tab(0, 5) local pos res.affected_rows, pos = _from_length_coded_bin(packet, 2) res.insert_id, pos = _from_length_coded_bin(packet, pos) res.server_status, pos = _get_byte2(packet, pos) res.warning_count, pos = _get_byte2(packet, pos) local message = sub(packet, pos) if message and message ~= "" then res.message = message end return res end local function _parse_eof_packet(packet) local pos = 2 local warning_count, pos = _get_byte2(packet, pos) local status_flags = _get_byte2(packet, pos) return warning_count, status_flags end local function _parse_err_packet(packet) local errno, pos = _get_byte2(packet, 2) local marker = sub(packet, pos, pos) local sqlstate if marker == '#' then -- with sqlstate pos = pos + 1 sqlstate = sub(packet, pos, pos + 5 - 1) pos = pos + 5 end local message = sub(packet, pos) return errno, message, sqlstate end local function _parse_result_set_header_packet(packet) local field_count, pos = _from_length_coded_bin(packet, 1) local extra extra = _from_length_coded_bin(packet, pos) return field_count, extra end local function _parse_field_packet(data) local col = new_tab(0, 2) local catalog, db, table, orig_table, orig_name, charsetnr, length local pos catalog, pos = _from_length_coded_str(data, 1) db, pos = _from_length_coded_str(data, pos) table, pos = _from_length_coded_str(data, pos) orig_table, pos = _from_length_coded_str(data, pos) col.name, pos = _from_length_coded_str(data, pos) orig_name, pos = _from_length_coded_str(data, pos) pos = pos + 1 -- ignore the filler charsetnr, pos = _get_byte2(data, pos) length, pos = _get_byte4(data, pos) col.type = strbyte(data, pos) --[[ pos = pos + 1 col.flags, pos = _get_byte2(data, pos) col.decimals = strbyte(data, pos) pos = pos + 1 local default = sub(data, pos + 2) if default and default ~= "" then col.default = default end --]] return col end local function _parse_row_data_packet(data, cols, compact) local pos = 1 local ncols = #cols local row if compact then row = new_tab(ncols, 0) else row = new_tab(0, ncols) end for i = 1, ncols do local value value, pos = _from_length_coded_str(data, pos) local col = cols[i] local typ = col.type local name = col.name if value ~= nil then local conv = converters[typ] if conv then value = conv(value) end end if compact then row[i] = value else row[name] = value end end return row end local function _recv_field_packet(self, sock) local packet, typ, err = _recv_packet(self, sock) if not packet then return nil, err end if typ == "ERR" then local errno, msg, sqlstate = _parse_err_packet(packet) return nil, msg, errno, sqlstate end if typ ~= 'DATA' then return nil, "bad field packet type: " .. typ end -- typ == 'DATA' return _parse_field_packet(packet) end local function _recv_decode_packet_resp(self) return function(sock) -- don't return more than 2 results return true, (_recv_packet(self,sock)) end end local function _recv_auth_resp(self) return function(sock) local packet, typ, err = _recv_packet(self,sock) if not packet then --print("recv auth resp : failed to receive the result packet") error ("failed to receive the result packet"..err) --return nil,err end if typ == 'ERR' then local errno, msg, sqlstate = _parse_err_packet(packet) error( strformat("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) --return nil, errno,msg, sqlstate end if typ == 'EOF' then error "old pre-4.1 authentication protocol not supported" end if typ ~= 'OK' then error "bad packet type: " end return true, true end end local function _mysql_login(self,user,password,database,on_connect) return function(sockchannel) local packet, typ, err = sockchannel:response( _recv_decode_packet_resp(self) ) --local aat={} if not packet then error( err ) end if typ == "ERR" then local errno, msg, sqlstate = _parse_err_packet(packet) error( strformat("errno:%d, msg:%s,sqlstate:%s",errno,msg,sqlstate)) end self.protocol_ver = strbyte(packet) local server_ver, pos = _from_cstring(packet, 2) if not server_ver then error "bad handshake initialization packet: bad server version" end self._server_ver = server_ver local thread_id, pos = _get_byte4(packet, pos) local scramble1 = sub(packet, pos, pos + 8 - 1) if not scramble1 then error "1st part of scramble not found" end pos = pos + 9 -- skip filler -- two lower bytes self._server_capabilities, pos = _get_byte2(packet, pos) self._server_lang = strbyte(packet, pos) pos = pos + 1 self._server_status, pos = _get_byte2(packet, pos) local more_capabilities more_capabilities, pos = _get_byte2(packet, pos) self._server_capabilities = self._server_capabilities|more_capabilities<<16 local len = 21 - 8 - 1 pos = pos + 1 + 10 local scramble_part2 = sub(packet, pos, pos + len - 1) if not scramble_part2 then error "2nd part of scramble not found" end local scramble = scramble1..scramble_part2 local token = _compute_token(password, scramble) local client_flags = 260047; local req = strpack("