import os import sys import time import random import struct import asyncio from collections import namedtuple import zstd import cbor2 import ecies import hjson import aiofiles import aiologger try: import aiofiles.os except AttributeError: os.link = os.symlink import aiofiles.os from Crypto.Hash import SHA256 from Crypto.Random import get_random_bytes PINGS_COUNT = 3 PING_TIMEOUT = 10 HANDSHAKE_TIMEOUT = 10 BROADCAST_TIMEOUT = 300 HEARTBEAT = 10 WATCHER_INTERVAL = 30 CACHE_LIFETIME = 3600 MAX_TTL = 7 MAX_PAYLOAD_SIZE = 1024*1024*64 CHUNK_SIZE = 512 COOLDOWN = 0.1 MAX_PIECE_SIZE = 1024*1024*10 config = {} peers = [] cache = [] logger = aiologger.Logger.with_default_handlers( formatter=aiologger.formatters.base.Formatter( fmt='%(asctime)s %(levelname)s: %(message)s' ) ) def sha256(data): hash = SHA256.new() hash.update(data) return hash.digest() def chunks(l, n): for i in range(0, len(l), n): yield l[i:i + n] class Error(Exception): pass async def cleanup_pieces(age): count = 0 for filename in aiofiles.os.listdir( config['StoragePath'] ): path = os.path.join( config['StoragePath'], filename ) access_time = await aiofiles.os.path.getatime( path ) if time.time() - access_time > age: await logger.info(f'Purging `{filename}\' from the storage.') await aiofiles.os.remove( path ) count += 1 return count async def is_piece_exists(hash): if await aiofiles.os.path.isfile( os.path.join( config['StoragePath'], hash.hex() ) ): return True return False async def save_piece(data, hash=None): if not hash: hash = sha256(data) path = os.path.join( config['StoragePath'], hash.hex() ) async with aiofiles.open( path, 'wb' ) as f: data = zstd.compress(data) await f.write(data) return hash async def read_piece(hash): path = os.path.join( config['StoragePath'], hash.hex() ) async with aiofiles.open( path, 'rb' ) as f: data = await f.read() data = zstd.decompress(data) if sha256(data) != hash: await aiofiles.os.remove(path) raise ValueError('piece actual checksum (i.e. on disk) and expected checksum do not match') os.utime(path) return data CachedMessage = namedtuple( 'CachedMessage', 'kind uid ts' ) class Message: QUERY = 0xa QUERY_HIT = 0xb NOT_AVAILABLE = 0xc def __init__(self, kind, uid=None, **fields): self.kind = kind self.uid = uid if uid else get_random_bytes(16) self.fields = fields if self.kind == Message.QUERY: if 'hash' not in self.fields\ or type(self.fields['hash']) != bytes\ or len(self.fields['hash']) != 32: raise ValueError('malformed `QUERY\' message: illegal or missing `hash\' field') if 'ttl' not in self.fields\ or type(self.fields['ttl']) != int\ or self.fields['ttl'] < 0: raise ValueError('malformed `QUERY\' message: illegal or missing `ttl\' field') elif self.kind == Message.QUERY_HIT: if 'data' not in self.fields\ or type(self.fields['data']) != bytes\ or len(self.fields['data']) < 1\ or len(self.fields['data']) > MAX_PIECE_SIZE: raise ValueError('malformed `QUERY_HIT\' message: illegal or missing `data\' field') elif self.kind == Message.NOT_AVAILABLE: if self.fields: raise ValueError('malformed `NOT_AVAILABLE\' message: unexpected payload') elif self.kind == ServiceMessage.PING: if self.fields: raise ValueError('malformed `PING\' message: unexpected payload') elif self.kind == ServiceMessage.PONG: if self.fields: raise ValueError('malformed `PONG\' message: unexpected payload') elif self.kind == ServiceMessage.CLOSE: if self.fields: raise ValueError('malformed `CLOSE\' message: unexpected payload') def __getattr__(self, field): if field not in self.fields: raise Error(f'missing required field `{field}\'') return self.fields[field] def cache(self): for message in cache: if ( message.uid == self.uid and message.kind == self.kind ): return False cache.append( CachedMessage( kind=self.kind, uid=self.uid, ts=time.time() ) ) return True def is_response_for(self, message): if self.uid != message.uid: return False if message.kind == ServiceMessage.PING: return self.kind == ServiceMessage.PONG elif message.kind == Message.QUERY: return self.kind in ( Message.QUERY_HIT, Message.NOT_AVAILABLE ) return False class ServiceMessage(Message): HELLO = 0x0 CHALLENGE = 0x1 ANSWER = 0x2 FINISH = 0x3 PING = 0x4 PONG = 0x5 CLOSE = 0x6 class Peer: def __init__(self, reader, writer, address=None): self.reader = reader self.writer = writer self.address = address self.key = None self.queue = [] self.is_open = True self.send_lock = asyncio.Lock() self.receive_lock = asyncio.Lock() self.last_message_ts = -1 async def write(peer, data): peer.writer.write(data) await peer.writer.drain() async def read(peer, size): buffer = b'' while len(buffer) < size: buffer += await peer.reader.read(size - len(buffer)) return buffer async def send(peer, message): if type(message) is ServiceMessage: buffer = bytes([message.kind]) if message.fields: payload = cbor2.dumps(message.fields) buffer += struct.pack(' ServiceMessage.CLOSE: raise Error('unecrypted non-service messages are not allowed') length = struct.unpack(' MAX_PAYLOAD_SIZE: raise Error('payload is too large') payload = {} if length: payload = await read(peer, length) payload = cbor2.loads(payload) return ServiceMessage( kind, **payload ) head = await read(peer, 116) head = ecies.decrypt( config['Secret'], head ) kind = head[0] uid = head[1:17] chunks_count = struct.unpack(' MAX_PAYLOAD_SIZE: raise Error('payload is too large') for _ in range(chunks_count): length = struct.unpack(' CHUNK_SIZE: raise Error('illegal chunk length') payload += await read(peer, length) payload = ecies.decrypt( config['Secret'], payload ) payload = zstd.decompress(payload) payload = cbor2.loads(payload) return Message( kind, uid=uid, **payload ) async def close(peer, gracefully=True): if not peer.is_open: return if gracefully: try: await send( peer, ServiceMessage( ServiceMessage.CLOSE ) ) except: pass peer.writer.close() peer.is_open = False try: await asyncio.wait_for( peer.writer.wait_closed(), timeout=3 ) except: pass async def wait_response(peer, message): while peer.is_open: for other_message in peer.queue: if other_message.is_response_for(message): peer.queue.remove(other_message) return other_message await asyncio.sleep(0) async def communicate(peer, message, timeout=None): await send(peer, message) answer = await asyncio.wait_for( wait_response( peer, message ), timeout=timeout ) if not answer: raise Error('communication timeout') return answer async def ping(peer): await communicate( peer, ServiceMessage( ServiceMessage.PING ), timeout=PING_TIMEOUT ) async def respond(peer, message, kind, is_service=False, **data): await send( peer, (ServiceMessage if is_service else Message)( kind, uid=message.uid, **data ) ) async def query(hash, uid=None, ttl=0, filter=None): if await is_piece_exists(hash): return await read_piece(hash) answer = await broadcast( Message( Message.QUERY, uid=uid, hash=hash, ttl=ttl ), message_filter=lambda answer: answer.kind == Message.QUERY_HIT and sha256(answer.data) == hash, peer_filter=filter ) if not answer: return None await save_piece( answer.data, hash ) return answer.data async def broadcast(message, message_filter=None, peer_filter=None): if message.ttl >= MAX_TTL: return message.fields['ttl'] += 1 for peer in random.sample(peers, len(peers)): if not peer.is_open: continue if peer_filter and not peer_filter(peer): continue try: answer = await communicate( peer, message, timeout=BROADCAST_TIMEOUT ) if message_filter and message_filter(answer): return answer except: continue async def tick(peer): while peer.is_open: attempts = PINGS_COUNT while True: try: await ping(peer) except: attempts -= 1 if attempts < 1: self.close(False) return break await asyncio.sleep(HEARTBEAT) async def handshake(peer): await send( peer, ServiceMessage( ServiceMessage.HELLO, key=config['Key'] ) ) answer = await receive(peer) if answer.kind != ServiceMessage.HELLO: raise Error('handshake failed: illegal initial message') key = answer.key if key == config['Key']: raise Error('handshake failed: looping connection') for peer in peers.copy(): if peer.key == key: if not peer.is_open: peers.remove(peer) continue raise Error('handshake failed: duplicated connection') data = get_random_bytes(16) await send( peer, ServiceMessage( ServiceMessage.CHALLENGE, data=ecies.encrypt( key, data ) ) ) answer = await receive(peer) if answer.kind != ServiceMessage.CHALLENGE: raise Error('handshake failed: illegal challenge initiation message') await send( peer, ServiceMessage( ServiceMessage.ANSWER, data=ecies.decrypt( config['Secret'], answer.data ) ) ) answer = await receive(peer) if answer.kind != ServiceMessage.ANSWER: raise Error('handshake failed: illegal challenge answer message') if answer.data != data: raise Error('handshake failed: challenge data mismatch') await send( peer, ServiceMessage( ServiceMessage.FINISH ) ) answer = await receive(peer) if answer.kind != ServiceMessage.FINISH: raise Error('handshake failed: illegal finish message') peer.key = key async def cooldown(peer): delta = time.time() - peer.last_message_ts if delta < COOLDOWN: await asyncio.sleep(COOLDOWN - delta) peer.last_message_ts = time.time() async def serve(peer): await asyncio.wait_for( handshake(peer), timeout=HANDSHAKE_TIMEOUT ) asyncio.create_task( tick(peer) ) if peer.address: await logger.info(f'Connected to {peer.address}.') while peer.is_open: message = await receive(peer) await cooldown(peer) if not message.cache(): if message.kind == Message.QUERY: await respond( peer, message, Message.NOT_AVAILABLE ) continue if message.kind in ( ServiceMessage.PONG, Message.QUERY_HIT, Message.NOT_AVAILABLE ): peer.queue.append(message) continue if message.kind == ServiceMessage.PING: await respond( peer, message, ServiceMessage.PONG, is_service=True ) elif message.kind == ServiceMessage.CLOSE: await close(peer, False) elif message.kind == Message.QUERY: answer = await query( message.hash, uid=message.uid, ttl=message.ttl, filter=lambda other_peer: other_peer.key not in (peer.key, config['Key']) ) if not answer: await respond( peer, message, Message.NOT_AVAILABLE ) continue await respond( peer, message, Message.QUERY_HIT, data=answer ) else: raise Error(f'unknown message kind: {hex(message.kind)}') async def accept(reader, writer, address=None): peer = Peer(reader, writer, address) peers.append(peer) try: await serve(peer) except Exception as e: if peer.address: await logger.warning(f'Connection lost to {peer.address}: {e}') finally: await close(peer) async def dial(address): parts = address.split(':') host = ':'.join(parts[:-1]) port = int(parts[-1]) try: reader, writer = await asyncio.open_connection( host, port ) except Exception as e: dummy_peer = Peer(None, None, address) dummy_peer.is_open = False peers.append(dummy_peer) await logger.error(f'Dial {address}: {e}') return asyncio.create_task( accept(reader, writer, address) ) async def listen(): try: server = await asyncio.start_server( accept, config['ListenAddress'], int(config['ListenPort']) ) except Exception as e: await logger.error(f'Bind {config["ListenAddress"]}:{config["ListenPort"]}: {e}') return await logger.info(f'Listening at {config["ListenAddress"]}:{config["ListenPort"]}') async with server: await server.serve_forever() async def watcher(): while True: for peer in peers.copy(): if not peer.is_open: peers.remove(peer) if peer.address: asyncio.create_task( dial(peer.address) ) continue for message in cache.copy(): if time.time() - message.ts > CACHE_LIFETIME: cache.remove(message) await asyncio.sleep(WATCHER_INTERVAL) async def shutdown(delay): await asyncio.sleep(delay) await logger.info('Performing graceful shutdown') for peer in peers: if not peer.is_open: continue try: await send( peer, ServiceMessage( ServiceMessage.CLOSE ) ) except: pass while True: os.kill(os.getpid(), 2) async def accept_admin(reader, writer): try: length = struct.unpack('') return try: async with aiofiles.open( sys.argv[1], 'r' ) as f: config = hjson.loads(await f.read()) except Exception as e: await logger.error(f'Load configuration `{sys.argv[1]}\': {e}') return if not await aiofiles.os.path.isdir( config['StoragePath'] ): await aiofiles.os.mkdir( config['StoragePath'] ) asyncio.create_task( watcher() ) for address in config['Peers']: await dial(address) asyncio.create_task( listen_admin() ) await listen() try: asyncio.run(main()) except KeyboardInterrupt: print('Interrupted.')