#!/usr/bin/python import os import os.path import sys import lark GRAMMAR = r""" start: toplevel+ ?toplevel: include | funcdef | arrdec ";" | vardec ";" | varinit ";" | arrinit ";" | asm ";" include: "#" "include" FILENAME FILENAME: "<" /.+/ ">" funcdef: NAME "(" params ")" block varinit: NAME arrinit: NAME "[" INTEGER "]" vardec: NAME "=" expr arrdec: NAME "[" expr "]" ("=" expr)? params: | NAME ("," NAME)* block: "{" op* "}" asm: "asm" "(" STRING+ ")" ?op: block | label | goto ";" | break ";" | continue ";" | vardec ";" | arrdec ";" | varinit ";" | arrinit ";" | inc ";" | dec ";" | rinc ";" | rdec ";" | asm ";" | funcall ";" | if | while | for | return ";" label: NAME ":" goto: "goto" NAME break: "break" continue: "continue" if: "if" "(" expr ")" op ("else" op)? while: "while" "(" expr ")" op for: "for" "(" vardec ";" expr ";" (inc|dec|rinc|rdec|vardec|funcall) ")" op return: "return" expr inc: NAME ("[" expr "]")? "++" dec: NAME ("[" expr "]")? "--" rinc: "++" NAME ("[" expr "]")? rdec: "--" NAME ("[" expr "]")? funcall: NAME "(" args ")" args: | [expr ("," expr)*] ?expr: op1 | op1 "?" op1 ":" expr -> ifexpr ?op1: op2 | vardec | arrdec | op1 "==" op1 -> equals | op1 "!=" op1 -> not_equals | op1 "+" op2 -> plus | op1 "-" op2 -> minus ?op2: op3 | op2 "*" op2 -> times | op2 "/" op3 -> divide | op2 "%" op3 -> modulo | op2 "<" op3 -> less | op2 ">" op3 -> greater | op2 "<=" op3 -> less_or_equals | op2 ">=" op3 -> greater_or_equals ?op3: op4 | op3 "**" op4 -> raise ?op4: op5 | op4 "||" op4 -> or | op4 "&&" op4 -> and | "!" atom -> not | "*" atom -> deref ?op5: atom | op5 "[" op1 "]" -> index ?atom: "(" op1 ")" | NAME | INTEGER | FLOAT | CHAR | STRING | "-" atom -> negate | array | funcall | inc | dec | rinc | rdec array: "{" atom ("," atom)* "}" NAME: /[A-Za-z_][a-zA-Z0-9_]*/ INTEGER: /[0-9]+/ FLOAT: /[0-9]+\.[0-9]+/ CHAR: /'(.|(\\.))'/ _STRING_INNER: /(.|\n)*?/ _STRING_ESC_INNER: _STRING_INNER /(?0, but it is {count}.") if len(node.children) == 3: value = self.compile_literal(node.children[2]) size = len(value.split(" ")) # Dirty. if size < count: value += f" 0*{count-size}" elif size != count: raise Exception(f"Illegal array declaration '{node.children[0].value}`: value size is {size}, but expected {count}.") buffer.emit( "ld {} Y", self.make_array(value) ) else: buffer.emit( "ld {} Y", self.make_array(f"0*{count}") ) buffer.emit( "mov Y {}", name ) return buffer def compile_op(self, node): buffer = Buffer() if node.data == "block": buffer.emit( self.compile_block( node ) ) elif node.data == "label": buffer.emit( "{}:", self.scope.get_label(node.children[0].value) ) elif node.data == "goto": buffer.emit( "jmp {}", self.scope.get_label(node.children[0].value) ) elif node.data == "break": if len(self.loops) < 1: raise Exception("'break` outside of a loop.") buffer.emit( "jmp {}", self.loops[-1][1] ) elif node.data == "continue": if len(self.loops) < 1: raise Exception("'continue` outside of a loop.") buffer.emit( "jmp {}", self.loops[-1][0] ) elif node.data == "varinit": name = node.children[0].value if self.scope.is_local(name): raise Exception(f"Duplicated declaration of a local variable: '{name}`.") self.scope.insert(name) elif node.data == "vardec": name = self.scope[node.children[0].value] self.record_usage(name) buffer.emit( self.compile_expr(node.children[1]) ) buffer.emit( "mov Y {}", name ) elif node.data in ("arrdec", "arrinit"): buffer.emit( self.compile_arrdec(node) ) elif node.data in ("inc", "rinc"): if len(node.children) == 2: raise Exception(f"Not implemented: {node}") name = self.scope[node.children[0].value] self.record_usage(name) buffer.emit( "inc {}", name ) elif node.data == ("dec", "rdec"): if len(node.children) == 2: raise Exception(f"Not implemented: {node}") name = self.scope[node.children[0].value] self.record_usage(name) buffer.emit( "dec {}", name ) elif node.data == "funcall": buffer.emit( self.compile_funcall(node, dest='ZZ') ) elif node.data == "return": buffer.emit( self.compile_expr( node.children[0] ) ) buffer.emit("push Y") buffer.emit("ret") elif node.data == "asm": buffer.emit( self.compile_asm(node) ) elif node.data == "if": else_label = self.make_label() exit_label = self.make_label() buffer.emit( self.compile_expr( node.children[0] ) ) buffer.emit( "nbnz Y {}", else_label ) buffer.emit( self.compile_op( node.children[1] ) ) buffer.emit( "jmp {}", exit_label ) buffer.emit( "{}:", else_label ) if len(node.children) == 3: buffer.emit( self.compile_op( node.children[2] ) ) buffer.emit( "{}:", exit_label ) elif node.data == "while": loop_label = self.make_label() exit_label = self.make_label() self.loops.append((loop_label, exit_label)) buffer.emit( "{}:", loop_label ) buffer.emit( self.compile_expr( node.children[0] ) ) buffer.emit( "nbnz Y {}", exit_label ) buffer.emit( self.compile_op( node.children[1] ) ) self.loops.pop() buffer.emit( "jmp {}", loop_label ) buffer.emit( "{}:", exit_label ) elif node.data == "for": loop_label = self.make_label() exit_label = self.make_label() self.loops.append((loop_label, exit_label)) self.scope.new() buffer.emit( self.compile_op( node.children[0] ) ) buffer.emit( "{}:", loop_label ) buffer.emit( self.compile_expr( node.children[1] ) ) buffer.emit( "nbnz Y {}", exit_label ) buffer.emit( self.compile_op( node.children[3] ) ) buffer.emit( self.compile_op( node.children[2] ) ) self.scope.leave() self.loops.pop() buffer.emit( "jmp {}", loop_label ) buffer.emit( "{}:", exit_label ) else: raise Exception(f"Not implemented: {node}") return buffer def collect_labels(self, node): for child in node.children: if child.data == "label": self.scope.add_label(child.children[0].value) def compile_block(self, node, *prepend_names, scope=True): if scope: self.scope.new() for name in prepend_names: self.scope.insert(name) buffer = Buffer() self.collect_labels(node) for child in node.children: buffer.emit( self.compile_op(child) ) if scope: self.scope.leave() return buffer def compile_toplevel(self, node): buffer = Buffer() if node.data == "funcdef": name = node.children[0].value params = self.funcs[name].params buffer.emit("__{}:", name) for param in params: buffer.emit( "pop __{}_{}", self.scope.ndx, param ) self.where = name buffer.emit( self.compile_block( node.children[2], *params ) ) buffer.emit( "push Z" ) buffer.emit( "ret" ) elif node.data == "vardec": name = node.children[0].value self.init_buffer.emit( self.compile_expr(node.children[1]) ) self.init_buffer.emit( "mov Y __0_{}", name ) self.record_usage(f"__0__{name}") elif node.data in ("arrdec", "arrinit"): self.init_buffer.emit( self.compile_arrdec(node) ) elif node.data == "asm": buffer.emit( self.compile_asm(node) ) elif node.data == "include": filename = node.children[0].value[1:-1] if not os.path.isfile(filename): filename = os.path.join(os.getenv("WC_I"), filename) if not os.path.isfile(filename): raise Exception(f"No such file: '{os.path.basename(filename)}`.") if filename not in self.included_files: with open(filename, "r") as f: self.compile_program(f.read()) self.included_files.add(filename) elif node.data == "varinit": pass else: raise Exception(f"Not implemented: {node}") return buffer def collect_toplevel(self, ast): for node in ast.children: if node.data == "funcdef": name = node.children[0].value if name in self.funcs: raise Exception(f"Duplicated function declaration: '{name}`.") self.funcs[name] = Func( len(node.children[1].children), tuple(map(lambda t: t.value, node.children[1].children)) ) elif node.data in ("varinit", "vardec", "arrinit", "arrdec"): name = node.children[0].value if name in self.scope: raise Exception(f"Duplicated top-level variable declaration: '{name}`.") self.scope.insert(name) # Because we're at the top-level rn. def compile_program(self, text): ast = self.parser.parse(text) #print(ast.pretty()) self.collect_toplevel(ast) for node in ast.children: buffer = self.compile_toplevel(node) if node.data == "funcdef": self.compiled_funcs[node.children[0].value] = buffer else: self.buffer.emit(buffer) def compile(self, text): self.compile_program(text) for name in self.compiled_funcs: if self.is_used(name): self.buffer.emit(self.compiled_funcs[name]) for param in self.funcs[name].params: self.record_usage(param) self.buffer = self.init_buffer + self.buffer + self.arrays_buffer for name in self.scope.names: if self.is_used(name): self.buffer.emit( "{}:0", name ) if "main" not in self.funcs: raise Exception("Missing 'main` function.") return self.buffer.generate() + "\n" wmc = WMC() try: if len(sys.argv) == 3: with open(sys.argv[1], "r") as fin: with open(sys.argv[2], "w") as fout: fout.write(wmc.compile(fin.read())) else: sys.stdout.write(wmc.compile(sys.stdin.read())) except Exception as e: #__import__('traceback').print_exc() print(e) sys.exit(1)