-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMIL_train.py
63 lines (47 loc) · 1.38 KB
/
MIL_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import sys
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import MIL
import staintools
import utils
import data_prep
LOG_DIR = './test/log'
METAGRAPH_DIR = './test/model'
TFR_DIR = './tfr_test'
SCN_DIR = '.'
PNG_DIR = './png'
DIC_DIR = '.'
PRE_DIR = './tfr_test'
SAVE = True
TILE_SIZE = 299
OVERLAP = -10000
STD = './colorstandard.png'
ARCHITECTURE = 'I3'
N_EPOCH = 1
BATCH_SIZE = 2
TOP_K = 2
def main():
slides_tfr = os.listdir(TFR_DIR)
pretrain_tfr = os.listdir(PRE_DIR)
slides_scn = os.listdir(SCN_DIR)
slides_scn = list(filter(lambda x: (x[-4:] == '.scn'), slides_scn))
m = MIL.MIL(mode=ARCHITECTURE, log_dir=LOG_DIR, meta_graph=None)
#m.pre_train(pretrain_data_path=[PRE_DIR + '/' + f for f in pretrain_tfr],
# valid_data_path=['./tfr_test/0000026280.tfrecords'],
# batch_size=BATCH_SIZE, n_epoch=N_EPOCH, out_dir=METAGRAPH_DIR, save=SAVE)
m.train(data_dir=TFR_DIR, slides=slides_tfr, top_k=TOP_K, sample_rate=0.8,
valid_data_path=['./tfr_test/0000026280.tfrecords'],
n_epoch=N_EPOCH, batch_size=BATCH_SIZE,
save=SAVE, out_dir=METAGRAPH_DIR)
print('Trained!')
if __name__ == "__main__":
tf.reset_default_graph()
for DIR in (LOG_DIR, METAGRAPH_DIR):
try:
os.mkdir(DIR)
except FileExistsError:
pass
main()
sys.exit(0)