Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/batch predict age and gender #1396

Merged
merged 66 commits into from
Feb 16, 2025

Conversation

NatLee
Copy link
Contributor

@NatLee NatLee commented Dec 6, 2024

Tickets

#441
#678
#1069
#1101

What has been done

This PR introduces a new predict function designed to support batch predictions.

How to test

The batch_analyze function allows users to load a model and perform batch predictions efficiently.

In this case, the goal is to implement a batch method capable of handling a single image containing multiple faces.

# built-in dependencies
from typing import Any, Dict, List, Union, Optional

# 3rd party dependencies
import numpy as np
from tqdm import tqdm

# project dependencies
from deepface.modules import modeling, detection, preprocessing
from deepface.models.demography import Gender, Race, Emotion

def batch_analyze(
    img_path: Union[str, np.ndarray],
    actions: Union[tuple, list] = ("emotion", "age", "gender", "race"),
    enforce_detection: bool = True,
    detector_backend: str = "opencv",
    align: bool = True,
    expand_percentage: int = 0,
    silent: bool = False,
    anti_spoofing: bool = False,
) -> List[Dict[str, Any]]:


    # if actions is passed as tuple with single item, interestingly it becomes str here
    if isinstance(actions, str):
        actions = (actions,)

    # check if actions is not an iterable or empty.
    if not hasattr(actions, "__getitem__") or not actions:
        raise ValueError("`actions` must be a list of strings.")

    actions = list(actions)

    # For each action, check if it is valid
    for action in actions:
        if action not in ("emotion", "age", "gender", "race"):
            raise ValueError(
                f"Invalid action passed ({repr(action)})). "
                "Valid actions are `emotion`, `age`, `gender`, `race`."
            )

    img_objs = detection.extract_faces(
        img_path=img_path,
        detector_backend=detector_backend,
        enforce_detection=enforce_detection,
        grayscale=False,
        align=align,
        expand_percentage=expand_percentage,
        anti_spoofing=anti_spoofing,
    )

    if anti_spoofing and any(img_obj.get("is_real", True) is False for img_obj in img_objs):
        raise ValueError("Spoof detected in the given image.")

    def preprocess_face(img_obj: Dict[str, Any]) -> Optional[np.ndarray]:
        """
        Preprocess the face image for analysis.
        """
        img_content = img_obj["face"]
        if img_content.shape[0] == 0 or img_content.shape[1] == 0:
            return None
        img_content = img_content[:, :, ::-1]  # BGR to RGB
        return preprocessing.resize_image(img=img_content, target_size=(224, 224))

    # Filter out empty faces
    face_data = [
        (
            preprocess_face(img_obj),
            img_obj["facial_area"],
            img_obj["confidence"]
        )
            for img_obj in img_objs if img_obj["face"].size > 0
    ]

    if not face_data:
        return []

    # Unpack the face data
    valid_faces, face_regions, face_confidences = zip(*face_data)
    faces_array = np.array(valid_faces)
    # Initialize the results list with face regions and confidence scores
    results = [{"region": region, "face_confidence": conf}
            for region, conf in zip(face_regions, face_confidences)]
    # Iterate over the actions and perform analysis
    pbar = tqdm(
        actions,
        desc="Finding actions",
        disable=silent if len(actions) > 1 else True,
    )
    for action in pbar:
        pbar.set_description(f"Action: {action}")
        model = modeling.build_model(task="facial_attribute", model_name=action.capitalize())
        predictions = model.predict(faces_array.squeeze())
        # If the model returns a single prediction, reshape it to match the number of faces.
        # Determine the correct shape of predictions by using number of faces and predictions shape.
        # Example: For 1 face with Emotion model, predictions will be reshaped to (1, 7).
        if faces_array.shape[0] == 1 and len(predictions.shape) <= 1:
            # For models like `Emotion`, which return a single prediction for a single face
            predictions = predictions.reshape(1, -1)
        # Update the results with the predictions
        # ----------------------------------------
        # For emotion, calculate the percentage of each emotion and find the dominant emotion
        if action == "emotion":
            emotion_results = [
                {
                    "emotion": {
                        label: 100 * pred[i] / pred.sum()
                        for i, label in enumerate(Emotion.labels)
                    },
                    "dominant_emotion": Emotion.labels[np.argmax(pred)]
                }
                for pred in predictions
            ]
            for result, emotion_result in zip(results, emotion_results):
                result.update(emotion_result)
        # ----------------------------------------
        # For age, find the dominant age category (0-100)
        elif action == "age":
            age_results = [{"age": int(np.argmax(pred) if len(pred.shape) > 0 else pred)}
                        for pred in predictions]
            for result, age_result in zip(results, age_results):
                result.update(age_result)
        # ----------------------------------------
        # For gender, calculate the percentage of each gender and find the dominant gender
        elif action == "gender":
            gender_results = [
                {
                    "gender": {
                        label: 100 * pred[i]
                        for i, label in enumerate(Gender.labels)
                    },
                    "dominant_gender": Gender.labels[np.argmax(pred)]
                }
                for pred in predictions
            ]
            for result, gender_result in zip(results, gender_results):
                result.update(gender_result)
        # ----------------------------------------
        # For race, calculate the percentage of each race and find the dominant race
        elif action == "race":
            race_results = [
                {
                    "race": {
                        label: 100 * pred[i] / pred.sum()
                        for i, label in enumerate(Race.labels)
                    },
                    "dominant_race": Race.labels[np.argmax(pred)]
                }
                for pred in predictions
            ]
            for result, race_result in zip(results, race_results):
                result.update(race_result)
    return results

My benchmark script is shown as below:

import time
import cv2
from deepface.DeepFace import analyze as single_analyze

def run_benchmark(img_path: str, num_faces: List[int], iterations: int = 3) -> None:
    """
    Run performance benchmark comparing single_analyze vs batch_analyze
    
    Args:
        img_path: Path to the base image
        num_faces: List of number of faces to test (will concatenate base image n times)
        iterations: Number of iterations to run for each test for more stable results
    """
    print("\n=== DeepFace Analysis Performance Benchmark ===")
    print(f"Base image: {img_path}")
    print(f"Iterations per test: {iterations}")
    print("-" * 50)

    # Load base image
    img = cv2.imread(img_path)
    if img is None:
        raise ValueError(f"Could not load image from {img_path}")

    # Warm up (first run is usually slower due to model loading)
    print("Warming up...", end="", flush=True)
    batch_analyze(img_path=img, silent=True)
    print("done\n")

    results = []
    
    for n in num_faces:
        print(f"\nTesting with {n} face{'s' if n > 1 else ''}")
        print("-" * 30)
        
        # Create test image with n faces
        test_img = cv2.hconcat([img] * n) if n > 1 else img
        
        # Test single_analyze
        single_times = []
        for i in range(iterations):
            start = time.time()
            single_analyze(img_path=test_img, silent=True)
            single_times.append(time.time() - start)
        single_avg = sum(single_times) / len(single_times)
        
        # Test batch_analyze
        batch_times = []
        for i in range(iterations):
            start = time.time()
            batch_analyze(img_path=test_img, silent=True)
            batch_times.append(time.time() - start)
        batch_avg = sum(batch_times) / len(batch_times)
        
        results.append({
            'faces': n,
            'single_avg': single_avg,
            'batch_avg': batch_avg,
            'speedup': single_avg / batch_avg
        })
        
        print(f"Single analyze: {single_avg:.2f}s (avg)")
        print(f"Batch analyze:  {batch_avg:.2f}s (avg)")
        print(f"Speedup:       {single_avg/batch_avg:.2f}x")

    # Print summary table
    print("\n=== Summary ===")
    print("Faces  | Single (s) | Batch (s) | Speedup")
    print("-" * 40)
    for r in results:
        print(f"{r['faces']:5d} | {r['single_avg']:9.2f} | {r['batch_avg']:8.2f} | {r['speedup']:7.2f}x")

if __name__ == "__main__":
    # Run benchmark with different numbers of faces
    run_benchmark(
        img_path="./tests/dataset/img4.jpg",
        num_faces=[1, 2, 5, 10],
        iterations=3
    )

Summary:

Faces Single (s) Batch (s) Speedup
1 0.39 0.33 1.19x
2 0.71 0.71 1.00x
5 1.79 1.62 1.11x
10 4.46 3.87 1.15x

The results indicate that batch processing improves efficiency, especially for a single face and larger datasets. While the speedup varies, with a peak of 1.19x for a single face, the overall trend suggests that batch processing is beneficial for optimizing processing time. This improvement is promising for scaling tasks efficiently.

@h-alice
Copy link
Contributor

h-alice commented Dec 6, 2024

Bump
Really needs this feature.

@serengil
Copy link
Owner

serengil commented Dec 6, 2024

I don't support having another predicts function. Instead, you can add that logic under predict.

1- predict accepts both single image and list of images as

img: Union[np.ndarray, List[np.ndarray]]

2- in predict function, you can check the type of img, and redirect it to your logic if it is list as

if isinstance(img, np.ndarray):
   # put old predict logic here
elif isinstance(img, np.ndarray):
   # put your batch processing logic here

3- this new logic is worth to have its own unit tests. possibly, you can add some unit tests here.

4- return type of predict should be Union[np.float64, np.ndarray]

5- You should also update the interface in DeepFace.py

@serengil
Copy link
Owner

Actions failed because of linting - link

************* Module deepface.models.demography.Age
pylint: Command line or configuration file:1: UserWarning: 'Exception' is not a proper value for the 'overgeneral-exceptions' option. Use fully qualified name (maybe 'builtins.Exception' ?) instead. This will cease to be checked at runtime in 3.1.0.
deepface/models/demography/Age.py:70:0: C0303: Trailing whitespace (trailing-whitespace)
************* Module deepface.models.demography.Emotion
deepface/models/demography/Emotion.py:88:0: C0303: Trailing whitespace (trailing-whitespace)

imgs = np.expand_dims(imgs, axis=0)

# Batch prediction
age_predictions = self.model.predict_on_batch(imgs)
Copy link
Owner

@serengil serengil Dec 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.predict causes memory issue when it is called in a for loop, that is why we call it as self.model(img, training=False).numpy()[0, :]

in your design, if this is called in a for loop, still it will cause memory problem.

IMO, if it is single image, we should call self.model(img, training=False).numpy()[0, :], it is many faces then call self.model.predict_on_batch

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for sharing your perspective on this matter.

We found the issue you mentioned is also mentioned in this page: tensorflow/tensorflow#44711. We believe it’s being resolved.

Furthermore, if we can utilize the batch prediction method provided in this PR, we may be able to avoid repeatedly calling the predict function within a loop of unrolled batch images, which is the root cause of the memory issue you described.

We recommend that you consider retaining our batch prediction method.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hey, even though this is sorted in newer tf versions, many users using old tf versions raise tickets about this problem. so, we should consider the people using older tf version. that is why, i recommend to use self.model(img, training=False).numpy()[0, :] for single images, and self.model.predict_on_batch for batches.

Copy link
Contributor

@h-alice h-alice Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! 👋
Please take a look at our prediction function, which uses the legacy single prediction method you suggested, and also provides batch prediction if a batch of images is provided.

Please let us know if there’s anything else we can improve. Any advice you have is greatly appreciated.

def _predict_internal(self, img_batch: np.ndarray) -> np.ndarray:

img = "dataset/img4.jpg"
# Copy and combine the same image to create multiple faces
img = cv2.imread(img)
img = cv2.hconcat([img, img])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hconcat makes a single image

input image shape before hconcat is (1728, 2500, 3)
input image shape after hconcat is (1728, 5000, 3)

to have a numpy array with (2, 1728, 2500, 3) shape, you should do something like:

img = np.stack((img, img), axis=0)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please check that img is now having (2, x, x, x) shape

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finally, unit tests failed for that input. The case you tested did not test what you did. It is still single image.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @serengil 👋
We have implemented batched images support in DeepFace::analysis, and the test cases have been modified as per your request. Please help us check if this matches your requirements.

Due to the complexity of designing a more efficient flow for analysis, we will prioritize extending the functionality of models that can accept batched images for now.

We will discuss enhancing the performance of the analysis function in separate threads or through pull requests. We would invite you to participate in these discussions once we are ready.

Please help us merge this PR if all the requirements are met.

"""
image_batch = np.array(img)
# Remove batch dimension in advance if exists
image_batch = image_batch.squeeze()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i initially confused about why why squeeze first and expand dimensions second

would you please add a comment here something like:

we did perform squeeze and expand dimensions sequentially to have same behaviour for (224, 224, 3), (1, 224, 224, 3) and (n, 224, 224, 3)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We took a look at the processing flow and discovered that the squeeze operation is unnecessary. Every single image input would have an expanded batch dimension of (1, 224, 224, 3), so there’s no need to handle inputs with this dimension.

The redundant squeeze process has been removed.

@serengil
Copy link
Owner

I am not available to review it until early Feb.

@serengil
Copy link
Owner

Tests failed

@NatLee
Copy link
Contributor Author

NatLee commented Feb 5, 2025

The test seems passed. If you're okay, it's ready to merge.

@@ -253,6 +256,29 @@ def analyze(
- 'middle eastern': Confidence score for Middle Eastern ethnicity.
- 'white': Confidence score for White ethnicity.
"""

if isinstance(img_path, np.ndarray) and len(img_path.shape) == 4:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this control should not be done in DeepFace.py

as you can see, we stored no logic in this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I've moved the control logic into the analyze function.

@serengil
Copy link
Owner

LGTM

thank you for your contribution

@serengil serengil merged commit 112d189 into serengil:master Feb 16, 2025
2 checks passed
@h-alice
Copy link
Contributor

h-alice commented Feb 17, 2025

Thank you so much @serengil !

@NatLee
Copy link
Contributor Author

NatLee commented Feb 17, 2025

@serengil Many thanks for your reviews!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants