tftp.lua 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. -----------------------------------------------------------------------------
  2. -- TFTP support for the Lua language
  3. -- LuaSocket toolkit.
  4. -- Author: Diego Nehab
  5. -----------------------------------------------------------------------------
  6. -----------------------------------------------------------------------------
  7. -- Load required files
  8. -----------------------------------------------------------------------------
  9. local base = _G
  10. local table = require("table")
  11. local math = require("math")
  12. local string = require("string")
  13. local socket = require("socket")
  14. local ltn12 = require("ltn12")
  15. local url = require("socket.url")
  16. module("socket.tftp")
  17. -----------------------------------------------------------------------------
  18. -- Program constants
  19. -----------------------------------------------------------------------------
  20. local char = string.char
  21. local byte = string.byte
  22. PORT = 69
  23. local OP_RRQ = 1
  24. local OP_WRQ = 2
  25. local OP_DATA = 3
  26. local OP_ACK = 4
  27. local OP_ERROR = 5
  28. local OP_INV = {"RRQ", "WRQ", "DATA", "ACK", "ERROR"}
  29. -----------------------------------------------------------------------------
  30. -- Packet creation functions
  31. -----------------------------------------------------------------------------
  32. local function RRQ(source, mode)
  33. return char(0, OP_RRQ) .. source .. char(0) .. mode .. char(0)
  34. end
  35. local function WRQ(source, mode)
  36. return char(0, OP_RRQ) .. source .. char(0) .. mode .. char(0)
  37. end
  38. local function ACK(block)
  39. local low, high
  40. low = math.mod(block, 256)
  41. high = (block - low)/256
  42. return char(0, OP_ACK, high, low)
  43. end
  44. local function get_OP(dgram)
  45. local op = byte(dgram, 1)*256 + byte(dgram, 2)
  46. return op
  47. end
  48. -----------------------------------------------------------------------------
  49. -- Packet analysis functions
  50. -----------------------------------------------------------------------------
  51. local function split_DATA(dgram)
  52. local block = byte(dgram, 3)*256 + byte(dgram, 4)
  53. local data = string.sub(dgram, 5)
  54. return block, data
  55. end
  56. local function get_ERROR(dgram)
  57. local code = byte(dgram, 3)*256 + byte(dgram, 4)
  58. local msg
  59. _,_, msg = string.find(dgram, "(.*)\000", 5)
  60. return string.format("error code %d: %s", code, msg)
  61. end
  62. -----------------------------------------------------------------------------
  63. -- The real work
  64. -----------------------------------------------------------------------------
  65. local function tget(gett)
  66. local retries, dgram, sent, datahost, dataport, code
  67. local last = 0
  68. socket.try(gett.host, "missing host")
  69. local con = socket.try(socket.udp())
  70. local try = socket.newtry(function() con:close() end)
  71. -- convert from name to ip if needed
  72. gett.host = try(socket.dns.toip(gett.host))
  73. con:settimeout(1)
  74. -- first packet gives data host/port to be used for data transfers
  75. local path = string.gsub(gett.path or "", "^/", "")
  76. path = url.unescape(path)
  77. retries = 0
  78. repeat
  79. sent = try(con:sendto(RRQ(path, "octet"), gett.host, gett.port))
  80. dgram, datahost, dataport = con:receivefrom()
  81. retries = retries + 1
  82. until dgram or datahost ~= "timeout" or retries > 5
  83. try(dgram, datahost)
  84. -- associate socket with data host/port
  85. try(con:setpeername(datahost, dataport))
  86. -- default sink
  87. local sink = gett.sink or ltn12.sink.null()
  88. -- process all data packets
  89. while 1 do
  90. -- decode packet
  91. code = get_OP(dgram)
  92. try(code ~= OP_ERROR, get_ERROR(dgram))
  93. try(code == OP_DATA, "unhandled opcode " .. code)
  94. -- get data packet parts
  95. local block, data = split_DATA(dgram)
  96. -- if not repeated, write
  97. if block == last+1 then
  98. try(sink(data))
  99. last = block
  100. end
  101. -- last packet brings less than 512 bytes of data
  102. if string.len(data) < 512 then
  103. try(con:send(ACK(block)))
  104. try(con:close())
  105. try(sink(nil))
  106. return 1
  107. end
  108. -- get the next packet
  109. retries = 0
  110. repeat
  111. sent = try(con:send(ACK(last)))
  112. dgram, err = con:receive()
  113. retries = retries + 1
  114. until dgram or err ~= "timeout" or retries > 5
  115. try(dgram, err)
  116. end
  117. end
  118. local default = {
  119. port = PORT,
  120. path ="/",
  121. scheme = "tftp"
  122. }
  123. local function parse(u)
  124. local t = socket.try(url.parse(u, default))
  125. socket.try(t.scheme == "tftp", "invalid scheme '" .. t.scheme .. "'")
  126. socket.try(t.host, "invalid host")
  127. return t
  128. end
  129. local function sget(u)
  130. local gett = parse(u)
  131. local t = {}
  132. gett.sink = ltn12.sink.table(t)
  133. tget(gett)
  134. return table.concat(t)
  135. end
  136. get = socket.protect(function(gett)
  137. if base.type(gett) == "string" then return sget(gett)
  138. else return tget(gett) end
  139. end)