dispatch.lua 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. -----------------------------------------------------------------------------
  2. -- A hacked dispatcher module
  3. -- LuaSocket sample files
  4. -- Author: Diego Nehab
  5. -----------------------------------------------------------------------------
  6. local base = _G
  7. local table = require("table")
  8. local string = require("string")
  9. local socket = require("socket")
  10. local coroutine = require("coroutine")
  11. module("dispatch")
  12. -- if too much time goes by without any activity in one of our sockets, we
  13. -- just kill it
  14. TIMEOUT = 60
  15. -----------------------------------------------------------------------------
  16. -- We implement 3 types of dispatchers:
  17. -- sequential
  18. -- coroutine
  19. -- threaded
  20. -- The user can choose whatever one is needed
  21. -----------------------------------------------------------------------------
  22. local handlert = {}
  23. -- default handler is coroutine
  24. function newhandler(mode)
  25. mode = mode or "coroutine"
  26. return handlert[mode]()
  27. end
  28. local function seqstart(self, func)
  29. return func()
  30. end
  31. -- sequential handler simply calls the functions and doesn't wrap I/O
  32. function handlert.sequential()
  33. return {
  34. tcp = socket.tcp,
  35. start = seqstart
  36. }
  37. end
  38. -----------------------------------------------------------------------------
  39. -- Mega hack. Don't try to do this at home.
  40. -----------------------------------------------------------------------------
  41. -- we can't yield across calls to protect on Lua 5.1, so we rewrite it with
  42. -- coroutines
  43. -- make sure you don't require any module that uses socket.protect before
  44. -- loading our hack
  45. if string.sub(base._VERSION, -3) == "5.1" then
  46. local function _protect(co, status, ...)
  47. if not status then
  48. local msg = ...
  49. if base.type(msg) == 'table' then
  50. return nil, msg[1]
  51. else
  52. base.error(msg, 0)
  53. end
  54. end
  55. if coroutine.status(co) == "suspended" then
  56. return _protect(co, coroutine.resume(co, coroutine.yield(...)))
  57. else
  58. return ...
  59. end
  60. end
  61. function socket.protect(f)
  62. return function(...)
  63. local co = coroutine.create(f)
  64. return _protect(co, coroutine.resume(co, ...))
  65. end
  66. end
  67. end
  68. -----------------------------------------------------------------------------
  69. -- Simple set data structure. O(1) everything.
  70. -----------------------------------------------------------------------------
  71. local function newset()
  72. local reverse = {}
  73. local set = {}
  74. return base.setmetatable(set, {__index = {
  75. insert = function(set, value)
  76. if not reverse[value] then
  77. table.insert(set, value)
  78. reverse[value] = #set
  79. end
  80. end,
  81. remove = function(set, value)
  82. local index = reverse[value]
  83. if index then
  84. reverse[value] = nil
  85. local top = table.remove(set)
  86. if top ~= value then
  87. reverse[top] = index
  88. set[index] = top
  89. end
  90. end
  91. end
  92. }})
  93. end
  94. -----------------------------------------------------------------------------
  95. -- socket.tcp() wrapper for the coroutine dispatcher
  96. -----------------------------------------------------------------------------
  97. local function cowrap(dispatcher, tcp, error)
  98. if not tcp then return nil, error end
  99. -- put it in non-blocking mode right away
  100. tcp:settimeout(0)
  101. -- metatable for wrap produces new methods on demand for those that we
  102. -- don't override explicitly.
  103. local metat = { __index = function(table, key)
  104. table[key] = function(...)
  105. return tcp[key](tcp,select(2,...))
  106. end
  107. return table[key]
  108. end}
  109. -- does our user want to do his own non-blocking I/O?
  110. local zero = false
  111. -- create a wrap object that will behave just like a real socket object
  112. local wrap = { }
  113. -- we ignore settimeout to preserve our 0 timeout, but record whether
  114. -- the user wants to do his own non-blocking I/O
  115. function wrap:settimeout(value, mode)
  116. if value == 0 then zero = true
  117. else zero = false end
  118. return 1
  119. end
  120. -- send in non-blocking mode and yield on timeout
  121. function wrap:send(data, first, last)
  122. first = (first or 1) - 1
  123. local result, error
  124. while true do
  125. -- return control to dispatcher and tell it we want to send
  126. -- if upon return the dispatcher tells us we timed out,
  127. -- return an error to whoever called us
  128. if coroutine.yield(dispatcher.sending, tcp) == "timeout" then
  129. return nil, "timeout"
  130. end
  131. -- try sending
  132. result, error, first = tcp:send(data, first+1, last)
  133. -- if we are done, or there was an unexpected error,
  134. -- break away from loop
  135. if error ~= "timeout" then return result, error, first end
  136. end
  137. end
  138. -- receive in non-blocking mode and yield on timeout
  139. -- or simply return partial read, if user requested timeout = 0
  140. function wrap:receive(pattern, partial)
  141. local error = "timeout"
  142. local value
  143. while true do
  144. -- return control to dispatcher and tell it we want to receive
  145. -- if upon return the dispatcher tells us we timed out,
  146. -- return an error to whoever called us
  147. if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then
  148. return nil, "timeout"
  149. end
  150. -- try receiving
  151. value, error, partial = tcp:receive(pattern, partial)
  152. -- if we are done, or there was an unexpected error,
  153. -- break away from loop. also, if the user requested
  154. -- zero timeout, return all we got
  155. if (error ~= "timeout") or zero then
  156. return value, error, partial
  157. end
  158. end
  159. end
  160. -- connect in non-blocking mode and yield on timeout
  161. function wrap:connect(host, port)
  162. local result, error = tcp:connect(host, port)
  163. if error == "timeout" then
  164. -- return control to dispatcher. we will be writable when
  165. -- connection succeeds.
  166. -- if upon return the dispatcher tells us we have a
  167. -- timeout, just abort
  168. if coroutine.yield(dispatcher.sending, tcp) == "timeout" then
  169. return nil, "timeout"
  170. end
  171. -- when we come back, check if connection was successful
  172. result, error = tcp:connect(host, port)
  173. if result or error == "already connected" then return 1
  174. else return nil, "non-blocking connect failed" end
  175. else return result, error end
  176. end
  177. -- accept in non-blocking mode and yield on timeout
  178. function wrap:accept()
  179. while 1 do
  180. -- return control to dispatcher. we will be readable when a
  181. -- connection arrives.
  182. -- if upon return the dispatcher tells us we have a
  183. -- timeout, just abort
  184. if coroutine.yield(dispatcher.receiving, tcp) == "timeout" then
  185. return nil, "timeout"
  186. end
  187. local client, error = tcp:accept()
  188. if error ~= "timeout" then
  189. return cowrap(dispatcher, client, error)
  190. end
  191. end
  192. end
  193. -- remove cortn from context
  194. function wrap:close()
  195. dispatcher.stamp[tcp] = nil
  196. dispatcher.sending.set:remove(tcp)
  197. dispatcher.sending.cortn[tcp] = nil
  198. dispatcher.receiving.set:remove(tcp)
  199. dispatcher.receiving.cortn[tcp] = nil
  200. return tcp:close()
  201. end
  202. return base.setmetatable(wrap, metat)
  203. end
  204. -----------------------------------------------------------------------------
  205. -- Our coroutine dispatcher
  206. -----------------------------------------------------------------------------
  207. local cometat = { __index = {} }
  208. function schedule(cortn, status, operation, tcp)
  209. if status then
  210. if cortn and operation then
  211. operation.set:insert(tcp)
  212. operation.cortn[tcp] = cortn
  213. operation.stamp[tcp] = socket.gettime()
  214. end
  215. else base.error(operation) end
  216. end
  217. function kick(operation, tcp)
  218. operation.cortn[tcp] = nil
  219. operation.set:remove(tcp)
  220. end
  221. function wakeup(operation, tcp)
  222. local cortn = operation.cortn[tcp]
  223. -- if cortn is still valid, wake it up
  224. if cortn then
  225. kick(operation, tcp)
  226. return cortn, coroutine.resume(cortn)
  227. -- othrewise, just get scheduler not to do anything
  228. else
  229. return nil, true
  230. end
  231. end
  232. function abort(operation, tcp)
  233. local cortn = operation.cortn[tcp]
  234. if cortn then
  235. kick(operation, tcp)
  236. coroutine.resume(cortn, "timeout")
  237. end
  238. end
  239. -- step through all active cortns
  240. function cometat.__index:step()
  241. -- check which sockets are interesting and act on them
  242. local readable, writable = socket.select(self.receiving.set,
  243. self.sending.set, 1)
  244. -- for all readable connections, resume their cortns and reschedule
  245. -- when they yield back to us
  246. for _, tcp in base.ipairs(readable) do
  247. schedule(wakeup(self.receiving, tcp))
  248. end
  249. -- for all writable connections, do the same
  250. for _, tcp in base.ipairs(writable) do
  251. schedule(wakeup(self.sending, tcp))
  252. end
  253. -- politely ask replacement I/O functions in idle cortns to
  254. -- return reporting a timeout
  255. local now = socket.gettime()
  256. for tcp, stamp in base.pairs(self.stamp) do
  257. if tcp.class == "tcp{client}" and now - stamp > TIMEOUT then
  258. abort(self.sending, tcp)
  259. abort(self.receiving, tcp)
  260. end
  261. end
  262. end
  263. function cometat.__index:start(func)
  264. local cortn = coroutine.create(func)
  265. schedule(cortn, coroutine.resume(cortn))
  266. end
  267. function handlert.coroutine()
  268. local stamp = {}
  269. local dispatcher = {
  270. stamp = stamp,
  271. sending = {
  272. name = "sending",
  273. set = newset(),
  274. cortn = {},
  275. stamp = stamp
  276. },
  277. receiving = {
  278. name = "receiving",
  279. set = newset(),
  280. cortn = {},
  281. stamp = stamp
  282. },
  283. }
  284. function dispatcher.tcp()
  285. return cowrap(dispatcher, socket.tcp())
  286. end
  287. return base.setmetatable(dispatcher, cometat)
  288. end