byafn.py 15 KB


  1. import os
  2. import sys
  3. import time
  4. import random
  5. import struct
  6. import asyncio
  7. from collections import namedtuple
  8. import zstd
  9. import cbor2
  10. import ecies
  11. import hjson
  12. import aiofiles
  13. import aiologger
  14. try:
  15. import aiofiles.os
  16. except AttributeError:
  17. os.link = os.symlink
  18. import aiofiles.os
  19. from Crypto.Hash import SHA256
  20. from Crypto.Random import get_random_bytes
  21. PINGS_COUNT = 3
  22. PING_TIMEOUT = 10
  23. HANDSHAKE_TIMEOUT = 10
  24. BROADCAST_TIMEOUT = 300
  25. HEARTBEAT = 10
  26. WATCHER_INTERVAL = 30
  27. CACHE_LIFETIME = 3600
  28. MAX_TTL = 7
  29. MAX_DISTANCE = 8
  30. MAX_PAYLOAD_SIZE = 1024*1024*64
  31. CHUNK_SIZE = 512
  32. RATELIMIT = 0.5
  33. config = {}
  34. peers = []
  35. cache = []
  36. logger = aiologger.Logger.with_default_handlers(
  37. formatter=aiologger.formatters.base.Formatter(
  38. fmt='%(asctime)s %(levelname)s: %(message)s'
  39. )
  40. )
  41. def sha256(data):
  42. hash = SHA256.new()
  43. hash.update(data)
  44. return hash.digest()
  45. def chunks(l, n):
  46. for i in range(0, len(l), n):
  47. yield l[i:i + n]
  48. class Error(Exception): pass
  49. async def is_piece_exists(hash):
  50. if await aiofiles.os.path.isfile(
  51. os.path.join(
  52. config['StoragePath'],
  53. hash.hex()
  54. )
  55. ):
  56. return True
  57. return False
  58. async def save_piece(data, hash=None):
  59. if not hash:
  60. hash = sha256(data)
  61. path = os.path.join(
  62. config['StoragePath'],
  63. hash.hex()
  64. )
  65. async with aiofiles.open(
  66. path,
  67. 'wb'
  68. ) as f:
  69. data = zstd.compress(data)
  70. await f.write(data)
  71. return hash
  72. async def read_piece(hash):
  73. path = os.path.join(
  74. config['StoragePath'],
  75. hash.hex()
  76. )
  77. async with aiofiles.open(
  78. path,
  79. 'rb'
  80. ) as f:
  81. data = await f.read()
  82. data = zstd.decompress(data)
  83. if sha256(data) != hash:
  84. await aiofiles.os.remove(path)
  85. raise ValueError
  86. return data
  87. CachedMessage = namedtuple(
  88. 'CachedMessage',
  89. 'kind uid ts'
  90. )
  91. class Message:
  92. QUERY = 0xa
  93. QUERY_HIT = 0xb
  94. NOT_AVAILABLE = 0xc
  95. def __init__(self, kind, uid=None, **fields):
  96. self.kind = kind
  97. self.uid = uid if uid else get_random_bytes(16)
  98. self.fields = fields
  99. def __getattr__(self, field):
  100. return self.fields[field]
  101. def cache(self):
  102. for message in cache:
  103. if (
  104. message.uid == self.uid and
  105. message.kind == self.kind
  106. ):
  107. return False
  108. cache.append(
  109. CachedMessage(
  110. kind=self.kind,
  111. uid=self.uid,
  112. ts=time.time()
  113. )
  114. )
  115. return True
  116. def is_response_for(self, message):
  117. if self.uid != message.uid:
  118. return False
  119. if message.kind == ServiceMessage.PING:
  120. return self.kind == ServiceMessage.PONG
  121. elif message.kind == Message.QUERY:
  122. return self.kind in (
  123. Message.QUERY_HIT,
  124. Message.NOT_AVAILABLE
  125. )
  126. return False
  127. class ServiceMessage(Message):
  128. HELLO = 0x0
  129. CHALLENGE = 0x1
  130. ANSWER = 0x2
  131. FINISH = 0x3
  132. PING = 0x4
  133. PONG = 0x5
  134. CLOSE = 0x6
  135. class Peer:
  136. def __init__(self, reader, writer, address=None):
  137. self.reader = reader
  138. self.writer = writer
  139. self.address = address
  140. self.key = None
  141. self.queue = []
  142. self.is_open = True
  143. self.send_lock = asyncio.Lock()
  144. self.receive_lock = asyncio.Lock()
  145. self.ticks = 0
  146. self.last_message_in_ts = -1
  147. async def write(peer, data):
  148. peer.writer.write(data)
  149. await peer.writer.drain()
  150. async def read(peer, size):
  151. buffer = b''
  152. while len(buffer) < size:
  153. buffer += await peer.reader.read(size - len(buffer))
  154. return buffer
  155. async def cooldown(peer, out=True):
  156. if out:
  157. await asyncio.sleep(RATELIMIT)
  158. else:
  159. if time.time() - peer.last_message_in_ts < RATELIMIT:
  160. raise Error(f'rate limit (={RATELIMIT}s.) exceeded for incoming messages')
  161. peer.last_message_in_ts = time.time()
  162. async def send(peer, message):
  163. if type(message) is ServiceMessage:
  164. buffer = bytes([message.kind])
  165. if message.fields:
  166. payload = cbor2.dumps(message.fields)
  167. buffer += struct.pack('<H', len(payload))
  168. buffer += payload
  169. else:
  170. buffer += bytes(2)
  171. async with peer.send_lock:
  172. await write(peer, buffer)
  173. return
  174. payload = b''
  175. chunks_count = 0
  176. if message.fields:
  177. payload = cbor2.dumps(message.fields)
  178. payload = zstd.compress(payload)
  179. payload = ecies.encrypt(
  180. peer.key,
  181. payload
  182. )
  183. chunks_count = max(1, len(payload) // CHUNK_SIZE)
  184. buffer = b'\xff' + ecies.encrypt(
  185. peer.key,
  186. bytes([message.kind]) +
  187. message.uid +
  188. struct.pack('<H', chunks_count)
  189. )
  190. async with peer.send_lock:
  191. await cooldown(peer)
  192. await write(peer, buffer)
  193. if chunks_count:
  194. for chunk in chunks(payload, CHUNK_SIZE):
  195. await write(
  196. peer,
  197. struct.pack('<H', len(chunk)) +
  198. chunk
  199. )
  200. async def receive(peer):
  201. async with peer.receive_lock:
  202. await cooldown(peer, False)
  203. kind = (await read(peer, 1))[0]
  204. if kind != 0xff:
  205. if kind > ServiceMessage.CLOSE:
  206. raise Error(f'unecrypted non-service messages are not allowed')
  207. length = struct.unpack('<H', await read(peer, 2))[0]
  208. payload = {}
  209. if length:
  210. payload = await read(peer, length)
  211. payload = cbor2.loads(payload)
  212. return ServiceMessage(
  213. kind,
  214. **payload
  215. )
  216. head = await read(peer, 116)
  217. head = ecies.decrypt(
  218. config['Secret'],
  219. head
  220. )
  221. kind = head[0]
  222. uid = head[1:17]
  223. chunks_count = struct.unpack('<H', head[17:19])[0]
  224. payload = {}
  225. if chunks_count:
  226. payload = b''
  227. if chunks_count * CHUNK_SIZE > MAX_PAYLOAD_SIZE:
  228. raise Error('payload is too large')
  229. for _ in range(chunks_count):
  230. length = struct.unpack('<H', await read(peer, 2))[0]
  231. if not length or length > CHUNK_SIZE:
  232. raise Error('illegal chunk length')
  233. payload += await read(peer, length)
  234. payload = ecies.decrypt(
  235. config['Secret'],
  236. payload
  237. )
  238. payload = zstd.decompress(payload)
  239. payload = cbor2.loads(payload)
  240. return Message(
  241. kind,
  242. uid=uid,
  243. **payload
  244. )
  245. async def close(peer, gracefully=True):
  246. if not peer.is_open:
  247. return
  248. if gracefully:
  249. try:
  250. await send(
  251. peer,
  252. ServiceMessage(
  253. ServiceMessage.CLOSE
  254. )
  255. )
  256. except:
  257. pass
  258. peer.writer.close()
  259. peer.is_open = False
  260. try:
  261. await asyncio.wait_for(
  262. peer.writer.wait_closed(),
  263. timeout=3
  264. )
  265. except:
  266. pass
  267. async def wait_response(peer, message):
  268. while peer.is_open:
  269. for other_message in peer.queue:
  270. if other_message.is_response_for(message):
  271. peer.queue.remove(other_message)
  272. return other_message
  273. await asyncio.sleep(1)
  274. async def communicate(peer, message, timeout=None):
  275. await send(peer, message)
  276. answer = await asyncio.wait_for(
  277. wait_response(
  278. peer,
  279. message
  280. ),
  281. timeout=timeout
  282. )
  283. if not answer:
  284. raise Error('communication timeout')
  285. return answer
  286. async def ping(peer):
  287. await communicate(
  288. peer,
  289. ServiceMessage(
  290. ServiceMessage.PING
  291. ),
  292. timeout=PING_TIMEOUT
  293. )
  294. async def respond(peer, message, kind, is_service=False, **data):
  295. await send(
  296. peer,
  297. (ServiceMessage if is_service else Message)(
  298. kind,
  299. uid=message.uid,
  300. **data
  301. )
  302. )
  303. async def query(hash, uid=None, ttl=0, filter=None):
  304. if await is_piece_exists(hash):
  305. return await read_piece(hash)
  306. answer = await broadcast(
  307. Message(
  308. Message.QUERY,
  309. uid=uid,
  310. hash=hash,
  311. ttl=ttl
  312. ),
  313. message_filter=lambda answer: answer.kind == Message.QUERY_HIT and sha256(answer.data) == hash,
  314. peer_filter=filter
  315. )
  316. if not answer:
  317. return None
  318. await save_piece(
  319. answer.data,
  320. hash
  321. )
  322. return answer.data
  323. async def broadcast(message, message_filter=None, peer_filter=None):
  324. if message.ttl >= MAX_TTL:
  325. return
  326. message.fields['ttl'] += 1
  327. for peer in peers:
  328. if not peer.is_open:
  329. continue
  330. if peer_filter and not peer_filter(peer):
  331. continue
  332. try:
  333. answer = await communicate(
  334. peer,
  335. message,
  336. timeout=BROADCAST_TIMEOUT
  337. )
  338. if message_filter and message_filter(answer):
  339. return answer
  340. except:
  341. continue
  342. async def tick(peer):
  343. while peer.is_open:
  344. if peer.ticks > 0 and peer.ticks % 3 == 0:
  345. attempts = PINGS_COUNT
  346. while True:
  347. try:
  348. await ping(peer)
  349. except:
  350. attempts -= 1
  351. if attempts < 1:
  352. self.close(False)
  353. return
  354. break
  355. peer.ticks += 1
  356. await asyncio.sleep(HEARTBEAT)
  357. async def handshake(peer):
  358. await send(
  359. peer,
  360. ServiceMessage(
  361. ServiceMessage.HELLO,
  362. key=config['Key']
  363. )
  364. )
  365. answer = await receive(peer)
  366. if answer.kind != ServiceMessage.HELLO:
  367. raise Error
  368. key = answer.key
  369. if key == config['Key']:
  370. raise Error
  371. for peer in peers.copy():
  372. if peer.key == key:
  373. if not peer.is_open:
  374. peers.remove(peer)
  375. continue
  376. raise Error
  377. data = get_random_bytes(16)
  378. await send(
  379. peer,
  380. ServiceMessage(
  381. ServiceMessage.CHALLENGE,
  382. data=ecies.encrypt(
  383. key,
  384. data
  385. )
  386. )
  387. )
  388. answer = await receive(peer)
  389. if answer.kind != ServiceMessage.CHALLENGE:
  390. raise Error
  391. await send(
  392. peer,
  393. ServiceMessage(
  394. ServiceMessage.ANSWER,
  395. data=ecies.decrypt(
  396. config['Secret'],
  397. answer.data
  398. )
  399. )
  400. )
  401. answer = await receive(peer)
  402. if answer.kind != ServiceMessage.ANSWER:
  403. raise Error
  404. if answer.data != data:
  405. raise Error
  406. await send(
  407. peer,
  408. ServiceMessage(
  409. ServiceMessage.FINISH
  410. )
  411. )
  412. answer = await receive(peer)
  413. if answer.kind != ServiceMessage.FINISH:
  414. raise Error
  415. peer.key = key
  416. async def serve(peer):
  417. await asyncio.wait_for(
  418. handshake(peer),
  419. timeout=HANDSHAKE_TIMEOUT
  420. )
  421. await asyncio.sleep(RATELIMIT)
  422. asyncio.create_task(
  423. tick(peer)
  424. )
  425. if peer.address:
  426. await logger.info(f'Connected to {peer.address}')
  427. while peer.is_open:
  428. message = await receive(peer)
  429. if not message.cache():
  430. if message.kind == Message.QUERY:
  431. await respond(
  432. peer,
  433. message,
  434. Message.NOT_AVAILABLE
  435. )
  436. continue
  437. if message.kind in (
  438. ServiceMessage.PONG,
  439. Message.QUERY_HIT,
  440. Message.NOT_AVAILABLE
  441. ):
  442. peer.queue.append(message)
  443. continue
  444. if message.kind == ServiceMessage.PING:
  445. await respond(
  446. peer,
  447. message,
  448. ServiceMessage.PONG,
  449. is_service=True
  450. )
  451. elif message.kind == ServiceMessage.CLOSE:
  452. await close(peer, False)
  453. elif message.kind == Message.QUERY:
  454. answer = await query(
  455. message.hash,
  456. uid=message.uid,
  457. ttl=message.ttl,
  458. filter=lambda other_peer: other_peer.key not in (peer.key, config['Key'])
  459. )
  460. if not answer:
  461. await respond(
  462. peer,
  463. message,
  464. Message.NOT_AVAILABLE
  465. )
  466. continue
  467. await respond(
  468. peer,
  469. message,
  470. Message.QUERY_HIT,
  471. data=answer
  472. )
  473. else:
  474. raise Error(f'unknown message kind={message.kind}')
  475. async def accept(reader, writer, address=None):
  476. peer = Peer(reader, writer, address)
  477. peers.append(peer)
  478. try:
  479. await serve(peer)
  480. except Exception as e:
  481. if peer.address:
  482. await logger.warning(f'Connection lost {peer.address}: {e}')
  483. finally:
  484. await close(peer)
  485. async def dial(address):
  486. parts = address.split(':')
  487. host = ':'.join(parts[:-1])
  488. port = int(parts[-1])
  489. try:
  490. reader, writer = await asyncio.open_connection(
  491. host,
  492. port
  493. )
  494. except Exception as e:
  495. dummy_peer = Peer(None, None, address)
  496. dummy_peer.is_open = False
  497. peers.append(dummy_peer)
  498. await logger.error(f'Dial {address}: {e}')
  499. return
  500. asyncio.create_task(
  501. accept(reader, writer, address)
  502. )
  503. async def listen():
  504. server = await asyncio.start_server(
  505. accept,
  506. config['ListenAddress'],
  507. int(config['ListenPort'])
  508. )
  509. await logger.info(f'Listening at {config["ListenAddress"]}:{config["ListenPort"]}')
  510. async with server:
  511. await server.serve_forever()
  512. async def watcher():
  513. while True:
  514. for peer in peers.copy():
  515. if not peer.is_open:
  516. peers.remove(peer)
  517. if peer.address:
  518. asyncio.create_task(
  519. dial(peer.address)
  520. )
  521. continue
  522. for message in cache.copy():
  523. if time.time() - message.ts > CACHE_LIFETIME:
  524. cache.remove(message)
  525. await asyncio.sleep(WATCHER_INTERVAL)
  526. async def shutdown(delay):
  527. await asyncio.sleep(delay)
  528. await logger.info('Performing graceful shutdown')
  529. for peer in peers:
  530. if not peer.is_open:
  531. continue
  532. try:
  533. await send(
  534. peer,
  535. ServiceMessage(
  536. ServiceMessage.CLOSE
  537. )
  538. )
  539. except:
  540. pass
  541. while True:
  542. os.kill(os.getpid(), 2)
  543. async def accept_admin(reader, writer):
  544. try:
  545. length = struct.unpack('<I', await reader.read(4))[0]
  546. request = cbor2.loads(await reader.read(length))
  547. response = {}
  548. if 'store' in request:
  549. hash = await save_piece(
  550. request['store']['piece']
  551. )
  552. response['hash'] = hash
  553. elif 'query' in request:
  554. piece = await query(
  555. request['query']['hash']
  556. )
  557. if piece:
  558. response['piece'] = piece
  559. elif 'shutdown' in request:
  560. delay = int(request['shutdown']['delay'])
  561. await logger.info(f'Requested shutdown in {delay}sec.')
  562. asyncio.create_task(
  563. shutdown(delay)
  564. )
  565. else:
  566. raise Error('unrecognized command')
  567. except Exception as e:
  568. await logger.error(f'Process request on admin socket: {e}')
  569. response = cbor2.dumps(response)
  570. writer.write(struct.pack('<I', len(response)))
  571. writer.write(response)
  572. await writer.drain()
  573. async def listen_admin():
  574. server = await asyncio.start_unix_server(
  575. accept_admin,
  576. config['AdminSocketPath'],
  577. )
  578. async with server:
  579. await server.serve_forever()
  580. async def main():
  581. global config
  582. if len(sys.argv) < 2:
  583. print(f'usage: {sys.argv[0]} <config.conf>')
  584. return
  585. try:
  586. async with aiofiles.open(
  587. sys.argv[1],
  588. 'r'
  589. ) as f:
  590. config = hjson.loads(await f.read())
  591. except Exception as e:
  592. await logger.error(f'Load configuration `{sys.argv[1]}\': {e}')
  593. return
  594. if not await aiofiles.os.path.isdir(config['StoragePath']):
  595. await aiofiles.os.mkdir(config['StoragePath'])
  596. asyncio.create_task(
  597. watcher()
  598. )
  599. for address in config['Peers']:
  600. await dial(address)
  601. asyncio.create_task(
  602. listen_admin()
  603. )
  604. await listen()
  605. try:
  606. asyncio.run(main())
  607. except KeyboardInterrupt:
  608. print('Interrupted')