wma.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. #!/usr/bin/python
  2. import sys
  3. import struct
  4. import lark
  5. GRAMMAR = r"""
  6. start: _NL? command ((_NL+|";") command)* _NL?
  7. command: LABEL? operation
  8. | LABEL (arg|mixed)
  9. | LABEL
  10. | INCLUDE
  11. ?operation: "nop" -> nop
  12. | ("mvj"|"mj") arg arg arg -> mj
  13. | "sblez" arg arg arg -> sjlez
  14. | "ablez" arg arg arg -> ajlez
  15. | "sblz" arg arg arg -> sjlz
  16. | "bles" arg arg arg -> jles
  17. | ("nbnz"|"tjt") arg arg -> tjt
  18. | ("dbnz"|"djt") arg arg -> djt
  19. | "sslez" arg arg -> sslez
  20. | "aslez" arg arg -> aslez
  21. | ("ibnc"|"ije") arg arg arg -> ije
  22. | "vblz" arg arg arg -> djlz
  23. | "xblz" arg arg arg -> xjlz
  24. | "dslz" arg -> dslz
  25. | ("ssgt"|"ssl") arg arg -> ssl
  26. | "mbnz" arg arg arg -> mbnz
  27. | "modbz" arg arg arg -> modbz
  28. | ("aja"|"aj") arg arg -> aj
  29. | "la" arg arg -> la
  30. | "ld" arg arg -> ld
  31. | "ia" arg -> ia
  32. | "jmc" arg -> jmc
  33. | ("jw"|"ja") arg -> ja
  34. | ("push"|"psh") arg -> psh
  35. | "pd" arg -> pd
  36. | "pop" arg -> pop
  37. | "shlbnz" arg arg arg -> shlbnz
  38. | "shrbnz" arg arg arg -> shrbnz
  39. | "nbz" arg arg -> nbz
  40. | "anz" arg arg arg -> anz
  41. | "abgz" arg arg arg -> abgz
  42. | "swp" arg arg -> swp
  43. | "add" arg arg -> h_add
  44. | "sub" arg arg -> h_sub
  45. | "inc" arg -> h_inc
  46. | "dec" arg -> h_dec
  47. | ("mov"|"mv") arg arg -> h_mov
  48. | ("jmp"|"j") arg -> h_jmp
  49. | "hlt" -> h_halt
  50. | "out" arg -> h_out
  51. | "outn" arg -> h_outn
  52. | "in" arg -> h_in
  53. mixed: arg arg+
  54. ?arg: INTEGER
  55. | DOUBLE
  56. | CHAR
  57. | CHARS
  58. | OFFSET
  59. | LABELOFFSET
  60. | QMARK
  61. | NAMEOFFSET
  62. | NAME
  63. | rep
  64. rep: arg "*" COUNT
  65. COUNT: /[0-9]+/
  66. INTEGER: /-?[0-9]+/
  67. DOUBLE: /-?[0-9]+\.[0-9]+/
  68. CHAR: "'" /./ "'"
  69. CHARS: "\"" /[^"]*/ "\""
  70. QMARK: "?"
  71. OFFSET: "?" /(-|\+)[0-9]+/
  72. LABELOFFSET: "?" /(-|\+)[A-Za-z][a-zA-Z0-9_]*/
  73. NAMEOFFSET: /[A-Za-z][a-zA-Z0-9_]*(-|\+)[0-9]+/
  74. LABEL: /[A-Za-z][a-zA-Z0-9_]*:/
  75. NAME: /[A-Za-z][a-zA-Z0-9_]*/
  76. INCLUDE: "+" /.+/
  77. _NL: /\n+/
  78. IG: /[ \t\r]+/
  79. COM: /#.*[^\n]/
  80. %ignore IG
  81. %ignore COM
  82. """
  83. class WMA:
  84. def __init__(self):
  85. self.buffer = []
  86. self.size = 0
  87. self.parser = lark.Lark(GRAMMAR)
  88. self.used_regs = []
  89. def emit(self, *ops):
  90. self.buffer.extend(ops)
  91. if type(ops[0]) is tuple and ops[0][0]:
  92. return
  93. self.size += len(ops)
  94. def compile_arg(self, arg):
  95. if type(arg) is lark.Tree:
  96. if arg.data == "mixed":
  97. for subnode in arg.children:
  98. self.compile_arg(subnode)
  99. elif arg.data == "rep":
  100. count = int(arg.children[1].value)
  101. for _ in range(count):
  102. self.compile_arg(arg.children[0])
  103. elif arg.type == "INTEGER":
  104. self.emit(int(arg.value))
  105. elif arg.type == "DOUBLE":
  106. self.emit(float(arg.value))
  107. elif arg.type == "CHAR":
  108. self.emit(ord(arg.value[1]))
  109. elif arg.type == "CHARS":
  110. for char in arg.value[1:-1]:
  111. self.emit(ord(char))
  112. elif arg.type == "QMARK":
  113. self.emit(self.size+1)
  114. elif arg.type == "OFFSET":
  115. self.emit(self.size+int(arg.value[1:])+1)
  116. elif arg.type == "LABELOFFSET":
  117. self.emit((False, arg.value[2:], arg.value[1]))
  118. elif arg.type == "NAMEOFFSET":
  119. n, o = arg.value.split(
  120. '+' if '+' in arg.value else '-'
  121. )
  122. self.emit((False, n, '+' if '+' in arg.value else '-', int(o)))
  123. elif arg.type == "NAME":
  124. if arg.value == "IO":
  125. self.emit(-1)
  126. elif arg.value == "Z":
  127. self.emit(-2)
  128. elif arg.value == "O":
  129. self.emit(-3)
  130. elif arg.value == "N":
  131. self.emit(-4)
  132. elif arg.value == "J":
  133. self.emit(-5)
  134. elif arg.value == "T":
  135. self.emit(-6)
  136. elif arg.value == "SP":
  137. self.emit(-7)
  138. elif arg.value == "EZ":
  139. self.emit(-8)
  140. elif arg.value == "SZ":
  141. self.emit(-9)
  142. elif arg.value == "MZ":
  143. self.emit(-10)
  144. elif arg.value == "JZ":
  145. self.emit(-11)
  146. elif arg.value == "W":
  147. self.emit(-12)
  148. elif arg.value == "MM":
  149. self.emit(-13)
  150. elif arg.value == "DR":
  151. self.emit(-14)
  152. elif arg.value == "ZZ":
  153. self.emit(-15)
  154. else:
  155. self.used_regs.append(arg.value)
  156. self.emit((False, arg.value))
  157. def compile_operation(self, op):
  158. if op.data == "nop":
  159. self.emit(0)
  160. elif op.data == "mj":
  161. self.emit(1)
  162. self.compile_arg(op.children[0])
  163. self.compile_arg(op.children[1])
  164. self.compile_arg(op.children[2])
  165. elif op.data == "sjlez":
  166. self.emit(2)
  167. self.compile_arg(op.children[0])
  168. self.compile_arg(op.children[1])
  169. self.compile_arg(op.children[2])
  170. elif op.data == "ajlez":
  171. self.emit(3)
  172. self.compile_arg(op.children[0])
  173. self.compile_arg(op.children[1])
  174. self.compile_arg(op.children[2])
  175. elif op.data == "sjlz":
  176. self.emit(4)
  177. self.compile_arg(op.children[0])
  178. self.compile_arg(op.children[1])
  179. self.compile_arg(op.children[2])
  180. elif op.data == "jles":
  181. self.emit(5)
  182. self.compile_arg(op.children[0])
  183. self.compile_arg(op.children[1])
  184. self.compile_arg(op.children[2])
  185. elif op.data == "tjt":
  186. self.emit(6)
  187. self.compile_arg(op.children[0])
  188. self.compile_arg(op.children[1])
  189. elif op.data == "djt":
  190. self.emit(7)
  191. self.compile_arg(op.children[0])
  192. self.compile_arg(op.children[1])
  193. elif op.data == "sslez":
  194. self.emit(8)
  195. self.compile_arg(op.children[0])
  196. self.compile_arg(op.children[1])
  197. elif op.data == "aslez":
  198. self.emit(9)
  199. self.compile_arg(op.children[0])
  200. self.compile_arg(op.children[1])
  201. elif op.data == "ije":
  202. self.emit(10)
  203. self.compile_arg(op.children[0])
  204. self.compile_arg(op.children[1])
  205. self.compile_arg(op.children[2])
  206. elif op.data == "djlz":
  207. self.emit(11)
  208. self.compile_arg(op.children[0])
  209. self.compile_arg(op.children[1])
  210. self.compile_arg(op.children[2])
  211. elif op.data == "xjlz":
  212. self.emit(12)
  213. self.compile_arg(op.children[0])
  214. self.compile_arg(op.children[1])
  215. self.compile_arg(op.children[2])
  216. elif op.data == "dslz":
  217. self.emit(13)
  218. self.compile_arg(op.children[0])
  219. elif op.data == "ssl":
  220. self.emit(14)
  221. self.compile_arg(op.children[0])
  222. self.compile_arg(op.children[1])
  223. elif op.data == "mbnz":
  224. self.emit(15)
  225. self.compile_arg(op.children[0])
  226. self.compile_arg(op.children[1])
  227. self.compile_arg(op.children[2])
  228. elif op.data == "modbz":
  229. self.emit(16)
  230. self.compile_arg(op.children[0])
  231. self.compile_arg(op.children[1])
  232. self.compile_arg(op.children[2])
  233. elif op.data == "aj":
  234. self.emit(17)
  235. self.compile_arg(op.children[0])
  236. self.compile_arg(op.children[1])
  237. elif op.data == "la":
  238. self.emit(18)
  239. self.compile_arg(op.children[0])
  240. self.compile_arg(op.children[1])
  241. elif op.data == "ld":
  242. self.emit(19)
  243. self.compile_arg(op.children[0])
  244. self.compile_arg(op.children[1])
  245. elif op.data == "ia":
  246. self.emit(20)
  247. self.compile_arg(op.children[0])
  248. elif op.data == "jmc":
  249. self.emit(21)
  250. self.compile_arg(op.children[0])
  251. elif op.data == "ja":
  252. self.emit(22)
  253. self.compile_arg(op.children[0])
  254. elif op.data == "psh":
  255. self.emit(23)
  256. self.compile_arg(op.children[0])
  257. elif op.data == "pd":
  258. self.emit(24)
  259. self.compile_arg(op.children[0])
  260. elif op.data == "pop":
  261. self.emit(25)
  262. self.compile_arg(op.children[0])
  263. elif op.data == "shlbnz":
  264. self.emit(26)
  265. self.compile_arg(op.children[0])
  266. self.compile_arg(op.children[1])
  267. self.compile_arg(op.children[2])
  268. elif op.data == "shrbnz":
  269. self.emit(27)
  270. self.compile_arg(op.children[0])
  271. self.compile_arg(op.children[1])
  272. self.compile_arg(op.children[2])
  273. elif op.data == "nbz":
  274. self.emit(28)
  275. self.compile_arg(op.children[0])
  276. self.compile_arg(op.children[1])
  277. elif op.data == "anz":
  278. self.emit(29)
  279. self.compile_arg(op.children[0])
  280. self.compile_arg(op.children[1])
  281. self.compile_arg(op.children[2])
  282. elif op.data == "abgz":
  283. self.emit(30)
  284. self.compile_arg(op.children[0])
  285. self.compile_arg(op.children[1])
  286. self.compile_arg(op.children[2])
  287. elif op.data == "swp":
  288. self.emit(31)
  289. self.compile_arg(op.children[0])
  290. self.compile_arg(op.children[1])
  291. elif op.data == "h_add":
  292. self.emit(3)
  293. self.compile_arg(op.children[0])
  294. self.compile_arg(op.children[1])
  295. self.emit(self.size+2)
  296. elif op.data == "h_sub":
  297. self.emit(2)
  298. self.compile_arg(op.children[0])
  299. self.compile_arg(op.children[1])
  300. self.emit(self.size+2)
  301. elif op.data == "h_inc":
  302. self.emit(3)
  303. self.emit(-3)
  304. self.compile_arg(op.children[0])
  305. self.emit(self.size+2)
  306. elif op.data == "h_dec":
  307. self.emit(2)
  308. self.emit(-3)
  309. self.compile_arg(op.children[0])
  310. self.emit(self.size+2)
  311. elif op.data == "h_mov":
  312. self.emit(1)
  313. self.compile_arg(op.children[0])
  314. self.compile_arg(op.children[1])
  315. self.emit(self.size+2)
  316. elif op.data == "h_jmp":
  317. self.emit(1)
  318. self.emit(0)
  319. self.emit(0)
  320. self.compile_arg(op.children[0])
  321. elif op.data == "h_halt":
  322. self.emit(1)
  323. self.emit(0)
  324. self.emit(0)
  325. self.emit(-1)
  326. elif op.data == "h_out":
  327. self.emit(1)
  328. self.compile_arg(op.children[0])
  329. self.emit(-1)
  330. self.emit(self.size+2)
  331. elif op.data == "h_outn":
  332. self.emit(1)
  333. self.compile_arg(op.children[0])
  334. self.emit(-2)
  335. self.emit(self.size+2)
  336. elif op.data == "h_in":
  337. self.emit(1)
  338. self.emit(-1)
  339. self.compile_arg(op.children[0])
  340. self.emit(self.size+2)
  341. def compile_labels(self):
  342. labels = {}
  343. position = 0
  344. while position < len(self.buffer):
  345. this = self.buffer[position]
  346. if type(this) is tuple and this[0]:
  347. label = this[1]
  348. if label in labels:
  349. raise Exception(f"Duplicated label: {label}.")
  350. elif label in ("IO", "Z", "O", "N", "J", "T", "SP", "EZ", "SZ", "MZ", "JZ", "W", "MM", "DR", "ZZ"):
  351. raise Exception(f"Register override: {label}.")
  352. self.buffer.pop(position)
  353. labels[label] = position + 1
  354. position += 1
  355. position = 0
  356. while position < len(self.buffer):
  357. this = self.buffer[position]
  358. if type(this) is tuple and not this[0]:
  359. label = this[1]
  360. if label not in labels:
  361. raise Exception(f"Undefined label/register: {label}.")
  362. if len(this) == 3:
  363. self.buffer[position] = 1 + (position + labels[label] if this[2] == '+' else position - labels[label])
  364. elif len(this) == 4:
  365. self.buffer[position] = labels[label] + this[3] if this[2] == '+' else labels[label] - this[3]
  366. else:
  367. self.buffer[position] = labels[label]
  368. position += 1
  369. def encode(self):
  370. buffer = b""
  371. for b in self.buffer:
  372. if type(b) is float:
  373. b = struct.pack("<d", b)
  374. else:
  375. b = struct.pack("<q", b)
  376. buffer += b
  377. return buffer
  378. def precompile(self, source):
  379. ast = self.parser.parse(source)
  380. for command in ast.children:
  381. if len(command.children) == 2:
  382. label = command.children[0].value[:-1]
  383. self.emit((True, label))
  384. if type(command.children[1]) is lark.Tree and command.children[1].data != "mixed":
  385. self.compile_operation(command.children[1])
  386. else:
  387. self.compile_arg(command.children[1])
  388. else:
  389. if type(command.children[0]) is lark.Token:
  390. if command.children[0].type == "LABEL":
  391. label = command.children[0].value[:-1]
  392. self.emit((True, label))
  393. else:
  394. with open(command.children[0].value[1:], "r") as f:
  395. self.precompile(f.read())
  396. else:
  397. self.compile_operation(command.children[0])
  398. def compile(self, source):
  399. self.precompile(source)
  400. for reg in "ABCDEFGHIXYK":
  401. if reg in self.used_regs:
  402. self.buffer.append((True, reg))
  403. self.buffer.append(0)
  404. self.compile_labels()
  405. return self.encode()
  406. wma = WMA()
  407. try:
  408. if len(sys.argv) == 3:
  409. with open(sys.argv[1], "r") as fin:
  410. with open(sys.argv[2], "wb") as fout:
  411. fout.write(wma.compile(fin.read()))
  412. else:
  413. sys.stdout.buffer.write(wma.compile(sys.stdin.read()))
  414. except Exception as e:
  415. print(e)
  416. sys.exit(1)