sproto.lua 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. local core = require "sproto.core"
  2. local assert = assert
  3. local sproto = {}
  4. local host = {}
  5. local weak_mt = { __mode = "kv" }
  6. local sproto_mt = { __index = sproto }
  7. local sproto_nogc = { __index = sproto }
  8. local host_mt = { __index = host }
  9. function sproto_mt:__gc()
  10. core.deleteproto(self.__cobj)
  11. end
  12. function sproto.new(bin)
  13. local cobj = assert(core.newproto(bin))
  14. local self = {
  15. __cobj = cobj,
  16. __tcache = setmetatable( {} , weak_mt ),
  17. __pcache = setmetatable( {} , weak_mt ),
  18. }
  19. return setmetatable(self, sproto_mt)
  20. end
  21. function sproto.sharenew(cobj)
  22. local self = {
  23. __cobj = cobj,
  24. __tcache = setmetatable( {} , weak_mt ),
  25. __pcache = setmetatable( {} , weak_mt ),
  26. }
  27. return setmetatable(self, sproto_nogc)
  28. end
  29. function sproto.parse(ptext)
  30. local parser = require "sprotoparser"
  31. local pbin = parser.parse(ptext)
  32. return sproto.new(pbin)
  33. end
  34. function sproto:host( packagename )
  35. packagename = packagename or "package"
  36. local obj = {
  37. __proto = self,
  38. __package = assert(core.querytype(self.__cobj, packagename), "type package not found"),
  39. __session = {},
  40. }
  41. return setmetatable(obj, host_mt)
  42. end
  43. local function querytype(self, typename)
  44. local v = self.__tcache[typename]
  45. if not v then
  46. v = assert(core.querytype(self.__cobj, typename), "type not found")
  47. self.__tcache[typename] = v
  48. end
  49. return v
  50. end
  51. function sproto:exist_type(typename)
  52. local v = self.__tcache[typename]
  53. if not v then
  54. return core.querytype(self.__cobj, typename) ~= nil
  55. else
  56. return true
  57. end
  58. end
  59. function sproto:encode(typename, tbl)
  60. local st = querytype(self, typename)
  61. return core.encode(st, tbl)
  62. end
  63. function sproto:decode(typename, ...)
  64. local st = querytype(self, typename)
  65. return core.decode(st, ...)
  66. end
  67. function sproto:pencode(typename, tbl)
  68. local st = querytype(self, typename)
  69. return core.pack(core.encode(st, tbl))
  70. end
  71. function sproto:pdecode(typename, ...)
  72. local st = querytype(self, typename)
  73. return core.decode(st, core.unpack(...))
  74. end
  75. local function queryproto(self, pname)
  76. local v = self.__pcache[pname]
  77. if not v then
  78. local tag, req, resp = core.protocol(self.__cobj, pname)
  79. assert(tag, pname .. " not found")
  80. if tonumber(pname) then
  81. pname, tag = tag, pname
  82. end
  83. v = {
  84. request = req,
  85. response =resp,
  86. name = pname,
  87. tag = tag,
  88. }
  89. self.__pcache[pname] = v
  90. self.__pcache[tag] = v
  91. end
  92. return v
  93. end
  94. function sproto:exist_proto(pname)
  95. local v = self.__pcache[pname]
  96. if not v then
  97. return core.protocol(self.__cobj, pname) ~= nil
  98. else
  99. return true
  100. end
  101. end
  102. function sproto:request_encode(protoname, tbl)
  103. local p = queryproto(self, protoname)
  104. local request = p.request
  105. if request then
  106. return core.encode(request,tbl) , p.tag
  107. else
  108. return "" , p.tag
  109. end
  110. end
  111. function sproto:response_encode(protoname, tbl)
  112. local p = queryproto(self, protoname)
  113. local response = p.response
  114. if response then
  115. return core.encode(response,tbl)
  116. else
  117. return ""
  118. end
  119. end
  120. function sproto:request_decode(protoname, ...)
  121. local p = queryproto(self, protoname)
  122. local request = p.request
  123. if request then
  124. return core.decode(request,...) , p.name
  125. else
  126. return nil, p.name
  127. end
  128. end
  129. function sproto:response_decode(protoname, ...)
  130. local p = queryproto(self, protoname)
  131. local response = p.response
  132. if response then
  133. return core.decode(response,...)
  134. end
  135. end
  136. sproto.pack = core.pack
  137. sproto.unpack = core.unpack
  138. function sproto:default(typename, type)
  139. if type == nil then
  140. return core.default(querytype(self, typename))
  141. else
  142. local p = queryproto(self, typename)
  143. if type == "REQUEST" then
  144. if p.request then
  145. return core.default(p.request)
  146. end
  147. elseif type == "RESPONSE" then
  148. if p.response then
  149. return core.default(p.response)
  150. end
  151. else
  152. error "Invalid type"
  153. end
  154. end
  155. end
  156. local header_tmp = {}
  157. local function gen_response(self, response, session)
  158. return function(args, ud)
  159. header_tmp.type = nil
  160. header_tmp.session = session
  161. header_tmp.ud = ud
  162. local header = core.encode(self.__package, header_tmp)
  163. if response then
  164. local content = core.encode(response, args)
  165. return core.pack(header .. content)
  166. else
  167. return core.pack(header)
  168. end
  169. end
  170. end
  171. function host:dispatch(...)
  172. local bin = core.unpack(...)
  173. header_tmp.type = nil
  174. header_tmp.session = nil
  175. header_tmp.ud = nil
  176. local header, size = core.decode(self.__package, bin, header_tmp)
  177. local content = bin:sub(size + 1)
  178. if header.type then
  179. -- request
  180. local proto = queryproto(self.__proto, header.type)
  181. local result
  182. if proto.request then
  183. result = core.decode(proto.request, content)
  184. end
  185. if header_tmp.session then
  186. return "REQUEST", proto.name, result, gen_response(self, proto.response, header_tmp.session), header.ud
  187. else
  188. return "REQUEST", proto.name, result, nil, header.ud
  189. end
  190. else
  191. -- response
  192. local session = assert(header_tmp.session, "session not found")
  193. local response = assert(self.__session[session], "Unknown session")
  194. self.__session[session] = nil
  195. if response == true then
  196. return "RESPONSE", session, nil, header.ud
  197. else
  198. local result = core.decode(response, content)
  199. return "RESPONSE", session, result, header.ud
  200. end
  201. end
  202. end
  203. function host:attach(sp)
  204. return function(name, args, session, ud)
  205. local proto = queryproto(sp, name)
  206. header_tmp.type = proto.tag
  207. header_tmp.session = session
  208. header_tmp.ud = ud
  209. local header = core.encode(self.__package, header_tmp)
  210. if session then
  211. self.__session[session] = proto.response or true
  212. end
  213. if proto.request then
  214. local content = core.encode(proto.request, args)
  215. return core.pack(header .. content)
  216. else
  217. return core.pack(header)
  218. end
  219. end
  220. end
  221. return sproto