-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtrain.py
49 lines (39 loc) · 2.21 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import tensorflow as tf
from dataset import *
import os
import argparse
from model import Model
import time
def train(epoch, batch_size, learning_rate, max_article_length):
tf.logging.set_verbosity(tf.logging.INFO)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
model = Model()
classifier = tf.estimator.Estimator(model_fn=model.build,
config=tf.estimator.RunConfig(session_config=config,
model_dir=os.path.join('ckpt', time.strftime("%m%d_%H%M%S"))),
params={
'feature_columns': [tf.feature_column.numeric_column(key='x')], \
'kernels': [(3,512),(4,512),(5,512)], \
'num_classes': 2, \
'learning_rate': learning_rate, \
'max_article_length': max_article_length
})
data = SST(Word2vecEnWordEmbedder)
classifier.train(input_fn=lambda: data.train_input_fn(batch_size=batch_size, padded_size=max_article_length, epoch=epoch))
eval_val = classifier.evaluate(input_fn=lambda: data.eval_input_fn(batch_size=batch_size, padded_size=max_article_length))
print("----------------Evaluation Test Set----------------")
print(eval_val)
if __name__=="__main__":
MAX_ARTICLE_LENGTH = 500
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", dest="epoch", help="Train Epoch", default=10, type=int)
parser.add_argument("--batch-size", dest="batch_size", help="Train Batch Size", default=300, type=int)
parser.add_argument("--learning-rate", dest="learning_rate", help="Train Learning Rate", default=0.001, type=float)
parser.add_argument("--gpu-index", dest="gpu_index", help="GPU Index Number", default="0", type=str)
args = vars(parser.parse_args())
os.environ['CUDA_VISIBLE_DEVICES'] = args['gpu_index']
train(epoch=args['epoch'],
batch_size=args['batch_size'],
learning_rate=args['learning_rate'],
max_article_length=MAX_ARTICLE_LENGTH)