|
- #!/usr/bin/python
- import os
- import os.path
- import re
- import sys
- import lark
- GRAMMAR = r"""
- start: toplevel+
- ?toplevel: include
- | define
- | funcdef
- | arrdec ";"
- | vardec ";"
- | varinit ";"
- | arrinit ";"
- | asm ";"
- include: "#" "include" FILENAME
- define: "#" "define" NAME atom3
- FILENAME: "<" /.+/ ">"
- funcdef: NAME "(" params ")" block
- varinit: NAME
- arrinit: NAME "[" INTEGER "]"
- vardec: NAME "=" expr
- arrdec: NAME "[" expr "]" ("=" expr)?
- params:
- | NAME ("," NAME)*
- block: "{" op* "}"
- asm: "asm" "(" STRING+ ")"
- ?op: block
- | label
- | goto ";"
- | break ";"
- | continue ";"
- | vardec ";"
- | arrdec ";"
- | return ";"
- | varinit ";"
- | arrinit ";"
- | inc ";"
- | dec ";"
- | rinc ";"
- | rdec ";"
- | asm ";"
- | funcall ";"
- | switch
- | if
- | while
- | for
- label: NAME ":"
- goto: "goto" NAME
- break: "break"
- continue: "continue"
- switch: "switch" "(" expr ")" "{" case+ default? "}"
- case: "case" atom2 ":" op*
- default: "default" ":" op+
- if: "if" "(" expr ")" op ("else" op)?
- while: "while" "(" expr ")" op
- for: "for" "(" vardec ";" expr ";" (inc|dec|rinc|rdec|vardec|funcall) ")" op
- return: "return" expr?
- inc: NAME ("[" expr "]")? "++"
- dec: NAME ("[" expr "]")? "--"
- rinc: "++" NAME ("[" expr "]")?
- rdec: "--" NAME ("[" expr "]")?
- funcall: NAME "(" args ")"
- args:
- | [expr ("," expr)*]
- ?expr: op1
- | vardec
- | arrdec
- ?op1: op2
- | op2 "?" expr ":" op1 -> ifexpr
- ?op2: op3
- | op2 "||" op3 -> or
- ?op3: op4
- | op3 "&&" op4 -> and
- ?op4: op5
- | op4 "==" op5 -> equals
- | op4 "!=" op5 -> not_equals
- ?op5: op6
- | op5 "<" op6 -> less
- | op5 ">" op6 -> greater
- | op5 "<=" op6 -> less_or_equals
- | op5 ">=" op6 -> greater_or_equals
- ?op6: op7
- | op6 "+" op7 -> plus
- | op6 "-" op7 -> minus
- ?op7: op8
- | op7 "*" op8 -> times
- | op7 "/" op8 -> divide
- | op7 "%" op8 -> modulo
- ?op8: op9
- | op8 "**" op9 -> raise
- ?op9: op10
- | "*" op9 -> deref
- | "&" op9 -> ref
- ?op10: op11
- | "!" op10 -> not
- ?op11: op12
- | "-" op11 -> negate
- ?op12: atom
- | op12 "[" expr "]" -> index
- ?atom: "(" op1 ")"
- | atom2
- | funcall
- | inc
- | dec
- | rinc
- | rdec
- ?atom2: NAME
- | INTEGER
- | FLOAT
- | CHAR
- | STRING
- | "-" atom -> negate
- | array
- ?atom3: NAME
- | INTEGER
- | FLOAT
- | CHAR
- | STRING
- | "-" (INTEGER|FLOAT) -> negate
- array: "{" atom ("," atom)* "}"
- NAME: /[A-Za-z_][a-zA-Z0-9_]*/
- INTEGER: /[0-9]+/
- FLOAT: /[0-9]+\.[0-9]+/
- CHAR: /'(.|(\\.))'/
- _STRING_INNER: /(.|\n)*?/
- _STRING_ESC_INNER: _STRING_INNER /(?<!\\)(\\\\)*?/
- STRING: "\"" _STRING_ESC_INNER "\""
- IG: /[ \t\r\n]+/
- COM: /\/\*(.|\!)*\*\//
- %ignore IG
- %ignore COM
- """
- def parse_escape(s):
- return bytes(s, "utf-8").decode("unicode_escape")
- class ASMWalker:
- def __init__(self, asm):
- self.asm = asm
- self.pos = -1
- self.size = len(asm)
- def step(self):
- if self.pos >= self.size:
- return False
- self.pos += 1
- return True
- def skip(self, offset):
- for i in range(self.pos, self.pos+offset):
- if i >= self.size:
- break
- self.asm[i] = f"#{self.asm[i]}"
- #self.pos += offset
- def match(self, offset, *args):
- offset = self.pos + offset
- if offset >= self.size:
- return None
- asm = self.asm[offset]
- args = " ".join(args)
-
- return re.match(f"^{args}$", asm)
- def rule2(self, rule1, rule2, skip=2):
- m = self.match(0, *rule1)
- if m:
- if self.match(1, *map(lambda s: s.format(m=m[1]), rule2)):
- self.skip(skip)
- def rule3(self, rule1, rule2, rule3, skip=3):
- m = self.match(0, *rule1)
- if m:
- if self.match(1, *map(lambda s: s.format(m=m[1]), rule2)) and self.match(2, *map(lambda s: s.format(m=m[1]), rule2)):
- self.skip(skip)
- def get(self):
- return self.asm[self.pos]
- @property
- def end(self):
- return self.pos >= self.size
- class Buffer:
- def __init__(self, *init):
- self.buffer = list(init)
- def emit(self, asm, *args):
- if type(asm) is list:
- self.buffer.extend(asm)
- elif type(asm) is Buffer:
- self.buffer.extend(asm.buffer)
- else:
- self.buffer.append(asm.format(*args))
- def optimize(self):
- buffer = Buffer()
- walker = ASMWalker(self.buffer)
- while walker.step():
- walker.rule2(
- ("jmp", "(.+)"),
- ("{m}:",),
- skip=1
- )
- walker.rule2(
- ("push", "(.+)"),
- ("pop {m}",)
- )
- walker.rule3(
- ("jmp", "(.+)"),
- (".*:",),
- ("{m}:",),
- skip=1
- )
- if walker.end:
- break
- buffer.emit(walker.get())
- return buffer
- def generate(self):
- return "\n".join(self.buffer)
- def __add__(self, other):
- if type(other) is Buffer:
- return Buffer(
- *self.buffer + other.buffer
- )
- raise TypeError
- class Scope:
- def __init__(self):
- self.scopes = []
- self.ndx = 0
- self.names = set()
- def new(self):
- self.scopes.append((self.ndx, {}, {}))
- self.ndx += 1
- def leave(self):
- self.scopes.pop()
- def add_label(self, name):
- renamed = f"__{self.scopes[-1][0]}l_{name}"
- self.scopes[-1][2][name] = renamed
- return renamed
- def get_label(self, name):
- if name not in self.scopes[-1][2]:
- raise Exception(f"Undeclared label: '{name}`.")
- return self.scopes[-1][2][name]
- def is_local(self, name):
- return name in self.scopes[-1][1]
- def insert(self, name):
- renamed = f"__{self.scopes[-1][0]}_{name}"
- self.scopes[-1][1][name] = renamed
- self.names.add(renamed)
- return renamed
- def find(self, name):
- for _, scope, _ in self.scopes[::-1]:
- if name in scope:
- return scope[name]
- raise Exception(f"Undeclared identifier: '{name}`.")
- def __contains__(self, name):
- for _, scope, _ in self.scopes[::-1]:
- if name in scope:
- return True
- return False
- def __getitem__(self, name):
- if name in self:
- return self.find(name)
- return self.insert(name)
- def __iter__(self):
- for _, scope, _ in self.scopes[::-1]:
- for name in scope:
- yield (name, scope[name])
- @property
- def is_toplevel(self):
- return self.scopes[-1][0] == 0
- class Func:
- def __init__(self, argc, params):
- self.argc = argc
- self.params = params
- class WMC:
- def __init__(self):
- self.funcs = {}
- self.used_symbols = {}
- self.where = "toplevel"
- self.scope = Scope()
- self.scope.new()
- self.defined = {}
- self.compiled_funcs = {}
- self.arrays_buffer = Buffer()
- self.init_buffer = Buffer()
- self.buffer = Buffer(
- "jw __main",
- "pop Y",
- "mov 80 _dirBuf",
- "mov Y _dirBuf+1",
- "dir _dirBuf",
- "hlt",
- "_dirBuf:0*4",
- )
- self.label_ndx = 0
- self.included_files = set()
- self.parser = lark.Lark(GRAMMAR)
- self.loops = []
- def record_usage(self, name):
- if name in self.used_symbols:
- self.used_symbols[name].add(self.where)
- else:
- self.used_symbols[name] = set([self.where])
- def is_used(self, name):
- if name in ("main", "toplevel"):
- return True
- if name not in self.used_symbols:
- return False
- for other in self.used_symbols[name]:
- if other == name:
- continue
- if self.is_used(other):
- return True
- return False
- def make_label(self):
- name = f"__{self.label_ndx}l"
- self.label_ndx += 1
- return name
- def make_array(self, value):
- name = self.make_label()
- self.arrays_buffer.emit(
- "{}:{}",
- name, value
- )
- return name
- def compile_literal(self, node, soft=False):
- if type(node) is lark.Token:
- if node.type == "NAME":
- name = node.value
- if name in self.defined:
- return self.defined[name]
- if soft:
- return name
- value = self.scope.find(name)
- self.record_usage(value)
- return value
- elif node.type in ("INTEGER", "FLOAT"):
- return node.value
- elif node.type == "CHAR":
- return str(ord(parse_escape(node.value[1:-1])))
- elif node.type == "STRING":
- value = parse_escape(node.value[1:-1])
- value = map(ord, value)
- value = map(str, value)
- value = " ".join(value)
- value += " 0"
- return value
- elif node.data == "array":
- value = []
- for child in node.children:
- value.append(
- self.compile_literal(child)
- )
- return " ".join(value)
- raise Exception(f"Not implemented: {node}")
- def compile_unary_expr(self, node, *ops):
- buffer = Buffer()
- buffer.emit(
- self.compile_expr(
- node.children[0]
- )
- )
- for op in ops:
- buffer.emit(op)
- return buffer
- def compile_binary_expr(self, node):
- buffer = Buffer()
- buffer.emit(
- self.compile_expr(
- node.children[0]
- )
- )
- buffer.emit('push Y')
- buffer.emit(
- self.compile_expr(
- node.children[1]
- )
- )
- buffer.emit('push Y')
- buffer.emit('pop X')
- buffer.emit('pop Y')
- return buffer
- def compile_compare_expr(self, node, *ops, true='1', false='0'):
- buffer = Buffer()
- ret_label = self.make_label()
- exit_label = self.make_label()
- buffer.emit(
- self.compile_binary_expr(
- node
- )
- )
- for op in ops:
- buffer.emit(
- op,
- ret_label
- )
- buffer.emit(
- "mov {} Y",
- false
- )
- buffer.emit(
- "jmp {}",
- exit_label
- )
- buffer.emit(
- "{}:",
- ret_label
- )
- buffer.emit(
- "mov {} Y",
- true
- )
- buffer.emit(
- "{}:",
- exit_label
- )
- return buffer
- def compile_expr(self, node):
- buffer = Buffer()
- if type(node) is lark.Token:
- if node.type == "NAME":
- buffer.emit(
- "mov {} Y",
- self.compile_literal(node)
- )
- elif node.type in ("INTEGER", "FLOAT"):
- buffer.emit(
- "mov {} Y",
- self.compile_literal(node)
- )
- elif node.type == "CHAR":
- buffer.emit(
- "mov {} Y",
- self.compile_literal(node)
- )
- elif node.type == "STRING":
- buffer.emit(
- "ld {} Y",
- self.make_array(
- self.compile_literal(node)
- )
- )
- else:
- raise Exception(f"Not implemented: {node}")
- elif node.data == "ifexpr":
- else_label = self.make_label()
- exit_label = self.make_label()
- buffer.emit(
- self.compile_expr(
- node.children[0]
- )
- )
- buffer.emit(
- "nbnz Y {}",
- else_label
- )
- buffer.emit(
- self.compile_expr(
- node.children[1]
- )
- )
- buffer.emit(
- "jmp {}",
- exit_label
- )
- buffer.emit(
- "{}:",
- else_label
- )
- buffer.emit(
- self.compile_expr(
- node.children[2]
- )
- )
- buffer.emit(
- "{}:",
- exit_label
- )
- elif node.data == "equals":
- buffer.emit(
- self.compile_compare_expr(
- node,
- "sblez X Y !",
- "nbnz Y {}",
- true='1',
- false='0'
- )
- )
- elif node.data == "not_equals":
- buffer.emit(
- self.compile_compare_expr(
- node,
- "sblez X Y !",
- "nbnz Y {}",
- true='0',
- false='1'
- )
- )
- elif node.data == "plus":
- buffer.emit(
- self.compile_binary_expr(
- node
- )
- )
- buffer.emit(
- "ablez X Y !"
- )
- elif node.data == "minus":
- buffer.emit(
- self.compile_binary_expr(
- node
- )
- )
- buffer.emit(
- "sblez X Y !"
- )
- elif node.data == "times":
- buffer.emit(
- self.compile_binary_expr(
- node
- )
- )
- buffer.emit(
- "mbnz X Y !"
- )
- elif node.data == "divide":
- buffer.emit(
- self.compile_binary_expr(
- node
- )
- )
- buffer.emit(
- "vblz X Y !"
- )
- elif node.data == "modulo":
- buffer.emit(
- self.compile_binary_expr(
- node
- )
- )
- buffer.emit(
- "modbz X Y !"
- )
- elif node.data == "raise":
- buffer.emit(
- self.compile_binary_expr(
- node
- )
- )
- buffer.emit(
- "mov 12 _dirBuf"
- )
- buffer.emit(
- "mov X _dirBuf+1"
- )
- buffer.emit(
- "mov Y _dirBuf+2"
- )
- buffer.emit(
- "dir _dirBuf"
- )
- buffer.emit(
- "mov _dirBuf+2 Y"
- )
- elif node.data == "or":
- true_label = self.make_label()
- exit_label = self.make_label()
- buffer.emit(
- self.compile_expr(
- node.children[0]
- )
- )
- buffer.emit(
- "nbnz Y !"
- )
- buffer.emit(
- "nbnz Y {}",
- true_label
- )
- buffer.emit(
- self.compile_expr(
- node.children[1]
- )
- )
- buffer.emit(
- "jmp {}",
- exit_label
- )
- buffer.emit(
- "{}:",
- true_label
- )
- buffer.emit(
- "mov 1 Y"
- )
- buffer.emit(
- "{}:",
- exit_label
- )
- elif node.data == "and":
- false_label = self.make_label()
- exit_label = self.make_label()
- buffer.emit(
- self.compile_expr(
- node.children[0]
- )
- )
- buffer.emit(
- "nbnz Y {}",
- false_label
- )
- buffer.emit(
- self.compile_expr(
- node.children[1]
- )
- )
- buffer.emit(
- "jmp {}",
- exit_label
- )
- buffer.emit(
- "{}:",
- false_label
- )
- buffer.emit(
- "mov 0 Y"
- )
- buffer.emit(
- "{}:",
- exit_label
- )
- elif node.data == "less":
- buffer.emit(
- self.compile_compare_expr(
- node,
- "inc Y",
- "sblez X Y {}"
- )
- )
- elif node.data == "greater":
- buffer.emit(
- self.compile_compare_expr(
- node,
- "dec Y",
- "sblez X Y {}",
- true='0',
- false='1'
- )
- )
- elif node.data == "less_or_equals":
- buffer.emit(
- self.compile_compare_expr(
- node,
- "sblez X Y {}"
- )
- )
- elif node.data == "greater_or_equals":
- buffer.emit(
- self.compile_compare_expr(
- node,
- "inc Y",
- "sblez X Y {}",
- true='0',
- false='1'
- )
- )
- elif node.data == "not":
- buffer.emit(
- self.compile_unary_expr(
- node,
- "nbnz Y !"
- )
- )
- elif node.data == "deref":
- buffer.emit(
- self.compile_unary_expr(
- node,
- "la Y Y"
- )
- )
- elif node.data == "ref":
- value = node.children[0]
- if type(value) is lark.Token and value.type == "NAME":
- value = self.compile_literal(
- value
- )
- else:
- label = self.make_label()
- self.arrays_buffer.emit(
- "{}:0",
- label
- )
- buffer.emit(
- self.compile_expr(
- node.children[0]
- )
- )
- buffer.emit(
- "mov Y {}",
- label
- )
- value = label
- buffer.emit(
- "ld {} Y",
- value
- )
- elif node.data == "index":
- buffer.emit(
- self.compile_binary_expr(
- node
- )
- )
- buffer.emit(
- "ablez X Y !"
- )
- buffer.emit(
- "la Y Y"
- )
- elif node.data == "negate":
- buffer.emit(
- self.compile_unary_expr(
- node,
- "mov Y X",
- "sblez X Y !",
- "sblez X Y !",
- )
- )
- elif node.data == "array":
- buffer.emit(
- "ld {} Y",
- self.make_array(
- self.compile_literal(node)
- )
- )
- elif node.data == "inc":
- if len(node.children) == 2:
- raise Exception(f"Not implemented: {node}")
- name = self.scope[node.children[0].value]
- self.record_usage(name)
- buffer.emit(
- "mov {} Y",
- name
- )
- buffer.emit(
- "inc {}",
- name
- )
- elif node.data == "dec":
- if len(node.children) == 2:
- raise Exception(f"Not implemented: {node}")
- name = self.scope[node.children[0].value]
- self.record_usage(name)
- buffer.emit(
- "mov {} Y"
- )
- buffer.emit(
- "dec {}",
- name
- )
- elif node.data == "rinc":
- if len(node.children) == 2:
- raise Exception(f"Not implemented: {node}")
- name = self.scope[node.children[0].value]
- self.record_usage(name)
- buffer.emit(
- "inc {}",
- name
- )
- buffer.emit(
- "mov {} Y",
- name
- )
- elif node.data == "rdec":
- if len(node.children) == 2:
- raise Exception(f"Not implemented: {node}")
- name = self.scope[node.children[0].value]
- self.record_usage(name)
- buffer.emit(
- "dec {}",
- name
- )
- buffer.emit(
- "mov {} Y"
- )
- elif node.data == "funcall":
- buffer.emit(
- self.compile_funcall(node, dest='Y')
- )
- elif node.data == "vardec":
- buffer.emit(
- self.compile_op(node)
- )
- elif node.data == "arrdec":
- buffer.emit(
- self.compile_arrdec(node)
- )
- else:
- raise Exception(f"Not implemented: {node}")
- return buffer
- def compile_funcall(self, node, dest='Y'):
- buffer = Buffer()
- name = node.children[0].value
- if name not in self.funcs:
- raise Exception(f"Call to an undeclared function: '{name}`.")
- if self.funcs[name].argc != len(node.children[1].children):
- raise Exception(f"Function '{name}` expects {self.funcs[name].argc} arguments, but got {node.children[1].children}.")
- for arg in node.children[1].children[::-1]:
- buffer.emit(
- self.compile_expr(arg)
- )
- buffer.emit(
- "push Y"
- )
- buffer.emit(
- "jw __{}",
- name
- )
- buffer.emit(
- "pop {}",
- dest
- )
- self.record_usage(name)
- return buffer
- def compile_asm(self, node):
- buffer = Buffer()
- table = {}
- for name, renamed in self.scope:
- table[name] = renamed
- for child in node.children:
- value = parse_escape(child.value[1:-1])
- for name in table:
- if f"{{{name}}}" in value:
- self.record_usage(table[name])
- try:
- value = value.format(
- **table
- )
- except:
- raise Exception("Malformed asm directive.")
- buffer.emit(value)
- return buffer
- def compile_arrdec(self, node):
- buffer = Buffer()
- name = node.children[0].value
- if name in self.scope and len(node.children) == 3: # Index assignment.
- name = self.scope.find(name)
- self.record_usage(name)
- buffer.emit(
- self.compile_expr(
- node.children[1]
- )
- )
- buffer.emit(
- "mov {} X",
- name
- )
- buffer.emit(
- "ablez X Y !"
- )
- buffer.emit(
- "push Y"
- )
- buffer.emit(
- self.compile_expr(
- node.children[2]
- )
- )
- buffer.emit(
- "pop X"
- )
- buffer.emit(
- "str Y X"
- )
- return buffer
- if not self.scope.is_toplevel and self.scope.is_local(name):
- raise Exception(f"Duplicated declaration of a local variable: '{node.children[0].value}`.")
- name = self.scope[name]
- self.record_usage(name)
- count = int(node.children[1].value)
- if count <= 0:
- raise Exception(f"Illegal array declaration '{node.children[0].value}`: array size should be >0, but it is {count}.")
- if len(node.children) == 3:
- value = self.compile_literal(node.children[2])
- size = len(value.split(" ")) # Dirty.
- if size < count:
- value += f" 0*{count-size}"
- elif size != count:
- raise Exception(f"Illegal array declaration '{node.children[0].value}`: value size is {size}, but expected {count}.")
- buffer.emit(
- "ld {} Y",
- self.make_array(value)
- )
- else:
- buffer.emit(
- "ld {} Y",
- self.make_array(f"0*{count}")
- )
- buffer.emit(
- "mov Y {}",
- name
- )
- return buffer
- def compile_op(self, node):
- buffer = Buffer()
- if node.data == "block":
- buffer.emit(
- self.compile_block(
- node
- )
- )
- elif node.data == "label":
- buffer.emit(
- "{}:",
- self.scope.get_label(node.children[0].value)
- )
- elif node.data == "goto":
- buffer.emit(
- "jmp {}",
- self.scope.get_label(node.children[0].value)
- )
- elif node.data == "break":
- if len(self.loops) < 1:
- raise Exception("'break` outside of a loop/switch.")
- labels = self.loops[-1]
- buffer.emit(
- "jmp {}",
- labels[0] if len(labels) == 1 else labels[1]
- )
- elif node.data == "continue":
- if len(self.loops) < 1 or len(self.loops[-1]) != 2:
- raise Exception("'continue` outside of a loop.")
- buffer.emit(
- "jmp {}",
- self.loops[-1][0]
- )
- elif node.data == "varinit":
- name = node.children[0].value
- if self.scope.is_local(name):
- raise Exception(f"Duplicated declaration of a local variable: '{name}`.")
-
- self.scope.insert(name)
- elif node.data == "vardec":
- name = self.scope[node.children[0].value]
- self.record_usage(name)
- buffer.emit(
- self.compile_expr(node.children[1])
- )
- buffer.emit(
- "mov Y {}",
- name
- )
- elif node.data in ("arrdec", "arrinit"):
- buffer.emit(
- self.compile_arrdec(node)
- )
- elif node.data in ("inc", "rinc"):
- if len(node.children) == 2:
- raise Exception(f"Not implemented: {node}")
- name = self.scope[node.children[0].value]
- self.record_usage(name)
- buffer.emit(
- "inc {}",
- name
- )
- elif node.data == ("dec", "rdec"):
- if len(node.children) == 2:
- raise Exception(f"Not implemented: {node}")
- name = self.scope[node.children[0].value]
- self.record_usage(name)
- buffer.emit(
- "dec {}",
- name
- )
- elif node.data == "funcall":
- buffer.emit(
- self.compile_funcall(node, dest='ZZ')
- )
- elif node.data == "return":
- if len(node.children) == 1:
- buffer.emit(
- self.compile_expr(
- node.children[0]
- )
- )
- buffer.emit("push Y")
- else:
- buffer.emit("push Z")
-
- buffer.emit("ret")
- elif node.data == "asm":
- buffer.emit(
- self.compile_asm(node)
- )
- elif node.data == "switch":
- default_label = self.make_label()
- exit_label = self.make_label()
- buffer.emit(
- self.compile_expr(node.children[0])
- )
- buffer.emit(
- "push Y"
- )
- self.loops.append((exit_label,))
- labels = {}
- for child in node.children[1:]:
- if child.data == "default":
- label = default_label
- ops = child.children
- else:
- label = self.make_label()
- buffer.emit(
- self.compile_expr(child.children[0])
- )
- buffer.emit(
- "peek X"
- )
- buffer.emit(
- "sblez X Y !"
- )
- buffer.emit(
- "nbnz Y {}",
- label
- )
- ops = child.children[1:]
-
- subbuffer = Buffer()
- for op in ops:
- subbuffer.emit(
- self.compile_op(op)
- )
- labels[label] = subbuffer
- self.loops.pop()
- buffer.emit(
- "pop ZZ"
- )
- buffer.emit(
- "jmp {}",
- default_label if default_label in labels else exit_label
- )
- for name in labels:
- buffer.emit(
- "{}:",
- name
- )
- buffer.emit(
- labels[name]
- )
- buffer.emit(
- "{}:",
- exit_label
- )
- elif node.data == "if":
- else_label = self.make_label()
- exit_label = self.make_label()
- buffer.emit(
- self.compile_expr(
- node.children[0]
- )
- )
- buffer.emit(
- "nbnz Y {}",
- else_label
- )
- buffer.emit(
- self.compile_op(
- node.children[1]
- )
- )
- buffer.emit(
- "jmp {}",
- exit_label
- )
- buffer.emit(
- "{}:",
- else_label
- )
- if len(node.children) == 3:
- buffer.emit(
- self.compile_op(
- node.children[2]
- )
- )
- buffer.emit(
- "{}:",
- exit_label
- )
- elif node.data == "while":
- loop_label = self.make_label()
- exit_label = self.make_label()
- self.loops.append((loop_label, exit_label))
- buffer.emit(
- "{}:",
- loop_label
- )
- buffer.emit(
- self.compile_expr(
- node.children[0]
- )
- )
- buffer.emit(
- "nbnz Y {}",
- exit_label
- )
- buffer.emit(
- self.compile_op(
- node.children[1]
- )
- )
- self.loops.pop()
- buffer.emit(
- "jmp {}",
- loop_label
- )
- buffer.emit(
- "{}:",
- exit_label
- )
- elif node.data == "for":
- loop_label = self.make_label()
- exit_label = self.make_label()
- self.loops.append((loop_label, exit_label))
- self.scope.new()
- buffer.emit(
- self.compile_op(
- node.children[0]
- )
- )
- buffer.emit(
- "{}:",
- loop_label
- )
- buffer.emit(
- self.compile_expr(
- node.children[1]
- )
- )
- buffer.emit(
- "nbnz Y {}",
- exit_label
- )
- buffer.emit(
- self.compile_op(
- node.children[3]
- )
- )
- buffer.emit(
- self.compile_op(
- node.children[2]
- )
- )
- self.scope.leave()
- self.loops.pop()
- buffer.emit(
- "jmp {}",
- loop_label
- )
- buffer.emit(
- "{}:",
- exit_label
- )
- else:
- raise Exception(f"Not implemented: {node}")
- return buffer
- def collect_labels(self, node):
- for child in node.children:
- if child.data == "label":
- self.scope.add_label(child.children[0].value)
- def compile_block(self, node, *prepend_names, scope=True):
- if scope:
- self.scope.new()
- for name in prepend_names:
- self.scope.insert(name)
- buffer = Buffer()
- self.collect_labels(node)
- for child in node.children:
- buffer.emit(
- self.compile_op(child)
- )
- if scope:
- self.scope.leave()
- return buffer
- def compile_toplevel(self, node):
- buffer = Buffer()
- if node.data == "funcdef":
- name = node.children[0].value
- params = self.funcs[name].params
- buffer.emit("__{}:", name)
- for param in params:
- buffer.emit(
- "pop __{}_{}",
- self.scope.ndx, param
- )
- self.where = name
- buffer.emit(
- self.compile_block(
- node.children[2],
- *params
- )
- )
- buffer.emit(
- "push Z"
- )
- buffer.emit(
- "ret"
- )
- elif node.data == "vardec":
- name = node.children[0].value
- self.init_buffer.emit(
- self.compile_expr(node.children[1])
- )
- self.init_buffer.emit(
- "mov Y __0_{}",
- name
- )
- self.record_usage(f"__0__{name}")
- elif node.data in ("arrdec", "arrinit"):
- self.init_buffer.emit(
- self.compile_arrdec(node)
- )
- elif node.data == "asm":
- buffer.emit(
- self.compile_asm(node)
- )
- elif node.data == "include":
- filename = node.children[0].value[1:-1]
- if not os.path.isfile(filename):
- filename = os.path.join(os.getenv("WC_I"), filename)
- if not os.path.isfile(filename):
- raise Exception(f"No such file: '{os.path.basename(filename)}`.")
- if filename not in self.included_files:
- with open(filename, "r") as f:
- self.compile_program(f.read())
- self.included_files.add(filename)
- elif node.data in ("varinit", "define"):
- pass
- else:
- raise Exception(f"Not implemented: {node}")
- return buffer
- def collect_toplevel(self, ast):
- for node in ast.children:
- if node.data == "funcdef":
- name = node.children[0].value
- if name in self.funcs:
- raise Exception(f"Duplicated function declaration: '{name}`.")
- self.funcs[name] = Func(
- len(node.children[1].children),
- tuple(map(lambda t: t.value, node.children[1].children))
- )
- elif node.data in ("varinit", "vardec", "arrinit", "arrdec"):
- name = node.children[0].value
- if name in self.scope:
- raise Exception(f"Duplicated top-level variable declaration: '{name}`.")
- self.scope.insert(name) # Because we're at the top-level rn.
- elif node.data == "define":
- name = node.children[0].value
- value = node.children[1]
- if type(value) is lark.Tree and value.data == "negate":
- value = "-" + self.compile_literal(value.children[0])
- else:
- value = self.compile_literal(value, soft=True)
- self.defined[name] = value
- def compile_program(self, text):
- ast = self.parser.parse(text)
- #print(ast.pretty())
- self.collect_toplevel(ast)
- for node in ast.children:
- buffer = self.compile_toplevel(node)
- if node.data == "funcdef":
- self.compiled_funcs[node.children[0].value] = buffer
- else:
- self.buffer.emit(buffer)
- def compile(self, text):
- self.compile_program(text)
- for name in self.compiled_funcs:
- if self.is_used(name):
- self.buffer.emit(self.compiled_funcs[name])
- for param in self.funcs[name].params:
- self.record_usage(param)
- self.buffer = self.init_buffer + self.buffer + self.arrays_buffer
- for name in self.scope.names:
- if self.is_used(name):
- self.buffer.emit(
- "{}:0", name
- )
- if "main" not in self.funcs:
- raise Exception("Missing 'main` function.")
- self.buffer = self.buffer.optimize()
- return self.buffer.generate() + "\n"
- wmc = WMC()
- try:
- if len(sys.argv) == 3:
- with open(sys.argv[1], "r") as fin:
- with open(sys.argv[2], "w") as fout:
- fout.write(wmc.compile(fin.read()))
- else:
- sys.stdout.write(wmc.compile(sys.stdin.read()))
- except Exception as e:
- #__import__('traceback').print_exc()
- print(e)
- sys.exit(1)
|