+import sys
+import struct
+import lark
+GRAMMAR = r"""
+start: _NL? command ((_NL+|";") command)* _NL?
+command: LABEL? operation
+ | LABEL (arg|mixed)
+?operation: "nop" -> nop
+ | ("mvj"|"mj") arg arg arg -> mj
+ | "sblez" arg arg arg -> sjlez
+ | "ablez" arg arg arg -> ajlez
+ | "sblz" arg arg arg -> sjlz
+ | "bles" arg arg arg -> jles
+ | ("nbnz"|"tjt") arg arg -> tjt
+ | ("dbnz"|"djt") arg arg -> djt
+ | "sslez" arg arg -> sslez
+ | "aslez" arg arg -> aslez
+ | ("ibnc"|"ije") arg arg arg -> ije
+ | "vblz" arg arg arg -> djlz
+ | "xblz" arg arg arg -> xjlz
+ | "dslz" arg -> dslz
+ | ("ssgt"|"ssl") arg arg -> ssl
+ | "mbnz" arg arg arg -> mbnz
+ | "modbz" arg arg arg -> modbz
+ | ("aja"|"aj") arg arg -> aj
+ | "la" arg arg -> la
+ | "ld" arg arg -> ld
+ | "ia" arg -> ia
+ | "jmc" arg -> jmc
+ | ("jw"|"ja") arg -> ja
+ | ("push"|"psh") arg -> psh
+ | "pd" arg -> pd
+ | "pop" arg -> pop
+ | "shlbnz" arg arg arg -> shlbnz
+ | "shrbnz" arg arg arg -> shrbnz
+ | "nbz" arg arg -> nbz
+ | "anz" arg arg arg -> anz
+ | "abgz" arg arg arg -> abgz
+ | "swp" arg arg -> swp
+ | "add" arg arg -> h_add
+ | "sub" arg arg -> h_sub
+ | "inc" arg -> h_inc
+ | "dec" arg -> h_dec
+ | ("mov"|"mv") arg arg -> h_mov
+ | ("jmp"|"j") arg -> h_jmp
+ | "hlt" -> h_halt
+ | "out" arg -> h_out
+ | "in" arg -> h_in
+mixed: arg arg+
+?arg: INTEGER
+ | CHAR
+ | NAME
+ | rep
+rep: arg "*" COUNT
+COUNT: /[0-9]+/
+INTEGER: /-?[0-9]+/
+CHAR: "'" /./ "'"
+CHARS: "\"" /[^"]*/ "\""
+QMARK: "?"
+OFFSET: "$" /(-|\+)[0-9]+/
+LABELOFFSET: "$" /(-|\+)[A-Za-z][a-zA-Z0-9_]*/
+LABEL: /[A-Za-z][a-zA-Z0-9_]*:/
+NAME: /[A-Za-z][a-zA-Z0-9_]*/
+INCLUDE: "+" /.+/
+_NL: /\n+/
+IG: /[ \t\r]+/
+COM: /#.*[^\n]/
+%ignore IG
+%ignore COM
+class WMA:
+ def __init__(self):
+ self.buffer = []
+ self.size = 0
+ self.parser = lark.Lark(GRAMMAR)
+ def emit(self, *ops):
+ self.buffer.extend(ops)
+ if type(ops[0]) is tuple and ops[0][0]:
+ return
+ self.size += len(ops)
+ def compile_arg(self, arg):
+ if type(arg) is lark.Tree:
+ if arg.data == "mixed":
+ for subnode in arg.children:
+ self.compile_arg(subnode)
+ elif arg.data == "rep":
+ count = int(arg.children[1].value)
+ for _ in range(count):
+ self.compile_arg(arg.children[0])
+ elif arg.type == "INTEGER":
+ self.emit(int(arg.value))
+ elif arg.type == "CHAR":
+ self.emit(ord(arg.value[1]))
+ elif arg.type == "CHARS":
+ for char in arg.value[1:-1]:
+ self.emit(ord(char))
+ elif arg.type == "QMARK":
+ self.emit(self.size+2)
+ elif arg.type == "OFFSET":
+ self.emit(self.size+int(arg.value[1:])+1)
+ elif arg.type == "LABELOFFSET":
+ self.emit((False, arg.value[2:], arg.value[1]))
+ elif arg.type == "NAME":
+ if arg.value == "IO":
+ self.emit(-1)
+ elif arg.value == "Z":
+ self.emit(-2)
+ elif arg.value == "O":
+ self.emit(-3)
+ elif arg.value == "N":
+ self.emit(-4)
+ elif arg.value == "J":
+ self.emit(-5)
+ elif arg.value == "T":
+ self.emit(-6)
+ elif arg.value == "SP":
+ self.emit(-7)
+ elif arg.value == "EZ":
+ self.emit(-8)
+ elif arg.value == "SZ":
+ self.emit(-9)
+ elif arg.value == "MZ":
+ self.emit(-10)
+ elif arg.value == "JZ":
+ self.emit(-11)
+ elif arg.value == "W":
+ self.emit(-12)
+ elif arg.value == "MM":
+ self.emit(-13)
+ else:
+ self.emit((False, arg.value))
+ def compile_operation(self, op):
+ if op.data == "nop":
+ self.emit(0)
+ elif op.data == "mj":
+ self.emit(1)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ self.compile_arg(op.children[2])
+ elif op.data == "sjlez":
+ self.emit(2)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ self.compile_arg(op.children[2])
+ elif op.data == "ajlez":
+ self.emit(3)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ self.compile_arg(op.children[2])
+ elif op.data == "sjlz":
+ self.emit(4)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ self.compile_arg(op.children[2])
+ elif op.data == "jles":
+ self.emit(5)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ self.compile_arg(op.children[2])
+ elif op.data == "tjt":
+ self.emit(6)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ elif op.data == "djt":
+ self.emit(7)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ elif op.data == "sslez":
+ self.emit(8)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ elif op.data == "aslez":
+ self.emit(9)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ elif op.data == "ije":
+ self.emit(10)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ self.compile_arg(op.children[2])
+ elif op.data == "djlz":
+ self.emit(11)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ self.compile_arg(op.children[2])
+ elif op.data == "xjlz":
+ self.emit(12)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ self.compile_arg(op.children[2])
+ elif op.data == "dslz":
+ self.emit(13)
+ self.compile_arg(op.children[0])
+ elif op.data == "ssl":
+ self.emit(14)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ elif op.data == "mbnz":
+ self.emit(15)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ self.compile_arg(op.children[2])
+ elif op.data == "modbz":
+ self.emit(16)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ self.compile_arg(op.children[2])
+ elif op.data == "aj":
+ self.emit(17)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ elif op.data == "la":
+ self.emit(18)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ elif op.data == "ld":
+ self.emit(19)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ elif op.data == "ia":
+ self.emit(20)
+ self.compile_arg(op.children[0])
+ elif op.data == "jmc":
+ self.emit(21)
+ self.compile_arg(op.children[0])
+ elif op.data == "ja":
+ self.emit(22)
+ self.compile_arg(op.children[0])
+ elif op.data == "psh":
+ self.emit(23)
+ self.compile_arg(op.children[0])
+ elif op.data == "pd":
+ self.emit(24)
+ self.compile_arg(op.children[0])
+ elif op.data == "pop":
+ self.emit(25)
+ self.compile_arg(op.children[0])
+ elif op.data == "shlbnz":
+ self.emit(26)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ self.compile_arg(op.children[2])
+ elif op.data == "shrbnz":
+ self.emit(27)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ self.compile_arg(op.children[2])
+ elif op.data == "nbz":
+ self.emit(28)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ elif op.data == "anz":
+ self.emit(29)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ self.compile_arg(op.children[2])
+ elif op.data == "abgz":
+ self.emit(30)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ self.compile_arg(op.children[2])
+ elif op.data == "swp":
+ self.emit(31)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ elif op.data == "h_add":
+ self.emit(3)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ self.emit(self.size+2)
+ elif op.data == "h_sub":
+ self.emit(2)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ self.emit(self.size+2)
+ elif op.data == "h_inc":
+ self.emit(3)
+ self.emit(-3)
+ self.compile_arg(op.children[0])
+ self.emit(self.size+2)
+ elif op.data == "h_dec":
+ self.emit(2)
+ self.emit(-3)
+ self.compile_arg(op.children[0])
+ self.emit(self.size+2)
+ elif op.data == "h_mov":
+ self.emit(1)
+ self.compile_arg(op.children[0])
+ self.compile_arg(op.children[1])
+ self.emit(self.size+2)
+ elif op.data == "h_jmp":
+ self.emit(1)
+ self.emit(0)
+ self.emit(0)
+ self.compile_arg(op.children[0])
+ elif op.data == "h_halt":
+ self.emit(1)
+ self.emit(0)
+ self.emit(0)
+ self.emit(-1)
+ elif op.data == "h_out":
+ self.emit(1)
+ self.compile_arg(op.children[0])
+ self.emit(-1)
+ self.emit(self.size+2)
+ elif op.data == "h_in":
+ self.emit(1)
+ self.emit(-1)
+ self.compile_arg(op.children[0])
+ self.emit(self.size+2)
+ def compile_labels(self):
+ labels = {}
+ position = 0
+ while position < len(self.buffer):
+ this = self.buffer[position]
+ if type(this) is tuple and this[0]:
+ label = this[1]
+ if label in labels:
+ raise Exception(f"Duplicated label: {label}.")
+ elif label in ("IO", "Z", "O", "N", "J", "T", "SP", "EZ", "SZ", "MZ", "JZ", "W", "MM"):
+ raise Exception(f"Register override: {label}.")
+ self.buffer.pop(position)
+ labels[label] = position + 1
+ position += 1
+ position = 0
+ while position < len(self.buffer):
+ this = self.buffer[position]
+ if type(this) is tuple and not this[0]:
+ label = this[1]
+ if label not in labels:
+ raise Exception(f"Undefined label/register: {label}.")
+ if len(this) == 3:
+ self.buffer[position] = 1 + (position + labels[label] if this[2] == '+' else position - labels[label])
+ else:
+ self.buffer[position] = labels[label]
+ position += 1
+ def encode(self):
+ return struct.pack(f"<{'q'*self.size}", *self.buffer)
+ def precompile(self, source):
+ ast = self.parser.parse(source)
+ for command in ast.children:
+ if len(command.children) == 2:
+ label = command.children[0].value[:-1]
+ self.emit((True,label))
+ if type(command.children[1]) is lark.Tree and command.children[1].data != "mixed":
+ self.compile_operation(command.children[1])
+ else:
+ self.compile_arg(command.children[1])
+ else:
+ if type(command.children[0]) is lark.Token:
+ if command.children[0].type == "LABEL":
+ label = command.children[0].value[:-1]
+ self.emit((True,label))
+ else:
+ with open(command.children[0].value[1:], "r") as f:
+ self.precompile(f.read())
+ else:
+ self.compile_operation(command.children[0])
+ def compile(self, source):
+ source += "\n"
+ for reg in "ABCDEFGHIXYK":
+ source += f"{reg}:0"
+ if reg != "K":
+ source += ";"
+ self.precompile(source)
+ self.compile_labels()
+ return self.encode()
+wma = WMA()
+ if len(sys.argv) == 3:
+ with open(sys.argv[1], "r") as fin:
+ with open(sys.argv[2], "wb") as fout:
+ fout.write(wma.compile(fin.read()))
+ else:
+ sys.stdout.buffer.write(wma.compile(sys.stdin.read()))
+except Exception as e:
+ print(e)
+ sys.exit(1)