wmc.py 26 KB


  1. #!/usr/bin/python
  2. import os
  3. import os.path
  4. import sys
  5. import lark
  6. GRAMMAR = r"""
  7. start: toplevel+
  8. ?toplevel: include
  9. | funcdef
  10. | arrdec ";"
  11. | vardec ";"
  12. | varinit ";"
  13. | arrinit ";"
  14. | asm ";"
  15. include: "#" "include" FILENAME
  16. FILENAME: "<" /.+/ ">"
  17. funcdef: NAME "(" params ")" block
  18. varinit: NAME
  19. arrinit: NAME "[" INTEGER "]"
  20. vardec: NAME "=" expr
  21. arrdec: NAME "[" expr "]" ("=" expr)?
  22. params:
  23. | NAME ("," NAME)*
  24. block: "{" op* "}"
  25. asm: "asm" "(" STRING+ ")"
  26. ?op: block
  27. | label
  28. | goto ";"
  29. | break ";"
  30. | continue ";"
  31. | vardec ";"
  32. | arrdec ";"
  33. | varinit ";"
  34. | arrinit ";"
  35. | inc ";"
  36. | dec ";"
  37. | rinc ";"
  38. | rdec ";"
  39. | asm ";"
  40. | funcall ";"
  41. | if
  42. | while
  43. | for
  44. | return ";"
  45. label: NAME ":"
  46. goto: "goto" NAME
  47. break: "break"
  48. continue: "continue"
  49. if: "if" "(" expr ")" op ("else" op)?
  50. while: "while" "(" expr ")" op
  51. for: "for" "(" vardec ";" expr ";" (inc|dec|rinc|rdec|vardec|funcall) ")" op
  52. return: "return" expr
  53. inc: NAME ("[" expr "]")? "++"
  54. dec: NAME ("[" expr "]")? "--"
  55. rinc: "++" NAME ("[" expr "]")?
  56. rdec: "--" NAME ("[" expr "]")?
  57. funcall: NAME "(" args ")"
  58. args:
  59. | [expr ("," expr)*]
  60. ?expr: op1
  61. | op1 "?" op1 ":" expr -> ifexpr
  62. ?op1: op2
  63. | vardec
  64. | op1 "==" op1 -> equals
  65. | op1 "!=" op1 -> not_equals
  66. | op1 "+" op2 -> plus
  67. | op1 "-" op2 -> minus
  68. ?op2: op3
  69. | op2 "*" op2 -> times
  70. | op2 "/" op3 -> divide
  71. | op2 "%" op3 -> modulo
  72. | op2 "<" op3 -> less
  73. | op2 ">" op3 -> greater
  74. | op2 "<=" op3 -> less_or_equals
  75. | op2 ">=" op3 -> greater_or_equals
  76. ?op3: op4
  77. | op3 "**" op4 -> raise
  78. ?op4: op5
  79. | op4 "||" op4 -> or
  80. | op4 "&&" op4 -> and
  81. | "!" atom -> not
  82. | "*" atom -> deref
  83. ?op5: atom
  84. | op5 "[" op1 "]" -> index
  85. ?atom: "(" op1 ")"
  86. | NAME
  87. | INTEGER
  88. | FLOAT
  89. | CHAR
  90. | STRING
  91. | "-" atom -> negate
  92. | array
  93. | funcall
  94. | inc
  95. | dec
  96. | rinc
  97. | rdec
  98. array: "{" atom ("," atom)* "}"
  99. NAME: /[A-Za-z_][a-zA-Z0-9_]*/
  100. INTEGER: /[0-9]+/
  101. FLOAT: /[0-9]+\.[0-9]+/
  102. CHAR: /'(.|(\\.))'/
  103. _STRING_INNER: /(.|\n)*?/
  104. _STRING_ESC_INNER: _STRING_INNER /(?<!\\)(\\\\)*?/
  105. STRING: "\"" _STRING_ESC_INNER "\""
  106. IG: /[ \t\r\n]+/
  107. COM: /\/\*(.|\!)*\*\//
  108. %ignore IG
  109. %ignore COM
  110. """
  111. def parse_escape(s):
  112. return bytes(s, "utf-8").decode("unicode_escape")
  113. class Buffer:
  114. def __init__(self, *init):
  115. self.buffer = list(init)
  116. def emit(self, asm, *args):
  117. if type(asm) is list:
  118. self.buffer.extend(asm)
  119. elif type(asm) is Buffer:
  120. self.buffer.extend(asm.buffer)
  121. else:
  122. self.buffer.append(asm.format(*args))
  123. def generate(self):
  124. return "\n".join(self.buffer)
  125. def __add__(self, other):
  126. if type(other) is Buffer:
  127. return Buffer(
  128. *self.buffer + other.buffer
  129. )
  130. raise TypeError
  131. class Scope:
  132. def __init__(self):
  133. self.scopes = []
  134. self.ndx = 0
  135. self.names = set()
  136. def new(self):
  137. self.scopes.append((self.ndx, {}, {}))
  138. self.ndx += 1
  139. def leave(self):
  140. self.scopes.pop()
  141. def add_label(self, name):
  142. renamed = f"__{self.scopes[-1][0]}l_{name}"
  143. self.scopes[-1][2][name] = renamed
  144. return renamed
  145. def get_label(self, name):
  146. if name not in self.scopes[-1][2]:
  147. raise Exception(f"Undeclared label: '{name}`.")
  148. return self.scopes[-1][2][name]
  149. def is_local(self, name):
  150. return name in self.scopes[-1][1]
  151. def insert(self, name):
  152. renamed = f"__{self.scopes[-1][0]}_{name}"
  153. self.scopes[-1][1][name] = renamed
  154. self.names.add(renamed)
  155. return renamed
  156. def find(self, name):
  157. for _, scope, _ in self.scopes[::-1]:
  158. if name in scope:
  159. return scope[name]
  160. raise Exception(f"Undeclared identifier: '{name}`.")
  161. def __contains__(self, name):
  162. for _, scope, _ in self.scopes[::-1]:
  163. if name in scope:
  164. return True
  165. return False
  166. def __getitem__(self, name):
  167. if name in self:
  168. return self.find(name)
  169. return self.insert(name)
  170. def __iter__(self):
  171. for _, scope, _ in self.scopes[::-1]:
  172. for name in scope:
  173. yield (name, scope[name])
  174. @property
  175. def is_toplevel(self):
  176. return self.scopes[-1][0] == 0
  177. class Func:
  178. def __init__(self, argc, params):
  179. self.argc = argc
  180. self.params = params
  181. class WMC:
  182. def __init__(self):
  183. self.funcs = {}
  184. self.used_symbols = {}
  185. self.where = "toplevel"
  186. self.scope = Scope()
  187. self.scope.new()
  188. self.compiled_funcs = {}
  189. self.arrays_buffer = Buffer()
  190. self.init_buffer = Buffer()
  191. self.buffer = Buffer(
  192. "jw __main",
  193. "pop Y",
  194. "mov 80 _dirBuf",
  195. "mov Y _dirBuf+1",
  196. "dir _dirBuf",
  197. "hlt",
  198. "_dirBuf:0*4",
  199. )
  200. self.label_ndx = 0
  201. self.included_files = set()
  202. self.parser = lark.Lark(GRAMMAR)
  203. self.loops = []
  204. def record_usage(self, name):
  205. if name in self.used_symbols:
  206. self.used_symbols[name].add(self.where)
  207. else:
  208. self.used_symbols[name] = set([self.where])
  209. def is_used(self, name):
  210. if name in ("main", "toplevel"):
  211. return True
  212. if name not in self.used_symbols:
  213. return False
  214. for other in self.used_symbols[name]:
  215. if other == name:
  216. continue
  217. if self.is_used(other):
  218. return True
  219. return False
  220. def make_label(self):
  221. name = f"__{self.label_ndx}l"
  222. self.label_ndx += 1
  223. return name
  224. def make_array(self, value):
  225. name = self.make_label()
  226. self.arrays_buffer.emit(
  227. "{}:{}",
  228. name, value
  229. )
  230. return name
  231. def compile_literal(self, node):
  232. if type(node) is lark.Token:
  233. if node.type == "NAME":
  234. value = self.scope.find(node.value)
  235. self.record_usage(value)
  236. return value
  237. elif node.type in ("INTEGER", "FLOAT"):
  238. return node.value
  239. elif node.type == "CHAR":
  240. return str(ord(parse_escape(node.value[1:-1])))
  241. elif node.type == "STRING":
  242. value = parse_escape(node.value[1:-1])
  243. value = map(ord, value)
  244. value = map(str, value)
  245. value = " ".join(value)
  246. value += " 0"
  247. return value
  248. elif node.data == "array":
  249. value = []
  250. for child in node.children:
  251. value.append(
  252. self.compile_literal(child)
  253. )
  254. return " ".join(value)
  255. raise Exception(f"Not implemented: {node}")
  256. def compile_unary_expr(self, node, *ops):
  257. buffer = Buffer()
  258. buffer.emit(
  259. self.compile_expr(
  260. node.children[0]
  261. )
  262. )
  263. for op in ops:
  264. buffer.emit(op)
  265. return buffer
  266. def compile_binary_expr(self, node):
  267. buffer = Buffer()
  268. buffer.emit(
  269. self.compile_expr(
  270. node.children[0]
  271. )
  272. )
  273. buffer.emit('push Y')
  274. buffer.emit(
  275. self.compile_expr(
  276. node.children[1]
  277. )
  278. )
  279. buffer.emit('push Y')
  280. buffer.emit('pop X')
  281. buffer.emit('pop Y')
  282. return buffer
  283. def compile_compare_expr(self, node, *ops, true='1', false='0'):
  284. buffer = Buffer()
  285. ret_label = self.make_label()
  286. exit_label = self.make_label()
  287. buffer.emit(
  288. self.compile_binary_expr(
  289. node
  290. )
  291. )
  292. for op in ops:
  293. buffer.emit(
  294. op,
  295. ret_label
  296. )
  297. buffer.emit(
  298. "mov {} Y",
  299. false
  300. )
  301. buffer.emit(
  302. "jmp {}",
  303. exit_label
  304. )
  305. buffer.emit(
  306. "{}:",
  307. ret_label
  308. )
  309. buffer.emit(
  310. "mov {} Y",
  311. true
  312. )
  313. buffer.emit(
  314. "{}:",
  315. exit_label
  316. )
  317. return buffer
  318. def compile_expr(self, node):
  319. buffer = Buffer()
  320. if type(node) is lark.Token:
  321. if node.type == "NAME":
  322. buffer.emit(
  323. "mov {} Y",
  324. self.compile_literal(node)
  325. )
  326. elif node.type in ("INTEGER", "FLOAT"):
  327. buffer.emit(
  328. "mov {} Y",
  329. self.compile_literal(node)
  330. )
  331. elif node.type == "CHAR":
  332. buffer.emit(
  333. "mov {} Y",
  334. self.compile_literal(node)
  335. )
  336. elif node.type == "STRING":
  337. buffer.emit(
  338. "ld {} Y",
  339. self.make_array(
  340. self.compile_literal(node)
  341. )
  342. )
  343. else:
  344. raise Exception(f"Not implemented: {node}")
  345. elif node.data == "ifexpr":
  346. else_label = self.make_label()
  347. exit_label = self.make_label()
  348. buffer.emit(
  349. self.compile_expr(
  350. node.children[0]
  351. )
  352. )
  353. buffer.emit(
  354. "nbnz Y {}",
  355. else_label
  356. )
  357. buffer.emit(
  358. self.compile_expr(
  359. node.children[1]
  360. )
  361. )
  362. buffer.emit(
  363. "jmp {}",
  364. exit_label
  365. )
  366. buffer.emit(
  367. "{}:",
  368. else_label
  369. )
  370. buffer.emit(
  371. self.compile_expr(
  372. node.children[2]
  373. )
  374. )
  375. buffer.emit(
  376. "{}:",
  377. exit_label
  378. )
  379. elif node.data == "equals":
  380. buffer.emit(
  381. self.compile_compare_expr(
  382. node,
  383. "sblez X Y !",
  384. "nbnz Y {}",
  385. true='1',
  386. false='0'
  387. )
  388. )
  389. elif node.data == "not_equals":
  390. buffer.emit(
  391. self.compile_compare_expr(
  392. node,
  393. "sblez X Y !",
  394. "nbnz Y {}",
  395. true='0',
  396. false='1'
  397. )
  398. )
  399. elif node.data == "plus":
  400. buffer.emit(
  401. self.compile_binary_expr(
  402. node
  403. )
  404. )
  405. buffer.emit(
  406. "ablez X Y !"
  407. )
  408. elif node.data == "minus":
  409. buffer.emit(
  410. self.compile_binary_expr(
  411. node
  412. )
  413. )
  414. buffer.emit(
  415. "sblez X Y !"
  416. )
  417. elif node.data == "times":
  418. buffer.emit(
  419. self.compile_binary_expr(
  420. node
  421. )
  422. )
  423. buffer.emit(
  424. "mbnz X Y !"
  425. )
  426. elif node.data == "divide":
  427. buffer.emit(
  428. self.compile_binary_expr(
  429. node
  430. )
  431. )
  432. buffer.emit(
  433. "vblz X Y !"
  434. )
  435. elif node.data == "modulo":
  436. buffer.emit(
  437. self.compile_binary_expr(
  438. node
  439. )
  440. )
  441. buffer.emit(
  442. "modbz X Y !"
  443. )
  444. elif node.data == "raise":
  445. buffer.emit(
  446. self.compile_binary_expr(
  447. node
  448. )
  449. )
  450. buffer.emit(
  451. "mov 12 _dirBuf"
  452. )
  453. buffer.emit(
  454. "mov X _dirBuf+1"
  455. )
  456. buffer.emit(
  457. "mov Y _dirBuf+2"
  458. )
  459. buffer.emit(
  460. "dir _dirBuf"
  461. )
  462. buffer.emit(
  463. "mov _dirBuf+2 Y"
  464. )
  465. elif node.data == "or":
  466. true_label = self.make_label()
  467. exit_label = self.make_label()
  468. buffer.emit(
  469. self.compile_expr(
  470. node.children[0]
  471. )
  472. )
  473. buffer.emit(
  474. "nbnz Y !"
  475. )
  476. buffer.emit(
  477. "nbnz Y {}",
  478. true_label
  479. )
  480. buffer.emit(
  481. self.compile_expr(
  482. node.children[1]
  483. )
  484. )
  485. buffer.emit(
  486. "jmp {}",
  487. exit_label
  488. )
  489. buffer.emit(
  490. "{}:",
  491. true_label
  492. )
  493. buffer.emit(
  494. "mov 1 Y"
  495. )
  496. buffer.emit(
  497. "{}:",
  498. exit_label
  499. )
  500. elif node.data == "and":
  501. false_label = self.make_label()
  502. exit_label = self.make_label()
  503. buffer.emit(
  504. self.compile_expr(
  505. node.children[0]
  506. )
  507. )
  508. buffer.emit(
  509. "nbnz Y {}",
  510. false_label
  511. )
  512. buffer.emit(
  513. self.compile_expr(
  514. node.children[1]
  515. )
  516. )
  517. buffer.emit(
  518. "jmp {}",
  519. exit_label
  520. )
  521. buffer.emit(
  522. "{}:",
  523. false_label
  524. )
  525. buffer.emit(
  526. "mov 0 Y"
  527. )
  528. buffer.emit(
  529. "{}:",
  530. exit_label
  531. )
  532. elif node.data == "less":
  533. buffer.emit(
  534. self.compile_compare_expr(
  535. node,
  536. "inc Y",
  537. "sblez X Y {}"
  538. )
  539. )
  540. elif node.data == "greater":
  541. buffer.emit(
  542. self.compile_compare_expr(
  543. node,
  544. "dec Y",
  545. "sblez X Y {}",
  546. true='0',
  547. false='1'
  548. )
  549. )
  550. elif node.data == "less_or_equals":
  551. buffer.emit(
  552. self.compile_compare_expr(
  553. node,
  554. "sblez X Y {}"
  555. )
  556. )
  557. elif node.data == "greater_or_equals":
  558. buffer.emit(
  559. self.compile_compare_expr(
  560. node,
  561. "inc Y",
  562. "sblez X Y {}",
  563. true='0',
  564. false='1'
  565. )
  566. )
  567. elif node.data == "not":
  568. buffer.emit(
  569. self.compile_unary_expr(
  570. node,
  571. "nbnz Y !"
  572. )
  573. )
  574. elif node.data == "deref":
  575. buffer.emit(
  576. self.compile_unary_expr(
  577. node,
  578. "la Y Y"
  579. )
  580. )
  581. elif node.data == "index":
  582. buffer.emit(
  583. self.compile_binary_expr(
  584. node
  585. )
  586. )
  587. buffer.emit(
  588. "ablez X Y !"
  589. )
  590. buffer.emit(
  591. "la Y Y"
  592. )
  593. elif node.data == "negate":
  594. buffer.emit(
  595. self.compile_unary_expr(
  596. node,
  597. "mov Y X",
  598. "sblez X Y !",
  599. "sblez X Y !",
  600. )
  601. )
  602. elif node.data == "array":
  603. buffer.emit(
  604. "ld {} Y",
  605. self.make_array(
  606. self.compile_literal(node)
  607. )
  608. )
  609. elif node.data == "inc":
  610. if len(node.children) == 2:
  611. raise Exception(f"Not implemented: {node}")
  612. name = self.scope[node.children[0].value]
  613. self.record_usage(name)
  614. buffer.emit(
  615. "mov {} Y",
  616. name
  617. )
  618. buffer.emit(
  619. "inc {}",
  620. name
  621. )
  622. elif node.data == "dec":
  623. if len(node.children) == 2:
  624. raise Exception(f"Not implemented: {node}")
  625. name = self.scope[node.children[0].value]
  626. self.record_usage(name)
  627. buffer.emit(
  628. "mov {} Y"
  629. )
  630. buffer.emit(
  631. "dec {}",
  632. name
  633. )
  634. elif node.data == "rinc":
  635. if len(node.children) == 2:
  636. raise Exception(f"Not implemented: {node}")
  637. name = self.scope[node.children[0].value]
  638. self.record_usage(name)
  639. buffer.emit(
  640. "inc {}",
  641. name
  642. )
  643. buffer.emit(
  644. "mov {} Y",
  645. name
  646. )
  647. elif node.data == "rdec":
  648. if len(node.children) == 2:
  649. raise Exception(f"Not implemented: {node}")
  650. name = self.scope[node.children[0].value]
  651. self.record_usage(name)
  652. buffer.emit(
  653. "dec {}",
  654. name
  655. )
  656. buffer.emit(
  657. "mov {} Y"
  658. )
  659. elif node.data == "funcall":
  660. buffer.emit(
  661. self.compile_funcall(node, dest='Y')
  662. )
  663. else:
  664. raise Exception(f"Not implemented: {node}")
  665. return buffer
  666. def compile_funcall(self, node, dest='Y'):
  667. buffer = Buffer()
  668. name = node.children[0].value
  669. if name not in self.funcs:
  670. raise Exception(f"Call to an undeclared function: '{name}`.")
  671. if self.funcs[name].argc != len(node.children[1].children):
  672. raise Exception(f"Function '{name}` expects {self.funcs[name].argc} arguments, but got {node.children[1].children}.")
  673. for arg in node.children[1].children[::-1]:
  674. buffer.emit(
  675. self.compile_expr(arg)
  676. )
  677. buffer.emit(
  678. "push Y"
  679. )
  680. buffer.emit(
  681. "jw __{}",
  682. name
  683. )
  684. buffer.emit(
  685. "pop {}",
  686. dest
  687. )
  688. self.record_usage(name)
  689. return buffer
  690. def compile_asm(self, node):
  691. buffer = Buffer()
  692. table = {}
  693. for name, renamed in self.scope:
  694. table[name] = renamed
  695. for child in node.children:
  696. value = parse_escape(child.value[1:-1])
  697. for name in table:
  698. if f"{{{name}}}" in value:
  699. self.record_usage(table[name])
  700. try:
  701. value = value.format(
  702. **table
  703. )
  704. except:
  705. raise Exception("Malformed asm directive.")
  706. buffer.emit(value)
  707. return buffer
  708. def compile_arrdec(self, node):
  709. buffer = Buffer()
  710. name = node.children[0].value
  711. if name in self.scope and len(node.children) == 3: # Index assignment.
  712. name = self.scope.find(name)
  713. self.record_usage(name)
  714. buffer.emit(
  715. self.compile_expr(
  716. node.children[1]
  717. )
  718. )
  719. buffer.emit(
  720. "mov {} X",
  721. name
  722. )
  723. buffer.emit(
  724. "ablez X Y !"
  725. )
  726. buffer.emit(
  727. "push Y"
  728. )
  729. buffer.emit(
  730. self.compile_expr(
  731. node.children[2]
  732. )
  733. )
  734. buffer.emit(
  735. "pop X"
  736. )
  737. buffer.emit(
  738. "str Y X"
  739. )
  740. return buffer
  741. if not self.scope.is_toplevel and self.scope.is_local(name):
  742. raise Exception(f"Duplicated declaration of a local variable: '{node.children[0].value}`.")
  743. name = self.scope[name]
  744. self.record_usage(name)
  745. count = int(node.children[1].value)
  746. if count <= 0:
  747. raise Exception(f"Illegal array declaration '{node.children[0].value}`: array size should be >0, but it is {count}.")
  748. if len(node.children) == 3:
  749. value = self.compile_literal(node.children[2])
  750. size = len(value.split(" ")) # Dirty.
  751. if size < count:
  752. value += f" 0*{count-size}"
  753. elif size != count:
  754. raise Exception(f"Illegal array declaration '{node.children[0].value}`: value size is {size}, but expected {count}.")
  755. buffer.emit(
  756. "ld {} Y",
  757. self.make_array(value)
  758. )
  759. else:
  760. buffer.emit(
  761. "ld {} Y",
  762. self.make_array(f"0*{count}")
  763. )
  764. buffer.emit(
  765. "mov Y {}",
  766. name
  767. )
  768. return buffer
  769. def compile_op(self, node):
  770. buffer = Buffer()
  771. if node.data == "block":
  772. buffer.emit(
  773. self.compile_block(
  774. node
  775. )
  776. )
  777. elif node.data == "label":
  778. buffer.emit(
  779. "{}:",
  780. self.scope.get_label(node.children[0].value)
  781. )
  782. elif node.data == "goto":
  783. buffer.emit(
  784. "jmp {}",
  785. self.scope.get_label(node.children[0].value)
  786. )
  787. elif node.data == "break":
  788. if len(self.loops) < 1:
  789. raise Exception("'break` outside of a loop.")
  790. buffer.emit(
  791. "jmp {}",
  792. self.loops[-1][1]
  793. )
  794. elif node.data == "continue":
  795. if len(self.loops) < 1:
  796. raise Exception("'continue` outside of a loop.")
  797. buffer.emit(
  798. "jmp {}",
  799. self.loops[-1][0]
  800. )
  801. elif node.data == "varinit":
  802. name = node.children[0].value
  803. if self.scope.is_local(name):
  804. raise Exception(f"Duplicated declaration of a local variable: '{name}`.")
  805. self.scope.insert(name)
  806. elif node.data == "vardec":
  807. name = self.scope[node.children[0].value]
  808. self.record_usage(name)
  809. buffer.emit(
  810. self.compile_expr(node.children[1])
  811. )
  812. buffer.emit(
  813. "mov Y {}",
  814. name
  815. )
  816. elif node.data in ("arrdec", "arrinit"):
  817. buffer.emit(
  818. self.compile_arrdec(node)
  819. )
  820. elif node.data in ("inc", "rinc"):
  821. if len(node.children) == 2:
  822. raise Exception(f"Not implemented: {node}")
  823. name = self.scope[node.children[0].value]
  824. self.record_usage(name)
  825. buffer.emit(
  826. "inc {}",
  827. name
  828. )
  829. elif node.data == ("dec", "rdec"):
  830. if len(node.children) == 2:
  831. raise Exception(f"Not implemented: {node}")
  832. name = self.scope[node.children[0].value]
  833. self.record_usage(name)
  834. buffer.emit(
  835. "dec {}",
  836. name
  837. )
  838. elif node.data == "funcall":
  839. buffer.emit(
  840. self.compile_funcall(node, dest='ZZ')
  841. )
  842. elif node.data == "return":
  843. buffer.emit(
  844. self.compile_expr(
  845. node.children[0]
  846. )
  847. )
  848. buffer.emit("push Y")
  849. buffer.emit("ret")
  850. elif node.data == "asm":
  851. buffer.emit(
  852. self.compile_asm(node)
  853. )
  854. elif node.data == "if":
  855. else_label = self.make_label()
  856. exit_label = self.make_label()
  857. buffer.emit(
  858. self.compile_expr(
  859. node.children[0]
  860. )
  861. )
  862. buffer.emit(
  863. "nbnz Y {}",
  864. else_label
  865. )
  866. buffer.emit(
  867. self.compile_op(
  868. node.children[1]
  869. )
  870. )
  871. buffer.emit(
  872. "jmp {}",
  873. exit_label
  874. )
  875. buffer.emit(
  876. "{}:",
  877. else_label
  878. )
  879. if len(node.children) == 3:
  880. buffer.emit(
  881. self.compile_op(
  882. node.children[2]
  883. )
  884. )
  885. buffer.emit(
  886. "{}:",
  887. exit_label
  888. )
  889. elif node.data == "while":
  890. loop_label = self.make_label()
  891. exit_label = self.make_label()
  892. self.loops.append((loop_label, exit_label))
  893. buffer.emit(
  894. "{}:",
  895. loop_label
  896. )
  897. buffer.emit(
  898. self.compile_expr(
  899. node.children[0]
  900. )
  901. )
  902. buffer.emit(
  903. "nbnz Y {}",
  904. exit_label
  905. )
  906. buffer.emit(
  907. self.compile_op(
  908. node.children[1]
  909. )
  910. )
  911. self.loops.pop()
  912. buffer.emit(
  913. "jmp {}",
  914. loop_label
  915. )
  916. buffer.emit(
  917. "{}:",
  918. exit_label
  919. )
  920. elif node.data == "for":
  921. loop_label = self.make_label()
  922. exit_label = self.make_label()
  923. self.loops.append((loop_label, exit_label))
  924. self.scope.new()
  925. buffer.emit(
  926. self.compile_op(
  927. node.children[0]
  928. )
  929. )
  930. buffer.emit(
  931. "{}:",
  932. loop_label
  933. )
  934. buffer.emit(
  935. self.compile_expr(
  936. node.children[1]
  937. )
  938. )
  939. buffer.emit(
  940. "nbnz Y {}",
  941. exit_label
  942. )
  943. buffer.emit(
  944. self.compile_op(
  945. node.children[3]
  946. )
  947. )
  948. buffer.emit(
  949. self.compile_op(
  950. node.children[2]
  951. )
  952. )
  953. self.scope.leave()
  954. self.loops.pop()
  955. buffer.emit(
  956. "jmp {}",
  957. loop_label
  958. )
  959. buffer.emit(
  960. "{}:",
  961. exit_label
  962. )
  963. else:
  964. raise Exception(f"Not implemented: {node}")
  965. return buffer
  966. def collect_labels(self, node):
  967. for child in node.children:
  968. if child.data == "label":
  969. self.scope.add_label(child.children[0].value)
  970. def compile_block(self, node, *prepend_names, scope=True):
  971. if scope:
  972. self.scope.new()
  973. for name in prepend_names:
  974. self.scope.insert(name)
  975. buffer = Buffer()
  976. self.collect_labels(node)
  977. for child in node.children:
  978. buffer.emit(
  979. self.compile_op(child)
  980. )
  981. if scope:
  982. self.scope.leave()
  983. return buffer
  984. def compile_toplevel(self, node):
  985. buffer = Buffer()
  986. if node.data == "funcdef":
  987. name = node.children[0].value
  988. params = self.funcs[name].params
  989. buffer.emit("__{}:", name)
  990. for param in params:
  991. buffer.emit(
  992. "pop __{}_{}",
  993. self.scope.ndx, param
  994. )
  995. self.where = name
  996. buffer.emit(
  997. self.compile_block(
  998. node.children[2],
  999. *params
  1000. )
  1001. )
  1002. buffer.emit(
  1003. "push Z"
  1004. )
  1005. buffer.emit(
  1006. "ret"
  1007. )
  1008. elif node.data == "vardec":
  1009. name = node.children[0].value
  1010. self.init_buffer.emit(
  1011. self.compile_expr(node.children[1])
  1012. )
  1013. self.init_buffer.emit(
  1014. "mov Y __0_{}",
  1015. name
  1016. )
  1017. self.record_usage(f"__0__{name}")
  1018. elif node.data in ("arrdec", "arrinit"):
  1019. self.init_buffer.emit(
  1020. self.compile_arrdec(node)
  1021. )
  1022. elif node.data == "asm":
  1023. buffer.emit(
  1024. self.compile_asm(node)
  1025. )
  1026. elif node.data == "include":
  1027. filename = node.children[0].value[1:-1]
  1028. if not os.path.isfile(filename):
  1029. filename = os.path.join(os.getenv("WC_I"), filename)
  1030. if not os.path.isfile(filename):
  1031. raise Exception(f"No such file: '{os.path.basename(filename)}`.")
  1032. if filename not in self.included_files:
  1033. with open(filename, "r") as f:
  1034. self.compile_program(f.read())
  1035. self.included_files.add(filename)
  1036. elif node.data == "varinit":
  1037. pass
  1038. else:
  1039. raise Exception(f"Not implemented: {node}")
  1040. return buffer
  1041. def collect_toplevel(self, ast):
  1042. for node in ast.children:
  1043. if node.data == "funcdef":
  1044. name = node.children[0].value
  1045. if name in self.funcs:
  1046. raise Exception(f"Duplicated function declaration: '{name}`.")
  1047. self.funcs[name] = Func(
  1048. len(node.children[1].children),
  1049. tuple(map(lambda t: t.value, node.children[1].children))
  1050. )
  1051. elif node.data in ("varinit", "vardec", "arrinit", "arrdec"):
  1052. name = node.children[0].value
  1053. if name in self.scope:
  1054. raise Exception(f"Duplicated top-level variable declaration: '{name}`.")
  1055. self.scope.insert(name) # Because we're at the top-level rn.
  1056. def compile_program(self, text):
  1057. ast = self.parser.parse(text)
  1058. #print(ast.pretty())
  1059. self.collect_toplevel(ast)
  1060. for node in ast.children:
  1061. buffer = self.compile_toplevel(node)
  1062. if node.data == "funcdef":
  1063. self.compiled_funcs[node.children[0].value] = buffer
  1064. else:
  1065. self.buffer.emit(buffer)
  1066. def compile(self, text):
  1067. self.compile_program(text)
  1068. for name in self.compiled_funcs:
  1069. if self.is_used(name):
  1070. self.buffer.emit(self.compiled_funcs[name])
  1071. for param in self.funcs[name].params:
  1072. self.record_usage(param)
  1073. self.buffer = self.init_buffer + self.buffer + self.arrays_buffer
  1074. for name in self.scope.names:
  1075. if self.is_used(name):
  1076. self.buffer.emit(
  1077. "{}:0", name
  1078. )
  1079. if "main" not in self.funcs:
  1080. raise Exception("Missing 'main` function.")
  1081. return self.buffer.generate() + "\n"
  1082. wmc = WMC()
  1083. try:
  1084. if len(sys.argv) == 3:
  1085. with open(sys.argv[1], "r") as fin:
  1086. with open(sys.argv[2], "w") as fout:
  1087. fout.write(wmc.compile(fin.read()))
  1088. else:
  1089. sys.stdout.write(wmc.compile(sys.stdin.read()))
  1090. except Exception as e:
  1091. #__import__('traceback').print_exc()
  1092. print(e)
  1093. sys.exit(1)