#!/usr/bin/python import sys import struct import lark GRAMMAR = r""" start: _NL? command ((_NL+|";") command)* _NL? command: LABEL? operation | LABEL (arg|mixed) | LABEL | INCLUDE ?operation: "nop" -> nop | ("mvj"|"mj") arg arg arg -> mj | "sblez" arg arg arg -> sjlez | "ablez" arg arg arg -> ajlez | "sblz" arg arg arg -> sjlz | "bles" arg arg arg -> jles | ("nbnz"|"tjt") arg arg -> tjt | ("dbnz"|"djt") arg arg -> djt | "sslez" arg arg -> sslez | "aslez" arg arg -> aslez | ("ibnc"|"ije") arg arg arg -> ije | "vblz" arg arg arg -> djlz | "xblz" arg arg arg -> xjlz | "dslz" arg -> dslz | ("ssgt"|"ssl") arg arg -> ssl | "mbnz" arg arg arg -> mbnz | "modbz" arg arg arg -> modbz | ("aja"|"aj") arg arg -> aj | "la" arg arg -> la | "ld" arg arg -> ld | "ia" arg -> ia | "jmc" arg -> jmc | ("jw"|"ja") arg -> ja | ("push"|"psh") arg -> psh | "pd" arg -> pd | "pop" arg -> pop | "shlbnz" arg arg arg -> shlbnz | "shrbnz" arg arg arg -> shrbnz | "nbz" arg arg -> nbz | "anz" arg arg arg -> anz | "abgz" arg arg arg -> abgz | "swp" arg arg -> swp | "add" arg arg -> h_add | "sub" arg arg -> h_sub | "inc" arg -> h_inc | "dec" arg -> h_dec | ("mov"|"mv") arg arg -> h_mov | ("jmp"|"j") arg -> h_jmp | "hlt" -> h_halt | "out" arg -> h_out | "in" arg -> h_in mixed: arg arg+ ?arg: INTEGER | CHAR | CHARS | QMARK | LABELOFFSET | OFFSET | NAME | rep rep: arg "*" COUNT COUNT: /[0-9]+/ INTEGER: /-?[0-9]+/ CHAR: "'" /./ "'" CHARS: "\"" /[^"]*/ "\"" QMARK: "?" OFFSET: "$" /(-|\+)[0-9]+/ LABELOFFSET: "$" /(-|\+)[A-Za-z][a-zA-Z0-9_]*/ LABEL: /[A-Za-z][a-zA-Z0-9_]*:/ NAME: /[A-Za-z][a-zA-Z0-9_]*/ INCLUDE: "+" /.+/ _NL: /\n+/ IG: /[ \t\r]+/ COM: /#.*[^\n]/ %ignore IG %ignore COM """ class WMA: def __init__(self): self.buffer = [] self.size = 0 self.parser = lark.Lark(GRAMMAR) def emit(self, *ops): self.buffer.extend(ops) if type(ops[0]) is tuple and ops[0][0]: return self.size += len(ops) def compile_arg(self, arg): if type(arg) is lark.Tree: if arg.data == "mixed": for subnode in arg.children: self.compile_arg(subnode) elif arg.data == "rep": count = int(arg.children[1].value) for _ in range(count): self.compile_arg(arg.children[0]) elif arg.type == "INTEGER": self.emit(int(arg.value)) elif arg.type == "CHAR": self.emit(ord(arg.value[1])) elif arg.type == "CHARS": for char in arg.value[1:-1]: self.emit(ord(char)) elif arg.type == "QMARK": self.emit(self.size+2) elif arg.type == "OFFSET": self.emit(self.size+int(arg.value[1:])+1) elif arg.type == "LABELOFFSET": self.emit((False, arg.value[2:], arg.value[1])) elif arg.type == "NAME": if arg.value == "IO": self.emit(-1) elif arg.value == "Z": self.emit(-2) elif arg.value == "O": self.emit(-3) elif arg.value == "N": self.emit(-4) elif arg.value == "J": self.emit(-5) elif arg.value == "T": self.emit(-6) elif arg.value == "SP": self.emit(-7) elif arg.value == "EZ": self.emit(-8) elif arg.value == "SZ": self.emit(-9) elif arg.value == "MZ": self.emit(-10) elif arg.value == "JZ": self.emit(-11) elif arg.value == "W": self.emit(-12) elif arg.value == "MM": self.emit(-13) else: self.emit((False, arg.value)) def compile_operation(self, op): if op.data == "nop": self.emit(0) elif op.data == "mj": self.emit(1) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) self.compile_arg(op.children[2]) elif op.data == "sjlez": self.emit(2) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) self.compile_arg(op.children[2]) elif op.data == "ajlez": self.emit(3) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) self.compile_arg(op.children[2]) elif op.data == "sjlz": self.emit(4) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) self.compile_arg(op.children[2]) elif op.data == "jles": self.emit(5) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) self.compile_arg(op.children[2]) elif op.data == "tjt": self.emit(6) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) elif op.data == "djt": self.emit(7) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) elif op.data == "sslez": self.emit(8) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) elif op.data == "aslez": self.emit(9) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) elif op.data == "ije": self.emit(10) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) self.compile_arg(op.children[2]) elif op.data == "djlz": self.emit(11) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) self.compile_arg(op.children[2]) elif op.data == "xjlz": self.emit(12) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) self.compile_arg(op.children[2]) elif op.data == "dslz": self.emit(13) self.compile_arg(op.children[0]) elif op.data == "ssl": self.emit(14) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) elif op.data == "mbnz": self.emit(15) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) self.compile_arg(op.children[2]) elif op.data == "modbz": self.emit(16) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) self.compile_arg(op.children[2]) elif op.data == "aj": self.emit(17) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) elif op.data == "la": self.emit(18) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) elif op.data == "ld": self.emit(19) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) elif op.data == "ia": self.emit(20) self.compile_arg(op.children[0]) elif op.data == "jmc": self.emit(21) self.compile_arg(op.children[0]) elif op.data == "ja": self.emit(22) self.compile_arg(op.children[0]) elif op.data == "psh": self.emit(23) self.compile_arg(op.children[0]) elif op.data == "pd": self.emit(24) self.compile_arg(op.children[0]) elif op.data == "pop": self.emit(25) self.compile_arg(op.children[0]) elif op.data == "shlbnz": self.emit(26) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) self.compile_arg(op.children[2]) elif op.data == "shrbnz": self.emit(27) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) self.compile_arg(op.children[2]) elif op.data == "nbz": self.emit(28) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) elif op.data == "anz": self.emit(29) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) self.compile_arg(op.children[2]) elif op.data == "abgz": self.emit(30) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) self.compile_arg(op.children[2]) elif op.data == "swp": self.emit(31) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) elif op.data == "h_add": self.emit(3) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) self.emit(self.size+2) elif op.data == "h_sub": self.emit(2) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) self.emit(self.size+2) elif op.data == "h_inc": self.emit(3) self.emit(-3) self.compile_arg(op.children[0]) self.emit(self.size+2) elif op.data == "h_dec": self.emit(2) self.emit(-3) self.compile_arg(op.children[0]) self.emit(self.size+2) elif op.data == "h_mov": self.emit(1) self.compile_arg(op.children[0]) self.compile_arg(op.children[1]) self.emit(self.size+2) elif op.data == "h_jmp": self.emit(1) self.emit(0) self.emit(0) self.compile_arg(op.children[0]) elif op.data == "h_halt": self.emit(1) self.emit(0) self.emit(0) self.emit(-1) elif op.data == "h_out": self.emit(1) self.compile_arg(op.children[0]) self.emit(-1) self.emit(self.size+2) elif op.data == "h_in": self.emit(1) self.emit(-1) self.compile_arg(op.children[0]) self.emit(self.size+2) def compile_labels(self): labels = {} position = 0 while position < len(self.buffer): this = self.buffer[position] if type(this) is tuple and this[0]: label = this[1] if label in labels: raise Exception(f"Duplicated label: {label}.") elif label in ("IO", "Z", "O", "N", "J", "T", "SP", "EZ", "SZ", "MZ", "JZ", "W", "MM"): raise Exception(f"Register override: {label}.") self.buffer.pop(position) labels[label] = position + 1 position += 1 position = 0 while position < len(self.buffer): this = self.buffer[position] if type(this) is tuple and not this[0]: label = this[1] if label not in labels: raise Exception(f"Undefined label/register: {label}.") if len(this) == 3: self.buffer[position] = 1 + (position + labels[label] if this[2] == '+' else position - labels[label]) else: self.buffer[position] = labels[label] position += 1 def encode(self): return struct.pack(f"<{'q'*self.size}", *self.buffer) def precompile(self, source): ast = self.parser.parse(source) for command in ast.children: if len(command.children) == 2: label = command.children[0].value[:-1] self.emit((True,label)) if type(command.children[1]) is lark.Tree and command.children[1].data != "mixed": self.compile_operation(command.children[1]) else: self.compile_arg(command.children[1]) else: if type(command.children[0]) is lark.Token: if command.children[0].type == "LABEL": label = command.children[0].value[:-1] self.emit((True,label)) else: with open(command.children[0].value[1:], "r") as f: self.precompile(f.read()) else: self.compile_operation(command.children[0]) def compile(self, source): source += "\n" for reg in "ABCDEFGHIXYK": source += f"{reg}:0" if reg != "K": source += ";" self.precompile(source) self.compile_labels() return self.encode() wma = WMA() try: if len(sys.argv) == 3: with open(sys.argv[1], "r") as fin: with open(sys.argv[2], "wb") as fout: fout.write(wma.compile(fin.read())) else: sys.stdout.buffer.write(wma.compile(sys.stdin.read())) except Exception as e: print(e) sys.exit(1)