markov.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import os.path
  2. import re
  3. import atexit
  4. import string
  5. import threading
  6. import spacy
  7. import ujson
  8. import markovify
  9. from config import config
  10. class Markov:
  11. def __init__(self):
  12. self.rebuilding = False
  13. self.counter = 0
  14. self.corpus = []
  15. self.chain = None
  16. self.nlp = spacy.load("xx_sent_ud_sm")
  17. self.load()
  18. atexit.register(self.save)
  19. @property
  20. def is_ready(self):
  21. return self.chain is not None
  22. def generate(self, init_state=None):
  23. orig_init_state = init_state
  24. if init_state is not None:
  25. init_state = self.tokenize(init_state)
  26. init_state = tuple(init_state)
  27. size = len(init_state)
  28. if size < config.MARKOV_STATE_SIZE:
  29. init_state = (markovify.chain.BEGIN,) * (
  30. config.MARKOV_STATE_SIZE - size
  31. ) + init_state
  32. elif size > config.MARKOV_STATE_SIZE:
  33. init_state = init_state[: -config.MARKOV_STATE_SIZE]
  34. words = self.chain.walk(init_state)
  35. if not words:
  36. return self.generate(init_state)
  37. text = orig_init_state if orig_init_state is not None else ""
  38. for word in words:
  39. if word in "-–—" or not all(
  40. c in string.punctuation or c == "…" for c in word
  41. ):
  42. text += " "
  43. text += word
  44. return text.strip()
  45. def _rebuild(self):
  46. try:
  47. self.chain = markovify.Chain(self.corpus, config.MARKOV_STATE_SIZE).compile()
  48. finally:
  49. self.rebuilding = False
  50. def rebuild(self):
  51. if self.rebuilding:
  52. return
  53. self.rebuilding = True
  54. self.counter = 0
  55. t = threading.Thread(target=self._rebuild)
  56. t.start()
  57. def tokenize(self, text):
  58. text = re.sub(r"(@[A-Za-z0-9_]+,?)", "", text)
  59. text = re.sub(
  60. "https?:\\/\\/(?:www\\.)?[-a-zA-Z0-9@:%._\\+~#=]{1,256}\\.[a-zA-Z0-9()]{1,6}\\b(?:[-a-zA-Z0-9()@:%_\\+.~#?&\\/=]*)",
  61. "",
  62. text,
  63. )
  64. text = self.nlp(text)
  65. text = map(lambda word: str(word).strip(), text)
  66. text = filter(bool, text)
  67. return list(text)
  68. def extend_corpus(self, text):
  69. text = text.strip()
  70. if not text:
  71. return
  72. if "\n" in text:
  73. for line in text.split("\n"):
  74. self.extend_corpus(line)
  75. return
  76. text = self.tokenize(text)
  77. if text not in self.corpus:
  78. self.corpus.insert(0, text)
  79. if (
  80. config.MARKOV_CORPUS_SIZE > 0
  81. and len(self.corpus) > config.MARKOV_CORPUS_SIZE
  82. ):
  83. self.corpus = self.corpus[: config.MARKOV_CORPUS_SIZE]
  84. self.counter += 1
  85. if (
  86. config.MARKOV_REBUILD_RATE > 0
  87. and self.counter % config.MARKOV_REBUILD_RATE == 0
  88. ):
  89. self.rebuild()
  90. def load(self):
  91. if os.path.isfile(config.MARKOV_CHAIN_PATH):
  92. with open(config.MARKOV_CHAIN_PATH, "r") as f:
  93. self.chain = markovify.Chain.from_json(f.read())
  94. if os.path.isfile(config.MARKOV_CORPUS_PATH):
  95. with open(config.MARKOV_CORPUS_PATH, "r") as f:
  96. self.corpus = ujson.load(f)
  97. def save(self):
  98. if self.chain:
  99. with open(config.MARKOV_CHAIN_PATH, "w") as f:
  100. f.write(self.chain.to_json())
  101. with open(config.MARKOV_CORPUS_PATH, "w") as f:
  102. ujson.dump(self.corpus, f)