txlyre 2 ani în urmă
părinte
comite
532bb89a72
2 a modificat fișierele cu 151 adăugiri și 51 ștergeri
  1. 125 46
      byafn.py
  2. 26 5
      byafnctl.py

+ 125 - 46
byafn.py

@@ -36,12 +36,13 @@ WATCHER_INTERVAL  = 30
 CACHE_LIFETIME    = 3600
 
 MAX_TTL           = 7
-MAX_DISTANCE      = 8
 
 MAX_PAYLOAD_SIZE  = 1024*1024*64
 CHUNK_SIZE        = 512
 
-RATELIMIT         = 0.5
+COOLDOWN          = 0.1
+
+MAX_PIECE_SIZE    = 1024*1024*10
 
 config = {}
 
@@ -66,6 +67,32 @@ def chunks(l, n):
 
 class Error(Exception): pass
 
+async def cleanup_pieces(age):
+  count = 0
+
+  for filename in aiofiles.os.listdir(
+    config['StoragePath']
+  ):
+    path = os.path.join(
+      config['StoragePath'],
+      filename
+    )
+
+    access_time = await aiofiles.os.path.getatime(
+      path
+    )
+
+    if time.time() - access_time > age:
+      await logger.info(f'Purging `{filename}\' from the storage.')
+
+      await aiofiles.os.remove(
+        path
+      )
+
+      count += 1
+
+  return count
+
 async def is_piece_exists(hash):
   if await aiofiles.os.path.isfile(
     os.path.join(
@@ -112,7 +139,9 @@ async def read_piece(hash):
     if sha256(data) != hash:
       await aiofiles.os.remove(path)
 
-      raise ValueError
+      raise ValueError('piece actual checksum (i.e. on disk) and expected checksum do not match')
+
+    os.utime(path)
 
     return data
 
@@ -131,7 +160,39 @@ class Message:
     self.uid = uid if uid else get_random_bytes(16)
     self.fields = fields
 
+    if self.kind == Message.QUERY:
+      if 'hash' not in self.fields\
+      or type(self.fields['hash']) != bytes\
+      or len(self.fields['hash']) != 32:
+        raise ValueError('malformed `QUERY\' message: illegal or missing `hash\' field')
+
+      if 'ttl' not in self.fields\
+      or type(self.fields['ttl']) != int\
+      or self.fields['ttl'] < 1:
+        raise ValueError('malformed `QUERY\' message: illegal or missing `ttl\' field')
+    elif self.kind == Message.QUERY_HIT:
+      if 'data' not in self.fields\
+      or type(self.fields['data']) != bytes\
+      or len(self.fields['data']) < 1\
+      or len(self.fields['data']) > MAX_PIECE_SIZE:
+        raise ValueError('malformed `QUERY_HIT\' message: illegal or missing `data\' field')
+    elif self.kind == Message.NOT_AVAILABLE:
+      if self.fields:
+        raise ValueError('malformed `NOT_AVAILABLE\' message: unexpected payload')
+    elif self.kind == ServiceMessage.PING:
+      if self.fields:
+        raise ValueError('malformed `PING\' message: unexpected payload')
+    elif self.kind == ServiceMessage.PONG:
+      if self.fields:
+        raise ValueError('malformed `PONG\' message: unexpected payload')
+    elif self.kind == ServiceMessage.CLOSE:
+      if self.fields:
+        raise ValueError('malformed `CLOSE\' message: unexpected payload')
+
   def __getattr__(self, field):
+    if field not in self.fields:
+      raise Error(f'missing required field `{field}\'')
+
     return self.fields[field]
 
   def cache(self):
@@ -188,8 +249,7 @@ class Peer:
     self.is_open = True
     self.send_lock = asyncio.Lock()
     self.receive_lock = asyncio.Lock()
-    self.ticks = 0
-    self.last_message_in_ts = -1
+    self.last_message_ts = -1
 
 async def write(peer, data):
   peer.writer.write(data)
@@ -203,15 +263,6 @@ async def read(peer, size):
 
   return buffer
 
-async def cooldown(peer, out=True):
-  if out:
-    await asyncio.sleep(RATELIMIT + 0.01)
-  else:
-    if time.time() - peer.last_message_in_ts < RATELIMIT:
-      raise Error(f'rate limit (={RATELIMIT}s.) exceeded for incoming messages')
-
-    peer.last_message_in_ts = time.time()
-
 async def send(peer, message):
   if type(message) is ServiceMessage:
     buffer = bytes([message.kind])
@@ -250,7 +301,6 @@ async def send(peer, message):
   )
 
   async with peer.send_lock:
-    await cooldown(peer)
     await write(peer, buffer)
 
     if chunks_count:
@@ -263,16 +313,17 @@ async def send(peer, message):
 
 async def receive(peer):
   async with peer.receive_lock:
-    await cooldown(peer, False)
-
     kind = (await read(peer, 1))[0]
-
+      
     if kind != 0xff:
       if kind > ServiceMessage.CLOSE:
-        raise Error(f'unecrypted non-service messages are not allowed')
+        raise Error('unecrypted non-service messages are not allowed')
 
       length = struct.unpack('<H', await read(peer, 2))[0]
 
+      if length > MAX_PAYLOAD_SIZE:
+        raise Error('payload is too large')
+
       payload = {}
 
       if length:
@@ -359,7 +410,7 @@ async def wait_response(peer, message):
 
         return other_message
 
-    await asyncio.sleep(1)
+    await asyncio.sleep(0)
 
 async def communicate(peer, message, timeout=None):
   await send(peer, message)
@@ -427,7 +478,7 @@ async def broadcast(message, message_filter=None, peer_filter=None):
 
   message.fields['ttl'] += 1
 
-  for peer in peers:
+  for peer in random.sample(peers, len(peers)):
     if not peer.is_open:
       continue
 
@@ -477,12 +528,12 @@ async def handshake(peer):
 
   answer = await receive(peer)
   if answer.kind != ServiceMessage.HELLO:
-    raise Error
+    raise Error('handshake failed: illegal initial message')
 
   key = answer.key
 
   if key == config['Key']:
-    raise Error
+    raise Error('handshake failed: looping connection')
 
   for peer in peers.copy():
     if peer.key == key:
@@ -491,7 +542,7 @@ async def handshake(peer):
 
         continue
 
-      raise Error
+      raise Error('handshake failed: duplicated connection')
 
   data = get_random_bytes(16)
 
@@ -508,7 +559,7 @@ async def handshake(peer):
 
   answer = await receive(peer)
   if answer.kind != ServiceMessage.CHALLENGE:
-    raise Error
+    raise Error('handshake failed: illegal challenge initiation message')
 
   await send(
     peer,
@@ -523,10 +574,10 @@ async def handshake(peer):
 
   answer = await receive(peer)
   if answer.kind != ServiceMessage.ANSWER:
-    raise Error
+    raise Error('handshake failed: illegal challenge answer message')
 
   if answer.data != data:
-    raise Error
+    raise Error('handshake failed: challenge data mismatch')
 
   await send(
     peer,
@@ -537,28 +588,36 @@ async def handshake(peer):
 
   answer = await receive(peer)
   if answer.kind != ServiceMessage.FINISH:
-    raise Error
+    raise Error('handshake failed: illegal finish message')
 
   peer.key = key
 
+async def cooldown(peer):
+  delta = time.time() - peer.last_message_ts
+
+  if delta < COOLDOWN:
+    await asyncio.sleep(COOLDOWN - delta)
+
+  peer.last_message_ts = time.time()
+
 async def serve(peer):
   await asyncio.wait_for(
     handshake(peer),
     timeout=HANDSHAKE_TIMEOUT
   )
 
-  await asyncio.sleep(RATELIMIT)
-
   asyncio.create_task(
     tick(peer)
   )
 
   if peer.address:
-    await logger.info(f'Connected to {peer.address}')
+    await logger.info(f'Connected to {peer.address}.')
 
   while peer.is_open:
     message = await receive(peer)
 
+    await cooldown(peer)
+
     if not message.cache():
       if message.kind == Message.QUERY:
         await respond(
@@ -611,7 +670,7 @@ async def serve(peer):
         data=answer
       )
     else:
-      raise Error(f'unknown message kind={message.kind}')
+      raise Error(f'unknown message kind: {hex(message.kind)}')
 
 async def accept(reader, writer, address=None):
   peer = Peer(reader, writer, address)
@@ -621,7 +680,7 @@ async def accept(reader, writer, address=None):
     await serve(peer)
   except Exception as e:
     if peer.address:
-      await logger.warning(f'Connection lost {peer.address}: {e}')
+      await logger.warning(f'Connection lost to {peer.address}: {e}')
   finally:
     await close(peer)
 
@@ -650,11 +709,16 @@ async def dial(address):
   )
 
 async def listen():
-  server = await asyncio.start_server(
-    accept,
-    config['ListenAddress'],
-    int(config['ListenPort'])
-  )
+  try:
+    server = await asyncio.start_server(
+      accept,
+      config['ListenAddress'],
+      int(config['ListenPort'])
+    )
+  except Exception as e:
+    await logger.error(f'Bind {config["ListenAddress"]}:{config["ListenPort"]}: {e}')
+
+    return
 
   await logger.info(f'Listening at {config["ListenAddress"]}:{config["ListenPort"]}')
 
@@ -729,6 +793,12 @@ async def accept_admin(reader, writer):
       asyncio.create_task(
         shutdown(delay)
       )
+    elif 'cleanup' in request:
+      age = int(request['cleanup']['age'])
+
+      response['removed_count'] = await cleanup_pieces(
+        age
+      )
     else:
       raise Error('unrecognized command')
   except Exception as e:
@@ -742,10 +812,15 @@ async def accept_admin(reader, writer):
   await writer.drain()
 
 async def listen_admin():
-  server = await asyncio.start_unix_server(
-    accept_admin,
-    config['AdminSocketPath'],
-  )
+  try:
+    server = await asyncio.start_unix_server(
+      accept_admin,
+      config['AdminSocketPath'],
+    )
+  except Exception as e:
+    await logger.error(f'Bind {config["AdminSocketPath"]}: {e}')
+
+    return
 
   async with server:
     await server.serve_forever()
@@ -754,7 +829,7 @@ async def main():
   global config
 
   if len(sys.argv) < 2:
-    print(f'usage: {sys.argv[0]} <config.conf>')
+    print(f'Usage: {sys.argv[0]} <config.conf>')
 
     return
 
@@ -769,8 +844,12 @@ async def main():
 
     return
 
-  if not await aiofiles.os.path.isdir(config['StoragePath']):
-    await aiofiles.os.mkdir(config['StoragePath'])
+  if not await aiofiles.os.path.isdir(
+    config['StoragePath']
+  ):
+    await aiofiles.os.mkdir(
+      config['StoragePath']
+    )
 
   asyncio.create_task(
     watcher()
@@ -788,4 +867,4 @@ async def main():
 try:
   asyncio.run(main())
 except KeyboardInterrupt:
-  print('Interrupted')
+  print('Interrupted.')

+ 26 - 5
byafnctl.py

@@ -16,8 +16,13 @@ from Crypto.Hash import SHA256
 from Crypto.Cipher import Salsa20
 from Crypto.Random import get_random_bytes
 
+BYAFN_ADMINSOCKET = os.getenv(
+  'BYAFN_ADMINSOCKET',
+  default='./adminsocket'
+)
+
 logging.basicConfig(
-  format="%(asctime)s %(levelname)s: %(message)s", 
+  format='%(asctime)s %(levelname)s: %(message)s',
   level=logging.INFO
 )
 
@@ -195,7 +200,7 @@ parser.add_argument(
 parser.add_argument(
   '-a', '--admin-socket-path',
   help='Set AdminSocketPath (use together with --genconf, --share or --query)',
-  type=str, default='./adminsocket'
+  type=str, default=BYAFN_ADMINSOCKET
 )
 
 parser.add_argument(
@@ -234,6 +239,12 @@ parser.add_argument(
   type=int
 )
 
+parser.add_argument(
+  '--cleanup',
+  help='Remove every piece from the storage if they weren\'t accessed more than for specified amount of seconds',
+  type=int
+)
+
 args = parser.parse_args()
 
 if args.genconf:
@@ -277,8 +288,10 @@ if args.share:
     pieces_count = max(1, file_size // args.piece_size)
     progress = tqdm.trange(pieces_count)
     pieces = []
+
     while True:
       piece = f.read(args.piece_size)
+
       if not piece:
         break
 
@@ -326,8 +339,10 @@ if args.query:
     checksum = SHA256.new()
 
     progress = tqdm.trange(len(metafile.pieces))
+
     for piece in metafile.pieces:
       interval = 10
+
       while True:
         data = send_command(query={
           'hash': piece
@@ -359,14 +374,20 @@ if args.query:
         sys.exit(1)
 
       output.write(piece)
-
       progress.update(1)
 
     output.close()
 
-if args.shutdown:
+if args.shutdown is not None:
   send_command(shutdown={
     'delay': args.shutdown
   })
 
-  logging.info('Shutdown command sent')
+  logging.info('Shutdown command sent')
+
+if args.cleanup is not None:
+  response = send_command(cleanup={
+    'age': args.cleanup
+  })
+
+  logging.info(f'Removed {response["removed_count"]} piece(s).')