sprotoparser.lua 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. local lpeg = require "lpeg"
  2. local table = require "table"
  3. local packbytes
  4. local packvalue
  5. if _VERSION == "Lua 5.3" then
  6. function packbytes(str)
  7. return string.pack("<s4",str)
  8. end
  9. function packvalue(id)
  10. id = (id + 1) * 2
  11. return string.pack("<I2",id)
  12. end
  13. else
  14. function packbytes(str)
  15. local size = #str
  16. local a = size % 256
  17. size = math.floor(size / 256)
  18. local b = size % 256
  19. size = math.floor(size / 256)
  20. local c = size % 256
  21. size = math.floor(size / 256)
  22. local d = size
  23. return string.char(a)..string.char(b)..string.char(c)..string.char(d) .. str
  24. end
  25. function packvalue(id)
  26. id = (id + 1) * 2
  27. assert(id >=0 and id < 65536)
  28. local a = id % 256
  29. local b = math.floor(id / 256)
  30. return string.char(a) .. string.char(b)
  31. end
  32. end
  33. local P = lpeg.P
  34. local S = lpeg.S
  35. local R = lpeg.R
  36. local C = lpeg.C
  37. local Ct = lpeg.Ct
  38. local Cg = lpeg.Cg
  39. local Cc = lpeg.Cc
  40. local V = lpeg.V
  41. local function count_lines(_,pos, parser_state)
  42. if parser_state.pos < pos then
  43. parser_state.line = parser_state.line + 1
  44. parser_state.pos = pos
  45. end
  46. return pos
  47. end
  48. local exception = lpeg.Cmt( lpeg.Carg(1) , function ( _ , pos, parser_state)
  49. error(string.format("syntax error at [%s] line (%d)", parser_state.file or "", parser_state.line))
  50. return pos
  51. end)
  52. local eof = P(-1)
  53. local newline = lpeg.Cmt((P"\n" + "\r\n") * lpeg.Carg(1) ,count_lines)
  54. local line_comment = "#" * (1 - newline) ^0 * (newline + eof)
  55. local blank = S" \t" + newline + line_comment
  56. local blank0 = blank ^ 0
  57. local blanks = blank ^ 1
  58. local alpha = R"az" + R"AZ" + "_"
  59. local alnum = alpha + R"09"
  60. local word = alpha * alnum ^ 0
  61. local name = C(word)
  62. local typename = C(word * ("." * word) ^ 0)
  63. local tag = R"09" ^ 1 / tonumber
  64. local mainkey = "(" * blank0 * name * blank0 * ")"
  65. local decimal = "(" * blank0 * C(tag) * blank0 * ")"
  66. local function multipat(pat)
  67. return Ct(blank0 * (pat * blanks) ^ 0 * pat^0 * blank0)
  68. end
  69. local function namedpat(name, pat)
  70. return Ct(Cg(Cc(name), "type") * Cg(pat))
  71. end
  72. local typedef = P {
  73. "ALL",
  74. FIELD = namedpat("field", (name * blanks * tag * blank0 * ":" * blank0 * (C"*")^-1 * typename * (mainkey + decimal)^0)),
  75. STRUCT = P"{" * multipat(V"FIELD" + V"TYPE") * P"}",
  76. TYPE = namedpat("type", P"." * name * blank0 * V"STRUCT" ),
  77. SUBPROTO = Ct((C"request" + C"response") * blanks * (typename + V"STRUCT")),
  78. PROTOCOL = namedpat("protocol", name * blanks * tag * blank0 * P"{" * multipat(V"SUBPROTO") * P"}"),
  79. ALL = multipat(V"TYPE" + V"PROTOCOL"),
  80. }
  81. local proto = blank0 * typedef * blank0
  82. local convert = {}
  83. function convert.protocol(all, obj)
  84. local result = { tag = obj[2] }
  85. for _, p in ipairs(obj[3]) do
  86. assert(result[p[1]] == nil)
  87. local typename = p[2]
  88. if type(typename) == "table" then
  89. local struct = typename
  90. typename = obj[1] .. "." .. p[1]
  91. all.type[typename] = convert.type(all, { typename, struct })
  92. end
  93. if typename == "nil" then
  94. if p[1] == "response" then
  95. result.confirm = true
  96. end
  97. else
  98. result[p[1]] = typename
  99. end
  100. end
  101. return result
  102. end
  103. function convert.type(all, obj)
  104. local result = {}
  105. local typename = obj[1]
  106. local tags = {}
  107. local names = {}
  108. for _, f in ipairs(obj[2]) do
  109. if f.type == "field" then
  110. local name = f[1]
  111. if names[name] then
  112. error(string.format("redefine %s in type %s", name, typename))
  113. end
  114. names[name] = true
  115. local tag = f[2]
  116. if tags[tag] then
  117. error(string.format("redefine tag %d in type %s", tag, typename))
  118. end
  119. tags[tag] = true
  120. local field = { name = name, tag = tag }
  121. table.insert(result, field)
  122. local fieldtype = f[3]
  123. if fieldtype == "*" then
  124. field.array = true
  125. fieldtype = f[4]
  126. end
  127. local mainkey = f[5]
  128. if mainkey then
  129. if fieldtype == "integer" then
  130. field.decimal = mainkey
  131. else
  132. assert(field.array)
  133. field.key = mainkey
  134. end
  135. end
  136. field.typename = fieldtype
  137. else
  138. assert(f.type == "type") -- nest type
  139. local nesttypename = typename .. "." .. f[1]
  140. f[1] = nesttypename
  141. assert(all.type[nesttypename] == nil, "redefined " .. nesttypename)
  142. all.type[nesttypename] = convert.type(all, f)
  143. end
  144. end
  145. table.sort(result, function(a,b) return a.tag < b.tag end)
  146. return result
  147. end
  148. local function adjust(r)
  149. local result = { type = {} , protocol = {} }
  150. for _, obj in ipairs(r) do
  151. local set = result[obj.type]
  152. local name = obj[1]
  153. assert(set[name] == nil , "redefined " .. name)
  154. set[name] = convert[obj.type](result,obj)
  155. end
  156. return result
  157. end
  158. local buildin_types = {
  159. integer = 0,
  160. boolean = 1,
  161. string = 2,
  162. binary = 2, -- binary is a sub type of string
  163. }
  164. local function checktype(types, ptype, t)
  165. if buildin_types[t] then
  166. return t
  167. end
  168. local fullname = ptype .. "." .. t
  169. if types[fullname] then
  170. return fullname
  171. else
  172. ptype = ptype:match "(.+)%..+$"
  173. if ptype then
  174. return checktype(types, ptype, t)
  175. elseif types[t] then
  176. return t
  177. end
  178. end
  179. end
  180. local function check_protocol(r)
  181. local map = {}
  182. local type = r.type
  183. for name, v in pairs(r.protocol) do
  184. local tag = v.tag
  185. local request = v.request
  186. local response = v.response
  187. local p = map[tag]
  188. if p then
  189. error(string.format("redefined protocol tag %d at %s", tag, name))
  190. end
  191. if request and not type[request] then
  192. error(string.format("Undefined request type %s in protocol %s", request, name))
  193. end
  194. if response and not type[response] then
  195. error(string.format("Undefined response type %s in protocol %s", response, name))
  196. end
  197. map[tag] = v
  198. end
  199. return r
  200. end
  201. local function flattypename(r)
  202. for typename, t in pairs(r.type) do
  203. for _, f in pairs(t) do
  204. local ftype = f.typename
  205. local fullname = checktype(r.type, typename, ftype)
  206. if fullname == nil then
  207. error(string.format("Undefined type %s in type %s", ftype, typename))
  208. end
  209. f.typename = fullname
  210. end
  211. end
  212. return r
  213. end
  214. local function parser(text,filename)
  215. local state = { file = filename, pos = 0, line = 1 }
  216. local r = lpeg.match(proto * -1 + exception , text , 1, state )
  217. return flattypename(check_protocol(adjust(r)))
  218. end
  219. --[[
  220. -- The protocol of sproto
  221. .type {
  222. .field {
  223. name 0 : string
  224. buildin 1 : integer
  225. type 2 : integer
  226. tag 3 : integer
  227. array 4 : boolean
  228. key 5 : integer # If key exists, array must be true, and it's a map.
  229. }
  230. name 0 : string
  231. fields 1 : *field
  232. }
  233. .protocol {
  234. name 0 : string
  235. tag 1 : integer
  236. request 2 : integer # index
  237. response 3 : integer # index
  238. confirm 4 : boolean # true means response nil
  239. }
  240. .group {
  241. type 0 : *type
  242. protocol 1 : *protocol
  243. }
  244. ]]
  245. local function packfield(f)
  246. local strtbl = {}
  247. if f.array then
  248. if f.key then
  249. table.insert(strtbl, "\6\0") -- 6 fields
  250. else
  251. table.insert(strtbl, "\5\0") -- 5 fields
  252. end
  253. else
  254. table.insert(strtbl, "\4\0") -- 4 fields
  255. end
  256. table.insert(strtbl, "\0\0") -- name (tag = 0, ref an object)
  257. if f.buildin then
  258. table.insert(strtbl, packvalue(f.buildin)) -- buildin (tag = 1)
  259. if f.extra then
  260. table.insert(strtbl, packvalue(f.extra)) -- f.buildin can be integer or string
  261. else
  262. table.insert(strtbl, "\1\0") -- skip (tag = 2)
  263. end
  264. table.insert(strtbl, packvalue(f.tag)) -- tag (tag = 3)
  265. else
  266. table.insert(strtbl, "\1\0") -- skip (tag = 1)
  267. table.insert(strtbl, packvalue(f.type)) -- type (tag = 2)
  268. table.insert(strtbl, packvalue(f.tag)) -- tag (tag = 3)
  269. end
  270. if f.array then
  271. table.insert(strtbl, packvalue(1)) -- array = true (tag = 4)
  272. end
  273. if f.key then
  274. table.insert(strtbl, packvalue(f.key)) -- key tag (tag = 5)
  275. end
  276. table.insert(strtbl, packbytes(f.name)) -- external object (name)
  277. return packbytes(table.concat(strtbl))
  278. end
  279. local function packtype(name, t, alltypes)
  280. local fields = {}
  281. local tmp = {}
  282. for _, f in ipairs(t) do
  283. tmp.array = f.array
  284. tmp.name = f.name
  285. tmp.tag = f.tag
  286. tmp.extra = f.decimal
  287. tmp.buildin = buildin_types[f.typename]
  288. if f.typename == "binary" then
  289. tmp.extra = 1 -- binary is sub type of string
  290. end
  291. local subtype
  292. if not tmp.buildin then
  293. subtype = assert(alltypes[f.typename])
  294. tmp.type = subtype.id
  295. else
  296. tmp.type = nil
  297. end
  298. if f.key then
  299. tmp.key = subtype.fields[f.key]
  300. if not tmp.key then
  301. error("Invalid map index :" .. f.key)
  302. end
  303. else
  304. tmp.key = nil
  305. end
  306. table.insert(fields, packfield(tmp))
  307. end
  308. local data
  309. if #fields == 0 then
  310. data = {
  311. "\1\0", -- 1 fields
  312. "\0\0", -- name (id = 0, ref = 0)
  313. packbytes(name),
  314. }
  315. else
  316. data = {
  317. "\2\0", -- 2 fields
  318. "\0\0", -- name (tag = 0, ref = 0)
  319. "\0\0", -- field[] (tag = 1, ref = 1)
  320. packbytes(name),
  321. packbytes(table.concat(fields)),
  322. }
  323. end
  324. return packbytes(table.concat(data))
  325. end
  326. local function packproto(name, p, alltypes)
  327. if p.request then
  328. local request = alltypes[p.request]
  329. if request == nil then
  330. error(string.format("Protocol %s request type %s not found", name, p.request))
  331. end
  332. request = request.id
  333. end
  334. local tmp = {
  335. "\4\0", -- 4 fields
  336. "\0\0", -- name (id=0, ref=0)
  337. packvalue(p.tag), -- tag (tag=1)
  338. }
  339. if p.request == nil and p.response == nil and p.confirm == nil then
  340. tmp[1] = "\2\0" -- only two fields
  341. else
  342. if p.request then
  343. table.insert(tmp, packvalue(alltypes[p.request].id)) -- request typename (tag=2)
  344. else
  345. table.insert(tmp, "\1\0") -- skip this field (request)
  346. end
  347. if p.response then
  348. table.insert(tmp, packvalue(alltypes[p.response].id)) -- request typename (tag=3)
  349. elseif p.confirm then
  350. tmp[1] = "\5\0" -- add confirm field
  351. table.insert(tmp, "\1\0") -- skip this field (response)
  352. table.insert(tmp, packvalue(1)) -- confirm = true
  353. else
  354. tmp[1] = "\3\0" -- only three fields
  355. end
  356. end
  357. table.insert(tmp, packbytes(name))
  358. return packbytes(table.concat(tmp))
  359. end
  360. local function packgroup(t,p)
  361. if next(t) == nil then
  362. assert(next(p) == nil)
  363. return "\0\0"
  364. end
  365. local tt, tp
  366. local alltypes = {}
  367. for name in pairs(t) do
  368. table.insert(alltypes, name)
  369. end
  370. table.sort(alltypes) -- make result stable
  371. for idx, name in ipairs(alltypes) do
  372. local fields = {}
  373. for _, type_fields in ipairs(t[name]) do
  374. if buildin_types[type_fields.typename] then
  375. fields[type_fields.name] = type_fields.tag
  376. end
  377. end
  378. alltypes[name] = { id = idx - 1, fields = fields }
  379. end
  380. tt = {}
  381. for _,name in ipairs(alltypes) do
  382. table.insert(tt, packtype(name, t[name], alltypes))
  383. end
  384. tt = packbytes(table.concat(tt))
  385. if next(p) then
  386. local tmp = {}
  387. for name, tbl in pairs(p) do
  388. table.insert(tmp, tbl)
  389. tbl.name = name
  390. end
  391. table.sort(tmp, function(a,b) return a.tag < b.tag end)
  392. tp = {}
  393. for _, tbl in ipairs(tmp) do
  394. table.insert(tp, packproto(tbl.name, tbl, alltypes))
  395. end
  396. tp = packbytes(table.concat(tp))
  397. end
  398. local result
  399. if tp == nil then
  400. result = {
  401. "\1\0", -- 1 field
  402. "\0\0", -- type[] (id = 0, ref = 0)
  403. tt,
  404. }
  405. else
  406. result = {
  407. "\2\0", -- 2fields
  408. "\0\0", -- type array (id = 0, ref = 0)
  409. "\0\0", -- protocol array (id = 1, ref =1)
  410. tt,
  411. tp,
  412. }
  413. end
  414. return table.concat(result)
  415. end
  416. local function encodeall(r)
  417. return packgroup(r.type, r.protocol)
  418. end
  419. local sparser = {}
  420. function sparser.dump(str)
  421. local tmp = ""
  422. for i=1,#str do
  423. tmp = tmp .. string.format("%02X ", string.byte(str,i))
  424. if i % 8 == 0 then
  425. if i % 16 == 0 then
  426. print(tmp)
  427. tmp = ""
  428. else
  429. tmp = tmp .. "- "
  430. end
  431. end
  432. end
  433. print(tmp)
  434. end
  435. function sparser.parse(text, name)
  436. local r = parser(text, name or "=text")
  437. local data = encodeall(r)
  438. return data
  439. end
  440. return sparser