Skip to content

Commit

Permalink
Merge pull request #1438 from serengil/feat-task-1802-post-batch-changes
Browse files Browse the repository at this point in the history
post batch changes
  • Loading branch information
serengil authored Feb 18, 2025
2 parents ca73032 + 3037e4e commit 6c714a8
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 88 deletions.
13 changes: 7 additions & 6 deletions deepface/DeepFace.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ def analyze(
anti_spoofing (boolean): Flag to enable anti spoofing (default is False).
Returns:
(List[List[Dict[str, Any]]]): A list of analysis results if received batched image,
(List[List[Dict[str, Any]]]): A list of analysis results if received batched image,
explained below.
(List[Dict[str, Any]]): A list of dictionaries, where each dictionary represents
the analysis results for a detected face. Each dictionary in the list contains the
following keys:
Expand Down Expand Up @@ -385,12 +385,12 @@ def represent(
normalization: str = "base",
anti_spoofing: bool = False,
max_faces: Optional[int] = None,
) -> List[Dict[str, Any]]:
) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
"""
Represent facial images as multi-dimensional vector embeddings.
Args:
img_path (str, np.ndarray, IO[bytes], or Sequence[Union[str, np.ndarray, IO[bytes]]]):
img_path (str, np.ndarray, IO[bytes], or Sequence[Union[str, np.ndarray, IO[bytes]]]):
The exact path to the image, a numpy array
in BGR format, a file object that supports at least `.read` and is opened in binary
mode, or a base64 encoded image. If the source image contains multiple faces,
Expand Down Expand Up @@ -423,8 +423,9 @@ def represent(
max_faces (int): Set a limit on the number of faces to be processed (default is None).
Returns:
results (List[Dict[str, Any]]): A list of dictionaries, each containing the
following fields:
results (List[Dict[str, Any]] or List[Dict[str, Any]]): A list of dictionaries.
Result type becomes List of List of Dict if batch input passed.
Each containing the following fields:
- embedding (List[float]): Multidimensional vector representing facial features.
The number of dimensions varies based on the reference model
Expand Down
10 changes: 4 additions & 6 deletions deepface/models/Demography.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray,
def _predict_internal(self, img_batch: np.ndarray) -> np.ndarray:
"""
Predict for single image or batched images.
This method uses legacy method while receiving single image as input.
This method uses legacy method while receiving single image as input.
And switch to batch prediction if receives batched images.
Args:
Expand All @@ -35,11 +35,11 @@ def _predict_internal(self, img_batch: np.ndarray) -> np.ndarray:
with x = image width, y = image height and c = channel
The channel dimension will be 1 if input is grayscale. (For emotion model)
"""
if not self.model_name: # Check if called from derived class
if not self.model_name: # Check if called from derived class
raise NotImplementedError("no model selected")
assert img_batch.ndim == 4, "expected 4-dimensional tensor input"

if img_batch.shape[0] == 1: # Single image
if img_batch.shape[0] == 1: # Single image
# Predict with legacy method.
return self.model(img_batch, training=False).numpy()[0, :]

Expand All @@ -48,10 +48,8 @@ def _predict_internal(self, img_batch: np.ndarray) -> np.ndarray:
return self.model.predict_on_batch(img_batch)

def _preprocess_batch_or_single_input(
self,
img: Union[np.ndarray, List[np.ndarray]]
self, img: Union[np.ndarray, List[np.ndarray]]
) -> np.ndarray:

"""
Preprocess single or batch of images, return as 4-D numpy array.
Args:
Expand Down
16 changes: 8 additions & 8 deletions deepface/models/demography/Age.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

logger = Logger()

# ----------------------------------------
# dependency configurations

tf_version = package_utils.get_tf_major_version()
Expand All @@ -25,12 +24,11 @@
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Convolution2D, Flatten, Activation

# ----------------------------------------

WEIGHTS_URL = (
"https://github.com/serengil/deepface_models/releases/download/v1.0/age_model_weights.h5"
)


# pylint: disable=too-few-public-methods
class ApparentAgeClient(Demography):
"""
Expand All @@ -49,7 +47,7 @@ def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.float64,
List of images as List[np.ndarray] or
Batch of images as np.ndarray (n, 224, 224, 3)
Returns:
np.ndarray (age_classes,) if single image,
np.ndarray (age_classes,) if single image,
np.ndarray (n, age_classes) if batched images.
"""
# Preprocessing input image or image list.
Expand All @@ -59,11 +57,10 @@ def predict(self, img: Union[np.ndarray, List[np.ndarray]]) -> Union[np.float64,
age_predictions = self._predict_internal(imgs)

# Calculate apparent ages
if len(age_predictions.shape) == 1: # Single prediction list
if len(age_predictions.shape) == 1: # Single prediction list
return find_apparent_age(age_predictions)

return np.array([
find_apparent_age(age_prediction) for age_prediction in age_predictions])
return np.array([find_apparent_age(age_prediction) for age_prediction in age_predictions])


def load_model(
Expand Down Expand Up @@ -100,6 +97,7 @@ def load_model(

return age_model


def find_apparent_age(age_predictions: np.ndarray) -> np.float64:
"""
Find apparent age prediction from a given probas of ages
Expand All @@ -108,7 +106,9 @@ def find_apparent_age(age_predictions: np.ndarray) -> np.float64:
Returns:
apparent_age (float)
"""
assert len(age_predictions.shape) == 1, f"Input should be a list of predictions, \
assert (
len(age_predictions.shape) == 1
), f"Input should be a list of predictions, \
not batched. Got shape: {age_predictions.shape}"
output_indexes = np.arange(0, 101)
apparent_age = np.sum(age_predictions * output_indexes)
Expand Down
1 change: 0 additions & 1 deletion deepface/modules/demography.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def analyze(
batch_resp_obj.append(resp_obj)
return batch_resp_obj


# if actions is passed as tuple with single item, interestingly it becomes str here
if isinstance(actions, str):
actions = (actions,)
Expand Down
49 changes: 25 additions & 24 deletions deepface/modules/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ def represent(
normalization: str = "base",
anti_spoofing: bool = False,
max_faces: Optional[int] = None,
) -> List[Dict[str, Any]]:
) -> Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]:
"""
Represent facial images as multi-dimensional vector embeddings.
Args:
img_path (str, np.ndarray, or Sequence[Union[str, np.ndarray]]):
img_path (str, np.ndarray, or Sequence[Union[str, np.ndarray]]):
The exact path to the image, a numpy array in BGR format,
a base64 encoded image, or a sequence of these.
If the source image contains multiple faces,
Expand Down Expand Up @@ -53,8 +53,9 @@ def represent(
max_faces (int): Set a limit on the number of faces to be processed (default is None).
Returns:
results (List[Dict[str, Any]]): A list of dictionaries, each containing the
following fields:
results (List[Dict[str, Any]] or List[Dict[str, Any]]): A list of dictionaries.
Result type becomes List of List of Dict if batch input passed.
Each containing the following fields:
- embedding (List[float]): Multidimensional vector representing facial features.
The number of dimensions varies based on the reference model
Expand All @@ -80,14 +81,10 @@ def represent(
else:
images = [img_path]

batch_images = []
batch_regions = []
batch_confidences = []
batch_images, batch_regions, batch_confidences, batch_indexes = [], [], [], []

for single_img_path in images:
# ---------------------------------
# we have run pre-process in verification.
# so, this can be skipped if it is coming from verify.
for idx, single_img_path in enumerate(images):
# we have run pre-process in verification. so, skip if it is coming from verify.
target_size = model.input_shape
if detector_backend != "skip":
img_objs = detection.extract_faces(
Expand Down Expand Up @@ -130,6 +127,7 @@ def represent(
for img_obj in img_objs:
if anti_spoofing is True and img_obj.get("is_real", True) is False:
raise ValueError("Spoof detected in the given image.")

img = img_obj["face"]

# bgr to rgb
Expand All @@ -151,22 +149,25 @@ def represent(
batch_images.append(img)
batch_regions.append(region)
batch_confidences.append(confidence)
batch_indexes.append(idx)

# Convert list of images to a numpy array for batch processing
batch_images = np.concatenate(batch_images, axis=0)

# Forward pass through the model for the entire batch
embeddings = model.forward(batch_images)
if len(batch_images) == 1:
embeddings = [embeddings]

for embedding, region, confidence in zip(embeddings, batch_regions, batch_confidences):
resp_objs.append(
{
"embedding": embedding,
"facial_area": region,
"face_confidence": confidence,
}
)

return resp_objs

for idx in range(0, len(images)):
resp_obj = []
for idy, batch_index in enumerate(batch_indexes):
if idx == batch_index:
resp_obj.append(
{
"embedding": embeddings if len(batch_images) == 1 else embeddings[idy],
"facial_area": batch_regions[idy],
"face_confidence": batch_confidences[idy],
}
)
resp_objs.append(resp_obj)

return resp_objs[0] if len(images) == 1 else resp_objs
42 changes: 29 additions & 13 deletions tests/test_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,26 +144,39 @@ def test_analyze_for_different_detectors():
else:
assert result["gender"]["Man"] < result["gender"]["Woman"]

def test_analyze_for_batched_image():
img = "dataset/img4.jpg"

def test_analyze_for_numpy_batched_image():
img1_path = "dataset/img4.jpg"
img2_path = "dataset/couple.jpg"

# Copy and combine the same image to create multiple faces
img = cv2.imread(img)
img = np.stack([img, img])
assert len(img.shape) == 4 # Check dimension.
assert img.shape[0] == 2 # Check batch size.
img1 = cv2.imread(img1_path)
img2 = cv2.imread(img2_path)

expected_num_faces = [1, 2]

img1 = cv2.resize(img1, (500, 500))
img2 = cv2.resize(img2, (500, 500))

img = np.stack([img1, img2])
assert len(img.shape) == 4 # Check dimension.
assert img.shape[0] == 2 # Check batch size.

demography_batch = DeepFace.analyze(img, silent=True)
# 2 image in batch, so 2 demography objects.
assert len(demography_batch) == 2
assert len(demography_batch) == 2

for demography_objs in demography_batch:
assert len(demography_objs) == 1 # 1 face in each image
for demography in demography_objs: # Iterate over faces
assert type(demography) == dict # Check type
for i, demography_objs in enumerate(demography_batch):

assert len(demography_objs) == expected_num_faces[i]
for demography in demography_objs: # Iterate over faces
assert isinstance(demography, dict) # Check type
assert demography["age"] > 20 and demography["age"] < 40
assert demography["dominant_gender"] == "Woman"
assert demography["dominant_gender"] in ["Woman", "Man"]

logger.info("✅ test analyze for multiple faces done")


def test_batch_detect_age_for_multiple_faces():
# Load test image and resize to model input size
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
Expand All @@ -176,6 +189,7 @@ def test_batch_detect_age_for_multiple_faces():
assert np.array_equal(int(results[0]), int(results[1]))
logger.info("✅ test batch detect age for multiple faces done")


def test_batch_detect_emotion_for_multiple_faces():
# Load test image and resize to model input size
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
Expand All @@ -187,6 +201,7 @@ def test_batch_detect_emotion_for_multiple_faces():
assert np.array_equal(results[0], results[1])
logger.info("✅ test batch detect emotion for multiple faces done")


def test_batch_detect_gender_for_multiple_faces():
# Load test image and resize to model input size
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
Expand All @@ -198,6 +213,7 @@ def test_batch_detect_gender_for_multiple_faces():
assert np.array_equal(results[0], results[1])
logger.info("✅ test batch detect gender for multiple faces done")


def test_batch_detect_race_for_multiple_faces():
# Load test image and resize to model input size
img = cv2.resize(cv2.imread("dataset/img1.jpg"), (224, 224))
Expand All @@ -207,4 +223,4 @@ def test_batch_detect_race_for_multiple_faces():
assert len(results) == 2
# Check two races are the same
assert np.array_equal(results[0], results[1])
logger.info("✅ test batch detect race for multiple faces done")
logger.info("✅ test batch detect race for multiple faces done")
Loading

0 comments on commit 6c714a8

Please sign in to comment.