123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868 |
- import os
- import sys
- import time
- import random
- import struct
- import asyncio
- from collections import namedtuple
- import zstd
- import cbor2
- import ecies
- import hjson
- import aiofiles
- import aiologger
- try:
- import aiofiles.os
- except AttributeError:
- os.link = os.symlink
- import aiofiles.os
- from Crypto.Hash import SHA256
- from Crypto.Random import get_random_bytes
- PINGS_COUNT = 3
- PING_TIMEOUT = 10
- HANDSHAKE_TIMEOUT = 10
- BROADCAST_TIMEOUT = 300
- HEARTBEAT = 10
- WATCHER_INTERVAL = 30
- CACHE_LIFETIME = 3600
- MAX_TTL = 7
- MAX_PAYLOAD_SIZE = 1024*1024*64
- CHUNK_SIZE = 512
- COOLDOWN = 0.1
- MAX_PIECE_SIZE = 1024*1024*10
- config = {}
- peers = []
- cache = []
- logger = aiologger.Logger.with_default_handlers(
- formatter=aiologger.formatters.base.Formatter(
- fmt='%(asctime)s %(levelname)s: %(message)s'
- )
- )
- def sha256(data):
- hash = SHA256.new()
- hash.update(data)
- return hash.digest()
- def chunks(l, n):
- for i in range(0, len(l), n):
- yield l[i:i + n]
- class Error(Exception): pass
- async def cleanup_pieces(age):
- count = 0
- for filename in aiofiles.os.listdir(
- config['StoragePath']
- ):
- path = os.path.join(
- config['StoragePath'],
- filename
- )
- access_time = await aiofiles.os.path.getatime(
- path
- )
- if time.time() - access_time > age:
- await logger.info(f'Purging `{filename}\' from the storage.')
- await aiofiles.os.remove(
- path
- )
- count += 1
- return count
- async def is_piece_exists(hash):
- if await aiofiles.os.path.isfile(
- os.path.join(
- config['StoragePath'],
- hash.hex()
- )
- ):
- return True
- return False
- async def save_piece(data, hash=None):
- if not hash:
- hash = sha256(data)
- path = os.path.join(
- config['StoragePath'],
- hash.hex()
- )
- async with aiofiles.open(
- path,
- 'wb'
- ) as f:
- data = zstd.compress(data)
- await f.write(data)
- return hash
- async def read_piece(hash):
- path = os.path.join(
- config['StoragePath'],
- hash.hex()
- )
- async with aiofiles.open(
- path,
- 'rb'
- ) as f:
- data = await f.read()
- data = zstd.decompress(data)
- if sha256(data) != hash:
- await aiofiles.os.remove(path)
- raise ValueError('piece actual checksum (i.e. on disk) and expected checksum do not match')
- os.utime(path)
- return data
- CachedMessage = namedtuple(
- 'CachedMessage',
- 'kind uid ts'
- )
- class Message:
- QUERY = 0xa
- QUERY_HIT = 0xb
- NOT_AVAILABLE = 0xc
- def __init__(self, kind, uid=None, **fields):
- self.kind = kind
- self.uid = uid if uid else get_random_bytes(16)
- self.fields = fields
- if self.kind == Message.QUERY:
- if 'hash' not in self.fields\
- or type(self.fields['hash']) != bytes\
- or len(self.fields['hash']) != 32:
- raise ValueError('malformed `QUERY\' message: illegal or missing `hash\' field')
- if 'ttl' not in self.fields\
- or type(self.fields['ttl']) != int\
- or self.fields['ttl'] < 0:
- raise ValueError('malformed `QUERY\' message: illegal or missing `ttl\' field')
- elif self.kind == Message.QUERY_HIT:
- if 'data' not in self.fields\
- or type(self.fields['data']) != bytes\
- or len(self.fields['data']) < 1\
- or len(self.fields['data']) > MAX_PIECE_SIZE:
- raise ValueError('malformed `QUERY_HIT\' message: illegal or missing `data\' field')
- elif self.kind == Message.NOT_AVAILABLE:
- if self.fields:
- raise ValueError('malformed `NOT_AVAILABLE\' message: unexpected payload')
- elif self.kind == ServiceMessage.PING:
- if self.fields:
- raise ValueError('malformed `PING\' message: unexpected payload')
- elif self.kind == ServiceMessage.PONG:
- if self.fields:
- raise ValueError('malformed `PONG\' message: unexpected payload')
- elif self.kind == ServiceMessage.CLOSE:
- if self.fields:
- raise ValueError('malformed `CLOSE\' message: unexpected payload')
- def __getattr__(self, field):
- if field not in self.fields:
- raise Error(f'missing required field `{field}\'')
- return self.fields[field]
- def cache(self):
- for message in cache:
- if (
- message.uid == self.uid and
- message.kind == self.kind
- ):
- return False
- cache.append(
- CachedMessage(
- kind=self.kind,
- uid=self.uid,
- ts=time.time()
- )
- )
- return True
- def is_response_for(self, message):
- if self.uid != message.uid:
- return False
- if message.kind == ServiceMessage.PING:
- return self.kind == ServiceMessage.PONG
- elif message.kind == Message.QUERY:
- return self.kind in (
- Message.QUERY_HIT,
- Message.NOT_AVAILABLE
- )
- return False
- class ServiceMessage(Message):
- HELLO = 0x0
- CHALLENGE = 0x1
- ANSWER = 0x2
- FINISH = 0x3
- PING = 0x4
- PONG = 0x5
- CLOSE = 0x6
- class Peer:
- def __init__(self, reader, writer, address=None):
- self.reader = reader
- self.writer = writer
- self.address = address
- self.key = None
- self.queue = []
- self.is_open = True
- self.send_lock = asyncio.Lock()
- self.receive_lock = asyncio.Lock()
- self.last_message_ts = -1
- async def write(peer, data):
- peer.writer.write(data)
- await peer.writer.drain()
- async def read(peer, size):
- buffer = b''
- while len(buffer) < size:
- buffer += await peer.reader.read(size - len(buffer))
- return buffer
- async def send(peer, message):
- if type(message) is ServiceMessage:
- buffer = bytes([message.kind])
- if message.fields:
- payload = cbor2.dumps(message.fields)
- buffer += struct.pack('<H', len(payload))
- buffer += payload
- else:
- buffer += bytes(2)
- async with peer.send_lock:
- await write(peer, buffer)
- return
- payload = b''
- chunks_count = 0
- if message.fields:
- payload = cbor2.dumps(message.fields)
- payload = zstd.compress(payload)
- payload = ecies.encrypt(
- peer.key,
- payload
- )
- chunks_count = max(1, len(payload) // CHUNK_SIZE)
- buffer = b'\xff' + ecies.encrypt(
- peer.key,
- bytes([message.kind]) +
- message.uid +
- struct.pack('<H', chunks_count)
- )
- async with peer.send_lock:
- await write(peer, buffer)
- if chunks_count:
- for chunk in chunks(payload, CHUNK_SIZE):
- await write(
- peer,
- struct.pack('<H', len(chunk)) +
- chunk
- )
- async def receive(peer):
- async with peer.receive_lock:
- kind = (await read(peer, 1))[0]
-
- if kind != 0xff:
- if kind > ServiceMessage.CLOSE:
- raise Error('unecrypted non-service messages are not allowed')
- length = struct.unpack('<H', await read(peer, 2))[0]
- if length > MAX_PAYLOAD_SIZE:
- raise Error('payload is too large')
- payload = {}
- if length:
- payload = await read(peer, length)
- payload = cbor2.loads(payload)
- return ServiceMessage(
- kind,
- **payload
- )
- head = await read(peer, 116)
- head = ecies.decrypt(
- config['Secret'],
- head
- )
- kind = head[0]
- uid = head[1:17]
- chunks_count = struct.unpack('<H', head[17:19])[0]
- payload = {}
- if chunks_count:
- payload = b''
- if chunks_count * CHUNK_SIZE > MAX_PAYLOAD_SIZE:
- raise Error('payload is too large')
- for _ in range(chunks_count):
- length = struct.unpack('<H', await read(peer, 2))[0]
- if not length or length > CHUNK_SIZE:
- raise Error('illegal chunk length')
- payload += await read(peer, length)
- payload = ecies.decrypt(
- config['Secret'],
- payload
- )
- payload = zstd.decompress(payload)
- payload = cbor2.loads(payload)
- return Message(
- kind,
- uid=uid,
- **payload
- )
- async def close(peer, gracefully=True):
- if not peer.is_open:
- return
- if gracefully:
- try:
- await send(
- peer,
- ServiceMessage(
- ServiceMessage.CLOSE
- )
- )
- except:
- pass
- peer.writer.close()
- peer.is_open = False
- try:
- await asyncio.wait_for(
- peer.writer.wait_closed(),
- timeout=3
- )
- except:
- pass
- async def wait_response(peer, message):
- while peer.is_open:
- for other_message in peer.queue:
- if other_message.is_response_for(message):
- peer.queue.remove(other_message)
- return other_message
- await asyncio.sleep(0)
- async def communicate(peer, message, timeout=None):
- await send(peer, message)
-
- answer = await asyncio.wait_for(
- wait_response(
- peer,
- message
- ),
- timeout=timeout
- )
- if not answer:
- raise Error('communication timeout')
- return answer
- async def ping(peer):
- await communicate(
- peer,
- ServiceMessage(
- ServiceMessage.PING
- ),
- timeout=PING_TIMEOUT
- )
- async def respond(peer, message, kind, is_service=False, **data):
- await send(
- peer,
- (ServiceMessage if is_service else Message)(
- kind,
- uid=message.uid,
- **data
- )
- )
- async def query(hash, uid=None, ttl=0, filter=None):
- if await is_piece_exists(hash):
- return await read_piece(hash)
- answer = await broadcast(
- Message(
- Message.QUERY,
- uid=uid,
- hash=hash,
- ttl=ttl
- ),
- message_filter=lambda answer: answer.kind == Message.QUERY_HIT and sha256(answer.data) == hash,
- peer_filter=filter
- )
- if not answer:
- return None
- await save_piece(
- answer.data,
- hash
- )
- return answer.data
- async def broadcast(message, message_filter=None, peer_filter=None):
- if message.ttl >= MAX_TTL:
- return
- message.fields['ttl'] += 1
- for peer in random.sample(peers, len(peers)):
- if not peer.is_open:
- continue
- if peer_filter and not peer_filter(peer):
- continue
- try:
- answer = await communicate(
- peer,
- message,
- timeout=BROADCAST_TIMEOUT
- )
- if message_filter and message_filter(answer):
- return answer
- except:
- continue
- async def tick(peer):
- while peer.is_open:
- attempts = PINGS_COUNT
- while True:
- try:
- await ping(peer)
- except:
- attempts -= 1
- if attempts < 1:
- self.close(False)
- return
- break
- await asyncio.sleep(HEARTBEAT)
- async def handshake(peer):
- await send(
- peer,
- ServiceMessage(
- ServiceMessage.HELLO,
- key=config['Key']
- )
- )
- answer = await receive(peer)
- if answer.kind != ServiceMessage.HELLO:
- raise Error('handshake failed: illegal initial message')
- key = answer.key
- if key == config['Key']:
- raise Error('handshake failed: looping connection')
- for peer in peers.copy():
- if peer.key == key:
- if not peer.is_open:
- peers.remove(peer)
- continue
- raise Error('handshake failed: duplicated connection')
- data = get_random_bytes(16)
- await send(
- peer,
- ServiceMessage(
- ServiceMessage.CHALLENGE,
- data=ecies.encrypt(
- key,
- data
- )
- )
- )
- answer = await receive(peer)
- if answer.kind != ServiceMessage.CHALLENGE:
- raise Error('handshake failed: illegal challenge initiation message')
- await send(
- peer,
- ServiceMessage(
- ServiceMessage.ANSWER,
- data=ecies.decrypt(
- config['Secret'],
- answer.data
- )
- )
- )
- answer = await receive(peer)
- if answer.kind != ServiceMessage.ANSWER:
- raise Error('handshake failed: illegal challenge answer message')
- if answer.data != data:
- raise Error('handshake failed: challenge data mismatch')
- await send(
- peer,
- ServiceMessage(
- ServiceMessage.FINISH
- )
- )
- answer = await receive(peer)
- if answer.kind != ServiceMessage.FINISH:
- raise Error('handshake failed: illegal finish message')
- peer.key = key
- async def cooldown(peer):
- delta = time.time() - peer.last_message_ts
- if delta < COOLDOWN:
- await asyncio.sleep(COOLDOWN - delta)
- peer.last_message_ts = time.time()
- async def serve(peer):
- await asyncio.wait_for(
- handshake(peer),
- timeout=HANDSHAKE_TIMEOUT
- )
- asyncio.create_task(
- tick(peer)
- )
- if peer.address:
- await logger.info(f'Connected to {peer.address}.')
- while peer.is_open:
- message = await receive(peer)
- await cooldown(peer)
- if not message.cache():
- if message.kind == Message.QUERY:
- await respond(
- peer,
- message,
- Message.NOT_AVAILABLE
- )
- continue
- if message.kind in (
- ServiceMessage.PONG,
- Message.QUERY_HIT,
- Message.NOT_AVAILABLE
- ):
- peer.queue.append(message)
- continue
- if message.kind == ServiceMessage.PING:
- await respond(
- peer,
- message,
- ServiceMessage.PONG,
- is_service=True
- )
- elif message.kind == ServiceMessage.CLOSE:
- await close(peer, False)
- elif message.kind == Message.QUERY:
- answer = await query(
- message.hash,
- uid=message.uid,
- ttl=message.ttl,
- filter=lambda other_peer: other_peer.key not in (peer.key, config['Key'])
- )
- if not answer:
- await respond(
- peer,
- message,
- Message.NOT_AVAILABLE
- )
- continue
- await respond(
- peer,
- message,
- Message.QUERY_HIT,
- data=answer
- )
- else:
- raise Error(f'unknown message kind: {hex(message.kind)}')
- async def accept(reader, writer, address=None):
- peer = Peer(reader, writer, address)
- peers.append(peer)
- try:
- await serve(peer)
- except Exception as e:
- if peer.address:
- await logger.warning(f'Connection lost to {peer.address}: {e}')
- finally:
- await close(peer)
- async def dial(address):
- parts = address.split(':')
- host = ':'.join(parts[:-1])
- port = int(parts[-1])
- try:
- reader, writer = await asyncio.open_connection(
- host,
- port
- )
- except Exception as e:
- dummy_peer = Peer(None, None, address)
- dummy_peer.is_open = False
- peers.append(dummy_peer)
- await logger.error(f'Dial {address}: {e}')
- return
- asyncio.create_task(
- accept(reader, writer, address)
- )
- async def listen():
- try:
- server = await asyncio.start_server(
- accept,
- config['ListenAddress'],
- int(config['ListenPort'])
- )
- except Exception as e:
- await logger.error(f'Bind {config["ListenAddress"]}:{config["ListenPort"]}: {e}')
- return
- await logger.info(f'Listening at {config["ListenAddress"]}:{config["ListenPort"]}')
- async with server:
- await server.serve_forever()
- async def watcher():
- while True:
- for peer in peers.copy():
- if not peer.is_open:
- peers.remove(peer)
- if peer.address:
- asyncio.create_task(
- dial(peer.address)
- )
- continue
- for message in cache.copy():
- if time.time() - message.ts > CACHE_LIFETIME:
- cache.remove(message)
- await asyncio.sleep(WATCHER_INTERVAL)
- async def shutdown(delay):
- await asyncio.sleep(delay)
- await logger.info('Performing graceful shutdown')
- for peer in peers:
- if not peer.is_open:
- continue
- try:
- await send(
- peer,
- ServiceMessage(
- ServiceMessage.CLOSE
- )
- )
- except:
- pass
- while True:
- os.kill(os.getpid(), 2)
- async def accept_admin(reader, writer):
- try:
- length = struct.unpack('<I', await reader.read(4))[0]
- request = cbor2.loads(await reader.read(length))
- response = {}
- if 'store' in request:
- hash = await save_piece(
- request['store']['piece']
- )
- response['hash'] = hash
- elif 'query' in request:
- piece = await query(
- request['query']['hash']
- )
- if piece:
- response['piece'] = piece
- elif 'shutdown' in request:
- delay = int(request['shutdown']['delay'])
- await logger.info(f'Requested shutdown in {delay}sec.')
- asyncio.create_task(
- shutdown(delay)
- )
- elif 'cleanup' in request:
- age = int(request['cleanup']['age'])
- response['removed_count'] = await cleanup_pieces(
- age
- )
- else:
- raise Error('unrecognized command')
- except Exception as e:
- await logger.error(f'Process request on admin socket: {e}')
- response = cbor2.dumps(response)
-
- writer.write(struct.pack('<I', len(response)))
- writer.write(response)
- await writer.drain()
- async def listen_admin():
- try:
- server = await asyncio.start_unix_server(
- accept_admin,
- config['AdminSocketPath'],
- )
- except Exception as e:
- await logger.error(f'Bind {config["AdminSocketPath"]}: {e}')
- return
- async with server:
- await server.serve_forever()
- async def main():
- global config
- if len(sys.argv) < 2:
- print(f'Usage: {sys.argv[0]} <config.conf>')
- return
- try:
- async with aiofiles.open(
- sys.argv[1],
- 'r'
- ) as f:
- config = hjson.loads(await f.read())
- except Exception as e:
- await logger.error(f'Load configuration `{sys.argv[1]}\': {e}')
- return
- if not await aiofiles.os.path.isdir(
- config['StoragePath']
- ):
- await aiofiles.os.mkdir(
- config['StoragePath']
- )
- asyncio.create_task(
- watcher()
- )
- for address in config['Peers']:
- await dial(address)
- asyncio.create_task(
- listen_admin()
- )
- await listen()
- try:
- asyncio.run(main())
- except KeyboardInterrupt:
- print('Interrupted.')
|