markov.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import os.path
  2. import re
  3. import atexit
  4. import string
  5. import threading
  6. import spacy
  7. import ujson
  8. import markovify
  9. from config import config
  10. class Markov:
  11. def __init__(self):
  12. self.counter = 0
  13. self.corpus = []
  14. self.chain = None
  15. self.nlp = spacy.load("xx_sent_ud_sm")
  16. self.load()
  17. atexit.register(self.save)
  18. @property
  19. def is_ready(self):
  20. return self.chain is not None
  21. def generate(self, init_state=None):
  22. orig_init_state = init_state
  23. if init_state is not None:
  24. init_state = self.tokenize(init_state)
  25. init_state = tuple(init_state)
  26. size = len(init_state)
  27. if size < config.MARKOV_STATE_SIZE:
  28. init_state = (markovify.chain.BEGIN,) * (
  29. config.MARKOV_STATE_SIZE - size
  30. ) + init_state
  31. elif size > config.MARKOV_STATE_SIZE:
  32. init_state = init_state[: -config.MARKOV_STATE_SIZE]
  33. words = self.chain.walk(init_state)
  34. if not words:
  35. return self.generate(init_state)
  36. text = orig_init_state if orig_init_state is not None else ""
  37. for word in words:
  38. if word in "-–—" or not all(
  39. c in string.punctuation or c == "…" for c in word
  40. ):
  41. text += " "
  42. text += word
  43. return text.strip()
  44. def _rebuild(self):
  45. self.chain = markovify.Chain(self.corpus, config.MARKOV_STATE_SIZE).compile()
  46. def rebuild(self):
  47. self.counter = 0
  48. t = threading.Thread(target=self._rebuild)
  49. t.start()
  50. def tokenize(self, text):
  51. text = re.sub(r"(@[A-Za-z0-9_]+,?)", "", text)
  52. text = re.sub(
  53. "https?:\\/\\/(?:www\\.)?[-a-zA-Z0-9@:%._\\+~#=]{1,256}\\.[a-zA-Z0-9()]{1,6}\\b(?:[-a-zA-Z0-9()@:%_\\+.~#?&\\/=]*)",
  54. "",
  55. text,
  56. )
  57. text = self.nlp(text)
  58. text = map(lambda word: str(word).strip(), text)
  59. text = filter(bool, text)
  60. return list(text)
  61. def extend_corpus(self, text):
  62. text = text.strip()
  63. if not text:
  64. return
  65. if "\n" in text:
  66. for line in text.split("\n"):
  67. self.extend_corpus(line)
  68. return
  69. text = self.tokenize(text)
  70. if text not in self.corpus:
  71. self.corpus.insert(0, text)
  72. if (
  73. config.MARKOV_CORPUS_SIZE > 0
  74. and len(self.corpus) > config.MARKOV_CORPUS_SIZE
  75. ):
  76. self.corpus = self.corpus[: config.MARKOV_CORPUS_SIZE]
  77. self.counter += 1
  78. if (
  79. config.MARKOV_REBUILD_RATE > 0
  80. and self.counter % config.MARKOV_REBUILD_RATE == 0
  81. ):
  82. self.rebuild()
  83. def load(self):
  84. if os.path.isfile(config.MARKOV_CHAIN_PATH):
  85. with open(config.MARKOV_CHAIN_PATH, "r") as f:
  86. self.chain = markovify.Chain.from_json(f.read())
  87. if os.path.isfile(config.MARKOV_CORPUS_PATH):
  88. with open(config.MARKOV_CORPUS_PATH, "r") as f:
  89. self.corpus = ujson.load(f)
  90. def save(self):
  91. if self.chain:
  92. with open(config.MARKOV_CHAIN_PATH, "w") as f:
  93. f.write(self.chain.to_json())
  94. with open(config.MARKOV_CORPUS_PATH, "w") as f:
  95. ujson.dump(self.corpus, f)