-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpicture_analysis.py
52 lines (36 loc) · 1.84 KB
/
picture_analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
import torch
from model_training import train_model
import tensorflow as tf
import numpy as np
def calculate_probabilities_for_classes(picture):
animal_classes = ["dog", "cat", "bird", "other"]
model = AutoModelForZeroShotImageClassification.from_pretrained("openai/clip-vit-large-patch14")
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14", clean_up_tokenization_spaces=True)
inputs = processor(images=picture, text=animal_classes, return_tensors="pt", padding=True)
with torch.no_grad():
output = model(**inputs)
probabilities = output.logits_per_image[0].softmax(dim=-1).numpy()
probabilities_list = list(probabilities)
result = [{"probability": prob, "class": animal_class}
for prob, animal_class in sorted(zip(probabilities_list, animal_classes), key=lambda x: -x[0])]
return result
def check_animal(picture, picture_keras):
animal_type = "error"
probabilities_for_classes = calculate_probabilities_for_classes(picture)
highest_probability = list(probabilities_for_classes[0].values())
if highest_probability[0] >= 0.9:
animal_type = highest_probability[1]
elif highest_probability[1] == "cat" and highest_probability[0] >= 0.8:
animal_type = "cat"
model, animal_classes = train_model(animal_type)
if len(animal_classes) > 1:
animal_prediction = model.predict(picture_keras)
animal_prediction_class = animal_classes[np.argmax(animal_prediction[0])]
animal_prediction_probability = np.max(tf.nn.softmax(animal_prediction[0])) * 100
else:
animal_prediction_probability = 1
animal_prediction_class = animal_classes[0]
print(animal_prediction_probability)
print(animal_prediction_class)
return animal_type