socketchannel.lua 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. local skynet = require "skynet"
  2. local socket = require "skynet.socket"
  3. local socketdriver = require "skynet.socketdriver"
  4. -- channel support auto reconnect , and capture socket error in request/response transaction
  5. -- { host = "", port = , auth = function(so) , response = function(so) session, data }
  6. local socket_channel = {}
  7. local channel = {}
  8. local channel_socket = {}
  9. local channel_meta = { __index = channel }
  10. local channel_socket_meta = {
  11. __index = channel_socket,
  12. __gc = function(cs)
  13. local fd = cs[1]
  14. cs[1] = false
  15. if fd then
  16. socket.shutdown(fd)
  17. end
  18. end
  19. }
  20. local socket_error = setmetatable({}, {__tostring = function() return "[Error: socket]" end }) -- alias for error object
  21. socket_channel.error = socket_error
  22. function socket_channel.channel(desc)
  23. local c = {
  24. __host = assert(desc.host),
  25. __port = assert(desc.port),
  26. __backup = desc.backup,
  27. __auth = desc.auth,
  28. __response = desc.response, -- It's for session mode
  29. __request = {}, -- request seq { response func or session } -- It's for order mode
  30. __thread = {}, -- coroutine seq or session->coroutine map
  31. __result = {}, -- response result { coroutine -> result }
  32. __result_data = {},
  33. __connecting = {},
  34. __sock = false,
  35. __closed = false,
  36. __authcoroutine = false,
  37. __nodelay = desc.nodelay,
  38. }
  39. return setmetatable(c, channel_meta)
  40. end
  41. local function close_channel_socket(self)
  42. if self.__sock then
  43. local so = self.__sock
  44. self.__sock = false
  45. -- never raise error
  46. pcall(socket.close,so[1])
  47. end
  48. end
  49. local function wakeup_all(self, errmsg)
  50. if self.__response then
  51. for k,co in pairs(self.__thread) do
  52. self.__thread[k] = nil
  53. self.__result[co] = socket_error
  54. self.__result_data[co] = errmsg
  55. skynet.wakeup(co)
  56. end
  57. else
  58. for i = 1, #self.__request do
  59. self.__request[i] = nil
  60. end
  61. for i = 1, #self.__thread do
  62. local co = self.__thread[i]
  63. self.__thread[i] = nil
  64. if co then -- ignore the close signal
  65. self.__result[co] = socket_error
  66. self.__result_data[co] = errmsg
  67. skynet.wakeup(co)
  68. end
  69. end
  70. end
  71. end
  72. local function exit_thread(self)
  73. local co = coroutine.running()
  74. if self.__dispatch_thread == co then
  75. self.__dispatch_thread = nil
  76. local connecting = self.__connecting_thread
  77. if connecting then
  78. skynet.wakeup(connecting)
  79. end
  80. end
  81. end
  82. local function dispatch_by_session(self)
  83. local response = self.__response
  84. -- response() return session
  85. while self.__sock do
  86. local ok , session, result_ok, result_data, padding = pcall(response, self.__sock)
  87. if ok and session then
  88. local co = self.__thread[session]
  89. if co then
  90. if padding and result_ok then
  91. -- If padding is true, append result_data to a table (self.__result_data[co])
  92. local result = self.__result_data[co] or {}
  93. self.__result_data[co] = result
  94. table.insert(result, result_data)
  95. else
  96. self.__thread[session] = nil
  97. self.__result[co] = result_ok
  98. if result_ok and self.__result_data[co] then
  99. table.insert(self.__result_data[co], result_data)
  100. else
  101. self.__result_data[co] = result_data
  102. end
  103. skynet.wakeup(co)
  104. end
  105. else
  106. self.__thread[session] = nil
  107. skynet.error("socket: unknown session :", session)
  108. end
  109. else
  110. close_channel_socket(self)
  111. local errormsg
  112. if session ~= socket_error then
  113. errormsg = session
  114. end
  115. wakeup_all(self, errormsg)
  116. end
  117. end
  118. exit_thread(self)
  119. end
  120. local function pop_response(self)
  121. while true do
  122. local func,co = table.remove(self.__request, 1), table.remove(self.__thread, 1)
  123. if func then
  124. return func, co
  125. end
  126. self.__wait_response = coroutine.running()
  127. skynet.wait(self.__wait_response)
  128. end
  129. end
  130. local function push_response(self, response, co)
  131. if self.__response then
  132. -- response is session
  133. self.__thread[response] = co
  134. else
  135. -- response is a function, push it to __request
  136. table.insert(self.__request, response)
  137. table.insert(self.__thread, co)
  138. if self.__wait_response then
  139. skynet.wakeup(self.__wait_response)
  140. self.__wait_response = nil
  141. end
  142. end
  143. end
  144. local function dispatch_by_order(self)
  145. while self.__sock do
  146. local func, co = pop_response(self)
  147. if not co then
  148. -- close signal
  149. wakeup_all(self, "channel_closed")
  150. break
  151. end
  152. local ok, result_ok, result_data, padding = pcall(func, self.__sock)
  153. if ok then
  154. if padding and result_ok then
  155. -- if padding is true, wait for next result_data
  156. -- self.__result_data[co] is a table
  157. local result = self.__result_data[co] or {}
  158. self.__result_data[co] = result
  159. table.insert(result, result_data)
  160. else
  161. self.__result[co] = result_ok
  162. if result_ok and self.__result_data[co] then
  163. table.insert(self.__result_data[co], result_data)
  164. else
  165. self.__result_data[co] = result_data
  166. end
  167. skynet.wakeup(co)
  168. end
  169. else
  170. close_channel_socket(self)
  171. local errmsg
  172. if result_ok ~= socket_error then
  173. errmsg = result_ok
  174. end
  175. self.__result[co] = socket_error
  176. self.__result_data[co] = errmsg
  177. skynet.wakeup(co)
  178. wakeup_all(self, errmsg)
  179. end
  180. end
  181. exit_thread(self)
  182. end
  183. local function dispatch_function(self)
  184. if self.__response then
  185. return dispatch_by_session
  186. else
  187. return dispatch_by_order
  188. end
  189. end
  190. local function connect_backup(self)
  191. if self.__backup then
  192. for _, addr in ipairs(self.__backup) do
  193. local host, port
  194. if type(addr) == "table" then
  195. host, port = addr.host, addr.port
  196. else
  197. host = addr
  198. port = self.__port
  199. end
  200. skynet.error("socket: connect to backup host", host, port)
  201. local fd = socket.open(host, port)
  202. if fd then
  203. self.__host = host
  204. self.__port = port
  205. return fd
  206. end
  207. end
  208. end
  209. end
  210. local function connect_once(self)
  211. if self.__closed then
  212. return false
  213. end
  214. assert(not self.__sock and not self.__authcoroutine)
  215. local fd,err = socket.open(self.__host, self.__port)
  216. if not fd then
  217. fd = connect_backup(self)
  218. if not fd then
  219. return false, err
  220. end
  221. end
  222. if self.__nodelay then
  223. socketdriver.nodelay(fd)
  224. end
  225. self.__sock = setmetatable( {fd} , channel_socket_meta )
  226. self.__dispatch_thread = skynet.fork(dispatch_function(self), self)
  227. if self.__auth then
  228. self.__authcoroutine = coroutine.running()
  229. local ok , message = pcall(self.__auth, self)
  230. if not ok then
  231. close_channel_socket(self)
  232. if message ~= socket_error then
  233. self.__authcoroutine = false
  234. skynet.error("socket: auth failed", message)
  235. end
  236. end
  237. self.__authcoroutine = false
  238. if ok and not self.__sock then
  239. -- auth may change host, so connect again
  240. return connect_once(self)
  241. end
  242. return ok
  243. end
  244. return true
  245. end
  246. local function try_connect(self , once)
  247. local t = 0
  248. while not self.__closed do
  249. local ok, err = connect_once(self)
  250. if ok then
  251. if not once then
  252. skynet.error("socket: connect to", self.__host, self.__port)
  253. end
  254. return
  255. elseif once then
  256. return err
  257. else
  258. skynet.error("socket: connect", err)
  259. end
  260. if t > 1000 then
  261. skynet.error("socket: try to reconnect", self.__host, self.__port)
  262. skynet.sleep(t)
  263. t = 0
  264. else
  265. skynet.sleep(t)
  266. end
  267. t = t + 100
  268. end
  269. end
  270. local function check_connection(self)
  271. if self.__sock then
  272. if socket.disconnected(self.__sock[1]) then
  273. -- closed by peer
  274. skynet.error("socket: disconnect detected ", self.__host, self.__port)
  275. close_channel_socket(self)
  276. return
  277. end
  278. local authco = self.__authcoroutine
  279. if not authco then
  280. return true
  281. end
  282. if authco == coroutine.running() then
  283. -- authing
  284. return true
  285. end
  286. end
  287. if self.__closed then
  288. return false
  289. end
  290. end
  291. local function block_connect(self, once)
  292. local r = check_connection(self)
  293. if r ~= nil then
  294. return r
  295. end
  296. local err
  297. if #self.__connecting > 0 then
  298. -- connecting in other coroutine
  299. local co = coroutine.running()
  300. table.insert(self.__connecting, co)
  301. skynet.wait(co)
  302. else
  303. self.__connecting[1] = true
  304. err = try_connect(self, once)
  305. self.__connecting[1] = nil
  306. for i=2, #self.__connecting do
  307. local co = self.__connecting[i]
  308. self.__connecting[i] = nil
  309. skynet.wakeup(co)
  310. end
  311. end
  312. r = check_connection(self)
  313. if r == nil then
  314. skynet.error(string.format("Connect to %s:%d failed (%s)", self.__host, self.__port, err))
  315. error(socket_error)
  316. else
  317. return r
  318. end
  319. end
  320. function channel:connect(once)
  321. if self.__closed then
  322. if self.__dispatch_thread then
  323. -- closing, wait
  324. assert(self.__connecting_thread == nil, "already connecting")
  325. local co = coroutine.running()
  326. self.__connecting_thread = co
  327. skynet.wait(co)
  328. self.__connecting_thread = nil
  329. end
  330. self.__closed = false
  331. end
  332. return block_connect(self, once)
  333. end
  334. local function wait_for_response(self, response)
  335. local co = coroutine.running()
  336. push_response(self, response, co)
  337. skynet.wait(co)
  338. local result = self.__result[co]
  339. self.__result[co] = nil
  340. local result_data = self.__result_data[co]
  341. self.__result_data[co] = nil
  342. if result == socket_error then
  343. if result_data then
  344. error(result_data)
  345. else
  346. error(socket_error)
  347. end
  348. else
  349. assert(result, result_data)
  350. return result_data
  351. end
  352. end
  353. local socket_write = socket.write
  354. local socket_lwrite = socket.lwrite
  355. local function sock_err(self)
  356. close_channel_socket(self)
  357. wakeup_all(self)
  358. error(socket_error)
  359. end
  360. function channel:request(request, response, padding)
  361. assert(block_connect(self, true)) -- connect once
  362. local fd = self.__sock[1]
  363. if padding then
  364. -- padding may be a table, to support multi part request
  365. -- multi part request use low priority socket write
  366. -- now socket_lwrite returns as socket_write
  367. if not socket_lwrite(fd , request) then
  368. sock_err(self)
  369. end
  370. for _,v in ipairs(padding) do
  371. if not socket_lwrite(fd, v) then
  372. sock_err(self)
  373. end
  374. end
  375. else
  376. if not socket_write(fd , request) then
  377. sock_err(self)
  378. end
  379. end
  380. if response == nil then
  381. -- no response
  382. return
  383. end
  384. return wait_for_response(self, response)
  385. end
  386. function channel:response(response)
  387. assert(block_connect(self))
  388. return wait_for_response(self, response)
  389. end
  390. function channel:close()
  391. if not self.__closed then
  392. local thread = self.__dispatch_thread
  393. self.__closed = true
  394. close_channel_socket(self)
  395. if not self.__response and self.__dispatch_thread == thread and thread then
  396. -- dispatch by order, send close signal to dispatch thread
  397. push_response(self, true, false) -- (true, false) is close signal
  398. end
  399. end
  400. end
  401. function channel:changehost(host, port)
  402. self.__host = host
  403. if port then
  404. self.__port = port
  405. end
  406. if not self.__closed then
  407. close_channel_socket(self)
  408. end
  409. end
  410. function channel:changebackup(backup)
  411. self.__backup = backup
  412. end
  413. channel_meta.__gc = channel.close
  414. local function wrapper_socket_function(f)
  415. return function(self, ...)
  416. local result = f(self[1], ...)
  417. if not result then
  418. error(socket_error)
  419. else
  420. return result
  421. end
  422. end
  423. end
  424. channel_socket.read = wrapper_socket_function(socket.read)
  425. channel_socket.readline = wrapper_socket_function(socket.readline)
  426. return socket_channel