diff --git a/README.md b/README.md index 25b4bab..c1bf42e 100644 --- a/README.md +++ b/README.md @@ -143,3 +143,86 @@ For detailed API documentation, please refer to the [API Documentation](API_DOCU This project is licensed under the **APACHE License** - see the [LICENSE](LICENSE) file for details. +## deeptune: A Python Package for Model Fine-Tuning and Training + +`deeptune` is a Python package designed to help you fine-tune and train models of siamese architecture. It provides different backend options and loss functions like triplet loss and arcface loss. + +### Features + +- **Model Fine-Tuning:** Fine-tune pre-trained models with ease. +- **Training:** Train models with different backend options and loss functions. +- **Evaluation:** Evaluate models using various metrics. +- **CLI Support:** Interact with the package through the command line. + +### Installation + +To install the package, clone the repository and install the required packages: + +```bash +git clone https://github.com/Devasy23/FaceRec.git +cd FaceRec/deeptune +pip install -r requirements.txt +``` + +### Usage + +#### Command-Line Interface (CLI) + +The package provides a CLI to interact with the model fine-tuning and training functionalities. + +To evaluate a model, use the following command: + +```bash +python -m deeptune.cli.cli evaluate_model +``` + +Replace `` with the path to your model file and `` with the path to your dataset. + +#### Example + +Here is an example of how to use the package in your Python code: + +```python +from deeptune.evaluation.eval_mark_I import ( + load_and_preprocess_image, + generate_embeddings, + calculate_intra_cluster_distances, +) +from keras.models import load_model +import numpy as np + +# Load the pre-trained model +model_path = "path_to_your_model.h5" +model = load_model(model_path) + +# Path to the dataset +dataset_path = "path_to_your_dataset" + +# Generate embeddings +embeddings, labels = generate_embeddings(model, dataset_path) + +# Calculate intra-cluster distances +intra_distances = calculate_intra_cluster_distances(embeddings, labels) + +# Output the results +print(f"Intra-Cluster Distances: {intra_distances}") +print(f"Mean Distance: {np.mean(intra_distances)}") +``` + +### Project Structure + +- `deeptune/`: Main package directory. + - `__init__.py`: Makes `deeptune` a Python package. + - `data/`: Sub-package for data-related functionalities. + - `models/`: Sub-package for model-related functionalities. + - `training/`: Sub-package for training-related functionalities. + - `evaluation/`: Sub-package for evaluation-related functionalities. + - `utils/`: Sub-package for utility functions. + - `cli/`: Sub-package for CLI-related functionalities. + - `config.py`: Configuration file for storing settings or parameters. + - `requirements.txt`: Lists the dependencies for the package. + - `cli/cli.py`: CLI script to interact with the package. + +### License + +This project is licensed under the **APACHE License** - see the [LICENSE](LICENSE) file for details. diff --git a/deeptune/__init__.py b/deeptune/__init__.py new file mode 100644 index 0000000..8c6e313 --- /dev/null +++ b/deeptune/__init__.py @@ -0,0 +1 @@ +# deeptune package initialization diff --git a/deeptune/cli/__init__.py b/deeptune/cli/__init__.py new file mode 100644 index 0000000..385aad1 --- /dev/null +++ b/deeptune/cli/__init__.py @@ -0,0 +1 @@ +# This file makes `cli` a sub-package diff --git a/deeptune/cli/cli.py b/deeptune/cli/cli.py new file mode 100644 index 0000000..77f6db3 --- /dev/null +++ b/deeptune/cli/cli.py @@ -0,0 +1,32 @@ +import click +import os +from deeptune.evaluation.eval_mark_I import ( + load_and_preprocess_image, + generate_embeddings, + calculate_intra_cluster_distances, +) +from keras.models import load_model +import numpy as np + + +@click.group() +def cli(): + pass + + +@click.command() +@click.argument("model_path") +@click.argument("dataset_path") +def evaluate_model(model_path, dataset_path): + """Evaluate the model with the given dataset.""" + model = load_model(model_path) + embeddings, labels = generate_embeddings(model, dataset_path) + intra_distances = calculate_intra_cluster_distances(embeddings, labels) + print(f"Intra-Cluster Distances: {intra_distances}") + print(f"Mean Distance: {np.mean(intra_distances)}") + + +cli.add_command(evaluate_model) + +if __name__ == "__main__": + cli() diff --git a/deeptune/config.py b/deeptune/config.py new file mode 100644 index 0000000..0ca524e --- /dev/null +++ b/deeptune/config.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import os + +basedir = os.path.abspath(os.path.dirname(__file__)) + + +class Config: + SECRET_KEY = os.environ.get("SECRET_KEY") + DEBUG = True + upload_image_path = os.path.join(basedir, "static/Images/uploads") + ALLOWED_EXTENSIONS = ["jpg", "jpeg", "png", "jfif"] + image_data_file = os.path.join(basedir, "static/Images/image_data.json") diff --git a/deeptune/data/__init__.py b/deeptune/data/__init__.py new file mode 100644 index 0000000..7387250 --- /dev/null +++ b/deeptune/data/__init__.py @@ -0,0 +1 @@ +# deeptune data sub-package initialization diff --git a/deeptune/evaluation/__init__.py b/deeptune/evaluation/__init__.py new file mode 100644 index 0000000..2615185 --- /dev/null +++ b/deeptune/evaluation/__init__.py @@ -0,0 +1 @@ +# deeptune evaluation sub-package initialization diff --git a/Model-Training/eval-mark-I.py b/deeptune/evaluation/eval_mark_I.py similarity index 100% rename from Model-Training/eval-mark-I.py rename to deeptune/evaluation/eval_mark_I.py diff --git a/deeptune/models/__init__.py b/deeptune/models/__init__.py new file mode 100644 index 0000000..20a7923 --- /dev/null +++ b/deeptune/models/__init__.py @@ -0,0 +1 @@ +# deeptune models sub-package initialization diff --git a/deeptune/requirements.txt b/deeptune/requirements.txt new file mode 100644 index 0000000..002fa92 --- /dev/null +++ b/deeptune/requirements.txt @@ -0,0 +1,17 @@ +deepface==0.0.92 +fastapi==0.115.0 +keras==2.15.0 +matplotlib>=3.8.2 +numpy==1.26.0 +Pillow==10.4.0 +pydantic==2.9.2 +pymongo==4.6.3 +python-dotenv==1.0.1 +tensorflow==2.15.0 +uvicorn==0.31.0 +pytest==7.4.4 +httpx +sphinx +sphinx-rtd-theme +myst-parser +mypy-extensions diff --git a/deeptune/training/__init__.py b/deeptune/training/__init__.py new file mode 100644 index 0000000..c5d1940 --- /dev/null +++ b/deeptune/training/__init__.py @@ -0,0 +1 @@ +# deeptune training sub-package initialization diff --git a/deeptune/training/data_generators.py b/deeptune/training/data_generators.py new file mode 100644 index 0000000..4a68f4f --- /dev/null +++ b/deeptune/training/data_generators.py @@ -0,0 +1,41 @@ +import os +import random +import numpy as np +from keras.preprocessing import image + +class TripletGenerator: + def __init__(self, dataset_path, batch_size=32, target_size=(160, 160)): + self.dataset_path = dataset_path + self.batch_size = batch_size + self.target_size = target_size + self.classes = os.listdir(dataset_path) + self.class_indices = {cls: i for i, cls in enumerate(self.classes)} + self.image_paths = {cls: [os.path.join(dataset_path, cls, img) for img in os.listdir(os.path.join(dataset_path, cls))] for cls in self.classes} + + def __len__(self): + return len(self.classes) + + def __getitem__(self, idx): + cls = self.classes[idx] + positive_images = random.sample(self.image_paths[cls], 2) + negative_cls = random.choice([c for c in self.classes if c != cls]) + negative_image = random.choice(self.image_paths[negative_cls]) + + anchor = self.load_and_preprocess_image(positive_images[0]) + positive = self.load_and_preprocess_image(positive_images[1]) + negative = self.load_and_preprocess_image(negative_image) + + return np.array([anchor, positive, negative]), np.array([1, 0]) + + def load_and_preprocess_image(self, img_path): + img = image.load_img(img_path, target_size=self.target_size) + img_array = image.img_to_array(img) + img_array = np.expand_dims(img_array, axis=0) + img_array /= 255.0 + return img_array + + def generate(self): + while True: + for i in range(len(self)): + yield self[i] + diff --git a/deeptune/training/losses.py b/deeptune/training/losses.py new file mode 100644 index 0000000..3174704 --- /dev/null +++ b/deeptune/training/losses.py @@ -0,0 +1,46 @@ +import tensorflow as tf +import tensorflow.keras.backend as K +from tensorflow.keras.losses import Loss +import numpy as np + + +class TripletLoss(Loss): + def __init__(self, margin=1.0, **kwargs): + super().__init__(**kwargs) + self.margin = margin + + def call(self, y_true, y_pred): + anchor, positive, negative = y_pred[:, 0], y_pred[:, 1], y_pred[:, 2] + pos_dist = K.sum(K.square(anchor - positive), axis=-1) + neg_dist = K.sum(K.square(anchor - negative), axis=-1) + loss = K.maximum(pos_dist - neg_dist + self.margin, 0.0) + return K.mean(loss) + + +class ContrastiveLoss(Loss): + def __init__(self, margin=1.0, **kwargs): + super().__init__(**kwargs) + self.margin = margin + + def call(self, y_true, y_pred): + y_true = K.cast(y_true, y_pred.dtype) + pos_dist = K.sum(K.square(y_pred[:, 0] - y_pred[:, 1]), axis=-1) + neg_dist = K.sum(K.square(y_pred[:, 0] - y_pred[:, 2]), axis=-1) + loss = y_true * pos_dist + (1 - y_true) * K.maximum(self.margin - neg_dist, 0.0) + return K.mean(loss) + + +class ArcFaceLoss(Loss): + def __init__(self, scale=64.0, margin=0.5, **kwargs): + super().__init__(**kwargs) + self.scale = scale + self.margin = margin + + def call(self, y_true, y_pred): + y_true = K.cast(y_true, y_pred.dtype) + cosine = K.clip(y_pred, -1.0, 1.0) + theta = tf.acos(cosine) + target_logits = tf.cos(theta + self.margin) + logits = y_true * target_logits + (1 - y_true) * cosine + logits *= self.scale + return tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=logits) diff --git a/deeptune/utils/__init__.py b/deeptune/utils/__init__.py new file mode 100644 index 0000000..c324858 --- /dev/null +++ b/deeptune/utils/__init__.py @@ -0,0 +1 @@ +# deeptune utils sub-package initialization