socket.lua 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. local driver = require "skynet.socketdriver"
  2. local skynet = require "skynet"
  3. local skynet_core = require "skynet.core"
  4. local assert = assert
  5. local socket = {} -- api
  6. local buffer_pool = {} -- store all message buffer object
  7. local socket_pool = setmetatable( -- store all socket object
  8. {},
  9. { __gc = function(p)
  10. for id,v in pairs(p) do
  11. driver.close(id)
  12. -- don't need clear v.buffer, because buffer pool will be free at the end
  13. p[id] = nil
  14. end
  15. end
  16. }
  17. )
  18. local socket_message = {}
  19. local function wakeup(s)
  20. local co = s.co
  21. if co then
  22. s.co = nil
  23. skynet.wakeup(co)
  24. end
  25. end
  26. local function suspend(s)
  27. assert(not s.co)
  28. s.co = coroutine.running()
  29. skynet.wait(s.co)
  30. -- wakeup closing corouting every time suspend,
  31. -- because socket.close() will wait last socket buffer operation before clear the buffer.
  32. if s.closing then
  33. skynet.wakeup(s.closing)
  34. end
  35. end
  36. -- read skynet_socket.h for these macro
  37. -- SKYNET_SOCKET_TYPE_DATA = 1
  38. socket_message[1] = function(id, size, data)
  39. local s = socket_pool[id]
  40. if s == nil then
  41. skynet.error("socket: drop package from " .. id)
  42. driver.drop(data, size)
  43. return
  44. end
  45. local sz = driver.push(s.buffer, buffer_pool, data, size)
  46. local rr = s.read_required
  47. local rrt = type(rr)
  48. if rrt == "number" then
  49. -- read size
  50. if sz >= rr then
  51. s.read_required = nil
  52. wakeup(s)
  53. end
  54. else
  55. if s.buffer_limit and sz > s.buffer_limit then
  56. skynet.error(string.format("socket buffer overflow: fd=%d size=%d", id , sz))
  57. driver.clear(s.buffer,buffer_pool)
  58. driver.close(id)
  59. return
  60. end
  61. if rrt == "string" then
  62. -- read line
  63. if driver.readline(s.buffer,nil,rr) then
  64. s.read_required = nil
  65. wakeup(s)
  66. end
  67. end
  68. end
  69. end
  70. -- SKYNET_SOCKET_TYPE_CONNECT = 2
  71. socket_message[2] = function(id, _ , addr)
  72. local s = socket_pool[id]
  73. if s == nil then
  74. return
  75. end
  76. -- log remote addr
  77. s.connected = true
  78. wakeup(s)
  79. end
  80. -- SKYNET_SOCKET_TYPE_CLOSE = 3
  81. socket_message[3] = function(id)
  82. local s = socket_pool[id]
  83. if s == nil then
  84. return
  85. end
  86. s.connected = false
  87. wakeup(s)
  88. end
  89. -- SKYNET_SOCKET_TYPE_ACCEPT = 4
  90. socket_message[4] = function(id, newid, addr)
  91. local s = socket_pool[id]
  92. if s == nil then
  93. driver.close(newid)
  94. return
  95. end
  96. s.callback(newid, addr)
  97. end
  98. -- SKYNET_SOCKET_TYPE_ERROR = 5
  99. socket_message[5] = function(id, _, err)
  100. local s = socket_pool[id]
  101. if s == nil then
  102. skynet.error("socket: error on unknown", id, err)
  103. return
  104. end
  105. if s.connected then
  106. skynet.error("socket: error on", id, err)
  107. elseif s.connecting then
  108. s.connecting = err
  109. end
  110. s.connected = false
  111. driver.shutdown(id)
  112. wakeup(s)
  113. end
  114. -- SKYNET_SOCKET_TYPE_UDP = 6
  115. socket_message[6] = function(id, size, data, address)
  116. local s = socket_pool[id]
  117. if s == nil or s.callback == nil then
  118. skynet.error("socket: drop udp package from " .. id)
  119. driver.drop(data, size)
  120. return
  121. end
  122. local str = skynet.tostring(data, size)
  123. skynet_core.trash(data, size)
  124. s.callback(str, address)
  125. end
  126. local function default_warning(id, size)
  127. local s = socket_pool[id]
  128. if not s then
  129. return
  130. end
  131. skynet.error(string.format("WARNING: %d K bytes need to send out (fd = %d)", size, id))
  132. end
  133. -- SKYNET_SOCKET_TYPE_WARNING
  134. socket_message[7] = function(id, size)
  135. local s = socket_pool[id]
  136. if s then
  137. local warning = s.on_warning or default_warning
  138. warning(id, size)
  139. end
  140. end
  141. skynet.register_protocol {
  142. name = "socket",
  143. id = skynet.PTYPE_SOCKET, -- PTYPE_SOCKET = 6
  144. unpack = driver.unpack,
  145. dispatch = function (_, _, t, ...)
  146. socket_message[t](...)
  147. end
  148. }
  149. local function connect(id, func)
  150. local newbuffer
  151. if func == nil then
  152. newbuffer = driver.buffer()
  153. end
  154. local s = {
  155. id = id,
  156. buffer = newbuffer,
  157. connected = false,
  158. connecting = true,
  159. read_required = false,
  160. co = false,
  161. callback = func,
  162. protocol = "TCP",
  163. }
  164. assert(not socket_pool[id], "socket is not closed")
  165. socket_pool[id] = s
  166. suspend(s)
  167. local err = s.connecting
  168. s.connecting = nil
  169. if s.connected then
  170. return id
  171. else
  172. socket_pool[id] = nil
  173. return nil, err
  174. end
  175. end
  176. function socket.open(addr, port)
  177. local id = driver.connect(addr,port)
  178. return connect(id)
  179. end
  180. function socket.bind(os_fd)
  181. local id = driver.bind(os_fd)
  182. return connect(id)
  183. end
  184. function socket.stdin()
  185. return socket.bind(0)
  186. end
  187. function socket.start(id, func)
  188. driver.start(id)
  189. return connect(id, func)
  190. end
  191. function socket.shutdown(id)
  192. local s = socket_pool[id]
  193. if s then
  194. driver.clear(s.buffer,buffer_pool)
  195. -- the framework would send SKYNET_SOCKET_TYPE_CLOSE , need close(id) later
  196. driver.shutdown(id)
  197. end
  198. end
  199. function socket.close_fd(id)
  200. assert(socket_pool[id] == nil,"Use socket.close instead")
  201. driver.close(id)
  202. end
  203. function socket.close(id)
  204. local s = socket_pool[id]
  205. if s == nil then
  206. return
  207. end
  208. if s.connected then
  209. driver.close(id)
  210. -- notice: call socket.close in __gc should be carefully,
  211. -- because skynet.wait never return in __gc, so driver.clear may not be called
  212. if s.co then
  213. -- reading this socket on another coroutine, so don't shutdown (clear the buffer) immediately
  214. -- wait reading coroutine read the buffer.
  215. assert(not s.closing)
  216. s.closing = coroutine.running()
  217. skynet.wait(s.closing)
  218. else
  219. suspend(s)
  220. end
  221. s.connected = false
  222. end
  223. driver.clear(s.buffer,buffer_pool)
  224. assert(s.lock == nil or next(s.lock) == nil)
  225. socket_pool[id] = nil
  226. end
  227. function socket.read(id, sz)
  228. local s = socket_pool[id]
  229. assert(s)
  230. if sz == nil then
  231. -- read some bytes
  232. local ret = driver.readall(s.buffer, buffer_pool)
  233. if ret ~= "" then
  234. return ret
  235. end
  236. if not s.connected then
  237. return false, ret
  238. end
  239. assert(not s.read_required)
  240. s.read_required = 0
  241. suspend(s)
  242. ret = driver.readall(s.buffer, buffer_pool)
  243. if ret ~= "" then
  244. return ret
  245. else
  246. return false, ret
  247. end
  248. end
  249. local ret = driver.pop(s.buffer, buffer_pool, sz)
  250. if ret then
  251. return ret
  252. end
  253. if not s.connected then
  254. return false, driver.readall(s.buffer, buffer_pool)
  255. end
  256. assert(not s.read_required)
  257. s.read_required = sz
  258. suspend(s)
  259. ret = driver.pop(s.buffer, buffer_pool, sz)
  260. if ret then
  261. return ret
  262. else
  263. return false, driver.readall(s.buffer, buffer_pool)
  264. end
  265. end
  266. function socket.readall(id)
  267. local s = socket_pool[id]
  268. assert(s)
  269. if not s.connected then
  270. local r = driver.readall(s.buffer, buffer_pool)
  271. return r ~= "" and r
  272. end
  273. assert(not s.read_required)
  274. s.read_required = true
  275. suspend(s)
  276. assert(s.connected == false)
  277. return driver.readall(s.buffer, buffer_pool)
  278. end
  279. function socket.readline(id, sep)
  280. sep = sep or "\n"
  281. local s = socket_pool[id]
  282. assert(s)
  283. local ret = driver.readline(s.buffer, buffer_pool, sep)
  284. if ret then
  285. return ret
  286. end
  287. if not s.connected then
  288. return false, driver.readall(s.buffer, buffer_pool)
  289. end
  290. assert(not s.read_required)
  291. s.read_required = sep
  292. suspend(s)
  293. if s.connected then
  294. return driver.readline(s.buffer, buffer_pool, sep)
  295. else
  296. return false, driver.readall(s.buffer, buffer_pool)
  297. end
  298. end
  299. function socket.block(id)
  300. local s = socket_pool[id]
  301. if not s or not s.connected then
  302. return false
  303. end
  304. assert(not s.read_required)
  305. s.read_required = 0
  306. suspend(s)
  307. return s.connected
  308. end
  309. socket.write = assert(driver.send)
  310. socket.lwrite = assert(driver.lsend)
  311. socket.header = assert(driver.header)
  312. function socket.invalid(id)
  313. return socket_pool[id] == nil
  314. end
  315. function socket.disconnected(id)
  316. local s = socket_pool[id]
  317. if s then
  318. return not(s.connected or s.connecting)
  319. end
  320. end
  321. function socket.listen(host, port, backlog)
  322. if port == nil then
  323. host, port = string.match(host, "([^:]+):(.+)$")
  324. port = tonumber(port)
  325. end
  326. return driver.listen(host, port, backlog)
  327. end
  328. function socket.lock(id)
  329. local s = socket_pool[id]
  330. assert(s)
  331. local lock_set = s.lock
  332. if not lock_set then
  333. lock_set = {}
  334. s.lock = lock_set
  335. end
  336. if #lock_set == 0 then
  337. lock_set[1] = true
  338. else
  339. local co = coroutine.running()
  340. table.insert(lock_set, co)
  341. skynet.wait(co)
  342. end
  343. end
  344. function socket.unlock(id)
  345. local s = socket_pool[id]
  346. assert(s)
  347. local lock_set = assert(s.lock)
  348. table.remove(lock_set,1)
  349. local co = lock_set[1]
  350. if co then
  351. skynet.wakeup(co)
  352. end
  353. end
  354. -- abandon use to forward socket id to other service
  355. -- you must call socket.start(id) later in other service
  356. function socket.abandon(id)
  357. local s = socket_pool[id]
  358. if s then
  359. driver.clear(s.buffer,buffer_pool)
  360. s.connected = false
  361. wakeup(s)
  362. socket_pool[id] = nil
  363. end
  364. end
  365. function socket.limit(id, limit)
  366. local s = assert(socket_pool[id])
  367. s.buffer_limit = limit
  368. end
  369. ---------------------- UDP
  370. local function create_udp_object(id, cb)
  371. assert(not socket_pool[id], "socket is not closed")
  372. socket_pool[id] = {
  373. id = id,
  374. connected = true,
  375. protocol = "UDP",
  376. callback = cb,
  377. }
  378. end
  379. function socket.udp(callback, host, port)
  380. local id = driver.udp(host, port)
  381. create_udp_object(id, callback)
  382. return id
  383. end
  384. function socket.udp_connect(id, addr, port, callback)
  385. local obj = socket_pool[id]
  386. if obj then
  387. assert(obj.protocol == "UDP")
  388. if callback then
  389. obj.callback = callback
  390. end
  391. else
  392. create_udp_object(id, callback)
  393. end
  394. driver.udp_connect(id, addr, port)
  395. end
  396. socket.sendto = assert(driver.udp_send)
  397. socket.udp_address = assert(driver.udp_address)
  398. function socket.warning(id, callback)
  399. local obj = socket_pool[id]
  400. assert(obj)
  401. obj.on_warning = callback
  402. end
  403. return socket