txlyre 1 月之前
父节点
当前提交
569a3fc6da
共有 3 个文件被更改,包括 317 次插入303 次删除
  1. 290 290
      actions.py
  2. 8 1
      commands.py
  3. 19 12
      markov.py

+ 290 - 290
actions.py

@@ -1,291 +1,291 @@
-from random import uniform
-from asyncio import sleep
-
-from tortoise.contrib.postgres.functions import Random
-from telethon.utils import get_input_document
-from telethon.tl.functions.stickers import (
-    CreateStickerSetRequest,
-    AddStickerToSetRequest,
-)
-from telethon.tl.functions.messages import UploadMediaRequest, GetStickerSetRequest
-from telethon.tl.types import (
-    InputStickerSetID,
-    InputStickerSetShortName,
-    InputStickerSetItem,
-    InputMediaUploadedDocument,
-    InputPeerSelf,
-)
-from tortoise.expressions import F
-
-from models import (
-    Action,
-    Gif,
-    StickerPack,
-    Admin,
-    BirthDay,
-    VPNServer,
-    AllowedChat,
-    MarkovChat,
-)
-from utils import is_valid_name, is_valid_ip
-from config import config
-
-
-async def is_admin(bot, user):
-    admin = await bot.get_entity(config.ADMIN)
-
-    if user.id == admin.id:
-        return True
-
-    admin = await Admin.filter(user_id=user.id)
-    if admin:
-        return True
-
-    return False
-
-
-async def add_admin(user):
-    await Admin(user_id=user.id).save()
-
-
-async def delete_admin(user):
-    admin = await Admin.filter(user_id=user.id).first()
-    if not admin:
-        raise IndexError
-
-    await admin.delete()
-
-
-async def create_action(name, template, kind):
-    if not is_valid_name(name):
-        raise SyntaxError
-
-    await Action(name=name, template=template, kind=kind).save()
-
-
-async def find_action(name):
-    if not is_valid_name(name):
-        raise SyntaxError
-
-    return await Action.filter(name=name).first()
-
-
-async def delete_action(name):
-    action = await find_action(name)
-    if not action:
-        raise NameError
-
-    gifs = await action.gifs.all()
-    for gif in gifs:
-        await gif.delete()
-
-    await action.delete()
-
-
-async def add_gif(action, file_id):
-    await Gif(action=action, file_id=file_id).save()
-
-
-async def get_random_gif(action):
-    return await action.gifs.all().annotate(order=Random()).order_by("order").first()
-
-
-async def create_new_pack(bot, sticker):
-    last_pack = await StickerPack.all().order_by("-id").first()
-    set_id = last_pack.id + 1 if last_pack else 1
-
-    user = await bot.get_entity(config.USER)
-
-    me = await bot.get_me()
-    bot_username = me.username
-
-    pack = await bot(
-        CreateStickerSetRequest(
-            user_id=user.id,
-            title=f"Messages #{set_id}.",
-            short_name=f"messages{set_id}_by_{bot_username}",
-            stickers=[sticker],
-        )
-    )
-
-    sid = pack.set.id
-    hash = pack.set.access_hash
-
-    await StickerPack(short_name=pack.set.short_name, sid=sid, hash=hash).save()
-
-    return pack
-
-
-async def get_current_pack(bot):
-    pack = await StickerPack.all().order_by("-id").first()
-    if not pack or pack.stickers_count >= 119:
-        return None
-
-    return pack
-
-
-async def add_sticker(bot, file, emoji):
-    file = await bot.upload_file(file)
-    file = InputMediaUploadedDocument(file, "image/png", [])
-    file = await bot(UploadMediaRequest(InputPeerSelf(), file))
-    file = get_input_document(file)
-    sticker = InputStickerSetItem(document=file, emoji=emoji)
-
-    pack = await get_current_pack(bot)
-    if not pack:
-        pack = await create_new_pack(bot, sticker)
-    else:
-        await StickerPack.filter(id=pack.id).update(
-            stickers_count=F("stickers_count") + 1
-        )
-
-        pack = await bot(
-            AddStickerToSetRequest(
-                stickerset=InputStickerSetID(id=pack.sid, access_hash=pack.hash),
-                sticker=sticker,
-            )
-        )
-
-    return get_input_document(pack.documents[-1])
-
-
-async def get_birthdays(peer_id):
-    return await BirthDay.filter(peer_id=peer_id).all()
-
-
-async def add_or_update_birthday(peer_id, user, date):
-    birthday = await BirthDay.filter(peer_id=peer_id, user_id=user.id).first()
-
-    if birthday:
-        await BirthDay.filter(id=birthday.id).update(date=date)
-
-        return False
-
-    await BirthDay(peer_id=peer_id, user_id=user.id, date=date).save()
-
-    return True
-
-
-async def get_all_birthdays():
-    return await BirthDay.all()
-
-
-async def add_server(name, ip):
-    if not is_valid_ip(ip):
-        raise ValueError
-
-    await VPNServer(name=name, ip=ip).save()
-
-
-async def add_server(name, ip):
-    if not is_valid_name(name):
-        raise SyntaxError
-
-    if not is_valid_ip(ip):
-        raise ValueError
-
-    await VPNServer(name=name, ip=ip).save()
-
-
-async def delete_server(name):
-    if not is_valid_name(name):
-        raise SyntaxError
-
-    server = await VPNServer.filter(name=name).first()
-
-    if not server:
-        raise IndexError
-
-    await server.delete()
-
-
-async def list_servers():
-    servers = await VPNServer.all()
-
-    if not servers:
-        return "*пусто*"
-
-    return ", ".join(map(lambda server: server.name, servers))
-
-
-async def get_server_ip(name):
-    if not is_valid_name(name):
-        raise SyntaxError
-
-    server = await VPNServer.filter(name=name).first()
-
-    if not server:
-        raise IndexError
-
-    return server.ip
-
-
-async def add_allowed(peer_id):
-    await AllowedChat(peer_id=peer_id).save()
-
-
-async def delete_allowed(peer_id):
-    chat = await AllowedChat.filter(peer_id=peer_id).first()
-
-    if not chat:
-        raise IndexError
-
-    await chat.delete()
-
-
-async def is_allowed(peer_id):
-    return await AllowedChat.filter(peer_id=peer_id).exists()
-
-
-async def is_markov_enabled(peer_id):
-    return await MarkovChat.filter(peer_id=peer_id).exists()
-
-
-async def enable_markov(peer_id):
-    await MarkovChat(peer_id=peer_id).save()
-
-
-async def set_markov_options(peer_id, **options):
-    chat = await MarkovChat.filter(peer_id=peer_id).first()
-
-    if not chat:
-        raise IndexError
-
-    await MarkovChat.filter(id=chat.id).update(**options)
-
-
-async def get_markov_option(peer_id, option):
-    chat = await MarkovChat.filter(peer_id=peer_id).first()
-
-    if not chat:
-        raise IndexError
-
-    return getattr(chat, option)
-
-
-async def disable_markov(peer_id):
-    chat = await MarkovChat.filter(peer_id=peer_id).first()
-
-    if not chat:
-        raise IndexError
-
-    await chat.delete()
-
-
-async def list_markov_chats():
-    return await MarkovChat.all()
-
-async def markov_say(bot, peer_id, reply_to=None):
-    if not bot.markov.is_ready:
-        return
-
-    text = bot.markov.generate()
-
-    async with bot.action(peer_id, "typing"):
-        amount = 0
-        for _ in range(len(text)):
-            amount += round(uniform(0.05, 0.2), 2)
-
-        await sleep(min(amount, 8))
-
+from random import uniform
+from asyncio import sleep
+
+from tortoise.contrib.postgres.functions import Random
+from telethon.utils import get_input_document
+from telethon.tl.functions.stickers import (
+    CreateStickerSetRequest,
+    AddStickerToSetRequest,
+)
+from telethon.tl.functions.messages import UploadMediaRequest, GetStickerSetRequest
+from telethon.tl.types import (
+    InputStickerSetID,
+    InputStickerSetShortName,
+    InputStickerSetItem,
+    InputMediaUploadedDocument,
+    InputPeerSelf,
+)
+from tortoise.expressions import F
+
+from models import (
+    Action,
+    Gif,
+    StickerPack,
+    Admin,
+    BirthDay,
+    VPNServer,
+    AllowedChat,
+    MarkovChat,
+)
+from utils import is_valid_name, is_valid_ip
+from config import config
+
+
+async def is_admin(bot, user):
+    admin = await bot.get_entity(config.ADMIN)
+
+    if user.id == admin.id:
+        return True
+
+    admin = await Admin.filter(user_id=user.id)
+    if admin:
+        return True
+
+    return False
+
+
+async def add_admin(user):
+    await Admin(user_id=user.id).save()
+
+
+async def delete_admin(user):
+    admin = await Admin.filter(user_id=user.id).first()
+    if not admin:
+        raise IndexError
+
+    await admin.delete()
+
+
+async def create_action(name, template, kind):
+    if not is_valid_name(name):
+        raise SyntaxError
+
+    await Action(name=name, template=template, kind=kind).save()
+
+
+async def find_action(name):
+    if not is_valid_name(name):
+        raise SyntaxError
+
+    return await Action.filter(name=name).first()
+
+
+async def delete_action(name):
+    action = await find_action(name)
+    if not action:
+        raise NameError
+
+    gifs = await action.gifs.all()
+    for gif in gifs:
+        await gif.delete()
+
+    await action.delete()
+
+
+async def add_gif(action, file_id):
+    await Gif(action=action, file_id=file_id).save()
+
+
+async def get_random_gif(action):
+    return await action.gifs.all().annotate(order=Random()).order_by("order").first()
+
+
+async def create_new_pack(bot, sticker):
+    last_pack = await StickerPack.all().order_by("-id").first()
+    set_id = last_pack.id + 1 if last_pack else 1
+
+    user = await bot.get_entity(config.USER)
+
+    me = await bot.get_me()
+    bot_username = me.username
+
+    pack = await bot(
+        CreateStickerSetRequest(
+            user_id=user.id,
+            title=f"Messages #{set_id}.",
+            short_name=f"messages{set_id}_by_{bot_username}",
+            stickers=[sticker],
+        )
+    )
+
+    sid = pack.set.id
+    hash = pack.set.access_hash
+
+    await StickerPack(short_name=pack.set.short_name, sid=sid, hash=hash).save()
+
+    return pack
+
+
+async def get_current_pack(bot):
+    pack = await StickerPack.all().order_by("-id").first()
+    if not pack or pack.stickers_count >= 119:
+        return None
+
+    return pack
+
+
+async def add_sticker(bot, file, emoji):
+    file = await bot.upload_file(file)
+    file = InputMediaUploadedDocument(file, "image/png", [])
+    file = await bot(UploadMediaRequest(InputPeerSelf(), file))
+    file = get_input_document(file)
+    sticker = InputStickerSetItem(document=file, emoji=emoji)
+
+    pack = await get_current_pack(bot)
+    if not pack:
+        pack = await create_new_pack(bot, sticker)
+    else:
+        await StickerPack.filter(id=pack.id).update(
+            stickers_count=F("stickers_count") + 1
+        )
+
+        pack = await bot(
+            AddStickerToSetRequest(
+                stickerset=InputStickerSetID(id=pack.sid, access_hash=pack.hash),
+                sticker=sticker,
+            )
+        )
+
+    return get_input_document(pack.documents[-1])
+
+
+async def get_birthdays(peer_id):
+    return await BirthDay.filter(peer_id=peer_id).all()
+
+
+async def add_or_update_birthday(peer_id, user, date):
+    birthday = await BirthDay.filter(peer_id=peer_id, user_id=user.id).first()
+
+    if birthday:
+        await BirthDay.filter(id=birthday.id).update(date=date)
+
+        return False
+
+    await BirthDay(peer_id=peer_id, user_id=user.id, date=date).save()
+
+    return True
+
+
+async def get_all_birthdays():
+    return await BirthDay.all()
+
+
+async def add_server(name, ip):
+    if not is_valid_ip(ip):
+        raise ValueError
+
+    await VPNServer(name=name, ip=ip).save()
+
+
+async def add_server(name, ip):
+    if not is_valid_name(name):
+        raise SyntaxError
+
+    if not is_valid_ip(ip):
+        raise ValueError
+
+    await VPNServer(name=name, ip=ip).save()
+
+
+async def delete_server(name):
+    if not is_valid_name(name):
+        raise SyntaxError
+
+    server = await VPNServer.filter(name=name).first()
+
+    if not server:
+        raise IndexError
+
+    await server.delete()
+
+
+async def list_servers():
+    servers = await VPNServer.all()
+
+    if not servers:
+        return "*пусто*"
+
+    return ", ".join(map(lambda server: server.name, servers))
+
+
+async def get_server_ip(name):
+    if not is_valid_name(name):
+        raise SyntaxError
+
+    server = await VPNServer.filter(name=name).first()
+
+    if not server:
+        raise IndexError
+
+    return server.ip
+
+
+async def add_allowed(peer_id):
+    await AllowedChat(peer_id=peer_id).save()
+
+
+async def delete_allowed(peer_id):
+    chat = await AllowedChat.filter(peer_id=peer_id).first()
+
+    if not chat:
+        raise IndexError
+
+    await chat.delete()
+
+
+async def is_allowed(peer_id):
+    return await AllowedChat.filter(peer_id=peer_id).exists()
+
+
+async def is_markov_enabled(peer_id):
+    return await MarkovChat.filter(peer_id=peer_id).exists()
+
+
+async def enable_markov(peer_id):
+    await MarkovChat(peer_id=peer_id).save()
+
+
+async def set_markov_options(peer_id, **options):
+    chat = await MarkovChat.filter(peer_id=peer_id).first()
+
+    if not chat:
+        raise IndexError
+
+    await MarkovChat.filter(id=chat.id).update(**options)
+
+
+async def get_markov_option(peer_id, option):
+    chat = await MarkovChat.filter(peer_id=peer_id).first()
+
+    if not chat:
+        raise IndexError
+
+    return getattr(chat, option)
+
+
+async def disable_markov(peer_id):
+    chat = await MarkovChat.filter(peer_id=peer_id).first()
+
+    if not chat:
+        raise IndexError
+
+    await chat.delete()
+
+
+async def list_markov_chats():
+    return await MarkovChat.all()
+
+async def markov_say(bot, peer_id, reply_to=None, init_state=None):
+    if not bot.markov.is_ready:
+        return
+
+    text = bot.markov.generate(init_state)
+
+    async with bot.action(peer_id, "typing"):
+        amount = 0
+        for _ in range(len(text)):
+            amount += round(uniform(0.05, 0.2), 2)
+
+        await sleep(min(amount, 8))
+
     await bot.send_message(peer_id, message=text, reply_to=reply_to)

+ 8 - 1
commands.py

@@ -661,8 +661,15 @@ async def say_handler(bot, event, command):
     if not bot.markov.is_ready:
         await event.reply("Генератор текста ещё не готов к использованию. Пожалуйста, попробуйте чуть позже.")
     else:
-        await markov_say(bot, get_peer_id(event.peer_id))
+        init_state = None
 
+        if command.argc > 0:
+          init_state = command.args_string
+
+        try:
+            await markov_say(bot, get_peer_id(event.peer_id), init_state=init_state)
+        except:
+            await event.reply("Ошибка :(")
 
 COMMANDS = {
     "newadmin": Handler(newadmin_handler, is_restricted=True),

+ 19 - 12
markov.py

@@ -27,8 +27,11 @@ class Markov:
     def is_ready(self):
         return self.chain is not None
 
-    def generate(self):
-        words = self.chain.walk()
+    def generate(self, init_state=None):
+        if isinstance(init_state, str):
+            init_state = self.tokenize(init_state)
+
+        words = self.chain.walk(init_state)
         if not words:
             return self.generate()
 
@@ -46,6 +49,19 @@ class Markov:
 
         self.counter = 0
 
+    def tokenize(self, text):
+        text = re.sub(r"(@[A-Za-z0-9_]+,?)", "", text)
+        text = re.sub(
+            "https?:\\/\\/(?:www\\.)?[-a-zA-Z0-9@:%._\\+~#=]{1,256}\\.[a-zA-Z0-9()]{1,6}\\b(?:[-a-zA-Z0-9()@:%_\\+.~#?&\\/=]*)",
+            "",
+            text,
+        )
+        text = self.nlp(text)
+        text = map(lambda word: str(word).strip(), text)
+        text = filter(bool, text)
+
+        return list(text)
+
     def extend_corpus(self, text):
         text = text.strip()
         if not text:
@@ -57,16 +73,7 @@ class Markov:
 
             return
 
-        text = re.sub(r"(@[A-Za-z0-9_]+,?)", "", text)
-        text = re.sub(
-            "https?:\\/\\/(?:www\\.)?[-a-zA-Z0-9@:%._\\+~#=]{1,256}\\.[a-zA-Z0-9()]{1,6}\\b(?:[-a-zA-Z0-9()@:%_\\+.~#?&\\/=]*)",
-            "",
-            text,
-        )
-        text = self.nlp(text)
-        text = map(lambda word: str(word).strip(), text)
-        text = filter(bool, text)
-        text = list(text)
+        text = self.tokenize(text)
 
         if text not in self.corpus:
             self.corpus.insert(0, text)