functions_replacer.lua 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. --- Replace functions of table or upvalue.
  2. -- Search for the old functions and replace them with new ones.
  3. local M = {}
  4. -- Objects whose functions have been replaced already.
  5. -- Each objects need to be replaced once.
  6. local replaced_obj = {}
  7. -- Map old functions to new functions.
  8. -- Used to replace functions finally.
  9. -- Set to hotfix.updated_func_map.
  10. local updated_func_map = {}
  11. -- Do not update and replace protected objects.
  12. -- Set to hotfix.protected.
  13. local protected = {}
  14. local replace_functions -- forward declare
  15. -- Replace all updated functions in upvalues of function object.
  16. local function replace_functions_in_upvalues(function_object)
  17. local obj = function_object
  18. assert("function" == type(obj))
  19. assert(not protected[obj])
  20. assert(obj ~= updated_func_map)
  21. for i = 1, math.huge do
  22. local name, value = debug.getupvalue(obj, i)
  23. if not name then return end
  24. local new_func = updated_func_map[value]
  25. if new_func then
  26. assert("function" == type(value))
  27. debug.setupvalue(obj, i, new_func)
  28. else
  29. replace_functions(value)
  30. end
  31. end -- for
  32. assert(false, "Can not reach here!")
  33. end -- replace_functions_in_upvalues()
  34. -- Replace all updated functions in the table.
  35. local function replace_functions_in_table(table_object)
  36. local obj = table_object
  37. assert("table" == type(obj))
  38. assert(not protected[obj])
  39. assert(obj ~= updated_func_map)
  40. replace_functions(debug.getmetatable(obj))
  41. local new = {} -- to assign new fields
  42. for k, v in pairs(obj) do
  43. -- 配置文件不能更新
  44. if type(k)== "string" and not string.match(k, "../config") then
  45. local new_k = updated_func_map[k]
  46. local new_v = updated_func_map[v]
  47. if new_k then
  48. obj[k] = nil -- delete field
  49. new[new_k] = new_v or v
  50. else
  51. obj[k] = new_v or v
  52. replace_functions(k)
  53. end
  54. if not new_v then replace_functions(v) end
  55. end
  56. end -- for k, v
  57. for k, v in pairs(new) do obj[k] = v end
  58. end -- replace_functions_in_table()
  59. -- Replace all updated functions.
  60. -- Record all replaced objects in replaced_obj.
  61. function replace_functions(obj)
  62. if protected[obj] then return end
  63. local obj_type = type(obj)
  64. if "function" ~= obj_type and "table" ~= obj_type then return end
  65. if replaced_obj[obj] then return end
  66. replaced_obj[obj] = true
  67. assert(obj ~= updated_func_map)
  68. if "function" == obj_type then
  69. replace_functions_in_upvalues(obj)
  70. else -- table
  71. replace_functions_in_table(obj)
  72. end
  73. end -- replace_functions(obj)
  74. --- Replace all old functions with new ones.
  75. -- Replace in new_obj, _G and debug.getregistry().
  76. -- a_protected is a list of protected object.
  77. -- an_updated_func_map is a map from old function to new function.
  78. -- new_obj is the newly loaded module.
  79. function M.replace_all(a_protected, an_updated_func_map, new_obj, module_name)
  80. protected = a_protected
  81. updated_func_map = an_updated_func_map
  82. assert(type(protected) == "table")
  83. assert(type(updated_func_map) == "table")
  84. if nil == next(updated_func_map) then
  85. return
  86. end
  87. replaced_obj = {}
  88. replace_functions(new_obj) -- new_obj may be not in _G
  89. replace_functions(_G)
  90. replace_functions(debug.getregistry())
  91. replaced_obj = {}
  92. end -- M.replace_all()
  93. return M