123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486 |
- #!/usr/bin/python
- import sys
- import struct
- import lark
- GRAMMAR = r"""
- start: _NL? command ((_NL+|";") command)* _NL?
- command: LABEL? operation
- | LABEL (arg|mixed)
- | LABEL
- | INCLUDE
- ?operation: "org" INTEGER -> org
- | "nop" -> nop
- | ("mvj"|"mj") arg arg arg -> mj
- | "sblez" arg arg arg -> sjlez
- | "ablez" arg arg arg -> ajlez
- | "sblz" arg arg arg -> sjlz
- | "bles" arg arg arg -> jles
- | ("nbnz"|"tjt") arg arg -> tjt
- | ("dbnz"|"djt") arg arg -> djt
- | "sslez" arg arg -> sslez
- | "aslez" arg arg -> aslez
- | ("ibnc"|"ije") arg arg arg -> ije
- | "vblz" arg arg arg -> djlz
- | "xblz" arg arg arg -> xjlz
- | "dslz" arg -> dslz
- | ("ssgt"|"ssl") arg arg -> ssl
- | "mbnz" arg arg arg -> mbnz
- | "modbz" arg arg arg -> modbz
- | ("aja"|"aj") arg arg -> aj
- | "la" arg arg -> la
- | "ld" arg arg -> ld
- | "ia" arg -> ia
- | "jmc" arg -> jmc
- | ("jw"|"ja"|"call") arg -> ja
- | ("push"|"psh") arg -> psh
- | "pd" arg -> pd
- | "pop" arg -> pop
- | "shlbnz" arg arg arg -> shlbnz
- | "shrbnz" arg arg arg -> shrbnz
- | "nbz" arg arg -> nbz
- | "anz" arg arg arg -> anz
- | "abgz" arg arg arg -> abgz
- | "swp" arg arg -> swp
- | "str" arg arg -> str
- | "add" arg arg -> h_add
- | "sub" arg arg -> h_sub
- | "inc" arg -> h_inc
- | "dec" arg -> h_dec
- | ("mov"|"mv") arg arg -> h_mov
- | ("jmp"|"j") arg -> h_jmp
- | "hlt" -> h_halt
- | "out" arg -> h_out
- | "outn" arg -> h_outn
- | "in" arg -> h_in
- | "dir" arg -> h_dir
- | "ret" -> h_ret
- | "peek" arg -> h_peek
- mixed: arg arg+
- ?arg: INTEGER
- | DOUBLE
- | CHAR
- | CHARS
- | OFFSET
- | LABELOFFSET
- | QMARK
- | EMARK
- | NAMEOFFSET
- | NAME
- | rep
- rep: (INTEGER|DOUBLE|CHAR|CHARS|NAME|NAMEOFFSET) "*" COUNT
- COUNT: /[0-9]+/
- INTEGER: /-?[0-9]+/
- DOUBLE: /-?[0-9]+\.[0-9]+/
- CHAR: "'" /./ "'"
- CHARS: "\"" /[^"]*/ "\""
- QMARK: "?"
- EMARK: "!"
- OFFSET: "?" /(-|\+)[0-9]+/
- LABELOFFSET: "?" /(-|\+)[A-Za-z_][a-zA-Z0-9_]*/
- NAMEOFFSET: /[A-Za-z_][a-zA-Z0-9_]*(-|\+)[0-9]+/
- LABEL: /[A-Za-z_][a-zA-Z0-9_]*:/
- NAME: /[A-Za-z_][a-zA-Z0-9_]*/
- INCLUDE: "+" /.+/
- _NL: /\n+/
- IG: /[ \t\r]+/
- COM: /#.*[^\n]/
- %ignore IG
- %ignore COM
- """
- class WMA:
- def __init__(self):
- self.origin = 1
- self.buffer = []
- self.size = 0
-
- self.parser = lark.Lark(GRAMMAR, parser='lalr')
- self.used_regs = []
- self.labels = {}
- def add_label(self, label):
- if label in self.labels:
- raise Exception(f"Duplicated label: {label}.")
- elif label in ("PC", "IO", "Z", "O", "N", "J", "T", "SP", "EZ", "SZ", "MZ", "JZ", "W", "MM", "DR", "ZZ"):
- raise Exception(f"Register override: {label}.")
- self.labels[label] = len(self.buffer)
- def emit(self, *ops):
- self.buffer.extend(ops)
- if type(ops[0]) is tuple and ops[0][0]:
- return
- self.size += len(ops)
- def compile_arg(self, arg):
- if type(arg) is lark.Tree:
- if arg.data == "mixed":
- for subnode in arg.children:
- self.compile_arg(subnode)
- elif arg.data == "rep":
- count = int(arg.children[1].value)
- for _ in range(count):
- self.compile_arg(arg.children[0])
- elif arg.type == "INTEGER":
- self.emit(int(arg.value))
- elif arg.type == "DOUBLE":
- self.emit(float(arg.value))
- elif arg.type == "CHAR":
- self.emit(ord(arg.value[1]))
- elif arg.type == "CHARS":
- for char in arg.value[1:-1]:
- self.emit(ord(char))
- elif arg.type == "QMARK":
- self.emit(self.origin + self.size)
- elif arg.type == "EMARK":
- self.emit(self.origin + self.size + 1)
- elif arg.type == "OFFSET":
- self.emit(self.origin + self.size + int(arg.value[1:]))
- elif arg.type == "LABELOFFSET":
- self.emit((False, arg.value[2:], arg.value[1]))
- elif arg.type == "NAMEOFFSET":
- n, o = arg.value.split(
- '+' if '+' in arg.value else '-'
- )
- self.emit((False, n, '+' if '+' in arg.value else '-', int(o)))
- elif arg.type == "NAME":
- if arg.value == "PC":
- self.emit(0)
- elif arg.value == "IO":
- self.emit(-1)
- elif arg.value == "Z":
- self.emit(-2)
- elif arg.value == "O":
- self.emit(-3)
- elif arg.value == "N":
- self.emit(-4)
- elif arg.value == "J":
- self.emit(-5)
- elif arg.value == "T":
- self.emit(-6)
- elif arg.value == "SP":
- self.emit(-7)
- elif arg.value == "EZ":
- self.emit(-8)
- elif arg.value == "SZ":
- self.emit(-9)
- elif arg.value == "MZ":
- self.emit(-10)
- elif arg.value == "JZ":
- self.emit(-11)
- elif arg.value == "W":
- self.emit(-12)
- elif arg.value == "MM":
- self.emit(-13)
- elif arg.value == "DR":
- self.emit(-14)
- elif arg.value == "ZZ":
- self.emit(-15)
- else:
- self.used_regs.append(arg.value)
- self.emit((False, arg.value))
- def compile_operation(self, op):
- if op.data == "org":
- self.org = int(op.children[0].value)
- elif op.data == "nop":
- self.emit(0)
- elif op.data == "mj":
- self.emit(1)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- self.compile_arg(op.children[2])
- elif op.data == "sjlez":
- self.emit(2)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- self.compile_arg(op.children[2])
- elif op.data == "ajlez":
- self.emit(3)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- self.compile_arg(op.children[2])
- elif op.data == "sjlz":
- self.emit(4)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- self.compile_arg(op.children[2])
- elif op.data == "jles":
- self.emit(5)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- self.compile_arg(op.children[2])
- elif op.data == "tjt":
- self.emit(6)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- elif op.data == "djt":
- self.emit(7)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- elif op.data == "sslez":
- self.emit(8)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- elif op.data == "aslez":
- self.emit(9)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- elif op.data == "ije":
- self.emit(10)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- self.compile_arg(op.children[2])
- elif op.data == "djlz":
- self.emit(11)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- self.compile_arg(op.children[2])
- elif op.data == "xjlz":
- self.emit(12)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- self.compile_arg(op.children[2])
- elif op.data == "dslz":
- self.emit(13)
- self.compile_arg(op.children[0])
- elif op.data == "ssl":
- self.emit(14)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- elif op.data == "mbnz":
- self.emit(15)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- self.compile_arg(op.children[2])
- elif op.data == "modbz":
- self.emit(16)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- self.compile_arg(op.children[2])
- elif op.data == "aj":
- self.emit(17)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- elif op.data == "la":
- self.emit(18)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- elif op.data == "ld":
- self.emit(19)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- elif op.data == "ia":
- self.emit(20)
- self.compile_arg(op.children[0])
- elif op.data == "jmc":
- self.emit(21)
- self.compile_arg(op.children[0])
- elif op.data == "ja":
- self.emit(22)
- self.compile_arg(op.children[0])
- elif op.data == "psh":
- self.emit(23)
- self.compile_arg(op.children[0])
- elif op.data == "pd":
- self.emit(24)
- self.compile_arg(op.children[0])
- elif op.data == "pop":
- self.emit(25)
- self.compile_arg(op.children[0])
- elif op.data == "shlbnz":
- self.emit(26)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- self.compile_arg(op.children[2])
- elif op.data == "shrbnz":
- self.emit(27)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- self.compile_arg(op.children[2])
- elif op.data == "nbz":
- self.emit(28)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- elif op.data == "anz":
- self.emit(29)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- self.compile_arg(op.children[2])
- elif op.data == "abgz":
- self.emit(30)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- self.compile_arg(op.children[2])
- elif op.data == "swp":
- self.emit(31)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- elif op.data == "str":
- self.emit(32)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- elif op.data == "h_add":
- self.emit(3)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- self.emit(self.size+2)
- elif op.data == "h_sub":
- self.emit(2)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- self.emit(self.size+2)
- elif op.data == "h_inc":
- self.emit(3)
- self.emit(-3)
- self.compile_arg(op.children[0])
- self.emit(self.size+2)
- elif op.data == "h_dec":
- self.emit(2)
- self.emit(-3)
- self.compile_arg(op.children[0])
- self.emit(self.size+2)
- elif op.data == "h_mov":
- if type(op.children[0]) is lark.Token and op.children[0].type in ("INTEGER", "DOUBLE", "CHAR"):
- self.emit(19)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- else:
- self.emit(1)
- self.compile_arg(op.children[0])
- self.compile_arg(op.children[1])
- self.emit(self.size+2)
- elif op.data == "h_jmp":
- self.emit(1)
- self.emit(0)
- self.emit(0)
- self.compile_arg(op.children[0])
- elif op.data == "h_halt":
- self.emit(1)
- self.emit(0)
- self.emit(0)
- self.emit(-1)
- elif op.data == "h_out":
- self.emit(1)
- self.compile_arg(op.children[0])
- self.emit(-1)
- self.emit(self.size+2)
- elif op.data == "h_outn":
- self.emit(1)
- self.compile_arg(op.children[0])
- self.emit(-2)
- self.emit(self.size+2)
- elif op.data == "h_in":
- self.emit(1)
- self.emit(-1)
- self.compile_arg(op.children[0])
- self.emit(self.size+2)
- elif op.data == "h_dir":
- self.emit(19)
- self.compile_arg(op.children[0])
- self.emit(-14)
- elif op.data == "h_ret":
- self.emit(1)
- self.emit(-15)
- self.emit(-15)
- self.emit(-5)
- elif op.data == "h_peek":
- self.emit(25)
- self.compile_arg(op.children[0])
- self.emit(23)
- self.compile_arg(op.children[0])
- def compile_labels(self):
- position = 0
- while position < len(self.buffer):
- this = self.buffer[position]
- if type(this) is tuple and not this[0]:
- label = this[1]
- if label not in self.labels:
- raise Exception(f"Undefined label/register: {label}.")
-
- if len(this) == 3:
- self.buffer[position] = self.origin + (position + self.labels[label] if this[2] == '+' else position - selfmlabels[label])
- elif len(this) == 4:
- self.buffer[position] = self.origin + (self.labels[label] + this[3] if this[2] == '+' else self.labels[label] - this[3])
- else:
- self.buffer[position] = self.origin + self.labels[label]
- position += 1
- def encode(self):
- buffer = b""
- for b in self.buffer:
- if type(b) is float:
- b = struct.pack("<d", b)
- else:
- b = struct.pack("<q", b)
- buffer += b
- return buffer
- def precompile(self, source):
- ast = self.parser.parse(source)
- for command in ast.children:
- if len(command.children) == 2:
- label = command.children[0].value[:-1]
-
- self.add_label(label)
- if type(command.children[1]) is lark.Tree and command.children[1].data not in ("mixed", "rep"):
- self.compile_operation(command.children[1])
- else:
- self.compile_arg(command.children[1])
- else:
- if type(command.children[0]) is lark.Token:
- if command.children[0].type == "LABEL":
- label = command.children[0].value[:-1]
- self.add_label(label)
- self.labels[label] = len(self.buffer)
- else:
- with open(command.children[0].value[1:], "r") as f:
- self.precompile(f.read())
- else:
- self.compile_operation(command.children[0])
- def compile(self, source):
- self.precompile(source)
-
- for reg in "ABCDEFGHIXYKRS":
- if reg in self.used_regs:
- self.add_label(reg)
- self.buffer.append(0)
- self.add_label("END")
- self.compile_labels()
- return self.encode()
- wma = WMA()
- try:
- if len(sys.argv) == 3:
- with open(sys.argv[1], "r") as fin:
- with open(sys.argv[2], "wb") as fout:
- fout.write(wma.compile(fin.read()))
- else:
- sys.stdout.buffer.write(wma.compile(sys.stdin.read()))
- except Exception as e:
- print(e)
- sys.exit(1)
|