From d992428d65c894c9eb9a68ae41cf455893d62369 Mon Sep 17 00:00:00 2001 From: Nat Lee Date: Tue, 17 Dec 2024 14:41:18 +0800 Subject: [PATCH] [update] embed into deepface module --- deepface/models/demography/Age.py | 11 +- deepface/models/demography/Emotion.py | 12 +- deepface/models/demography/Gender.py | 11 +- deepface/models/demography/Race.py | 11 +- deepface/modules/demography.py | 158 +++++++++++++++----------- 5 files changed, 99 insertions(+), 104 deletions(-) diff --git a/deepface/models/demography/Age.py b/deepface/models/demography/Age.py index 9c7ef3c3d..d470cff6f 100644 --- a/deepface/models/demography/Age.py +++ b/deepface/models/demography/Age.py @@ -40,7 +40,7 @@ def __init__(self): self.model = load_model() self.model_name = "Age" - def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.float64, np.ndarray]: + def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray: """ Predict apparent age(s) for single or multiple faces Args: @@ -48,8 +48,7 @@ def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.float64, List of images as List[np.ndarray] or Batch of images as np.ndarray (n, 224, 224, 3) Returns: - Single age as np.float64 or - Multiple ages as np.ndarray (n,) + np.ndarray (n,) """ # Convert to numpy array if input is list if isinstance(img, list): @@ -64,9 +63,6 @@ def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.float64, if len(imgs.shape) == 3: # Single image - add batch dimension imgs = np.expand_dims(imgs, axis=0) - is_single = True - else: - is_single = False # Batch prediction age_predictions = self.model.predict_on_batch(imgs) @@ -76,9 +72,6 @@ def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.float64, [find_apparent_age(age_prediction) for age_prediction in age_predictions] ) - # Return single value for single image - if is_single: - return apparent_ages[0] return apparent_ages diff --git a/deepface/models/demography/Emotion.py b/deepface/models/demography/Emotion.py index e6cb3d94d..065795e37 100644 --- a/deepface/models/demography/Emotion.py +++ b/deepface/models/demography/Emotion.py @@ -58,7 +58,7 @@ def _preprocess_image(self, img: np.ndarray) -> np.ndarray: img_gray = cv2.resize(img_gray, (48, 48)) return img_gray - def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, np.ndarray]: + def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray: """ Predict emotion probabilities for single or multiple faces Args: @@ -66,8 +66,7 @@ def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, List of images as List[np.ndarray] or Batch of images as np.ndarray (n, 224, 224, 3) Returns: - Single prediction as np.ndarray (n_emotions,) [emotion_probs] or - Multiple predictions as np.ndarray (n, n_emotions) + np.ndarray (n, n_emotions) where n_emotions is the number of emotion categories """ # Convert to numpy array if input is list @@ -83,9 +82,6 @@ def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, if len(imgs.shape) == 3: # Single image - add batch dimension imgs = np.expand_dims(imgs, axis=0) - is_single = True - else: - is_single = False # Preprocess each image processed_imgs = np.array([self._preprocess_image(img) for img in imgs]) @@ -96,13 +92,9 @@ def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, # Batch prediction predictions = self.model.predict_on_batch(processed_imgs) - # Return single prediction for single image - if is_single: - return predictions[0] return predictions - def load_model( url=WEIGHTS_URL, ) -> Sequential: diff --git a/deepface/models/demography/Gender.py b/deepface/models/demography/Gender.py index ac8716af2..23fd69b2c 100644 --- a/deepface/models/demography/Gender.py +++ b/deepface/models/demography/Gender.py @@ -40,7 +40,7 @@ def __init__(self): self.model = load_model() self.model_name = "Gender" - def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, np.ndarray]: + def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray: """ Predict gender probabilities for single or multiple faces Args: @@ -48,8 +48,7 @@ def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, List of images as List[np.ndarray] or Batch of images as np.ndarray (n, 224, 224, 3) Returns: - Single prediction as np.ndarray (2,) [female_prob, male_prob] or - Multiple predictions as np.ndarray (n, 2) + np.ndarray (n, 2) """ # Convert to numpy array if input is list if isinstance(img, list): @@ -64,16 +63,10 @@ def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, if len(imgs.shape) == 3: # Single image - add batch dimension imgs = np.expand_dims(imgs, axis=0) - is_single = True - else: - is_single = False # Batch prediction predictions = self.model.predict_on_batch(imgs) - # Return single prediction for single image - if is_single: - return predictions[0] return predictions diff --git a/deepface/models/demography/Race.py b/deepface/models/demography/Race.py index cec6aaad2..dc4a7889a 100644 --- a/deepface/models/demography/Race.py +++ b/deepface/models/demography/Race.py @@ -40,7 +40,7 @@ def __init__(self): self.model = load_model() self.model_name = "Race" - def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, np.ndarray]: + def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> np.ndarray: """ Predict race probabilities for single or multiple faces Args: @@ -48,8 +48,7 @@ def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, List of images as List[np.ndarray] or Batch of images as np.ndarray (n, 224, 224, 3) Returns: - Single prediction as np.ndarray (n_races,) [race_probs] or - Multiple predictions as np.ndarray (n, n_races) + np.ndarray (n, n_races) where n_races is the number of race categories """ # Convert to numpy array if input is list @@ -65,16 +64,10 @@ def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, if len(imgs.shape) == 3: # Single image - add batch dimension imgs = np.expand_dims(imgs, axis=0) - is_single = True - else: - is_single = False # Batch prediction predictions = self.model.predict_on_batch(imgs) - # Return single prediction for single image - if is_single: - return predictions[0] return predictions diff --git a/deepface/modules/demography.py b/deepface/modules/demography.py index b68314b9c..4c58314c2 100644 --- a/deepface/modules/demography.py +++ b/deepface/modules/demography.py @@ -9,7 +9,7 @@ from deepface.modules import modeling, detection, preprocessing from deepface.models.demography import Gender, Race, Emotion - +# pylint: disable=trailing-whitespace def analyze( img_path: Union[str, np.ndarray], actions: Union[tuple, list] = ("emotion", "age", "gender", "race"), @@ -130,83 +130,107 @@ def analyze( anti_spoofing=anti_spoofing, ) + # Anti-spoofing check + if anti_spoofing: + for img_obj in img_objs: + if img_obj.get("is_real", True) is False: + raise ValueError("Spoof detected in the given image.") + + # Prepare the input for the model + valid_faces = [] + face_regions = [] + face_confidences = [] + for img_obj in img_objs: - if anti_spoofing is True and img_obj.get("is_real", True) is False: - raise ValueError("Spoof detected in the given image.") - + # Extract the face content img_content = img_obj["face"] - img_region = img_obj["facial_area"] - img_confidence = img_obj["confidence"] + # Check if the face content is empty if img_content.shape[0] == 0 or img_content.shape[1] == 0: continue - # rgb to bgr + # Convert the image to RGB format from BGR img_content = img_content[:, :, ::-1] - - # resize input image + # Resize the image to the target size for the model img_content = preprocessing.resize_image(img=img_content, target_size=(224, 224)) - obj = {} - # facial attribute analysis - pbar = tqdm( - range(0, len(actions)), - desc="Finding actions", - disable=silent if len(actions) > 1 else True, - ) - for index in pbar: - action = actions[index] - pbar.set_description(f"Action: {action}") - - if action == "emotion": - emotion_predictions = modeling.build_model( - task="facial_attribute", model_name="Emotion" - ).predict(img_content) - sum_of_predictions = emotion_predictions.sum() - - obj["emotion"] = {} - for i, emotion_label in enumerate(Emotion.labels): - emotion_prediction = 100 * emotion_predictions[i] / sum_of_predictions - obj["emotion"][emotion_label] = emotion_prediction - - obj["dominant_emotion"] = Emotion.labels[np.argmax(emotion_predictions)] - - elif action == "age": - apparent_age = modeling.build_model( - task="facial_attribute", model_name="Age" - ).predict(img_content) - # int cast is for exception - object of type 'float32' is not JSON serializable - obj["age"] = int(apparent_age) - - elif action == "gender": - gender_predictions = modeling.build_model( - task="facial_attribute", model_name="Gender" - ).predict(img_content) - obj["gender"] = {} - for i, gender_label in enumerate(Gender.labels): - gender_prediction = 100 * gender_predictions[i] - obj["gender"][gender_label] = gender_prediction + valid_faces.append(img_content) + face_regions.append(img_obj["facial_area"]) + face_confidences.append(img_obj["confidence"]) - obj["dominant_gender"] = Gender.labels[np.argmax(gender_predictions)] + # If no valid faces are found, return an empty list + if not valid_faces: + return [] - elif action == "race": - race_predictions = modeling.build_model( - task="facial_attribute", model_name="Race" - ).predict(img_content) - sum_of_predictions = race_predictions.sum() + # Convert the list of valid faces to a numpy array + faces_array = np.array(valid_faces) + resp_objects = [{} for _ in range(len(valid_faces))] - obj["race"] = {} + # For each action, predict the corresponding attribute + pbar = tqdm( + range(0, len(actions)), + desc="Finding actions", + disable=silent if len(actions) > 1 else True, + ) + + for index in pbar: + action = actions[index] + pbar.set_description(f"Action: {action}") + + if action == "emotion": + # Build the emotion model + model = modeling.build_model(task="facial_attribute", model_name="Emotion") + emotion_predictions = model.predict(faces_array) + + for idx, predictions in enumerate(emotion_predictions): + sum_of_predictions = predictions.sum() + resp_objects[idx]["emotion"] = {} + + for i, emotion_label in enumerate(Emotion.labels): + emotion_prediction = 100 * predictions[i] / sum_of_predictions + resp_objects[idx]["emotion"][emotion_label] = emotion_prediction + + resp_objects[idx]["dominant_emotion"] = Emotion.labels[np.argmax(predictions)] + + elif action == "age": + # Build the age model + model = modeling.build_model(task="facial_attribute", model_name="Age") + age_predictions = model.predict(faces_array) + + for idx, age in enumerate(age_predictions): + resp_objects[idx]["age"] = int(age) + + elif action == "gender": + # Build the gender model + model = modeling.build_model(task="facial_attribute", model_name="Gender") + gender_predictions = model.predict(faces_array) + + for idx, predictions in enumerate(gender_predictions): + resp_objects[idx]["gender"] = {} + + for i, gender_label in enumerate(Gender.labels): + gender_prediction = 100 * predictions[i] + resp_objects[idx]["gender"][gender_label] = gender_prediction + + resp_objects[idx]["dominant_gender"] = Gender.labels[np.argmax(predictions)] + + elif action == "race": + # Build the race model + model = modeling.build_model(task="facial_attribute", model_name="Race") + race_predictions = model.predict(faces_array) + + for idx, predictions in enumerate(race_predictions): + sum_of_predictions = predictions.sum() + resp_objects[idx]["race"] = {} + for i, race_label in enumerate(Race.labels): - race_prediction = 100 * race_predictions[i] / sum_of_predictions - obj["race"][race_label] = race_prediction - - obj["dominant_race"] = Race.labels[np.argmax(race_predictions)] - - # ----------------------------- - # mention facial areas - obj["region"] = img_region - # include image confidence - obj["face_confidence"] = img_confidence - - resp_objects.append(obj) + race_prediction = 100 * predictions[i] / sum_of_predictions + resp_objects[idx]["race"][race_label] = race_prediction + + resp_objects[idx]["dominant_race"] = Race.labels[np.argmax(predictions)] + + # Add the face region and confidence to the response objects + for idx, resp_obj in enumerate(resp_objects): + resp_obj["region"] = face_regions[idx] + resp_obj["face_confidence"] = face_confidences[idx] return resp_objects