msgserver.lua 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. local skynet = require "skynet"
  2. local gateserver = require "snax.gateserver"
  3. local netpack = require "skynet.netpack"
  4. local crypt = require "skynet.crypt"
  5. local socketdriver = require "skynet.socketdriver"
  6. local assert = assert
  7. local b64encode = crypt.base64encode
  8. local b64decode = crypt.base64decode
  9. --[[
  10. Protocol:
  11. All the number type is big-endian
  12. Shakehands (The first package)
  13. Client -> Server :
  14. base64(uid)@base64(server)#base64(subid):index:base64(hmac)
  15. Server -> Client
  16. XXX ErrorCode
  17. 404 User Not Found
  18. 403 Index Expired
  19. 401 Unauthorized
  20. 400 Bad Request
  21. 200 OK
  22. Req-Resp
  23. Client -> Server : Request
  24. word size (Not include self)
  25. string content (size-4)
  26. dword session
  27. Server -> Client : Response
  28. word size (Not include self)
  29. string content (size-5)
  30. byte ok (1 is ok, 0 is error)
  31. dword session
  32. API:
  33. server.userid(username)
  34. return uid, subid, server
  35. server.username(uid, subid, server)
  36. return username
  37. server.login(username, secret)
  38. update user secret
  39. server.logout(username)
  40. user logout
  41. server.ip(username)
  42. return ip when connection establish, or nil
  43. server.start(conf)
  44. start server
  45. Supported skynet command:
  46. kick username (may used by loginserver)
  47. login username secret (used by loginserver)
  48. logout username (used by agent)
  49. Config for server.start:
  50. conf.expired_number : the number of the response message cached after sending out (default is 128)
  51. conf.login_handler(uid, secret) -> subid : the function when a new user login, alloc a subid for it. (may call by login server)
  52. conf.logout_handler(uid, subid) : the functon when a user logout. (may call by agent)
  53. conf.kick_handler(uid, subid) : the functon when a user logout. (may call by login server)
  54. conf.request_handler(username, session, msg) : the function when recv a new request.
  55. conf.register_handler(servername) : call when gate open
  56. conf.disconnect_handler(username) : call when a connection disconnect (afk)
  57. ]]
  58. local server = {}
  59. skynet.register_protocol {
  60. name = "client",
  61. id = skynet.PTYPE_CLIENT,
  62. }
  63. local user_online = {}
  64. local handshake = {}
  65. local connection = {}
  66. function server.userid(username)
  67. -- base64(uid)@base64(server)#base64(subid)
  68. local uid, servername, subid = username:match "([^@]*)@([^#]*)#(.*)"
  69. return b64decode(uid), b64decode(subid), b64decode(servername)
  70. end
  71. function server.username(uid, subid, servername)
  72. return string.format("%s@%s#%s", b64encode(uid), b64encode(servername), b64encode(tostring(subid)))
  73. end
  74. function server.logout(username)
  75. local u = user_online[username]
  76. user_online[username] = nil
  77. if u.fd then
  78. gateserver.closeclient(u.fd)
  79. connection[u.fd] = nil
  80. end
  81. end
  82. function server.login(username, secret)
  83. assert(user_online[username] == nil)
  84. user_online[username] = {
  85. secret = secret,
  86. version = 0,
  87. index = 0,
  88. username = username,
  89. response = {}, -- response cache
  90. }
  91. end
  92. function server.ip(username)
  93. local u = user_online[username]
  94. if u and u.fd then
  95. return u.ip
  96. end
  97. end
  98. function server.start(conf)
  99. local expired_number = conf.expired_number or 128
  100. local handler = {}
  101. local CMD = {
  102. login = assert(conf.login_handler),
  103. logout = assert(conf.logout_handler),
  104. kick = assert(conf.kick_handler),
  105. }
  106. function handler.command(cmd, source, ...)
  107. local f = assert(CMD[cmd])
  108. return f(...)
  109. end
  110. function handler.open(source, gateconf)
  111. local servername = assert(gateconf.servername)
  112. return conf.register_handler(servername)
  113. end
  114. function handler.connect(fd, addr)
  115. handshake[fd] = addr
  116. gateserver.openclient(fd)
  117. end
  118. function handler.disconnect(fd)
  119. handshake[fd] = nil
  120. local c = connection[fd]
  121. if c then
  122. c.fd = nil
  123. connection[fd] = nil
  124. if conf.disconnect_handler then
  125. conf.disconnect_handler(c.username)
  126. end
  127. end
  128. end
  129. handler.error = handler.disconnect
  130. -- atomic , no yield
  131. local function do_auth(fd, message, addr)
  132. local username, index, hmac = string.match(message, "([^:]*):([^:]*):([^:]*)")
  133. local u = user_online[username]
  134. if u == nil then
  135. return "404 User Not Found"
  136. end
  137. local idx = assert(tonumber(index))
  138. hmac = b64decode(hmac)
  139. if idx <= u.version then
  140. return "403 Index Expired"
  141. end
  142. local text = string.format("%s:%s", username, index)
  143. local v = crypt.hmac_hash(u.secret, text) -- equivalent to crypt.hmac64(crypt.hashkey(text), u.secret)
  144. if v ~= hmac then
  145. return "401 Unauthorized"
  146. end
  147. u.version = idx
  148. u.fd = fd
  149. u.ip = addr
  150. connection[fd] = u
  151. end
  152. local function auth(fd, addr, msg, sz)
  153. local message = netpack.tostring(msg, sz)
  154. local ok, result = pcall(do_auth, fd, message, addr)
  155. if not ok then
  156. skynet.error(result)
  157. result = "400 Bad Request"
  158. end
  159. local close = result ~= nil
  160. if result == nil then
  161. result = "200 OK"
  162. end
  163. socketdriver.send(fd, netpack.pack(result))
  164. if close then
  165. gateserver.closeclient(fd)
  166. end
  167. end
  168. local request_handler = assert(conf.request_handler)
  169. -- u.response is a struct { return_fd , response, version, index }
  170. local function retire_response(u)
  171. if u.index >= expired_number * 2 then
  172. local max = 0
  173. local response = u.response
  174. for k,p in pairs(response) do
  175. if p[1] == nil then
  176. -- request complete, check expired
  177. if p[4] < expired_number then
  178. response[k] = nil
  179. else
  180. p[4] = p[4] - expired_number
  181. if p[4] > max then
  182. max = p[4]
  183. end
  184. end
  185. end
  186. end
  187. u.index = max + 1
  188. end
  189. end
  190. local function do_request(fd, message)
  191. local u = assert(connection[fd], "invalid fd")
  192. local session = string.unpack(">I4", message, -4)
  193. message = message:sub(1,-5)
  194. local p = u.response[session]
  195. if p then
  196. -- session can be reuse in the same connection
  197. if p[3] == u.version then
  198. local last = u.response[session]
  199. u.response[session] = nil
  200. p = nil
  201. if last[2] == nil then
  202. local error_msg = string.format("Conflict session %s", crypt.hexencode(session))
  203. skynet.error(error_msg)
  204. error(error_msg)
  205. end
  206. end
  207. end
  208. if p == nil then
  209. p = { fd }
  210. u.response[session] = p
  211. local ok, result = pcall(conf.request_handler, u.username, message)
  212. -- NOTICE: YIELD here, socket may close.
  213. result = result or ""
  214. if not ok then
  215. skynet.error(result)
  216. result = string.pack(">BI4", 0, session)
  217. else
  218. result = result .. string.pack(">BI4", 1, session)
  219. end
  220. p[2] = string.pack(">s2",result)
  221. p[3] = u.version
  222. p[4] = u.index
  223. else
  224. -- update version/index, change return fd.
  225. -- resend response.
  226. p[1] = fd
  227. p[3] = u.version
  228. p[4] = u.index
  229. if p[2] == nil then
  230. -- already request, but response is not ready
  231. return
  232. end
  233. end
  234. u.index = u.index + 1
  235. -- the return fd is p[1] (fd may change by multi request) check connect
  236. fd = p[1]
  237. if connection[fd] then
  238. socketdriver.send(fd, p[2])
  239. end
  240. p[1] = nil
  241. retire_response(u)
  242. end
  243. local function request(fd, msg, sz)
  244. local message = netpack.tostring(msg, sz)
  245. local ok, err = pcall(do_request, fd, message)
  246. -- not atomic, may yield
  247. if not ok then
  248. skynet.error(string.format("Invalid package %s : %s", err, message))
  249. if connection[fd] then
  250. gateserver.closeclient(fd)
  251. end
  252. end
  253. end
  254. function handler.message(fd, msg, sz)
  255. local addr = handshake[fd]
  256. if addr then
  257. auth(fd,addr,msg,sz)
  258. handshake[fd] = nil
  259. else
  260. request(fd, msg, sz)
  261. end
  262. end
  263. return gateserver.start(handler)
  264. end
  265. return server