markov.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import os.path
  2. import re
  3. import atexit
  4. import string
  5. import spacy
  6. import ujson
  7. import markovify
  8. from config import config
  9. class Markov:
  10. def __init__(self):
  11. self.counter = 0
  12. self.corpus = []
  13. self.chain = None
  14. self.nlp = spacy.load("xx_sent_ud_sm")
  15. self.load()
  16. atexit.register(self.save)
  17. @property
  18. def is_ready(self):
  19. return self.chain is not None
  20. def generate(self, init_state=None):
  21. orig_init_state = init_state
  22. if init_state is not None:
  23. init_state = self.tokenize(init_state)
  24. init_state = tuple(init_state)
  25. size = len(init_state)
  26. if size < config.MARKOV_STATE_SIZE:
  27. init_state = (markovify.chain.BEGIN,) * (
  28. config.MARKOV_STATE_SIZE - size
  29. ) + init_state
  30. elif size > config.MARKOV_STATE_SIZE:
  31. init_state = init_state[: -config.MARKOV_STATE_SIZE]
  32. words = self.chain.walk(init_state)
  33. if not words:
  34. return self.generate(init_state)
  35. text = orig_init_state if orig_init_state is not None else ""
  36. for word in words:
  37. if word in "-–—" or not all(
  38. c in string.punctuation or c == "…" for c in word
  39. ):
  40. text += " "
  41. text += word
  42. return text.strip()
  43. def rebuild(self):
  44. self.chain = markovify.Chain(self.corpus, config.MARKOV_STATE_SIZE).compile()
  45. self.counter = 0
  46. def tokenize(self, text):
  47. text = re.sub(r"(@[A-Za-z0-9_]+,?)", "", text)
  48. text = re.sub(
  49. "https?:\\/\\/(?:www\\.)?[-a-zA-Z0-9@:%._\\+~#=]{1,256}\\.[a-zA-Z0-9()]{1,6}\\b(?:[-a-zA-Z0-9()@:%_\\+.~#?&\\/=]*)",
  50. "",
  51. text,
  52. )
  53. text = self.nlp(text)
  54. text = map(lambda word: str(word).strip(), text)
  55. text = filter(bool, text)
  56. return list(text)
  57. def extend_corpus(self, text):
  58. text = text.strip()
  59. if not text:
  60. return
  61. if "\n" in text:
  62. for line in text.split("\n"):
  63. self.extend_corpus(line)
  64. return
  65. text = self.tokenize(text)
  66. if text not in self.corpus:
  67. self.corpus.insert(0, text)
  68. if (
  69. config.MARKOV_CORPUS_SIZE > 0
  70. and len(self.corpus) > config.MARKOV_CORPUS_SIZE
  71. ):
  72. self.corpus = self.corpus[: config.MARKOV_CORPUS_SIZE]
  73. self.counter += 1
  74. if (
  75. config.MARKOV_REBUILD_RATE > 0
  76. and self.counter % config.MARKOV_REBUILD_RATE == 0
  77. ):
  78. self.rebuild()
  79. def load(self):
  80. if os.path.isfile(config.MARKOV_CHAIN_PATH):
  81. with open(config.MARKOV_CHAIN_PATH, "r") as f:
  82. self.chain = markovify.Chain.from_json(f.read())
  83. if os.path.isfile(config.MARKOV_CORPUS_PATH):
  84. with open(config.MARKOV_CORPUS_PATH, "r") as f:
  85. self.corpus = ujson.load(f)
  86. def save(self):
  87. if self.chain:
  88. with open(config.MARKOV_CHAIN_PATH, "w") as f:
  89. f.write(self.chain.to_json())
  90. with open(config.MARKOV_CORPUS_PATH, "w") as f:
  91. ujson.dump(self.corpus, f)