#!/usr/bin/python
import sys
import struct

import lark

GRAMMAR = r"""
start: _NL? command ((_NL+|";") command)* _NL?

command: LABEL? operation
       | LABEL (arg|mixed)
       | LABEL
       | INCLUDE
?operation: "org" INTEGER -> org
          | "nop" -> nop
          | ("mvj"|"mj") arg arg arg -> mj          
          | "sblez" arg arg arg -> sjlez
          | "ablez" arg arg arg -> ajlez
          | "sblz" arg arg arg -> sjlz
          | "bles" arg arg arg -> jles
          | ("nbnz"|"tjt") arg arg -> tjt
          | ("dbnz"|"djt") arg arg -> djt
          | "sslez" arg arg -> sslez
          | "aslez" arg arg -> aslez
          | ("ibnc"|"ije") arg arg arg -> ije
          | "vblz" arg arg arg -> djlz
          | "xblz" arg arg arg -> xjlz
          | "dslz" arg -> dslz
          | ("ssgt"|"ssl") arg arg -> ssl
          | "mbnz" arg arg arg -> mbnz
          | "modbz" arg arg arg -> modbz
          | ("aja"|"aj") arg arg -> aj
          | "la" arg arg -> la
          | "ld" arg arg -> ld
          | "ia" arg -> ia
          | "jmc" arg -> jmc
          | ("jw"|"ja"|"call") arg -> ja
          | ("push"|"psh") arg -> psh
          | "pd" arg -> pd
          | "pop" arg -> pop
          | "shlbnz" arg arg arg -> shlbnz
          | "shrbnz" arg arg arg -> shrbnz
          | "nbz" arg arg -> nbz
          | "anz" arg arg arg -> anz
          | "abgz" arg arg arg -> abgz
          | "swp" arg arg -> swp
          | "str" arg arg -> str
          | "add" arg arg -> h_add
          | "sub" arg arg -> h_sub
          | "inc" arg -> h_inc
          | "dec" arg -> h_dec
          | ("mov"|"mv") arg arg -> h_mov
          | ("jmp"|"j") arg -> h_jmp
          | "hlt" -> h_halt
          | "out" arg -> h_out
          | "outn" arg -> h_outn
          | "in" arg -> h_in
          | "dir" arg -> h_dir
          | "ret" -> h_ret
mixed: arg arg+
?arg: INTEGER
    | DOUBLE
    | CHAR
    | CHARS
    | OFFSET
    | LABELOFFSET
    | QMARK
    | EMARK
    | NAMEOFFSET
    | NAME
    | rep
rep: (INTEGER|DOUBLE|CHAR|CHARS|NAME|NAMEOFFSET) "*" COUNT
COUNT: /[0-9]+/
INTEGER: /-?[0-9]+/
DOUBLE: /-?[0-9]+\.[0-9]+/
CHAR: "'" /./ "'"
CHARS: "\"" /[^"]*/ "\""
QMARK: "?"
EMARK: "!"
OFFSET: "?" /(-|\+)[0-9]+/
LABELOFFSET: "?" /(-|\+)[A-Za-z_][a-zA-Z0-9_]*/ 
NAMEOFFSET: /[A-Za-z_][a-zA-Z0-9_]*(-|\+)[0-9]+/
LABEL: /[A-Za-z_][a-zA-Z0-9_]*:/
NAME: /[A-Za-z_][a-zA-Z0-9_]*/
INCLUDE: "+" /.+/
_NL: /\n+/
IG: /[ \t\r]+/
COM: /#.*[^\n]/
%ignore IG
%ignore COM
"""

class WMA:
  def __init__(self):
    self.origin = 1

    self.buffer = []
    self.size = 0
    
    self.parser = lark.Lark(GRAMMAR, parser='lalr')
    self.used_regs = []
    self.labels = {}

  def add_label(self, label):
    if label in self.labels:
      raise Exception(f"Duplicated label: {label}.")
    elif label in ("PC", "IO", "Z", "O", "N", "J", "T", "SP", "EZ", "SZ", "MZ", "JZ", "W", "MM", "DR", "ZZ"):
      raise Exception(f"Register override: {label}.")

    self.labels[label] = len(self.buffer)

  def emit(self, *ops):
    self.buffer.extend(ops)
    if type(ops[0]) is tuple and ops[0][0]:
      return

    self.size += len(ops)

  def compile_arg(self, arg):
    if type(arg) is lark.Tree:
      if arg.data == "mixed":
        for subnode in arg.children:
          self.compile_arg(subnode)    
      elif arg.data == "rep":
        count = int(arg.children[1].value)

        for _ in range(count):
          self.compile_arg(arg.children[0])
    elif arg.type == "INTEGER":
      self.emit(int(arg.value))
    elif arg.type == "DOUBLE":
      self.emit(float(arg.value))
    elif arg.type == "CHAR":
      self.emit(ord(arg.value[1]))
    elif arg.type == "CHARS":
      for char in arg.value[1:-1]:
        self.emit(ord(char))
    elif arg.type == "QMARK":
      self.emit(self.origin + self.size)
    elif arg.type == "EMARK":
      self.emit(self.origin + self.size + 1)
    elif arg.type == "OFFSET":
      self.emit(self.origin + self.size + int(arg.value[1:]))
    elif arg.type == "LABELOFFSET":
      self.emit((False, arg.value[2:], arg.value[1]))
    elif arg.type == "NAMEOFFSET":
      n, o = arg.value.split(
        '+' if '+' in arg.value else '-'
      )

      self.emit((False, n, '+' if '+' in arg.value else '-', int(o)))
    elif arg.type == "NAME":
      if arg.value == "PC":
        self.emit(0)
      elif arg.value == "IO":
        self.emit(-1)
      elif arg.value == "Z":
        self.emit(-2)
      elif arg.value == "O":
        self.emit(-3)
      elif arg.value == "N":
        self.emit(-4)
      elif arg.value == "J":
        self.emit(-5)
      elif arg.value == "T":
        self.emit(-6)
      elif arg.value == "SP":
        self.emit(-7)
      elif arg.value == "EZ":
        self.emit(-8)
      elif arg.value == "SZ":
        self.emit(-9)
      elif arg.value == "MZ":
        self.emit(-10)
      elif arg.value == "JZ":
        self.emit(-11)
      elif arg.value == "W":
        self.emit(-12)
      elif arg.value == "MM":
        self.emit(-13)
      elif arg.value == "DR":
        self.emit(-14)
      elif arg.value == "ZZ":
        self.emit(-15)
      else:
        self.used_regs.append(arg.value)
        self.emit((False, arg.value))

  def compile_operation(self, op):
    if op.data == "org":
      self.org = int(op.children[0].value)
    elif op.data == "nop":
      self.emit(0)
    elif op.data == "mj":
      self.emit(1)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
      self.compile_arg(op.children[2])
    elif op.data == "sjlez":
      self.emit(2)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
      self.compile_arg(op.children[2])
    elif op.data == "ajlez":
      self.emit(3)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
      self.compile_arg(op.children[2])
    elif op.data == "sjlz": 
      self.emit(4)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
      self.compile_arg(op.children[2])
    elif op.data == "jles":
      self.emit(5)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
      self.compile_arg(op.children[2]) 
    elif op.data == "tjt":
      self.emit(6)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
    elif op.data == "djt":
      self.emit(7)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
    elif op.data == "sslez":
      self.emit(8)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
    elif op.data == "aslez":
      self.emit(9)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
    elif op.data == "ije":
      self.emit(10)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
      self.compile_arg(op.children[2]) 
    elif op.data == "djlz":
      self.emit(11)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
      self.compile_arg(op.children[2]) 
    elif op.data == "xjlz":
      self.emit(12)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
      self.compile_arg(op.children[2]) 
    elif op.data == "dslz":
      self.emit(13)
      self.compile_arg(op.children[0])
    elif op.data == "ssl":
      self.emit(14)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
    elif op.data == "mbnz":
      self.emit(15)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
      self.compile_arg(op.children[2])
    elif op.data == "modbz":
      self.emit(16)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
      self.compile_arg(op.children[2])
    elif op.data == "aj":
      self.emit(17)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
    elif op.data == "la":
      self.emit(18)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
    elif op.data == "ld":
      self.emit(19)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
    elif op.data == "ia":
      self.emit(20)
      self.compile_arg(op.children[0])
    elif op.data == "jmc":
      self.emit(21)
      self.compile_arg(op.children[0])
    elif op.data == "ja":
      self.emit(22)
      self.compile_arg(op.children[0])
    elif op.data == "psh":
      self.emit(23)
      self.compile_arg(op.children[0])
    elif op.data == "pd":
      self.emit(24)
      self.compile_arg(op.children[0])
    elif op.data == "pop":
      self.emit(25)
      self.compile_arg(op.children[0])
    elif op.data == "shlbnz":
      self.emit(26)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
      self.compile_arg(op.children[2])
    elif op.data == "shrbnz":
      self.emit(27)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
      self.compile_arg(op.children[2])
    elif op.data == "nbz":
      self.emit(28)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
    elif op.data == "anz":
      self.emit(29)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
      self.compile_arg(op.children[2])
    elif op.data == "abgz":
      self.emit(30)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
      self.compile_arg(op.children[2])
    elif op.data == "swp":
      self.emit(31)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
    elif op.data == "str":
      self.emit(32)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
    elif op.data == "h_add":
      self.emit(3)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
      self.emit(self.size+2)
    elif op.data == "h_sub":
      self.emit(2)
      self.compile_arg(op.children[0])
      self.compile_arg(op.children[1])
      self.emit(self.size+2)
    elif op.data == "h_inc":
      self.emit(3)
      self.emit(-3)
      self.compile_arg(op.children[0])
      self.emit(self.size+2)
    elif op.data == "h_dec":
      self.emit(2)
      self.emit(-3)
      self.compile_arg(op.children[0])
      self.emit(self.size+2)
    elif op.data == "h_mov":
      if type(op.children[0]) is lark.Token and op.children[0].type in ("INTEGER", "DOUBLE", "CHAR"):
        self.emit(19)
        self.compile_arg(op.children[0])
        self.compile_arg(op.children[1])
      else:
        self.emit(1)
        self.compile_arg(op.children[0])
        self.compile_arg(op.children[1])
        self.emit(self.size+2)
    elif op.data == "h_jmp":
      self.emit(1)
      self.emit(0)
      self.emit(0)
      self.compile_arg(op.children[0])
    elif op.data == "h_halt":
      self.emit(1)
      self.emit(0)
      self.emit(0)
      self.emit(-1)
    elif op.data == "h_out":
      self.emit(1)
      self.compile_arg(op.children[0])
      self.emit(-1)
      self.emit(self.size+2)
    elif op.data == "h_outn":
      self.emit(1)
      self.compile_arg(op.children[0])
      self.emit(-2)
      self.emit(self.size+2)
    elif op.data == "h_in":
      self.emit(1)
      self.emit(-1)
      self.compile_arg(op.children[0])
      self.emit(self.size+2)
    elif op.data == "h_dir":
      self.emit(19)
      self.compile_arg(op.children[0])
      self.emit(-14)
    elif op.data == "h_ret":
      self.emit(1)
      self.emit(-2)
      self.emit(-2)
      self.emit(-5)

  def compile_labels(self):
    position = 0
    while position < len(self.buffer):
      this = self.buffer[position]

      if type(this) is tuple and not this[0]:
        label = this[1]

        if label not in self.labels:          
          raise Exception(f"Undefined label/register: {label}.")
        
        if len(this) == 3:
          self.buffer[position] = self.origin + (position + self.labels[label] if this[2] == '+' else position - selfmlabels[label])
        elif len(this) == 4:
          self.buffer[position] = self.origin + (self.labels[label] + this[3] if this[2] == '+' else self.labels[label] - this[3])
        else:
          self.buffer[position] = self.origin + self.labels[label]

      position += 1

  def encode(self):
    buffer = b""
    for b in self.buffer:
      if type(b) is float:
        b = struct.pack("<d", b)
      else:
        b = struct.pack("<q", b)

      buffer += b

    return buffer

  def precompile(self, source):
    ast = self.parser.parse(source)

    for command in ast.children:
      if len(command.children) == 2:
        label = command.children[0].value[:-1]
        
        self.add_label(label)

        if type(command.children[1]) is lark.Tree and command.children[1].data not in ("mixed", "rep"):                
          self.compile_operation(command.children[1])
        else:          
          self.compile_arg(command.children[1])
      else:
        if type(command.children[0]) is lark.Token:
          if command.children[0].type == "LABEL":
            label = command.children[0].value[:-1]

            self.add_label(label)

            self.labels[label] = len(self.buffer)
          else:
            with open(command.children[0].value[1:], "r") as f:
              self.precompile(f.read())
        else:
          self.compile_operation(command.children[0])

  def compile(self, source):
    self.precompile(source)
    
    for reg in "ABCDEFGHIXYKRS":
      if reg in self.used_regs:
        self.add_label(reg)
        self.buffer.append(0)

    self.compile_labels()

    return self.encode()

wma = WMA()

try:
  if len(sys.argv) == 3:
    with open(sys.argv[1], "r") as fin:
      with open(sys.argv[2], "wb") as fout:
        fout.write(wma.compile(fin.read()))
  else:
    sys.stdout.buffer.write(wma.compile(sys.stdin.read()))
except Exception as e:
  print(e)

  sys.exit(1)