-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_helpers.py
156 lines (130 loc) · 6.05 KB
/
generate_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import pandas as pd
from pandas import DataFrame
from pathlib import Path
from sae_lens import SAE, HookedSAETransformer
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
# LOAD DATASET FROM .JSON OR .JSONL INTO A DATAFRAME
def load_json_dataset(str_path:str):
'''
load_json_dataset(path_to_file)
Takes in a .json or .jsonl dataset and loads it into a DataFrame
'''
path = Path(str_path)
if path.suffix == '.json':
df = pd.read_json(str_path)
elif path.suffix == '.jsonl' :
df = pd.read_json(str_path,lines=True)
else:
raise ValueError('Invalid file type: Currently handles .json and .jsonl')
return df
# LOAD MULTIPLE SAES AND GET A LIST OF OUTPUTS
def load_pretrained_SAEs (sae_release:str, sae_id_list:list, device) -> list:
'''
load_SAEs : Instead of loading each pretrained SAE individually,
this helps you load all the SAEs you want all at once
and outputs a list of the SAEs, and corresponding lists
with their cfg dicts and sparsity values (note: ok if sparsity values are None)
'''
#initialize
n_saes = len(sae_id_list)
sae_list = [None]*n_saes
cfg_dict_list = [None]*n_saes
sparsity_list = [None]*n_saes
for i in range(n_saes):
sae_list[i], cfg_dict_list[i], sparsity_list[i] = \
SAE.from_pretrained(
release = sae_release, # <- Release name
sae_id = sae_id_list[i], # <- SAE id (not always a hook point!)
device = device
)
return sae_list, cfg_dict_list, sparsity_list
# LOAD SAE DIRECTORY AND FILTER BY MODEL AND RELEASES OF INTEREST
def sae_directory_info(model=None, release=None, exact_match_model:bool=True, exact_match_release:bool=False)-> DataFrame:
'''
Returns a filtered version of the sae directory data frame - using get_pretrained_saes_directory of SAELens
if model is None, then returns everything that get_pretrained_saes_directory of SAELens would
Model and Release can be keyword that should be in the name e.g. 'gemma-scope' or an exact name
By default, need to match the model exactly but not the release
'''
# directly from tutorial 2.0 of sae lens
model_df = pd.DataFrame.from_records({k:v.__dict__ for k,v in get_pretrained_saes_directory().items()}).T
model_df.drop(columns=["expected_var_explained", "expected_l0", "config_overrides", "conversion_func"], inplace=True)
# case: return everything
if model is None and release is None: # get everything
return model_df
# case Model is not None
if model is not None:
if exact_match_model:
model_df = model_df[model_df['model']==model]
else:
model_df = model_df[[model in r for r in model_df['model'].to_list()]]
# case release is not None
if release is not None:
if exact_match_release:
model_df = model_df[model_df['release']==release]
else:
model_df = model_df[[release in r for r in model_df['release'].to_list()]]
return model_df
######################################################################################################
### FUNCTIONS TO FILTER OUT DATAFRAME RESULTS TO HELP LOAD APPROPRIATE SAES ####
### GEMMA-SCOPE SPECIFIC FUNCTIONS UNLESS OTHERWISE SPECFIED ##################
### WILL NEED TO UPDATE IN THE FUTURE DEPENDING ON WHAT PPL NAME THEIR SAES #####
######################################################################################################
def get_saeids_for_layer(sae_id_list:list, layer:int=0, width:int=16):
'''
get_saeids_for_layer(sae_id_list:list, layer:int=0, width:int=16)
WORKS WITH GEMMA-SCOPE AND GPT2-SMALL RIGHT NOW
FOR GEMMA: assumes that there will be a layer_XX and a width_YY somewhere in the name
FOR GPT SMALL: assumes there will be blocks.XX. in the name
'''
# first tries gemmascope expectation
newlist= [s for s in sae_id_list if (f'width_{width}' in s and f'layer_{layer}' in s)]
# then tries gpt2small exepctation
if len(newlist) == 0:
newlist= [s for s in sae_id_list if (f'blocks.{layer}.' in s)]
return newlist
def parse_sae_LO(sae_id):
'''
FOR GEMMASCOPE
assumes that there will be a _l0_X at the END OF THE NAME
'''
return int(sae_id.split('_l0_')[-1])
def parse_sae_layer(sae_id):
'''
FOR GEMMASCOPE
assumes that there will be a 'layer_XX/' at the BEGINNING OF THE NAME
'''
return int(sae_id.split('/')[0].split('_')[-1])
def lowest_L0_sae_id(sae_id_list:list):
'''
FOR GEMMASCOPE
assumes that there will be a _l0_X at the END OF THE NAME
'''
return min(sae_id_list, key=parse_sae_LO)
def get_lowest_L0_sae_id_for_each_layer(model_df:DataFrame, layer=None, width=16):
'''
FOR GEMMASCOPE
This function takes DataFrame with a SINGLE ENTRY (single release) produced by sae_directory_info
e.g. model_df = sae_directory_info(model=model_name, release=release_id, exact_match_release=True)
Then it finds the lowest L0 for each layer of interest by name, not by dataframe
(WE SHOULD CHECK: tutorial 2.0 removes the L0 column, so I think it's the correc way?)
# layer should be a list of natural numbers or None
If Layer is None, returns lowest L0 sae_id for each of all layers'''
# THIS - [0] - IS WHERE THE SINGLE DATAFRAME ENTRY COMES IN, CAN BE FIXED IN THE FUTURE
all_sae_ids = list(model_df.saes_map.to_list()[0].keys())
# layer is a list of natural numbers
if layer is None:
layer = list(set([parse_sae_layer(s) for s in all_sae_ids]))
# initialize
nlayers = len(layer)
lowest_L0_sae_id_list = [None]*nlayers
for idx, ilayer in enumerate(layer):
layer_sae_ids = get_saeids_for_layer(all_sae_ids, layer=ilayer, width=width)
if len(layer_sae_ids) == 0:
print(f'No SAEs found for layer {ilayer}')
continue
elif len(layer_sae_ids) == 1:
lowest_L0_sae_id_list[idx]= layer_sae_ids
else:
lowest_L0_sae_id_list[idx] = lowest_L0_sae_id(layer_sae_ids)
return lowest_L0_sae_id_list