mongo.lua 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637
  1. local bson = require "bson"
  2. local socket = require "skynet.socket"
  3. local socketchannel = require "skynet.socketchannel"
  4. local skynet = require "skynet"
  5. local driver = require "skynet.mongo.driver"
  6. local md5 = require "md5"
  7. local crypt = require "skynet.crypt"
  8. local rawget = rawget
  9. local assert = assert
  10. local table = table
  11. local bson_encode = bson.encode
  12. local bson_encode_order = bson.encode_order
  13. local bson_decode = bson.decode
  14. local empty_bson = bson_encode {}
  15. local mongo = {}
  16. mongo.null = assert(bson.null)
  17. mongo.maxkey = assert(bson.maxkey)
  18. mongo.minkey = assert(bson.minkey)
  19. mongo.type = assert(bson.type)
  20. local mongo_cursor = {}
  21. local cursor_meta = {
  22. __index = mongo_cursor,
  23. }
  24. local mongo_client = {}
  25. local client_meta = {
  26. __index = function(self, key)
  27. return rawget(mongo_client, key) or self:getDB(key)
  28. end,
  29. __tostring = function (self)
  30. local port_string
  31. if self.port then
  32. port_string = ":" .. tostring(self.port)
  33. else
  34. port_string = ""
  35. end
  36. return "[mongo client : " .. self.host .. port_string .."]"
  37. end,
  38. -- DO NOT need disconnect, because channel will shutdown during gc
  39. }
  40. local mongo_db = {}
  41. local db_meta = {
  42. __index = function (self, key)
  43. return rawget(mongo_db, key) or self:getCollection(key)
  44. end,
  45. __tostring = function (self)
  46. return "[mongo db : " .. self.name .. "]"
  47. end
  48. }
  49. local mongo_collection = {}
  50. local collection_meta = {
  51. __index = function(self, key)
  52. return rawget(mongo_collection, key) or self:getCollection(key)
  53. end ,
  54. __tostring = function (self)
  55. return "[mongo collection : " .. self.full_name .. "]"
  56. end
  57. }
  58. local function dispatch_reply(so)
  59. local len_reply = so:read(4)
  60. local reply = so:read(driver.length(len_reply))
  61. local result = { result = {} }
  62. local succ, reply_id, document, cursor_id, startfrom = driver.reply(reply, result.result)
  63. result.document = document
  64. result.cursor_id = cursor_id
  65. result.startfrom = startfrom
  66. result.data = reply
  67. return reply_id, succ, result
  68. end
  69. local function __parse_addr(addr)
  70. local host, port = string.match(addr, "([^:]+):(.+)")
  71. return host, tonumber(port)
  72. end
  73. local function mongo_auth(mongoc)
  74. local user = rawget(mongoc, "username")
  75. local pass = rawget(mongoc, "password")
  76. local authmod = rawget(mongoc, "authmod") or "scram_sha1"
  77. authmod = "auth_" .. authmod
  78. return function()
  79. if user ~= nil and pass ~= nil then
  80. -- autmod can be "mongodb_cr" or "scram_sha1"
  81. local auth_func = mongoc[authmod]
  82. assert(auth_func , "Invalid authmod")
  83. assert(auth_func(mongoc,user, pass))
  84. end
  85. local rs_data = mongoc:runCommand("ismaster")
  86. if rs_data.ok == 1 then
  87. if rs_data.hosts then
  88. local backup = {}
  89. for _, v in ipairs(rs_data.hosts) do
  90. local host, port = __parse_addr(v)
  91. table.insert(backup, {host = host, port = port})
  92. end
  93. mongoc.__sock:changebackup(backup)
  94. end
  95. if rs_data.ismaster then
  96. if rawget(mongoc, "__pickserver") then
  97. rawset(mongoc, "__pickserver", nil)
  98. end
  99. return
  100. else
  101. if rs_data.primary then
  102. local host, port = __parse_addr(rs_data.primary)
  103. mongoc.host = host
  104. mongoc.port = port
  105. mongoc.__sock:changehost(host, port)
  106. else
  107. skynet.error("WARNING: NO PRIMARY RETURN " .. rs_data.me)
  108. -- determine the primary db using hosts
  109. local pickserver = {}
  110. if rawget(mongoc, "__pickserver") == nil then
  111. for _, v in ipairs(rs_data.hosts) do
  112. if v ~= rs_data.me then
  113. table.insert(pickserver, v)
  114. end
  115. rawset(mongoc, "__pickserver", pickserver)
  116. end
  117. end
  118. if #mongoc.__pickserver <= 0 then
  119. error("CAN NOT DETERMINE THE PRIMARY DB")
  120. end
  121. skynet.error("INFO: TRY TO CONNECT " .. mongoc.__pickserver[1])
  122. local host, port = __parse_addr(mongoc.__pickserver[1])
  123. table.remove(mongoc.__pickserver, 1)
  124. mongoc.host = host
  125. mongoc.port = port
  126. mongoc.__sock:changehost(host, port)
  127. end
  128. end
  129. end
  130. end
  131. end
  132. function mongo.client( conf )
  133. local first = conf
  134. local backup = nil
  135. if conf.rs then
  136. first = conf.rs[1]
  137. backup = conf.rs
  138. end
  139. local obj = {
  140. host = first.host,
  141. port = first.port or 27017,
  142. username = first.username,
  143. password = first.password,
  144. authmod = first.authmod,
  145. }
  146. obj.__id = 0
  147. obj.__sock = socketchannel.channel {
  148. host = obj.host,
  149. port = obj.port,
  150. response = dispatch_reply,
  151. auth = mongo_auth(obj),
  152. backup = backup,
  153. nodelay = true,
  154. }
  155. setmetatable(obj, client_meta)
  156. obj.__sock:connect(true) -- try connect only once
  157. return obj
  158. end
  159. function mongo_client:getDB(dbname)
  160. local db = {
  161. connection = self,
  162. name = dbname,
  163. full_name = dbname,
  164. database = false,
  165. __cmd = dbname .. "." .. "$cmd",
  166. }
  167. db.database = db
  168. return setmetatable(db, db_meta)
  169. end
  170. function mongo_client:disconnect()
  171. if self.__sock then
  172. local so = self.__sock
  173. self.__sock = false
  174. so:close()
  175. end
  176. end
  177. function mongo_client:genId()
  178. local id = self.__id + 1
  179. self.__id = id
  180. return id
  181. end
  182. function mongo_client:runCommand(...)
  183. if not self.admin then
  184. self.admin = self:getDB "admin"
  185. end
  186. return self.admin:runCommand(...)
  187. end
  188. function mongo_client:auth_mongodb_cr(user,password)
  189. local password = md5.sumhexa(string.format("%s:mongo:%s",user,password))
  190. local result= self:runCommand "getnonce"
  191. if result.ok ~=1 then
  192. return false
  193. end
  194. local key = md5.sumhexa(string.format("%s%s%s",result.nonce,user,password))
  195. local result= self:runCommand ("authenticate",1,"user",user,"nonce",result.nonce,"key",key)
  196. return result.ok == 1
  197. end
  198. local function salt_password(password, salt, iter)
  199. salt = salt .. "\0\0\0\1"
  200. local output = crypt.hmac_sha1(password, salt)
  201. local inter = output
  202. for i=2,iter do
  203. inter = crypt.hmac_sha1(password, inter)
  204. output = crypt.xor_str(output, inter)
  205. end
  206. return output
  207. end
  208. function mongo_client:auth_scram_sha1(username,password)
  209. local user = string.gsub(string.gsub(username, '=', '=3D'), ',' , '=2C')
  210. local nonce = crypt.base64encode(crypt.randomkey())
  211. local first_bare = "n=" .. user .. ",r=" .. nonce
  212. local sasl_start_payload = crypt.base64encode("n,," .. first_bare)
  213. local r
  214. r = self:runCommand("saslStart",1,"autoAuthorize",1,"mechanism","SCRAM-SHA-1","payload",sasl_start_payload)
  215. if r.ok ~= 1 then
  216. return false
  217. end
  218. local conversationId = r['conversationId']
  219. local server_first = r['payload']
  220. local parsed_s = crypt.base64decode(server_first)
  221. local parsed_t = {}
  222. for k, v in string.gmatch(parsed_s, "(%w+)=([^,]*)") do
  223. parsed_t[k] = v
  224. end
  225. local iterations = tonumber(parsed_t['i'])
  226. local salt = parsed_t['s']
  227. local rnonce = parsed_t['r']
  228. if not string.sub(rnonce, 1, 12) == nonce then
  229. skynet.error("Server returned an invalid nonce.")
  230. return false
  231. end
  232. local without_proof = "c=biws,r=" .. rnonce
  233. local pbkdf2_key = md5.sumhexa(string.format("%s:mongo:%s",username,password))
  234. local salted_pass = salt_password(pbkdf2_key, crypt.base64decode(salt), iterations)
  235. local client_key = crypt.hmac_sha1(salted_pass, "Client Key")
  236. local stored_key = crypt.sha1(client_key)
  237. local auth_msg = first_bare .. ',' .. parsed_s .. ',' .. without_proof
  238. local client_sig = crypt.hmac_sha1(stored_key, auth_msg)
  239. local client_key_xor_sig = crypt.xor_str(client_key, client_sig)
  240. local client_proof = "p=" .. crypt.base64encode(client_key_xor_sig)
  241. local client_final = crypt.base64encode(without_proof .. ',' .. client_proof)
  242. local server_key = crypt.hmac_sha1(salted_pass, "Server Key")
  243. local server_sig = crypt.base64encode(crypt.hmac_sha1(server_key, auth_msg))
  244. r = self:runCommand("saslContinue",1,"conversationId",conversationId,"payload",client_final)
  245. if r.ok ~= 1 then
  246. return false
  247. end
  248. parsed_s = crypt.base64decode(r['payload'])
  249. parsed_t = {}
  250. for k, v in string.gmatch(parsed_s, "(%w+)=([^,]*)") do
  251. parsed_t[k] = v
  252. end
  253. if parsed_t['v'] ~= server_sig then
  254. skynet.error("Server returned an invalid signature.")
  255. return false
  256. end
  257. if not r.done then
  258. r = self:runCommand("saslContinue",1,"conversationId",conversationId,"payload","")
  259. if r.ok ~= 1 then
  260. return false
  261. end
  262. if not r.done then
  263. skynet.error("SASL conversation failed to complete.")
  264. return false
  265. end
  266. end
  267. return true
  268. end
  269. function mongo_client:logout()
  270. local result = self:runCommand "logout"
  271. return result.ok == 1
  272. end
  273. function mongo_db:runCommand(cmd,cmd_v,...)
  274. local conn = self.connection
  275. local request_id = conn:genId()
  276. local sock = conn.__sock
  277. local bson_cmd
  278. if not cmd_v then
  279. bson_cmd = bson_encode_order(cmd,1)
  280. else
  281. bson_cmd = bson_encode_order(cmd,cmd_v,...)
  282. end
  283. local pack = driver.query(request_id, 0, self.__cmd, 0, 1, bson_cmd)
  284. -- we must hold req (req.data), because req.document is a lightuserdata, it's a pointer to the string (req.data)
  285. local req = sock:request(pack, request_id)
  286. local doc = req.document
  287. return bson_decode(doc)
  288. end
  289. function mongo_db:getCollection(collection)
  290. local col = {
  291. connection = self.connection,
  292. name = collection,
  293. full_name = self.full_name .. "." .. collection,
  294. database = self.database,
  295. }
  296. self[collection] = setmetatable(col, collection_meta)
  297. return col
  298. end
  299. mongo_collection.getCollection = mongo_db.getCollection
  300. function mongo_collection:insert(doc)
  301. if doc._id == nil then
  302. doc._id = bson.objectid()
  303. end
  304. local sock = self.connection.__sock
  305. local pack = driver.insert(0, self.full_name, bson_encode(doc))
  306. -- flags support 1: ContinueOnError
  307. sock:request(pack)
  308. end
  309. function mongo_collection:safe_insert(doc)
  310. return self.database:runCommand("insert", self.name, "documents", {bson_encode(doc)})
  311. end
  312. function mongo_collection:batch_insert(docs)
  313. for i=1,#docs do
  314. if docs[i]._id == nil then
  315. docs[i]._id = bson.objectid()
  316. end
  317. docs[i] = bson_encode(docs[i])
  318. end
  319. local sock = self.connection.__sock
  320. local pack = driver.insert(0, self.full_name, docs)
  321. sock:request(pack)
  322. end
  323. function mongo_collection:update(selector,update,upsert,multi)
  324. local flags = (upsert and 1 or 0) + (multi and 2 or 0)
  325. local sock = self.connection.__sock
  326. local pack = driver.update(self.full_name, flags, bson_encode(selector), bson_encode(update))
  327. sock:request(pack)
  328. end
  329. function mongo_collection:delete(selector, single)
  330. local sock = self.connection.__sock
  331. local pack = driver.delete(self.full_name, single, bson_encode(selector))
  332. sock:request(pack)
  333. end
  334. function mongo_collection:findOne(query, selector)
  335. local conn = self.connection
  336. local request_id = conn:genId()
  337. local sock = conn.__sock
  338. local pack = driver.query(request_id, 0, self.full_name, 0, 1, query and bson_encode(query) or empty_bson, selector and bson_encode(selector))
  339. -- we must hold req (req.data), because req.document is a lightuserdata, it's a pointer to the string (req.data)
  340. local req = sock:request(pack, request_id)
  341. local doc = req.document
  342. return bson_decode(doc)
  343. end
  344. function mongo_collection:find(query, selector)
  345. return setmetatable( {
  346. __collection = self,
  347. __query = query and bson_encode(query) or empty_bson,
  348. __selector = selector and bson_encode(selector),
  349. __ptr = nil,
  350. __data = nil,
  351. __cursor = nil,
  352. __document = {},
  353. __flags = 0,
  354. __skip = 0,
  355. __sortquery = nil,
  356. __limit = 0,
  357. } , cursor_meta)
  358. end
  359. local function unfold(list, key, ...)
  360. if key == nil then
  361. return list
  362. end
  363. local next_func, t = pairs(key)
  364. local k, v = next_func(t) -- The first key pair
  365. table.insert(list, k)
  366. table.insert(list, v)
  367. return unfold(list, ...)
  368. end
  369. -- cursor:sort { key = 1 } or cursor:sort( {key1 = 1}, {key2 = -1})
  370. function mongo_cursor:sort(key, key_v, ...)
  371. if key_v then
  372. local key_list = unfold({}, key, key_v , ...)
  373. key = bson_encode_order(table.unpack(key_list))
  374. end
  375. self.__sortquery = bson_encode {['$query'] = self.__query, ['$orderby'] = key}
  376. return self
  377. end
  378. function mongo_cursor:skip(amount)
  379. self.__skip = amount
  380. return self
  381. end
  382. function mongo_cursor:limit(amount)
  383. self.__limit = amount
  384. return self
  385. end
  386. function mongo_cursor:count(with_limit_and_skip)
  387. local cmd = {
  388. 'count', self.__collection.name,
  389. 'query', self.__query,
  390. }
  391. if with_limit_and_skip then
  392. local len = #cmd
  393. cmd[len+1] = 'limit'
  394. cmd[len+2] = self.__limit
  395. cmd[len+3] = 'skip'
  396. cmd[len+4] = self.__skip
  397. end
  398. local ret = self.__collection.database:runCommand(table.unpack(cmd))
  399. assert(ret and ret.ok == 1)
  400. return ret.n
  401. end
  402. -- For compatibility.
  403. -- collection:createIndex({username = 1}, {unique = true})
  404. local function createIndex_onekey(self, key, option)
  405. local doc = {}
  406. for k,v in pairs(option) do
  407. doc[k] = v
  408. end
  409. local k,v = next(key) -- support only one key
  410. assert(next(key,k) == nil, "Use new api for multi-keys")
  411. doc.name = doc.name or (k .. "_" .. v)
  412. doc.key = key
  413. return self.database:runCommand("createIndexes", self.name, "indexes", {doc})
  414. end
  415. local function IndexModel(option)
  416. local doc = {}
  417. for k,v in pairs(option) do
  418. if type(k) == "string" then
  419. doc[k] = v
  420. end
  421. end
  422. local keys = {}
  423. local name
  424. for _, kv in ipairs(option) do
  425. local k,v
  426. if type(kv) == "string" then
  427. k = kv
  428. v = 1
  429. else
  430. k,v = next(kv)
  431. end
  432. table.insert(keys, k)
  433. table.insert(keys, v)
  434. name = (name == nil) and k or (name .. "_" .. k)
  435. name = name .. "_" .. v
  436. end
  437. assert(name, "Need keys")
  438. doc.name = doc.name or name
  439. doc.key = bson_encode_order(table.unpack(keys))
  440. return doc
  441. end
  442. -- collection:createIndex { { key1 = 1}, { key2 = 1 }, unique = true }
  443. -- or collection:createIndex { "key1", "key2", unique = true }
  444. -- or collection:createIndex( { key1 = 1} , { unique = true } ) -- For compatibility
  445. function mongo_collection:createIndex(arg1 , arg2)
  446. if arg2 then
  447. return createIndex_onekey(self, arg1, arg2)
  448. else
  449. return self.database:runCommand("createIndexes", self.name, "indexes", { IndexModel(arg1) })
  450. end
  451. end
  452. function mongo_collection:createIndexes(...)
  453. local idx = { ... }
  454. for k,v in ipairs(idx) do
  455. idx[k] = IndexModel(v)
  456. end
  457. return self.database:runCommand("createIndexes", self.name, "indexes", idx)
  458. end
  459. mongo_collection.ensureIndex = mongo_collection.createIndex
  460. function mongo_collection:drop()
  461. return self.database:runCommand("drop", self.name)
  462. end
  463. -- collection:dropIndex("age_1")
  464. -- collection:dropIndex("*")
  465. function mongo_collection:dropIndex(indexName)
  466. return self.database:runCommand("dropIndexes", self.name, "index", indexName)
  467. end
  468. -- collection:findAndModify({query = {name = "userid"}, update = {["$inc"] = {nextid = 1}}, })
  469. -- keys, value type
  470. -- query, table
  471. -- sort, table
  472. -- remove, bool
  473. -- update, table
  474. -- new, bool
  475. -- fields, bool
  476. -- upsert, boolean
  477. function mongo_collection:findAndModify(doc)
  478. assert(doc.query)
  479. assert(doc.update or doc.remove)
  480. local cmd = {"findAndModify", self.name};
  481. for k, v in pairs(doc) do
  482. table.insert(cmd, k)
  483. table.insert(cmd, v)
  484. end
  485. return self.database:runCommand(table.unpack(cmd))
  486. end
  487. function mongo_cursor:hasNext()
  488. if self.__ptr == nil then
  489. if self.__document == nil then
  490. return false
  491. end
  492. local conn = self.__collection.connection
  493. local request_id = conn:genId()
  494. local sock = conn.__sock
  495. local pack
  496. if self.__data == nil then
  497. local query = self.__sortquery or self.__query
  498. pack = driver.query(request_id, self.__flags, self.__collection.full_name, self.__skip, self.__limit, query, self.__selector)
  499. else
  500. if self.__cursor then
  501. pack = driver.more(request_id, self.__collection.full_name, self.__limit, self.__cursor)
  502. else
  503. -- no more
  504. self.__document = nil
  505. self.__data = nil
  506. return false
  507. end
  508. end
  509. local ok, result = pcall(sock.request,sock,pack, request_id)
  510. local doc = result.document
  511. local cursor = result.cursor_id
  512. if ok then
  513. if doc then
  514. local doc = result.result
  515. self.__document = doc
  516. self.__data = result.data
  517. self.__ptr = 1
  518. self.__cursor = cursor
  519. local limit = self.__limit
  520. if cursor and limit > 0 then
  521. limit = limit - #doc
  522. if limit <= 0 then
  523. -- reach limit
  524. self:close()
  525. end
  526. self.__limit = limit
  527. end
  528. return true
  529. else
  530. self.__document = nil
  531. self.__data = nil
  532. self.__cursor = nil
  533. return false
  534. end
  535. else
  536. self.__document = nil
  537. self.__data = nil
  538. self.__cursor = nil
  539. if doc then
  540. local err = bson_decode(doc)
  541. error(err["$err"])
  542. else
  543. error("Reply from mongod error")
  544. end
  545. end
  546. end
  547. return true
  548. end
  549. function mongo_cursor:next()
  550. if self.__ptr == nil then
  551. error "Call hasNext first"
  552. end
  553. local r = bson_decode(self.__document[self.__ptr])
  554. self.__ptr = self.__ptr + 1
  555. if self.__ptr > #self.__document then
  556. self.__ptr = nil
  557. end
  558. return r
  559. end
  560. function mongo_cursor:close()
  561. if self.__cursor then
  562. local sock = self.__collection.connection.__sock
  563. local pack = driver.kill(self.__cursor)
  564. sock:request(pack)
  565. self.__cursor = nil
  566. end
  567. end
  568. return mongo