-
Notifications
You must be signed in to change notification settings - Fork 287
/
Copy pathcode2vec.py
38 lines (32 loc) · 1.53 KB
/
code2vec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from vocabularies import VocabType
from config import Config
from interactive_predict import InteractivePredictor
from model_base import Code2VecModelBase
def load_model_dynamically(config: Config) -> Code2VecModelBase:
assert config.DL_FRAMEWORK in {'tensorflow', 'keras'}
if config.DL_FRAMEWORK == 'tensorflow':
from tensorflow_model import Code2VecModel
elif config.DL_FRAMEWORK == 'keras':
from keras_model import Code2VecModel
return Code2VecModel(config)
if __name__ == '__main__':
config = Config(set_defaults=True, load_from_args=True, verify=True)
model = load_model_dynamically(config)
config.log('Done creating code2vec model')
if config.is_training:
model.train()
if config.SAVE_W2V is not None:
model.save_word2vec_format(config.SAVE_W2V, VocabType.Token)
config.log('Origin word vectors saved in word2vec text format in: %s' % config.SAVE_W2V)
if config.SAVE_T2V is not None:
model.save_word2vec_format(config.SAVE_T2V, VocabType.Target)
config.log('Target word vectors saved in word2vec text format in: %s' % config.SAVE_T2V)
if (config.is_testing and not config.is_training) or config.RELEASE:
eval_results = model.evaluate()
if eval_results is not None:
config.log(
str(eval_results).replace('topk', 'top{}'.format(config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION)))
if config.PREDICT:
predictor = InteractivePredictor(config, model)
predictor.predict()
model.close_session()