import os
import os.path
import sys
import time
import queue
import argparse
import tempfile
import threading

import cv2
import requests
from loguru import logger
from Crypto.Hash import BLAKE2b

try:
    import rdrand
except ImportError:
    logger.warning("RdRand is not available.")

try:
    from v4l2py.device import Device, VideoCapture
except ImportError:
    logger.warning("v4l2py is not available.")

tmp_dir = None
push_timeout = 10


class TemporaryFile:
    def __init__(self, name, io, delete):
        self.name = name
        self.__io = io
        self.__delete = delete

    def __getattr__(self, k):
        return getattr(self.__io, k)

    def __del__(self):
        if self.__delete:
            try:
                os.unlink(self.name)
            except FileNotFoundError:
                pass


def NamedTemporaryFile(
    mode="w+b", bufsize=-1, suffix="", prefix="tmp", dir=None, delete=True
):
    if not dir:
        dir = tempfile.gettempdir()

    name = os.path.join(dir, prefix + os.urandom(32).hex() + suffix)

    if mode is None:
        return TemporaryFile(name, None, delete)

    fh = open(name, "w+b", bufsize)
    if mode != "w+b":
        fh.close()
        fh = open(name, mode)

    return TemporaryFile(name, fh, delete)


def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i : i + n]


def run_check(cmd):
    logger.info(f"Executing '{cmd}'.")

    if os.system(cmd) != 0:
        raise ValueError("Exit code != 0.")


def extract_image(path):
    logger.info(f"Extract image '{path}'.")

    im = cv2.imread(path)

    data = []
    rows, cols, _ = im.shape
    for i in range(rows):
        for j in range(cols):
            r, g, b = im[i, j]

            byte = 0
            byte |= 1 << (r & 1)
            byte |= 1 << ((r >> 1) & 1)
            byte |= 1 << ((g >> 1) & 1)
            byte |= 1 << ((g >> 1) & 1)
            byte |= 1 << ((b >> 1) & 1)
            byte |= 1 << ((b >> 1) & 1)
            byte |= 1 << ((r >> 2) & 1)
            byte |= 1 << ((r >> 2) & 1)

            data.append(byte)

    return bytes(data)


def extract_wav(path):
    logger.info(f"Extract audio: '{path}'.")

    data = []

    with open(path, "rb") as f:
        for sample in chunks(f.read()[44:], 2):
            data.append(sample[0])

    return bytes(data)


def extract_video(path):
    logger.info(f"Extract video: '{path}'.")

    with tempfile.TemporaryDirectory(dir=tmp_dir) as tmpd:
        run_check(
            f"ffmpeg -hide_banner -loglevel error -y -i {path} -vf mpdecimate -r 1/1 {tmpd}/%d.bmp"
        )

        data = b""
        for filename in sorted(
            os.listdir(tmpd), key=lambda filename: int(filename.split(".")[0])
        ):
            data += extract_image(os.path.join(tmpd, filename))

        return data


def extract_lsbs(data):
    logger.info("Extract LSBs.")

    buffer = []

    if len(data) % 2 != 0:
        data = data[:-1]

    for chunk in chunks(data, 4):
        if len(chunk) != 4:
            break

        tmp_byte = 0
        for byte in chunk:
            for n in range(2):
                tmp_byte |= 1 << ((byte >> n) & 1)

        buffer.append(tmp_byte)

    return bytes(buffer)


def whiten(data):
    logger.info("Whitening.")

    buffer = b""
    for chunk in chunks(data, 256):
        if len(chunk) < 256:
            break

        buffer += BLAKE2b.new(data=chunk, digest_bits=256).digest()

    return buffer


def read_video(source, duration=60):
    tmpf = NamedTemporaryFile(suffix=".mkv", mode=None, dir=tmp_dir)

    run_check(
        f"ffmpeg -hide_banner -loglevel error -y -i {source} -t {duration} -acodec copy -vcodec copy {tmpf.name}"
    )

    return extract_video(tmpf.name)


def read_audio(source, duration=60):
    tmpf = NamedTemporaryFile(suffix=".wav", mode=None, dir=tmp_dir)

    run_check(
        f"ffmpeg -hide_banner -loglevel error -y -f alsa -i {source} -t {duration} -ar 44100 -f s16le -acodec pcm_s16le {tmpf.name}"
    )

    return extract_wav(tmpf.name)


def read_audio_video(source, duration=60):
    tmpf = NamedTemporaryFile(suffix=".mkv", mode=None, dir=tmp_dir)

    run_check(
        f"ffmpeg -hide_banner -loglevel error -y -i {source} -t {duration} -acodec copy -vcodec copy {tmpf.name}"
    )

    data_a = extract_video(tmpf.name)

    tmpf2 = NamedTemporaryFile(suffix=".wav", mode=None, dir=tmp_dir)

    run_check(
        f"ffmpeg -hide_banner -loglevel error -y -i {tmpf.name} -vn -ar 44100 -f s16le -acodec pcm_s16le {tmpf2.name}"
    )

    data_b = extract_wav(tmpf2.name)

    return bytes(a ^ b for a, b in zip(data_a, data_b))


def read_rdseed(_, amount=16):
    data = rdrand.rdseed_get_bytes(amount)
    if len(data) != amount or data.count(0) == amount:
        raise ValueError("bad data")

    return data


def sample(source, source_type, multiplier=1):
    match source_type:
        case "video":
            sampler = read_video
            multiplier *= 60

        case "audio":
            sampler = read_audio
            multiplier *= 60

        case "video+audio":
            sampler = read_audio_video
            multiplier *= 60

        case "rdseed":
            sampler = read_rdseed

        case _:
            raise ValueError(source_type)

    multiplier = int(multiplier)
    if multiplier < 1:
        raise ValueError(multiplier)

    logger.info("Sampling...")

    data = sampler(source, multiplier)

    logger.info(f"Sample ready: {len(data)}b.")

    if source_type != "rdseed":
        data = extract_lsbs(data)
        data = whiten(data)

    return data


def video2_sampler(q, source):
    with Device.from_id(source) as device:
        capture = VideoCapture(device)
        capture.set_format(
            device.info.frame_sizes[0].width, device.info.frame_sizes[0].height, "YUYV"
        )

        last = 0
        for frame in device:
            new = time.monotonic()

            if new - last >= 1:
                data = extract_lsbs(bytes(frame))
                data = whiten(data)

                logger.info(f"Sample ready: {len(data)}b.")

                q.put(data)

                last = new


def push(pool_url, data, secret):
    logger.info(f"Pushing {len(data)}b.")

    resp = requests.post(
        f"{pool_url}/api/pool",
        data=data,
        headers={"X-Secret": secret},
        timeout=(push_timeout, push_timeout),
    )

    (logger.success if resp.status_code == 200 else logger.error)(
        f"{resp.status_code}: {resp.text}"
    )


def puller(queue, source, source_type, multiplier):
    while True:
        try:
            data = sample(source, source_type, multiplier)
        except KeyboardInterrupt:
            logger.info("Interrupted by user.")

            sys.exit(0)
        except Exception as e:
            logger.error(f"Pull exception: {e}")

            continue

        for piece in chunks(data, 1024 * 500):
            queue.put(piece)


def pusher(queue, pool_url, secret, cooldown=0):
    while True:
        piece = queue.get()

        try:
            push(pool_url, piece, secret)
        except KeyboardInterrupt:
            logger.info("Interrupted by user.")

            sys.exit(0)
        except Exception as e:
            logger.error(f"Push exception: {e}")

            queue.put(piece)

        if cooldown:
            time.sleep(cooldown)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--source", type=str, required=True)
    parser.add_argument("--source-type", type=str, default="video+audio")
    parser.add_argument("--multiplier", type=float, default=1)
    parser.add_argument("--secret-file", type=str, default="./.secret")
    parser.add_argument("--cooldown", type=int, default=0)
    parser.add_argument("--pool-url", type=str, default="https://yebi.su")
    parser.add_argument("--push-timeout", type=int, default=10)
    parser.add_argument("--tmp-dir", type=str)

    args = parser.parse_args()

    if args.tmp_dir and os.path.isdir(args.tmp_dir):
        tmp_dir = args.tmp_dir

        logger.info(f"Changed temp-dir: '{tmp_dir}'")

    push_timeout = max(args.push_timeout, 1)

    with open(args.secret_file, "r") as f:
        lines = f.read().strip().split("\n")
        ident = lines[0].strip()
        secret = lines[1].strip()

        secret = f"{ident} {secret}"

    q = queue.Queue()

    pusher_th = threading.Thread(
        target=pusher, args=(q, args.pool_url, secret, args.cooldown)
    )

    if args.source_type == "video2":
        threading.Thread(target=video2_sampler, args=(q, args.source)).start()
    else:
        threading.Thread(
            target=puller, args=(q, args.source, args.source_type, args.multiplier)
        ).start()

    pusher_th = threading.Thread(
        target=pusher, args=(q, args.pool_url, secret, args.cooldown)
    )
    pusher_th.start()
    pusher_th.join()