generated from scikit-learn-contrib/project-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathxcd.py
28 lines (22 loc) · 1.19 KB
/
xcd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import click
import os
from src.datasets.BACH import *
import json
from src.predict_cancer import predict_cancer
@click.command()
@click.argument('img_location')
@click.option('--concept_path', '-cp', default=os.path.join("data", "CONCEPTS_32"))
@click.option('--hovernet_path', '-hp', default=os.path.join("model", "HoVerNet.ckpt"))
@click.option('--cell_encoder_path', '-cep', default=os.path.join("model", "CellEncoder.ckpt"))
@click.option('--gnn_path', '-gp', default=os.path.join("model", "GCN.ckpt"))
@click.option('--explanation_file', '-eout', default=None)
def cli(img_location, concept_path, hovernet_path, cell_encoder_path, gnn_path, explanation_file):
print("Predicting cancer for image: {}".format(img_location))
prediction = predict_cancer(img_loc=img_location, hover_net_loc=hovernet_path, resnet_encoder=cell_encoder_path,
gnn_loc=gnn_path, explainability_location=explanation_file, concept_folder=concept_path)
pred = ["Normal", "Benign", "In Situ", "Invasive"][prediction.argmax().item()]
print("Prediction: {}".format(pred))
print(id_to_path(226))
if __name__ == '__main__':
os.chdir(os.path.dirname(os.path.abspath(__file__)))
cli()