-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtinyrv_opcodes_gen.py
executable file
·83 lines (77 loc) · 4.91 KB
/
tinyrv_opcodes_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#!/usr/bin/env python3
import re, csv, yaml, pathlib, collections
try:
def dr(h,l): return list(range(h,l-1,-1))
base = pathlib.Path('riscv-opcodes')
opcodes = yaml.safe_load(open(base / 'instr_dict.yaml'))
csrs = dict((int(a, 16), n) for fn in ['csrs.csv', 'csrs32.csv'] for a, n in csv.reader(open(base / fn), skipinitialspace=True))
arg_bits = dict((a, dr(int(h),int(l))) for a, h, l in csv.reader(open(base / 'arg_lut.csv'), skipinitialspace=True))
for s in open(base / 'constants.py').readlines(): # immediate scrambling from latex_mapping. Some better way?
if mask := re.match(r"latex_mapping\[['\"](.*?)['\"]\] = ['\"][^\[]*\[([^\]]*)\]['\"]", s):
fbits = sum([(dr(*(int(i) for i in part.split(':'))) if ':' in part else [int(part)]) for part in mask[2].split('$\\\\vert$')], [])
locs = [-1] * (max(fbits)+1)
for i, b in enumerate(fbits): locs[-b-1] = arg_bits[mask[1]][i]
arg_bits[mask[1]] = locs #[31] * (32-len(locs)) + locs if locs[0] == 31 else locs # sign extension to 32 bits
except Exception as e: raise Exception("Unable to load RISC-V specs. Do:\n"
"git clone https://github.com/riscv/riscv-opcodes.git\n"
"make -C riscv-opcodes")
for name, op in opcodes.items():
op['name'], op['mask'], op['match'] = name, int(op['mask'], 16), int(op['match'], 16)
del op['encoding']
op['arg_bits'] = {}
for vf in op['variable_fields']:
if vf not in arg_bits: continue
bits = [-1] * (32-len(arg_bits[vf])) + arg_bits[vf] # padding for proper alignment when combining hi and lo fields
vf2 = vf.replace('hi','').replace('lo','').replace('c_','')
op['arg_bits'][vf2] = [max(a,b) for a, b in zip(op['arg_bits'][vf2], bits)] if vf2 in op['arg_bits'] else bits
op['arg_getter'] = {}
for vf in op['arg_bits']:
op['arg_bits'][vf] = op['arg_bits'][vf][next(n for n, e in enumerate(op['arg_bits'][vf]) if e >= 0):] # remove padding
pieces, prev_shift, prev_mask = [], None, None
for target, source in enumerate(op['arg_bits'][vf][::-1]):
if source < 0: continue
shift = f'>>{source-target}' if target < source else f'<<{target-source}' if target > source else ''
if shift == prev_shift:
prev_mask |= 1<<target
pieces[-1] = f'(x{shift})&{prev_mask}'
else:
prev_shift, prev_mask = shift, 1<<target
pieces.append(f'(x{shift})&{1<<target}')
op['arg_getter'][vf] = f'$sext({len(op["arg_bits"][vf])},' + '|'.join(pieces) + ')$' if op['arg_bits'][vf][0] == 31 and vf != 'csr' or vf in {'imm6', 'nzimm6', 'nzimm18', 'nzimm10', 'imm12', 'bimm9'} else '$' + '|'.join(pieces) + '$'
# 2 purposes: sort most common ops first in the mask list, and give precendence to more specific compressed ops.
common_ops = ('addi,sw,lw,jal,bne,beq,add,jalr,lbu,slli,lui,andi,or,bltu,srli,and,sub,blt,bgeu,xor,sb,auipc,sltiu,bge,lb,mul,sltu,lhu,sll,srl,sh,amoadd_w,xori,ori,csrrci,csrrs,c_nop,c_addi16sp,c_ebreak,c_jr,c_jalr').split(',')
def make_mm(affinity):
mask_match = []
mask_match_aliases = collections.defaultdict(set)
for mask in dict((opcodes[op]['mask'],1) for op in common_ops + list(opcodes)):
matches = {}
for op in opcodes.values():
if op['mask'] != mask or op['name'].endswith('_rv32'): continue
if op['match'] in matches:
prefer = False
for ext in op['extension']:
if affinity in ext: prefer = True
if prefer:
mask_match_aliases[op['name']].add(matches[op['match']][1:-1])
matches[op['match']] = f'${op["name"]}$'
else:
mask_match_aliases[matches[op['match']][1:-1]].add(op['name'])
else: matches[op['match']] = f'${op["name"]}$'
mask_match.append((mask, matches))
return mask_match, dict(mask_match_aliases)
mask_match_rv32, mask_match_aliases_rv32 = make_mm('rv32')
mask_match_rv64, mask_match_aliases_rv64 = make_mm('rv64')
print('writing tinyrv/opcodes.py')
with open('tinyrv/opcodes.py', 'w') as f:
f.write(f'# auto-generated by tinyrv_opcodes_gen.py\n')
f.write('def sext(length, word): return word|~((1<<length)-1) if word&(1<<(length-1)) else word&((1<<length)-1)\n')
ostr = str(opcodes).replace("'$", "lambda x:").replace("$'", "")
f.write(f'opcodes={ostr}\n')
f.write(f'arg_bits={str(arg_bits)}\n')
f.write(f'csrs={str(csrs)}\n')
dstr = str(mask_match_rv32).replace("'$", "opcodes['").replace("$'", "']")
f.write(f'mask_match_rv32={dstr}\n')
f.write(f'mask_match_aliases_rv32={dict(mask_match_aliases_rv32)}\n')
dstr = str(mask_match_rv64).replace("'$", "opcodes['").replace("$'", "']")
f.write(f'mask_match_rv64={dstr}\n')
f.write(f'mask_match_aliases_rv64={dict(mask_match_aliases_rv64)}\n')