123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449 |
- local driver = require "skynet.socketdriver"
- local skynet = require "skynet"
- local skynet_core = require "skynet.core"
- local assert = assert
- local socket = {} -- api
- local buffer_pool = {} -- store all message buffer object
- local socket_pool = setmetatable( -- store all socket object
- {},
- { __gc = function(p)
- for id,v in pairs(p) do
- driver.close(id)
- -- don't need clear v.buffer, because buffer pool will be free at the end
- p[id] = nil
- end
- end
- }
- )
- local socket_message = {}
- local function wakeup(s)
- local co = s.co
- if co then
- s.co = nil
- skynet.wakeup(co)
- end
- end
- local function suspend(s)
- assert(not s.co)
- s.co = coroutine.running()
- skynet.wait(s.co)
- -- wakeup closing corouting every time suspend,
- -- because socket.close() will wait last socket buffer operation before clear the buffer.
- if s.closing then
- skynet.wakeup(s.closing)
- end
- end
- -- read skynet_socket.h for these macro
- -- SKYNET_SOCKET_TYPE_DATA = 1
- socket_message[1] = function(id, size, data)
- local s = socket_pool[id]
- if s == nil then
- skynet.error("socket: drop package from " .. id)
- driver.drop(data, size)
- return
- end
- local sz = driver.push(s.buffer, buffer_pool, data, size)
- local rr = s.read_required
- local rrt = type(rr)
- if rrt == "number" then
- -- read size
- if sz >= rr then
- s.read_required = nil
- wakeup(s)
- end
- else
- if s.buffer_limit and sz > s.buffer_limit then
- skynet.error(string.format("socket buffer overflow: fd=%d size=%d", id , sz))
- driver.clear(s.buffer,buffer_pool)
- driver.close(id)
- return
- end
- if rrt == "string" then
- -- read line
- if driver.readline(s.buffer,nil,rr) then
- s.read_required = nil
- wakeup(s)
- end
- end
- end
- end
- -- SKYNET_SOCKET_TYPE_CONNECT = 2
- socket_message[2] = function(id, _ , addr)
- local s = socket_pool[id]
- if s == nil then
- return
- end
- -- log remote addr
- s.connected = true
- wakeup(s)
- end
- -- SKYNET_SOCKET_TYPE_CLOSE = 3
- socket_message[3] = function(id)
- local s = socket_pool[id]
- if s == nil then
- return
- end
- s.connected = false
- wakeup(s)
- end
- -- SKYNET_SOCKET_TYPE_ACCEPT = 4
- socket_message[4] = function(id, newid, addr)
- local s = socket_pool[id]
- if s == nil then
- driver.close(newid)
- return
- end
- s.callback(newid, addr)
- end
- -- SKYNET_SOCKET_TYPE_ERROR = 5
- socket_message[5] = function(id, _, err)
- local s = socket_pool[id]
- if s == nil then
- skynet.error("socket: error on unknown", id, err)
- return
- end
- if s.connected then
- skynet.error("socket: error on", id, err)
- elseif s.connecting then
- s.connecting = err
- end
- s.connected = false
- driver.shutdown(id)
- wakeup(s)
- end
- -- SKYNET_SOCKET_TYPE_UDP = 6
- socket_message[6] = function(id, size, data, address)
- local s = socket_pool[id]
- if s == nil or s.callback == nil then
- skynet.error("socket: drop udp package from " .. id)
- driver.drop(data, size)
- return
- end
- local str = skynet.tostring(data, size)
- skynet_core.trash(data, size)
- s.callback(str, address)
- end
- local function default_warning(id, size)
- local s = socket_pool[id]
- if not s then
- return
- end
- skynet.error(string.format("WARNING: %d K bytes need to send out (fd = %d)", size, id))
- end
- -- SKYNET_SOCKET_TYPE_WARNING
- socket_message[7] = function(id, size)
- local s = socket_pool[id]
- if s then
- local warning = s.on_warning or default_warning
- warning(id, size)
- end
- end
- skynet.register_protocol {
- name = "socket",
- id = skynet.PTYPE_SOCKET, -- PTYPE_SOCKET = 6
- unpack = driver.unpack,
- dispatch = function (_, _, t, ...)
- socket_message[t](...)
- end
- }
- local function connect(id, func)
- local newbuffer
- if func == nil then
- newbuffer = driver.buffer()
- end
- local s = {
- id = id,
- buffer = newbuffer,
- connected = false,
- connecting = true,
- read_required = false,
- co = false,
- callback = func,
- protocol = "TCP",
- }
- assert(not socket_pool[id], "socket is not closed")
- socket_pool[id] = s
- suspend(s)
- local err = s.connecting
- s.connecting = nil
- if s.connected then
- return id
- else
- socket_pool[id] = nil
- return nil, err
- end
- end
- function socket.open(addr, port)
- local id = driver.connect(addr,port)
- return connect(id)
- end
- function socket.bind(os_fd)
- local id = driver.bind(os_fd)
- return connect(id)
- end
- function socket.stdin()
- return socket.bind(0)
- end
- function socket.start(id, func)
- driver.start(id)
- return connect(id, func)
- end
- function socket.shutdown(id)
- local s = socket_pool[id]
- if s then
- driver.clear(s.buffer,buffer_pool)
- -- the framework would send SKYNET_SOCKET_TYPE_CLOSE , need close(id) later
- driver.shutdown(id)
- end
- end
- function socket.close_fd(id)
- assert(socket_pool[id] == nil,"Use socket.close instead")
- driver.close(id)
- end
- function socket.close(id)
- local s = socket_pool[id]
- if s == nil then
- return
- end
- if s.connected then
- driver.close(id)
- -- notice: call socket.close in __gc should be carefully,
- -- because skynet.wait never return in __gc, so driver.clear may not be called
- if s.co then
- -- reading this socket on another coroutine, so don't shutdown (clear the buffer) immediately
- -- wait reading coroutine read the buffer.
- assert(not s.closing)
- s.closing = coroutine.running()
- skynet.wait(s.closing)
- else
- suspend(s)
- end
- s.connected = false
- end
- driver.clear(s.buffer,buffer_pool)
- assert(s.lock == nil or next(s.lock) == nil)
- socket_pool[id] = nil
- end
- function socket.read(id, sz)
- local s = socket_pool[id]
- assert(s)
- if sz == nil then
- -- read some bytes
- local ret = driver.readall(s.buffer, buffer_pool)
- if ret ~= "" then
- return ret
- end
- if not s.connected then
- return false, ret
- end
- assert(not s.read_required)
- s.read_required = 0
- suspend(s)
- ret = driver.readall(s.buffer, buffer_pool)
- if ret ~= "" then
- return ret
- else
- return false, ret
- end
- end
- local ret = driver.pop(s.buffer, buffer_pool, sz)
- if ret then
- return ret
- end
- if not s.connected then
- return false, driver.readall(s.buffer, buffer_pool)
- end
- assert(not s.read_required)
- s.read_required = sz
- suspend(s)
- ret = driver.pop(s.buffer, buffer_pool, sz)
- if ret then
- return ret
- else
- return false, driver.readall(s.buffer, buffer_pool)
- end
- end
- function socket.readall(id)
- local s = socket_pool[id]
- assert(s)
- if not s.connected then
- local r = driver.readall(s.buffer, buffer_pool)
- return r ~= "" and r
- end
- assert(not s.read_required)
- s.read_required = true
- suspend(s)
- assert(s.connected == false)
- return driver.readall(s.buffer, buffer_pool)
- end
- function socket.readline(id, sep)
- sep = sep or "\n"
- local s = socket_pool[id]
- assert(s)
- local ret = driver.readline(s.buffer, buffer_pool, sep)
- if ret then
- return ret
- end
- if not s.connected then
- return false, driver.readall(s.buffer, buffer_pool)
- end
- assert(not s.read_required)
- s.read_required = sep
- suspend(s)
- if s.connected then
- return driver.readline(s.buffer, buffer_pool, sep)
- else
- return false, driver.readall(s.buffer, buffer_pool)
- end
- end
- function socket.block(id)
- local s = socket_pool[id]
- if not s or not s.connected then
- return false
- end
- assert(not s.read_required)
- s.read_required = 0
- suspend(s)
- return s.connected
- end
- socket.write = assert(driver.send)
- socket.lwrite = assert(driver.lsend)
- socket.header = assert(driver.header)
- function socket.invalid(id)
- return socket_pool[id] == nil
- end
- function socket.disconnected(id)
- local s = socket_pool[id]
- if s then
- return not(s.connected or s.connecting)
- end
- end
- function socket.listen(host, port, backlog)
- if port == nil then
- host, port = string.match(host, "([^:]+):(.+)$")
- port = tonumber(port)
- end
- return driver.listen(host, port, backlog)
- end
- function socket.lock(id)
- local s = socket_pool[id]
- assert(s)
- local lock_set = s.lock
- if not lock_set then
- lock_set = {}
- s.lock = lock_set
- end
- if #lock_set == 0 then
- lock_set[1] = true
- else
- local co = coroutine.running()
- table.insert(lock_set, co)
- skynet.wait(co)
- end
- end
- function socket.unlock(id)
- local s = socket_pool[id]
- assert(s)
- local lock_set = assert(s.lock)
- table.remove(lock_set,1)
- local co = lock_set[1]
- if co then
- skynet.wakeup(co)
- end
- end
- -- abandon use to forward socket id to other service
- -- you must call socket.start(id) later in other service
- function socket.abandon(id)
- local s = socket_pool[id]
- if s then
- driver.clear(s.buffer,buffer_pool)
- s.connected = false
- wakeup(s)
- socket_pool[id] = nil
- end
- end
- function socket.limit(id, limit)
- local s = assert(socket_pool[id])
- s.buffer_limit = limit
- end
- ---------------------- UDP
- local function create_udp_object(id, cb)
- assert(not socket_pool[id], "socket is not closed")
- socket_pool[id] = {
- id = id,
- connected = true,
- protocol = "UDP",
- callback = cb,
- }
- end
- function socket.udp(callback, host, port)
- local id = driver.udp(host, port)
- create_udp_object(id, callback)
- return id
- end
- function socket.udp_connect(id, addr, port, callback)
- local obj = socket_pool[id]
- if obj then
- assert(obj.protocol == "UDP")
- if callback then
- obj.callback = callback
- end
- else
- create_udp_object(id, callback)
- end
- driver.udp_connect(id, addr, port)
- end
- socket.sendto = assert(driver.udp_send)
- socket.udp_address = assert(driver.udp_address)
- function socket.warning(id, callback)
- local obj = socket_pool[id]
- assert(obj)
- obj.on_warning = callback
- end
- return socket
|